Skip to main content

gam_sae/manifold/
construction.rs

1use super::*;
2use gam_math::jet_scalar::JetScalar;
3
4// [#780] Softmax-entropy Gershgorin majorizer leaf helpers live in a sibling
5// cohesive module, inlined here so they share this module scope.
6include!("softmax_entropy_majorizer.rs");
7
8// [#780] The exact stationarity-Jacobian correction and exact-Hessian solve
9// methods live in a sibling file, inlined here so they share this `impl
10// SaeManifoldTerm` / module scope while keeping this file under the line-count
11// gate.
12include!("construction_exact_hessian.rs");
13
14// [#780] The outer-gradient error taxonomy (`OuterGradientError`), the
15// `ForcedRowLayout` override alias, the `COTRAIN_*` co-training weight
16// constants, and the `AmortizedEncoderConsistency` report were extracted
17// verbatim into the sibling `construction_aux_types` module to keep this file
18// under the per-file line-count gate. They re-enter this module's scope via the
19// parent's glob re-export (`use super::*;` above).
20
21impl SaeManifoldTerm {
22    #[must_use = "build error must be handled"]
23    pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
24        if atoms.is_empty() {
25            return Err("SaeManifoldTerm::new: at least one atom required".into());
26        }
27        let n = atoms[0].n_obs();
28        let p = atoms[0].output_dim();
29        if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
30            return Err(format!(
31                "SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
32                assignment.n_obs(),
33                assignment.k_atoms(),
34                atoms.len()
35            ));
36        }
37        for (k, atom) in atoms.iter().enumerate() {
38            if atom.n_obs() != n {
39                return Err(format!(
40                    "SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
41                    atom.n_obs()
42                ));
43            }
44            if atom.output_dim() != p {
45                return Err(format!(
46                    "SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
47                    atom.output_dim()
48                ));
49            }
50            if atom.latent_dim != assignment.coords[k].latent_dim() {
51                return Err(format!(
52                    "SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
53                    atom.latent_dim,
54                    assignment.coords[k].latent_dim()
55                ));
56            }
57        }
58        Ok(Self {
59            atoms,
60            assignment,
61            temperature_schedule: None,
62            last_row_layout: None,
63            row_metric: None,
64            collapse_events: Vec::new(),
65            row_loss_weights: None,
66            last_frames_active: false,
67            assembly_chunk_override: None,
68            fixed_decoder_assembly: false,
69            softmax_active_cap: None,
70            border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
71            certificate_dispersion: None,
72            curvature_walk_report: None,
73            expected_evidence_gauge_deflated_directions: None,
74            evidence_gauge_deflation_reanchors: 0,
75            evidence_gauge_deflation_last_delta_sign: 0,
76            dictionary_cocollapse_reseeds: 0,
77            best_cocollapse_incumbent: None,
78            decoder_repulsion_gate: None,
79            barrier_coactivation_gate: None,
80            hybrid_split_report: None,
81            atom_inner_fits: None,
82            oos_linear_images: None,
83            separation_barrier_strength_override: None,
84        })
85    }
86
87    /// #1777 — apply the PER-FIT configuration overrides (the FFI-facing
88    /// [`SaeFitConfig`]) as the source of truth for this term's fit, isolating it
89    /// from the deprecated process-global barrier/α atomics.
90    ///
91    /// Distributes the config to its two authorities: the barrier strength override
92    /// onto the term (read by `separation_barrier_strength`), and the IBP-α
93    /// override onto the assignment (read by
94    /// [`SaeAssignment::resolved_ibp_alpha`]). Any `None` field leaves that axis on
95    /// its historical fallback (process-global override, then the
96    /// data-derived/mode default), so an all-`None` config is a strict no-op. Call
97    /// this after building the term (before the fit) so concurrent fits carrying
98    /// distinct configs stay isolated without any global writes.
99    pub fn set_fit_config(&mut self, config: SaeFitConfig) {
100        self.separation_barrier_strength_override = config.separation_barrier_strength_override;
101        self.assignment
102            .set_ibp_alpha_override(config.ibp_alpha_override);
103    }
104
105    /// #1777 — the per-fit configuration currently in force on this term,
106    /// reconstructed from its two authorities (the term's barrier override and the
107    /// assignment's α override). Round-trips with [`Self::set_fit_config`].
108    #[must_use]
109    pub fn fit_config(&self) -> SaeFitConfig {
110        SaeFitConfig {
111            separation_barrier_strength_override: self.separation_barrier_strength_override,
112            ibp_alpha_override: self.assignment.ibp_alpha_override,
113        }
114    }
115
116    /// #1408/#1409 — install the optional hard per-row active-atom cap for
117    /// Softmax mode (threaded from the fit/encode `top_k`). A `Some(k)` with
118    /// `1 <= k < K` makes the Softmax assignment optimize on the COMPACT
119    /// top-`k` row layout (see [`Self::softmax_active_cap`]); `Some(k) >= K`
120    /// and `None` are both no-ops (full support). Non-softmax modes ignore it.
121    pub fn set_softmax_active_cap(&mut self, top_k: Option<usize>) {
122        self.softmax_active_cap = match top_k {
123            Some(k) if k >= 1 && k < self.k_atoms() => Some(k),
124            _ => None,
125        };
126    }
127
128    /// Install the fitted reconstruction dispersion used by
129    /// [`dictionary_incoherence_report`]. This is a pure diagnostic scalar and
130    /// does not feed any loss, criterion, penalty, or optimizer state.
131    pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
132        if !dispersion.is_finite() || dispersion <= 0.0 {
133            return Err(format!(
134                "SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
135            ));
136        }
137        self.certificate_dispersion = Some(dispersion);
138        Ok(())
139    }
140
141    /// Harvest the per-atom inner-decoder-smooth byproducts (#1097 / #1103) the
142    /// residual-gauge certificate's post-PIRLS atom inference reports consume.
143    ///
144    /// This is the post-fit harness seam: it needs the reconstruction target `Z`
145    /// (`target`) and the fitted dispersion `φ` (`dispersion`), both available
146    /// only after the joint fit converges and the engine has discarded `Z` from
147    /// the objective. For each atom `k` it captures the Gaussian-identity
148    /// penalized smooth of the atom's leading decoder output channel `j`
149    /// (largest column 2-norm of `B_k`) against its partial residual
150    /// `e_{i} = z_i − fitted_i + a_{ik} g_k(t_i)` on channel `j`, holding all
151    /// other atoms and the assignment fixed at the fitted optimum — exactly the
152    /// fixed snapshot ([`crate::identifiability::AtomInnerFit`]) the Riesz
153    /// debiasing and split-LRT smooth-structure e-value read.
154    ///
155    /// A pure read of the fitted state: it mutates only the diagnostic
156    /// `atom_inner_fits` field, never a loss / criterion / penalty / optimizer
157    /// state. Atoms with no active rows or a degenerate (rank-deficient,
158    /// non-SPD) inner Hessian get a `None` slot — the genuine prerequisite (an
159    /// SPD penalized inner Hessian on a non-empty active set) is absent there.
160    pub fn set_atom_inner_fits(
161        &mut self,
162        target: ArrayView2<'_, f64>,
163        rho: &SaeManifoldRho,
164        dispersion: f64,
165    ) -> Result<(), String> {
166        if !dispersion.is_finite() || dispersion <= 0.0 {
167            return Err(format!(
168                "SaeManifoldTerm::set_atom_inner_fits: dispersion must be finite and positive, got {dispersion}"
169            ));
170        }
171        let n = self.n_obs();
172        let p = self.output_dim();
173        let k_atoms = self.k_atoms();
174        if target.dim() != (n, p) {
175            return Err(format!(
176                "SaeManifoldTerm::set_atom_inner_fits: target {:?} != ({n}, {p})",
177                target.dim()
178            ));
179        }
180
181        // #1026 — `atom_inner_fits` is a pure diagnostic; skip its dense (N×K×P)
182        // tensor (~256 GiB at K=32768,P=32) past a cell ceiling — all-None slots,
183        // never OOM. The fit is unaffected; only this audit field is absent.
184        if n.saturating_mul(k_atoms).saturating_mul(p) > 64_000_000 {
185            self.atom_inner_fits = Some((0..k_atoms).map(|_| None).collect());
186            return Ok(());
187        }
188
189        // Settled per-row assignments and per-(row, atom) decoded outputs, so the
190        // per-atom partial residual is `e_k = (z − fitted) + a_k decoded_k`.
191        let mut assignments = Vec::with_capacity(n);
192        for row in 0..n {
193            assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
194        }
195        let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
196        let mut dbuf = vec![0.0_f64; p];
197        for row in 0..n {
198            for atom_idx in 0..k_atoms {
199                self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
200                for c in 0..p {
201                    decoded[[row, atom_idx, c]] = dbuf[c];
202                }
203            }
204        }
205        let mut fitted = Array2::<f64>::zeros((n, p));
206        for row in 0..n {
207            for atom_idx in 0..k_atoms {
208                let a = assignments[row][atom_idx];
209                if a == 0.0 {
210                    continue;
211                }
212                for c in 0..p {
213                    fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
214                }
215            }
216        }
217
218        let mut inner_fits: Vec<Option<crate::identifiability::AtomInnerFit>> =
219            Vec::with_capacity(k_atoms);
220        for atom_idx in 0..k_atoms {
221            inner_fits.push(self.build_atom_inner_fit(
222                atom_idx,
223                target,
224                &assignments,
225                decoded.view(),
226                fitted.view(),
227                dispersion,
228            )?);
229        }
230        self.atom_inner_fits = Some(inner_fits);
231        Ok(())
232    }
233
234    /// Build one atom's fixed inner-smooth snapshot for the post-PIRLS atom
235    /// inference reports, or `None` when the atom has no active rows or the
236    /// penalized inner Hessian is not SPD. Returns `Err` only on a structural
237    /// inconsistency (shape mismatch), never on a benign degenerate atom.
238    pub(crate) fn build_atom_inner_fit(
239        &self,
240        atom_idx: usize,
241        target: ArrayView2<'_, f64>,
242        assignments: &[Array1<f64>],
243        decoded: ArrayView3<'_, f64>,
244        fitted: ArrayView2<'_, f64>,
245        dispersion: f64,
246    ) -> Result<Option<crate::identifiability::AtomInnerFit>, String> {
247        let atom = &self.atoms[atom_idx];
248        let n = atom.n_obs();
249        let m = atom.basis_size();
250        let p = atom.output_dim();
251        if m == 0 || p == 0 {
252            return Ok(None);
253        }
254
255        // Leading decoder output channel j = argmax_j ‖B_k[:, j]‖, the channel
256        // that carries the atom's signal.
257        let mut j_lead = 0usize;
258        let mut best_norm = -1.0_f64;
259        for col in 0..p {
260            let mut norm = 0.0_f64;
261            for r in 0..m {
262                let v = atom.decoder_coefficients[[r, col]];
263                norm += v * v;
264            }
265            if norm > best_norm {
266                best_norm = norm;
267                j_lead = col;
268            }
269        }
270        let beta = atom.decoder_coefficients.column(j_lead).to_owned();
271
272        // Active rows: a_{ik} > 0.
273        let active: Vec<usize> = (0..n)
274            .filter(|&row| assignments[row][atom_idx] > 0.0)
275            .collect();
276        let n_active = active.len();
277        // The penalized smooth needs at least as many active rows as it has
278        // basis columns to give a non-degenerate data Gram; below that the inner
279        // fit's SPD prerequisite is genuinely unmet.
280        if n_active == 0 {
281            return Ok(None);
282        }
283
284        let mut design = Array2::<f64>::zeros((n_active, m));
285        let mut derivative_design = Array2::<f64>::zeros((n_active, m));
286        let mut row_scores = Array2::<f64>::zeros((n_active, m));
287        let mut weights = Array1::<f64>::zeros(n_active);
288        for (slot, &row) in active.iter().enumerate() {
289            let a_ik = assignments[row][atom_idx];
290            let w_i = a_ik * a_ik;
291            weights[slot] = w_i;
292            for col in 0..m {
293                design[[slot, col]] = atom.basis_values[[row, col]];
294                // Leading latent axis (axis 0) is the atom's primary coordinate;
295                // it is the one the average-derivative functional integrates.
296                derivative_design[[slot, col]] = atom.basis_jacobian[[row, col, 0]];
297            }
298            // Partial residual on channel j, then the inner-smooth working
299            // response z_i = e_i / a_ik so that w_i (z_i − Φᵀβ) = a_ik r_i.
300            let e_i = target[[row, j_lead]] - fitted[[row, j_lead]]
301                + a_ik * decoded[[row, atom_idx, j_lead]];
302            let mu_hat = design.row(slot).dot(&beta);
303            let z_i = e_i / a_ik;
304            let res_i = z_i - mu_hat;
305            // Gaussian-identity score s_i = −w_i res_i Φ_i / φ.
306            let scale = -w_i * res_i / dispersion;
307            for col in 0..m {
308                row_scores[[slot, col]] = scale * design[[slot, col]];
309            }
310        }
311
312        // Penalized inner Hessian H = ΦᵀWΦ + S̃_k.
313        let mut xtwx = Array2::<f64>::zeros((m, m));
314        for slot in 0..n_active {
315            let w_i = weights[slot];
316            for a in 0..m {
317                let xa = design[[slot, a]];
318                if xa == 0.0 {
319                    continue;
320                }
321                for b in 0..m {
322                    xtwx[[a, b]] += w_i * xa * design[[slot, b]];
323                }
324            }
325        }
326        let penalty = atom.smooth_penalty.clone();
327        if penalty.dim() != (m, m) {
328            return Err(format!(
329                "build_atom_inner_fit: atom {atom_idx} smooth penalty {:?} != ({m}, {m})",
330                penalty.dim()
331            ));
332        }
333        let penalized_hessian = &xtwx + &penalty;
334
335        // SPD prerequisite: the inner penalized Hessian must factor, else the
336        // atom's inner-smooth fit is degenerate and no report is producible.
337        if penalized_hessian.cholesky(Side::Lower).is_err() {
338            return Ok(None);
339        }
340
341        // Peak (largest fitted |g_k| on channel j) and mode (largest assignment
342        // mass) design rows, over the active set.
343        let mut peak_slot = 0usize;
344        let mut peak_val = -1.0_f64;
345        let mut mode_slot = 0usize;
346        let mut mode_mass = -1.0_f64;
347        for (slot, &row) in active.iter().enumerate() {
348            let g_val = design.row(slot).dot(&beta).abs();
349            if g_val > peak_val {
350                peak_val = g_val;
351                peak_slot = slot;
352            }
353            let mass = assignments[row][atom_idx];
354            if mass > mode_mass {
355                mode_mass = mass;
356                mode_slot = slot;
357            }
358        }
359        let peak_design_row = design.row(peak_slot).to_owned();
360        let mode_design_row = design.row(mode_slot).to_owned();
361
362        Ok(Some(crate::identifiability::AtomInnerFit {
363            design,
364            derivative_design,
365            beta,
366            penalty,
367            penalized_hessian,
368            row_scores,
369            weights,
370            dispersion,
371            peak_design_row,
372            mode_design_row,
373        }))
374    }
375
376    /// Profile the Gaussian reconstruction dispersion at the current seed
377    /// state. This is the scale used to make SAE penalty seeds dimensionless
378    /// before the outer rho search starts.
379    pub fn seed_reconstruction_dispersion(
380        &self,
381        target: ArrayView2<'_, f64>,
382    ) -> Result<f64, String> {
383        let fitted = self.try_fitted()?;
384        if fitted.dim() != target.dim() {
385            return Err(format!(
386                "SaeManifoldTerm::seed_reconstruction_dispersion: fitted {:?} != target {:?}",
387                fitted.dim(),
388                target.dim()
389            ));
390        }
391        let n_scalar = (target.nrows() * target.ncols()).max(1) as f64;
392        let mut rss = 0.0_f64;
393        for row in 0..target.nrows() {
394            for col in 0..target.ncols() {
395                let r = target[[row, col]] - fitted[[row, col]];
396                rss += r * r;
397            }
398        }
399        if !rss.is_finite() || rss < 0.0 {
400            return Err(format!(
401                "SaeManifoldTerm::seed_reconstruction_dispersion: non-finite seed RSS {rss}"
402            ));
403        }
404        Ok((rss / n_scalar).max(SAE_SEED_DISPERSION_FLOOR))
405    }
406
407    /// Install per-row design honesty weights (#991) — the `1/π` inclusion
408    /// corrections of a designed corpus subsample (see the field docs on
409    /// `row_loss_weights` for exactly where they enter the objective).
410    ///
411    /// Weights must be finite and strictly positive, one per term row. They
412    /// are self-normalized to mean `1.0` here (only the *relative* design
413    /// correction matters at the fitted sample size; the absolute `n/budget`
414    /// scale would silently inflate the dispersion estimate against the
415    /// sample-sized dof). Weights that are identically equal after
416    /// normalization (an exact full pass, or any uniform design) are stored
417    /// as `None`, so the unweighted path stays bit-for-bit identical rather
418    /// than "multiplied by 1.0".
419    pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
420        if weights.len() != self.n_obs() {
421            return Err(format!(
422                "SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
423                weights.len(),
424                self.n_obs()
425            ));
426        }
427        if weights.is_empty() {
428            self.row_loss_weights = None;
429            return Ok(());
430        }
431        if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
432            return Err(
433                "SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
434                 positive"
435                    .to_string(),
436            );
437        }
438        let first = weights[0];
439        if weights.iter().all(|w| *w == first) {
440            // Uniform design (full pass, or flat measure): the normalized
441            // weight is exactly 1 everywhere — take the unweighted path.
442            self.row_loss_weights = None;
443            return Ok(());
444        }
445        let mean = weights.iter().sum::<f64>() / weights.len() as f64;
446        self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
447        Ok(())
448    }
449
450    /// The installed (mean-1 normalized) design honesty weights, `None` on the
451    /// exact unweighted path.
452    pub fn row_loss_weights(&self) -> Option<&[f64]> {
453        self.row_loss_weights.as_deref()
454    }
455
456    /// Drop any installed per-row reconstruction weights, returning the term to
457    /// the exact unweighted (full-pass) path. Used by the #997 structure-search
458    /// wiring to clear the internal estimation/evaluation mask off the adopted
459    /// term before the payload reconstruction is read over all rows.
460    pub fn clear_row_loss_weights(&mut self) {
461        self.row_loss_weights = None;
462    }
463
464    /// Huber-style OUTLIER-ROBUST per-row weights from the target activation
465    /// norms — the missing default *policy* for the existing
466    /// [`set_row_loss_weights`](Self::set_row_loss_weights) mechanism.
467    ///
468    /// The SAE fits unweighted least squares, which weights each token by its
469    /// squared residual ∝ `‖z_i‖²`. On real LLM residual streams the per-token
470    /// norm distribution is heavy-tailed (e.g. an OLMo mixed-layer slice has
471    /// `p99/median ≈ 4.7`), so a small **coherent** cluster of high-norm tokens —
472    /// typically special / attention-sink tokens, not semantic content —
473    /// dominates the objective (measured: the top 5% of tokens carry ~31% of the
474    /// total `‖z‖²` budget) and pulls dictionary atoms toward their direction.
475    /// Mean-centering does NOT address this (it is per-feature, not per-token).
476    ///
477    /// This returns Huber weights `w_i = min(1, δ·m / ‖z_i‖)` where `m` is the
478    /// MEDIAN token norm: tokens at or below `δ·m` keep full weight, higher-norm
479    /// tokens are downweighted so their objective share grows only LINEARLY (not
480    /// quadratically) with norm. `δ` is the robustness knob (`δ=1` thresholds at
481    /// the median; larger `δ` only touches the extreme tail). The result is
482    /// mean-normalized (overall objective scale preserved). OPT-IN: the caller
483    /// installs it via `set_row_loss_weights` — the default fit is unchanged.
484    pub fn robust_norm_row_weights(
485        target: ArrayView2<'_, f64>,
486        delta: f64,
487    ) -> Result<Vec<f64>, String> {
488        if !(delta.is_finite() && delta > 0.0) {
489            return Err(format!(
490                "robust_norm_row_weights: delta must be finite and positive; got {delta}"
491            ));
492        }
493        let n = target.nrows();
494        if n == 0 {
495            return Ok(Vec::new());
496        }
497        let norms: Vec<f64> = (0..n)
498            .map(|i| {
499                let r = target.row(i);
500                r.dot(&r).sqrt()
501            })
502            .collect();
503        let mut sorted = norms.clone();
504        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
505        // Median token norm (lower-median for even n; floored off zero so an
506        // all-zero/degenerate slice yields uniform weights instead of NaN).
507        let median = sorted[n / 2].max(f64::MIN_POSITIVE);
508        let thresh = delta * median;
509        let raw: Vec<f64> = norms
510            .iter()
511            .map(|&nm| if nm <= thresh { 1.0 } else { thresh / nm })
512            .collect();
513        let mean = raw.iter().sum::<f64>() / n as f64;
514        if !(mean.is_finite() && mean > 0.0) {
515            return Err("robust_norm_row_weights: degenerate weight normalizer".to_string());
516        }
517        Ok(raw.into_iter().map(|w| w / mean).collect())
518    }
519
520    /// Install the single per-row [`RowMetric`](gam_problem::RowMetric)
521    /// that both the reconstruction likelihood and the isometry gauge read.
522    /// Installing per-row output-Fisher factors here flips the provenance to
523    /// `OutputFisher` *and* is the only way the gauge acquires a non-identity
524    /// weight, so the two inner products cannot diverge. Passing a Euclidean
525    /// metric (or never calling this) keeps the bit-identical isotropic path.
526    ///
527    /// The metric's row count and output dimension must match the term.
528    pub fn set_row_metric(
529        &mut self,
530        metric: gam_problem::RowMetric,
531    ) -> Result<(), String> {
532        if metric.n_rows() != self.n_obs() {
533            return Err(format!(
534                "SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
535                metric.n_rows(),
536                self.n_obs()
537            ));
538        }
539        if metric.p_out() != self.output_dim() {
540            return Err(format!(
541                "SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
542                metric.p_out(),
543                self.output_dim()
544            ));
545        }
546        self.row_metric = Some(metric);
547        Ok(())
548    }
549
550    /// The installed per-row metric, if any. `None` ⇒ Euclidean / isotropic.
551    /// Consumed by the gauge wiring (to build the matching `WeightField`) and by
552    /// Object 4 (to read the [`MetricProvenance`](gam_problem::MetricProvenance)).
553    pub fn row_metric(&self) -> Option<&gam_problem::RowMetric> {
554        self.row_metric.as_ref()
555    }
556
557    /// The per-row inner product the additive diagnostics read through: the
558    /// installed [`RowMetric`](gam_problem::RowMetric) when one
559    /// was set (output-Fisher harvest present), otherwise a freshly-built
560    /// Euclidean metric of the term's own `(n_obs, output_dim)` shape. Either way
561    /// a metric always exists, so the diagnostics are never gated by a flag — the
562    /// Euclidean fallback is the bit-identical isotropic path.
563    pub(crate) fn diagnostic_metric(
564        &self,
565    ) -> Result<gam_problem::RowMetric, String> {
566        match self.row_metric() {
567            Some(metric) => Ok(metric.clone()),
568            None => {
569                gam_problem::RowMetric::euclidean(self.n_obs(), self.output_dim())
570            }
571        }
572    }
573
574    /// Build the additive post-fit diagnostic report for this fitted term: the
575    /// two-score per-atom [`AtomTwoLensReport`](crate::inference::atom_lens::AtomTwoLensReport)
576    /// (presence / behavioral coupling / discrepancy) and the residual-gauge
577    /// [`ResidualGaugeReport`](crate::identifiability::ResidualGaugeReport)
578    /// certificate.
579    ///
580    /// Both reports are read through the same single metric
581    /// ([`Self::diagnostic_metric`]): under a Euclidean / no-harvest provenance
582    /// the lens coupling is `None` and the gauge is certified under Euclidean
583    /// provenance — never an error, never gated by a flag (magic-by-default,
584    /// mirroring the metric selection itself).
585    ///
586    /// `per_atom_ard_variances`, when supplied, is one ARD variance vector per
587    /// atom (length = `latent_dim_k`), threaded into the certificate's
588    /// equal-ARD-rotation detection. `None` (or a per-atom `None`) ⇒ no ARD prior
589    /// on that atom. `isometry_pin_active` records whether an isometry gauge
590    /// penalty was installed on the fit: `false` escalates the certificate to the
591    /// `diffeomorphism-unpinned` verdict (the honest "no metric pin" statement),
592    /// exactly as the certificate's own escalation flag specifies.
593    ///
594    /// Pure read: it never mutates the term, never touches a loss / criterion /
595    /// penalty / optimizer state.
596    pub fn fit_diagnostics_report(
597        &self,
598        per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
599        isometry_pin_active: bool,
600        reconstruction_dispersion: Option<f64>,
601        assignments_override: Option<ArrayView2<'_, f64>>,
602    ) -> Result<SaeManifoldFitDiagnostics, String> {
603        if let Some(view) = assignments_override {
604            let n = self.n_obs();
605            let k = self.k_atoms();
606            if view.dim() != (n, k) {
607                return Err(format!(
608                    "fit_diagnostics_report: assignments_override shape {:?} must be ({n}, {k})",
609                    view.dim()
610                ));
611            }
612        }
613        let metric = self.diagnostic_metric()?;
614        let atom_two_lens =
615            crate::inference::atom_lens::atom_two_lens(self, &metric, assignments_override);
616
617        let (certificate_model, streamed_curvature) =
618            self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
619        // #998: within-atom gauge families are certified on their EXACT orbits
620        // in the model's own (decoder, coordinate) parameter space — compensated
621        // symmetries are data-nulls by construction there, no lowering-error
622        // calibration involved. This now holds whether or not an isometry pin is
623        // active:
624        //   * pin INACTIVE ⇒ the orbit verdict is the data residual alone (no
625        //     penalty operator);
626        //   * pin ACTIVE ⇒ the orbit verdict adds the isometry pin's orbit-space
627        //     curvature through an [`OrbitPenaltyOperator`] lowered from the
628        //     atom's second jet `Φ''` (the pullback-metric change along the orbit
629        //     differentiates `J = Φ'B` through `t`). A model-class symmetry that
630        //     preserves the metric stays a certified freedom; a non-isometric
631        //     orbit (a basis not closed under the action) is genuinely pinned.
632        // The relative-curvature fraction `cost/stiffness²` is invariant to the
633        // pin strength μ (both faces scale with μ), so the operator is built at a
634        // canonical unit weight. An atom whose basis exposes no analytic second
635        // jet supplies no operator and falls back to the data residual — never an
636        // error. Magic-by-default either way: the choice is derived from the fit,
637        // never a flag.
638        let views = self.atom_parameter_views();
639        let ops: Vec<Option<crate::identifiability::OrbitPenaltyOperator>> =
640            if isometry_pin_active {
641                views
642                    .iter()
643                    .map(|view| {
644                        view.as_ref().and_then(|v| {
645                            crate::identifiability::isometry_orbit_penalty_operator(
646                                v, 1.0,
647                            )
648                        })
649                    })
650                    .collect()
651            } else {
652                (0..self.k_atoms()).map(|_| None).collect()
653            };
654        let residual_gauge = if isometry_pin_active {
655            // The pin-active path consumes the per-row Jacobian curvature
656            // directly (the certificate_model retains it under a pin), so route
657            // through the non-streamed exact entry point.
658            crate::identifiability::residual_gauge_exact(
659                &certificate_model,
660                &views,
661                &ops,
662            )?
663        } else {
664            let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
665                "fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
666                    .to_string()
667            })?;
668            crate::identifiability::residual_gauge_exact_from_curvature_gram(
669                &certificate_model,
670                &views,
671                &ops,
672                curvature_gram,
673                root_rows,
674            )?
675        };
676
677        // #1097 / #1103: per-atom Riesz-debiased functionals and the any-n-valid
678        // split-LRT smooth-structure e-value (non-constant vs constant inner
679        // decoder), read straight off the certificate model — which carries
680        // each atom's `inner_fit` snapshot when the caller harvested it via
681        // [`Self::set_atom_inner_fits`] before this report. Atoms without a
682        // harvested inner fit degrade their inference fields to `None` inside
683        // `atom_inference_reports`, so this is always populated (one entry per
684        // atom) and never gated by a flag.
685        let atom_inference =
686            crate::identifiability::atom_inference_reports(&certificate_model);
687
688        Ok(SaeManifoldFitDiagnostics {
689            atom_two_lens,
690            residual_gauge,
691            incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
692                Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
693                    self, dispersion,
694                )?),
695                None => None,
696            },
697            atom_inference,
698        })
699    }
700
701    /// Build the trust-diagnostics producer for the Python `diagnostics` block.
702    ///
703    /// `assignments` is supplied by the payload assembly site so top-k projection,
704    /// when requested, is reflected in coverage/frequency and in the tangent
705    /// spectra. The active threshold is shared with the atom lens so all
706    /// assignment-support diagnostics agree on what "active" means.
707    pub fn trust_diagnostics_report(
708        &self,
709        assignments: ArrayView2<'_, f64>,
710    ) -> Result<SaeTrustDiagnostics, String> {
711        let n = self.n_obs();
712        let k_atoms = self.k_atoms();
713        if assignments.dim() != (n, k_atoms) {
714            return Err(format!(
715                "trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
716                assignments.dim()
717            ));
718        }
719        if !assignments.iter().all(|v| v.is_finite()) {
720            return Err("trust_diagnostics_report: assignments must be finite".to_string());
721        }
722        let metric = self.diagnostic_metric()?;
723        let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
724        let mut atoms = Vec::with_capacity(k_atoms);
725        let mut atom_trust = Vec::with_capacity(k_atoms);
726        for (atom_idx, atom) in self.atoms.iter().enumerate() {
727            let mut active_token_count = 0usize;
728            let mut activation_sum = 0.0_f64;
729            for row in 0..n {
730                let mass = assignments[[row, atom_idx]];
731                activation_sum += mass;
732                if mass > active_threshold {
733                    active_token_count += 1;
734                }
735            }
736            let coverage = if n > 0 {
737                active_token_count as f64 / n as f64
738            } else {
739                0.0
740            };
741            let activation_frequency = if n > 0 {
742                activation_sum / n as f64
743            } else {
744                0.0
745            };
746            let (sigma_min_tangent, sigma_max_tangent) = self
747                .atom_tangent_spectrum_from_assignments(
748                    atom_idx,
749                    assignments,
750                    &metric,
751                    active_threshold,
752                )?;
753            let tangent_condition_score = if sigma_max_tangent > 0.0 {
754                (sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
755            } else {
756                0.0
757            };
758            let trust_score = tangent_condition_score;
759            atom_trust.push(trust_score);
760            atoms.push(SaeAtomTrustDiagnostics {
761                trust_score,
762                sigma_min_tangent,
763                sigma_max_tangent,
764                tangent_condition_score,
765                coverage,
766                activation_frequency,
767                untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
768                active_token_count,
769            });
770        }
771        Ok(SaeTrustDiagnostics { atom_trust, atoms })
772    }
773
774    pub(crate) fn atom_tangent_spectrum_from_assignments(
775        &self,
776        atom_idx: usize,
777        assignments: ArrayView2<'_, f64>,
778        metric: &gam_problem::RowMetric,
779        active_threshold: f64,
780    ) -> Result<(f64, f64), String> {
781        let atom = &self.atoms[atom_idx];
782        let d = atom.latent_dim;
783        let p = self.output_dim();
784        if d == 0 || p == 0 {
785            return Ok((0.0, 0.0));
786        }
787        let mut gram = Array2::<f64>::zeros((d, d));
788        let mut active_mass_sum = 0.0_f64;
789        let mut jac_row = vec![0.0_f64; p * d];
790        for row in 0..self.n_obs() {
791            let mass = assignments[[row, atom_idx]];
792            if !(mass > active_threshold) {
793                continue;
794            }
795            active_mass_sum += mass;
796            for axis in 0..d {
797                let start = axis;
798                let mut tangent = vec![0.0_f64; p];
799                atom.fill_decoded_derivative_row(row, axis, &mut tangent);
800                for out in 0..p {
801                    jac_row[out * d + start] = tangent[out];
802                }
803            }
804            let row_pullback = metric.pullback(row, &jac_row, d);
805            for axis_a in 0..d {
806                for axis_b in 0..=axis_a {
807                    gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
808                }
809            }
810            jac_row.fill(0.0);
811        }
812        if !(active_mass_sum > 0.0) {
813            return Ok((0.0, 0.0));
814        }
815        let inv_mass = 1.0 / active_mass_sum;
816        for axis_a in 0..d {
817            for axis_b in 0..=axis_a {
818                let value = gram[[axis_a, axis_b]] * inv_mass;
819                gram[[axis_a, axis_b]] = value;
820                gram[[axis_b, axis_a]] = value;
821            }
822        }
823        let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
824            format!(
825                "trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
826            )
827        })?;
828        let mut sigma_min = f64::INFINITY;
829        let mut sigma_max = 0.0_f64;
830        for value in evals.iter().copied() {
831            let clamped = value.max(0.0);
832            let sigma = clamped.sqrt();
833            sigma_min = sigma_min.min(sigma);
834            sigma_max = sigma_max.max(sigma);
835        }
836        if sigma_min.is_finite() {
837            Ok((sigma_min, sigma_max))
838        } else {
839            Ok((0.0, 0.0))
840        }
841    }
842
843    /// Per-atom exact parameter-space views for the #998 certificate path:
844    /// the basis values / first-derivative jet, decoder coefficients, latent
845    /// coordinates, and assignment mass each atom was actually fitted with.
846    /// Sphere atoms get `None` (their chart's group action is nonlinear, so
847    /// the exact-orbit realisation does not apply and they stay on the frame
848    /// path), as does any atom whose coordinate chart width disagrees with its
849    /// latent dimension (a structurally inconsistent atom must not masquerade
850    /// as exactly certified).
851    pub(crate) fn atom_parameter_views(
852        &self,
853    ) -> Vec<Option<crate::identifiability::AtomParameterView>> {
854        let assignments = self.assignment.assignments();
855        let n = self.n_obs();
856        self.atoms
857            .iter()
858            .enumerate()
859            .map(|(k, atom)| {
860                if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
861                    return None;
862                }
863                let coords = self.assignment.coords[k].as_matrix().to_owned();
864                if coords.nrows() != n || coords.ncols() != atom.latent_dim {
865                    return None;
866                }
867                let mut activations = Array1::<f64>::zeros(n);
868                for row in 0..n {
869                    activations[row] = assignments[[row, k]];
870                }
871                // Second jet Φ'' (#998): supplied when the atom's evaluator
872                // exposes an analytic Hessian, so a pin-active fit can lower its
873                // orbit-space isometry penalty operator (the metric-change of the
874                // pullback gram differentiates Φ' through t). Absent ⇒ the orbit
875                // verdict stays on the data residual / no-pin path, never an
876                // error.
877                let basis_second_jet = atom
878                    .basis_evaluator
879                    .as_ref()
880                    .and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
881                    .and_then(|res| res.ok());
882                Some(crate::identifiability::AtomParameterView {
883                    basis_values: atom.basis_values.clone(),
884                    basis_jacobian: atom.basis_jacobian.clone(),
885                    decoder: atom.decoder_coefficients.clone(),
886                    coords,
887                    activations,
888                    basis_second_jet,
889                })
890            })
891            .collect()
892    }
893
894    /// Lower this fitted term into the self-contained
895    /// [`FittedSaeManifold`](crate::identifiability::FittedSaeManifold) the
896    /// residual-gauge certificate consumes.
897    ///
898    /// The certificate's parameter space is the per-atom decoder **frame** — the
899    /// `(output_dim, latent_dim)` image of the atom's latent axes in output space.
900    /// We realise it as the active-mass-weighted mean decoder tangent
901    /// `frame_k[:, a] = (Σ_n a_{nk} · ∂g_k/∂t_a(n)) / Σ_n a_{nk}` over the atom's
902    /// active rows (the centroid decoder Jacobian columns the certificate docs
903    /// name). The per-row pinning Jacobian block `J_n ∈ ℝ^{p × param_dim}` is the
904    /// assignment-weighted per-row decoder tangent placed at each atom's frame
905    /// slot: column `(k, i, a)` of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i]` — exactly
906    /// the directions the reconstruction data gives cost to, in the same metric
907    /// the fit used (whitened by the certificate through `RowMetric`).
908    ///
909    /// The flattened frame layout matches the certificate's
910    /// `vec(frame_0) ⊕ vec(frame_1) ⊕ …`, row-major within each frame
911    /// (`frame_k[i, a]` at offset `atom_offset(k) + i·latent_dim_k + a`).
912    pub(crate) fn to_residual_gauge_model(
913        &self,
914        metric: gam_problem::RowMetric,
915        per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
916        isometry_pin_active: bool,
917    ) -> Result<
918        (
919            crate::identifiability::FittedSaeManifold,
920            Option<(Array2<f64>, usize)>,
921        ),
922        String,
923    > {
924        use crate::identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
925
926        let n = self.n_obs();
927        let p = self.output_dim();
928        let k = self.k_atoms();
929        let assignments = self.assignment.assignments();
930
931        // Per-atom frame `(p, d)` = active-mass-weighted mean decoder tangent,
932        // and the flattened-frame column offset bookkeeping for the joint
933        // parameter vector (`vec(frame_0) ⊕ …`, row-major within each frame).
934        let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
935        let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
936        let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
937        let mut cursor = 0usize;
938        for (atom_idx, atom) in self.atoms.iter().enumerate() {
939            let d = atom.latent_dim;
940            let topology = match (&atom.basis_kind, d) {
941                (SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
942                    AtomTopology::Circle
943                }
944                (SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
945                    AtomTopology::Torus { latent_dim: d }
946                }
947                (SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
948                // `Cylinder` (`S¹ × ℝ`) has exactly one continuous gauge: the
949                // rotation (shift) of the periodic axis. The unbounded line axis
950                // carries no rotational gauge, and its translation is already
951                // pinned by the design's constant column — so the identifiability
952                // gauge is that of a single circle. Fixing it as `Torus` would
953                // over-impose a second (nonexistent) circle shift; fixing it as
954                // `EuclideanPatch { 2 }` would over-impose a frame rotation
955                // mixing the periodic and linear axes. `Circle` fixes the one
956                // real continuous gauge and leaves the linear axis ungauged.
957                (SaeAtomBasisKind::Cylinder, _) => AtomTopology::Circle,
958                (
959                    SaeAtomBasisKind::Linear
960                    | SaeAtomBasisKind::Duchon
961                    | SaeAtomBasisKind::EuclideanPatch
962                    | SaeAtomBasisKind::Poincare
963                    | SaeAtomBasisKind::Precomputed(_),
964                    _,
965                ) => AtomTopology::EuclideanPatch { latent_dim: d },
966            };
967
968            let mut frame = Array2::<f64>::zeros((p, d));
969            let mut active_mass = 0.0_f64;
970            let mut tangent = vec![0.0_f64; p];
971            for row in 0..n {
972                let a_nk = assignments[[row, atom_idx]];
973                if !(a_nk > 0.0) {
974                    continue;
975                }
976                active_mass += a_nk;
977                for axis in 0..d {
978                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
979                    for i in 0..p {
980                        frame[[i, axis]] += a_nk * tangent[i];
981                    }
982                }
983            }
984            if active_mass > 0.0 {
985                let inv = 1.0 / active_mass;
986                frame.mapv_inplace(|v| v * inv);
987            }
988
989            // #995 lowering-error scale: mass-weighted relative dispersion of
990            // the per-row tangents around the mean frame just built,
991            //   Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖².
992            // 0 ⇒ the frame represents every active row exactly (flat
993            // decoder); → 1 ⇒ the tangent field disperses so strongly (e.g. a
994            // full circle, whose tangents average out) that the mean-frame
995            // compression cannot distinguish gauge motion from curvature. The
996            // certificate calibrates its per-generator verdict tolerance to
997            // this scale so it never claims a pin it cannot resolve.
998            let mut disp_num = 0.0_f64;
999            let mut disp_den = 0.0_f64;
1000            for row in 0..n {
1001                let a_nk = assignments[[row, atom_idx]];
1002                if !(a_nk > 0.0) {
1003                    continue;
1004                }
1005                for axis in 0..d {
1006                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1007                    for i in 0..p {
1008                        let dev = tangent[i] - frame[[i, axis]];
1009                        disp_num += a_nk * dev * dev;
1010                        disp_den += a_nk * tangent[i] * tangent[i];
1011                    }
1012                }
1013            }
1014            let lowering_error = if disp_den > 0.0 {
1015                (disp_num / disp_den).clamp(0.0, 1.0)
1016            } else {
1017                0.0
1018            };
1019
1020            let ard_variances = per_atom_ard_variances
1021                .and_then(|all| all.get(atom_idx))
1022                .and_then(|opt| opt.clone())
1023                .filter(|v| v.len() == d);
1024
1025            fitted_atoms.push(FittedAtom {
1026                name: atom.name.clone(),
1027                topology,
1028                frame,
1029                ard_variances,
1030                lowering_error,
1031                // #1019: post-fit chart canonicalization (arc length for
1032                // d = 1, isometry-flow for d = 2 torus, flat-reference
1033                // isometry-flow for d = 2 free/patch, round-sphere
1034                // conformal-boost flow for d = 2 sphere atoms) pins the chart;
1035                // the certificate downgrades this atom's chart freedom to the
1036                // finite isometry group with PinnedByCanonicalization
1037                // provenance.
1038                chart_canonicalized: atom.chart_canonicalized
1039                    && (d == 1
1040                        || (d == 2
1041                            && matches!(
1042                                atom.basis_kind,
1043                                SaeAtomBasisKind::Torus
1044                                    | SaeAtomBasisKind::Linear
1045                                    | SaeAtomBasisKind::Duchon
1046                                    | SaeAtomBasisKind::EuclideanPatch
1047                                    | SaeAtomBasisKind::Sphere
1048                            ))),
1049                // #1097 / #1103: the per-atom inner-decoder-smooth snapshot,
1050                // attached when the post-fit harness has run
1051                // [`Self::set_atom_inner_fits`] (it needs the reconstruction
1052                // target Z, dropped from the objective at fit end). `None` on a
1053                // bare certificate-only model, or for a degenerate atom whose
1054                // inner Hessian was not SPD.
1055                inner_fit: self
1056                    .atom_inner_fits
1057                    .as_ref()
1058                    .and_then(|fits| fits.get(atom_idx))
1059                    .and_then(|slot| slot.clone()),
1060            });
1061            atom_offsets.push(cursor);
1062            atom_axis_dim.push(d);
1063            cursor += p * d;
1064        }
1065        let param_dim = cursor;
1066
1067        // Per-row pinning Jacobian `J_n ∈ ℝ^{p × param_dim}` flattened row-major
1068        // (`J_n[i, c] = jacobian_rows[n][i · param_dim + c]`). Column `(k, i', a)`
1069        // of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i']` placed at the atom-k frame slot
1070        // and read out on output coordinate `i = i'` (a frame perturbation of
1071        // output `i'` moves only the row's output coordinate `i'`).
1072        //
1073        // The pinned certificate still consumes the legacy row-block contract.
1074        // The unpinned exact path consumes only `RᵀR`, so stream each transient
1075        // row Jacobian through the metric whitening and discard it immediately.
1076        let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
1077            let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1078            let mut tangent = vec![0.0_f64; p];
1079            for row in 0..n {
1080                let mut j_flat = vec![0.0_f64; p * param_dim];
1081                for (atom_idx, atom) in self.atoms.iter().enumerate() {
1082                    let a_nk = assignments[[row, atom_idx]];
1083                    if !(a_nk > 0.0) {
1084                        continue;
1085                    }
1086                    let d = atom_axis_dim[atom_idx];
1087                    let base = atom_offsets[atom_idx];
1088                    for axis in 0..d {
1089                        atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1090                        for i in 0..p {
1091                            // Frame coordinate `(k, i, axis)` sits at column
1092                            // `base + i·d + axis`; it sources output coordinate `i`.
1093                            j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1094                        }
1095                    }
1096                }
1097                jacobian_rows.push(j_flat);
1098            }
1099            (jacobian_rows, None)
1100        } else {
1101            let streamed = self.residual_gauge_streamed_data_curvature(
1102                &metric,
1103                &atom_offsets,
1104                &atom_axis_dim,
1105                param_dim,
1106            )?;
1107            (Vec::new(), Some(streamed))
1108        };
1109
1110        // Isometry-penalty curvature root over the frame parameter space. When
1111        // the isometry gauge pin is active it gives curvature along every fitted
1112        // frame direction (it resists deviation of the decoder image from its
1113        // arc-length parameterization), so its row space is the span of the
1114        // per-atom frame columns: one root row per `(k, axis)` carrying that
1115        // atom's frame column at the atom's frame slot. Empty (`0 × param_dim`)
1116        // when the pin is inactive — exactly the certificate's escalation
1117        // condition to `diffeomorphism-unpinned`.
1118        let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
1119            let mut root_rows: Vec<Array1<f64>> = Vec::new();
1120            for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
1121                let d = atom_axis_dim[atom_idx];
1122                let base = atom_offsets[atom_idx];
1123                for axis in 0..d {
1124                    let mut r = Array1::<f64>::zeros(param_dim);
1125                    let mut any = false;
1126                    for i in 0..p {
1127                        let v = fitted.frame[[i, axis]];
1128                        if v != 0.0 {
1129                            any = true;
1130                        }
1131                        r[base + i * d + axis] = v;
1132                    }
1133                    if any {
1134                        root_rows.push(r);
1135                    }
1136                }
1137            }
1138            let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
1139            for (ri, r) in root_rows.iter().enumerate() {
1140                root.row_mut(ri).assign(r);
1141            }
1142            root
1143        } else {
1144            Array2::<f64>::zeros((0, param_dim))
1145        };
1146
1147        Ok((
1148            FittedSaeManifold {
1149                atoms: fitted_atoms,
1150                jacobian_rows,
1151                isometry_penalty_root,
1152                metric,
1153            },
1154            streamed_curvature,
1155        ))
1156    }
1157
1158    pub(crate) fn residual_gauge_streamed_data_curvature(
1159        &self,
1160        metric: &gam_problem::RowMetric,
1161        atom_offsets: &[usize],
1162        atom_axis_dim: &[usize],
1163        param_dim: usize,
1164    ) -> Result<(Array2<f64>, usize), String> {
1165        let n = self.n_obs();
1166        let p = self.output_dim();
1167        if metric.p_out() != p {
1168            return Err(format!(
1169                "residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
1170                metric.p_out()
1171            ));
1172        }
1173        let rank = metric.metric_rank();
1174        let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
1175        if param_dim == 0 || n == 0 || rank == 0 {
1176            return Ok((gram, n * rank));
1177        }
1178
1179        let assignments = self.assignment.assignments();
1180        let mut tangent = vec![0.0_f64; p];
1181        let mut j_flat = vec![0.0_f64; p * param_dim];
1182        let mut root_row = Array1::<f64>::zeros(param_dim);
1183        for row in 0..n {
1184            j_flat.fill(0.0);
1185            for (atom_idx, atom) in self.atoms.iter().enumerate() {
1186                let a_nk = assignments[[row, atom_idx]];
1187                if !(a_nk > 0.0) {
1188                    continue;
1189                }
1190                let d = atom_axis_dim[atom_idx];
1191                let base = atom_offsets[atom_idx];
1192                for axis in 0..d {
1193                    atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1194                    for i in 0..p {
1195                        j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1196                    }
1197                }
1198            }
1199
1200            if metric.drives_gauge() {
1201                for r in 0..rank {
1202                    root_row.fill(0.0);
1203                    for c in 0..param_dim {
1204                        let mut acc = 0.0_f64;
1205                        for i in 0..p {
1206                            acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
1207                        }
1208                        root_row[c] = acc;
1209                    }
1210                    let row_slice = root_row.as_slice().ok_or_else(|| {
1211                        "residual_gauge_streamed_data_curvature: non-contiguous root row"
1212                            .to_string()
1213                    })?;
1214                    Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
1215                }
1216            } else {
1217                for i in 0..p {
1218                    let start = i * param_dim;
1219                    let end = start + param_dim;
1220                    Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
1221                }
1222            }
1223        }
1224
1225        for a in 0..param_dim {
1226            for b in 0..a {
1227                gram[[b, a]] = gram[[a, b]];
1228            }
1229        }
1230        Ok((gram, n * rank))
1231    }
1232
1233    pub(crate) fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
1234        for a in 0..row.len() {
1235            let va = row[a];
1236            if va == 0.0 {
1237                continue;
1238            }
1239            for b in 0..=a {
1240                let vb = row[b];
1241                if vb != 0.0 {
1242                    gram[[a, b]] += va * vb;
1243                }
1244            }
1245        }
1246    }
1247
1248    pub fn set_temperature_schedule(
1249        &mut self,
1250        sched: GumbelTemperatureSchedule,
1251    ) -> Result<(), String> {
1252        sched.validate()?;
1253        self.assignment
1254            .mode
1255            .set_temperature(sched.current_tau(sched.iter_count))?;
1256        self.temperature_schedule = Some(sched);
1257        Ok(())
1258    }
1259
1260    pub(crate) fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
1261        let Some(schedule) = self.temperature_schedule.as_mut() else {
1262            return Ok(None);
1263        };
1264        schedule.validate()?;
1265        let tau = schedule.step();
1266        self.assignment.mode.set_temperature(tau)?;
1267        Ok(Some(tau))
1268    }
1269
1270    pub fn n_obs(&self) -> usize {
1271        self.assignment.n_obs()
1272    }
1273
1274    pub fn k_atoms(&self) -> usize {
1275        self.atoms.len()
1276    }
1277
1278    /// Auto-derived in-core vs streaming plan for SAE Arrow-Schur work.
1279    ///
1280    /// This is intentionally not user-configurable: the route follows the
1281    /// retained full-batch working-set estimate and the currently selected GPU
1282    /// memory budget when CUDA is usable, otherwise a conservative host budget.
1283    pub fn streaming_plan(&self) -> SaeStreamingPlan {
1284        let n_obs = self.n_obs();
1285        let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
1286        let d_max = self
1287            .atoms
1288            .iter()
1289            .map(|atom| atom.latent_dim)
1290            .max()
1291            .unwrap_or(0);
1292        let border_dim = if self.any_frame_active() {
1293            self.factored_border_dim()
1294        } else {
1295            self.beta_dim()
1296        };
1297        sae_streaming_plan_for_shape(n_obs, total_basis, self.k_atoms(), d_max, border_dim)
1298    }
1299
1300    /// Construction-time validation: every Psi-tier analytic penalty in the
1301    /// registry must be dispatchable into the SAE arrow-Schur row layout.
1302    ///
1303    /// Two invariants are enforced upfront so the dispatch loop in
1304    /// `add_sae_analytic_penalty_contributions` is total (no runtime
1305    /// "unsupported penalty" fallthrough, no per-call K-gating):
1306    ///
1307    /// 1. Every Psi-tier penalty is either in [`sae_penalty_is_row_block_supported`],
1308    ///    or `NuclearNorm` (which is redirected to the per-atom decoder (β) block
1309    ///    rather than the coord "t" row block). Assignment sparsity penalties
1310    ///    (`IBPAssignment`, `SoftmaxAssignmentSparsity`) are refused because the SAE
1311    ///    term already owns them through its built-in assignment path
1312    ///    (`loss.assignment_sparsity`). Penalty kinds with cross-row structure
1313    ///    (`TotalVariation`, `Monotonicity`, `BlockSparsity`,
1314    ///    `IvaeRidgeMeanGauge`, `Orthogonality`, `NestedPrefix`,
1315    ///    `SheafConsistency`) cannot be expressed in the SAE row-block layout
1316    ///    and are refused here.
1317    ///
1318    /// 2. If any Psi-tier row-block penalty is present, every atom shares
1319    ///    the same coord latent dim. The current registry model carries one
1320    ///    `latent_dim` per descriptor (the "t" latent block declares one
1321    ///    `d` value); per-atom dispatch with heterogeneous `d_k` would
1322    ///    require per-atom registry entries or per-kind in-place
1323    ///    reshaping. Mixed-d row-block fits are rejected with an actionable
1324    ///    error pointing at the configuration mismatch.
1325    ///
1326    /// The K=1 case trivially satisfies (2). Beta-tier and rho-tier
1327    /// penalties are not constrained here.
1328    pub(crate) fn validate_analytic_penalty_registry(
1329        &self,
1330        registry: &AnalyticPenaltyRegistry,
1331    ) -> Result<(), String> {
1332        let mut row_block_penalty_present = false;
1333        for penalty in &registry.penalties {
1334            if penalty.tier() != PenaltyTier::Psi {
1335                continue;
1336            }
1337            if matches!(
1338                penalty,
1339                AnalyticPenaltyKind::IBPAssignment(_)
1340                    | AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
1341            ) {
1342                return Err(format!(
1343                    "SAE-manifold term refuses analytic penalty {:?}: assignment sparsity \
1344                     is owned by the built-in SAE assignment path (loss.assignment_sparsity). \
1345                     Registering it would double-count the objective and gradient",
1346                    penalty.name()
1347                ));
1348            }
1349            // NuclearNorm is redirected to the per-atom decoder (β) block in
1350            // `add_sae_beta_penalty` (it penalizes each atom's decoder matrix
1351            // singular spectrum, i.e. its embedding rank), so it bypasses the
1352            // coord "t" row-block requirement below.
1353            if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
1354                continue;
1355            }
1356            if !sae_penalty_is_row_block_supported(penalty) {
1357                return Err(format!(
1358                    "SAE-manifold term refuses analytic penalty {:?}: this kind \
1359                     has cross-row structure and cannot be expressed in the \
1360                     arrow-Schur row layout. Use only row-block-supported \
1361                     coord penalties (ARD, BlockOrthogonality, \
1362                     Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
1363                     ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
1364                     coord latent block, or move the penalty to a non-SAE \
1365                     term",
1366                    penalty.name()
1367                ));
1368            }
1369            row_block_penalty_present = true;
1370        }
1371        if row_block_penalty_present {
1372            let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
1373            if let Some(first) = dims.next() {
1374                if let Some(mismatch) = dims.find(|d| *d != first) {
1375                    return Err(format!(
1376                        "SAE-manifold term refuses row-block analytic penalty: \
1377                         atoms have heterogeneous coord latent dims (saw {first} \
1378                         and {mismatch}). Row-block penalties (ARD, \
1379                         BlockOrthogonality, ...) target the unified \"t\" \
1380                         latent block whose declared `d` matches one shape; \
1381                         per-atom dispatch with mixed `d_k` would silently \
1382                         truncate or expand axes. Configure all atoms with the \
1383                         same `atom_dim`, or split the row-block penalty into \
1384                         per-atom descriptors keyed to per-atom latent blocks"
1385                    ));
1386                }
1387            }
1388        }
1389        Ok(())
1390    }
1391
1392    pub fn output_dim(&self) -> usize {
1393        self.atoms[0].output_dim()
1394    }
1395
1396    pub fn beta_dim(&self) -> usize {
1397        let p = self.output_dim();
1398        self.atoms.iter().map(|a| a.basis_size() * p).sum()
1399    }
1400
1401    pub(crate) fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
1402        let mut workspace =
1403            std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
1404        if workspace.dim() != (border_dim, border_dim) {
1405            workspace = Array2::<f64>::zeros((border_dim, border_dim));
1406        } else {
1407            workspace.fill(0.0);
1408        }
1409        workspace
1410    }
1411
1412    pub(crate) fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
1413        let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
1414        self.border_hbb_workspace = workspace;
1415    }
1416
1417    /// Factored arrow-Schur border dimension `Σ_k M_k · r_k` (issue #972): the
1418    /// number of decoder coordinates the border actually carries once the
1419    /// low-rank Grassmann frames are profiled out. Atoms with no active frame
1420    /// contribute their full `M_k · p` (`r_k == p`), so on the all-full-`B` path
1421    /// this equals [`Self::beta_dim`]. The border Cholesky / evidence log-det
1422    /// scale with THIS count, not `beta_dim`.
1423    pub fn factored_border_dim(&self) -> usize {
1424        self.atoms.iter().map(|a| a.border_coeff_count()).sum()
1425    }
1426
1427    /// Total profiled-out Grassmann manifold dimension `Σ_k r_k·(p − r_k)` across
1428    /// all active frames (issue #972). This is the count of decoder-frame degrees
1429    /// of freedom estimated OUTSIDE the border by closed-form polar steps, and it
1430    /// must enter the Laplace evidence dimension accounting (evidence honesty):
1431    /// the profiled frame is a MAP point on `∏_k Gr(r_k, p)`, contributing this
1432    /// many free dimensions to the model. `0` when every atom is on the full-`B`
1433    /// path. Threaded into [`Self::reml_occam_term`].
1434    pub fn grassmann_evidence_dimension(&self) -> usize {
1435        self.atoms
1436            .iter()
1437            .map(|a| a.frame_manifold_dimension())
1438            .sum()
1439    }
1440
1441    /// True iff any atom has an active low-rank Grassmann frame (issue #972).
1442    pub fn frames_active(&self) -> bool {
1443        self.atoms.iter().any(|a| a.decoder_frame.is_some())
1444    }
1445
1446    /// Alias of [`Self::frames_active`] (issue #972 / #977 T1): the predicate the
1447    /// assembly / step-lift branch on to decide whether the β-tier is built in
1448    /// the factored coordinate layout. Named to read as the question
1449    /// "is the factored path engaged?" at its call sites.
1450    pub fn any_frame_active(&self) -> bool {
1451        self.frames_active()
1452    }
1453
1454    /// Per-atom column offsets of the *factored* border (issue #972 / #977 T1):
1455    /// the running prefix sum of `M_k · r_k`, one entry per atom (the same
1456    /// convention as [`Self::beta_offsets`]). This is the start of each atom's
1457    /// `C_k` block in the reduced border vector; on the all-full-`B` path it
1458    /// equals `beta_offsets`. Distinct from [`Self::factored_border_offsets`]
1459    /// only in name (both compute the identical prefix sum) — this method is the
1460    /// one the frame transform reads, mirroring `beta_offsets` at the call site.
1461    pub fn factored_beta_offsets(&self) -> Vec<usize> {
1462        self.factored_border_offsets()
1463    }
1464
1465    /// Frame output matrix `U_k ∈ St(p, r_k)` for atom `k` (issue #972 / #977 T1).
1466    /// Returns the active frame `U_k` (`p × r_k`) when atom `k` is framed, else
1467    /// the identity `I_p` (the `r_k == p`, `U_k == I_p` full-`B` special case) so
1468    /// the projection / lift code is uniform across a mixed dictionary.
1469    pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
1470        let atom = &self.atoms[atom_idx];
1471        match &atom.decoder_frame {
1472            Some(frame) => frame.frame().to_owned(),
1473            None => Array2::<f64>::eye(atom.output_dim()),
1474        }
1475    }
1476
1477    /// Per-pair frame factor `W_{ij} = U_iᵀ U_j` (`r_i × r_j`) used as the output
1478    /// factor of the factored data β-Hessian block `G_{ij} ⊗ W_{ij}` (issue #972
1479    /// / #977 T1). When both atoms are framed this is the dense principal-angle
1480    /// cosine matrix between the two frames; for `i == j` with an orthonormal
1481    /// frame it is exactly `I_{r_i}`; for any un-framed atom the corresponding
1482    /// `U` is `I_p`, so a same-atom un-framed pair gives `I_p` (the clean full-`B`
1483    /// `G ⊗ I_p` collapse) and a framed/un-framed cross pair gives the rectangular
1484    /// `U_iᵀ` / `U_j` overlap.
1485    pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
1486        let ui = self.frame_output_matrix(atom_i);
1487        let uj = self.frame_output_matrix(atom_j);
1488        // `U_iᵀ U_j`: `(r_i × p) · (p × r_j)`. `fast_atb` forms `U_iᵀ U_j` directly.
1489        fast_atb(&ui, &uj)
1490    }
1491
1492    /// Per-atom column offsets of the *factored* border (issue #972): the
1493    /// running prefix sum of `M_k · r_k`. The analogue of [`Self::beta_offsets`]
1494    /// for the reduced coordinate layout — atom `k`'s `C_k` occupies
1495    /// `[factored_border_offsets()[k] .. + M_k·r_k)`. On the full-`B` path this
1496    /// equals `beta_offsets`.
1497    pub fn factored_border_offsets(&self) -> Vec<usize> {
1498        let mut out = Vec::with_capacity(self.k_atoms());
1499        let mut cursor = 0usize;
1500        for atom in &self.atoms {
1501            out.push(cursor);
1502            cursor += atom.border_coeff_count();
1503        }
1504        out
1505    }
1506
1507    /// Assemble the factored border coordinate vector `C = [vec(C_1); …; vec(C_K)]`
1508    /// in row-major `C_k[m, j] → C[off_k + m·r_k + j]` layout (issue #972).
1509    ///
1510    /// This is the reduced state the arrow-Schur border carries when frames are
1511    /// active: its length is [`Self::factored_border_dim`] (`Σ M_k·r_k`), the
1512    /// border-size invariant verified by [`grassmann_assert_border_dim_invariant`].
1513    /// Atoms
1514    /// without an active frame contribute their full `vec(B_k)` (their `r_k == p`
1515    /// coordinates are the decoder itself), so on the all-full-`B` path this
1516    /// reproduces [`Self::flatten_beta`].
1517    pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
1518        let offsets = self.factored_border_offsets();
1519        let mut out = Array1::<f64>::zeros(self.factored_border_dim());
1520        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1521            let off = offsets[atom_idx];
1522            let r = atom.border_frame_rank();
1523            let m = atom.basis_size();
1524            let coords = match atom.factored_coordinates()? {
1525                Some(c) => c,
1526                // Full-`B` path: the decoder itself is the coordinate matrix.
1527                None => atom.decoder_coefficients.clone(),
1528            };
1529            for basis_col in 0..m {
1530                for j in 0..r {
1531                    out[off + basis_col * r + j] = coords[[basis_col, j]];
1532                }
1533            }
1534        }
1535        Ok(out)
1536    }
1537
1538    /// Scatter a factored border coordinate vector `C` (length
1539    /// [`Self::factored_border_dim`]) back into the per-atom decoders, refreshing
1540    /// each `decoder_coefficients = C_k · U_kᵀ` so the full-`B` consumers stay
1541    /// consistent after a factored border solve (issue #972). The inverse of
1542    /// [`Self::flatten_factored_border`].
1543    pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
1544        let expected = self.factored_border_dim();
1545        if border.len() != expected {
1546            return Err(format!(
1547                "SaeManifoldTerm::scatter_factored_border: border length {} must equal \
1548                 factored border dim {expected}",
1549                border.len()
1550            ));
1551        }
1552        let offsets = self.factored_border_offsets();
1553        for atom_idx in 0..self.atoms.len() {
1554            let off = offsets[atom_idx];
1555            let (r, m, has_frame) = {
1556                let atom = &self.atoms[atom_idx];
1557                (
1558                    atom.border_frame_rank(),
1559                    atom.basis_size(),
1560                    atom.decoder_frame.is_some(),
1561                )
1562            };
1563            let mut coords = Array2::<f64>::zeros((m, r));
1564            for basis_col in 0..m {
1565                for j in 0..r {
1566                    coords[[basis_col, j]] = border[off + basis_col * r + j];
1567                }
1568            }
1569            if has_frame {
1570                self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
1571            } else {
1572                // Full-`B` path: the coordinates ARE the decoder.
1573                self.atoms[atom_idx].decoder_coefficients = coords;
1574            }
1575        }
1576        Ok(())
1577    }
1578
1579    /// Auto-derive and install low-rank Grassmann decoder frames across all
1580    /// atoms (issue #972) — magic-by-default, no flag. Each atom independently
1581    /// activates its frame iff the factorization materially shrinks its border
1582    /// (see [`SaeManifoldAtom::maybe_activate_decoder_frame`]). Returns the
1583    /// number of atoms that activated a frame. Idempotent: re-running re-derives
1584    /// each frame from the current decoder.
1585    ///
1586    /// The decision keys on the *frontier* regime the issue targets: at large
1587    /// ambient `p` the full border `Σ M_k · p` reaches `10^7`–`10^8` and the
1588    /// border Cholesky dies, while the decoder's effective column rank `r` stays
1589    /// `≪ p`. Small-`p` atoms (where `r` cannot beat the activation margin)
1590    /// keep the bit-for-bit full-`B` path, so the small-model evidence is
1591    /// unchanged (verified by `factored_evidence_matches_full_b_at_small_p`).
1592    pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
1593        let mut activated = 0usize;
1594        for atom in &mut self.atoms {
1595            let expected_rank = atom.decoder_frame_activation_rank()?;
1596            match (
1597                expected_rank,
1598                atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
1599            ) {
1600                (Some(expected), Some(current)) if expected == current => {
1601                    continue;
1602                }
1603                (None, Some(_)) => {
1604                    atom.deactivate_decoder_frame();
1605                    continue;
1606                }
1607                (None, None) => {
1608                    continue;
1609                }
1610                (Some(_), _) => {}
1611            }
1612            if atom.maybe_activate_decoder_frame()?.is_some() {
1613                activated += 1;
1614            }
1615        }
1616        Ok(activated)
1617    }
1618
1619    /// Reconcile decoder-frame activation before a fit entry point. The
1620    /// user-facing `auto_activate_decoder_frames` contract returns only newly
1621    /// installed frames; this helper enforces the stronger invariant the large-p
1622    /// solver needs: every atom whose current decoder satisfies the activation
1623    /// predicate has an active frame after the pass.
1624    pub(crate) fn ensure_decoder_frames_active_for_current_decoder(
1625        &mut self,
1626    ) -> Result<(), String> {
1627        self.auto_activate_decoder_frames()?;
1628        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1629            let expected_rank = atom.decoder_frame_activation_rank()?;
1630            if let Some(expected_rank) = expected_rank {
1631                match atom.decoder_frame.as_ref() {
1632                    Some(frame) if frame.rank() == expected_rank => {}
1633                    Some(frame) => {
1634                        return Err(format!(
1635                            "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1636                             atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
1637                            frame.rank()
1638                        ));
1639                    }
1640                    None => {
1641                        return Err(format!(
1642                            "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1643                             atom {atom_idx} has audited rank {expected_rank} but no active frame"
1644                        ));
1645                    }
1646                }
1647            } else if atom.decoder_frame.is_some() {
1648                return Err(format!(
1649                    "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1650                     atom {atom_idx} kept a frame after the full-B predicate won"
1651                ));
1652            }
1653        }
1654        Ok(())
1655    }
1656
1657    /// Closed-form streaming POLAR refresh of every ACTIVE decoder frame from the
1658    /// current data evidence (issue #972 / #977 T1) — the U-block of the
1659    /// alternating block-coordinate ascent that complements the border's
1660    /// C-block Newton step.
1661    ///
1662    /// For each framed atom `k` we accumulate the `p × r_k` cross-moment
1663    ///   `A_k = Σ_n a_{n,k} · e_{n,k} · ĉ_{n,k}ᵀ`,
1664    /// where `e_{n,k} = z_n − Σ_{k'≠k} a_{n,k'}·decoded_{k'}(n)` is the row's
1665    /// partial reconstruction residual (everything except atom `k`) and
1666    /// `ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^{r_k}` is atom `k`'s in-span decoded
1667    /// coordinate. The polar factor `U_new = polar(A_k)` is the closed-form MAP
1668    /// frame on `Gr(r_k, p)` given the C-coordinates held fixed — the same
1669    /// `O(p r²)` thin SVD the issue prescribes, run OUTSIDE the border. The frame
1670    /// is then re-installed and the decoder re-projected onto it so the
1671    /// authoritative `B_k = C_k U_newᵀ` and the `(C_k, U_new)` pair stay
1672    /// consistent (a no-op in span for a truly rank-`r` atom). Un-framed atoms
1673    /// are skipped. Returns the number of frames refreshed.
1674    pub(crate) fn refresh_active_frames_from_data(
1675        &mut self,
1676        target: ArrayView2<'_, f64>,
1677        rho: &SaeManifoldRho,
1678    ) -> Result<usize, String> {
1679        let n = self.n_obs();
1680        let p = self.output_dim();
1681        let k_atoms = self.k_atoms();
1682        if n == 0 {
1683            return Ok(0);
1684        }
1685        // Per-row assignments and per-(row, atom) decoded outputs, computed once.
1686        let mut assignments = Vec::with_capacity(n);
1687        for row in 0..n {
1688            assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
1689        }
1690        let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
1691        let mut dbuf = vec![0.0_f64; p];
1692        for row in 0..n {
1693            for atom_idx in 0..k_atoms {
1694                self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
1695                for c in 0..p {
1696                    decoded[[row, atom_idx, c]] = dbuf[c];
1697                }
1698            }
1699        }
1700        // Full fitted reconstruction `Σ_k a_k decoded_k`, so the per-atom partial
1701        // residual is `e_k = (z − fitted) + a_k decoded_k` (add atom k back in).
1702        let mut fitted = Array2::<f64>::zeros((n, p));
1703        for row in 0..n {
1704            for atom_idx in 0..k_atoms {
1705                let a = assignments[row][atom_idx];
1706                if a == 0.0 {
1707                    continue;
1708                }
1709                for c in 0..p {
1710                    fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
1711                }
1712            }
1713        }
1714        let mut refreshed = 0usize;
1715        for atom_idx in 0..k_atoms {
1716            // Only atoms with an active frame are refreshed.
1717            let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
1718                continue;
1719            };
1720            let r = self.atoms[atom_idx].border_frame_rank();
1721            let m = self.atoms[atom_idx].basis_size();
1722            // Accumulate `A_k = Σ_n a_k · e_{n,k} · ĉ_{n,k}ᵀ` directly (p × r).
1723            let mut cross = GrassmannCrossMoment::new(p, r);
1724            // Build per-row p-target `a_k·e_k` and r-coord `a_k·ĉ` batched, then
1725            // accumulate as one outer-product sum. `accumulate` forms
1726            // `targetsᵀ·coords`, so scaling EITHER side by `a_k` once gives the
1727            // `a_k²` weight on the cross-moment that matches the C-block normal
1728            // equations (residual leg carries `a_k`, coordinate leg carries
1729            // `a_k`).
1730            let mut targets = Array2::<f64>::zeros((n, p));
1731            let mut rcoords = Array2::<f64>::zeros((n, r));
1732            for row in 0..n {
1733                let a = assignments[row][atom_idx];
1734                // Partial residual e_{n,k} = z_n − (fitted − a_k decoded_k).
1735                for c in 0..p {
1736                    let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
1737                    targets[[row, c]] = a * e;
1738                }
1739                // In-span coordinate ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^r.
1740                for j in 0..r {
1741                    let mut acc = 0.0_f64;
1742                    for basis_col in 0..m {
1743                        acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
1744                            * coords_c[[basis_col, j]];
1745                    }
1746                    rcoords[[row, j]] = a * acc;
1747                }
1748            }
1749            cross.accumulate(targets.view(), rcoords.view())?;
1750            // `polar(A_k)` is well-defined only when the moment is non-trivial;
1751            // a zero moment (e.g. a fully collapsed atom) leaves the frame as-is.
1752            if cross.moment().iter().all(|&v| v == 0.0) {
1753                continue;
1754            }
1755            self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
1756            refreshed += 1;
1757        }
1758        Ok(refreshed)
1759    }
1760
1761    pub fn beta_offsets(&self) -> Vec<usize> {
1762        let p = self.output_dim();
1763        let mut out = Vec::with_capacity(self.k_atoms());
1764        let mut cursor = 0usize;
1765        for atom in &self.atoms {
1766            out.push(cursor);
1767            cursor += atom.basis_size() * p;
1768        }
1769        out
1770    }
1771
1772    /// Per-atom β column ranges for the block-Jacobi Schur preconditioner.
1773    ///
1774    /// Returns one `Range<usize>` per atom, covering that atom's decoder
1775    /// coefficients in the flat β vector:
1776    ///   `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
1777    ///
1778    /// Pass to [`ArrowSchurSystem::set_block_offsets`] so that
1779    /// [`gam_solve::arrow_schur::JacobiPreconditioner`] builds one dense
1780    /// Schur sub-block per atom instead of scalar-diagonal inversion.
1781    pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
1782        let p = self.output_dim();
1783        let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
1784        let mut cursor = 0usize;
1785        for atom in &self.atoms {
1786            let width = atom.basis_size() * p;
1787            ranges.push(cursor..cursor + width);
1788            cursor += width;
1789        }
1790        Arc::from(ranges.into_boxed_slice())
1791    }
1792
1793    /// Decide whether the sparse per-row active-set layout is engaged for a
1794    /// dense-weight assignment mode, and if so derive the per-row active-atom
1795    /// cap and magnitude cutoff.
1796    ///
1797    /// #1408: this plan is mode-agnostic. `assemble_arrow_schur` consults it
1798    /// directly for IBP-MAP, and for `AssignmentMode::Softmax` via
1799    /// [`Self::softmax_active_plan`], which tightens it with an explicit `top_k`
1800    /// (`softmax_active_cap`). Softmax therefore engages the compact active-set
1801    /// layout whenever `top_k` or the budget bounds the active set (the
1802    /// active-sub-block Gershgorin majorizer + coherent logdet/θ-adjoint are
1803    /// landed — see `SaeRowLayout`'s doc); it keeps the full `K`-atom layout only
1804    /// when neither lever engages. The decision is auto-derived from
1805    /// the problem size and the device/host working-set budget — never a CLI flag
1806    /// or kwarg. JumpReLU is not handled here (it always uses its structural gate
1807    /// via [`SaeRowLayout::from_jumprelu`]). The dense Gauss-Newton data Gram `G`
1808    /// is `(m_total × m_total)` f64; if its dense form fits the budget we keep
1809    /// the exact full-support solve (every atom active per row), so small-`K`
1810    /// problems are bit-for-bit unchanged. Above that, we cap each row to the
1811    /// `k_active` atoms that make the *sparse* Gram fit the same budget, with a
1812    /// relative magnitude cutoff that drops assignment mass contributing
1813    /// negligible `O(a²)` curvature.
1814    ///
1815    /// Returns `Some((k_active_cap, cutoff))` to engage sparsity, or `None` to
1816    /// keep the dense full-support layout.
1817    pub(crate) fn sparse_active_plan(&self) -> Option<(usize, f64)> {
1818        // The per-row Riemannian tangent projection for non-Euclidean atom
1819        // latents is now applied directly on the compact active-set rows (see
1820        // the `Some(layout)` arm in `assemble_arrow_schur`, via
1821        // `compact_row_ext_manifold_and_point`), which rebuilds each row's
1822        // product manifold in its compact column order and applies the SAME
1823        // gt/htt/htbeta + Kronecker-Jacobian projections the dense path uses. So
1824        // the sparse plan may engage on curved ext-coord manifolds (circle /
1825        // torus / sphere atoms) — the affordability lever for manifold-SAE at
1826        // large `K`, where the dense `K²` co-assignment Gram is the cost. (The
1827        // former `is_euclidean()`-only restriction punted every curved atom to
1828        // the dense layout; it is lifted.) The host/device in-core budget is the
1829        // single gate now; it is parameterised in `sparse_active_plan_for_budget`
1830        // so the engagement regression can pin a small budget without allocating
1831        // a multi-GB dense Gram.
1832        let budget = match crate::gpu::device_runtime::GpuRuntime::global() {
1833            // Allow up to one quarter of the AGGREGATE device budget for the dense
1834            // Gram, matching the streaming dispatcher's in-core fraction. The
1835            // per-atom-pair Gram blocks fan out across the whole device pool, so
1836            // the in-core fraction sums every ordinal's budget, not just the
1837            // primary's.
1838            Some(rt) => {
1839                let aggregate: usize = rt
1840                    .device_ordinals()
1841                    .iter()
1842                    .map(|&ord| rt.memory_budget_for(ord))
1843                    .sum();
1844                aggregate / 4
1845            }
1846            None => sae_host_in_core_budget_bytes().0,
1847        };
1848        self.sparse_active_plan_for_budget(budget)
1849    }
1850
1851    /// Budget-parameterised core of [`Self::sparse_active_plan`]. The dense data
1852    /// Gram footprint `(m_total · m_total) f64` is compared against `budget`; a
1853    /// term whose dense Gram exceeds the budget engages the compact active-set
1854    /// plan (returns `Some((k_active_cap, cutoff))`), regardless of whether any
1855    /// atom latent is curved. Pulled out so the curved-atom engagement
1856    /// regression can pin a small budget deterministically.
1857    pub(crate) fn sparse_active_plan_for_budget(&self, budget: usize) -> Option<(usize, f64)> {
1858        // Relative magnitude cutoff: assignment mass below this fraction of the
1859        // row's peak `|a_k|` enters the Gram only as `O(a²)` curvature and is
1860        // dropped. Chosen so dropped terms are ~1e-6 of the peak self-coupling.
1861        const RELATIVE_CUTOFF: f64 = 1.0e-3;
1862
1863        let k_atoms = self.k_atoms();
1864        if k_atoms <= 1 {
1865            return None;
1866        }
1867        let p = self.output_dim();
1868        let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
1869        // Dense data Gram footprint: (m_total · m_total) f64.
1870        let dense_gram_bytes = m_total
1871            .saturating_mul(m_total)
1872            .saturating_mul(SAE_BYTES_PER_F64);
1873        if dense_gram_bytes <= budget {
1874            return None;
1875        }
1876
1877        // Sparse Gram footprint scales with the per-row active basis count
1878        // `k_active · m_atom`. Solve for the largest `k_active` whose sparse
1879        // Gram `(k_active · m_atom)²` still fits the budget.
1880        let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
1881        let max_active_basis = ((budget as f64 / SAE_BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
1882        let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
1883        // p does not enter the Gram dimension (it is carried by the `⊗ I_p`
1884        // structure), but a degenerate `p == 0` term has no decoder columns.
1885        if p == 0 {
1886            return None;
1887        }
1888        Some((k_active_cap, RELATIVE_CUTOFF))
1889    }
1890
1891    /// #1408/#1409 — per-row active-set plan for the Softmax assignment.
1892    ///
1893    /// Engages the compact top-`k` row layout when EITHER the user supplied a
1894    /// hard `top_k` cap ([`Self::softmax_active_cap`], `1 <= k < K`) OR the
1895    /// dense data Gram exceeds the in-core budget (the same memory lever the
1896    /// IBP path uses via [`Self::sparse_active_plan`]). The returned
1897    /// `k_active_cap` is the tighter of the two, so an explicit `top_k`
1898    /// genuinely bounds the optimization even below the memory threshold and a
1899    /// large-K budget breach still bounds it when no `top_k` is set. Returns
1900    /// `None` (keep the exact full-`K` dense softmax layout) when neither lever
1901    /// engages.
1902    ///
1903    /// The cutoff is the same relative magnitude floor as the budget plan
1904    /// (`1e-3` of the row peak); under an explicit `top_k` cap alone (no budget
1905    /// breach) it is `0.0` so exactly the top-`k` atoms are retained.
1906    pub(crate) fn softmax_active_plan(&self) -> Option<(usize, f64)> {
1907        if self.k_atoms() <= 1 {
1908            return None;
1909        }
1910        let budget_plan = self.sparse_active_plan();
1911        match (self.softmax_active_cap, budget_plan) {
1912            (Some(cap), Some((budget_cap, cutoff))) => Some((cap.min(budget_cap), cutoff)),
1913            // Explicit cap only: retain exactly the top-`cap` atoms (no extra
1914            // magnitude cutoff beyond the cap).
1915            (Some(cap), None) => Some((cap, 0.0)),
1916            (None, plan) => plan,
1917        }
1918    }
1919
1920    pub fn flatten_beta(&self) -> Array1<f64> {
1921        let p = self.output_dim();
1922        let offsets = self.beta_offsets();
1923        let mut out = Array1::<f64>::zeros(self.beta_dim());
1924        for (atom_idx, atom) in self.atoms.iter().enumerate() {
1925            let m = atom.basis_size();
1926            let off = offsets[atom_idx];
1927            for basis_col in 0..m {
1928                for out_col in 0..p {
1929                    out[off + basis_col * p + out_col] =
1930                        atom.decoder_coefficients[[basis_col, out_col]];
1931                }
1932            }
1933        }
1934        out
1935    }
1936
1937    pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
1938        if beta.len() != self.beta_dim() {
1939            return Err(format!(
1940                "set_flat_beta: beta length {} != expected {}",
1941                beta.len(),
1942                self.beta_dim()
1943            ));
1944        }
1945        let p = self.output_dim();
1946        let offsets = self.beta_offsets();
1947        for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
1948            let m = atom.basis_size();
1949            let off = offsets[atom_idx];
1950            for basis_col in 0..m {
1951                for out_col in 0..p {
1952                    atom.decoder_coefficients[[basis_col, out_col]] =
1953                        beta[off + basis_col * p + out_col];
1954                }
1955            }
1956        }
1957        Ok(())
1958    }
1959
1960    pub fn refit_decoder_least_squares_at_current_state(
1961        &mut self,
1962        target: ArrayView2<'_, f64>,
1963        rho: Option<&SaeManifoldRho>,
1964    ) -> Result<(), String> {
1965        let n = self.n_obs();
1966        let p = self.output_dim();
1967        if target.dim() != (n, p) {
1968            return Err(format!(
1969                "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: target shape {:?} != ({n}, {p})",
1970                target.dim()
1971            ));
1972        }
1973        let k_atoms = self.k_atoms();
1974        let offsets = self.beta_offsets();
1975        let m_total = self.beta_dim() / p;
1976        let mut design = Array2::<f64>::zeros((n, m_total));
1977        for row in 0..n {
1978            let assignments = match rho {
1979                Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
1980                None => self.assignment.try_assignments_row(row)?,
1981            };
1982            for atom_idx in 0..k_atoms {
1983                let atom = &self.atoms[atom_idx];
1984                let weight = assignments[atom_idx];
1985                let m = atom.basis_size();
1986                let off = offsets[atom_idx] / p;
1987                for basis_col in 0..m {
1988                    design[[row, off + basis_col]] = weight * atom.basis_values[[row, basis_col]];
1989                }
1990            }
1991        }
1992        let beta = solve_design_least_squares(design.view(), target)?;
1993        if beta.dim() != (m_total, p) {
1994            return Err(format!(
1995                "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: beta shape {:?} != ({m_total}, {p})",
1996                beta.dim()
1997            ));
1998        }
1999        for atom_idx in 0..k_atoms {
2000            let m = self.atoms[atom_idx].basis_size();
2001            let off = offsets[atom_idx] / p;
2002            for basis_col in 0..m {
2003                for out_col in 0..p {
2004                    self.atoms[atom_idx].decoder_coefficients[[basis_col, out_col]] =
2005                        beta[[off + basis_col, out_col]];
2006                }
2007            }
2008            self.atoms[atom_idx].refresh_intrinsic_smooth_penalty();
2009        }
2010        Ok(())
2011    }
2012
2013    pub fn fitted(&self) -> Array2<f64> {
2014        self.try_fitted().expect("assignment logits must be finite")
2015    }
2016
2017    /// The #1026 hybrid-collapse substitution map: `atom_idx → &AtomLinearImage`
2018    /// for every `d = 1` slot whose post-fit verdict selected its straight
2019    /// (`Θ → 0`) sub-model. Empty when no report has been computed
2020    /// (`hybrid_split_report == None`, e.g. mid-fit) or no slot collapsed. The
2021    /// SINGLE source of the collapse policy — every reconstruction path (the
2022    /// rho-keyed `try_fitted_with_rho`, the explicit-assignment
2023    /// [`Self::reconstruct_from_assignments`] used by the top-k projection)
2024    /// reads it so train, OOS, and top-k reconstructions decode collapsed slots
2025    /// identically (#1228, #1233).
2026    pub(crate) fn hybrid_linear_image_map(
2027        &self,
2028    ) -> std::collections::HashMap<usize, &crate::hybrid_split::AtomLinearImage> {
2029        // A fitted term carries its collapse policy on the post-fit
2030        // `hybrid_split_report`; an OOS term carries the same trained images on
2031        // `oos_linear_images` (#1228). At most one is `Some` in practice, but
2032        // prefer the report when both are present.
2033        if let Some(report) = self.hybrid_split_report.as_ref() {
2034            return report
2035                .verdicts
2036                .iter()
2037                .filter_map(|v| v.linear_image.as_ref().map(|img| (img.atom_idx, img)))
2038                .collect();
2039        }
2040        if let Some(images) = self.oos_linear_images.as_ref() {
2041            return images.iter().map(|img| (img.atom_idx, img)).collect();
2042        }
2043        std::collections::HashMap::new()
2044    }
2045
2046    /// #1228 — attach the trained dictionary's hybrid-collapsed linear images to
2047    /// this (typically OOS) term so its reconstruction (`fitted` / the top-k
2048    /// assembler) decodes verdict-linear `d = 1` slots by the SAME straight
2049    /// sub-model the training reconstruction used, instead of the original
2050    /// curved decoder. Each image's `atom_idx` must index a real slot; an image
2051    /// whose channel count `p` disagrees with this term's output dim, or whose
2052    /// `atom_idx` is out of range, is rejected so a stale/mismatched payload
2053    /// cannot silently corrupt the reconstruction. Pass an empty slice (or never
2054    /// call this) for an all-curved OOS reconstruction.
2055    ///
2056    /// `pub` (not `pub(crate)`): this is part of the FFI surface — the gam-pyffi
2057    /// crate calls it from `latent_basis_and_sae_ffi.rs` to attach a trained
2058    /// dictionary's hybrid-linear images to an OOS reconstruction term (#1228).
2059    /// Downgrading it to `pub(crate)` breaks the gam-pyffi cdylib build with
2060    /// E0624 (the gam lib still compiles, so the lib build does not catch it).
2061    pub fn set_hybrid_linear_images(
2062        &mut self,
2063        images: Vec<crate::hybrid_split::AtomLinearImage>,
2064    ) -> Result<(), String> {
2065        let p = self.output_dim();
2066        let k_atoms = self.k_atoms();
2067        for img in &images {
2068            if img.atom_idx >= k_atoms {
2069                return Err(format!(
2070                    "set_hybrid_linear_images: atom_idx {} out of range (k_atoms={k_atoms})",
2071                    img.atom_idx
2072                ));
2073            }
2074            if img.b0.len() != p || img.b1.len() != p {
2075                return Err(format!(
2076                    "set_hybrid_linear_images: atom {} linear image has p=({}, {}) != output_dim {p}",
2077                    img.atom_idx,
2078                    img.b0.len(),
2079                    img.b1.len()
2080                ));
2081            }
2082            // #1777 — a collapse-rescued image's projection direction `v` must
2083            // have one entry per output channel so `coordinate_from_residual` can
2084            // project a held-out row's `p`-vector residual onto it.
2085            if let Some(v) = img.v.as_ref() {
2086                if v.len() != p {
2087                    return Err(format!(
2088                        "set_hybrid_linear_images: atom {} projection direction v has len {} != output_dim {p}",
2089                        img.atom_idx,
2090                        v.len()
2091                    ));
2092                }
2093            }
2094            if self.atoms[img.atom_idx].latent_dim != 1 {
2095                return Err(format!(
2096                    "set_hybrid_linear_images: atom {} is not d=1; only d=1 slots collapse to a straight image",
2097                    img.atom_idx
2098                ));
2099            }
2100        }
2101        self.oos_linear_images = if images.is_empty() {
2102            None
2103        } else {
2104            Some(images)
2105        };
2106        Ok(())
2107    }
2108
2109    /// Assemble the reconstruction `Σ_k a[i,k]·g_k(t_{ik})` from an EXPLICIT
2110    /// per-row assignment matrix (e.g. a hard top-k projection of the fitted
2111    /// soft assignments), honouring the #1026 hybrid collapse when `collapse` is
2112    /// set: a verdict-linear `d = 1` slot decodes its straight sub-model image
2113    /// instead of its curved curve, exactly as the production `try_fitted` does.
2114    /// This is the shared assembler the FFI top-k path uses so the projected
2115    /// reconstruction composes with hybrid collapse (#1233) instead of
2116    /// re-deriving the curved image by hand and silently bypassing the verdict.
2117    /// The atom coordinates (`t`) and decoded curves are the term's own fitted
2118    /// ones; only the assignment masses come from `assignments`.
2119    pub fn reconstruct_from_assignments(
2120        &self,
2121        assignments: ArrayView2<'_, f64>,
2122        collapse: bool,
2123    ) -> Result<Array2<f64>, String> {
2124        let n = self.n_obs();
2125        let p = self.output_dim();
2126        let k_atoms = self.k_atoms();
2127        if assignments.dim() != (n, k_atoms) {
2128            return Err(format!(
2129                "SaeManifoldTerm::reconstruct_from_assignments: assignments {:?} != ({n}, {k_atoms})",
2130                assignments.dim()
2131            ));
2132        }
2133        let linear_images = if collapse {
2134            self.hybrid_linear_image_map()
2135        } else {
2136            std::collections::HashMap::new()
2137        };
2138        let mut out = Array2::<f64>::zeros((n, p));
2139        let mut g_buf = vec![0.0_f64; p];
2140        for row in 0..n {
2141            for atom_idx in 0..k_atoms {
2142                let a_k = assignments[[row, atom_idx]];
2143                if a_k == 0.0 {
2144                    continue;
2145                }
2146                if let Some(image) = linear_images.get(&atom_idx) {
2147                    let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2148                    image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2149                } else {
2150                    self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2151                }
2152                let mut out_row = out.row_mut(row);
2153                for out_col in 0..p {
2154                    out_row[out_col] += a_k * g_buf[out_col];
2155                }
2156            }
2157        }
2158        Ok(out)
2159    }
2160
2161    /// #1777 — TARGET-AWARE hybrid-collapsed reconstruction: identical to
2162    /// [`Self::try_fitted`] except that a #1026 COLLAPSE-RESCUED `d = 1` slot
2163    /// (whose linear image carries a projection direction `v`) recomputes each
2164    /// row's coordinate from THIS `target` as
2165    /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — its own leave-this-atom-out residual
2166    /// projected onto `v` — instead of reading the train-only cached
2167    /// `row_codes[i]` (or, worse, the atom's collapsed own coordinate `own_t`).
2168    ///
2169    /// This is the SAME math the train split used to build `row_codes`
2170    /// (`row_codes[i] = ⟨target_resid[i], v⟩`), so on the TRAIN rows/target it
2171    /// reproduces the train reconstruction bit-for-bit, and on a HELD-OUT
2172    /// rows/target it produces the correct out-of-sample coordinate — train and
2173    /// OOS are ONE model. Ordinary (non-rescued) straight images and curved slots
2174    /// are decoded exactly as in [`Self::try_fitted`]; they ignore `target`.
2175    ///
2176    /// `rho` selects the assignment-mass resolution (`Some` uses the ρ-keyed
2177    /// gates, `None` the persisted gates), mirroring [`Self::try_fitted_with_rho`].
2178    /// This is the reconstruction path an OOS predict should call once the trained
2179    /// hybrid-linear images are attached via [`Self::set_hybrid_linear_images`].
2180    pub fn try_fitted_target_aware(
2181        &self,
2182        target: ArrayView2<'_, f64>,
2183        rho: Option<&SaeManifoldRho>,
2184    ) -> Result<Array2<f64>, String> {
2185        let n = self.n_obs();
2186        let p = self.output_dim();
2187        let k_atoms = self.k_atoms();
2188        if target.dim() != (n, p) {
2189            return Err(format!(
2190                "SaeManifoldTerm::try_fitted_target_aware: target {:?} != ({n}, {p})",
2191                target.dim()
2192            ));
2193        }
2194        let linear_images = self.hybrid_linear_image_map();
2195        // The all-curved reconstruction `full = Σ_j a_j·γ_j`, the same quantity the
2196        // train split's `target_resid_for` subtracts. A rescued slot `k`'s
2197        // leave-this-atom-out residual is then `target − full + a_k·γ_k`.
2198        let full_curved = self.try_fitted_with_rho(rho, false)?;
2199        let mut out = Array2::<f64>::zeros((n, p));
2200        let mut g_buf = vec![0.0_f64; p];
2201        let mut decoded_buf = vec![0.0_f64; p];
2202        let mut resid_buf = vec![0.0_f64; p];
2203        for row in 0..n {
2204            let a = match rho {
2205                Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2206                None => self.assignment.try_assignments_row(row)?,
2207            };
2208            for atom_idx in 0..k_atoms {
2209                let a_k = a[atom_idx];
2210                if let Some(image) = linear_images.get(&atom_idx) {
2211                    if image.is_collapse_rescued() {
2212                        // Recompute this row's coordinate from its own
2213                        // leave-this-atom-out residual projected onto `v`.
2214                        self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_buf);
2215                        for col in 0..p {
2216                            resid_buf[col] =
2217                                target[[row, col]] - full_curved[[row, col]] + a_k * decoded_buf[col];
2218                        }
2219                        // `coordinate_from_residual` returns `None` only on a
2220                        // length mismatch (impossible here — validated at attach)
2221                        // or a non-rescued image (excluded by the branch); fall
2222                        // back to the train code/own-coord path if it ever does.
2223                        let coord = image
2224                            .coordinate_from_residual(&resid_buf)
2225                            .unwrap_or_else(|| {
2226                                let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2227                                image.coordinate_for_row(row, own_t)
2228                            });
2229                        image.fill_row(coord, &mut g_buf);
2230                    } else {
2231                        // Ordinary straight image: decode at the atom's own coord.
2232                        let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2233                        image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2234                    }
2235                } else {
2236                    self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2237                }
2238                let mut out_row = out.row_mut(row);
2239                for out_col in 0..p {
2240                    out_row[out_col] += a_k * g_buf[out_col];
2241                }
2242            }
2243        }
2244        Ok(out)
2245    }
2246
2247    pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
2248        // Production/user-facing reconstruction: honours the #1026 hybrid-split
2249        // verdict (verdict-linear `d = 1` slots decode their straight sub-model).
2250        self.try_fitted_with_rho(None, true)
2251    }
2252
2253    pub(crate) fn try_fitted_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
2254        // Internal/fitting reconstruction: the pure CURVED image (the joint fit
2255        // and the #1026 adjudication both require the uncollapsed curve).
2256        self.try_fitted_with_rho(Some(rho), false)
2257    }
2258
2259    pub(crate) fn try_fitted_with_rho(
2260        &self,
2261        rho: Option<&SaeManifoldRho>,
2262        collapse: bool,
2263    ) -> Result<Array2<f64>, String> {
2264        let n = self.n_obs();
2265        let p = self.output_dim();
2266        let k_atoms = self.k_atoms();
2267        let mut out = Array2::<f64>::zeros((n, p));
2268        // #1026 — the curved/linear hybrid-split verdict is LOAD-BEARING on the
2269        // production reconstruction, not just a side report. When
2270        // [`Self::compute_hybrid_split_report`] (run post-fit in
2271        // `canonicalize_charts_post_fit`) adjudicated a `d = 1` atom's evidence
2272        // in favour of its straight (Θ→0) sub-model, the model's output
2273        // reconstruction (`fitted()` / `try_fitted` → predict and the user-facing
2274        // output) decodes that slot with its fitted linear image instead of its
2275        // curved decoded curve. The linear images are coordinate-keyed and
2276        // rho-independent (exact weighted-LS lines realised inside the
2277        // adjudication — no re-fit, no #1051 outer continuation).
2278        //
2279        // The collapse engages only when the caller asks for it (`collapse`):
2280        // the production `try_fitted` path and the explicit
2281        // `hybrid_collapsed_reconstruction` entry point. The pure-curved
2282        // `try_fitted_for_rho` opts out — the joint fit's loss/assembly optimise
2283        // the curved decoder coefficients and must see the curved image, and the
2284        // #1026 adjudication itself compares the curved fit against its straight
2285        // sub-model — both require the uncollapsed curve. (During fitting the
2286        // report is `None` regardless; it is only computed post-fit.)
2287        let linear_images = if collapse {
2288            self.hybrid_linear_image_map()
2289        } else {
2290            std::collections::HashMap::new()
2291        };
2292        // Reuse a single scratch buffer across all (row, atom) pairs instead of
2293        // allocating a fresh `Array1<f64>` of length p per call.
2294        let mut g_buf = vec![0.0_f64; p];
2295        for row in 0..n {
2296            let a = match rho {
2297                Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2298                None => self.assignment.try_assignments_row(row)?,
2299            };
2300            for atom_idx in 0..k_atoms {
2301                let a_k = a[atom_idx];
2302                if let Some(image) = linear_images.get(&atom_idx) {
2303                    // Verdict-linear slot: substitute the straight sub-model image
2304                    // at this row's fitted on-atom coordinate — or, for a #1026
2305                    // collapse-rescued slot, at its fresh per-row code.
2306                    let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2307                    image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2308                } else {
2309                    self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2310                }
2311                let mut out_row = out.row_mut(row);
2312                for out_col in 0..p {
2313                    out_row[out_col] += a_k * g_buf[out_col];
2314                }
2315            }
2316        }
2317        Ok(out)
2318    }
2319
2320    /// Per-atom **leave-one-atom-out (LOAO) explained-variance contribution**
2321    /// (#1026): for each atom `k`, the drop in reconstruction explained variance
2322    /// `ΔEV_k = EV(full) − EV(full ⊖ atom_k)` when that atom's contribution
2323    /// `a[i,k]·g_k(coord[i,k])` is removed from the assembled reconstruction and
2324    /// nothing else is refit. Because every atom adds linearly into the same
2325    /// fitted reconstruction (`fitted[i] = Σ_k a[i,k]·g_k`), zeroing one atom is
2326    /// the exact "this atom withheld" counterfactual, and the EV it was earning
2327    /// is `EV(full) − EV(without k)`. This is the per-atom held-out EV
2328    /// attribution the #1026 roadmap pairs with each atom's fitted turning `Θ`:
2329    /// a `Θ ≈ 0` atom earning a large `ΔEV` is a linear-tail direction; a
2330    /// high-`Θ` atom earning a large `ΔEV` is a genuine curved family carrying
2331    /// reconstruction it would otherwise shatter into `N(ε) ≈ Θ/(2√(2ε))` linear
2332    /// directions. Pure read-only diagnostic — never mutates any atom.
2333    ///
2334    /// Returns one `Option<f64>` per atom in atom order; `None` for an atom
2335    /// whose ⊖-reconstruction EV is undefined (degenerate target variance), and
2336    /// `None` for the whole vector if the full-reconstruction EV is undefined.
2337    /// #1026: the load-bearing curved-vs-linear hybrid-split verdict for the
2338    /// fitted dictionary, or `None` until [`Self::canonicalize_charts_post_fit`]
2339    /// has run (or when no `d = 1` atom is eligible). Surfaced in the Python model
2340    /// output so the user sees which atoms genuinely earn their curvature.
2341    pub fn hybrid_split_report(
2342        &self,
2343    ) -> Option<&crate::hybrid_split::SaeHybridSplitReport> {
2344        self.hybrid_split_report.as_ref()
2345    }
2346
2347    /// Build the #1026 curved-vs-linear hybrid-split report by adjudicating each
2348    /// eligible `d = 1` atom's fitted curved image against its straight (linear
2349    /// special-case) sub-model on the common rank-aware Laplace evidence scale.
2350    ///
2351    /// Both candidates are scored against the SAME data — the atom's
2352    /// leave-this-atom-out response residual `y_resp = target − (full − a_k·γ_k)`
2353    /// (#1202) — over its assigned rows: the curved candidate predicts its actual
2354    /// mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2355    /// mass-weighted straight line fit to `y_resp` (the collapsed linear lane —
2356    /// closed form, NOT the broken euclidean outer fit path of #1051). Linear is
2357    /// the curved family's nested `Θ = 0` sub-model on common data, so the
2358    /// per-slot evidence argmin is a genuine match-or-beat comparison. Eligible
2359    /// atoms are `d = 1` atoms with an installed evaluator at the full curvature
2360    /// dial (`homotopy_eta == 1.0`) whose live coordinate dim still matches the
2361    /// atom's latent dim. Returns `None` when no reconstruction `target` is
2362    /// supplied (there is no data to adjudicate against).
2363    pub fn compute_hybrid_split_report(
2364        &self,
2365        rho: &SaeManifoldRho,
2366        target: Option<ArrayView2<'_, f64>>,
2367    ) -> Result<Option<crate::hybrid_split::SaeHybridSplitReport>, String> {
2368        let n = self.n_obs();
2369        let p = self.output_dim();
2370        // Per-atom held-out `ΔEV_k` (leave-one-atom-out explained-variance drop),
2371        // paired with each atom's fitted turning Θ onto the verdict so the report
2372        // carries the #1026 `(Θ, ΔEV)` frontier point as structured data. Absent
2373        // when no reconstruction target is supplied.
2374        let loao_ev: Vec<Option<f64>> = match target {
2375            Some(t) => self.per_atom_loao_explained_variance(t, rho)?,
2376            None => vec![None; self.k_atoms()],
2377        };
2378        let delta_ev_for =
2379            |atom_idx: usize| -> Option<f64> { loao_ev.get(atom_idx).copied().flatten() };
2380        // The common-evidence comparison (#1202) scores both candidates against
2381        // the response data the atom is responsible for. That requires a target;
2382        // with none supplied there is nothing to adjudicate against, so no report.
2383        let Some(target) = target else {
2384            return Ok(None);
2385        };
2386        if target.dim() != (n, p) {
2387            return Err(format!(
2388                "SaeManifoldTerm::compute_hybrid_split_report: target {:?} != ({n}, {p})",
2389                target.dim()
2390            ));
2391        }
2392        // Per-row assignment masses (once), so each atom's weighted straight-line
2393        // fit uses the same row weighting the joint reconstruction loss does.
2394        let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2395        for row in 0..n {
2396            weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2397        }
2398        // The full assembled reconstruction `Σ_k a[i,k]·γ_k`, computed once. Each
2399        // atom's leave-this-atom-out response residual is `y_resp = target −
2400        // (full − a_k·γ_k)`, the data both that atom's candidates fit (#1202).
2401        let full = self.try_fitted_for_rho(rho)?;
2402        let eligible: Vec<usize> = (0..self.k_atoms())
2403            .filter(|&atom_idx| {
2404                let atom = &self.atoms[atom_idx];
2405                atom.latent_dim == 1
2406                    && atom.basis_evaluator.is_some()
2407                    && atom.homotopy_eta == 1.0
2408                    && self.assignment.coords[atom_idx].latent_dim() == atom.latent_dim
2409            })
2410            .collect();
2411        // Per-atom fitted decoded image at every row (the curved candidate's
2412        // realized curve, which the linear candidate must approximate).
2413        let coords_for = |atom_idx: usize| -> Array1<f64> {
2414            self.assignment.coords[atom_idx]
2415                .as_matrix()
2416                .column(0)
2417                .to_owned()
2418        };
2419        let assign_for = |atom_idx: usize| -> Array1<f64> {
2420            Array1::from_iter((0..n).map(|row| weights[row][atom_idx]))
2421        };
2422        let decoded_for = |atom_idx: usize| -> Array2<f64> {
2423            let mut decoded = Array2::<f64>::zeros((n, p));
2424            let mut buf = vec![0.0_f64; p];
2425            for row in 0..n {
2426                self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2427                for col in 0..p {
2428                    decoded[[row, col]] = buf[col];
2429                }
2430            }
2431            decoded
2432        };
2433        // The atom's leave-this-atom-out response residual `y_resp = target −
2434        // (full − a_k·γ_k) = (target − full) + a_k·γ_k`. Both the curved and the
2435        // linear candidate are scored against this on common data (#1202).
2436        let target_resid_for = |atom_idx: usize| -> Array2<f64> {
2437            let mut resid = Array2::<f64>::zeros((n, p));
2438            let mut buf = vec![0.0_f64; p];
2439            for row in 0..n {
2440                let a_k = weights[row][atom_idx];
2441                self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2442                for col in 0..p {
2443                    resid[[row, col]] = target[[row, col]] - full[[row, col]] + a_k * buf[col];
2444                }
2445            }
2446            resid
2447        };
2448        let manifold_for = |atom_idx: usize| -> gam_terms::latent::LatentManifold {
2449            self.assignment.coords[atom_idx].manifold().clone()
2450        };
2451        // #1026 EV-preservation gate denominator: the full target's total
2452        // column-centered variance `SST_full` (the SAME `sst` the reconstruction
2453        // EV is measured against), so the gate vetoes any collapse that would drop
2454        // full-reconstruction EV by more than its tolerance.
2455        let total_centered_variance = {
2456            let mut tss = 0.0_f64;
2457            for col in 0..p {
2458                let mut mean = 0.0_f64;
2459                for row in 0..n {
2460                    mean += target[[row, col]];
2461                }
2462                mean /= n as f64;
2463                for row in 0..n {
2464                    let c = target[[row, col]] - mean;
2465                    tss += c * c;
2466                }
2467            }
2468            tss
2469        };
2470        crate::hybrid_split::build_hybrid_split_report(
2471            &self.atoms,
2472            eligible.into_iter(),
2473            coords_for,
2474            assign_for,
2475            decoded_for,
2476            target_resid_for,
2477            manifold_for,
2478            delta_ev_for,
2479            total_centered_variance,
2480        )
2481    }
2482
2483    pub fn per_atom_loao_explained_variance(
2484        &self,
2485        target: ArrayView2<'_, f64>,
2486        rho: &SaeManifoldRho,
2487    ) -> Result<Vec<Option<f64>>, String> {
2488        let n = self.n_obs();
2489        let p = self.output_dim();
2490        let k_atoms = self.k_atoms();
2491        if target.dim() != (n, p) {
2492            return Err(format!(
2493                "SaeManifoldTerm::per_atom_loao_explained_variance: target {:?} != ({n}, {p})",
2494                target.dim()
2495            ));
2496        }
2497        let full = self.try_fitted_for_rho(rho)?;
2498        let Some(ev_full) = reconstruction_explained_variance(target, full.view()) else {
2499            return Ok(vec![None; k_atoms]);
2500        };
2501        // Cache each row's assignment weights once, then subtract a single
2502        // atom's decoded contribution per LOAO pass instead of reassembling the
2503        // whole dictionary k times.
2504        let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2505        for row in 0..n {
2506            weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2507        }
2508        let mut g_buf = vec![0.0_f64; p];
2509        let mut out = Vec::with_capacity(k_atoms);
2510        for atom_idx in 0..k_atoms {
2511            let mut without = full.clone();
2512            for row in 0..n {
2513                let a_k = weights[row][atom_idx];
2514                if a_k == 0.0 {
2515                    continue;
2516                }
2517                self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2518                let mut without_row = without.row_mut(row);
2519                for out_col in 0..p {
2520                    without_row[out_col] -= a_k * g_buf[out_col];
2521                }
2522            }
2523            out.push(
2524                reconstruction_explained_variance(target, without.view())
2525                    .map(|ev_without| ev_full - ev_without),
2526            );
2527        }
2528        Ok(out)
2529    }
2530
2531    /// #1026 — the LOAD-BEARING collapsed reconstruction: the assembled
2532    /// dictionary output `Σ_k a[i,k]·g_k(coord[i,k])` in which every slot whose
2533    /// hybrid-split verdict selected LINEAR has its curved decoded image replaced
2534    /// by its fitted straight sub-model `b₀ + (t − t̄)·b₁`. This is what makes the
2535    /// verdict *change the reconstruction* instead of merely logging a choice:
2536    /// the linear-collapsed atom no longer pays its `M·p` curved coefficients, it
2537    /// carries a `2·p` straight image whose decoded curve has zero turning.
2538    ///
2539    /// The straight images are the exact weighted-least-squares lines already
2540    /// realized inside [`Self::compute_hybrid_split_report`] (no re-fit, no outer
2541    /// continuation, sidestepping #1051). Returns the curved reconstruction
2542    /// unchanged when no verdict selected linear, or when the report has not been
2543    /// computed yet (`hybrid_split_report == None`).
2544    pub fn hybrid_collapsed_reconstruction(
2545        &self,
2546        rho: &SaeManifoldRho,
2547    ) -> Result<Array2<f64>, String> {
2548        // #1026 — the hybrid collapse is realised by the SINGLE reconstruction
2549        // path ([`Self::try_fitted_with_rho`]) with the collapse flag set: a
2550        // verdict-linear `d = 1` slot decodes its straight sub-model image
2551        // instead of its curved curve. This replaces the dedicated re-collapse
2552        // loop this method used to carry (a parallel layer). The production
2553        // `try_fitted` shares the identical routine at `rho = None`; this entry
2554        // point keeps the rho-keyed collapse for the #1026 EV-dominance reporting
2555        // (`hybrid_collapsed_explained_variance`) and the regression battery.
2556        self.try_fitted_with_rho(Some(rho), true)
2557    }
2558
2559    /// #1026 — the reconstruction explained variance of the hybrid-collapsed
2560    /// dictionary (every verdict-linear slot decoded by its straight sub-model)
2561    /// against `target`. The companion of [`Self::per_atom_loao_explained_variance`]
2562    /// for the dominance claim: because each linear-collapsed slot is the curved
2563    /// family's `Θ → 0` sub-model and is only kept when its evidence beats the
2564    /// curved candidate's parameter price, the collapsed dictionary match-or-beats
2565    /// the all-curved one on EV-per-parameter — the strict-generalization floor
2566    /// the #1026 hybrid argument rests on. `None` when EV is undefined (degenerate
2567    /// target variance).
2568    pub fn hybrid_collapsed_explained_variance(
2569        &self,
2570        target: ArrayView2<'_, f64>,
2571        rho: &SaeManifoldRho,
2572    ) -> Result<Option<f64>, String> {
2573        let n = self.n_obs();
2574        let p = self.output_dim();
2575        if target.dim() != (n, p) {
2576            return Err(format!(
2577                "SaeManifoldTerm::hybrid_collapsed_explained_variance: target {:?} != ({n}, {p})",
2578                target.dim()
2579            ));
2580        }
2581        let collapsed = self.hybrid_collapsed_reconstruction(rho)?;
2582        Ok(reconstruction_explained_variance(target, collapsed.view()))
2583    }
2584
2585    /// #1026 ladder item 2/3 — the AMORTIZED ENCODER, wired from the fitted
2586    /// dictionary. Builds the offline certified [`EncodeAtlas`] over this term's
2587    /// frozen atoms and encodes a target corpus `targets` (`n × p`) through the
2588    /// per-chart distilled Jacobian predictor, with the Kantorovich certificate
2589    /// gating each row and an exact-solve fallback for the rows the amortized
2590    /// predictor cannot certify. Returns one [`EncodeResult`] per atom (the
2591    /// per-atom encoded coordinates + per-row certificate mask), in dictionary
2592    /// order.
2593    ///
2594    /// This is the thread's "encoder + certificate-gated exact fallback"
2595    /// deployment made reachable from a fit: the distilled map approximates
2596    /// inference at one mat-vec/row, and any row whose amortized prediction fails
2597    /// `h ≤ ½` falls back to the certified IFT-warm-start Newton encode
2598    /// ([`EncodeAtlas::certified_encode_row`]); rows that still cannot be
2599    /// certified ride the [`EncodeResult::encode_uncertified_count`] flag for the
2600    /// upstream exact multi-start solve (honesty, never a silent wrong encode).
2601    ///
2602    /// Magic by default: the atlas's worst-case bounds are auto-derived from the
2603    /// fit — `amplitude_bound[k]` is the largest fitted assignment mass `a[i,k]`
2604    /// the encode can produce for atom `k` (the encode recovers `t` from
2605    /// `x ≈ z·γ_k(t)` at amplitude `z = a[i,k]`), and `target_norm_bound` is the
2606    /// largest target row norm — so no caller supplies a knob. Per-row amplitudes
2607    /// are the fitted assignment masses for the same target the dictionary was fit
2608    /// against; an external corpus reuses the per-row masses the assignment
2609    /// produces for it upstream (passed in `amplitudes`, one column per atom).
2610    pub fn amortized_encode_target(
2611        &self,
2612        targets: ArrayView2<'_, f64>,
2613        amplitudes: ArrayView2<'_, f64>,
2614    ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2615        let p = self.output_dim();
2616        let k_atoms = self.k_atoms();
2617        let n = targets.nrows();
2618        if targets.ncols() != p {
2619            return Err(format!(
2620                "SaeManifoldTerm::amortized_encode_target: targets have {} cols but output_dim is {p}",
2621                targets.ncols()
2622            ));
2623        }
2624        if amplitudes.dim() != (n, k_atoms) {
2625            return Err(format!(
2626                "SaeManifoldTerm::amortized_encode_target: amplitudes {:?} must be (n={n}, K={k_atoms})",
2627                amplitudes.dim()
2628            ));
2629        }
2630
2631        // Magic-by-default offline bounds, auto-derived from the fit so no caller
2632        // supplies a knob. `target_norm_bound` is the largest target row L2 norm
2633        // (bounds `‖x‖` over the corpus); `amplitude_bound[k]` is the largest
2634        // fitted assignment mass for atom `k` (bounds `|z_k|`), with a strictly
2635        // positive floor so a near-inactive atom still certifies a finite radius.
2636        let mut target_norm_bound = 0.0_f64;
2637        for row in 0..n {
2638            let norm = targets.row(row).dot(&targets.row(row)).sqrt();
2639            if norm.is_finite() && norm > target_norm_bound {
2640                target_norm_bound = norm;
2641            }
2642        }
2643        let mut amplitude_bound = vec![0.0_f64; k_atoms];
2644        for atom_idx in 0..k_atoms {
2645            let mut bound = 0.0_f64;
2646            for row in 0..n {
2647                let z = amplitudes[[row, atom_idx]].abs();
2648                if z.is_finite() && z > bound {
2649                    bound = z;
2650                }
2651            }
2652            // A strictly positive amplitude floor keeps the offline Lipschitz
2653            // scaling finite for atoms with no active row in this corpus (those
2654            // rows encode to the chart center via the certificate anyway).
2655            amplitude_bound[atom_idx] = bound.max(1.0);
2656        }
2657
2658        let atlas = crate::encode::EncodeAtlas::build(
2659            &self.atoms,
2660            &amplitude_bound,
2661            target_norm_bound,
2662            crate::encode::AtlasConfig::default(),
2663        )?;
2664
2665        // Per-atom amortized encode with a certificate-gated exact-solve fallback:
2666        // a row whose distilled prediction fails `h ≤ ½` is retried through the
2667        // certified IFT-warm-start Newton path; a row that still cannot be
2668        // certified stays flagged for the upstream multi-start solve.
2669        // (The atlas is rho-free; the per-row amplitudes already carry the
2670        // rho-resolved assignment masses the caller produced upstream.)
2671        let mut results = Vec::with_capacity(k_atoms);
2672        for atom_idx in 0..k_atoms {
2673            let atom = &self.atoms[atom_idx];
2674            let amp_col = amplitudes.column(atom_idx).to_owned();
2675            let amortized =
2676                atlas.amortized_encode_batch(atom, atom_idx, targets, amp_col.view())?;
2677            let mut coords = amortized.coords;
2678            let mut certified = amortized.certified;
2679            for row in 0..n {
2680                if certified[row] {
2681                    continue;
2682                }
2683                let (t, cert) =
2684                    atlas.certified_encode_row(atom, atom_idx, targets.row(row), amp_col[row])?;
2685                if cert.certified() {
2686                    coords.row_mut(row).assign(&t);
2687                    certified[row] = true;
2688                }
2689            }
2690            results.push(crate::encode::EncodeResult::from_rows(
2691                coords, certified,
2692            ));
2693        }
2694        Ok(results)
2695    }
2696
2697    /// #1026 — the fitted per-row assignment masses `a[i,k]` (the activation
2698    /// amplitudes `z_k` the amortized encode recovers `t` against), as an
2699    /// `n × K` matrix. These are exactly the masses
2700    /// [`Self::try_fitted_with_rho`] assembles the reconstruction from, so
2701    /// feeding them to [`Self::amortized_encode_target`] re-encodes the SAME
2702    /// inference the dictionary was fit against — the self-consistency the
2703    /// distilled encoder is supervised to approximate.
2704    pub fn fitted_assignment_amplitudes(
2705        &self,
2706        rho: &SaeManifoldRho,
2707    ) -> Result<Array2<f64>, String> {
2708        let n = self.n_obs();
2709        let k_atoms = self.k_atoms();
2710        let mut amplitudes = Array2::<f64>::zeros((n, k_atoms));
2711        for row in 0..n {
2712            let a = self.assignment.try_assignments_row_for_rho(row, rho)?;
2713            for atom_idx in 0..k_atoms {
2714                amplitudes[[row, atom_idx]] = a[atom_idx];
2715            }
2716        }
2717        Ok(amplitudes)
2718    }
2719
2720    /// #1026 — encode the dictionary's own fit-time target with the amortized
2721    /// encoder, deriving the per-row amplitudes from the fitted assignment so the
2722    /// caller supplies neither bounds nor amplitudes (magic by default). The
2723    /// end-to-end "fit → distilled encoder → certificate-gated encode" path.
2724    pub fn amortized_encode_fitted(
2725        &self,
2726        targets: ArrayView2<'_, f64>,
2727        rho: &SaeManifoldRho,
2728    ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2729        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2730        self.amortized_encode_target(targets, amplitudes.view())
2731    }
2732
2733    /// #1154 — amortized-encoder consistency of the CURRENT dictionary against
2734    /// its own fit-time target. This is the co-training signal of the joint
2735    /// amortized-encoder + REML loop (Design A): the amortized (one-mat-vec)
2736    /// encode is built from the *current* fitted decoder, run on `targets`, and
2737    /// scored on two principled axes —
2738    ///
2739    /// * `recon_consistency` (the bilinear part of the co-training loss): the
2740    ///   mean per-element squared gap between the **amortized** reconstruction
2741    ///   `Σ_k z_k · Φ_k(t̂_k) B_k` (decode the amortized coords) and the
2742    ///   **exact** fitted reconstruction `Σ_k z_k · Φ_k(t_k^*) B_k` the inner
2743    ///   solve converged to. A dictionary whose encode map is well-approximated
2744    ///   to first order by the per-chart IFT predictor scores near zero; a
2745    ///   dictionary the amortized encoder *cannot* invert faithfully (sharp
2746    ///   curvature, poorly-charted regions) scores high. Minimising this jointly
2747    ///   with REML steers the fit toward dictionaries that admit a fast,
2748    ///   faithful amortized encode — the architectural co-adaptation #1154 adds.
2749    /// * `uncertified_fraction`: the share of (row, atom) encodes whose
2750    ///   Kantorovich certificate failed (`h > ½`), i.e. that fell back to the
2751    ///   certified IFT-warm-start Newton. This is the encoder's *certifiable coverage*
2752    ///   of the dictionary; co-training rewards dictionaries the cheap encode
2753    ///   certifies, not just ones it happens to land.
2754    ///
2755    /// The certificate keeps every accepted amortized coord honest (uncertified
2756    /// rows already ride the exact fallback inside `amortized_encode_target`), so
2757    /// this metric never silently trusts a wrong encode — it MEASURES how much of
2758    /// the dictionary the cheap encoder can faithfully and certifiably invert.
2759    pub fn amortized_encoder_consistency(
2760        &self,
2761        targets: ArrayView2<'_, f64>,
2762        rho: &SaeManifoldRho,
2763    ) -> Result<AmortizedEncoderConsistency, String> {
2764        let n = self.n_obs();
2765        let p = self.output_dim();
2766        let k_atoms = self.k_atoms();
2767        if targets.dim() != (n, p) {
2768            return Err(format!(
2769                "SaeManifoldTerm::amortized_encoder_consistency: targets {:?} must be (n={n}, p={p})",
2770                targets.dim()
2771            ));
2772        }
2773        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2774        let encodes = self.amortized_encode_target(targets, amplitudes.view())?;
2775        // The EXACT fitted reconstruction the inner solve converged to (pure
2776        // curved image, rho-keyed) is the supervision target for the amortized
2777        // reconstruction. Both are n×p ambient, so the comparison is layout-free.
2778        let exact_recon = self.try_fitted_for_rho(rho)?;
2779
2780        // Build the amortized reconstruction Σ_k z_k · Φ_k(t̂_k) B_k by decoding
2781        // each atom's amortized coords through that atom's own basis evaluator.
2782        let mut amortized_recon = Array2::<f64>::zeros((n, p));
2783        let mut uncertified = 0usize;
2784        for atom_idx in 0..k_atoms {
2785            let atom = &self.atoms[atom_idx];
2786            let result = &encodes[atom_idx];
2787            // An atom with no basis evaluator cannot decode an amortized
2788            // reconstruction; every one of its rows is necessarily uncertified
2789            // (the encode flagged them all), so it contributes nothing to the
2790            // amortized recon and its full row-count to the uncertified tally.
2791            // Count it and skip the decode rather than erroring — the consistency
2792            // fold stays a bounded penalty, never a hard abort of the criterion.
2793            let Some(evaluator) = atom.basis_evaluator.as_ref() else {
2794                uncertified += n;
2795                continue;
2796            };
2797            uncertified += result.encode_uncertified_count;
2798            // Decode the amortized coords: Φ_k(t̂) is (n × M_k); B_k is (M_k × p).
2799            let (phi, _jac) = evaluator.evaluate(result.coords.view())?;
2800            let decoded = phi.dot(&atom.decoder_coefficients); // (n × p)
2801            for row in 0..n {
2802                let z = amplitudes[[row, atom_idx]];
2803                if z == 0.0 {
2804                    continue;
2805                }
2806                for col in 0..p {
2807                    amortized_recon[[row, col]] += z * decoded[[row, col]];
2808                }
2809            }
2810        }
2811
2812        let mut sse = 0.0_f64;
2813        for row in 0..n {
2814            for col in 0..p {
2815                let gap = amortized_recon[[row, col]] - exact_recon[[row, col]];
2816                sse += gap * gap;
2817            }
2818        }
2819        let denom = (n.max(1) * p.max(1)) as f64;
2820        let recon_consistency = sse / denom;
2821        let total_encodes = (n * k_atoms).max(1) as f64;
2822        let uncertified_fraction = uncertified as f64 / total_encodes;
2823
2824        Ok(AmortizedEncoderConsistency {
2825            recon_consistency,
2826            uncertified_fraction,
2827            n_uncertified: uncertified,
2828            n_encodes: n * k_atoms,
2829        })
2830    }
2831
2832    /// #1154 — the co-trained REML criterion: the exact REML criterion at `rho`
2833    /// PLUS the amortized-encoder consistency penalty, so the outer optimizer
2834    /// co-adapts the dictionary + smoothing parameters λ TOWARD a dictionary the
2835    /// fast amortized encoder can faithfully and certifiably invert.
2836    ///
2837    /// This is Design A of #1154. The inner solve still converges the `(t, β)`
2838    /// system to stationarity at the engine's current ρ (so the implicit-function
2839    /// REML λ-gradient `dβ̂/dλ = −(H+S_λ)⁻¹(dS_λ/dλ)β̂` stays EXACT — the encoder
2840    /// only warm-starts/co-adapts, it never replaces the stationary point). The
2841    /// added term
2842    ///
2843    /// ```text
2844    ///   J_cotrain(ρ) = REML(ρ)  +  w · ‖x̂_amortized − x̂_exact‖²/(n·p)
2845    ///                            +  w_cert · uncertified_fraction
2846    /// ```
2847    ///
2848    /// folds the post-fit amortized-encode quality into the ranked objective. The
2849    /// weights are auto-scaled to the REML criterion magnitude (magic by default:
2850    /// no caller knob) so the consistency term is a meaningful but non-dominant
2851    /// fraction of the objective regardless of problem scale.
2852    pub fn reml_criterion_cotrained(
2853        &mut self,
2854        target: ArrayView2<'_, f64>,
2855        rho: &SaeManifoldRho,
2856        registry: Option<&AnalyticPenaltyRegistry>,
2857        inner_max_iter: usize,
2858        learning_rate: f64,
2859        ridge_ext_coord: f64,
2860        ridge_beta: f64,
2861    ) -> Result<(f64, SaeManifoldLoss, AmortizedEncoderConsistency), String> {
2862        // #1154: always attempt the amortized warm-start first inside
2863        // `reml_criterion_cotrained` (the encode/warm path for the cotrained
2864        // objective). Good warm-starts from the running dictionary land the
2865        // inner solve closer to the stationary point used for the fold.
2866        // Advisory only (0 or err falls back to cold); telemetry recorded by
2867        // outer objective callers when present.
2868        self.warm_start_latents_from_amortized_encoder(target, rho)
2869            .unwrap_or(0);
2870        let (reml, loss) = self.reml_criterion_with_refine_policy(
2871            target,
2872            rho,
2873            registry,
2874            inner_max_iter,
2875            learning_rate,
2876            ridge_ext_coord,
2877            ridge_beta,
2878            true,
2879        )?;
2880        let consistency = self.amortized_encoder_consistency(target, rho)?;
2881        // Auto-scale the co-training weights to the REML magnitude so the
2882        // consistency penalty is a bounded, scale-free fraction of the objective
2883        // (magic by default: no caller knob). `reml_scale` floors at 1 so a
2884        // near-zero criterion still admits a meaningful consistency contribution.
2885        let cotrained = Self::fold_cotrain_consistency(reml, &consistency);
2886        Ok((cotrained, loss, consistency))
2887    }
2888
2889    /// #1154 — the single source of the co-training fold arithmetic: add the
2890    /// auto-scaled amortized-encoder consistency penalty to an already-computed
2891    /// REML criterion at the converged dictionary. Both the public
2892    /// [`Self::reml_criterion_cotrained`] entry point and the outer-loop value /
2893    /// gradient lanes (`SaeManifoldOuterObjective::fold_cotrain_consistency`)
2894    /// route through THIS function, so the folded objective cannot drift between
2895    /// the criterion and the cascade-ranked cost (the objective↔gradient desync
2896    /// bug class). The weights are auto-scaled to the REML magnitude (`max(|REML|,
2897    /// 1)`) so the penalty is a bounded, scale-free fraction of the objective
2898    /// regardless of problem scale; the fold carries no analytic gradient (under
2899    /// Design A the REML λ-gradient stays the exact implicit-function path).
2900    #[must_use]
2901    pub fn fold_cotrain_consistency(
2902        reml_cost: f64,
2903        consistency: &AmortizedEncoderConsistency,
2904    ) -> f64 {
2905        let reml_scale = reml_cost.abs().max(1.0);
2906        reml_cost
2907            + COTRAIN_RECON_WEIGHT * reml_scale * consistency.recon_consistency
2908            + COTRAIN_CERT_WEIGHT * reml_scale * consistency.uncertified_fraction
2909    }
2910
2911    /// #1154 item 2 — warm-start the inner latent coordinates from the amortized
2912    /// encoder (Design A). Builds the per-chart IFT-Jacobian atlas from the
2913    /// CURRENT dictionary, runs the one-mat-vec amortized encode of `target`
2914    /// against each atom at the rho-resolved assignment masses, and overwrites
2915    /// each atom's stored latent coords with the predicted `t̂` ON THE ROWS THE
2916    /// KANTOROVICH CERTIFICATE ACCEPTS. Uncertified rows are left at their
2917    /// current coords (the previous-iterate start), so the
2918    /// warm-start can only HELP — a row the cheap predictor cannot certify never
2919    /// corrupts the seed. The subsequent inner Newton refines from this seed to
2920    /// the SAME stationary point (the warm-start changes only the basin entry,
2921    /// not the root), so the REML λ-gradient stays exactly the implicit-function
2922    /// path and the criterion is unchanged at convergence — the amortized encoder
2923    /// only accelerates/co-adapts the inner solve, it never replaces the
2924    /// stationary point.
2925    ///
2926    /// Returns the number of (row, atom) coords actually warm-started (the
2927    /// certified-prediction count), for instrumentation / tests. A first-build
2928    /// dictionary with no usable charts simply warm-starts nothing and returns 0
2929    /// (the cold path is byte-for-byte unchanged).
2930    pub fn warm_start_latents_from_amortized_encoder(
2931        &mut self,
2932        target: ArrayView2<'_, f64>,
2933        rho: &SaeManifoldRho,
2934    ) -> Result<usize, String> {
2935        let n = self.n_obs();
2936        let k_atoms = self.k_atoms();
2937        if n == 0 || k_atoms == 0 {
2938            return Ok(0);
2939        }
2940        let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2941        let encodes = self.amortized_encode_target(target, amplitudes.view())?;
2942        let mut warm_started = 0usize;
2943        for atom_idx in 0..k_atoms {
2944            let d = self.atoms[atom_idx].latent_dim;
2945            if d == 0 {
2946                continue;
2947            }
2948            let result = &encodes[atom_idx];
2949            // Start from the atom's CURRENT coords so uncertified rows are left
2950            // exactly as they were; overwrite only the certified predictions.
2951            let mut coords = self.assignment.coords[atom_idx].as_matrix();
2952            if coords.dim() != (n, d) {
2953                return Err(format!(
2954                    "warm_start_latents_from_amortized_encoder: atom {atom_idx} coords {:?} != (n={n}, d={d})",
2955                    coords.dim()
2956                ));
2957            }
2958            for row in 0..n {
2959                if !result.certified[row] {
2960                    continue;
2961                }
2962                for axis in 0..d {
2963                    coords[[row, axis]] = result.coords[[row, axis]];
2964                }
2965                warm_started += 1;
2966            }
2967            // `as_matrix` lays coords out row-major (`[[row, axis]]`), exactly the
2968            // `values[row*d + axis]` order `set_flat` expects, so a plain
2969            // row-major iterator reconstructs the flat vector.
2970            let flat = Array1::from_iter(coords.iter().copied());
2971            self.assignment.coords[atom_idx].set_flat(flat.view());
2972        }
2973        // The basis caches must follow the freshly-seeded coords so the next
2974        // inner solve evaluates Φ at the warm-started t̂, not the stale coords.
2975        self.refresh_basis_from_current_coords()?;
2976        Ok(warm_started)
2977    }
2978
2979    pub fn loss(
2980        &self,
2981        target: ArrayView2<'_, f64>,
2982        rho: &SaeManifoldRho,
2983    ) -> Result<SaeManifoldLoss, String> {
2984        self.loss_scaled(target, rho, 1.0)
2985    }
2986
2987    /// Penalized objective with a `penalty_scale` applied to the β-tier
2988    /// (decoder smoothness) penalty, mirroring
2989    /// [`Self::assemble_arrow_schur_scaled`]. The streaming line search sums
2990    /// per-chunk `loss_scaled(..., n_chunk / N)` so that the global smoothness
2991    /// penalty is counted exactly once across a pass while the per-row data,
2992    /// assignment-prior, and ARD terms sum naturally. `penalty_scale == 1.0`
2993    /// recovers the full-batch objective.
2994    pub fn loss_scaled(
2995        &self,
2996        target: ArrayView2<'_, f64>,
2997        rho: &SaeManifoldRho,
2998        penalty_scale: f64,
2999    ) -> Result<SaeManifoldLoss, String> {
3000        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3001            return Err(format!(
3002                "SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3003            ));
3004        }
3005        if target.dim() != (self.n_obs(), self.output_dim()) {
3006            return Err(format!(
3007                "SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
3008                self.n_obs(),
3009                self.output_dim(),
3010                target.dim()
3011            ));
3012        }
3013        // The likelihood whitens through the RowMetric **only** when the metric
3014        // is a genuinely estimated noise model (`metric.whitens_likelihood()`,
3015        // i.e. `WhitenedStructured` — the #974 residual-covariance seam). For
3016        // Euclidean (default `None`) and for the OutputFisher *gauge* metric the
3017        // reconstruction data-fit stays the isotropic `0.5 * Σ r²`: a gauge /
3018        // output-Fisher inner product must NOT silently replace the
3019        // reconstruction loss with a Fisher pullback (#980). It only drives the
3020        // gauge (see `analytic_penalties::corrected_isometry_penalty`). The
3021        // producer of `WhitenedStructured` is
3022        // `inference::residual_factor::StructuredResidualModel::row_metric`; the
3023        // SAME metric whitens the assembled gradient/Hessian in
3024        // `assemble_arrow_schur` (the single #974 seam), so this value and that
3025        // gradient cannot desync. Without a whitening metric this path is
3026        // bit-for-bit the historical isotropic data-fit.
3027        let whitens = self
3028            .row_metric
3029            .as_ref()
3030            .is_some_and(|metric| metric.whitens_likelihood());
3031        // #991 design honesty weights: the reconstruction channel of row `i`
3032        // is weighted by `w_i` (mean-1 HT inclusion correction). The assembly
3033        // applies the same `w_i` via a `√w_i` scaling of the row residual /
3034        // Jacobian / β load at its single seam, so this value and that
3035        // gradient/Hessian carry the identical per-row factor. `None` ⇒ the
3036        // historical unweighted sum, bit-for-bit.
3037        let row_loss_w = self.row_loss_weights.as_deref();
3038        let n = self.n_obs();
3039        let p = self.output_dim();
3040        let k_atoms = self.k_atoms();
3041        // #1017: the data-fit is the dominant per-line-search-trial cost (it
3042        // re-runs every Armijo halving × every inner Newton iteration × every
3043        // outer ρ evaluation). The old path materialised the whole `n × p`
3044        // fitted matrix (`try_fitted_for_rho`) and then walked it AGAIN to form
3045        // the residual sum — two sequential `n·p` passes plus an `n·p`
3046        // allocation per trial. Fuse the reconstruction and the residual reduce
3047        // into ONE row-parallel pass that never materialises the fitted matrix:
3048        // each row decodes its atoms into per-worker scratch, differences
3049        // against the target, and contributes its scalar `0.5·w·‖r‖²` to a
3050        // chunk-ordered fold (bit-identical run-to-run). Per-worker scratch
3051        // (`map_init`) keeps the only allocations one `g_buf`/`fitted_row` pair
3052        // per rayon thread rather than per row. Stay sequential inside a worker
3053        // (the topology race owns the outer pool) to avoid nested
3054        // oversubscription.
3055        let parallel = n >= SAE_LOSS_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
3056        let row_data_fit =
3057            |row: usize,
3058             g_buf: &mut [f64],
3059             fitted_row: &mut [f64],
3060             assign_buf: &mut [f64]|
3061             -> Result<f64, String> {
3062                // #1557 — fill the per-atom assignment row into reused per-worker
3063                // scratch via the `_into` twin instead of heap-allocating a fresh
3064                // `Array1` per row per loss eval. Bit-identical to the allocating
3065                // `try_assignments_row_for_rho` (same arithmetic, same order); this
3066                // loss reruns every Armijo halving × inner Newton iter × outer ρ
3067                // eval, so the per-row K-sized allocation was a hot-path churn.
3068                self.assignment
3069                    .try_assignments_row_for_rho_into(row, rho, assign_buf)?;
3070                let a = &*assign_buf;
3071                for slot in fitted_row.iter_mut() {
3072                    *slot = 0.0;
3073                }
3074                for atom_idx in 0..k_atoms {
3075                    self.atoms[atom_idx].fill_decoded_row(row, g_buf);
3076                    let a_k = a[atom_idx];
3077                    for out_col in 0..p {
3078                        fitted_row[out_col] += a_k * g_buf[out_col];
3079                    }
3080                }
3081                for out_col in 0..p {
3082                    fitted_row[out_col] = target[[row, out_col]] - fitted_row[out_col];
3083                }
3084                let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3085                let mut acc = 0.0_f64;
3086                match self.row_metric.as_ref() {
3087                    Some(metric) if whitens => {
3088                        let resid = ArrayView1::from(&fitted_row[..p]);
3089                        for w in metric.whiten_residual_row(row, resid) {
3090                            acc += 0.5 * w_row * w * w;
3091                        }
3092                    }
3093                    _ => {
3094                        for &r in fitted_row[..p].iter() {
3095                            acc += 0.5 * w_row * r * r;
3096                        }
3097                    }
3098                }
3099                Ok(acc)
3100            };
3101        let data_fit = if parallel {
3102            use rayon::prelude::*;
3103            const CHUNK: usize = 32;
3104            let partials: Vec<Result<f64, String>> = (0..n)
3105                .into_par_iter()
3106                .chunks(CHUNK)
3107                .map_init(
3108                    || (vec![0.0_f64; p], vec![0.0_f64; p], vec![0.0_f64; k_atoms]),
3109                    |(g_buf, fitted_row, assign_buf), idxs| {
3110                        // #1557 — pin any faer GEMM reached from this row-parallel
3111                        // data-fit chunk to `Par::Seq` (no nested Rayon re-fan); the
3112                        // per-row reductions are tiny, so the result is bit-identical.
3113                        with_nested_parallel(|| {
3114                            let mut acc = 0.0_f64;
3115                            for row in idxs {
3116                                acc += row_data_fit(row, g_buf, fitted_row, assign_buf)?;
3117                            }
3118                            Ok(acc)
3119                        })
3120                    },
3121                )
3122                .collect();
3123            let mut total = 0.0_f64;
3124            for partial in partials {
3125                total += partial?;
3126            }
3127            total
3128        } else {
3129            let mut g_buf = vec![0.0_f64; p];
3130            let mut fitted_row = vec![0.0_f64; p];
3131            let mut assign_buf = vec![0.0_f64; k_atoms];
3132            let mut total = 0.0_f64;
3133            for row in 0..n {
3134                total += row_data_fit(row, &mut g_buf, &mut fitted_row, &mut assign_buf)?;
3135            }
3136            total
3137        };
3138        let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
3139        let smoothness = penalty_scale * self.decoder_smoothness_value(&rho.lambda_smooth_vec());
3140        let ard = self.ard_value(rho)?;
3141        Ok(SaeManifoldLoss {
3142            data_fit,
3143            assignment_sparsity,
3144            smoothness,
3145            ard,
3146            evidence_gauge_deflated_directions: 0,
3147        })
3148    }
3149
3150    /// Reconstruction data-fit `0.5·Σ_i w_i·‖whiten(Z_i − R_i)‖²` for an EXPLICIT
3151    /// reconstruction matrix `R` (e.g. the hard top-k–projected `fitted`), using
3152    /// the SAME per-row metric and design-honesty weights as [`Self::loss_scaled`]
3153    /// (the soft-assignment data-fit). The only difference is the residual source:
3154    /// `loss_scaled` decodes the soft assignments on the fly, this consumes a
3155    /// reconstruction the caller already assembled (so the projected loss and the
3156    /// returned projected `fitted` describe one and the same model). The penalty
3157    /// terms (`assignment_sparsity`/`smoothness`/`ard`) are decoder/ρ properties
3158    /// the top-k gate does not change, so the caller keeps them from the soft
3159    /// `loss_scaled` and only swaps this data-fit in — see #1232.
3160    pub fn data_fit_for_reconstruction(
3161        &self,
3162        target: ArrayView2<'_, f64>,
3163        reconstruction: ArrayView2<'_, f64>,
3164    ) -> Result<f64, String> {
3165        let n = self.n_obs();
3166        let p = self.output_dim();
3167        if target.dim() != (n, p) {
3168            return Err(format!(
3169                "SaeManifoldTerm::data_fit_for_reconstruction: Z must be ({n}, {p}); got {:?}",
3170                target.dim()
3171            ));
3172        }
3173        if reconstruction.dim() != (n, p) {
3174            return Err(format!(
3175                "SaeManifoldTerm::data_fit_for_reconstruction: reconstruction must be ({n}, {p}); got {:?}",
3176                reconstruction.dim()
3177            ));
3178        }
3179        let whitens = self
3180            .row_metric
3181            .as_ref()
3182            .is_some_and(|metric| metric.whitens_likelihood());
3183        let row_loss_w = self.row_loss_weights.as_deref();
3184        let mut resid = vec![0.0_f64; p];
3185        let mut total = 0.0_f64;
3186        for row in 0..n {
3187            for out_col in 0..p {
3188                resid[out_col] = target[[row, out_col]] - reconstruction[[row, out_col]];
3189            }
3190            let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3191            match self.row_metric.as_ref() {
3192                Some(metric) if whitens => {
3193                    let r = ArrayView1::from(&resid[..p]);
3194                    for w in metric.whiten_residual_row(row, r) {
3195                        total += 0.5 * w_row * w * w;
3196                    }
3197                }
3198                _ => {
3199                    for &r in resid[..p].iter() {
3200                        total += 0.5 * w_row * r * r;
3201                    }
3202                }
3203            }
3204        }
3205        Ok(total)
3206    }
3207
3208    pub fn analytic_penalty_value_total(
3209        &self,
3210        registry: &AnalyticPenaltyRegistry,
3211        penalty_scale: f64,
3212    ) -> Result<f64, ArrowSchurError> {
3213        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3214            return Err(ArrowSchurError::SchurFactorFailed {
3215                reason: format!(
3216                    "SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
3217                     and positive; got {penalty_scale}"
3218                ),
3219            });
3220        }
3221        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3222        let layout = registry.rho_layout();
3223        let beta = self.flatten_beta();
3224        let mut value = 0.0_f64;
3225        for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
3226            let rho_local = rho_global.slice(s![rho_slice.clone()]);
3227            // Skip the registry `ARDPenalty` here for the same reason it is
3228            // skipped in `add_sae_analytic_penalty_contributions`: the coordinate
3229            // ARD energy is already counted by `loss.ard` (the von-Mises
3230            // `ard_value`), and the registry penalty's legacy Gaussian `½λt²` is
3231            // period-discontinuous. Including it would double-count the energy and
3232            // make this line-search objective jump across the branch cut while the
3233            // assembled gradient (von-Mises only, after the assembly fix) stays
3234            // continuous — i.e. a near-zero step would change the objective by a
3235            // finite amount and Armijo would wrongly reject it.
3236            if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
3237                continue;
3238            }
3239            match tier {
3240                PenaltyTier::Psi => {
3241                    if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
3242                        for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3243                            value += penalty_scale
3244                                * per_atom.value(beta.slice(s![start..end]), rho_local);
3245                        }
3246                    } else {
3247                        if !sae_penalty_is_row_block_supported(penalty) {
3248                            return Err(ArrowSchurError::SchurFactorFailed {
3249                                reason: format!(
3250                                    "validate_analytic_penalty_registry should have refused \
3251                                     non-row-block Psi-tier penalty {:?} (registry layout name \
3252                                     {name:?})",
3253                                    penalty.name()
3254                                ),
3255                            });
3256                        }
3257                        for atom_idx in 0..self.k_atoms() {
3258                            let coord = &self.assignment.coords[atom_idx];
3259                            if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3260                                let corrected_kind =
3261                                    self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3262                                value += corrected_kind.value(coord.as_flat().view(), rho_local);
3263                            } else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
3264                                // Origin-anchored magnitude shrinkage (SCAD/MCP) is
3265                                // restricted to the Euclidean axes; periodic axes have
3266                                // no chart origin and would make this energy
3267                                // period-discontinuous (issue #795). This must mirror
3268                                // the gradient/curvature assembly in
3269                                // `add_sae_coord_penalty` exactly.
3270                                match sae_coord_penalty_euclidean_restriction(coord) {
3271                                    Some((_axes, compacted)) => {
3272                                        value += penalty.value(compacted.view(), rho_local);
3273                                    }
3274                                    None => {
3275                                        value += penalty.value(coord.as_flat().view(), rho_local);
3276                                    }
3277                                }
3278                            } else {
3279                                value += penalty.value(coord.as_flat().view(), rho_local);
3280                            }
3281                        }
3282                    }
3283                }
3284                PenaltyTier::Beta => {
3285                    if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
3286                        if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3287                            value += penalty_scale * per_fit.value(beta.view(), rho_local);
3288                        }
3289                    } else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
3290                        for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3291                            if start < end {
3292                                value += penalty_scale * per_atom.value(beta.view(), rho_local);
3293                            }
3294                        }
3295                    } else {
3296                        value += penalty_scale * penalty.value(beta.view(), rho_local);
3297                    }
3298                }
3299                PenaltyTier::Rho => {}
3300            }
3301        }
3302        Ok(value)
3303    }
3304
3305    /// Energy of the decoder-block analytic penalties that have no native
3306    /// `SaeManifoldLoss` counterpart, evaluated at the current decoder `β` and
3307    /// the converged SAE state. These act on the per-atom decoder coefficient
3308    /// matrices: cross-atom decoder incoherence (#671), mechanism
3309    /// (feature-group) sparsity, and nuclear-norm embedding rank (#672). Each
3310    /// is injected with its live per-atom shape / co-activation before its
3311    /// value is taken, mirroring the assemble path.
3312    ///
3313    /// This is deliberately narrower than [`Self::analytic_penalty_value_total`]:
3314    /// it excludes the Psi-tier coordinate / assignment penalties (ARD,
3315    /// Isometry, ScadMcp, BlockOrthogonality, IBP/softmax assignment sparsity).
3316    /// The SAE already carries its own ARD (`loss.ard`) and assignment sparsity
3317    /// (`loss.assignment_sparsity`) energy, so adding the registry ARD /
3318    /// assignment value on top would double-count, and the gauge-only
3319    /// coordinate penalties are not part of the penalized deviance the
3320    /// REML/Laplace criterion scores. The decoder-block penalties, by contrast,
3321    /// are real penalized-energy terms with no `loss.*` representative: the
3322    /// inner solve minimizes them (they enter `gb`/`hbb`) but they were absent
3323    /// from the criterion scalar `v`. This restores that consistency so the
3324    /// ρ-sweep ranks the same objective the inner solve descends — the #671
3325    /// incoherence lever in particular now shapes model selection, not just the
3326    /// Newton step.
3327    ///
3328    /// NOTE: the coordinate-block penalties with no native `loss.*` twin
3329    /// (`ScadMcp`, `BlockOrthogonality`) carry the same residual inconsistency
3330    /// (scored in the line search via `penalized_objective_total`, absent from
3331    /// the REML scalar). They are left out here because they share a registry
3332    /// dispatch with the always-on `Isometry` gauge, whose inclusion in the
3333    /// topology-comparison criterion is a separate design question (#673:
3334    /// topology evidence is gauge-conditional). Folding the coord-tier energy in
3335    /// is tracked apart from this #671 decoder fix.
3336    pub fn analytic_decoder_penalty_value_total(
3337        &self,
3338        registry: &AnalyticPenaltyRegistry,
3339    ) -> Result<f64, ArrowSchurError> {
3340        // Resolve each penalty's rho slice exactly as `analytic_penalty_value_total`
3341        // does (registry-local rho at zeros), so a learnable decoder-penalty weight
3342        // is honoured rather than indexing into an empty view.
3343        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3344        let layout = registry.rho_layout();
3345        let beta = self.flatten_beta();
3346        let mut value = 0.0_f64;
3347        for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3348            let rho_local = rho_global.slice(s![rho_slice.clone()]);
3349            match penalty {
3350                AnalyticPenaltyKind::DecoderIncoherence(base) => {
3351                    if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3352                        value += per_fit.value(beta.view(), rho_local);
3353                    }
3354                }
3355                AnalyticPenaltyKind::MechanismSparsity(base) => {
3356                    for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3357                        if start < end {
3358                            value += per_atom.value(beta.view(), rho_local);
3359                        }
3360                    }
3361                }
3362                AnalyticPenaltyKind::NuclearNorm(base) => {
3363                    for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3364                        value += per_atom.value(beta.slice(s![start..end]), rho_local);
3365                    }
3366                }
3367                _ => {}
3368            }
3369        }
3370        Ok(value)
3371    }
3372
3373    /// Energy of the COORDINATE-tier isometry penalty(ies) at the converged
3374    /// SAE state. This is the per-atom `½μ Σ_n ‖J_n^T W_n J_n / gbar − g_ref‖²`
3375    /// summed over atoms, evaluated through `corrected_isometry_penalty` so the
3376    /// live decoder/coordinate caches drive the value exactly as the assemble
3377    /// path does. It has no `SaeManifoldLoss` twin (the loss carries only
3378    /// data-fit / assignment / smoothness / ARD), so the Laplace/REML criterion
3379    /// must add it explicitly to score the same penalized objective the inner
3380    /// solve descends.
3381    pub fn isometry_penalty_value_total(
3382        &self,
3383        registry: &AnalyticPenaltyRegistry,
3384    ) -> Result<f64, ArrowSchurError> {
3385        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3386        let layout = registry.rho_layout();
3387        let mut value = 0.0_f64;
3388        for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3389            if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3390                let rho_local = rho_global.slice(s![rho_slice.clone()]);
3391                for atom_idx in 0..self.k_atoms() {
3392                    let coord = &self.assignment.coords[atom_idx];
3393                    let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3394                    value += corrected_kind.value(coord.as_flat().view(), rho_local);
3395                }
3396            }
3397        }
3398        Ok(value)
3399    }
3400
3401    /// Whether assembling `registry` will scatter an isometry Gauss-Newton
3402    /// cross-block (`H_tβ`) into the per-row dense `htbeta` slabs.
3403    ///
3404    /// `add_sae_isometry_metric_gn_blocks` writes the coupled cross-block (and
3405    /// flips on `activate_dense_htbeta_supplement`) only when (a) the registry
3406    /// carries an `Isometry` penalty and (b) the atom's chart
3407    /// `preserves_isometry_cross_block_coherence` (flat charts — `Euclidean`,
3408    /// `Circle`, and flat products — keep the full `μ AᵀA` coupling; curved /
3409    /// boundary charts drop it to stay PSD). On the non-frames matrix-free path
3410    /// the data-fit cross-block is carried by the Kronecker row operator and the
3411    /// per-row `htbeta` slab is allocated at zero width (#1406/#1407 anti-leak),
3412    /// so this dense isometry supplement has nowhere to land unless the slab is
3413    /// widened to the full `beta_dim`. This predicate decides exactly that. The
3414    /// effective isometry weight `μ` is NOT consulted here: a near-zero `μ`
3415    /// short-circuits the per-row write, but the slab must still exist so the
3416    /// solver's `htbeta_dense_supplement` read is well-shaped.
3417    pub(crate) fn registry_writes_dense_isometry_cross_block(
3418        &self,
3419        registry: &AnalyticPenaltyRegistry,
3420    ) -> bool {
3421        registry
3422            .penalties
3423            .iter()
3424            .any(|p| matches!(p, AnalyticPenaltyKind::Isometry(_)))
3425            && self
3426                .assignment
3427                .coords
3428                .iter()
3429                .any(|coord| coord.manifold().preserves_isometry_cross_block_coherence())
3430    }
3431
3432    /// Extra analytic-penalty energy that has no native `SaeManifoldLoss`
3433    /// component but is part of the penalized objective ranked by the SAE
3434    /// Laplace/REML criterion.
3435    pub fn reml_extra_penalty_value_total(
3436        &self,
3437        registry: &AnalyticPenaltyRegistry,
3438    ) -> Result<f64, ArrowSchurError> {
3439        Ok(self.analytic_decoder_penalty_value_total(registry)?
3440            + self.isometry_penalty_value_total(registry)?)
3441    }
3442
3443    pub fn penalized_objective_total(
3444        &self,
3445        target: ArrayView2<'_, f64>,
3446        rho: &SaeManifoldRho,
3447        registry: Option<&AnalyticPenaltyRegistry>,
3448        penalty_scale: f64,
3449    ) -> Result<f64, String> {
3450        let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
3451        if let Some(analytic_registry) = registry {
3452            total += self
3453                .analytic_penalty_value_total(analytic_registry, penalty_scale)
3454                .map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
3455        }
3456        // #1026 — decoder-repulsion value, on the SAME frozen gate the assembly
3457        // used, so the line search sees the term the Newton step optimizes. 0
3458        // unless two atoms are near-collinear (the no-op case).
3459        total += self.decoder_repulsion_value(penalty_scale);
3460        // #1026/#1522 — interior-point collapse-prevention barriers, on the SAME
3461        // decoders the assembly's gradient/curvature used, so the line search sees
3462        // exactly the term the inner Newton step optimises (no value/grad desync).
3463        total += self.separation_barrier_value(penalty_scale);
3464        Ok(total)
3465    }
3466
3467    pub(crate) fn decoder_smoothness_value(&self, lambda_smooth: &[f64]) -> f64 {
3468        // Smoothness penalty value is `0.5·λ·Σ_oc B[:,oc]ᵀ S B[:,oc]`. Form the
3469        // `S·B` matrix product once per atom (O(M²·p)) and reduce against `B`
3470        // with a single O(M·p) Hadamard sum, instead of the previous
3471        // four-factor multiply-accumulate inside an `O(M²·p)` triple loop.
3472        // The quadratic form only sees the symmetric part of `S`, so reusing
3473        // the raw (un-symmetrised) `smooth_penalty` here is numerically
3474        // identical to the symmetrised assembly form.
3475        // Per-atom `S_k · B_k` products are independent across atoms, so they ride
3476        // the multi-GPU batched smoothness GEMM (uniform-shape groups tiled across
3477        // every device); `symmetrize = false` because the quadratic form only sees
3478        // the symmetric part of `S` regardless. Exact CPU fallback per atom.
3479        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3480            .atoms
3481            .iter()
3482            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3483            .collect();
3484        let sb_all = batched_smooth_sb(&sb_inputs, false);
3485        let mut acc = 0.0;
3486        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3487            acc += 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3488        }
3489        acc
3490    }
3491
3492    /// Per-atom decoder-smoothness values (#1556): entry `k` is
3493    /// `0.5·λ_smooth[k]·<B_k, S_k B_k>` (sum = [`Self::decoder_smoothness_value`]).
3494    /// This is the explicit `∂loss.smoothness/∂log λ_smooth[k]` gradient entry.
3495    pub(crate) fn decoder_smoothness_value_per_atom(&self, lambda_smooth: &[f64]) -> Vec<f64> {
3496        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3497            .atoms
3498            .iter()
3499            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3500            .collect();
3501        let sb_all = batched_smooth_sb(&sb_inputs, false);
3502        let mut per_atom = vec![0.0_f64; self.atoms.len()];
3503        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3504            per_atom[atom_idx] =
3505                0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3506        }
3507        per_atom
3508    }
3509
3510    pub(crate) fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
3511        if rho.log_ard.len() != self.k_atoms() {
3512            return Err(format!(
3513                "ARD rho has {} atoms but term has {}",
3514                rho.log_ard.len(),
3515                self.k_atoms()
3516            ));
3517        }
3518        let n = self.n_obs();
3519        let mut acc = 0.0;
3520        for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3521            let d = coord.latent_dim();
3522            if rho.log_ard[atom_idx].is_empty() {
3523                continue;
3524            }
3525            if rho.log_ard[atom_idx].len() != d {
3526                return Err(format!(
3527                    "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
3528                    rho.log_ard[atom_idx].len()
3529                ));
3530            }
3531            // Per-axis periodicity selects the smooth von-Mises energy on
3532            // wrapped (Circle) axes and the Gaussian on Euclidean axes.
3533            let periods = coord.effective_axis_periods();
3534            for axis in 0..d {
3535                let log_alpha = rho.log_ard[atom_idx][axis];
3536                // Clamp the log-precision before exponentiating: a raw
3537                // `exp(log_ard)` overflows to `inf` for `log_ard ≳ 709`, and the
3538                // `inf` precision then poisons the ARD energy / curvature with
3539                // `inf · 0.0 = NaN` (#742, Issue 4).
3540                let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
3541                let period = periods[axis];
3542                let mut energy = 0.0;
3543                for row in 0..n {
3544                    let v = coord.row(row)[axis];
3545                    energy += ArdAxisPrior::eval(alpha, v, period).value;
3546                }
3547                // Negative-log prior for precision alpha. The data-dependent
3548                // energy is the (Gaussian or von-Mises) coordinate prior; the
3549                // accompanying normaliser is the precision log-partition.
3550                //
3551                // Euclidean axes keep the Gaussian normaliser `-0.5 n log α`.
3552                // Periodic (von-Mises) axes use the EXACT von-Mises precision
3553                // log-partition `n[-η + log I0(η)]`, η = α/κ², κ = 2π/P, rather
3554                // than the Gaussian surrogate: the von-Mises partition function
3555                // is `2π I0(η)` (up to the κ Jacobian), so the per-observation
3556                // normaliser is `-η + log I0(η)` and is exact across the cut.
3557                match period {
3558                    None => {
3559                        acc += energy - 0.5 * (n as f64) * log_alpha;
3560                    }
3561                    Some(p) => {
3562                        let kappa = std::f64::consts::TAU / p;
3563                        let eta = alpha / (kappa * kappa);
3564                        // Overflow-free `log I0(η)`; `bessel_i0(η).ln()` would be
3565                        // `+inf` for `η ≳ 709` (#1113).
3566                        let log_i0 = bessel_i0_log_and_ratio(eta).0;
3567                        acc += energy + (n as f64) * (-eta + log_i0);
3568                    }
3569                }
3570            }
3571        }
3572        Ok(acc)
3573    }
3574
3575    /// Assemble the enlarged `(logits, t)` row-local Arrow-Schur system.
3576    ///
3577    /// Full-batch entry point: a single chunk covering all rows, with the
3578    /// β-tier penalties (decoder smoothness, ARD, analytic β penalties) carrying
3579    /// their full strength. The streaming driver calls
3580    /// [`Self::assemble_arrow_schur_scaled`] directly with a `penalty_scale`
3581    /// equal to the minibatch fraction `n_chunk / N`, so that the sum of the
3582    /// per-chunk β-tier contributions over a full pass reconstructs exactly the
3583    /// single global β penalty (the smoothness/ARD/β terms are functions of `B`
3584    /// and the global coordinates, not of the chunk's rows).
3585    pub fn assemble_arrow_schur(
3586        &mut self,
3587        target: ArrayView2<'_, f64>,
3588        rho: &SaeManifoldRho,
3589        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3590    ) -> Result<ArrowSchurSystem, String> {
3591        self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
3592    }
3593
3594    /// Assemble the row-local Arrow-Schur system with a `penalty_scale` applied
3595    /// to the β-tier (decoder smoothness, ARD prior, analytic β penalties).
3596    ///
3597    /// `penalty_scale == 1.0` recovers the full-batch assembly. The streaming
3598    /// driver passes the minibatch fraction `n_chunk / N` so that the β-tier
3599    /// reduced-Schur and gradient contributions of the chunks sum to exactly one
3600    /// global copy across a full pass (data-fit, assignment-prior, and per-row
3601    /// coord/logit analytic terms are *not* scaled — they are genuine per-row
3602    /// sums).
3603    pub fn assemble_arrow_schur_scaled(
3604        &mut self,
3605        target: ArrayView2<'_, f64>,
3606        rho: &SaeManifoldRho,
3607        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3608        penalty_scale: f64,
3609    ) -> Result<ArrowSchurSystem, String> {
3610        self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3611            target,
3612            rho,
3613            analytic_penalties,
3614            penalty_scale,
3615            SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
3616        )
3617    }
3618
3619    pub(crate) fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3620        &mut self,
3621        target: ArrayView2<'_, f64>,
3622        rho: &SaeManifoldRho,
3623        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3624        penalty_scale: f64,
3625        dense_beta_penalty_probe_max_dim: usize,
3626    ) -> Result<ArrowSchurSystem, String> {
3627        self.assemble_arrow_schur_inner(
3628            target,
3629            rho,
3630            analytic_penalties,
3631            penalty_scale,
3632            dense_beta_penalty_probe_max_dim,
3633            None,
3634        )
3635    }
3636
3637    /// Innermost assembly entry. `forced_layout` overrides the budget-derived
3638    /// active-set layout so a caller can pin the dense (`Forced(None)`) or a
3639    /// specific compact (`Forced(Some(layout))`) path — used by the
3640    /// compact-vs-dense Riemannian-geometry equality regression test to drive
3641    /// both layouts on identical data. `Computed` is the production path:
3642    /// the layout is derived from the assignment mode + `sparse_active_plan`.
3643    pub(crate) fn assemble_arrow_schur_inner(
3644        &mut self,
3645        target: ArrayView2<'_, f64>,
3646        rho: &SaeManifoldRho,
3647        analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3648        penalty_scale: f64,
3649        dense_beta_penalty_probe_max_dim: usize,
3650        forced_layout: ForcedRowLayout,
3651    ) -> Result<ArrowSchurSystem, String> {
3652        if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3653            return Err(format!(
3654                "SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3655            ));
3656        }
3657        if target.dim() != (self.n_obs(), self.output_dim()) {
3658            return Err(format!(
3659                "SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
3660                self.n_obs(),
3661                self.output_dim(),
3662                target.dim()
3663            ));
3664        }
3665        if rho.log_ard.len() != self.k_atoms() {
3666            return Err(format!(
3667                "SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
3668                rho.log_ard.len(),
3669                self.k_atoms()
3670            ));
3671        }
3672        // `lambda_smooth` is indexed per-atom in the smoothness gradient/curvature
3673        // assembly (`lambda_smooth[atom_idx]`); a too-short vector (e.g. a growth
3674        // move that grew `k_atoms()` without extending ρ — #1556) would panic deep
3675        // in the assembly loop with an opaque index-out-of-bounds. Validate it here
3676        // alongside `log_ard` so the contract violation surfaces as a clear Err.
3677        if rho.log_lambda_smooth.len() != self.k_atoms() {
3678            return Err(format!(
3679                "SaeManifoldTerm::assemble_arrow_schur: log_lambda_smooth length {} != K {}",
3680                rho.log_lambda_smooth.len(),
3681                self.k_atoms()
3682            ));
3683        }
3684        for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3685            let ard_len = rho.log_ard[atom_idx].len();
3686            let d = coord.latent_dim();
3687            if ard_len != 0 && ard_len != d {
3688                return Err(format!(
3689                    "SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
3690                     has len {ard_len}; expected 0 (disabled) or atom dim {d}"
3691                ));
3692            }
3693        }
3694        // Reparameterize each atom's roughness Gram into arc length at the
3695        // current decoder/coordinates (issue #673). This is the single
3696        // chokepoint for both the inner Newton assembly and the undamped
3697        // evidence factorization, so freezing the pullback-metric weight here
3698        // (lagged-diffusivity) keeps the smoothness value, gradient, Kronecker
3699        // Hessian, and REML log-det mutually consistent within each assembly
3700        // and makes the converged penalty — hence the topology evidence —
3701        // gauge-invariant. Constant-speed (periodic) atoms are unaffected.
3702        for atom in &mut self.atoms {
3703            atom.refresh_intrinsic_smooth_penalty();
3704        }
3705        // #1026 — freeze the decoder-repulsion collinearity gate at the SAME
3706        // assembly chokepoint as the smoothness Gram, so the repulsion's
3707        // gradient/curvature (assembled below) and its value (read by the
3708        // line-search `penalized_objective_total`) share one frozen gate.
3709        self.refresh_decoder_repulsion_gate();
3710        // #1625 — freeze the SEPARATION barrier's normalized-coactivation `q_jk`
3711        // at the same chokepoint. The barrier weights its decoder-shape repulsion
3712        // by the routing coactivation, but its gradient treats that weight as a
3713        // constant; recomputing it from the trial logits in the line-search value
3714        // desyncs value vs gradient in the logit block and stalls the inner solve
3715        // (#1625). Freezing it here makes value/gradient/curvature consistent.
3716        self.refresh_barrier_coactivation_gate();
3717        let n = self.n_obs();
3718        let p = self.output_dim();
3719        let k_atoms = self.k_atoms();
3720        let assignment_dim = self.assignment.assignment_coord_dim();
3721        let q = self.assignment.row_block_dim();
3722        let beta_dim = self.beta_dim();
3723        let frame_projection = FrameProjection::new(self);
3724        let beta_offsets = frame_projection.beta_offsets.clone();
3725        let coord_offsets = self.assignment.coord_offsets();
3726        // β-tier decoder smoothness is a global (B-only) penalty; under a
3727        // minibatch pass it is scaled by the chunk fraction so the per-chunk
3728        // contributions sum to one global copy.
3729        // Per-atom decoder-smoothness strengths (#1556): atom k's penalty `S_k`
3730        // is scaled by `λ_smooth[k]·penalty_scale`. The minibatch `penalty_scale`
3731        // multiplies every atom uniformly.
3732        let lambda_smooth: Vec<f64> = rho
3733            .lambda_smooth_vec()
3734            .iter()
3735            .map(|&l| l * penalty_scale)
3736            .collect();
3737        let (assignment_grad, assignment_hdiag) =
3738            assignment_prior_grad_hdiag(&self.assignment, rho)?;
3739
3740        // #1038 softmax entropy: the exact per-row Hessian in logits is dense
3741        // (`H_kj = (λ/τ²) a_k[δ_kj(m−L_k−1)+a_j(L_k+L_j+1−2m)]`), not just the
3742        // `assignment_hdiag` diagonal. Build the shared penalty + `scale = λ/τ²`
3743        // once here so the dense row block written into `block.htt` below, the
3744        // criterion's `log|H|`, and the #1006 θ-adjoint all differentiate the
3745        // SAME operator. JumpReLU / IBP keep their (separately exact) diagonal /
3746        // cross-row channels and leave this `None`. The block is gauge-null in
3747        // isolation (`H·𝟙 = 0`); it is only ever summed onto the gauge-breaking
3748        // data-fit row block before the Cholesky factor, never factored alone.
3749        let softmax_dense: Option<(
3750            gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
3751            f64,
3752        )> = match self.assignment.mode {
3753            AssignmentMode::Softmax {
3754                temperature,
3755                sparsity,
3756            } if k_atoms > 1 => {
3757                let inv_tau = 1.0 / temperature;
3758                let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
3759                Some((
3760                    gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
3761                        k_atoms,
3762                        temperature,
3763                    ),
3764                    scale,
3765                ))
3766            }
3767            _ => None,
3768        };
3769
3770        // Decoder smoothness penalty: build one KroneckerPenaltyOp per atom
3771        // (structure = λ·S_k ⊗ I_p, offset = beta_offsets[k]) instead of
3772        // materialising the dense K×K block.  The gradient is a dense K-vector
3773        // accumulated into `smooth_grad_gb` and written into sys.gb after sys
3774        // is constructed (#296).
3775        let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
3776        // #972 / #977 T1: retain each atom's symmetrised `λ S_k` (`M_k × M_k`) so
3777        // the frame transform can rebuild the smooth penalty in the factored
3778        // coordinate space as `λ S_k ⊗ I_{r_k}` (the `tr(C_kᵀ S_k C_k)` form,
3779        // using `U_kᵀU_k = I`). Unused — and not even read — on the full-`B`
3780        // path, so this is a zero-cost capture there.
3781        let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
3782        let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
3783        // #1117 — rank deficiency is handled at the basis layer: any
3784        // rank-deficient atom was reparametrized onto its data-supported subspace
3785        // at fit entry (`reduce_atoms_to_data_supported_rank`), so the β-tier here
3786        // always sees a full-rank design and needs no step-time data-null
3787        // deflation operator. The well-conditioned (full-rank) path is unchanged.
3788        // Per-atom smoothness-gradient GEMMs `½(S_k+S_kᵀ)·B_k` are independent
3789        // across atoms; batch them across ALL GPUs (uniform-shape tiles) and
3790        // scale by `lambda_smooth` below. `symmetrize = true` reproduces the
3791        // per-atom symmetrised `scaled_s/λ` used by the Kronecker op. Exact CPU
3792        // fallback per atom keeps the result bit-for-bit with the all-CPU path.
3793        let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3794            .atoms
3795            .iter()
3796            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3797            .collect();
3798        let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
3799        for (atom_idx, atom) in self.atoms.iter().enumerate() {
3800            let m = atom.basis_size();
3801            let off = beta_offsets[atom_idx];
3802            // Symmetrise and scale the smoothness penalty matrix.
3803            let mut scaled_s = Array2::<f64>::zeros((m, m));
3804            for i in 0..m {
3805                for j in 0..m {
3806                    let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
3807                    scaled_s[[i, j]] = lambda_smooth[atom_idx] * s_ij;
3808                }
3809            }
3810            // Gradient: g[beta_i] += (λ_k S_k B_k)[i, out_col]. The (m×m)·(m×p)
3811            // GEMM `½(S+Sᵀ)·B_k` was computed in the multi-GPU batch above; here
3812            // we only apply atom k's `lambda_smooth[atom_idx]`.
3813            let sb = &sym_sb_all[atom_idx] * lambda_smooth[atom_idx];
3814            for out_col in 0..p {
3815                for i in 0..m {
3816                    let beta_i = off + i * p + out_col;
3817                    smooth_grad_gb[beta_i] += sb[[i, out_col]];
3818                }
3819            }
3820            // IdentityRightKroneckerPenaltyOp: factor_a = λ·S_k (m×m), factor_b = I_p.
3821            smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
3822                factor_a: scaled_s.clone(),
3823                p,
3824                global_offset: off,
3825                k: beta_dim,
3826            }));
3827            // Retain `λ S_k` for the factored rebuild (no-op cost on full-`B`).
3828            smooth_scaled_s.push(scaled_s);
3829        }
3830
3831        // Per-row active-set layout. Engaged for two regimes:
3832        //   * JumpReLU — structural gate plus the smooth prior's
3833        //     machine-precision support: atoms with
3834        //     `(logit - threshold)/tau > -36` enter the compact solve
3835        //     ([`jumprelu_in_optimization_band`]). Strictly gated-off atoms
3836        //     (logit ≤ threshold) carry zero assignment mass so their data-fit
3837        //     reconstruction contribution and data-fit logit JVP are zero, but
3838        //     supported atoms keep value-consistent prior gradient in the row block.
3839        //   * IBP-MAP at large `K` — the dense `(m_total · p)²` data
3840        //     Gram is infeasible, so each row is truncated to its
3841        //     top-`k_active` atoms above a relative magnitude cutoff
3842        //     ([`Self::sparse_active_plan`]). Small-`K` problems return `None`
3843        //     and keep the exact full-support layout.
3844        // The compact row block is sized `q_active = |active| + Σ_{k∈active}
3845        // d_k` instead of the full `q`.
3846        let coord_dims: Vec<usize> = self
3847            .assignment
3848            .coords
3849            .iter()
3850            .map(|c| c.latent_dim())
3851            .collect();
3852        let row_layout: Option<SaeRowLayout> = match forced_layout {
3853            Some(layout) => layout,
3854            None => match self.assignment.mode {
3855                AssignmentMode::ThresholdGate {
3856                    threshold,
3857                    temperature,
3858                } => Some(SaeRowLayout::from_jumprelu(
3859                    n,
3860                    k_atoms,
3861                    threshold,
3862                    temperature,
3863                    &self.assignment.logits,
3864                    coord_dims.clone(),
3865                    self.assignment.coord_offsets(),
3866                )),
3867                // #1408/#1409 — Softmax engages the COMPACT top-`k` row layout
3868                // inside the optimization (no longer a post-fit projection).
3869                // The active set is each row's top-`k_active_cap` softmax atoms
3870                // above the relative cutoff; the cap comes from the user's
3871                // `top_k` (`softmax_active_cap`) and/or the in-core memory budget
3872                // ([`Self::softmax_active_plan`]). The full-`K` softmax
3873                // normalization still forms `a` (the gate map); only the dropped
3874                // tail logits, carrying negligible `O(a)` reconstruction mass and
3875                // `O(a²)` curvature, leave the per-row block.
3876                //
3877                // Coherence (the load-bearing correctness invariant): the
3878                // assembly's softmax curvature branch writes the ACTIVE×ACTIVE
3879                // principal sub-block of the Gershgorin Loewner majorizer
3880                // `D = diag(Σ_j|H_kj|)` (#1419; PSD and `D ⪰ H_entropy`) on the
3881                // compact logit slots — NOT the indefinite `assignment_hdiag`
3882                // diagonal. The logdet ρ-trace
3883                // (`assignment_log_strength_hessian_trace`) iterates the row's
3884                // active logit slots and indexes that SAME majorizer by global
3885                // atom, and the θ-adjoint reads its derivative via `jets.vars`
3886                // (global-atom indexed), so value, log|H|, and Γ differentiate
3887                // ONE operator on the compact support. The FFI's after-the-fit
3888                // top-`k` projection is then a no-op at the optimum.
3889                AssignmentMode::Softmax { .. } => match self.softmax_active_plan() {
3890                    Some((k_active_cap, relative_cutoff)) => {
3891                        let mut assignments_all = Vec::with_capacity(n);
3892                        for row in 0..n {
3893                            assignments_all
3894                                .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3895                        }
3896                        Some(SaeRowLayout::from_dense_weights(
3897                            &assignments_all,
3898                            k_active_cap,
3899                            relative_cutoff,
3900                            coord_dims.clone(),
3901                            self.assignment.coord_offsets(),
3902                        ))
3903                    }
3904                    None => None,
3905                },
3906                AssignmentMode::IBPMap { .. } => {
3907                    match self.sparse_active_plan() {
3908                        Some((k_active_cap, relative_cutoff)) => {
3909                            // Build per-row dense assignments once to derive the
3910                            // active set; the row loop re-derives `assignments`
3911                            // (cheap gate map at the same rho) and reuses these
3912                            // active sets.
3913                            let mut assignments_all = Vec::with_capacity(n);
3914                            for row in 0..n {
3915                                assignments_all
3916                                    .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3917                            }
3918                            // #1414: pass the RELATIVE cutoff through;
3919                            // `from_dense_weights` applies it per row against that
3920                            // row's own peak `max_k |a_{n,k}|`, matching the
3921                            // documented `sparse_active_plan` contract. A single
3922                            // global threshold (relative_cutoff · whole-dataset
3923                            // peak) wrongly drops every atom of a uniformly-small
3924                            // row when another row peaks high.
3925                            Some(SaeRowLayout::from_dense_weights(
3926                                &assignments_all,
3927                                k_active_cap,
3928                                relative_cutoff,
3929                                coord_dims.clone(),
3930                                self.assignment.coord_offsets(),
3931                            ))
3932                        }
3933                        None => None,
3934                    }
3935                }
3936            },
3937        };
3938        // #974 likelihood-whitening seam. The single per-row decision: when the
3939        // installed `RowMetric` is a genuinely estimated noise model
3940        // (`whitens_likelihood()` — only `WhitenedStructured`), the
3941        // reconstruction data-fit, its t-block Gauss-Newton row block, AND the
3942        // β-tier data-fit gradient are all assembled through the SAME per-row
3943        // metric `M_n = U_n U_nᵀ = Σ_n^{-1}`. There is exactly ONE construction
3944        // site (the `whiten_rows` closure below), so the value the line-search
3945        // sums and the gradient/Hessian the Newton step solves cannot drift apart
3946        // (the objective↔gradient-desync cure). For Euclidean / OutputFisher /
3947        // no-metric the closure is the identity and every downstream loop is
3948        // byte-identical to the historical isotropic path.
3949        let whitens_likelihood = self
3950            .row_metric
3951            .as_ref()
3952            .is_some_and(|metric| metric.whitens_likelihood());
3953        // #972 / #977 T1: engage the FACTORED Grassmann-coordinate β-tier when
3954        // any atom has an active decoder frame. The closed-form factorization
3955        // `Φᵀ(G ⊗ I_p)Φ = G ⊗ (U_iᵀU_j)` is EXACT only for the isotropic
3956        // likelihood; under an active whitening metric (`whitens_likelihood()`,
3957        // only `WhitenedStructured`) the per-row output factor would be
3958        // `U_iᵀ M_n U_j` and does NOT factor out of the basis Gram, so we fall
3959        // back to the full-`B` path there (frames + whitening is out of scope —
3960        // see #974). The common Euclidean / OutputFisher / no-metric case factors
3961        // cleanly. When `frames_engaged` is false, EVERY β-tier object below is
3962        // assembled bit-for-bit as the historical full-`B` path.
3963        let frames_engaged = self.any_frame_active() && !whitens_likelihood;
3964        // #1407: fixed-decoder mode skips the entire β decoder tier (G/gb/htbeta
3965        // operator/hbb/β-penalties); only per-row htt/gt are produced.
3966        let fixed_decoder = self.fixed_decoder_assembly;
3967        let admission_plan = self
3968            .streaming_plan()
3969            .admitted_or_error(self.n_obs(), self.output_dim(), self.k_atoms())
3970            .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
3971        // #1407: fixed-decoder builds NO dense β-Hessian (hbb) — force the
3972        // empty-hbb system constructor so no `beta_dim × beta_dim` workspace is
3973        // taken (the early return skips `reclaim_border_hbb_workspace`).
3974        let dense_beta_curvature = !fixed_decoder
3975            && admission_plan.direct_admitted
3976            && !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
3977        // #1406: the dense per-row cross-block slab `block.htbeta` is only WRITTEN
3978        // (line ~4243) and READ by the solver when `frames_engaged` (the factored
3979        // full-B path, which installs NO matrix-free row operator → the solver's
3980        // `sys_htbeta_apply_row` falls back to the dense slab). On the
3981        // `!frames_engaged` path the cross block is carried entirely by the
3982        // matrix-free Kronecker operator (`set_row_htbeta_operator`, ~line 4491);
3983        // `activate_dense_htbeta_supplement` is never called, so the solver never
3984        // touches `block.htbeta`. Allocating it at `beta_dim = K·M·p` there is the
3985        // ~6 TiB high-K leak (#1405/#1406): allocate ZERO columns instead. Frames
3986        // still use the (much smaller) factored border width.
3987        // #795/#1406/#1407: the non-frames matrix-free path normally holds a
3988        // ZERO-width per-row cross-block slab — the data-fit `H_tβ` is carried by
3989        // the Kronecker row operator (`set_row_htbeta_operator`), and allocating
3990        // the dense slab at `beta_dim = K·M·p` is the high-K memory leak. But an
3991        // ISOMETRY penalty on a coherence-preserving (flat) chart scatters an
3992        // ADDITIONAL Gauss-Newton cross-block into the dense per-row `htbeta`
3993        // slab and flips on `activate_dense_htbeta_supplement` — dropping it would
3994        // leave the Newton system block-diagonal and forfeit the strong `t↔B`
3995        // isometry coupling the circle fit needs to reach KKT stationarity (#795).
3996        // So on the non-frames path widen the slab to `beta_dim` exactly when that
3997        // dense supplement will be written, and keep zero width otherwise.
3998        let dense_isometry_cross_block = !fixed_decoder
3999            && analytic_penalties
4000                .map(|registry| self.registry_writes_dense_isometry_cross_block(registry))
4001                .unwrap_or(false);
4002        let row_htbeta_dim = if fixed_decoder {
4003            // Fixed-decoder mode skips the β tier entirely.
4004            0
4005        } else if frames_engaged {
4006            self.factored_border_dim()
4007        } else if dense_isometry_cross_block {
4008            // Matrix-free data-fit cross-block + dense isometry supplement: the
4009            // supplement is written/read in the full-`B` β coordinate system.
4010            beta_dim
4011        } else {
4012            // Matrix-free path with no dense cross-block supplement.
4013            0
4014        };
4015        // Build the Arrow-Schur system: heterogeneous row dims when a compact
4016        // layout is active, uniform `q` otherwise.
4017        let mut sys = if let Some(ref layout) = row_layout {
4018            let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
4019            if dense_beta_curvature {
4020                let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4021                ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
4022                    per_row_dims,
4023                    beta_dim,
4024                    hbb_workspace,
4025                    row_htbeta_dim,
4026                )
4027            } else {
4028                self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4029                ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
4030                    per_row_dims,
4031                    beta_dim,
4032                    row_htbeta_dim,
4033                )
4034            }
4035        } else if dense_beta_curvature {
4036            let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4037            ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
4038                n,
4039                q,
4040                beta_dim,
4041                hbb_workspace,
4042                row_htbeta_dim,
4043            )
4044        } else {
4045            self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4046            ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
4047        };
4048        // Apply accumulated smoothness-penalty gradients into sys.gb.
4049        for (i, g) in smooth_grad_gb.iter().enumerate() {
4050            sys.gb[i] += g;
4051        }
4052        // `w_dim` is the whitened output dimension: `rank` of the metric factor
4053        // when whitening, else `p` (identity). `error_white` is the whitened
4054        // residual `U_nᵀ r_n ∈ ℝ^{w_dim}` whose squared norm is `r_nᵀ M_n r_n`,
4055        // shared by the value path, the t-block GN, and (lifted back to p-space)
4056        // the β-tier gradient.
4057        let w_dim = match self.row_metric.as_ref() {
4058            Some(metric) if whitens_likelihood => metric.metric_rank(),
4059            _ => p,
4060        };
4061        // Data-fit Gauss-Newton β-Hessian is block-diagonal across the `p`
4062        // output channels and identical in each: with the flat β layout
4063        // `β[μ·p + oc] = B[μ, oc]` (μ enumerating (atom, basis_col)) the GN
4064        // outer product `Jβᵀ Jβ` couples only equal `oc`, with the same
4065        // `(M_total × M_total)` block `G[μ, μ'] = Σ_rows (a_k φ_k[m])(a_{k'} φ_{k'}[m'])`
4066        // for every channel. So `H_data = G ⊗ I_p`. The `μ` index of an `a_phi`
4067        // entry whose global β base is `beta_base` is `beta_base / p` (every
4068        // `beta_offset` and the `basis_col·p` stride are multiples of `p`).
4069        //
4070        // `G` is only non-zero on `(atom_i, atom_j)` pairs that co-occur in
4071        // some row's active set, so we accumulate it as a sparse map of dense
4072        // per-atom-pair `(m_i × m_j)` blocks keyed by `(atom_i, atom_j)` rather
4073        // than as a dense `(m_total × m_total)` matrix. At `K = 100K` with
4074        // per-row active sets of size `k_active ≪ K`, only `O(N · k_active²)`
4075        // pairs are ever touched, so the data Gram (and every matvec /
4076        // diagonal pass over it via `SparseBlockKroneckerPenaltyOp`) tracks the
4077        // active atoms instead of `K²`. In the dense full-support layout the
4078        // map degenerates to every co-occurring pair, reproducing the dense
4079        // Gram exactly. A `BTreeMap` key order keeps the installed op's
4080        // fingerprint deterministic. The `μ`-space offset of atom `k` is
4081        // `beta_offsets[k] / p`.
4082        type SaeGBlocks = std::collections::BTreeMap<(usize, usize), Array2<f64>>;
4083        let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
4084        let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
4085        // Stick-breaking prior for IBP-MAP depends only on (k_atoms, alpha_eff)
4086        // which are constant across rows for the current rho; precompute once.
4087        let ibp_prior_vec = match self.assignment.mode {
4088            AssignmentMode::IBPMap { .. } => {
4089                let alpha = self
4090                    .assignment
4091                    .resolved_ibp_alpha(rho)
4092                    .ok_or_else(|| "IBP assignment alpha resolution failed".to_string())?;
4093                Some(ordered_geometric_shrinkage_prior(k_atoms, alpha).to_vec())
4094            }
4095            _ => None,
4096        };
4097        let ibp_prior_slice = ibp_prior_vec.as_deref();
4098        // #991 design honesty weights (mean-1 HT inclusion corrections); see
4099        // the seam comment at the per-row residual below.
4100        let row_loss_w = self.row_loss_weights.as_deref();
4101        // Dense full-support index `[0, k_atoms)`, used by the row loop when no
4102        // compact layout is engaged so the active-atom iteration is uniform.
4103        let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
4104        // Per-atom per-axis periodicity, hoisted out of the row loop. Selects
4105        // the smooth von-Mises coordinate prior on wrapped (Circle) axes and
4106        // the Gaussian prior on Euclidean axes; see `ArdAxisPrior`.
4107        let ard_axis_periods: Vec<Vec<Option<f64>>> = self
4108            .assignment
4109            .coords
4110            .iter()
4111            .map(|coord| coord.effective_axis_periods())
4112            .collect();
4113        struct SaeAssemblyRow {
4114            pub(crate) row: usize,
4115            pub(crate) block: ArrowRowBlock,
4116            pub(crate) gb_delta: Vec<(usize, f64)>,
4117            pub(crate) g_blocks: SaeGBlocks,
4118            pub(crate) kron_a_phi: Option<Vec<(usize, f64)>>,
4119            pub(crate) kron_jac: Option<Vec<f64>>,
4120        }
4121
4122        // Per-row scratch reused across all rows a rayon worker processes
4123        // (#1017). The assembly closure is re-run every inner Newton iteration ×
4124        // every outer ρ evaluation; allocating these eight loop-invariant-sized
4125        // buffers (`k_atoms·p`, several `p`, one `q·max(w_dim,p)`) once per
4126        // worker via `map_init` — rather than once per (row × assembly) inside
4127        // the closure — removes the dominant small-allocation traffic the
4128        // eu-stack profile attributed to allocator/barrier spin at the SAE LLM
4129        // shape (p≈5120). Every buffer is fully filled (or `.fill(0.0)`'d) before
4130        // it is read each row, so reuse is bit-identical to the fresh-alloc path;
4131        // `gb_delta`/`g_blocks` are NOT scratch (they move into the returned
4132        // `SaeAssemblyRow`) and stay allocated per row.
4133        struct RowScratch {
4134            pub(crate) decoded: Array2<f64>,
4135            pub(crate) dg_buf: Vec<f64>,
4136            pub(crate) fitted: Array1<f64>,
4137            pub(crate) error: Array1<f64>,
4138            pub(crate) error_white: Vec<f64>,
4139            pub(crate) error_metric: Array1<f64>,
4140            pub(crate) jac_white: Vec<f64>,
4141            pub(crate) decoded_scratch: Vec<f64>,
4142            // #1557 — per-worker scratch for the row assignment vector (filled via
4143            // `_into`, not allocated per row); full `k_atoms`, global-atom indexed.
4144            pub(crate) assignments: Array1<f64>,
4145        }
4146        // #1410: size the per-worker scratch by the COMPACT row dimensions, not
4147        // full `K`/`q`. With a compact layout the assembly only ever touches each
4148        // row's active atoms (≤ `max_active`) and its compact tangent block
4149        // (≤ `max_q_row`); allocating `decoded` at `k_atoms·p` and `jac_white` at
4150        // `q·max(w_dim,p)` was the per-worker `O(K)` blow-up (≈11 GiB/worker at
4151        // K=100k, p=5120 — and `map_init` gives every Rayon worker its own copy).
4152        // Without a layout the dense path needs full `k_atoms`/`q`. `decoded` rows
4153        // are addressed by COMPACT SLOT in the compact branch below (the dense
4154        // branch keeps global-atom rows), so the row count is the max active set.
4155        //
4156        // #1410/#1408/#1409: SOFTMAX now ALSO takes the `Some(layout)` branch
4157        // whenever a `top_k` cap (`set_softmax_active_cap`) or an in-core memory
4158        // breach engages `softmax_active_plan` → `from_dense_weights`, so its
4159        // per-worker `decoded`/`jac_white` scratch is the COMPACT
4160        // `max_active`/`max_q_row` size too — no longer the full `(k_atoms·p)` /
4161        // `(q·max(w_dim,p))` blow-up. JumpReLU / IBP-MAP likewise pay only
4162        // `max_active`. The remaining `None` (full-`K`) branch is the UNCAPPED
4163        // softmax / no-budget-breach case, which genuinely assembles the dense
4164        // entropy block over all `K`; capping it (the compact contract) removes
4165        // the per-worker `O(K)` footprint entirely. (#1410: the residual per-row
4166        // `O(K)` softmax-majorizer scratch — a `row_logits` copy and the full-`K`
4167        // `d`/`H_entropy` blocks — is removed separately; see the active-only
4168        // `active_softmax_gershgorin_majorizer_entry` /
4169        // `softmax_dense_entropy_hessian_entry` helpers below.)
4170        let (decoded_rows, scratch_q) = match row_layout.as_ref() {
4171            Some(layout) => {
4172                let max_active = (0..n)
4173                    .map(|r| layout.active_atoms[r].len())
4174                    .max()
4175                    .unwrap_or(0)
4176                    .max(1);
4177                let max_q_row = (0..n)
4178                    .map(|r| layout.row_q_active(r))
4179                    .max()
4180                    .unwrap_or(q)
4181                    .max(1);
4182                (max_active, max_q_row)
4183            }
4184            None => (k_atoms, q),
4185        };
4186        use rayon::iter::{IntoParallelIterator, ParallelIterator};
4187        // #1033 large-n: fold the per-row assembly results in row-ordered CHUNKS
4188        // rather than collecting all `n` `SaeAssemblyRow`s at once. The previous
4189        // path materialized the FULL `Vec<SaeAssemblyRow>` (every row's htt/gt
4190        // block + per-row `g_blocks` + `kron_a_phi`/`kron_jac`) AND the fold
4191        // destinations simultaneously — a ~2× transient peak over the resident
4192        // system during the fold, the assembly-side OOM cliff at large `n`. By
4193        // collecting one chunk, folding it into `sys.rows`/`g_blocks`/`kron_*`,
4194        // and dropping the chunk's `Vec` before the next chunk, the transient
4195        // intermediate is bounded to `O(chunk_size)` while the resident output is
4196        // unchanged. The fold stays STRICTLY row-ascending (chunk `[c0..c1)` then
4197        // `[c1..c2)`, rows in order within each chunk), so every `+=` into
4198        // `sys.gb`, the `g_blocks` BTreeMap, and the `kron_*` pushes lands in the
4199        // identical order as the single-pass fold — bit-for-bit the same system.
4200        // Chunk width is the admission plan's `chunk_size` (the same value
4201        // `streaming_plan` sizes for the matrix-free window), floored so a tiny
4202        // plan still makes forward progress.
4203        let assembly_chunk_rows = self
4204            .assembly_chunk_override
4205            .unwrap_or(admission_plan.chunk_size)
4206            .clamp(1, n.max(1));
4207        let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4208        let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
4209        let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
4210        let mut chunk_start = 0usize;
4211        while chunk_start < n {
4212            let chunk_end = (chunk_start + assembly_chunk_rows).min(n);
4213            let mut fold_offset_in_chunk = 0usize;
4214            let row_results: Vec<SaeAssemblyRow> = (chunk_start..chunk_end)
4215                .into_par_iter()
4216                .map_init(
4217                    || RowScratch {
4218                        decoded: Array2::<f64>::zeros((decoded_rows, p)),
4219                        dg_buf: vec![0.0_f64; p],
4220                        fitted: Array1::<f64>::zeros(p),
4221                        error: Array1::<f64>::zeros(p),
4222                        error_white: vec![0.0_f64; w_dim],
4223                        error_metric: Array1::<f64>::zeros(p),
4224                        jac_white: vec![0.0_f64; scratch_q * w_dim.max(p)],
4225                        decoded_scratch: vec![0.0_f64; p],
4226                        assignments: Array1::<f64>::zeros(k_atoms),
4227                    },
4228                    |scratch, row| -> Result<SaeAssemblyRow, String> {
4229                        // #1557 — mark this rayon row worker as a nested data-parallel
4230                        // region so any faer GEMM reached transitively from the per-row
4231                        // assembly (frame `Uᵀ` products, the per-row cross-block /
4232                        // Schur-accumulation matmuls, the Riemannian projections) pins to
4233                        // `Par::Seq` via `effective_global_parallelism` instead of
4234                        // re-fanning the global Rayon pool against this outer fan-out
4235                        // (the `spindle` barrier-spin). Serial vs parallel over these tiny
4236                        // per-row blocks is a single small product, so the result is
4237                        // bit-identical. The guard is held for the whole closure body
4238                        // including its `?`/`return` paths.
4239                        with_nested_parallel(|| {
4240                        let RowScratch {
4241                            decoded,
4242                            dg_buf,
4243                            fitted,
4244                            error,
4245                            error_white,
4246                            error_metric,
4247                            jac_white,
4248                            decoded_scratch,
4249                            assignments,
4250                        } = scratch;
4251                        let mut gb_delta: Vec<(usize, f64)> = Vec::new();
4252                        let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4253                        // #1557 — fill per-worker scratch (bit-identical to alloc path).
4254                        let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
4255                        self.assignment
4256                            .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
4257                        // Reconstruction uses the row's active support: for the dense
4258                        // full-support layout this is all atoms (exact); for a compact
4259                        // layout the dropped atoms carry negligible `O(a)` reconstruction
4260                        // mass and zero curvature, so excluding them keeps `fitted`,
4261                        // `error`, and the logit-JVP cross term `(decoded[k] − fitted)`
4262                        // mutually consistent with the curvature actually assembled.
4263                        fitted.fill(0.0);
4264                        let row_active_owned: Option<&[usize]> =
4265                            row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
4266                        match row_active_owned {
4267                            Some(active) => {
4268                                // #1410: `decoded` is a compact (max_active × p) buffer
4269                                // here; index it by the active-set SLOT `j` (the same
4270                                // index the compact tangent block / `coord_starts` use),
4271                                // NOT the global `atom_idx`.
4272                                for (j, &atom_idx) in active.iter().enumerate() {
4273                                    let a_k = assignments[atom_idx];
4274                                    self.atoms[atom_idx]
4275                                        .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4276                                    for out_col in 0..p {
4277                                        decoded[[j, out_col]] = decoded_scratch[out_col];
4278                                        fitted[out_col] += a_k * decoded_scratch[out_col];
4279                                    }
4280                                }
4281                            }
4282                            None => {
4283                                for atom_idx in 0..k_atoms {
4284                                    let a_k = assignments[atom_idx];
4285                                    self.atoms[atom_idx]
4286                                        .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4287                                    for out_col in 0..p {
4288                                        decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
4289                                        fitted[out_col] += a_k * decoded_scratch[out_col];
4290                                    }
4291                                }
4292                            }
4293                        }
4294                        for out_col in 0..p {
4295                            error[out_col] = fitted[out_col] - target[[row, out_col]];
4296                        }
4297                        // #991 design-honesty seam: a per-row scalar weight `w_row` on the
4298                        // reconstruction channel is exactly the metric `w_row · I_p`, so it
4299                        // is realized as a `√w_row` scaling of the THREE row-local data
4300                        // quantities at their construction sites — this residual, the
4301                        // latent Jacobian (below), and the β basis load `a·φ` (below).
4302                        // Every downstream data object then carries exactly one factor of
4303                        // `w_row` (gt, htt, htbeta, the β Gram `G`, and the β gradient),
4304                        // matching the `w_row`-weighted value `loss_scaled` sums; the
4305                        // per-row latent priors (assignment / ARD, added to `gt`/`htt`
4306                        // further down) are deliberately unweighted — see the
4307                        // `row_loss_weights` field docs. `None` ⇒ `sqrt_row_w == 1.0` and
4308                        // no multiply is applied (bit-identical unweighted path).
4309                        let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
4310                        if sqrt_row_w != 1.0 {
4311                            for out_col in 0..p {
4312                                error[out_col] *= sqrt_row_w;
4313                            }
4314                        }
4315                        // #974 seam (step 1/2): whiten the per-row residual ONCE.
4316                        //   * not whitening ⇒ `error_white == error` (length p) and
4317                        //     `error_metric == error`; every downstream loop is the
4318                        //     historical isotropic path bit-for-bit.
4319                        //   * whitening ⇒ `error_white = U_nᵀ r_n ∈ ℝ^{w_dim}` (its squared
4320                        //     norm is `r_nᵀ M_n r_n`, the value the data-fit sums) and
4321                        //     `error_metric = U_n (U_nᵀ r_n) = M_n r_n ∈ ℝ^p` (the p-space
4322                        //     metric-applied residual the β-tier gradient contracts).
4323                        match self.row_metric.as_ref() {
4324                            Some(metric) if whitens_likelihood => {
4325                                let wr = metric.whiten_residual_row(row, error.view());
4326                                for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
4327                                    *slot = v;
4328                                }
4329                                let mr = metric.apply_metric_row(row, error.view());
4330                                for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
4331                                    *slot = v;
4332                                }
4333                            }
4334                            _ => {
4335                                for out_col in 0..p {
4336                                    error_white[out_col] = error[out_col];
4337                                    error_metric[out_col] = error[out_col];
4338                                }
4339                            }
4340                        }
4341
4342                        // Determine whether this row uses the compact active-set layout.
4343                        //   * JumpReLU: gated atoms plus the smooth prior's
4344                        //     machine-precision support enter.
4345                        //   * IBP-MAP at large K: only the top-`k_active` atoms.
4346                        //   * Otherwise (small K): the dense uniform-q layout.
4347                        let (q_row, mut local_jac_row) = if let Some(layout) = row_layout.as_ref() {
4348                            let active = &layout.active_atoms[row];
4349                            let starts = &layout.coord_starts[row];
4350                            let q_active = layout.row_q_active(row);
4351                            let mut jac_compact = Array2::<f64>::zeros((q_active, p));
4352                            // Logit JVP rows for active atoms only, using the per-mode
4353                            // assignment sensitivity `da_k/dl_k` contracted into the
4354                            // decoded / fitted-corrected output direction.
4355                            let logits_row = self.assignment.logits.row(row);
4356                            for (j, &k) in active.iter().enumerate() {
4357                                fill_active_atom_logit_jvp(
4358                                    ActiveAtomLogitJvp {
4359                                        mode: self.assignment.mode,
4360                                        k,
4361                                        logit_k: logits_row[k],
4362                                        a_k: assignments[k],
4363                                        // #1410: compact slot `j`, not global atom `k`.
4364                                        decoded_k: decoded.row(j),
4365                                        fitted: fitted.view(),
4366                                        ibp_prior: ibp_prior_slice,
4367                                        compact_index: j,
4368                                        // #1026/#1033: a FIXED logit (ungated, or every
4369                                        // atom under frozen routing) has a constant gate
4370                                        // ⇒ zero logit-JVP.
4371                                        ungated: self.assignment.logit_is_fixed(k),
4372                                    },
4373                                    &mut jac_compact,
4374                                );
4375                            }
4376                            // Coordinate JVP rows for active atoms only.
4377                            for (j, &k) in active.iter().enumerate() {
4378                                let d = self.atoms[k].latent_dim;
4379                                let a_k = assignments[k];
4380                                let coord_start = starts[j];
4381                                for axis in 0..d {
4382                                    self.atoms[k].fill_decoded_derivative_row(
4383                                        row,
4384                                        axis,
4385                                        dg_buf.as_mut_slice(),
4386                                    );
4387                                    for out_col in 0..p {
4388                                        jac_compact[[coord_start + axis, out_col]] =
4389                                            a_k * dg_buf[out_col];
4390                                    }
4391                                }
4392                            }
4393                            (q_active, jac_compact)
4394                        } else {
4395                            // Fresh per-row Jacobian, structurally identical to the
4396                            // JumpReLU branch: every (q × p) element is unconditionally
4397                            // overwritten below (assignment-chart JVP rows + coordinate rows), so the
4398                            // `Array2::zeros` allocation needs no separate `fill(0.0)` and
4399                            // the populated buffer is returned by move without a clone.
4400                            let mut jac_row = Array2::<f64>::zeros((q, p));
4401                            fill_assignment_logit_jvp_rows(
4402                                self.assignment.mode,
4403                                self.assignment.logits.row(row),
4404                                assignments.view(),
4405                                decoded.view(),
4406                                fitted.view(),
4407                                ibp_prior_slice,
4408                                // #1026/#1033: zero logit-JVP rows for FIXED-logit atoms
4409                                // (ungated, and all atoms under frozen routing).
4410                                &self.assignment.fixed_logit_mask(),
4411                                &mut jac_row,
4412                            );
4413                            // Coordinate columns for all atoms.
4414                            for atom_idx in 0..k_atoms {
4415                                let d = self.atoms[atom_idx].latent_dim;
4416                                let off = coord_offsets[atom_idx];
4417                                let a_k = assignments[atom_idx];
4418                                for axis in 0..d {
4419                                    self.atoms[atom_idx].fill_decoded_derivative_row(
4420                                        row,
4421                                        axis,
4422                                        dg_buf.as_mut_slice(),
4423                                    );
4424                                    for out_col in 0..p {
4425                                        jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
4426                                    }
4427                                }
4428                            }
4429                            (q, jac_row)
4430                        };
4431
4432                        // #991 design-honesty seam, Jacobian leg: scale the row's latent
4433                        // Jacobian by `√w_row` BEFORE the whitening / Kronecker capture so
4434                        // htt (= J̃J̃ᵀ), the data part of gt (= J̃ẽ, the residual already
4435                        // carries its own √w_row), and the htbeta cross block (J paired
4436                        // with the √w_row-scaled β load below) each carry exactly one
4437                        // factor of `w_row`. No-op on the unweighted path.
4438                        if sqrt_row_w != 1.0 {
4439                            for a in 0..q_row {
4440                                for out_col in 0..p {
4441                                    local_jac_row[[a, out_col]] *= sqrt_row_w;
4442                                }
4443                            }
4444                        }
4445
4446                        // #974 seam (step 2/2): whiten the per-row Jacobian through the SAME
4447                        // metric the residual was whitened by. `jac_white[a*w_dim + k]` holds
4448                        // `J̃[a, k] = Σ_out U_n[out, k] · J_n[a, out]` so the t-block
4449                        // Gauss-Newton row block is `htt = J̃ J̃ᵀ = J_n M_n J_nᵀ` and
4450                        // `gt = J̃ ẽ = J_nᵀ M_n r_n`. When not whitening, `w_dim == p` and the
4451                        // whitened jac equals the raw Jacobian, so htt/gt are byte-identical
4452                        // to the historical isotropic assembly. Because the SAME `error_white`
4453                        // feeds both the value-path data-fit (Σ½ ẽ²) and this gradient
4454                        // (J̃ ẽ), the objective and its t-block gradient share one whitening
4455                        // — they cannot desync.
4456                        if whitens_likelihood {
4457                            if let Some(metric) = self.row_metric.as_ref() {
4458                                for a in 0..q_row {
4459                                    for k in 0..w_dim {
4460                                        let mut acc = 0.0;
4461                                        // U_n[out, k] read through the metric's factor layout.
4462                                        for out_col in 0..p {
4463                                            acc += metric.factor_entry(row, out_col, k)
4464                                                * local_jac_row[[a, out_col]];
4465                                        }
4466                                        jac_white[a * w_dim + k] = acc;
4467                                    }
4468                                }
4469                            }
4470                        } else {
4471                            for a in 0..q_row {
4472                                for out_col in 0..p {
4473                                    jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
4474                                }
4475                            }
4476                        }
4477
4478                        // Build the per-row Arrow-Schur block at the row's active dim.
4479                        let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
4480                        for a in 0..q_row {
4481                            let jac_a = &jac_white[a * w_dim..(a + 1) * w_dim];
4482                            let g = jac_a
4483                                .iter()
4484                                .zip(error_white.iter())
4485                                .map(|(&j, &e)| j * e)
4486                                .sum::<f64>();
4487                            block.gt[a] += g;
4488                            for b in 0..q_row {
4489                                let jac_b = &jac_white[b * w_dim..(b + 1) * w_dim];
4490                                let h = jac_a
4491                                    .iter()
4492                                    .zip(jac_b.iter())
4493                                    .map(|(&ja, &jb)| ja * jb)
4494                                    .sum::<f64>();
4495                                block.htt[[a, b]] += h;
4496                            }
4497                        }
4498
4499                        // Assignment prior in logit space.
4500                        // For compact layout: position `j` = active_atoms index.
4501                        // For dense layout: position `atom_idx` directly.
4502                        //
4503                        // H-consistency note (#1006 audit / #1416 update). This
4504                        // `assignment_hdiag` is the assignment channel's raw diagonal
4505                        // curvature, added un-majorized. It is exact for JumpReLU and exact
4506                        // within each IBP row/column diagonal, and stores ONLY the diagonal of
4507                        // two full-Hessian structures — but those off-diagonal structures are
4508                        // now carried elsewhere, not dropped:
4509                        //
4510                        //   * softmax entropy has dense within-row Hessian
4511                        //     H_kj = (λ/τ²) a_k[δ_kj(m-L_k-1) + a_j(L_k+L_j+1-2m)];
4512                        //     this diagonal stores its Gershgorin Loewner majorizer (#1419).
4513                        //   * IBP empirical-π has cross-row rank-one terms per column
4514                        //     H_(i,k),(j,k) = w score_derivative_k z'_ik z'_jk for i != j.
4515                        //     This per-row diagonal stores only the diagonal/self-row part;
4516                        //     the FULL rank-one cross-row block `U D Uᵀ` is now INSTALLED as a
4517                        //     separate Woodbury source by `set_ibp_cross_row_source` (#1038),
4518                        //     so the assembled operator is `H_full = H₀' + U D Uᵀ` on the
4519                        //     NO-SELF base `H₀' = H₀ − Σ_k d_k diag(z'_ik²)` (self term
4520                        //     downdated, see `IbpCrossRowSource::self_term_downdate`). The
4521                        //     scalar `D`-coefficient `d_k = w·s'_k` is
4522                        //     `IbpHessianDiagThirdChannels::cross_row_d` (FD-verified against
4523                        //     ∂²value/∂ℓ_ik∂ℓ_jk in
4524                        //     `ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian`),
4525                        //     and `z_jac` carries `u_k`'s entries `z'_ik`.
4526                        //
4527                        // The criterion's log|H| and Γ adjoint differentiate this SAME
4528                        // `H_full`: the ρ-trace adds the cross-row off-diagonal in
4529                        // `assignment_log_strength_hessian_trace` (#1416, dense AND compact
4530                        // layouts) and the θ-adjoint adds it in `logdet_theta_adjoint`
4531                        // (#1416/#1641), so value and gradient stay on one operator.
4532                        let assignment_base = row * k_atoms;
4533                        if let Some(layout) = row_layout.as_ref() {
4534                            let active = &layout.active_atoms[row];
4535                            // #1408/#1409 softmax compact curvature: the entropy
4536                            // Hessian diagonal in `assignment_hdiag` is INDEFINITE,
4537                            // so on a compact softmax layout write the Gershgorin
4538                            // Loewner majorizer `D_kk = Σ_j|H_kj|` (#1419) — the same
4539                            // PSD operator the dense softmax branch writes — at each
4540                            // active logit slot. `D` is diagonal, so its active
4541                            // principal sub-block is `diag(D_kk : k ∈ active)`; each
4542                            // `D_kk` is the FULL-`K` abs-row-sum, so it still
4543                            // dominates the active principal sub-block of `H_entropy`
4544                            // (a genuine majorizer on the retained support). The
4545                            // gradient stays the EXACT entropy gradient (it sets the
4546                            // fixed point), so majorizing only conditions the Newton
4547                            // step. JumpReLU/IBP keep their (exact) diagonal.
4548                            //
4549                            // #1410: compute only the active `D_kk` directly from this
4550                            // row's softmax assignments `a` (= `assignments`, already
4551                            // in hand), via `active_softmax_gershgorin_majorizer_entry`.
4552                            // The previous `psd_majorizer_abs_row_sums(&row_logits, ..)`
4553                            // call allocated TWO length-`K` per-row scratch vectors (a
4554                            // fresh `row_logits` copy and the full-`K` returned `d`)
4555                            // only to read `d[k]` for the `≤ top_k` active `k` — an
4556                            // `O(K)` per-row allocation on the path the compact
4557                            // contract keeps `K`-free. The shared `m = Σ_j a_j l_j` is
4558                            // the one irreducible `O(K)` pass, computed once per row.
4559                            let assignments_slice = assignments
4560                                .as_slice()
4561                                .expect("softmax assignments row must be contiguous");
4562                            let majorizer_log_mean: Option<f64> = softmax_dense
4563                                .as_ref()
4564                                .map(|_| softmax_majorizer_log_mean(assignments_slice));
4565                            for (j, &k) in active.iter().enumerate() {
4566                                block.gt[j] += assignment_grad[assignment_base + k];
4567                                match (softmax_dense.as_ref(), majorizer_log_mean) {
4568                                    (Some((_penalty, scale)), Some(m)) => {
4569                                        block.htt[[j, j]] +=
4570                                            active_softmax_gershgorin_majorizer_entry(
4571                                                assignments_slice,
4572                                                k,
4573                                                m,
4574                                                *scale,
4575                                            );
4576                                    }
4577                                    _ => block.htt[[j, j]] += assignment_hdiag[assignment_base + k],
4578                                }
4579                            }
4580                        } else {
4581                            for free_idx in 0..assignment_dim {
4582                                block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
4583                            }
4584                            if let Some((penalty, scale)) = softmax_dense.as_ref() {
4585                                // #1419: write the genuine Gershgorin Loewner majorizer
4586                                // `D = diag(Σ_j|H_kj|)` of the exact entropy Hessian onto the
4587                                // row's logit block in place of the EXACT entropy Hessian. The
4588                                // entropy Hessian is INDEFINITE (concave directions on
4589                                // long-tailed rows), which drove the per-row evidence block
4590                                // non-PD and forced the downstream Faddeev–Popov deflation to
4591                                // flatten data-relevant logit directions (under-identifying the
4592                                // atoms). `D` is a nonnegative diagonal, hence exactly PSD and
4593                                // PD-preserving like the previous Fisher surrogate, so the block
4594                                // stays PD and the deflation no longer fires on the entropy
4595                                // block. Unlike the Fisher metric `G = scale·(diag(a) − a aᵀ)`,
4596                                // which is PSD but NOT a majorizer (`G − H_entropy` can be
4597                                // indefinite — K=2, a=(0.95,0.05): G₁₁=0.0475 < H₁₁=0.0784,
4598                                // #1419), `D` actually satisfies `D ⪰ H_entropy` and `D ⪰ 0`,
4599                                // so it is a true MM/Loewner curvature majorizer. Because the
4600                                // entropy penalty is a FIXED prior whose stationary point is set
4601                                // by its (unchanged) EXACT gradient, replacing its curvature
4602                                // with the majorizer only conditions the Newton step and the
4603                                // Laplace normalizer's curvature operator — it does NOT move the
4604                                // optimum.
4605                                //
4606                                // Softmax uses the REDUCED K−1 free-logit chart (the last
4607                                // reference logit is fixed at 0, `assignment_coord_dim() = K−1`).
4608                                // Holding z_{K-1} fixed, the reduced curvature over the free
4609                                // logits 0..K−1 is exactly the top-left (K−1)×(K−1) submatrix of
4610                                // the full K×K majorizer (the fixed logit contributes no
4611                                // row/column to the free curvature). The criterion's `log|H|`
4612                                // and the #1006 θ-adjoint differentiate this SAME `D` (see the
4613                                // `row_psd_majorizer_logit_derivative` site below), so value and
4614                                // adjoint stay on one exact branch.
4615                                let row_logits: Vec<f64> = (0..k_atoms)
4616                                    .map(|k| self.assignment.logits[[row, k]])
4617                                    .collect();
4618                                let h_dense = penalty.row_psd_majorizer(&row_logits, *scale);
4619                                for ki in 0..assignment_dim {
4620                                    for kj in 0..assignment_dim {
4621                                        block.htt[[ki, kj]] += h_dense[[ki, kj]];
4622                                    }
4623                                }
4624                            } else {
4625                                for free_idx in 0..assignment_dim {
4626                                    block.htt[[free_idx, free_idx]] +=
4627                                        assignment_hdiag[assignment_base + free_idx];
4628                                }
4629                            }
4630                        }
4631
4632                        // ARD on each on-atom coordinate.
4633                        // For compact layout: only active atoms; coord positions use compact starts.
4634                        // For dense layout: all atoms; coord positions use coord_offsets.
4635                        if let Some(layout) = row_layout.as_ref() {
4636                            let active = &layout.active_atoms[row];
4637                            let starts = &layout.coord_starts[row];
4638                            for (j, &k) in active.iter().enumerate() {
4639                                let coord = &self.assignment.coords[k];
4640                                let d = coord.latent_dim();
4641                                if rho.log_ard[k].is_empty() {
4642                                    continue;
4643                                }
4644                                if rho.log_ard[k].len() != d {
4645                                    return Err(format!(
4646                                        "ARD rho atom {k} has len {} but atom dim is {d}",
4647                                        rho.log_ard[k].len()
4648                                    ));
4649                                }
4650                                let row_t = coord.row(row);
4651                                let periods = &ard_axis_periods[k];
4652                                for axis in 0..d {
4653                                    // ARD on coords is a genuine per-row prior (each row
4654                                    // contributes the per-axis prior energy), so it is NOT
4655                                    // minibatch-scaled — the per-chunk row sums already
4656                                    // reconstruct the full coordinate prior across a pass.
4657                                    // The value (`ard_value`/`loss.ard`) and the gradient
4658                                    // both come from the SAME `ArdAxisPrior` energy, so they
4659                                    // stay FD-consistent on periodic axes. The exact
4660                                    // von-Mises curvature `V'' = α·cos(κt)` is INDEFINITE —
4661                                    // it goes negative for |t| past a quarter period — so
4662                                    // writing it raw into the Newton/Schur `htt` diagonal
4663                                    // makes that PSD curvature block indefinite and the Schur
4664                                    // Cholesky (used both for the Newton step and the exact
4665                                    // log-det) fails on a non-PD pivot. Accumulate the PSD
4666                                    // majorizer `max(V'', 0)` instead, exactly as
4667                                    // `add_sae_coord_penalty` does for the registry coord
4668                                    // penalties: the positive part keeps `htt` PSD so the
4669                                    // factorization succeeds, and majorizing the curvature of
4670                                    // a fixed prior only damps the Newton step — it does not
4671                                    // move the stationary point (the gradient, which sets the
4672                                    // fixed point, stays the exact `V'`).
4673                                    let alpha =
4674                                        SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
4675                                    let prior =
4676                                        ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4677                                    block.gt[starts[j] + axis] += prior.grad;
4678                                    block.htt[[starts[j] + axis, starts[j] + axis]] +=
4679                                        prior.hess.max(0.0);
4680                                }
4681                            }
4682                        } else {
4683                            for atom_idx in 0..k_atoms {
4684                                let coord = &self.assignment.coords[atom_idx];
4685                                let d = coord.latent_dim();
4686                                if rho.log_ard[atom_idx].is_empty() {
4687                                    continue;
4688                                }
4689                                if rho.log_ard[atom_idx].len() != d {
4690                                    return Err(format!(
4691                                        "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
4692                                        rho.log_ard[atom_idx].len()
4693                                    ));
4694                                }
4695                                let off = coord_offsets[atom_idx];
4696                                let row_t = coord.row(row);
4697                                let periods = &ard_axis_periods[atom_idx];
4698                                for axis in 0..d {
4699                                    // PSD-majorize the (possibly negative) von-Mises curvature
4700                                    // into the Newton/Schur `htt` block; see the compact-layout
4701                                    // branch above for why `max(V'', 0)` is required to keep
4702                                    // `htt` PD (the exact `V'' = α·cos κt` is indefinite past a
4703                                    // quarter period and breaks the Schur/log-det Cholesky).
4704                                    let alpha = SaeManifoldRho::stable_exp_strength(
4705                                        rho.log_ard[atom_idx][axis],
4706                                    );
4707                                    let prior =
4708                                        ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4709                                    block.gt[off + axis] += prior.grad;
4710                                    block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
4711                                }
4712                            }
4713                        }
4714
4715                        // Beta gradient/Hessian — Kronecker form J_β = φᵀ ⊗ I_p.
4716                        //
4717                        // The per-row beta Jacobian is
4718                        //   J_β[out_col, beta_idx] = a_k · phi_k[basis_col]   if out_col == out_col(beta_idx)
4719                        //                            0                         otherwise
4720                        // so the data-fit Gauss-Newton beta-Hessian factors as a rank-`p`
4721                        // sum of outer products. We pre-compute the per-(atom, basis_col)
4722                        // scalar `a_k · phi_k` once and reuse it across the `out_col`
4723                        // and inner `(atom_j, basis_col2)` loops.
4724                        //
4725                        // Full-B rows keep the matrix-free Kronecker path below. Factored
4726                        // rows write the `q_i × Σ M_k r_k` C-space cross slab directly by
4727                        // folding each output-channel contribution through the atom frame,
4728                        // so no `q_i × β_dim` slab is ever materialized.
4729                        //
4730                        // Only the row's active atoms contribute `a_phi` support and data
4731                        // curvature: in a compact layout (JumpReLU gate or large-K
4732                        // top-`k_active` truncation) the inactive atoms carry zero (gated)
4733                        // or sub-cutoff assignment mass and are excluded — this is what
4734                        // keeps both the htbeta support and the `G` accumulation
4735                        // `O(k_active)` rather than `O(K)`. In the dense full-support
4736                        // layout `row_active` spans all atoms.
4737                        let row_active: &[usize] = match row_layout.as_ref() {
4738                            Some(layout) => layout.active_atoms[row].as_slice(),
4739                            None => &all_atoms_index,
4740                        };
4741                        // #1407: in fixed-decoder mode the β tier is not assembled at
4742                        // all — leave gb_delta/g_blocks empty and kron None. htt/gt
4743                        // (built above) are the only outputs the frozen-decoder step
4744                        // consumes.
4745                        let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
4746                        // Per-active-atom weighted basis row `a_k · φ_k[·]`, retained so the
4747                        // data Gram blocks can be accumulated as clean per-atom-pair outer
4748                        // products `(a_k φ_k) (a_{k'} φ_{k'})ᵀ`.
4749                        let mut weighted_phi: Vec<(usize, Vec<f64>)> =
4750                            Vec::with_capacity(row_active.len());
4751                        if !fixed_decoder {
4752                            for &atom_idx in row_active {
4753                                let atom = &self.atoms[atom_idx];
4754                                let atom_beta_off = beta_offsets[atom_idx];
4755                                let m = atom.basis_size();
4756                                let a_k = assignments[atom_idx];
4757                                let mut wphi = Vec::with_capacity(m);
4758                                for basis_col in 0..m {
4759                                    let phi = atom.basis_values[[row, basis_col]];
4760                                    // #991 design-honesty seam, β leg: the `√w_row` here pairs
4761                                    // with the `√w_row` on the residual (β gradient =
4762                                    // `a·φ · M r` ⇒ w_row) and with itself (β Gram `G` and the
4763                                    // htbeta Kronecker capture ⇒ w_row). `1.0` when unweighted.
4764                                    let w = a_k * phi * sqrt_row_w;
4765                                    a_phi.push((atom_beta_off + basis_col * p, w));
4766                                    wphi.push(w);
4767                                }
4768                                weighted_phi.push((atom_idx, wphi));
4769                            }
4770                            // β data-fit gradient `gᵦ += J_βᵀ M_n r_n`. The β-Jacobian is
4771                            // `J_β = φ_nᵀ ⊗ I_p`, so `J_βᵀ M_n r_n = φ_n ⊗ (M_n r_n)` —
4772                            // contract the basis weight `a·φ` against the p-space metric-applied
4773                            // residual `error_metric` (= `M_n r_n`), the SAME whitening the value
4774                            // path and t-block share. When not whitening, `error_metric == error`
4775                            // and this is byte-identical to the historical `J_βᵀ r`.
4776                            for &(beta_base_i, j_beta_i) in a_phi.iter() {
4777                                if j_beta_i == 0.0 {
4778                                    continue;
4779                                }
4780                                for out_col in 0..p {
4781                                    gb_delta.push((
4782                                        beta_base_i + out_col,
4783                                        j_beta_i * error_metric[out_col],
4784                                    ));
4785                                    // No dense hbb write — the sparse `G ⊗ I_p` op installed
4786                                    // after the loop carries the data-fit GN β-Hessian.
4787                                }
4788                            }
4789                            if frames_engaged {
4790                                for &atom_idx in row_active {
4791                                    let atom = &self.atoms[atom_idx];
4792                                    let m = atom.basis_size();
4793                                    let a_k = assignments[atom_idx];
4794                                    for basis_col in 0..m {
4795                                        let phi = atom.basis_values[[row, basis_col]];
4796                                        let w = a_k * phi * sqrt_row_w;
4797                                        if w == 0.0 {
4798                                            continue;
4799                                        }
4800                                        let c_base = frame_projection.border_offsets[atom_idx]
4801                                            + basis_col * frame_projection.ranks[atom_idx];
4802                                        for c in 0..q_row {
4803                                            let mut hrow = block.htbeta.row_mut(c);
4804                                            let hrow_slice = hrow
4805                                                .as_slice_mut()
4806                                                .expect("htbeta row is contiguous");
4807                                            for out_col in 0..p {
4808                                                let value = local_jac_row[[c, out_col]] * w;
4809                                                frame_projection.accumulate_output_project(
4810                                                    atom_idx, c_base, out_col, value, hrow_slice,
4811                                                );
4812                                            }
4813                                        }
4814                                    }
4815                                }
4816                            }
4817                            // Data-fit GN β-Hessian: accumulate the channel-independent block
4818                            // `G[μ_i, μ_j] += (a_k φ_k)[μ_i] (a_{k'} φ_{k'})[μ_j]` into the
4819                            // sparse per-atom-pair map (the `out_col` dimension is carried by
4820                            // `I_p`). Only co-occurring `(atom_i, atom_j)` pairs are touched.
4821                            for ai in 0..weighted_phi.len() {
4822                                let (atom_i, ref wphi_i) = weighted_phi[ai];
4823                                let m_i = wphi_i.len();
4824                                for aj in 0..weighted_phi.len() {
4825                                    let (atom_j, ref wphi_j) = weighted_phi[aj];
4826                                    let m_j = wphi_j.len();
4827                                    let blk = g_blocks
4828                                        .entry((atom_i, atom_j))
4829                                        .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4830                                    for li in 0..m_i {
4831                                        let wi = wphi_i[li];
4832                                        if wi == 0.0 {
4833                                            continue;
4834                                        }
4835                                        for lj in 0..m_j {
4836                                            blk[[li, lj]] += wi * wphi_j[lj];
4837                                        }
4838                                    }
4839                                }
4840                            }
4841                        } // #1407 end `if !fixed_decoder` β-tier accumulation
4842                        let (kron_a_phi, kron_jac) = if !frames_engaged && !fixed_decoder {
4843                            // Flatten local_jac_row row-major into a plain Vec<f64> (q_row * p entries).
4844                            let mut jac_flat = vec![0.0_f64; q_row * p];
4845                            for c in 0..q_row {
4846                                for j in 0..p {
4847                                    jac_flat[c * p + j] = local_jac_row[[c, j]];
4848                                }
4849                            }
4850                            (Some(a_phi), Some(jac_flat))
4851                        } else {
4852                            (None, None)
4853                        };
4854                        Ok(SaeAssemblyRow {
4855                            row,
4856                            block,
4857                            gb_delta,
4858                            g_blocks,
4859                            kron_a_phi,
4860                            kron_jac,
4861                        })
4862                        }) // #1557 with_nested_parallel
4863                    },
4864                )
4865                .collect::<Result<Vec<_>, String>>()?;
4866
4867            // Fold THIS chunk's rows (ascending) into the global accumulators.
4868            // The parallel collect preserves index order within the chunk and
4869            // chunks are visited in ascending `chunk_start` order, so the overall
4870            // fold order is `0,1,2,…,n-1` — identical to the former single-pass
4871            // fold. The `row == chunk_start + fold_offset_in_chunk` assert pins
4872            // that strict sequential arrival (the invariant the `kron_*`
4873            // row-aligned pushes depend on).
4874            for row_result in row_results.into_iter() {
4875                let row = row_result.row;
4876                assert_eq!(
4877                    row,
4878                    chunk_start + fold_offset_in_chunk,
4879                    "parallel SAE row assembly returned rows out of order"
4880                );
4881                fold_offset_in_chunk += 1;
4882                for (idx, value) in row_result.gb_delta {
4883                    sys.gb[idx] += value;
4884                }
4885                for ((atom_i, atom_j), data) in row_result.g_blocks {
4886                    let m_i = data.nrows();
4887                    let m_j = data.ncols();
4888                    let blk = g_blocks
4889                        .entry((atom_i, atom_j))
4890                        .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4891                    for li in 0..m_i {
4892                        for lj in 0..m_j {
4893                            blk[[li, lj]] += data[[li, lj]];
4894                        }
4895                    }
4896                }
4897                if !frames_engaged && !fixed_decoder {
4898                    // Rows arrive in ascending order across chunks, so pushing
4899                    // here yields `kron_*[row]` aligned to the row index exactly
4900                    // as the single-pass `push` did.
4901                    kron_a_phi.push(
4902                        row_result
4903                            .kron_a_phi
4904                            .expect("full-B SAE row assembly must return a_phi rows"),
4905                    );
4906                    kron_jac.push(
4907                        row_result
4908                            .kron_jac
4909                            .expect("full-B SAE row assembly must return local Jacobian rows"),
4910                    );
4911                }
4912                sys.rows[row] = row_result.block;
4913            }
4914            chunk_start = chunk_end;
4915        }
4916        // #1407: fixed-decoder early return. The per-row htt/gt are now fully
4917        // assembled (data GN + assignment/ARD prior). Apply only the htt/gt
4918        // Riemannian projection (the decoder/β tier is intentionally absent), then
4919        // return the block-diagonal system. `fixed_decoder_step_from_rows` reads
4920        // only `rows[*].htt`/`gt` + `row_offsets`, so no β-tier object is needed.
4921        if fixed_decoder {
4922            match row_layout.as_ref() {
4923                None => {
4924                    // Dense uniform-q: project htt/gt (and the 0-width htbeta, a
4925                    // no-op) through the ext-coord manifold.
4926                    self.apply_sae_riemannian_geometry(&mut sys);
4927                }
4928                Some(layout) => {
4929                    // Compact heterogeneous-q: project each row's htt/gt at its
4930                    // own ext-coord point, mirroring the full path's compact
4931                    // Riemannian block (htbeta is 0-width here, so skipped).
4932                    if !self.ext_coord_manifold().is_euclidean() {
4933                        for row_idx in 0..n {
4934                            let (manifold_i, point_i) =
4935                                self.compact_row_ext_manifold_and_point(row_idx, layout);
4936                            let t_i = point_i.view();
4937                            let gt_e = sys.rows[row_idx].gt.clone();
4938                            let htt_e = sys.rows[row_idx].htt.clone();
4939                            sys.rows[row_idx].gt =
4940                                manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
4941                            sys.rows[row_idx].htt = manifold_i.riemannian_hessian_matrix(
4942                                t_i,
4943                                gt_e.view(),
4944                                htt_e.view(),
4945                            );
4946                        }
4947                    }
4948                }
4949            }
4950            if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
4951                sys.set_row_gauge_deflation(deflation);
4952            }
4953            self.last_row_layout = row_layout;
4954            self.last_frames_active = frames_engaged;
4955            return Ok(sys);
4956        }
4957        // Apply Riemannian geometry to the per-row row blocks (htt, gt) and
4958        // also to the per-row Kronecker local Jacobians stored in kron_jac.
4959        // When the SAE ext-coord manifold is non-Euclidean (any atom latent
4960        // on sphere / circle / interval), the local Jacobian rows that map
4961        // into the t-block tangent space must be projected via the per-row
4962        // tangent projector P_i.  This mirrors what
4963        // `apply_riemannian_latent_geometry` does to `row.htbeta`, applied
4964        // here to the (q × p) kron_jac so the Kronecker htbeta_matvec uses
4965        // the Riemannian-projected form.
4966        // Apply Riemannian geometry only for the dense uniform-q layout. Any
4967        // compact active-set layout (JumpReLU gate or large-K softmax/IBP
4968        // truncation) has heterogeneous q_i; the Riemannian projector path
4969        // requires a uniform latent dimension. The sparse plan only engages on
4970        // Euclidean ext-coord manifolds (see `sparse_active_plan`), so skipping
4971        // the projector here is correct — there is nothing to project.
4972        match row_layout.as_ref() {
4973            None => {
4974                let raw_gt_rows: Vec<Array1<f64>> =
4975                    sys.rows.iter().map(|row| row.gt.clone()).collect();
4976                self.apply_sae_riemannian_geometry(&mut sys);
4977                let manifold = self.ext_coord_manifold();
4978                if !frames_engaged && !manifold.is_euclidean() {
4979                    let ext = self.ext_coord_matrix();
4980                    // Project the local Jacobian columns onto the tangent space at
4981                    // each row's ext-coord point. Each column `j` of the row's
4982                    // (q_row × p) Jacobian is an ambient-space vector of length
4983                    // `q_row`; the manifold projector acts on one such column at a
4984                    // time. Working directly on the row-major `jac_flat` storage via
4985                    // a single reusable `col_buf` avoids the two dense (q × p) copies
4986                    // (flatten→Array2, project, unflatten→Vec) that previously fired
4987                    // per row. `t_buf` still holds the row's ext-coord vector.
4988                    let mut t_buf = vec![0.0_f64; q];
4989                    let mut col_buf = Array1::<f64>::zeros(q);
4990                    for row_idx in 0..n {
4991                        let ext_row = ext.row(row_idx);
4992                        for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
4993                            *slot = v;
4994                        }
4995                        let t_i = ArrayView1::from(t_buf.as_slice());
4996                        let raw_gt = raw_gt_rows[row_idx].view();
4997                        let jac_flat = &mut kron_jac[row_idx];
4998                        let q_row = jac_flat.len() / p;
4999                        for j in 0..p {
5000                            for c in 0..q_row {
5001                                col_buf[c] = jac_flat[c * p + j];
5002                            }
5003                            let projected_col = manifold.project_vector_to_gradient_tangent(
5004                                t_i,
5005                                raw_gt.slice(ndarray::s![..q_row]),
5006                                col_buf.slice(ndarray::s![..q_row]),
5007                            );
5008                            for c in 0..q_row {
5009                                jac_flat[c * p + j] = projected_col[c];
5010                            }
5011                        }
5012                    }
5013                }
5014            }
5015            Some(layout) => {
5016                // Compact active-set layout (#1117 follow-up): the dense
5017                // `ext_coord_manifold()` is keyed to the uniform full-`q` block
5018                // ordering, so it cannot be applied to the heterogeneous compact
5019                // rows directly. Instead we rebuild, PER ROW, the product manifold
5020                // and ext-coord point in that row's compact column order (see
5021                // `compact_row_ext_manifold_and_point`) and apply the SAME three
5022                // per-row Riemannian operations the dense
5023                // `apply_riemannian_latent_geometry` applies — gradient tangent
5024                // projection of `gt`, the Riemannian Hessian correction of `htt`,
5025                // and the column tangent projection of `htbeta` — plus the
5026                // identical Kronecker `kron_jac` column projection. On the shared
5027                // active support this is byte-identical to slicing the dense
5028                // product manifold, so engaging the sparse plan on a non-Euclidean
5029                // ext manifold is now correct (the former
5030                // `is_euclidean()`-only guard in `sparse_active_plan` is lifted).
5031                //
5032                // Euclidean ext manifolds still skip all of this (every
5033                // per-row manifold is a product of Euclidean parts whose
5034                // projector is the identity); we early-out so those rows stay
5035                // byte-for-byte the historical compact path.
5036                if !self.ext_coord_manifold().is_euclidean() {
5037                    for row_idx in 0..n {
5038                        let (manifold_i, point_i) =
5039                            self.compact_row_ext_manifold_and_point(row_idx, layout);
5040                        let t_i = point_i.view();
5041                        // gt / htt / htbeta on the compact ArrowRowBlock, exactly
5042                        // as `apply_riemannian_latent_geometry` does for dense
5043                        // uniform-q rows.
5044                        let gt_e = sys.rows[row_idx].gt.clone();
5045                        let htt_e = sys.rows[row_idx].htt.clone();
5046                        sys.rows[row_idx].gt =
5047                            manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
5048                        sys.rows[row_idx].htt =
5049                            manifold_i.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
5050                        // #1406: only the frames path holds a real dense `htbeta`
5051                        // slab; the matrix-free path leaves it 0-width (the
5052                        // cross-block geometry is applied to `kron_jac` below), so
5053                        // projecting a zero-column matrix is a no-op we skip.
5054                        if frames_engaged {
5055                            let htbeta_e = sys.rows[row_idx].htbeta.clone();
5056                            sys.rows[row_idx].htbeta = manifold_i
5057                                .project_matrix_columns_to_gradient_tangent(
5058                                    t_i,
5059                                    gt_e.view(),
5060                                    htbeta_e.view(),
5061                                );
5062                        }
5063                        // Kronecker local-Jacobian column projection (full-B path
5064                        // only), using the SAME pre-projection gradient `gt_e` so
5065                        // the cross-block geometry matches the dense branch.
5066                        if !frames_engaged {
5067                            let jac_flat = &mut kron_jac[row_idx];
5068                            let q_row = jac_flat.len() / p;
5069                            let mut col_buf = Array1::<f64>::zeros(q_row);
5070                            for j in 0..p {
5071                                for c in 0..q_row {
5072                                    col_buf[c] = jac_flat[c * p + j];
5073                                }
5074                                let projected_col = manifold_i.project_vector_to_gradient_tangent(
5075                                    t_i,
5076                                    gt_e.view(),
5077                                    col_buf.view(),
5078                                );
5079                                for c in 0..q_row {
5080                                    jac_flat[c * p + j] = projected_col[c];
5081                                }
5082                            }
5083                        }
5084                    }
5085                }
5086            }
5087        }
5088        // Build and install the full-B Kronecker htbeta_matvec.
5089        //
5090        // `SaeKroneckerRows` holds per-row `(a_phi, local_jac)` and implements
5091        // the cross-block operator without ever materialising the dense
5092        // `(q × K·p)` slab.  The cross-block factorises as `H_tβ = L · J_β`,
5093        // where `J_β = φᵀ ⊗ I_p` projects a length-`K` β vector onto the
5094        // `p`-dimensional decoded output space (`apply_jbeta`) and `L_i` is
5095        // the per-row `(q_i × p)` assignment+coordinate Jacobian that lifts
5096        // that p-vector into the row's `q_i`-dim tangent block (`apply_l`).
5097        // Both factors are required: the contract of `set_row_htbeta_operator`
5098        // is `out.len() == d` (= `q_i`), so writing `apply_jbeta`'s p-vector
5099        // output directly into a length-`q_i` buffer overflows whenever
5100        // `p > q_i` (the common case once `p` reflects real feature width).
5101        // Symmetric for the transpose: `H_βt = J_βᵀ · Lᵀ`, so apply `Lᵀ`
5102        // first to map the q_i-vector back to p-space, then scatter through
5103        // the support.
5104        // #1017/#1026: the legacy full-B device PCG assumes `G ⊗ I_p`, while
5105        // framed systems carry `G_ij ⊗ W_ij` with rank-r atom blocks. Feeding a
5106        // framed system to that kernel would silently return the wrong Newton
5107        // step. Framed device PCG therefore needs the dedicated factored kernel.
5108        // #1033 large-n: the per-row support `kron_a_phi` and local Jacobians
5109        // `kron_jac` are consumed by BOTH the host matrix-free row operator
5110        // (`SaeKroneckerRows`) and the solver's `DeviceSaePcgData`. Previously
5111        // each took its own full `O(n·q·p)` / `O(n·k_active)` clone, so the
5112        // always-resident footprint of the CPU non-frames path carried TWO copies
5113        // of the dominant Jacobian slab. Promote each to a single `Arc<[…]>` once
5114        // and hand both consumers a refcount bump (`O(1)`) — the backing
5115        // allocation is shared, halving the resident per-row Jacobian memory.
5116        // Reads are identical (`&arc[row]`, `.len()`), so the assembled system and
5117        // every matvec are bit-for-bit unchanged.
5118        let device_rows = if frames_engaged {
5119            None
5120        } else {
5121            let a_phi_shared: Arc<[Vec<(usize, f64)>]> =
5122                Arc::from(std::mem::take(&mut kron_a_phi).into_boxed_slice());
5123            let jac_shared: Arc<[Vec<f64>]> =
5124                Arc::from(std::mem::take(&mut kron_jac).into_boxed_slice());
5125            Some((a_phi_shared, jac_shared))
5126        };
5127        if !frames_engaged {
5128            let (a_phi_shared, jac_shared) = device_rows
5129                .clone()
5130                .expect("non-frames path always populates device_rows");
5131            let kron = Arc::new(SaeKroneckerRows::new(p, a_phi_shared, jac_shared));
5132            let kron_t = Arc::clone(&kron);
5133            let p_dim = p;
5134            sys.set_row_htbeta_operator(
5135                move |row_idx, x, out| {
5136                    // out = L_i · (J_β · x). Allocate a length-p scratch buffer
5137                    // for the intermediate decoded-output vector; both factors
5138                    // overwrite their output buffers (`apply_jbeta` zeroes
5139                    // before accumulating, `apply_l` writes per-row), so no
5140                    // pre-zeroing of `u_p`/`out` is needed.
5141                    let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5142                    let mut u_p = vec![0.0_f64; p_dim];
5143                    if let Some(xs) = x.as_slice() {
5144                        kron.apply_jbeta(row_idx, xs, &mut u_p);
5145                    } else {
5146                        let x_vec: Vec<f64> = x.iter().copied().collect();
5147                        kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
5148                    }
5149                    kron.apply_l(row_idx, &u_p, out_slice);
5150                },
5151                move |row_idx, v, out| {
5152                    // out += J_βᵀ · (Lᵀ · v). `apply_l_t` accumulates into a
5153                    // zero-initialised length-p buffer to produce the p-vector
5154                    // `Lᵀ v`; `scatter_jbeta_t` then adds φ_i[s] · u_p[j] into
5155                    // the length-K β accumulator at each active `(s, j)`.
5156                    let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5157                    let mut u_p = vec![0.0_f64; p_dim];
5158                    if let Some(vs) = v.as_slice() {
5159                        kron_t.apply_l_t(row_idx, vs, &mut u_p);
5160                    } else {
5161                        let v_vec: Vec<f64> = v.iter().copied().collect();
5162                        kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
5163                    }
5164                    kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
5165                },
5166            );
5167        }
5168        let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
5169        let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
5170            Some(&frame_projection)
5171        } else {
5172            None
5173        };
5174        if let Some(registry) = analytic_penalties {
5175            // Upfront validation: refuse penalty kinds the SAE row layout
5176            // cannot host, and refuse mixed-d row-block configurations.
5177            // This makes the dispatch loop below total — no runtime
5178            // "unsupported penalty" fallthrough, no K-gating.
5179            self.validate_analytic_penalty_registry(registry)
5180                .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5181            beta_penalty_assembly = self
5182                .add_sae_analytic_penalty_contributions(
5183                    &mut sys,
5184                    registry,
5185                    penalty_scale,
5186                    row_layout.as_ref(),
5187                    dense_beta_curvature,
5188                    factored_row_projection,
5189                )
5190                .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5191        }
5192        // #1026 — decoder repulsion (collinearity-gated, registry-independent):
5193        // accumulate into the full-`B` β-tier here, BEFORE the frame transform,
5194        // so a framed system carries it identically to the analytic β penalties.
5195        // No-op unless two atoms are near-collinear (the frozen gate is `None`).
5196        if self.add_sae_decoder_repulsion(&mut sys, penalty_scale, dense_beta_curvature) {
5197            beta_penalty_assembly.record_curvature(dense_beta_curvature);
5198        }
5199        // #1026/#1522 — interior-point collapse-prevention barriers. The amplitude
5200        // barrier supplies the OUTWARD radial force at the zero-decoder collapse
5201        // point (the principal failure state the threshold repulsion skips), and
5202        // the separation barrier supplies the alignment-divergent separating
5203        // curvature on normalized shapes weighted by coactivation. Both accumulate
5204        // into the full-`B` β-tier here, BEFORE the frame transform, so a framed
5205        // system carries them identically to the analytic β penalties.
5206        // #1610 — on the dense path the barrier's Levenberg majorizer scatters
5207        // onto `sys.hbb`; on the matrix-free / framed production path `sys.hbb` is
5208        // unused, so the barrier hands back a per-atom scalar ridge which we fold
5209        // into `smooth_scaled_s` (the single source for the CPU composite penalty
5210        // op AND the device smooth blocks), restoring the collapse-prevention
5211        // curvature the operator was silently dropping there.
5212        let mut sep_atom_curv = vec![0.0_f64; self.atoms.len()];
5213        if self.add_sae_separation_barrier(
5214            &mut sys,
5215            penalty_scale,
5216            dense_beta_curvature,
5217            &mut sep_atom_curv,
5218        ) {
5219            if dense_beta_curvature {
5220                beta_penalty_assembly.record_curvature(true);
5221            } else {
5222                // Fold the per-atom majorizer `lev_k·I_{M_k}` into the smooth
5223                // penalty factor `λ S_k`. With `⊗ I_p` (full-`B`) or `⊗ I_{r_k}`
5224                // (factored, `U_kᵀU_k = I`) this is exactly the `lev_k·I` block
5225                // diagonal the dense path writes — and it now flows through the
5226                // structured penalty op and the device smooth blocks. No
5227                // `deferred_factored` mark: the curvature is in the smooth op, not
5228                // a deferred dense block, so the device path stays engaged.
5229                for atom_idx in 0..self.atoms.len() {
5230                    let c = sep_atom_curv[atom_idx];
5231                    if c > 0.0 {
5232                        let m = smooth_scaled_s[atom_idx].nrows();
5233                        for i in 0..m {
5234                            smooth_scaled_s[atom_idx][[i, i]] += c;
5235                        }
5236                        smooth_ops[atom_idx] = Arc::new(IdentityRightKroneckerPenaltyOp {
5237                            factor_a: smooth_scaled_s[atom_idx].clone(),
5238                            p,
5239                            global_offset: beta_offsets[atom_idx],
5240                            k: beta_dim,
5241                        });
5242                    }
5243                }
5244            }
5245        }
5246        if frames_engaged {
5247            // ── #972 / #977 T1 — FACTORED β-tier transform ──────────────────
5248            //
5249            // The entire β-tier above was assembled in the full-`B` (p-wide)
5250            // layout: `sys.gb` is `g_B` (length `beta_dim`), `sys.hbb` carries
5251            // any analytic Beta-tier penalty, and `g_blocks` is the
5252            // FRAME-INDEPENDENT basis Gram. We now rebuild the β-tier in the
5253            // factored coordinate space `C` (width `factored_border_dim`), the
5254            // full-`B` system sandwiched by `Φ = blkdiag(I_{M_k} ⊗ U_k)`:
5255            //   * gradient   `g_C = Φᵀ g_B`              (per atom `(g_B U_k)`),
5256            //   * data H      `Φᵀ(G⊗I_p)Φ = G_{ij}⊗(U_iᵀU_j)`,
5257            //   * smooth      `λ S_k ⊗ I_{r_k}`          (since `U_kᵀU_k = I`),
5258            //   * analytic    `Φᵀ hbb Φ`                 (dense, only if written).
5259            // Un-framed atoms ride the `r_k = p, U_k = I_p` identity special case.
5260            let off_c = &frame_projection.border_offsets;
5261            let ranks = &frame_projection.ranks;
5262            let basis_sizes = &frame_projection.basis_sizes;
5263            let border_dim = frame_projection.border_dim();
5264            let gb_c = frame_projection.project_border_vec(sys.gb.view());
5265
5266            // Data β-Hessian: `G_{ij} ⊗ W_{ij}` with `W_{ij} = U_iᵀU_j`. The
5267            // basis Gram `g_blocks` is unchanged; only the output factor is the
5268            // per-pair frame overlap (`I_{r_k}` within a framed atom, `I_p` for
5269            // un-framed).
5270            let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
5271            for ((atom_i, atom_j), data) in g_blocks.into_iter() {
5272                if data.iter().all(|&v| v == 0.0) {
5273                    continue;
5274                }
5275                // `W_{ij} = U_iᵀ U_j` from the precomputed per-atom frames.
5276                let w = self.frame_cross_factor(atom_i, atom_j);
5277                frame_blocks.push(FactoredFrameGBlock {
5278                    atom_i,
5279                    atom_j,
5280                    g: data,
5281                    w,
5282                });
5283            }
5284            // #1017/#1026 — snapshot the factored data-fit blocks for the
5285            // frames-engaged device PCG BEFORE `FactoredFrameKroneckerOp::new`
5286            // consumes them. Cheap clone (co-occurring blocks only).
5287            let device_frame_blocks = frame_blocks.clone();
5288            let data_op =
5289                FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
5290
5291            // Smooth penalty in factored space: `λ S_k ⊗ I_{r_k}` at `off_C[k]`.
5292            let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
5293            for k in 0..self.atoms.len() {
5294                let r = ranks[k];
5295                ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
5296                    factor_a: smooth_scaled_s[k].clone(),
5297                    p: r,
5298                    global_offset: off_c[k],
5299                    k: border_dim,
5300                }));
5301            }
5302            ops.push(Arc::new(data_op));
5303            // Analytic Beta-tier penalty: project the dense full-`B` `hbb` block
5304            // `Φᵀ hbb Φ` into the factored space. Only present when a Beta-tier
5305            // penalty actually wrote `hbb` (else `hbb` is all-zero and the dense
5306            // `(border_dim)²` op is skipped entirely, exactly as full-`B`).
5307            if beta_penalty_assembly.dense_written {
5308                let hbb_c =
5309                    self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
5310                ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5311            } else if beta_penalty_assembly.deferred_factored {
5312                // Registry Beta-tier curvature deferred to factored-space probing.
5313                // The registry may be absent when `deferred_factored` was set ONLY
5314                // by the frozen-gate decoder repulsion (which is
5315                // registry-independent), so start from a zero factored block in
5316                // that case instead of unwrapping.
5317                let mut hbb_c = match analytic_penalties {
5318                    Some(registry) => self.build_factored_beta_penalty_curvature(
5319                        registry,
5320                        penalty_scale,
5321                        &frame_projection,
5322                    ),
5323                    None => Array2::<f64>::zeros((
5324                        frame_projection.border_dim(),
5325                        frame_projection.border_dim(),
5326                    )),
5327                };
5328                // #1610 — the frozen-gate decoder repulsion's PSD majorizer was
5329                // dropped on this matrix-free/framed path (only its gradient was
5330                // applied). Project it into the factored block via the same
5331                // `psd_majorizer_hvp` + frame-projection probe pattern the registry
5332                // DecoderIncoherence uses, so the collapse-prevention curvature
5333                // reaches the operator here too. No-op when no repulsion is active.
5334                self.add_factored_repulsion_curvature(
5335                    &mut hbb_c,
5336                    penalty_scale,
5337                    &frame_projection,
5338                );
5339                ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5340            }
5341
5342            // Re-point the system's β-tier to the factored width. The t-tier
5343            // (per-row `htt`, `gt`) is frame-independent and untouched; row
5344            // cross-block slabs were allocated and assembled directly in
5345            // factored coordinates, so analytic row supplements and data-fit
5346            // cross terms already share shape `(q_i × factored_border_dim)`.
5347            sys.k = border_dim;
5348            sys.gb = gb_c;
5349            self.reclaim_border_hbb_workspace(&mut sys);
5350            // Factored per-atom block ranges for the block-Jacobi Schur
5351            // preconditioner: `[off_C[k] .. off_C[k] + M_k·r_k]`.
5352            let mut block_ranges: Vec<std::ops::Range<usize>> =
5353                Vec::with_capacity(self.atoms.len());
5354            for k in 0..self.atoms.len() {
5355                let start = off_c[k];
5356                block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
5357            }
5358            sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
5359            sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
5360            // #1017/#1026 — install the frames-engaged device SAE PCG data. Skipped
5361            // (CPU fallback) when a dense analytic Beta-tier penalty fired (the
5362            // device kernel does not model that extra dense term). Builder:
5363            // `crate::frames::build_framed_device_sae_data`.
5364            let has_dense_beta_penalty =
5365                beta_penalty_assembly.dense_written || beta_penalty_assembly.deferred_factored;
5366            if !has_dense_beta_penalty {
5367                let device = crate::frames::build_framed_device_sae_data(
5368                    crate::frames::FramedDeviceArgs {
5369                        p,
5370                        border_dim,
5371                        border_offsets: off_c.as_slice(),
5372                        ranks: ranks.as_slice(),
5373                        basis_sizes: basis_sizes.as_slice(),
5374                        smooth_scaled_s: &smooth_scaled_s,
5375                        frame_blocks: device_frame_blocks,
5376                        rows: &sys.rows,
5377                    },
5378                );
5379                sys.set_device_sae_pcg_data(device);
5380            }
5381        } else {
5382            let (device_a_phi, device_local_jac) =
5383                device_rows.expect("full-beta SAE PCG rows are cloned before row operator install");
5384            // Wire per-atom β block ranges so the Jacobi preconditioner builds one
5385            // dense Schur sub-block per atom (block-Jacobi) instead of scalar-diagonal
5386            // inversion.  Each atom's decoder coefficients form a natural block:
5387            // `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
5388            sys.set_block_offsets(self.beta_block_offsets());
5389            // Install the composite BetaPenaltyOp (#296): smoothness contributions
5390            // via per-atom KroneckerPenaltyOp (avoid dense K×K materialisation), the
5391            // data-fit Gauss-Newton β-Hessian as the structured `G ⊗ I_p`
5392            // SparseBlockKroneckerPenaltyOp (block-sparse over co-occurring
5393            // `(atom, atom')` pairs, block-diagonal across the `p` output channels,
5394            // identical per channel), plus — only when a Beta-tier analytic penalty
5395            // was written — the dense `sys.hbb` residual contribution. When no beta
5396            // penalty fired, `sys.hbb` is all-zero and the dense `(K·p)²` operator
5397            // is skipped entirely. The sparse data op tracks only the active-atom
5398            // couplings, so its storage and matvec cost scale with `k_active`, not
5399            // `K`, at `K = 100K`.
5400            // Convert the per-atom-pair coupling map into `SparseGBlock`s keyed
5401            // by μ-space offsets. Empty blocks (no co-occurrence) are simply
5402            // absent from the map.
5403            let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
5404                .into_iter()
5405                .filter_map(|((atom_i, atom_j), data)| {
5406                    if data.iter().all(|&v| v == 0.0) {
5407                        None
5408                    } else {
5409                        Some(SparseGBlock {
5410                            row_off: mu_offsets[atom_i],
5411                            col_off: mu_offsets[atom_j],
5412                            data,
5413                        })
5414                    }
5415                })
5416                .collect();
5417            let device_smooth_blocks = smooth_scaled_s
5418                .iter()
5419                .enumerate()
5420                .map(|(atom_idx, factor_a)| {
5421                    // #1117 — rank deficiency is removed at the basis layer, so the
5422                    // device PCG smooth block is just `λ S_k ⊗ I_p` (full-rank
5423                    // design); no data-null deflation is folded in here.
5424                    DeviceSaeSmoothBlock {
5425                        global_offset: beta_offsets[atom_idx],
5426                        factor_a: factor_a.clone(),
5427                    }
5428                })
5429                .collect();
5430            sys.set_device_sae_pcg_data(DeviceSaePcgData {
5431                p,
5432                beta_dim,
5433                a_phi: device_a_phi,
5434                local_jac: device_local_jac,
5435                smooth_blocks: device_smooth_blocks,
5436                sparse_g_blocks: g_sparse_blocks.clone(),
5437                frame: None,
5438            });
5439            let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
5440            ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
5441                p,
5442                dim_a: m_total,
5443                k: beta_dim,
5444                blocks: g_sparse_blocks,
5445            }));
5446            if beta_penalty_assembly.dense_written {
5447                ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
5448            }
5449            sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
5450            self.reclaim_border_hbb_workspace(&mut sys);
5451        }
5452        if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
5453            sys.set_row_gauge_deflation(deflation);
5454        }
5455        // #1038 IBP cross-row Woodbury source. The exact IBP Hessian has the
5456        // per-column rank-one cross-row block `H_(i,k),(j,k) = w·s'_k·z'_ik·z'_jk`
5457        // (for ALL `i,j`, including the `i=j` self term) that couples DISTINCT
5458        // latent rows through the shared empirical mass `M_k = Σ_i z_ik`. The
5459        // assembled row-block-diagonal `htt` already carries the `i=j` self term
5460        // `w·s'_k·z'_ik²` — it is the first summand of `assignment_hdiag`'s
5461        // `hessian_diag` value `w·(score_derivative·z_jac² + score·c_ik)` written
5462        // at the logit diagonal above. So the consumer (`solver::arrow_schur`,
5463        // #1038 `IbpCrossRowSource`/`CrossRowWoodbury`) DOWNDATES exactly
5464        // `Σ_k d_k·z'_ik²` (`self_term_downdate`) to recover the NO-SELF base
5465        // `H₀'`, then re-adds the FULL rank-one `U D Uᵀ` via the determinant
5466        // lemma — so value, the evidence log-determinant, and the θ/ρ-adjoint all
5467        // differentiate the SAME `H_full = H₀' + U D Uᵀ`.
5468        //
5469        // The source is built from the SAME `ibp_assignment_third_channels`
5470        // operator the #1006 θ-adjoint consumes:
5471        //   * `d[k] = cross_row_d[k] = w·s'_k = w·score_derivative_k` (the column
5472        //     `D`-coefficient — NOT sign-definite, hence the consumer's
5473        //     indefinite-capacitance LU);
5474        //   * `entries[(i,k)] = (global_t_index, k, z'_ik)` with `z'_ik =
5475        //     z_jac[i·K + k]`. For the DENSE layout (`assignment_coord_dim() = K`,
5476        //     `last_row_layout = None`) atom `k`'s logit slot is local position `k`
5477        //     of row `i`'s block, so `global_t_index = sys.row_offsets[i] + k`. For
5478        //     the COMPACT layout (#1420) only the row's active atoms are
5479        //     coordinates and atom `k` lives at local position `pos` of
5480        //     `active_atoms[row]`, so `global_t_index = sys.row_offsets[i] + pos`.
5481        //     Both pin the `U`-column convention bit-for-bit to the consumer's
5482        //     `ibp_logit_sites`/`row_vars_for_cache_row` slot mapping.
5483        if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
5484            let mut entries: Vec<(usize, usize, f64)> = Vec::with_capacity(n * k_atoms);
5485            for row in 0..n {
5486                let start = row * k_atoms;
5487                let g_base = sys.row_offsets[row];
5488                match row_layout.as_ref() {
5489                    // #1420: compact layout — the local logit slot `pos` (not the
5490                    // global atom index `k`) is the t-coordinate. Atom `k`'s logit
5491                    // lives at local position `pos` of `active_atoms[row]`, so emit
5492                    // `(g_base + pos, atom, z_jac[row·K + atom])` for the active set
5493                    // only. Using `g_base + k` would attach atom `k`'s derivative to
5494                    // the wrong slot (and run out of range for compact rows),
5495                    // violating the `IbpCrossRowSource` contract.
5496                    Some(layout) => {
5497                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
5498                            let z_prime = channels.z_jac[start + atom];
5499                            entries.push((g_base + pos, atom, z_prime));
5500                        }
5501                    }
5502                    // Dense layout: atom `k`'s logit slot is local position `k`.
5503                    None => {
5504                        for k in 0..k_atoms {
5505                            let z_prime = channels.z_jac[start + k];
5506                            entries.push((g_base + k, k, z_prime));
5507                        }
5508                    }
5509                }
5510            }
5511            let source = IbpCrossRowSource {
5512                r: k_atoms,
5513                d: channels.cross_row_d.clone(),
5514                entries,
5515            };
5516            sys.set_ibp_cross_row_source(source);
5517        }
5518        // Store the active-set layout for `apply_newton_step`.
5519        self.last_row_layout = row_layout;
5520        // Record whether `delta_beta` from this system is a factored ΔC (needs a
5521        // frame lift) or a full-`B` ΔB. Read by `apply_newton_step_impl`.
5522        self.last_frames_active = frames_engaged;
5523        Ok(sys)
5524    }
5525
5526    /// Project a dense full-`B` Beta-tier penalty Hessian `hbb` (`beta_dim ×
5527    /// beta_dim`, the analytic `∂²P/∂B∂B` block) into the factored coordinate
5528    /// space `Φᵀ hbb Φ` (`border_dim × border_dim`) for the #972 / #977 T1
5529    /// frame transform. `Φ = blkdiag(I_{M_k} ⊗ U_k)` maps C-space → B-space, so
5530    /// the projected block contracts both index legs through the per-atom frames.
5531    ///
5532    /// The projection is done in two passes to stay `O(beta_dim · border_dim +
5533    /// border_dim²)` instead of forming the dense `Φ` explicitly: first
5534    /// `T = hbb · Φ` (right multiply, columns fold `U`), then `Φᵀ · T` (left
5535    /// multiply, rows fold `U`). Analytic Beta-tier penalties are rare and small,
5536    /// so this only fires when one is actually installed.
5537    pub(crate) fn project_dense_penalty_to_factored(
5538        &self,
5539        hbb: ArrayView2<'_, f64>,
5540        projection: &FrameProjection,
5541    ) -> Array2<f64> {
5542        projection.project_block(hbb)
5543    }
5544
5545    pub(crate) fn build_factored_beta_penalty_curvature(
5546        &self,
5547        registry: &AnalyticPenaltyRegistry,
5548        penalty_scale: f64,
5549        projection: &FrameProjection,
5550    ) -> Array2<f64> {
5551        let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
5552        let layout = registry.rho_layout();
5553        let target_beta = self.flatten_beta();
5554        let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
5555        for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
5556            if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
5557                continue;
5558            }
5559            let rho_local = rho_global.slice(s![rho_slice.clone()]);
5560            match tier {
5561                PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
5562                    self.add_factored_beta_penalty_curvature_for_penalty(
5563                        &mut hbb_c,
5564                        penalty,
5565                        target_beta.view(),
5566                        rho_local,
5567                        penalty_scale,
5568                        projection,
5569                    );
5570                }
5571                PenaltyTier::Beta => {
5572                    self.add_factored_beta_penalty_curvature_for_penalty(
5573                        &mut hbb_c,
5574                        penalty,
5575                        target_beta.view(),
5576                        rho_local,
5577                        penalty_scale,
5578                        projection,
5579                    );
5580                }
5581                _ => {}
5582            }
5583        }
5584        hbb_c
5585    }
5586
5587    pub(crate) fn add_factored_beta_penalty_curvature_for_penalty(
5588        &self,
5589        hbb_c: &mut Array2<f64>,
5590        penalty: &AnalyticPenaltyKind,
5591        target_beta: ArrayView1<'_, f64>,
5592        rho_local: ArrayView1<'_, f64>,
5593        penalty_scale: f64,
5594        projection: &FrameProjection,
5595    ) {
5596        let p = self.output_dim();
5597        if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
5598            let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
5599                return;
5600            };
5601            let beta_dim = self.beta_dim();
5602            let mut probe = Array1::<f64>::zeros(beta_dim);
5603            for k in 0..self.atoms.len() {
5604                for basis_col in 0..projection.basis_sizes[k] {
5605                    for frame_col in 0..projection.ranks[k] {
5606                        probe.fill(0.0);
5607                        projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5608                        let col = projection.border_offsets[k]
5609                            + basis_col * projection.ranks[k]
5610                            + frame_col;
5611                        let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5612                        projection
5613                            .project_border_vec(hv.view())
5614                            .iter()
5615                            .enumerate()
5616                            .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5617                    }
5618                }
5619            }
5620            return;
5621        }
5622        if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
5623            for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
5624                let atom_idx = projection
5625                    .beta_offsets
5626                    .iter()
5627                    .position(|&offset| offset == start)
5628                    .expect("live mechanism-sparsity offset must match an SAE atom");
5629                let block_len = end - start;
5630                let mut local_penalty = per_atom.clone();
5631                local_penalty.target = PsiSlice {
5632                    range: 0..block_len,
5633                    latent_dim: Some(projection.basis_sizes[atom_idx]),
5634                };
5635                let block = target_beta.slice(s![start..end]);
5636                let mut probe = Array1::<f64>::zeros(block_len);
5637                for basis_col in 0..projection.basis_sizes[atom_idx] {
5638                    for frame_col in 0..projection.ranks[atom_idx] {
5639                        probe.fill(0.0);
5640                        projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5641                        let col = projection.border_offsets[atom_idx]
5642                            + basis_col * projection.ranks[atom_idx]
5643                            + frame_col;
5644                        let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
5645                        projection.project_local_atom_vec_into(
5646                            atom_idx,
5647                            hv.view(),
5648                            hbb_c.column_mut(col),
5649                            penalty_scale,
5650                        );
5651                    }
5652                }
5653            }
5654            return;
5655        }
5656        if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
5657            for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
5658                let atom_idx = projection
5659                    .beta_offsets
5660                    .iter()
5661                    .position(|&offset| offset == start)
5662                    .expect("live nuclear-norm offset must match an SAE atom");
5663                let block = target_beta.slice(s![start..end]);
5664                let block_len = end - start;
5665                let mut probe = Array1::<f64>::zeros(block_len);
5666                for basis_col in 0..projection.basis_sizes[atom_idx] {
5667                    for frame_col in 0..projection.ranks[atom_idx] {
5668                        probe.fill(0.0);
5669                        projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5670                        let col = projection.border_offsets[atom_idx]
5671                            + basis_col * projection.ranks[atom_idx]
5672                            + frame_col;
5673                        let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
5674                        projection.project_local_atom_vec_into(
5675                            atom_idx,
5676                            hv.view(),
5677                            hbb_c.column_mut(col),
5678                            penalty_scale,
5679                        );
5680                    }
5681                }
5682            }
5683            return;
5684        }
5685        let beta_dim = self.beta_dim();
5686        let mut probe = Array1::<f64>::zeros(beta_dim);
5687        for k in 0..self.atoms.len() {
5688            for basis_col in 0..projection.basis_sizes[k] {
5689                for frame_col in 0..projection.ranks[k] {
5690                    probe.fill(0.0);
5691                    projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5692                    let col =
5693                        projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5694                    let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5695                    projection
5696                        .project_border_vec(hv.view())
5697                        .iter()
5698                        .enumerate()
5699                        .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5700                }
5701            }
5702        }
5703        assert_eq!(p, self.output_dim());
5704    }
5705
5706    /// #1610 — project the frozen-gate decoder-repulsion PSD majorizer into the
5707    /// factored β block `hbb_c`. Mirrors the `DecoderIncoherence` arm of
5708    /// [`Self::add_factored_beta_penalty_curvature_for_penalty`] but sources the
5709    /// penalty from [`Self::live_decoder_repulsion_penalty`] (registry-independent,
5710    /// collinearity-gated), so the repulsion curvature reaches the operator on the
5711    /// matrix-free/framed path where the dense `sys.hbb` write is unused. No-op
5712    /// when no repulsion is active.
5713    pub(crate) fn add_factored_repulsion_curvature(
5714        &self,
5715        hbb_c: &mut Array2<f64>,
5716        penalty_scale: f64,
5717        projection: &FrameProjection,
5718    ) {
5719        let Some(per_fit) = self.live_decoder_repulsion_penalty() else {
5720            return;
5721        };
5722        let beta_dim = self.beta_dim();
5723        let target_beta = self.flatten_beta();
5724        // The repulsion penalty is non-learnable; its strength is already folded
5725        // into the frozen gate (see `live_decoder_repulsion_penalty`), so the rho
5726        // slice is empty/inert.
5727        let rho_local = Array1::<f64>::zeros(0);
5728        let mut probe = Array1::<f64>::zeros(beta_dim);
5729        for k in 0..self.atoms.len() {
5730            for basis_col in 0..projection.basis_sizes[k] {
5731                for frame_col in 0..projection.ranks[k] {
5732                    probe.fill(0.0);
5733                    projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5734                    let col =
5735                        projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5736                    let hv =
5737                        per_fit.psd_majorizer_hvp(target_beta.view(), rho_local.view(), probe.view());
5738                    projection
5739                        .project_border_vec(hv.view())
5740                        .iter()
5741                        .enumerate()
5742                        .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5743                }
5744            }
5745        }
5746    }
5747
5748    pub(crate) fn ext_coord_matrix(&self) -> Array2<f64> {
5749        let n = self.n_obs();
5750        let q = self.assignment.row_block_dim();
5751        let flat = self.assignment.flatten_ext_coords();
5752        let mut out = Array2::<f64>::zeros((n, q));
5753        for row in 0..n {
5754            for col in 0..q {
5755                out[[row, col]] = flat[row * q + col];
5756            }
5757        }
5758        out
5759    }
5760
5761    pub(crate) fn ext_coord_manifold(&self) -> LatentManifold {
5762        let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
5763        for _ in 0..self.assignment.assignment_coord_dim() {
5764            parts.push(LatentManifold::Euclidean);
5765        }
5766        let mut any_constrained = false;
5767        for coord in &self.assignment.coords {
5768            if coord.manifold().is_euclidean() {
5769                for _ in 0..coord.latent_dim() {
5770                    parts.push(LatentManifold::Euclidean);
5771                }
5772            } else {
5773                any_constrained = true;
5774                parts.push(coord.manifold().clone());
5775            }
5776        }
5777        if any_constrained {
5778            LatentManifold::Product(parts)
5779        } else {
5780            LatentManifold::Euclidean
5781        }
5782    }
5783
5784    pub(crate) fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
5785        let manifold = self.ext_coord_manifold();
5786        if manifold.is_euclidean() {
5787            return;
5788        }
5789        let ext = self.ext_coord_matrix();
5790        let latent =
5791            LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
5792        sys.apply_riemannian_latent_geometry(&latent);
5793    }
5794
5795    /// Build the compact-layout ext-coord product manifold and point for one row.
5796    ///
5797    /// The dense `ext_coord_manifold()` is keyed to the full-`q` block ordering
5798    /// `[assignment parts (all Euclidean for IBP-MAP / JumpReLU), then per-atom
5799    /// coord blocks in atom order]`. A compact active-set row instead lays its
5800    /// `q_active` columns out as `[one Euclidean logit slot per active atom,
5801    /// then each active atom's coord block in `active` order]` (see
5802    /// [`SaeRowLayout::from_active_atoms`] / `coord_starts`). To reuse the exact
5803    /// per-row Riemannian projector on the compact block we rebuild a product
5804    /// manifold and the matching ext-coord point in that compact order: the
5805    /// `active.len()` logit slots are `Euclidean` (the assignment channel is
5806    /// always Euclidean for the modes that engage sparsity — `assignment_coord_dim
5807    /// == k_atoms`), and each active atom contributes its own coordinate
5808    /// manifold. On the shared active support this is byte-identical to slicing
5809    /// the dense full-`q` product manifold, so the compact projection matches the
5810    /// dense path exactly — it only drops the inactive atoms' (negligible-mass)
5811    /// coordinate blocks the compact layout already excludes from curvature.
5812    ///
5813    /// Returns `(manifold, t_compact)` where `t_compact` has length `q_active`.
5814    /// The logit-slot entries of `t_compact` are filled from the row logits (the
5815    /// Euclidean projector ignores the point, so any finite value is equivalent;
5816    /// using the true logits keeps the point well-defined and finite).
5817    pub(crate) fn compact_row_ext_manifold_and_point(
5818        &self,
5819        row: usize,
5820        layout: &SaeRowLayout,
5821    ) -> (LatentManifold, Array1<f64>) {
5822        let active = &layout.active_atoms[row];
5823        let q_active = layout.row_q_active(row);
5824        let mut parts: Vec<LatentManifold> = Vec::with_capacity(active.len() + active.len());
5825        let mut point = Array1::<f64>::zeros(q_active);
5826        // Logit slots: one Euclidean part per active atom, in `active` order.
5827        let logits_row = self.assignment.logits.row(row);
5828        for (j, &k) in active.iter().enumerate() {
5829            parts.push(LatentManifold::Euclidean);
5830            point[j] = logits_row[k];
5831        }
5832        // Coordinate blocks: each active atom's coordinate manifold + point, at
5833        // the compact coord start the layout assigned it.
5834        for (j, &k) in active.iter().enumerate() {
5835            let coord = &self.assignment.coords[k];
5836            let d = coord.latent_dim();
5837            let coord_start = layout.coord_starts[row][j];
5838            let manifold_k = coord.manifold();
5839            // A `d`-dim coordinate whose manifold is a product (e.g. a torus =
5840            // Circle×Circle) already carries its `d` parts; a scalar manifold is
5841            // one part. Either way the manifold's ambient width must equal `d`,
5842            // matching the `d` compact columns at `coord_start`.
5843            parts.push(manifold_k.clone());
5844            let coord_point = coord.row(row);
5845            for axis in 0..d {
5846                point[coord_start + axis] = coord_point[axis];
5847            }
5848        }
5849        (LatentManifold::Product(parts), point)
5850    }
5851
5852    /// Numerical rank of a symmetric matrix: the count of eigenvalues
5853    /// exceeding `tol · max_eig`, with `tol = 1e-9` (the conventional
5854    /// relative spectral cutoff used elsewhere in the codebase).
5855    ///
5856    /// Used to count the penalised dimension of each atom's `smooth_penalty`
5857    /// `S_k` so the REML criterion's `−½·p·rank(S)·log λ_smooth` Occam term
5858    /// uses the *effective* penalty rank rather than the ambient basis size
5859    /// (a thin-plate / B-spline penalty has a non-trivial null space).
5860    pub(crate) fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
5861        if s.nrows() != s.ncols() {
5862            return Err(format!(
5863                "SaeManifoldTerm::symmetric_rank: matrix must be square, got {}x{}",
5864                s.nrows(),
5865                s.ncols()
5866            ));
5867        }
5868        let m = s.ncols();
5869        if m == 0 {
5870            return Ok(0);
5871        }
5872        // Symmetrize defensively through the shared ndarray helper. The SAE
5873        // rank cutoff is intentionally local to the SAE evidence contract; only
5874        // the symmetric cleanup is shared with the other construction modules.
5875        let mut sym = s.clone();
5876        gam_linalg::matrix::symmetrize_in_place(&mut sym);
5877        let (evals, _evecs) = sym
5878            .eigh(Side::Lower)
5879            .map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
5880        let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
5881        if !(max_eig > 0.0) {
5882            return Ok(0);
5883        }
5884        let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
5885        Ok(evals.iter().filter(|&&v| v > tol).count())
5886    }
5887
5888    /// Penalised quasi-Laplace evidence score for the SAE term at a FIXED ρ.
5889    ///
5890    /// #1421: this is NOT a true normalized-prior REML/evidence objective. The
5891    /// assignment priors (softmax entropy, JumpReLU) have NO finite normalizer:
5892    /// for softmax the reference-logit chart sends `P(ℓ)→0` as a free logit →±∞
5893    /// so `∫ e^{−λP} dℓ = ∞`, and JumpReLU's bounded penalty `0<P<λ` keeps
5894    /// `e^{−λP}` bounded below over an unbounded domain, also divergent. There is
5895    /// therefore no ρ-independent assignment-prior normalizer that can be dropped
5896    /// as a constant. The smoothing-penalty `−½log|λS|_+` term IS a genuine
5897    /// (proper-Gaussian) REML normalizer and is kept exactly; the rest is a
5898    /// penalized quasi-Laplace score (Laplace curvature term `½log|H|` around the
5899    /// inner optimum), which the engine minimizes over ρ.
5900    ///
5901    /// Runs the inner `(t, β)` arrow-Schur Newton solve to convergence at the
5902    /// supplied ρ (with NO in-loop ARD update — ρ is owned by the engine),
5903    /// then forms the Laplace/REML cost
5904    ///
5905    /// ```text
5906    /// V(ρ) = ℓ_pen(t̂, β̂; ρ) + ½ log|H(t̂, β̂; ρ)|
5907    ///        − ½ · p · (Σ_k rank S_k) · log λ_smooth
5908    /// ```
5909    ///
5910    /// where `ℓ_pen = loss.total()` is the penalised objective at the inner
5911    /// optimum and `½ log|H|` is the Laplace normaliser. `H` is the joint
5912    /// `(t, β)` Hessian assembled by the arrow-Schur system; its `H_tt` block
5913    /// carries `α = exp(log_ard)` on its diagonal, so as α grows `½ log|H|`
5914    /// rises while the `−½·n·log α` already inside `loss.ard` falls — their
5915    /// balance IS the effective-dof term that the deleted `α = n/‖t‖²` rule
5916    /// dropped, which is why the criterion needs no clamp to stay finite on a
5917    /// collapsing axis.
5918    ///
5919    /// The final `−½·p·rank(S)·log λ_smooth` term is the smoothing-penalty
5920    /// normaliser `−½ log|λ S|_+` restricted to its ρ-dependent part: `S_k` is
5921    /// shared across all `p` decoder output channels (the `⊗ I_p` Kronecker
5922    /// structure), so `log|λ S|_+ = p·rank(S)·log λ + p·log|S|_+`, and the
5923    /// `½ p·log|S|_+` piece is ρ-independent. The ρ-independent additive
5924    /// constants that ARE dropped here (they shift `V` by a constant and do not
5925    /// affect the ρ-argmin) are the `2π` Laplace constant and the base
5926    /// `½ p·log|S|_+` penalty logdet. #1421: NO assignment-prior normalizer is
5927    /// dropped, because none exists (softmax/JumpReLU priors are improper — see
5928    /// the doc on this function): the quasi-Laplace score simply omits a
5929    /// normalizer that is not a finite constant.
5930    ///
5931    /// Returns `(V, loss)` so the engine can both rank ρ and surface the inner
5932    /// loss breakdown.
5933    pub fn reml_criterion(
5934        &mut self,
5935        target: ArrayView2<'_, f64>,
5936        rho: &SaeManifoldRho,
5937        registry: Option<&AnalyticPenaltyRegistry>,
5938        inner_max_iter: usize,
5939        learning_rate: f64,
5940        ridge_ext_coord: f64,
5941        ridge_beta: f64,
5942    ) -> Result<(f64, SaeManifoldLoss), String> {
5943        self.reml_criterion_with_refine_policy(
5944            target,
5945            rho,
5946            registry,
5947            inner_max_iter,
5948            learning_rate,
5949            ridge_ext_coord,
5950            ridge_beta,
5951            true,
5952        )
5953    }
5954
5955    pub(crate) fn reml_criterion_with_refine_policy(
5956        &mut self,
5957        target: ArrayView2<'_, f64>,
5958        rho: &SaeManifoldRho,
5959        registry: Option<&AnalyticPenaltyRegistry>,
5960        inner_max_iter: usize,
5961        learning_rate: f64,
5962        ridge_ext_coord: f64,
5963        ridge_beta: f64,
5964        refine_progress_extension: bool,
5965    ) -> Result<(f64, SaeManifoldLoss), String> {
5966        let plan = self.streaming_plan().admitted_or_error(
5967            self.n_obs(),
5968            self.output_dim(),
5969            self.k_atoms(),
5970        )?;
5971        if plan.streaming {
5972            // #1225: streaming and dense MUST optimize the SAME mathematical
5973            // objective — the full REML criterion `loss.total() + extra_penalty +
5974            // ½ log|H| − Occam`. The streaming branch previously returned only
5975            // `loss.total() + extra_penalty_energy`, dropping the Laplace
5976            // normalizer `½ log|H|` and the Occam term, so large shapes (exactly
5977            // where streaming is needed) were ranked by penalized loss rather than
5978            // REML — and dense vs streaming disagreed on the objective. Route
5979            // through the streaming exact-logdet path, which assembles the same
5980            // chunk-by-chunk-bit-identical `½ log|H|_stream` and the same
5981            // `−Occam`/extra-penalty terms as the dense `reml_criterion_with_cache`
5982            // (different memory strategy, same objective).
5983            self.reml_criterion_streaming_exact(
5984                target,
5985                rho,
5986                registry,
5987                inner_max_iter,
5988                learning_rate,
5989                ridge_ext_coord,
5990                ridge_beta,
5991            )
5992        } else {
5993            let (v, loss, _cache) = self.reml_criterion_with_cache_refine_policy(
5994                target,
5995                rho,
5996                registry,
5997                inner_max_iter,
5998                learning_rate,
5999                ridge_ext_coord,
6000                ridge_beta,
6001                refine_progress_extension,
6002            )?;
6003            Ok((v, loss))
6004        }
6005    }
6006
6007    /// As [`Self::reml_criterion`], but also returns the converged undamped
6008    /// `ArrowFactorCache` so callers (the EFS fixed-point step) can read the
6009    /// selected-inverse traces `(H⁻¹)_tt` / `(H⁻¹)_ββ` without re-factoring.
6010    /// The cache is the single shared O(K³) Direct factor; both the
6011    /// log-determinant criterion and the Fellner-Schall ρ-step consume it.
6012    pub fn reml_criterion_with_cache(
6013        &mut self,
6014        target: ArrayView2<'_, f64>,
6015        rho: &SaeManifoldRho,
6016        registry: Option<&AnalyticPenaltyRegistry>,
6017        inner_max_iter: usize,
6018        learning_rate: f64,
6019        ridge_ext_coord: f64,
6020        ridge_beta: f64,
6021    ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6022        self.reml_criterion_with_cache_refine_policy(
6023            target,
6024            rho,
6025            registry,
6026            inner_max_iter,
6027            learning_rate,
6028            ridge_ext_coord,
6029            ridge_beta,
6030            true,
6031        )
6032    }
6033
6034    pub(crate) fn reml_criterion_with_cache_refine_policy(
6035        &mut self,
6036        target: ArrayView2<'_, f64>,
6037        rho: &SaeManifoldRho,
6038        registry: Option<&AnalyticPenaltyRegistry>,
6039        inner_max_iter: usize,
6040        learning_rate: f64,
6041        ridge_ext_coord: f64,
6042        ridge_beta: f64,
6043        refine_progress_extension: bool,
6044    ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6045        let admission_plan = self.streaming_plan().admitted_or_error(
6046            self.n_obs(),
6047            self.output_dim(),
6048            self.k_atoms(),
6049        )?;
6050        if !admission_plan.direct_logdet_admitted() {
6051            return Err(format!(
6052                "SaeManifoldTerm::reml_criterion_with_cache: predicted working set {} bytes exceeds budget {} bytes for dense evidence cache; shape n={},p={},K={}; cost-only streaming route is required",
6053                admission_plan.estimated_direct_peak_bytes,
6054                admission_plan.in_core_budget_bytes,
6055                self.n_obs(),
6056                self.output_dim(),
6057                self.k_atoms()
6058            ));
6059        }
6060        // 1. Run the inner (t, β) Newton solve to convergence at FIXED ρ.
6061        //    `run_joint_fit_arrow_schur` no longer touches ρ.
6062        let mut rho_fixed = rho.clone();
6063        let mut loss = self.run_joint_fit_arrow_schur(
6064            target,
6065            &mut rho_fixed,
6066            registry,
6067            inner_max_iter,
6068            learning_rate,
6069            ridge_ext_coord,
6070            ridge_beta,
6071        )?;
6072
6073        // 2. Drive the inner (t, β) solve to the KKT/step-converged optimum and
6074        //    take one final UNDAMPED factor there to obtain the joint Hessian
6075        //    log-determinant. We force ridge = 0 and the dense `Direct` Schur
6076        //    mode so `arrow_log_det_from_cache` returns the exact
6077        //    `log|H| = Σ_i log|H_tt^(i)| + log|Schur_β|` (it rejects damped
6078        //    factors and InexactPCG caches, which have no dense Schur factor).
6079        //    This is the same evidence convention the main GAM REML path uses.
6080        //    The shared `converge_inner_for_undamped_logdet` driver guarantees
6081        //    the per-row `H_tt^(i)` blocks are PD at the converged optimum so
6082        //    the undamped (`ridge = 0`) factorization succeeds — the streaming
6083        //    log-det path reuses the identical driver so both rank the same
6084        //    converged Laplace optimum and stay bit-identical.
6085        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
6086        let cache = self.converge_inner_for_undamped_logdet(
6087            target,
6088            rho,
6089            &mut rho_fixed,
6090            registry,
6091            inner_max_iter,
6092            learning_rate,
6093            ridge_ext_coord,
6094            ridge_beta,
6095            &mut loss,
6096            &options,
6097            refine_progress_extension,
6098        )?;
6099        self.record_evidence_gauge_deflation_count(cache.gauge_deflated_directions)?;
6100        loss.evidence_gauge_deflated_directions = cache.gauge_deflated_directions;
6101        let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
6102            // Distinguish a GENUINE infeasibility — a probed ρ where the joint
6103            // Hessian is not PD so the Laplace evidence log-det is undefined —
6104            // from a real factorization defect. The cross-row IBP Woodbury
6105            // capacitance `C = I_R + D·Uᵀ H₀'⁻¹ U` can have det ≤ 0 at a ρ the
6106            // outer optimizer line-searches into (the indefinite basin adjacent
6107            // to the PD region); there the log-det legitimately does not exist.
6108            // That refusal must be RECOVERABLE (the outer BFGS should get +∞ and
6109            // steer back into the PD region), exactly like the "non-PD per-row
6110            // H_tt block" refusal — not a fatal `RemlOptimizationFailed` that
6111            // aborts the whole fit. See `is_recoverable_value_probe_refusal`.
6112            // (The old message claimed "no dense Schur factor", which is false
6113            // here — the Schur factor is present; the Woodbury correction is the
6114            // non-finite term.)
6115            if cache.cross_row_woodbury.is_some()
6116                && !cache.cross_row_woodbury_log_det().is_finite()
6117            {
6118                "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
6119                 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
6120                    .to_string()
6121            } else {
6122                "SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None \
6123                 (undamped joint Hessian log-det unavailable for the Laplace normaliser)"
6124                    .to_string()
6125            }
6126        })?;
6127
6128        // 3. Smoothing-penalty Occam term `−½·Σ_k r_k·rank(S_k)·log λ_smooth`
6129        //    plus the profiled-frame evidence-dimension correction
6130        //    `+½·Σ_k r_k·(p−r_k)·log λ_smooth` (issue #972). On the full-`B` path
6131        //    (`r_k == p`, no frames) this is exactly the historical
6132        //    `½·p·(Σ rank S_k)·log λ_smooth`, so the small-model criterion is
6133        //    unchanged. The single seam is `reml_occam_term`, shared with the
6134        //    streaming path so both rank the identical Laplace dimension count.
6135        let occam = self.reml_occam_term(rho)?;
6136
6137        // Decoder-block analytic-penalty energy (#671/#672). The inner solve
6138        // descended this energy (it enters `gb`/`hbb`) but it had no native
6139        // `loss.*` representative, so the Laplace criterion `v` was scoring a
6140        // different objective than the one minimized. Add the converged
6141        // decoder-penalty value so the ρ-sweep ranks the same penalized
6142        // deviance. Excludes the Psi-tier ARD/assignment penalties already
6143        // accounted for in `loss.total()` (see
6144        // `analytic_decoder_penalty_value_total`).
6145        // Extra analytic-penalty energy (#671/#737). Decoder-block penalties and
6146        // coordinate-tier isometry enter the inner solve but have no `loss.*`
6147        // representative, so the Laplace criterion must add them explicitly to
6148        // rank the same penalized deviance the Newton solve descends.
6149        let extra_penalty_energy = match registry {
6150            Some(reg) => self
6151                .reml_extra_penalty_value_total(reg)
6152                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
6153            None => 0.0,
6154        };
6155
6156        let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
6157        Ok((v, loss, cache))
6158    }
6159
6160    /// The #1037 quotient-dimension invariant: a Laplace normalizer `½log|H|` is
6161    /// only comparable across ρ at a COMMON quotient (gauge-deflation) dimension.
6162    /// The first observation pins the expected count; a later match is a no-op.
6163    ///
6164    /// A later observation that DIFFERS is, under the K>1 fit, a LEGITIMATE
6165    /// quotient-dimension event — an atom born, reseeded (the #976 collapse
6166    /// guards), or rank-reduced moves the number of gauge-flat rows. Because a
6167    /// deflated direction is lifted to unit stiffness and contributes the
6168    /// ρ-independent `log 1 = 0` to the evidence, re-anchoring the comparison to
6169    /// the new dimension is exactly evidence-preserving and keeps every future
6170    /// cross-ρ comparison consistent — the principled response, not an abort.
6171    ///
6172    /// The genuine pathology the guard still catches is a count that NEVER
6173    /// STABILIZES: re-anchors are bounded by the per-atom structural-event budget
6174    /// (`k·(reseed_budget+1)+1`), and a runaway quotient dimension past that
6175    /// bound refuses loudly. This supersedes the prior strict-constant guard and
6176    /// its ±1 flicker band (#1117) at root — the band was masking exactly the
6177    /// legitimate K>1 dimension changes this re-anchoring now handles.
6178    pub(crate) fn record_evidence_gauge_deflation_count(
6179        &mut self,
6180        count: usize,
6181    ) -> Result<(), String> {
6182        match self.expected_evidence_gauge_deflated_directions {
6183            Some(expected) if expected == count => Ok(()),
6184            Some(expected) => {
6185                // A change in the gauge-deflation count between two evidence
6186                // factorizations is a legitimate quotient-dimension event under
6187                // the K>1 fit: an atom can be born, reseeded (the #976 collapse
6188                // guards), or rank-reduced across the ρ-walk, and each such event
6189                // moves the number of gauge-flat rows. The #1037 invariant is
6190                // NOT "the count never changes" — it is "two Laplace normalizers
6191                // are only comparable at a COMMON quotient dimension". The
6192                // principled response to a legitimate change is therefore to
6193                // RE-ANCHOR the comparison to the new dimension (so every future
6194                // cross-ρ comparison within the optimization is consistent), not
6195                // to abort the fit. This is exactly evidence-preserving: each
6196                // gauge-deflated direction is lifted to unit stiffness and
6197                // contributes the ρ-independent `log 1 = 0` to `½log|H|`, so the
6198                // converged criterion value is identical whether a given row is
6199                // counted as deflated or not — only the BOOKKEEPING dimension
6200                // must agree across a comparison, and re-anchoring restores that.
6201                //
6202                // The genuine pathology the guard must still catch is a count
6203                // that NEVER STABILIZES — an OSCILLATING quotient dimension that
6204                // re-anchors without converging, signalling a truly ill-posed
6205                // evidence surface. But the deflation count is NOT a discrete
6206                // dictionary-level event count: it is the per-ROW-summed number of
6207                // near-null evidence directions across all N rows (#1217). On real
6208                // K≥2 activations it is an O(N) quantity that drifts SMOOTHLY and
6209                // monotonically as the conditioning improves over the ρ-walk
6210                // (e.g. 171→156→…→113 as smoothing increases) — a benign,
6211                // evidence-neutral change (each deflated direction contributes the
6212                // ρ-independent `log 1 = 0` to `½log|H|`, so re-anchoring never
6213                // moves the criterion value). Charging such a monotone drift
6214                // against a `k`-sized "structural event" budget was wrong: it
6215                // counts threshold crossings of a continuous per-row quantity, not
6216                // atom births/reseeds, so the budget tripped on a perfectly healthy
6217                // converging K=2 fit (#1217 regression from the #1189/#1190
6218                // basin-escape fixes, which shifted which rows sit near the
6219                // deflation floor).
6220                //
6221                // The principled discriminator is DIRECTION REVERSALS: a count
6222                // that drifts one way and settles is benign; a count that bounces
6223                // up and down without settling is the oscillating-quotient
6224                // pathology. We therefore charge the re-anchor budget ONLY on a
6225                // reversal of the change direction, and size the budget by the
6226                // number of distinct dictionary structural events (births/reseeds)
6227                // that can each legitimately flip the drift direction. A monotone
6228                // drift of any length re-anchors freely (it is consistently
6229                // re-anchored and evidence-neutral); a genuinely oscillating count
6230                // exhausts the reversal budget and refuses loudly.
6231                let delta_sign: i8 = if count > expected { 1 } else { -1 };
6232                let is_reversal = self.evidence_gauge_deflation_last_delta_sign != 0
6233                    && delta_sign != self.evidence_gauge_deflation_last_delta_sign;
6234                self.evidence_gauge_deflation_last_delta_sign = delta_sign;
6235                // A reversal alone is NOT the pathology — a BOUNDED flicker of a
6236                // few rows crossing the near-null deflation floor reverses
6237                // direction every step yet is the discretization jitter of a
6238                // continuous evidence spectrum, fully evidence-neutral (each
6239                // deflated direction contributes `log 1 = 0` either way). The
6240                // genuine "quotient dimension not stabilizing" pathology is a
6241                // WIDE-amplitude oscillation: a substantial FRACTION of the
6242                // dimension flipping back and forth. The count is an O(N) per-row
6243                // sum, so the discriminator must be the reversal AMPLITUDE
6244                // relative to the dimension level, not the bare reversal. Charge
6245                // the reversal budget only when a reversal's step exceeds a
6246                // relative jitter band; a converged-but-flickering fit (e.g.
6247                // 150<->147 on N=200, ~2% of the level) re-anchors freely while a
6248                // true runaway (e.g. 9<->2, ~80% of the level) still trips every
6249                // reversal and exhausts the budget. This was the second #795 root
6250                // cause: the single-planted-circle fit's per-row count flickers
6251                // 150<->147 near the deflation floor, so the bare-reversal guard
6252                // refused the simplest possible fit — with the isometry gauge ON
6253                // *or* OFF — long before the gauge magnitude mattered.
6254                let amplitude = expected.abs_diff(count);
6255                let level = expected.max(count);
6256                let jitter_band = (level / 4).max(2);
6257                if is_reversal && amplitude > jitter_band {
6258                    self.evidence_gauge_deflation_reanchors += 1;
6259                }
6260                let reversal_budget = self
6261                    .k_atoms()
6262                    .saturating_mul(
6263                        SAE_ATOM_COLLAPSE_RESEED_BUDGET
6264                            + SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET
6265                            + 1,
6266                    )
6267                    .saturating_add(1);
6268                if self.evidence_gauge_deflation_reanchors > reversal_budget {
6269                    return Err(format!(
6270                        "SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count \
6271                         oscillated (reversed direction {} times, last {expected}->{count}) within \
6272                         one optimization, exceeding the {reversal_budget}-reversal budget for {} \
6273                         atoms; the quotient dimension is not stabilizing, refusing to compare \
6274                         Laplace normalizers",
6275                        self.evidence_gauge_deflation_reanchors,
6276                        self.k_atoms()
6277                    ));
6278                }
6279                log::debug!(
6280                    "SaeManifoldTerm::reml_criterion: per-row evidence deflation count changed \
6281                     {expected}->{count} (a benign per-row conditioning drift across the ρ-walk; \
6282                     reversal {}/{reversal_budget}); re-anchoring the Laplace normalizer comparison \
6283                     to the new dimension",
6284                    self.evidence_gauge_deflation_reanchors
6285                );
6286                self.expected_evidence_gauge_deflated_directions = Some(count);
6287                Ok(())
6288            }
6289            None => {
6290                self.expected_evidence_gauge_deflated_directions = Some(count);
6291                Ok(())
6292            }
6293        }
6294    }
6295
6296    pub(crate) fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
6297        matches!(
6298            err,
6299            ArrowSchurError::PerRowFactorFailed { reason, .. }
6300                if reason.contains("H_tt is non-PD at base ridge")
6301                    && reason.contains("evidence mode preserves the genuine Cholesky")
6302        )
6303    }
6304
6305    /// Drive the inner `(t, β)` Newton solve to the KKT/step-converged optimum
6306    /// and return the final UNDAMPED (`ridge = 0`) joint-Hessian factor cache.
6307    ///
6308    /// The Laplace normaliser `½log|H|` is only the correct REML criterion at
6309    /// the inner optimum `(t̂, β̂)`, so the criterion must refine the inner state
6310    /// until either the KKT gradient or the undamped Newton step meets tolerance
6311    /// before factoring. Crucially, **at the converged optimum the per-row
6312    /// `H_tt^(i)` blocks are PD**, so the undamped (`ridge = 0`) factorization
6313    /// succeeds; an off-optimum iterate (e.g. the initial seed, or a state
6314    /// stopped after only `inner_max_iter` steps) can have an indefinite /
6315    /// rank-deficient per-row block (`p_out = 1` → rank-1 `JᵀJ`, softmax
6316    /// assignment-sparsity negative logit curvature) that surfaces
6317    /// `PerRowFactorFailed` from the undamped `factor_one_row`. Both the dense
6318    /// (`reml_criterion_with_cache`) and the streaming
6319    /// (`reml_criterion_streaming_exact`) evidence paths route through this same
6320    /// driver, so they converge to the identical inner state and their
6321    /// `ridge = 0` log-determinants stay bit-identical (#847).
6322    pub(crate) fn converge_inner_for_undamped_logdet(
6323        &mut self,
6324        target: ArrayView2<'_, f64>,
6325        rho: &SaeManifoldRho,
6326        rho_fixed: &mut SaeManifoldRho,
6327        registry: Option<&AnalyticPenaltyRegistry>,
6328        inner_max_iter: usize,
6329        learning_rate: f64,
6330        ridge_ext_coord: f64,
6331        ridge_beta: f64,
6332        loss: &mut SaeManifoldLoss,
6333        options: &ArrowSolveOptions,
6334        refine_progress_extension: bool,
6335    ) -> Result<ArrowFactorCache, String> {
6336        // `inner_max_iter == 0` is a genuine FREEZE of the inner `(t, β)` state
6337        // — a verbatim warm-start reuse, not a convergence request (gam#577/#579,
6338        // #850). The convergence/refinement loop below MUST NOT run even one
6339        // Newton step in that case (the old `inner_max_iter.max(1)` floor moved
6340        // β off the seed), so we factor exactly once at the frozen iterate and
6341        // return that undamped cache without invoking the stationarity gate.
6342        // The caller has already run `run_joint_fit_arrow_schur(..., 0, ...)`,
6343        // which under the `max_iter == 0` freeze (gam#577/#579, #850) runs ONLY
6344        // the β-neutral basis refresh and returns the loss without touching β —
6345        // it skips the rank-reduction, frame activation, re-seed guards, and the
6346        // #1026 decoder-LSQ polish that would otherwise refit β off the seed — so
6347        // `self` is at the warm-start β here.
6348        if inner_max_iter == 0 {
6349            let sys = self
6350                .assemble_arrow_schur(target, rho, registry)
6351                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6352            let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
6353                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6354            // The frozen-state Newton step (factored.0, factored.1) is discarded
6355            // — only the undamped factor cache (factored.2) is consumed for the
6356            // log-det / selected-inverse traces; β stays at the warm-start seed.
6357            return Ok(factored.2);
6358        }
6359        let mut total_inner_iter = inner_max_iter;
6360        let accepted_base_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
6361        let value_probe_base_refine_iter = inner_max_iter.max(1).saturating_mul(4).max(16);
6362        let base_refine_iter = if refine_progress_extension {
6363            accepted_base_refine_iter
6364        } else {
6365            value_probe_base_refine_iter
6366        };
6367        let progress_refine_iter = if refine_progress_extension {
6368            inner_max_iter.max(1).saturating_mul(64).max(256)
6369        } else {
6370            base_refine_iter
6371        };
6372        let mut previous_refine_grad_norm: Option<f64> = None;
6373        let mut saw_refine_progress = false;
6374        // #1051 — objective-stagnation convergence. On an ill-conditioned
6375        // penalised bilinear fit (the euclidean / Duchon decoder × latent
6376        // coordinate system on a trivial shape), the inner Newton crawls: each
6377        // refine round lowers the penalised objective by a shrinking amount while
6378        // the KKT gradient and the undamped step stay above their relative
6379        // tolerances (the near-singular Schur amplifies the step in the
6380        // weakly-identified decoder direction). The grad-OR-step gate then never
6381        // fires and the solve is rejected as "did not converge" — the 1e12
6382        // sentinel. A Newton/LM iterate whose objective has stopped decreasing
6383        // to within `√εmach` of its scale IS the numerical inner optimum; ranking
6384        // the Laplace criterion there is correct. We accept that fixed point
6385        // instead of grinding the budget.
6386        let entry_loss_total = loss.total();
6387        let mut previous_loss_total = entry_loss_total;
6388        let mut refine_rounds: usize = 0;
6389        // Consecutive stall rounds: counts how many successive refine rounds
6390        // ended in a stall AND a failed undamped factor.  Once this reaches
6391        // `SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS` the iterate is at
6392        // its numerical fixed point and cannot be improved further; returning
6393        // `Err` here is the same "did not converge" signal that
6394        // `is_recoverable_value_probe_refusal` already handles, so the outer
6395        // BFGS treats it as an INFINITY probe and tries a different ρ instead
6396        // of looping forever burning the extended progress budget.  Without
6397        // this counter the stagnation handler fell through when the undamped
6398        // factor failed and the loop kept extending via `saw_refine_progress`
6399        // from earlier rounds, accumulating minutes of wasted work (#1094).
6400        let mut consecutive_stall_factor_fail: usize = 0;
6401        loop {
6402            let sys = self
6403                .assemble_arrow_schur(target, rho, registry)
6404                .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6405            // Evidence-only factorization: the Newton step (Δt, Δβ) is discarded
6406            // and only the factor cache is consumed — the exact undamped log-det
6407            // and the selected-inverse traces. As ρ sweeps to extremes (e.g. a
6408            // wide ARD-α sweep), H_tt is genuinely PD but can be ill-conditioned;
6409            // the standard Direct guard rejects that to protect Newton-step
6410            // accuracy, but the log-det is exact from diag(L) regardless of the
6411            // condition number and the traces only need the (PD) factor. So
6412            // tolerate the ill-conditioning rejection here (a genuine non-PD pivot
6413            // still errors). The cache stays undamped at ridge=0, so
6414            // `arrow_log_det_from_cache` remains exact.
6415            // The exact KKT stationarity residual is the joint gradient
6416            // ‖g‖ = √(Σ_i ‖g_t^(i)‖² + ‖g_β‖²), read straight off the assembled
6417            // system. Unlike the Newton step Δ = H⁻¹g, the gradient is
6418            // factorisation-independent: it is NOT amplified by an inverse, so a
6419            // genuinely stationary but ill-conditioned fit (tiny g, possibly large
6420            // Δ in a flat direction) is correctly recognised as converged. The
6421            // `with_ill_conditioning_tolerated` Direct factor below documents that
6422            // its Δ may be inaccurate in exactly those flat directions, so using Δ
6423            // alone as the convergence gate would falsely reject healthy fits.
6424            let grad_norm_sq: f64 = sys
6425                .rows
6426                .iter()
6427                .map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
6428                .sum::<f64>()
6429                + sys.gb.iter().map(|&v| v * v).sum::<f64>();
6430            let grad_norm = grad_norm_sq.sqrt();
6431            // Quotient KKT-gradient (#1117): the raw joint gradient retains a
6432            // persistent small component in the chart-gauge orbit and the
6433            // rank-deficient decoder β-null even at a stationary fit, so the raw
6434            // grad gate never clears on a rank-deficient circle and the inner
6435            // refine loop crawls until the (large) progress budget dies — the
6436            // 2-min stall. Measure the gradient on the SAME identified quotient
6437            // the step gate already uses: a fit whose only remaining gradient
6438            // lives in those flat directions is stationary on the quotient, so
6439            // ranking the Laplace criterion there is correct. The dense per-row
6440            // g_t is laid into the `n·q` coordinate layout the gauge basis spans;
6441            // non-dense/heterogeneous systems fall back to the raw norm.
6442            let quotient_grad_norm = {
6443                let n = self.n_obs();
6444                let q = self.assignment.row_block_dim();
6445                let dense_len = n.saturating_mul(q);
6446                let mut grad_ext_coord = Array1::<f64>::zeros(dense_len);
6447                let mut dense_layout_ok = sys.rows.len() == n;
6448                if dense_layout_ok {
6449                    for (row_idx, row) in sys.rows.iter().enumerate() {
6450                        let base = sys.row_offsets[row_idx];
6451                        let di = sys.row_dims[row_idx];
6452                        if base + di > dense_len || row.gt.len() < di {
6453                            dense_layout_ok = false;
6454                            break;
6455                        }
6456                        for axis in 0..di {
6457                            grad_ext_coord[base + axis] = row.gt[axis];
6458                        }
6459                    }
6460                }
6461                if dense_layout_ok {
6462                    self.quotient_gradient_norm_sq(
6463                        grad_ext_coord.view(),
6464                        sys.gb.view(),
6465                        grad_norm_sq,
6466                        &rho_fixed.lambda_smooth_vec(),
6467                    )
6468                    .map(|v| v.sqrt())
6469                    .unwrap_or(grad_norm)
6470                } else {
6471                    grad_norm
6472                }
6473            };
6474            let iterate_scale = self.inner_iterate_scale();
6475            // Relative parameter-step tolerance for Δ (well-conditioned charts)
6476            // and a scaled KKT-gradient tolerance. Convergence is accepted on
6477            // EITHER a small KKT gradient OR a small undamped Newton step: SAE
6478            // manifold fits contain gauge-like coordinate/decoder directions (the
6479            // circle's rotation gauge, decoder column-space rotations) where the
6480            // shared-block Hessian is near-singular, so the undamped step can stay
6481            // large in that flat direction even at a genuine stationary point; the
6482            // gradient, which is not amplified by the inverse, recognises it. With
6483            // the isometry Gauss-Newton block now a coherent PSD pullback (no
6484            // indefinite Schur pivot), the inner solve reaches true stationarity,
6485            // so the gradient tolerance is a standard relative KKT residual rather
6486            // than the 0.1.154-regression band-aid (3e-3) that masked the
6487            // non-convergence the indefinite curvature caused.
6488            let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
6489            let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
6490            if !grad_norm_sq.is_finite() {
6491                return Err(format!(
6492                    "SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
6493                     at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
6494                     factorisation is degenerate at this ρ"
6495                ));
6496            }
6497            let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
6498                match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
6499                    Ok(factored) => factored,
6500                    Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
6501                        if grad_norm <= grad_tolerance || quotient_grad_norm <= grad_tolerance {
6502                            // K>1: the softmax/IBP logit–coordinate Gauss-Newton
6503                            // cross-terms (H_zt = J_z^T J_t, assembled row-locally from
6504                            // the assignment JVP × basis JVP) can make a per-row H_tt
6505                            // indefinite at the TRUE KKT stationary point — when two
6506                            // atoms' decoders specialise in opposite directions the
6507                            // Schur complement of the logit block goes negative even
6508                            // though the priors and the full-joint GN term are PSD.
6509                            //
6510                            // The undamped evidence factor already conditions that
6511                            // block the PRINCIPLED way: `factor_spectral_deflated_
6512                            // evidence_row` discovers the negative/flat eigen-direction
6513                            // and stiffens it to UNIT curvature (eigenvalue → +1), so it
6514                            // contributes a ρ-INDEPENDENT log 1 = 0 to the evidence —
6515                            // the same quotient pseudo-determinant convention the gauge
6516                            // (#1037) and data-null (#1117) deflations use. Reaching
6517                            // THIS arm at stationarity therefore means even the spectral
6518                            // deflation declined (a non-finite block or a failed
6519                            // eigendecomposition): the state is genuinely broken, so we
6520                            // surface the hard refusal and let the outer BFGS treat this
6521                            // ρ as an INFINITY probe (`is_recoverable_value_probe_
6522                            // refusal`). We must NOT ridge-damp here: a `+ridge·I`
6523                            // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6524                            // bias into the VALUE that the analytic ρ-gradient (built
6525                            // for the undamped Laplace log-det) never sees, desyncing
6526                            // the outer line-search — the multi-atom non-convergence
6527                            // this fix (#1117) removes.
6528                            return Err(format!(
6529                                "SaeManifoldTerm::reml_criterion: stationary undamped \
6530                                 evidence factorization has a non-PD per-row H_tt block \
6531                                 that spectral unit-stiffness deflation could not \
6532                                 condition (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); \
6533                                 {err}"
6534                            ));
6535                        }
6536                        let refine_limit = Self::refine_iteration_limit(
6537                            total_inner_iter,
6538                            base_refine_iter,
6539                            progress_refine_iter,
6540                            previous_refine_grad_norm,
6541                            grad_norm,
6542                            saw_refine_progress,
6543                        );
6544                        if total_inner_iter >= refine_limit {
6545                            // #1117/#1118 — pre-stationarity genuinely-indefinite
6546                            // non-gauge H_tt under K>1 IBP/softmax row-sharing. The
6547                            // logit × coordinate Gauss-Newton cross term H_zt = J_zᵀJ_t
6548                            // can drive a shared row's H_tt Schur complement NEGATIVE off
6549                            // the gauge orbit; the LM-escalated refinement above cannot
6550                            // always cross the indefinite basin into the PD region within
6551                            // the descent-extended budget.
6552                            //
6553                            // The undamped (ridge=0) evidence factor already conditions
6554                            // that block the PRINCIPLED way: `factor_spectral_deflated_
6555                            // evidence_row` discovers the negative/flat eigen-direction
6556                            // and stiffens it to UNIT curvature (eigenvalue → +1), a
6557                            // ρ-INDEPENDENT log 1 = 0 evidence contribution — so the
6558                            // `Ok(factored)` arm above accepts the indefinite block and
6559                            // returns a finite, monotone-comparable value to the outer
6560                            // BFGS WITHOUT a ρ-dependent bias. Reaching THIS arm means
6561                            // even that spectral deflation declined (a non-finite block
6562                            // or a failed eigendecomposition): the iterate is genuinely
6563                            // broken, so we surface the hard refusal and let the outer
6564                            // BFGS treat this ρ as an INFINITY probe.
6565                            //
6566                            // We must NOT ridge-damp here: a `+ridge·I` evidence
6567                            // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6568                            // bias into the VALUE that the analytic ρ-gradient (built
6569                            // for the undamped Laplace log-det) never sees, desyncing
6570                            // the outer line-search — the multi-atom non-convergence this
6571                            // fix removes. K=1 (and any already-PD or spectral-deflatable
6572                            // K>1 row) never reaches this branch.
6573                            return Err(format!(
6574                                "SaeManifoldTerm::reml_criterion: undamped evidence \
6575                                 factorization hit a non-PD per-row H_tt block before KKT \
6576                                 stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
6577                                 and the refinement budget was exhausted after \
6578                                 {total_inner_iter} inner iterations; {err}"
6579                            ));
6580                        }
6581                        let remaining = refine_limit - total_inner_iter;
6582                        let refine_iter = inner_max_iter.max(1).min(remaining);
6583                        saw_refine_progress |=
6584                            Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6585                        previous_refine_grad_norm = Some(grad_norm);
6586                        *loss = self.run_joint_fit_arrow_schur(
6587                            target,
6588                            rho_fixed,
6589                            registry,
6590                            refine_iter,
6591                            learning_rate,
6592                            ridge_ext_coord,
6593                            ridge_beta,
6594                        )?;
6595                        total_inner_iter += refine_iter;
6596                        continue;
6597                    }
6598                    Err(err) => {
6599                        return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
6600                    }
6601                };
6602            // The Laplace normaliser ½log|H| is only the correct REML criterion at
6603            // the inner optimum (t̂, β̂). Convergence is judged by EITHER a small
6604            // gradient (KKT stationarity) OR a small undamped Newton step; the
6605            // solve is only rejected as non-converged when BOTH are large, i.e.
6606            // the iterate is neither stationary nor about to move negligibly. That
6607            // disjunction is what keeps an ill-conditioned-but-stationary fit
6608            // (small g, large Δ) from being rejected while still refusing to rank
6609            // an off-optimum Laplace criterion that is genuinely mid-flight.
6610            let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
6611                + delta_beta.iter().map(|&v| v * v).sum::<f64>();
6612            if !step_norm_sq.is_finite() {
6613                return Err(format!(
6614                    "SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
6615                     the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
6616                     Hessian factorisation is degenerate at this ρ"
6617                ));
6618            }
6619            let step_norm = step_norm_sq.sqrt();
6620            let quotient_step_norm_sq = self.quotient_newton_step_norm_sq(
6621                delta_t.view(),
6622                delta_beta.view(),
6623                step_norm_sq,
6624                &rho_fixed.lambda_smooth_vec(),
6625            )?;
6626            let quotient_step_norm = quotient_step_norm_sq.sqrt();
6627            // Converge on ANY of: the raw KKT gradient (well-conditioned fit),
6628            // the QUOTIENT KKT gradient (#1117 — rank-deficient fit whose only
6629            // residual gradient is gauge/null flat-direction crawl), or the
6630            // quotient Newton step. The quotient-gradient disjunct is what lets
6631            // a rank-deficient K=1 circle terminate in budget instead of crawling
6632            // the weakly-identified valley until the refine budget dies.
6633            if grad_norm <= grad_tolerance
6634                || quotient_grad_norm <= grad_tolerance
6635                || quotient_step_norm <= step_tolerance
6636            {
6637                return Ok(cache);
6638            }
6639            let refine_limit = Self::refine_iteration_limit(
6640                total_inner_iter,
6641                base_refine_iter,
6642                progress_refine_iter,
6643                previous_refine_grad_norm,
6644                grad_norm,
6645                saw_refine_progress,
6646            );
6647            if total_inner_iter >= refine_limit {
6648                // Inner solve did not converge in reml_criterion; the returned
6649                // Err below carries the full non-convergence diagnostic
6650                // (gradient / quotient-step norms and tolerances) to the caller.
6651                return Err(format!(
6652                    "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6653                     neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
6654                     the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
6655                     (raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
6656                     tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
6657                     off-optimum Laplace criterion."
6658                ));
6659            }
6660            let remaining = refine_limit - total_inner_iter;
6661            let refine_iter = inner_max_iter.max(1).min(remaining);
6662            saw_refine_progress |=
6663                Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6664            previous_refine_grad_norm = Some(grad_norm);
6665            *loss = self.run_joint_fit_arrow_schur(
6666                target,
6667                rho_fixed,
6668                registry,
6669                refine_iter,
6670                learning_rate,
6671                ridge_ext_coord,
6672                ridge_beta,
6673            )?;
6674            total_inner_iter += refine_iter;
6675            refine_rounds += 1;
6676            // #1051 — objective-stagnation fixed point. A whole refine round that
6677            // failed to lower the penalised objective by a meaningful FRACTION of
6678            // the total since-entry reduction means the Newton/LM iterate is at
6679            // its numerical optimum: the remaining KKT residual lives in the
6680            // weakly-identified decoder / gauge directions the near-singular Schur
6681            // cannot resolve. Ranking the Laplace criterion at this fixed point is
6682            // correct (the only further motion is cosmetic flat-valley crawl), so
6683            // accept the current cache instead of refining until the budget dies.
6684            // Requires a few completed refine rounds (so the fraction baseline is
6685            // meaningful) but is NOT gated behind the full refine budget — the
6686            // whole point is to terminate the crawl long before that.
6687            let new_loss_total = loss.total();
6688            // Two stagnation signals, both required: (1) the latest refine round
6689            // contributed a negligible FRACTION of the total objective reduction
6690            // achieved since entry — the fit has captured essentially all the
6691            // achievable improvement and is now crawling cosmetically along the
6692            // weakly-identified valley; (2) the absolute relative decrease is
6693            // itself tiny. The fraction test is scale- and rate-free (it fires
6694            // whether the crawl decays fast or slow), so it recognises the
6695            // over-smoothed / rank-deficient fixed point the bare relative floor
6696            // misses, while still never firing on a fit that is materially
6697            // improving round over round.
6698            let total_improvement = (entry_loss_total - new_loss_total).max(0.0);
6699            let round_improvement = (previous_loss_total - new_loss_total).max(0.0);
6700            let objective_scale = previous_loss_total.abs().max(new_loss_total.abs()) + 1.0;
6701            let relative_decrease = round_improvement / objective_scale;
6702            let captured_fraction = if total_improvement > 0.0 {
6703                round_improvement / total_improvement
6704            } else {
6705                0.0
6706            };
6707            let stalled = new_loss_total.is_finite()
6708                && relative_decrease.is_finite()
6709                && (relative_decrease < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL
6710                    || captured_fraction < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION);
6711            previous_loss_total = new_loss_total;
6712            if stalled && refine_rounds >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6713                let stationary_sys = self
6714                    .assemble_arrow_schur(target, rho_fixed, registry)
6715                    .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6716                if let Ok((_dt, _db, stationary_cache)) =
6717                    solve_arrow_newton_step_with_options(&stationary_sys, 0.0, 0.0, options)
6718                {
6719                    return Ok(stationary_cache);
6720                }
6721                // Stagnated AND the undamped factor still fails: this is the
6722                // numerical fixed point of the inner solve under rank-deficient
6723                // or ill-conditioned geometry (e.g. multi-atom euclidean with
6724                // near-zero initial latent coords, #1094).  The iterate cannot
6725                // be improved further at this ρ.  Treat it as "inner solve did
6726                // not converge" — the same signal `is_recoverable_value_probe_refusal`
6727                // already handles, causing the outer BFGS to return INFINITY for
6728                // this ρ probe and try a different one.  Without this early
6729                // return the stagnation handler fell through and the loop kept
6730                // burning the extended `progress_refine_iter` budget indefinitely.
6731                consecutive_stall_factor_fail += 1;
6732                if consecutive_stall_factor_fail >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6733                    return Err(format!(
6734                        "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6735                         objective stalled for {consecutive_stall_factor_fail} consecutive refine \
6736                         rounds (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) and the undamped \
6737                         evidence factorization failed at each stall point — the iterate is at the \
6738                         numerical fixed point under rank-deficient geometry (#{consecutive_stall_factor_fail} \
6739                         stall-factor-fail rounds; refusing to rank an off-optimum Laplace criterion)"
6740                    ));
6741                }
6742            } else {
6743                consecutive_stall_factor_fail = 0;
6744            }
6745        }
6746    }
6747
6748    pub(crate) fn refine_iteration_limit(
6749        total_inner_iter: usize,
6750        base_refine_iter: usize,
6751        progress_refine_iter: usize,
6752        previous_grad_norm: Option<f64>,
6753        grad_norm: f64,
6754        saw_refine_progress: bool,
6755    ) -> usize {
6756        // Flat affine-gauge valleys can keep crawling productively after the
6757        // historical base budget. Extend only when the measured KKT residual has
6758        // shown a real finite round-to-round drop; true stalls end at the base
6759        // work budget (#968/#1029). Value-order probes pass the base budget as
6760        // their progress budget, so this branch cannot make probes expensive.
6761        if total_inner_iter < base_refine_iter {
6762            return base_refine_iter;
6763        }
6764        let making_progress =
6765            saw_refine_progress || Self::refine_round_made_progress(previous_grad_norm, grad_norm);
6766        if making_progress && grad_norm.is_finite() {
6767            progress_refine_iter
6768        } else {
6769            base_refine_iter
6770        }
6771    }
6772
6773    pub(crate) fn refine_round_made_progress(
6774        previous_grad_norm: Option<f64>,
6775        grad_norm: f64,
6776    ) -> bool {
6777        previous_grad_norm
6778            .is_some_and(|prev| prev.is_finite() && grad_norm.is_finite() && grad_norm < prev)
6779    }
6780
6781    pub(crate) fn outer_gradient_arrow_solver<'a>(
6782        &'a self,
6783        cache: &'a ArrowFactorCache,
6784        penalized_gram_scale: &[f64],
6785    ) -> Result<DeflatedArrowSolver<'a>, OuterGradientError> {
6786        let Err(conditioning_err) = Self::outer_gradient_conditioning_error(cache) else {
6787            return Ok(DeflatedArrowSolver::plain(cache));
6788        };
6789        let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
6790            return Err(conditioning_err);
6791        };
6792        if !(max_pivot.is_finite() && max_pivot > 0.0) {
6793            return Err(conditioning_err);
6794        }
6795
6796        // The conditioning gate has already flagged a near-singular joint Hessian
6797        // (`conditioning_err`). Below we attempt to attribute that flatness to the
6798        // closed-form gauge orbit (chart step gauges) plus the penalty-aware
6799        // decoder-null directions and deflate it. When NO such deflatable
6800        // direction can be recovered, the flat subspace is genuinely
6801        // non-identifiable -- a degenerate direction OUTSIDE the gauge orbit -- a
6802        // diagnosis distinct from the raw pivot-ratio conditioning trip. Both
6803        // classes are #1273 FD-eligible, but surfacing the gauge-degenerate case
6804        // as its own [`OuterGradientError::NonIdentifiable`] keeps the diagnostic
6805        // distinction the FD-eligibility contract is built around.
6806        let non_identifiable_err = OuterGradientError::NonIdentifiable {
6807            reason: format!(
6808                "near-singular joint Hessian with no deflatable gauge/decoder-null \
6809                 direction (max pivot {max_pivot:.3e})"
6810            ),
6811        };
6812
6813        let full_len = cache.delta_t_len() + cache.k;
6814        let mut raw_gauges = Vec::new();
6815        for gauge in self
6816            .dense_step_gauge_vectors()
6817            .map_err(OuterGradientError::internal)?
6818        {
6819            if gauge.len() != full_len {
6820                continue;
6821            }
6822            let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6823            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6824                continue;
6825            }
6826            raw_gauges.push(gauge);
6827        }
6828        // #1051/#1273: admit the penalty-aware decoder-β null directions as
6829        // additional deflation candidates. A rank-deficient decoder design
6830        // (e.g. a euclidean-1D line in a p=2 ambient: decoder column rank 1 of
6831        // 3) puts a genuine near-null direction of the joint Hessian in the β
6832        // block, OUTSIDE the closed-form chart gauge orbit. #1273: probing the
6833        // RAW unit-β basis `e_j` produced an INCOMPLETE candidate set — the
6834        // true flat direction is the penalised null of `G_k + λ_smooth·S_k`,
6835        // not an axis-aligned coordinate, so the outer gate rejected trial ρ
6836        // with a pivot ratio (5.3e-16 < 1e-12) that the inner gate (which
6837        // already uses `decoder_beta_null_directions(λ_smooth)`) accepts. Use
6838        // the SAME penalty-aware null directions here, evaluated at the smooth
6839        // scale the Schur factor used, so the outer and inner gates agree.
6840        // These full (n·q + beta_dim)-length vectors drop into the same
6841        // Gram-Schmidt + Rayleigh + Faddeev-Popov path below; the Rayleigh
6842        // floor still keeps only genuinely flat (sub-floor) directions, so a
6843        // well-conditioned decoder is unaffected.
6844        for dir in self
6845            .decoder_beta_null_directions(penalized_gram_scale)
6846            .map_err(OuterGradientError::internal)?
6847        {
6848            if dir.len() == full_len {
6849                raw_gauges.push(dir);
6850            }
6851        }
6852        // #1051/#1273: also admit the decoder COLUMN-SPAN null (an unrealised
6853        // ambient output channel of a rank-deficient decoder), which the
6854        // channel-free basis-null above structurally cannot represent. The
6855        // rank-1-decoder-line geometry (e.g. a 1-D euclidean line in p=2
6856        // ambient: decoder column rank 1 of 2) puts the joint Hessian's
6857        // sub-floor pivot entirely in one output channel; without this
6858        // candidate the outer gate had nothing to deflate it with and rejected
6859        // the trial ρ. The Rayleigh floor below still prunes any candidate that
6860        // is not genuinely flat against the cached Hessian.
6861        for dir in self
6862            .decoder_channel_null_directions()
6863            .map_err(OuterGradientError::internal)?
6864        {
6865            if dir.len() == full_len {
6866                raw_gauges.push(dir);
6867            }
6868        }
6869        if raw_gauges.is_empty() {
6870            return Err(non_identifiable_err);
6871        }
6872
6873        let mut gauge_span: Vec<Array1<f64>> = Vec::new();
6874        for mut gauge in raw_gauges {
6875            for basis in &gauge_span {
6876                let coeff = gauge.dot(basis);
6877                for i in 0..gauge.len() {
6878                    gauge[i] -= coeff * basis[i];
6879                }
6880            }
6881            let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6882            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6883                continue;
6884            }
6885            let inv_norm = norm_sq.sqrt().recip();
6886            for value in gauge.iter_mut() {
6887                *value *= inv_norm;
6888            }
6889            gauge_span.push(gauge);
6890        }
6891        if gauge_span.is_empty() {
6892            return Err(non_identifiable_err);
6893        }
6894
6895        let span_rank = gauge_span.len();
6896        let mut h_span = Array2::<f64>::zeros((span_rank, span_rank));
6897        for col in 0..span_rank {
6898            let h_gauge = match apply_cached_arrow_hessian(
6899                cache,
6900                gauge_span[col].slice(s![..cache.delta_t_len()]),
6901                gauge_span[col].slice(s![cache.delta_t_len()..]),
6902            ) {
6903                Ok(value) => value,
6904                // #1451: a shape/dimension mismatch or non-finite intermediate
6905                // from the Hessian apply is an internal-invariant defect and MUST
6906                // propagate; only a genuine numeric failure on a finite,
6907                // correctly-shaped input keeps the FD-eligible conditioning class.
6908                Err(err) => {
6909                    return Err(OuterGradientError::classify_arrow_solver_error(
6910                        &err,
6911                        conditioning_err.clone(),
6912                    ));
6913                }
6914            };
6915            let h_flat = flatten_arrow_parts(h_gauge.t.view(), h_gauge.beta.view());
6916            for row in 0..span_rank {
6917                h_span[[row, col]] = gauge_span[row].dot(&h_flat);
6918            }
6919        }
6920        for row in 0..span_rank {
6921            for col in 0..row {
6922                let sym = 0.5 * (h_span[[row, col]] + h_span[[col, row]]);
6923                h_span[[row, col]] = sym;
6924                h_span[[col, row]] = sym;
6925            }
6926        }
6927        // #1451: a non-finite entry in the projected gauge Hessian is an
6928        // internal-invariant defect (a NaN/Inf intermediate leaked into the
6929        // span), not a conditioning failure — it MUST propagate rather than be
6930        // masked behind an FD descent. Guard finiteness BEFORE the eigh so only a
6931        // genuine decomposition failure on a finite, correctly-shaped matrix keeps
6932        // the FD-eligible conditioning class.
6933        if !h_span.iter().all(|v| v.is_finite()) {
6934            return Err(OuterGradientError::internal(format!(
6935                "outer_gradient_arrow_solver: non-finite entry in projected gauge \
6936                 Hessian (h_span is {span_rank}x{span_rank})"
6937            )));
6938        }
6939        let (evals, evecs) = h_span
6940            .eigh(Side::Lower)
6941            .map_err(|_| conditioning_err.clone())?;
6942        let strict_gauge_floor = SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR * max_pivot;
6943        let mut orthonormal: Vec<Array1<f64>> = Vec::new();
6944        for eig_idx in 0..evals.len() {
6945            let rayleigh = evals[eig_idx];
6946            if !(rayleigh.is_finite() && rayleigh <= strict_gauge_floor) {
6947                continue;
6948            }
6949            let mut direction = Array1::<f64>::zeros(full_len);
6950            for basis_idx in 0..span_rank {
6951                let coeff = evecs[[basis_idx, eig_idx]];
6952                for row in 0..full_len {
6953                    direction[row] += coeff * gauge_span[basis_idx][row];
6954                }
6955            }
6956            let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6957            if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6958                continue;
6959            }
6960            let inv_norm = norm_sq.sqrt().recip();
6961            for value in direction.iter_mut() {
6962                *value *= inv_norm;
6963            }
6964            orthonormal.push(direction);
6965        }
6966        if orthonormal.is_empty() {
6967            // #1273/#1440: the conditioning gate has ALREADY certified a
6968            // near-singular joint Hessian (`conditioning_err`), so a genuine flat
6969            // direction exists inside the assembled gauge/decoder-null span even
6970            // when no projected-Hessian eigenvector cleared the strict or the
6971            // `fallback_gauge_floor` Rayleigh band. Rather than declining
6972            // (which historically routed the outer step to a finite-difference
6973            // descent direction — the FD instrument #1440 removes), deflate the
6974            // SMALLEST-Rayleigh eigenvector of the projected gauge Hessian
6975            // UNCONDITIONALLY. That eigenvector is the least-curvature member of
6976            // the validated gauge span (a Faddeev-Popov gauge candidate), so the
6977            // Tikhonov stiffness `max_pivot` in `from_orthonormal_gauges` bounds
6978            // its contribution at the Hessian scale and the components orthogonal
6979            // to it are byte-for-byte the plain analytic inverse solve. This keeps
6980            // the descent direction fully ANALYTIC (a projected/damped gradient),
6981            // never a differenced value path.
6982            let mut best_idx = None;
6983            let mut best_rayleigh = f64::INFINITY;
6984            for eig_idx in 0..evals.len() {
6985                let rayleigh = evals[eig_idx];
6986                if rayleigh.is_finite() && rayleigh < best_rayleigh {
6987                    best_idx = Some(eig_idx);
6988                    best_rayleigh = rayleigh;
6989                }
6990            }
6991            if let Some(eig_idx) = best_idx {
6992                let mut direction = Array1::<f64>::zeros(full_len);
6993                for basis_idx in 0..span_rank {
6994                    let coeff = evecs[[basis_idx, eig_idx]];
6995                    for row in 0..full_len {
6996                        direction[row] += coeff * gauge_span[basis_idx][row];
6997                    }
6998                }
6999                let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
7000                if norm_sq.is_finite() && norm_sq > 1.0e-24 {
7001                    let inv_norm = norm_sq.sqrt().recip();
7002                    for value in direction.iter_mut() {
7003                        *value *= inv_norm;
7004                    }
7005                    orthonormal.push(direction);
7006                }
7007            }
7008        }
7009        if orthonormal.is_empty() {
7010            return Err(non_identifiable_err);
7011        }
7012
7013        // Quotient-geometry gauge fixing: add stiffness only along the closed-form
7014        // gauge orbit (Faddeev-Popov style). Components orthogonal to that orbit
7015        // are identical to the original inverse solve, while gauge components are
7016        // bounded at the Hessian scale `max_pivot`.
7017        // #1451: a shape/length mismatch or non-finite stiffness/intermediate in
7018        // the deflated-solver assembly is an internal-invariant defect and MUST
7019        // propagate; only a genuine near-singular gauge Woodbury/back-solve keeps
7020        // the FD-eligible conditioning class.
7021        DeflatedArrowSolver::from_orthonormal_gauges(cache, orthonormal, max_pivot)
7022            .map_err(|err| OuterGradientError::classify_arrow_solver_error(&err, conditioning_err))
7023    }
7024
7025    pub(crate) fn outer_gradient_conditioning_error(
7026        cache: &ArrowFactorCache,
7027    ) -> Result<(), OuterGradientError> {
7028        let pivot = arrow_factor_min_pivot(cache);
7029        let Some(min_pivot) = pivot.min_pivot else {
7030            return Err(OuterGradientError::IllConditioned {
7031                reason: "joint Hessian numerically singular (no cached Cholesky pivots)"
7032                    .to_string(),
7033            });
7034        };
7035        let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
7036            return Err(OuterGradientError::IllConditioned {
7037                reason: "joint Hessian numerically singular (no cached Cholesky pivot scale)"
7038                    .to_string(),
7039            });
7040        };
7041        let ratio = min_pivot / max_pivot;
7042        if min_pivot.is_finite()
7043            && max_pivot.is_finite()
7044            && max_pivot > 0.0
7045            && ratio.is_finite()
7046            && ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
7047        {
7048            return Ok(());
7049        }
7050        Err(OuterGradientError::IllConditioned {
7051            reason: format!(
7052                "joint Hessian numerically singular (min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, max pivot {max_pivot:.3e})",
7053                floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
7054            ),
7055        })
7056    }
7057
7058    /// Smoothing-penalty Occam normalizer `−½ Σ_k r_k·rank(S_k)·log λ_smooth`
7059    /// PLUS the profiled-frame evidence-dimension term `½ Σ_k r_k·(p−r_k)·log
7060    /// λ_smooth` (issue #972).
7061    ///
7062    /// On the full-`B` path every atom's frame rank `r_k == p`, so the first
7063    /// piece reduces to the historical `½ p·(Σ rank S_k)·log λ_smooth` and the
7064    /// Grassmann term is zero — bit-for-bit unchanged. When a frame is active the
7065    /// decoder coordinates `C_k` carry the `⊗ I_{r_k}` Kronecker structure (the
7066    /// smoothing penalty `S_k` now acts on `r_k` channels, not `p`), so the
7067    /// penalty-logdet normalizer uses `r_k·rank(S_k)`; and the `r_k·(p−r_k)`
7068    /// frame degrees of freedom profiled OUT of the border are counted explicitly
7069    /// in the Laplace dimension accounting (evidence honesty) so the criterion
7070    /// cannot buy a free evidence boost by hiding decoder freedom in the frame.
7071    pub(crate) fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
7072        // #1556: λ_smooth is per-atom, so the Occam penalty normalizer and the
7073        // profiled-frame evidence-dimension term are both per-atom sums, each
7074        // atom `k` weighted by its own `log λ_smooth[k]`. With a uniform
7075        // (broadcast) vector this is bit-for-bit the historical global form.
7076        let mut acc = 0.0_f64;
7077        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7078            let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7079            // Penalized decoder dimension: `r_k` coordinate channels carry the
7080            // `S_k` roughness penalty (full-`B` path ⇒ `r_k == p`).
7081            let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7082            // Profiled Grassmann dimensions enter the Laplace evidence dimension
7083            // count with the OPPOSITE sign of the penalty Occam term (they are
7084            // free, unpenalized-by-`S` profiled directions), so `−occam` adds
7085            // `+½ r(p−r) log λ_k` to the criterion `V` — the honesty correction.
7086            let frame_dim = atom.frame_manifold_dimension();
7087            let log_lambda = rho.log_lambda_smooth[atom_idx];
7088            acc += 0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)) * log_lambda;
7089        }
7090        // `V = … − occam`, so the net occam SUBTRACTS the penalty normalizer and
7091        // ADDS the frame-dimension count after the caller's `− occam`.
7092        Ok(acc)
7093    }
7094
7095    /// Per-atom derivative `∂(occam)/∂log λ_smooth[k]` (#1556): atom `k`'s entry
7096    /// is `½·(r_k·rank(S_k) − frame_dim_k)`, matching the per-atom Occam term in
7097    /// [`Self::reml_occam_term`]. Returns one entry per atom in atom order.
7098    pub(crate) fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<Vec<f64>, String> {
7099        let mut out = Vec::with_capacity(self.atoms.len());
7100        for atom in &self.atoms {
7101            let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7102            let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7103            let frame_dim = atom.frame_manifold_dimension();
7104            out.push(0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)));
7105        }
7106        Ok(out)
7107    }
7108
7109    pub fn reml_criterion_streaming_exact(
7110        &mut self,
7111        target: ArrayView2<'_, f64>,
7112        rho: &SaeManifoldRho,
7113        registry: Option<&AnalyticPenaltyRegistry>,
7114        inner_max_iter: usize,
7115        learning_rate: f64,
7116        ridge_ext_coord: f64,
7117        ridge_beta: f64,
7118    ) -> Result<(f64, SaeManifoldLoss), String> {
7119        let mut rho_fixed = rho.clone();
7120        let mut loss = self.run_joint_fit_arrow_schur(
7121            target,
7122            &mut rho_fixed,
7123            registry,
7124            inner_max_iter,
7125            learning_rate,
7126            ridge_ext_coord,
7127            ridge_beta,
7128        )?;
7129        // Drive the inner (t, β) state to the SAME KKT/step-converged optimum the
7130        // dense `reml_criterion_with_cache` reaches before factoring. At that
7131        // optimum the per-row `H_tt^(i)` blocks are PD, so the undamped
7132        // (`ridge_t = 0`) streaming factorization in `streaming_exact_arrow_log_det`
7133        // succeeds — without this, a state stopped after only `inner_max_iter`
7134        // steps can leave a rank-deficient / indefinite row block (`p_out = 1` →
7135        // rank-1 `JᵀJ`, softmax negative-logit curvature) that surfaces
7136        // `PerRowFactorFailed` at base ridge 0. Sharing the driver also keeps the
7137        // streaming and dense log-determinants bit-identical (#847).
7138        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7139        // The dense factor cache from convergence is surplus here — the streaming
7140        // path recomputes the (bit-identical) log-determinant chunk-by-chunk in
7141        // `streaming_exact_arrow_log_det` to bound peak memory — so it is dropped.
7142        let converged_cache = self.converge_inner_for_undamped_logdet(
7143            target,
7144            rho,
7145            &mut rho_fixed,
7146            registry,
7147            inner_max_iter,
7148            learning_rate,
7149            ridge_ext_coord,
7150            ridge_beta,
7151            &mut loss,
7152            &options,
7153            true,
7154        )?;
7155        drop(converged_cache);
7156        let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
7157        let occam = self.reml_occam_term(rho)?;
7158        // Extra analytic-penalty energy (#671/#737), matching the full-batch
7159        // `reml_criterion_with_cache` path so streaming and dense criteria rank
7160        // the identical penalized objective.
7161        let extra_penalty_energy = match registry {
7162            Some(reg) => self
7163                .reml_extra_penalty_value_total(reg)
7164                .map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
7165            None => 0.0,
7166        };
7167        Ok((
7168            loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
7169            loss,
7170        ))
7171    }
7172
7173    pub fn streaming_exact_arrow_log_det(
7174        &mut self,
7175        target: ArrayView2<'_, f64>,
7176        rho: &SaeManifoldRho,
7177        registry: Option<&AnalyticPenaltyRegistry>,
7178    ) -> Result<f64, String> {
7179        if target.dim() != (self.n_obs(), self.output_dim()) {
7180            return Err(format!(
7181                "SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
7182                self.n_obs(),
7183                self.output_dim(),
7184                target.dim()
7185            ));
7186        }
7187        let plan = self.streaming_plan().admitted_or_error(
7188            self.n_obs(),
7189            self.output_dim(),
7190            self.k_atoms(),
7191        )?;
7192        if plan.estimated_dense_schur_bytes > plan.in_core_budget_bytes {
7193            // #988 memory-matrix-free evidence route. The dense k×k reduced Schur
7194            // (≈8 GB at the K=32k manifold border) does NOT fit the in-core
7195            // budget, so estimate log|S| via Stochastic Lanczos Quadrature on the
7196            // matrix-free `schur_matvec` apply (`gam_solve::arrow_schur::
7197            // matrix_free_arrow_evidence_log_det`) instead of assembling +
7198            // Cholesky-factoring the dense Schur. Peak memory is the per-row block
7199            // storage the inner PCG already holds, not the extra O(k²) dense S.
7200            //
7201            // Valid for the NON-IBP (softmax / JumpReLU) evidence, whose exact
7202            // log-det is `log_det_tt + log_det_schur` with NO cross-row Woodbury
7203            // correction. The IBP cross-row term additionally needs
7204            // `log det(I_R + D Uᵀ H₀'⁻¹ U)`, which has no matrix-free route yet, so
7205            // it keeps refusing (loudly, pointing at the dense resident path).
7206            if ibp_assignment_third_channels(&self.assignment, rho)?.is_some() {
7207                return Err(format!(
7208                    "SaeManifoldTerm::streaming_exact_arrow_log_det: predicted dense reduced Schur \
7209                     {} bytes exceeds budget {} bytes and the exact cross-row IBP Woodbury evidence \
7210                     has no matrix-free log-det route yet; route IBP-active large-K fits through the \
7211                     dense resident ArrowFactorCache::arrow_log_det",
7212                    plan.estimated_dense_schur_bytes, plan.in_core_budget_bytes
7213                ));
7214            }
7215            let n_total = self.n_obs();
7216            let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7217            // Assemble the WHOLE system once (a single "chunk" over all rows) so the
7218            // matrix-free reduced-Schur apply `v ↦ S·v` can iterate every row; the
7219            // per-row block storage is exactly what the inner solve already holds.
7220            let full_logits = self.assignment.logits.slice(s![0..n_total, ..]).to_owned();
7221            let full_coords: Vec<Array2<f64>> = self
7222                .assignment
7223                .coords
7224                .iter()
7225                .map(|coord| coord.as_matrix().slice(s![0..n_total, ..]).to_owned())
7226                .collect();
7227            let mut full_chunk = self.materialize_chunk(full_logits, full_coords)?;
7228            if let Some(w) = self.row_loss_weights.as_deref() {
7229                full_chunk.row_loss_weights = Some(w[0..n_total].to_vec());
7230            }
7231            // Full penalty (`penalty_scale = 1.0`): one chunk carries the whole
7232            // objective, matching the summed per-chunk `(end-start)/n_total` scale.
7233            let sys = full_chunk
7234                .assemble_arrow_schur_scaled(target, rho, registry, 1.0)
7235                .map_err(|err| {
7236                    format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7237                })?;
7238            let (log_det_tt, slq) = matrix_free_arrow_evidence_log_det(
7239                &sys,
7240                0.0,
7241                0.0,
7242                &options,
7243                SCHUR_SLQ_LOGDET_PROBES,
7244                SCHUR_SLQ_LOGDET_LANCZOS_STEPS,
7245                SCHUR_SLQ_LOGDET_SEED,
7246            )
7247            .map_err(|err| {
7248                format!(
7249                    "SaeManifoldTerm::streaming_exact_arrow_log_det: matrix-free evidence log-det: {err:?}"
7250                )
7251            })?;
7252            if !slq.estimate.is_finite() {
7253                return Err(format!(
7254                    "SaeManifoldTerm::streaming_exact_arrow_log_det: matrix-free SLQ reduced-Schur \
7255                     log|S| non-finite ({})",
7256                    slq.estimate
7257                ));
7258            }
7259            return Ok(log_det_tt + slq.estimate);
7260        }
7261        let n_total = self.n_obs();
7262        let chunk_size = plan.chunk_size.min(n_total.max(1));
7263        // #972 / #977 T1: the reduced β-Schur is over the FACTORED border when
7264        // frames are active (each chunk inherits the frames via
7265        // `materialize_chunk`, so every `chunk_schur` is `border_dim²`), matching
7266        // the dense path's factored log-det. Full-`B` ⇒ `border_dim == beta_dim`.
7267        let border_dim = if self.frames_active() {
7268            self.factored_border_dim()
7269        } else {
7270            self.beta_dim()
7271        };
7272        let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
7273        let mut log_det_tt = 0.0_f64;
7274        // #1038 cross-row IBP Woodbury accumulators. `M = Uᵀ H₀'⁻¹ U` is
7275        // chunk-additive in `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` and `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ`
7276        // (`A = H₀'` block-diagonal, `U` row-supported), closed against the
7277        // GLOBAL reduced Schur `S = schur_acc` after the loop. `None` for every
7278        // non-IBP (softmax / JumpReLU) term, where the streaming log-det is
7279        // exactly the bare `log_det_tt + log_det_schur` as before.
7280        let mut wood_m0: Option<Array2<f64>> = None;
7281        let mut wood_w: Option<Array2<f64>> = None;
7282        let mut wood_d: Option<Array1<f64>> = None;
7283        let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7284        let mut start = 0usize;
7285        while start < n_total {
7286            let end = (start + chunk_size).min(n_total);
7287            let penalty_scale = (end - start) as f64 / n_total as f64;
7288            let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
7289            let chunk_coords: Vec<Array2<f64>> = self
7290                .assignment
7291                .coords
7292                .iter()
7293                .map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
7294                .collect();
7295            let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
7296            // #1117 — rank deficiency is removed at the basis layer at fit entry
7297            // (`reduce_atoms_to_data_supported_rank`), so each chunk inherits the
7298            // already-reduced full-rank atoms via `materialize_chunk`; there are
7299            // no global deflation projectors to propagate.
7300            // #991: chunk terms inherit the row's design honesty weight slice
7301            // (global mean-1 normalization preserved — NOT re-normalized per
7302            // chunk — so the per-chunk sums reconstruct the global weighted
7303            // objective exactly).
7304            if let Some(w) = self.row_loss_weights.as_deref() {
7305                chunk.row_loss_weights = Some(w[start..end].to_vec());
7306            }
7307            let z_chunk = target.slice(s![start..end, ..]);
7308            let sys = chunk
7309                .assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
7310                .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7311            let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
7312            let (chunk_log_det_tt, chunk_schur, chunk_wood) = streaming
7313                .reduced_schur_log_det_tt_woodbury(0.0, 0.0, &options)
7314                .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7315            log_det_tt += chunk_log_det_tt;
7316            for row in 0..border_dim {
7317                for col in 0..border_dim {
7318                    schur_acc[[row, col]] += chunk_schur[[row, col]];
7319                }
7320            }
7321            if chunk_wood.is_some() && chunk_size < n_total {
7322                // The cross-row IBP empirical mass `M_k = Σ_i z_ik` couples ALL
7323                // rows, so the per-row `H₀'` diagonal (`score_derivative_k(M_k)`)
7324                // and the column coefficient `d_k = w·s'_k(M_k)` are only exact
7325                // when every row is assembled together — a SINGLE chunk. Under a
7326                // genuine multi-chunk pass each chunk would see a partial mass and
7327                // the Woodbury (and the bare per-row log-det) would be inexact, so
7328                // refuse loudly and route to the dense resident path rather than
7329                // return a silently-wrong evidence. The streaming log-det only
7330                // runs when the dense reduced Schur fits budget, so the single-
7331                // chunk regime is the common case; this guards the rest.
7332                return Err(
7333                    "SaeManifoldTerm::streaming_exact_arrow_log_det: exact cross-row IBP \
7334                     Woodbury evidence requires a single-chunk pass (the empirical mass \
7335                     M_k = Σ_i z_ik couples all rows); this shape needs >1 chunk. Route \
7336                     IBP-active large-n fits through the dense resident \
7337                     ArrowFactorCache::arrow_log_det."
7338                        .to_string(),
7339                );
7340            }
7341            if let Some(cw) = chunk_wood {
7342                wood_m0 = Some(match wood_m0.take() {
7343                    Some(mut acc) => {
7344                        acc += &cw.m0;
7345                        acc
7346                    }
7347                    None => cw.m0,
7348                });
7349                wood_w = Some(match wood_w.take() {
7350                    Some(mut acc) => {
7351                        acc += &cw.w;
7352                        acc
7353                    }
7354                    None => cw.w,
7355                });
7356                // `D = diag(d_k)` is per-atom; identical across chunks for a
7357                // single-chunk evidence pass (the regime the streaming log-det
7358                // runs in — the dense reduced Schur must fit budget here), where
7359                // it equals the global mass-derived `cross_row_d`.
7360                wood_d = Some(cw.d);
7361            }
7362            start = end;
7363        }
7364        let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
7365            .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7366        let mut total = log_det_tt + log_det_schur;
7367        // #1038/#1225: close the exact cross-row IBP Woodbury correction
7368        // `log det(I_R + D Uᵀ H₀'⁻¹ U)` so the streaming evidence equals the
7369        // dense `arrow_log_det_from_cache` (which adds the SAME term). Without
7370        // it the streaming criterion would silently drop the entire cross-row
7371        // coupling and disagree with the dense path by exactly `log|C|`.
7372        if let (Some(m0), Some(w), Some(d)) = (wood_m0, wood_w, wood_d) {
7373            let correction = streaming_cross_row_woodbury_log_det(&schur_acc, &m0, &w, &d)
7374                .map_err(|err| {
7375                    format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7376                })?
7377                .ok_or_else(|| {
7378                    "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
7379                     this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
7380                        .to_string()
7381                })?;
7382            total += correction;
7383        }
7384        Ok(total)
7385    }
7386
7387    /// Per-atom decoder-smoothness penalty quadratic form (#1556): entry `k` is
7388    /// the λ-free `<B_k, ½(S_k+S_kᵀ)·B_k> = Σ_oc B_k[:,oc]ᵀ S_k B_k[:,oc]`, the
7389    /// per-atom denominator of atom `k`'s λ_smooth Fellner-Schall update. The sum
7390    /// over atoms is `βᵀ(⊕_k S_k ⊗ I_p)β`, the un-scaled total penalty energy.
7391    /// `S_k` is symmetrised defensively (as the assembler does); the per-atom
7392    /// `½(S+Sᵀ)·B_k` GEMMs ride the multi-GPU batched smoothness GEMM with an
7393    /// exact per-atom CPU fallback.
7394    pub(crate) fn decoder_smoothness_quadratic_form_per_atom(&self) -> Vec<f64> {
7395        let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
7396            .atoms
7397            .iter()
7398            .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
7399            .collect();
7400        let sb_all = batched_smooth_sb(&sb_inputs, true);
7401        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7402        for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
7403            per_atom[atom_idx] = (&atom.decoder_coefficients * sb).sum();
7404        }
7405        per_atom
7406    }
7407
7408    /// Per-atom effective penalized dof of the decoder smoothness penalty
7409    /// (#1556): entry `k` is `tr(S_β⁻¹ · M_k)` with `M_k = (λ_smooth[k]·S_k) ⊗ I`
7410    /// and `S_β⁻¹ = (H⁻¹)_ββ` the Schur-complement inverse, each atom scaled by
7411    /// its OWN `lambda_smooth[atom_idx]`. Built on
7412    /// [`ArrowFactorCache::schur_inverse_apply`]: column `(k,μ,oc)` of `M_k` is
7413    /// `λ_k·S_k[:,μ] ⊗ e_oc` (sparse), so we apply `S_β⁻¹` to that K-vector and
7414    /// read back `result[col]`. The total edf is the sum of the returned vector
7415    /// (a uniform/broadcast λ reproduces the historical global trace).
7416    ///
7417    /// At `K ≥ SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS` this delegates to the
7418    /// matrix-free Hutchinson estimator (the exact `K·M·p`-solve trace is
7419    /// infeasible at that scale); below it the exact column solve is used
7420    /// unchanged.
7421    pub(crate) fn decoder_smoothness_effective_dof_per_atom(
7422        &self,
7423        cache: &ArrowFactorCache,
7424        lambda_smooth: &[f64],
7425    ) -> Result<Vec<f64>, ArrowSchurError> {
7426        let p = self.output_dim();
7427        let frames_active = self.frames_active();
7428        let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7429            let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7430            (
7431                self.factored_beta_offsets(),
7432                Box::new(move |k: usize| ranks[k]),
7433            )
7434        } else {
7435            (self.beta_offsets(), Box::new(move |_k: usize| p))
7436        };
7437        let k = cache.k;
7438        if self.atoms.len() >= Self::SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS {
7439            // Massive-K: `Σ_k M_k·r_k` exact solves is infeasible — estimate every
7440            // atom's trace matrix-free with one `S_β⁻¹` solve per Hutchinson probe.
7441            return self
7442                .decoder_smoothness_effective_dof_per_atom_hutchinson(
7443                    k,
7444                    &offsets,
7445                    out_dim.as_ref(),
7446                    lambda_smooth,
7447                    Self::SMOOTHNESS_DOF_HUTCHINSON_PROBES,
7448                    Self::SMOOTHNESS_DOF_HUTCHINSON_SEED,
7449                    |rhs| {
7450                        cache
7451                            .schur_inverse_apply(rhs)
7452                            .map_err(|e| format!("schur_inverse_apply: {e:?}"))
7453                    },
7454                )
7455                .map_err(|reason| ArrowSchurError::SchurFactorFailed { reason });
7456        }
7457        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7458        let mut m_col = Array1::<f64>::zeros(k);
7459        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7460            let s = &atom.smooth_penalty;
7461            let m = atom.basis_size();
7462            let off = offsets[atom_idx];
7463            let r = out_dim(atom_idx);
7464            let lambda = lambda_smooth[atom_idx];
7465            let mut trace = 0.0_f64;
7466            for mu in 0..m {
7467                for oc in 0..r {
7468                    let col = off + mu * r + oc;
7469                    m_col.fill(0.0);
7470                    for nu in 0..m {
7471                        let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7472                        m_col[off + nu * r + oc] = lambda * s_nu_mu;
7473                    }
7474                    let z = cache.schur_inverse_apply(m_col.view())?;
7475                    trace += z[col];
7476                }
7477            }
7478            per_atom[atom_idx] = trace;
7479        }
7480        Ok(per_atom)
7481    }
7482
7483    /// Per-atom effective penalized dof via the deflated solver (#1556): entry
7484    /// `k` is `tr((H⁻¹)_ββ · M_k)` for `M_k = (λ_smooth[k]·S_k) ⊗ I`, each atom
7485    /// scaled by its OWN `lambda_smooth[atom_idx]`. The total is the sum.
7486    pub(crate) fn decoder_smoothness_effective_dof_with_solver_per_atom(
7487        &self,
7488        cache: &ArrowFactorCache,
7489        solver: &DeflatedArrowSolver<'_>,
7490        lambda_smooth: &[f64],
7491    ) -> Result<Vec<f64>, String> {
7492        let p = self.output_dim();
7493        // #972 / #977 T1: the cache's β block is the FACTORED border when frames
7494        // are active (`cache.k == factored_border_dim`), so the smoothness edf
7495        // trace `tr((H⁻¹)_ββ · M)` is taken over the same factored layout, with
7496        // `M = ⊕_k (λ_k S_k) ⊗ I_{r_k}` at the factored offsets (the `U_kᵀU_k = I`
7497        // collapse means the per-coordinate-channel penalty is `λ_k S_k`, exactly
7498        // as in the full-`B` `⊗ I_p` case but with `r_k` channels). On the
7499        // full-`B` path `frames_active` is false: `out_dim_k = p`, the offsets
7500        // are `beta_offsets`, and this is bit-for-bit the historical trace.
7501        let frames_active = self.frames_active();
7502        let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7503            let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7504            (
7505                self.factored_beta_offsets(),
7506                Box::new(move |k: usize| ranks[k]),
7507            )
7508        } else {
7509            (self.beta_offsets(), Box::new(move |_k: usize| p))
7510        };
7511        let k = cache.k;
7512        // The t-RHS is identically zero for every β-only smoothness solve; build
7513        // it once instead of re-zeroing a delta_t_len()-sized buffer per column.
7514        let zero_t = Array1::<f64>::zeros(cache.delta_t_len());
7515        if self.atoms.len() >= Self::SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS {
7516            // Massive-K matrix-free path: one deflated `(H⁻¹)_ββ` solve per
7517            // Hutchinson probe estimates ALL per-atom traces, replacing the
7518            // `Σ_k M_k·r_k` deflated solves that form the `O(K³·M·p)` wall.
7519            return self.decoder_smoothness_effective_dof_per_atom_hutchinson(
7520                k,
7521                &offsets,
7522                out_dim.as_ref(),
7523                lambda_smooth,
7524                Self::SMOOTHNESS_DOF_HUTCHINSON_PROBES,
7525                Self::SMOOTHNESS_DOF_HUTCHINSON_SEED,
7526                |rhs| Ok(solver.solve(zero_t.view(), rhs)?.beta),
7527            );
7528        }
7529        let mut per_atom = vec![0.0_f64; self.atoms.len()];
7530        let mut m_col = Array1::<f64>::zeros(k);
7531        for (atom_idx, atom) in self.atoms.iter().enumerate() {
7532            let s = &atom.smooth_penalty;
7533            let m = atom.basis_size();
7534            let off = offsets[atom_idx];
7535            let r = out_dim(atom_idx);
7536            let lambda = lambda_smooth[atom_idx];
7537            let mut trace = 0.0_f64;
7538            for mu in 0..m {
7539                for oc in 0..r {
7540                    let col = off + mu * r + oc;
7541                    // M[:,col] = λ_k · S_k[:,mu] ⊗ e_oc (nonzero at off+ν·r+oc).
7542                    m_col.fill(0.0);
7543                    for nu in 0..m {
7544                        let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7545                        m_col[off + nu * r + oc] = lambda * s_nu_mu;
7546                    }
7547                    let z = solver.solve(zero_t.view(), m_col.view())?.beta;
7548                    trace += z[col];
7549                }
7550            }
7551            per_atom[atom_idx] = trace;
7552        }
7553        Ok(per_atom)
7554    }
7555
7556    pub(crate) fn assignment_log_strength_hessian_trace(
7557        &self,
7558        rho: &SaeManifoldRho,
7559        cache: &ArrowFactorCache,
7560        solver: &DeflatedArrowSolver<'_>,
7561    ) -> Result<f64, String> {
7562        let k_atoms = self.k_atoms();
7563        // #1038 softmax: `H` carries the DENSE entropy block, and since the
7564        // entropy curvature scales linearly with `λ_sparse = exp(ρ)`,
7565        // `∂H/∂ρ = H_entropy` (the full dense per-row block, not just its
7566        // diagonal). The trace `½ tr(H⁻¹ ∂H/∂ρ)` must therefore contract the
7567        // dense `∂H/∂ρ` against the per-row selected-inverse BLOCK, mirroring the
7568        // dense `log|H|` and θ-adjoint — a diagonal-only contraction would
7569        // desync the ρ-gradient from the criterion. The assembled majorizer
7570        // `D = diag(Σ_j|H_kj|)` is itself DIAGONAL (#1419), so the contraction
7571        // reduces to `½ Σ_slot (H⁻¹)_{slot,slot}·D_atom`. On the dense `None`
7572        // layout the logit slot equals the atom position; on the compact
7573        // softmax top-`k` layout (#1408/#1409) the slots are the row's active
7574        // atoms — the SAME `D_atom` (full-`K` abs-row-sum) the assembly wrote.
7575        if let AssignmentMode::Softmax {
7576            temperature,
7577            sparsity,
7578        } = self.assignment.mode
7579        {
7580            if k_atoms <= 1 {
7581                return Ok(0.0);
7582            }
7583            let inv_tau = 1.0 / temperature;
7584            let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
7585            let penalty = gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
7586                k_atoms,
7587                temperature,
7588            );
7589            // Softmax uses the reduced K−1 free-logit chart on the dense layout
7590            // (last reference logit fixed); the compact layout carries one slot
7591            // per active atom. The diagonal selected inverse gives each slot's
7592            // (H⁻¹)_{slot,slot}.
7593            let assignment_dim = self.assignment.assignment_coord_dim();
7594            // Kept-subspace inverse diagonal: the deflated inverse assigns
7595            // `1/λ̃ = 1` to each per-row UNIT-stiffness direction `vᵢ`, so a raw
7596            // diagonal `D` contraction would spuriously add `½ Σ_i vᵢᵀ D vᵢ` (a
7597            // ρ-independent direction must add 0). `latent_inverse_diagonal_kept`
7598            // removes that per-row deflated diagonal centrally.
7599            let inv_diag = solver
7600                .latent_inverse_diagonal_kept()
7601                .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7602            let mut trace = 0.0_f64;
7603            for row in 0..self.n_obs() {
7604                let row_base = cache.row_offsets[row];
7605                // ∂(scale·D)/∂ρ = scale·D (linear in λ_sparse = eᵖ) — the SAME
7606                // operator the assembly and θ-adjoint differentiate.
7607                match self.last_row_layout {
7608                    Some(ref layout) => {
7609                        // #1410: the compact adjoint reads `D_kk` only for this
7610                        // row's `≤ top_k` active atoms, so compute those entries
7611                        // directly from the softmax row `a` via the active-only
7612                        // Gershgorin helper — no full-`K` `row_logits` copy and no
7613                        // full-`K` `d` vector. `a` itself is the irreducible `O(K)`
7614                        // softmax normalisation, computed once per row and shared
7615                        // across the row's active slots.
7616                        let a = crate::assignment::softmax_row(
7617                            self.assignment.logits.row(row),
7618                            temperature,
7619                        );
7620                        let a = a.as_slice().expect("softmax row must be contiguous");
7621                        let m = softmax_majorizer_log_mean(a);
7622                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7623                            let d_atom =
7624                                active_softmax_gershgorin_majorizer_entry(a, atom, m, scale);
7625                            trace += inv_diag[row_base + pos] * d_atom;
7626                        }
7627                    }
7628                    None => {
7629                        // Dense layout genuinely contracts every free logit slot's
7630                        // `D_kk`, so the full-`K` `d` is intrinsic here; keep the
7631                        // single-source dense majorizer call.
7632                        let row_logits: Vec<f64> = (0..k_atoms)
7633                            .map(|k| self.assignment.logits[[row, k]])
7634                            .collect();
7635                        let d = penalty.psd_majorizer_abs_row_sums(&row_logits, scale);
7636                        let q = cache.row_dims[row];
7637                        let logit_dim = assignment_dim.min(q);
7638                        for atom in 0..logit_dim {
7639                            trace += inv_diag[row_base + atom] * d[atom];
7640                        }
7641                    }
7642                }
7643            }
7644            return Ok(0.5 * trace);
7645        }
7646        let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
7647        if hdiag.is_empty() {
7648            return Ok(0.0);
7649        }
7650        // RAW selected-inverse diagonal: the per-row diagonal contraction uses the
7651        // DEFLATED inverse; the full kept-subspace + β-Schur/rotation deflation
7652        // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per row afterwards
7653        // (`deflation_block_correction`), exactly as the data trace does. The
7654        // cross-row off-diagonal pass below contracts only DISTINCT rows `i ≠ j`,
7655        // off any single-row `vᵢ`'s support, so it needs no deflation correction.
7656        let inv_diag = solver
7657            .latent_inverse_diagonal()
7658            .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7659        let assignment_dim = self.assignment.assignment_coord_dim();
7660        let total_t = cache.delta_t_len();
7661        // #932 FRONT C: row-local Takahashi selected inverse on the plain arrow
7662        // for the per-row deflation correction below (the diagonal trace already
7663        // uses the cheap `latent_inverse_diagonal`); gauge / cross-row Woodbury
7664        // fall back to the per-row full-system `solve` loop.
7665        let fast_selected = solver.plain_selected_inverse_available();
7666        let selected_beta_inv = if fast_selected && cache.k > 0 {
7667            solver
7668                .beta_inv()
7669                .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?
7670        } else {
7671            Array2::<f64>::zeros((0, 0))
7672        };
7673        // #1416 cross-row IBP source: the per-row block that the deflation
7674        // factorizes is the NO-SELF base `H₀'` — the rank-one self curvature
7675        // `d_k·J_ik²` is DOWNDATED from each logit diagonal and re-applied through
7676        // the Woodbury carrier. The full-`H` diagonal contraction below still uses
7677        // the full `hdiag` (which carries that self term), but the per-row
7678        // DEFLATION correction must use `(∂H₀'/∂ρ)_tt`, i.e. `hdiag` MINUS the
7679        // downdated self term — otherwise the Daleckii–Krein correction
7680        // mis-attributes the (un-deflated) Woodbury self curvature's derivative to
7681        // the deflated subspace. For non-IBP modes there is no Woodbury source and
7682        // the self term is `0` (the deflated block IS the full block).
7683        // #1416 (compact-layout completion): the IBP cross-row Woodbury source is
7684        // installed for BOTH the dense and the compact (#1420 top-`k`) layouts (see
7685        // `set_ibp_cross_row_source`, which emits `(g_base + pos, atom, z'_ik)` for
7686        // the active set under a compact layout), so the deflated base `H₀'` is the
7687        // no-self block in BOTH layouts. The self-curvature downdate below must
7688        // therefore run regardless of layout — gating it to the dense path (the
7689        // pre-fix bug) left the compact deflation correction differentiating the
7690        // un-downdated full block. For non-IBP modes `ibp_assignment_third_channels`
7691        // returns `None`, there is no Woodbury source, and `self_curv` is
7692        // identically 0 (the deflated block IS the full block).
7693        let cross_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
7694        let learnable_alpha = matches!(
7695            self.assignment.mode,
7696            AssignmentMode::IBPMap {
7697                learnable_alpha: true,
7698                ..
7699            }
7700        );
7701        let self_curv = |row: usize, atom: usize| -> f64 {
7702            let Some(ch) = cross_channels.as_ref() else {
7703                return 0.0;
7704            };
7705            let d_k = if learnable_alpha {
7706                ch.cross_row_d_logalpha[atom]
7707            } else {
7708                ch.cross_row_d[atom]
7709            };
7710            let j = ch.z_jac[row * k_atoms + atom];
7711            d_k * j * j
7712        };
7713        let mut trace = 0.0_f64;
7714        // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback:
7715        // single-entry set/clear instead of a per-column total_t-sized zeroing.
7716        let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
7717        let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
7718        for row in 0..self.n_obs() {
7719            let row_base = cache.row_offsets[row];
7720            let assignment_base = row * k_atoms;
7721            let q = cache.row_dims[row];
7722            // Per-row diagonal `(∂H₀'/∂ρ)_tt` for the deflation correction: the
7723            // assignment prior curves only the logit/assignment slots (coordinate
7724            // slots are 0 — ARD handles those), MINUS the downdated cross-row self
7725            // curvature. The full-`H` trace contraction keeps the full `hdiag`.
7726            let mut d_diag = Array1::<f64>::zeros(q);
7727            match self.last_row_layout {
7728                Some(ref layout) => {
7729                    for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7730                        let d_slot = hdiag[assignment_base + atom];
7731                        trace += inv_diag[row_base + pos] * d_slot;
7732                        if pos < q {
7733                            d_diag[pos] = d_slot - self_curv(row, atom);
7734                        }
7735                    }
7736                }
7737                None => {
7738                    for free_idx in 0..assignment_dim {
7739                        let d_slot = hdiag[assignment_base + free_idx];
7740                        trace += inv_diag[row_base + free_idx] * d_slot;
7741                        if free_idx < q {
7742                            d_diag[free_idx] = d_slot - self_curv(row, free_idx);
7743                        }
7744                    }
7745                }
7746            }
7747            let dirs = cache
7748                .deflated_row_directions
7749                .get(row)
7750                .map(Vec::as_slice)
7751                .unwrap_or(&[]);
7752            if !dirs.is_empty() {
7753                let inv_vv = if fast_selected {
7754                    let (inv_vv, _inv_vbeta) = solver
7755                        .selected_inverse_row_blocks(row, &selected_beta_inv)
7756                        .map_err(|err| {
7757                            format!("assignment_log_strength_hessian_trace: selected inverse: {err}")
7758                        })?;
7759                    inv_vv
7760                } else {
7761                    let mut inv_vv = Array2::<f64>::zeros((q, q));
7762                    for col in 0..q {
7763                        rhs_t_scratch[row_base + col] = 1.0;
7764                        let solved = solver
7765                            .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
7766                            .map_err(|err| {
7767                                format!(
7768                                    "assignment_log_strength_hessian_trace: selected inverse: {err}"
7769                                )
7770                            })?;
7771                        rhs_t_scratch[row_base + col] = 0.0;
7772                        for r in 0..q {
7773                            inv_vv[[r, col]] = solved.t[row_base + r];
7774                        }
7775                    }
7776                    inv_vv
7777                };
7778                let mut d_mat = Array2::<f64>::zeros((q, q));
7779                for s in 0..q {
7780                    d_mat[[s, s]] = d_diag[s];
7781                }
7782                let spectrum = cache
7783                    .deflation_row_spectra
7784                    .get(row)
7785                    .and_then(Option::as_ref);
7786                trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
7787            }
7788        }
7789        // #1416: the IBP prior Hessian is `H_p = d·J Jᵀ + diag(s, c)`, where the
7790        // rank-one `d·J Jᵀ` couples EVERY row pair `(i, j)` in a column `k`
7791        // through the shared empirical mass `M_k`. The assembled `H` carries the
7792        // full `H_full = H₀' + U D Uᵀ` (Woodbury, `set_ibp_cross_row_source`), and
7793        // for fixed alpha the entire IBP prior scales with `λ = eᵖ`, so
7794        // `∂H_p/∂ρ = H_p`. The diagonal loop above already captures the `i = j`
7795        // self terms (the `d·J_ik²` summand lives in `hdiag`); this pass adds the
7796        // omitted off-diagonal `½·d_k·Σ_{i≠j}(H⁻¹)_{ik,jk}·J_ik·J_jk`. Only IBP
7797        // has the cross-row rank-one source; for other diagonal modes
7798        // `ibp_assignment_third_channels` returns `None` and the trace stays the
7799        // pure diagonal contraction.
7800        //
7801        // #1416 (compact completion): this pass is LAYOUT-AGNOSTIC. Under the dense
7802        // layout atom `k`'s logit slot is local position `k`
7803        // (`row_offsets[i] + k`); under the compact (#1420 top-`k`) layout only the
7804        // row's active atoms carry coordinates and atom `k` lives at local position
7805        // `pos` of `active_atoms[row]` (`row_offsets[i] + pos`). The Woodbury source
7806        // and the θ-adjoint already use this active-slot mapping, so gating the
7807        // cross-row pass to the dense layout (the pre-fix bug) dropped the
7808        // off-diagonal term from `∂log|H|/∂ρ` whenever the budget/`top_k` engaged
7809        // the compact layout. We build per-column active sites `(row, t_index)` once
7810        // — exactly the θ-adjoint `col_sites` construction — then contract the
7811        // off-diagonal `i ≠ j` remainder with one solve per active site.
7812        if let Some(channels) = cross_channels.as_ref() {
7813            let n = self.n_obs();
7814            let total_t = cache.delta_t_len();
7815            // This trace is ½ ∂log|H|/∂ρ. For FIXED-α IBP the whole prior
7816            // scales with λ=eᵖ so ∂H_p/∂ρ = H_p and the rank-one coefficient
7817            // is the VALUE `cross_row_d[k] = w·s'_k`. For LEARNABLE-α this trace
7818            // is ½ ∂log|H|/∂logα, and the rank-one block's logα-derivative is
7819            // `∂d_k/∂logα = w·∂s'_k/∂logα` (`cross_row_d_logalpha[k]`) — the same
7820            // α-derivative the DIAGONAL channel (`hessian_diag_log_alpha_derivative`)
7821            // already uses. Using the value `s'_k` here (the pre-fix bug) made the
7822            // off-diagonal inconsistent with the diagonal and the α-gradient wrong.
7823            // (`learnable_alpha` is the same flag the self-curvature downdate uses.)
7824            // Per-column active sites `(row, global t-index)`. Layout-agnostic.
7825            let mut col_sites: Vec<Vec<(usize, usize)>> = vec![Vec::new(); k_atoms];
7826            match self.last_row_layout {
7827                Some(ref layout) => {
7828                    for row in 0..n {
7829                        let base = cache.row_offsets[row];
7830                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7831                            col_sites[atom].push((row, base + pos));
7832                        }
7833                    }
7834                }
7835                None => {
7836                    for row in 0..n {
7837                        let base = cache.row_offsets[row];
7838                        for k in 0..k_atoms {
7839                            col_sites[k].push((row, base + k));
7840                        }
7841                    }
7842                }
7843            }
7844            let mut cross = 0.0_f64;
7845            // Hoisted RHS scratch: each active site sets exactly one t-slot, so
7846            // set-then-clear that single entry rather than allocating and zeroing
7847            // a total_t-sized vector per (column, site).
7848            let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
7849            let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
7850            for k in 0..k_atoms {
7851                let d_k = if learnable_alpha {
7852                    channels.cross_row_d_logalpha[k]
7853                } else {
7854                    channels.cross_row_d[k]
7855                };
7856                if d_k == 0.0 || col_sites[k].len() < 2 {
7857                    continue;
7858                }
7859                for &(i, t_i) in &col_sites[k] {
7860                    let j_ik = channels.z_jac[i * k_atoms + k];
7861                    if j_ik == 0.0 {
7862                        continue;
7863                    }
7864                    // (H⁻¹) column at row `i`'s active logit-`k` slot.
7865                    rhs_t_scratch[t_i] = 1.0;
7866                    let solved = solver
7867                        .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
7868                        .map_err(|err| {
7869                            format!("assignment_log_strength_hessian_trace: {err}")
7870                        })?;
7871                    rhs_t_scratch[t_i] = 0.0;
7872                    for &(j, t_j) in &col_sites[k] {
7873                        if j == i {
7874                            continue;
7875                        }
7876                        let j_jk = channels.z_jac[j * k_atoms + k];
7877                        if j_jk == 0.0 {
7878                            continue;
7879                        }
7880                        cross += d_k * solved.t[t_j] * j_ik * j_jk;
7881                    }
7882                }
7883            }
7884            trace += cross;
7885        }
7886        Ok(0.5 * trace)
7887    }
7888
7889    pub(crate) fn learnable_ibp_forward_alpha_data_derivative(
7890        &self,
7891        rho: &SaeManifoldRho,
7892        target: ArrayView2<'_, f64>,
7893    ) -> Result<f64, String> {
7894        let AssignmentMode::IBPMap {
7895            temperature: _,
7896            learnable_alpha: true,
7897            ..
7898        } = self.assignment.mode
7899        else {
7900            return Ok(0.0);
7901        };
7902        let alpha = self
7903            .assignment
7904            .resolved_ibp_alpha(rho)
7905            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
7906        let k_atoms = self.k_atoms();
7907        let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
7908        let mut dprior = Array1::<f64>::zeros(k_atoms);
7909        for k in 0..k_atoms {
7910            // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
7911            // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
7912            // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
7913            dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
7914        }
7915        let n = self.n_obs();
7916        let p = self.output_dim();
7917        let row_loss_w = self.row_loss_weights.as_deref();
7918        let whitens = self
7919            .row_metric
7920            .as_ref()
7921            .is_some_and(|metric| metric.whitens_likelihood());
7922        let mut decoded = vec![0.0_f64; p];
7923        let mut fitted = Array1::<f64>::zeros(p);
7924        let mut f_rho = Array1::<f64>::zeros(p);
7925        let mut residual = Array1::<f64>::zeros(p);
7926        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
7927        let mut assignments = vec![0.0_f64; k_atoms];
7928        let mut total = 0.0_f64;
7929        for row in 0..n {
7930            self.assignment
7931                .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
7932            fitted.fill(0.0);
7933            f_rho.fill(0.0);
7934            for k in 0..k_atoms {
7935                self.atoms[k].fill_decoded_row(row, &mut decoded);
7936                // Ungated (#1026 background-tier) atoms have a force-fixed unit
7937                // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
7938                // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
7939                // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
7940                // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
7941                // but `a_k` still varies with α through `π_k`, so it must NOT be
7942                // skipped.)
7943                let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
7944                    0.0
7945                } else {
7946                    (assignments[k] / prior[k]) * dprior[k]
7947                };
7948                for out_col in 0..p {
7949                    fitted[out_col] += assignments[k] * decoded[out_col];
7950                    f_rho[out_col] += da_rho * decoded[out_col];
7951                }
7952            }
7953            for out_col in 0..p {
7954                residual[out_col] = fitted[out_col] - target[[row, out_col]];
7955            }
7956            let residual_metric = match self.row_metric.as_ref() {
7957                Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
7958                _ => residual.to_vec(),
7959            };
7960            let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
7961            let mut row_dot = 0.0_f64;
7962            for out_col in 0..p {
7963                row_dot += residual_metric[out_col] * f_rho[out_col];
7964            }
7965            total += row_weight * row_dot;
7966        }
7967        Ok(total)
7968    }
7969
7970    /// Per-row spectral-deflation correction `tr((H⁻¹)_tt · (D − DΦ[D]))` for one
7971    /// evidence ρ-component, to be SUBTRACTED from the raw-derivative trace
7972    /// `tr((H⁻¹)_tt · D)` the trace otherwise accumulates.
7973    ///
7974    /// The criterion VALUE re-deflates each per-row `H_tt` at every ρ, so the
7975    /// correct evidence gradient contracts `(H⁻¹)_tt` against the deflation-map
7976    /// derivative `DΦ[D]`, not the raw `D = (∂H_raw/∂ρ)_tt`. By Daleckii–Krein,
7977    /// in the row's RAW eigenbasis `U`,
7978    ///   `DΦ[D] = U (F ∘ (Uᵀ D U)) Uᵀ`,  `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)`
7979    /// (raw `λ` in the denominator, conditioned `λ̃` in the numerator; the
7980    /// diagonal / degenerate entry is `f'(λₘ) = 1` for an unclamped kept
7981    /// direction and `0` otherwise). Hence `D − DΦ[D] = U ((1−F) ∘ (Uᵀ D U)) Uᵀ`,
7982    /// whose kept×kept block is `0`, deflated×deflated block is the full `M`, and
7983    /// kept(m)×deflated(i) block carries the ROTATION coefficient
7984    /// `(1−λᵢ)/(λₘ−λᵢ)`. Contracting against the FULL deflated selected-inverse
7985    /// t-block `inv_vv` (which carries the β-Schur back-substitution) captures
7986    /// both the within-row kept-subspace term and the deferred β-Schur/rotation
7987    /// coupling in one pass, matching the re-deflating fixed-state FD oracle.
7988    ///
7989    /// `spectrum = Some` (spectral deflation): exact Daleckii–Krein. `None` with a
7990    /// non-empty `dirs` (gauge-only deflation, ρ-independent structural null):
7991    /// fall back to the within-row kept-subspace term `Σᵢ vᵢᵀ D vᵢ`.
7992    /// `inv_vv` is assumed symmetric (selected inverse of a symmetric PD system).
7993    // #1610 — `pub(crate)` so the ARD/latent-block helpers moved into
7994    // `construction_ard.rs` (pure code move to stay under the 10k-line ban gate)
7995    // can still call this from the sibling module.
7996    pub(crate) fn deflation_block_correction(
7997        inv_vv: &Array2<f64>,
7998        d_mat: &Array2<f64>,
7999        dirs: &[Array1<f64>],
8000        spectrum: Option<&RowDeflationSpectrum>,
8001    ) -> f64 {
8002        let q = inv_vv.nrows();
8003        let Some(spec) = spectrum else {
8004            // Gauge-only deflation: ρ-independent structural null → within-row term.
8005            let mut acc = 0.0_f64;
8006            for v in dirs {
8007                for a in 0..q {
8008                    let va = if a < v.len() { v[a] } else { 0.0 };
8009                    if va == 0.0 {
8010                        continue;
8011                    }
8012                    for b in 0..q {
8013                        let vb = if b < v.len() { v[b] } else { 0.0 };
8014                        acc += va * vb * d_mat[[a, b]];
8015                    }
8016                }
8017            }
8018            return acc;
8019        };
8020        let u = &spec.evecs;
8021        if u.nrows() != q || u.ncols() != q {
8022            return 0.0;
8023        }
8024        let raw = &spec.raw_evals;
8025        let cond = &spec.cond_evals;
8026        // M = Uᵀ D U, W = Uᵀ inv_vv U (both q×q, symmetric).
8027        let m = u.t().dot(d_mat).dot(u);
8028        let w = u.t().dot(inv_vv).dot(u);
8029        // correction = Σ_{m,l} W[m,l]·M[m,l]·(1 − F[m,l]).
8030        let mut acc = 0.0_f64;
8031        let eps = 1.0e-12;
8032        for a in 0..q {
8033            for b in 0..q {
8034                let denom = raw[a] - raw[b];
8035                let f1 = if denom.abs() > eps {
8036                    (cond[a] - cond[b]) / denom
8037                } else if cond[a] == raw[a] {
8038                    1.0
8039                } else {
8040                    0.0
8041                };
8042                acc += w[[a, b]] * m[[a, b]] * (1.0 - f1);
8043            }
8044        }
8045        acc
8046    }
8047
8048    /// #1417: exact `½ tr(H⁻¹ ∂H_data/∂logα)` for LEARNABLE IBP alpha.
8049    ///
8050    /// The forward assignment is `a_ik = σ(ℓ_ik/τ)·π_k(α)` with the #614
8051    /// consistent stick-breaking mean `π_k(α) = (α/(α+1))^(k+1)`, so
8052    /// `∂logπ_k/∂logα = (k+1)/(α+1)`. EVERY data-Jacobian column for atom `k` —
8053    /// the logit-JVP row (carries one `π_k`), the coordinate rows (carry one
8054    /// `a_k`), and the β-leg (`a_k·φ`) — carries exactly ONE `a_k`/`π_k` factor
8055    /// (`σ(ℓ/τ)` is α-independent). Hence each Jacobian column scales as
8056    /// `∂J_·k/∂logα = ((k+1)/(α+1))·J_·k`, and the data Hessian block for the
8057    /// atom pair `(k_a, k_b)` scales as
8058    ///   ∂H_data[a,b]/∂logα = (((k_a+1) + (k_b+1))/(α+1))·H_data[a,b].
8059    /// Therefore the exact data-block contribution to the α-logdet trace is
8060    ///   ½ tr(H⁻¹ ∂H_data/∂logα)
8061    ///     = ½/(α+1) · Σ_{a,b} ((k_a+1) + (k_b+1))·(H⁻¹)_{ba}·H_data[a,b],
8062    /// over the full joint `(t, β)` index set. `H_data[a,b]` is the data-fit
8063    /// Gauss-Newton block built from the SAME `row_jets_for_logdet` first-jets the
8064    /// θ-adjoint uses (`H_tt = ⟨J_a,J_b⟩`, `H_tβ = ⟨J_a,J_β⟩`, `H_ββ = ⟨J_β,J_β'⟩`),
8065    /// and `(H⁻¹)` is contracted through the same per-row selected-inverse blocks.
8066    /// This closes the learnable-α gradient: combined with the prior-Hessian
8067    /// trace (`assignment_log_strength_hessian_trace`) the full
8068    /// `½ tr(H⁻¹ ∂H/∂logα)` is now assembled. For FIXED alpha (and non-IBP modes)
8069    /// this is identically zero.
8070    pub(crate) fn learnable_ibp_data_logdet_alpha_trace(
8071        &self,
8072        rho: &SaeManifoldRho,
8073        cache: &ArrowFactorCache,
8074        solver: &DeflatedArrowSolver<'_>,
8075    ) -> Result<f64, String> {
8076        let AssignmentMode::IBPMap {
8077            learnable_alpha: true,
8078            ..
8079        } = self.assignment.mode
8080        else {
8081            return Ok(0.0);
8082        };
8083        let alpha = self
8084            .assignment
8085            .resolved_ibp_alpha(rho)
8086            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8087        let inv_alpha1 = 1.0 / (alpha + 1.0);
8088        let n = self.n_obs();
8089        let total_t = cache.delta_t_len();
8090        let second_jets = self.atom_second_jets()?;
8091        let border = self.border_channels_for_cache(cache)?;
8092
8093        // β-tier selected inverse `(H⁻¹)_ββ` (shared across rows). #932 FRONT C:
8094        // on the plain bordered arrow this is the cached dense `S⁻¹` formed once
8095        // (no `K` full-system solves); when a gauge / #1038 cross-row Woodbury is
8096        // active the row-local Takahashi blocks are NOT valid, so we fall back to
8097        // the per-β-coordinate `solve` loop (bit-identical, just O(n) per call).
8098        let fast_selected = solver.plain_selected_inverse_available();
8099        let beta_inv = if cache.k == 0 {
8100            Array2::<f64>::zeros((0, 0))
8101        } else if fast_selected {
8102            solver.beta_inv().map_err(|err| {
8103                format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8104            })?
8105        } else {
8106            let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8107            let rhs_t = Array1::<f64>::zeros(total_t);
8108            let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8109            for col in 0..cache.k {
8110                rhs_beta[col] = 1.0;
8111                let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8112                    format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8113                })?;
8114                rhs_beta[col] = 0.0;
8115                for r in 0..cache.k {
8116                    beta_inv[[r, col]] = solved.beta[r];
8117                }
8118            }
8119            beta_inv
8120        };
8121        // Atom index of each β border channel (the `k_b` weight for the β leg).
8122        let border_atom: Vec<usize> = border.iter().map(|c| c.atom).collect();
8123
8124        let mut trace = 0.0_f64;
8125        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8126        let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8127        // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8128        // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8129        // rows fall back to the scalar per-row path (bit-identical either way).
8130        let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8131            std::collections::VecDeque::new();
8132        let mut jet_window_next = 0usize;
8133        // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback.
8134        let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
8135        let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
8136        for row in 0..n {
8137            let q = cache.row_dims[row];
8138            let base = cache.row_offsets[row];
8139            let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8140            self.assignment
8141                .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8142            if jet_window.is_empty() {
8143                jet_window_next = self.refill_jet_window(
8144                    rho,
8145                    jet_window_next,
8146                    cache,
8147                    &second_jets,
8148                    &border,
8149                    &mut jet_window,
8150                )?;
8151            }
8152            let jets = jet_window.pop_front().expect("jet window must be non-empty");
8153            // Atom index (k-weight) of each local t-var.
8154            let var_atom: Vec<usize> = jets
8155                .vars
8156                .iter()
8157                .map(|v| match *v {
8158                    SaeLocalRowVar::Logit { atom } => atom,
8159                    SaeLocalRowVar::Coord { atom, .. } => atom,
8160                })
8161                .collect();
8162
8163            // Per-row selected inverse blocks `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ`.
8164            // #932 FRONT C: row-local Takahashi (O(q·(q+K))) on the plain arrow;
8165            // per-row full-system `solve` loop (O(n·q)) under gauge / cross-row
8166            // Woodbury where the row-local blocks are not valid.
8167            let (inv_vv, inv_vbeta) = if fast_selected {
8168                solver
8169                    .selected_inverse_row_blocks(row, &beta_inv)
8170                    .map_err(|err| {
8171                        format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8172                    })?
8173            } else {
8174                let mut inv_vv = Array2::<f64>::zeros((q, q));
8175                let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8176                for col in 0..q {
8177                    rhs_t_scratch[base + col] = 1.0;
8178                    let solved = solver
8179                        .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
8180                        .map_err(|err| {
8181                            format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8182                        })?;
8183                    rhs_t_scratch[base + col] = 0.0;
8184                    for r in 0..q {
8185                        inv_vv[[r, col]] = solved.t[base + r];
8186                    }
8187                    for b in 0..cache.k {
8188                        inv_vbeta[[col, b]] = solved.beta[b];
8189                    }
8190                }
8191                (inv_vv, inv_vbeta)
8192            };
8193
8194            // #1026 — UNGATED (background-tier) atoms have a force-fixed unit gate,
8195            // so their mass `a_k ≡ 1` is α-INDEPENDENT: every data-Jacobian column
8196            // for an ungated atom carries `a_k = 1`, NOT `π_k(α)`, so its α-exponent
8197            // is `e_k = 0`, not `k+1`. Gated atoms keep `e_k = k+1`. (The prior trace
8198            // handles ungated separately by zeroing the fixed-logit `z_jac`.)
8199            let kfac = |atom: usize| -> f64 {
8200                if self.assignment.ungated.get(atom).copied().unwrap_or(false) {
8201                    0.0
8202                } else {
8203                    (atom + 1) as f64
8204                }
8205            };
8206            // t–t block: Σ_{a,b} (e_a + e_b)·(H⁻¹)_{ba}·⟨J_a, J_b⟩, where the
8207            // per-atom log-prior exponent is e_k = k+1 for the #614 consistent
8208            // stick-breaking mean π_k = (α/(α+1))^(k+1) (dlogπ_k/dlogα = (k+1)·inv_alpha1).
8209            for a in 0..q {
8210                for b in 0..q {
8211                    let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8212                    if h_ab == 0.0 {
8213                        continue;
8214                    }
8215                    let kw = kfac(var_atom[a]) + kfac(var_atom[b]);
8216                    trace += kw * inv_vv[[b, a]] * h_ab;
8217                }
8218            }
8219            // Deflation correction (kept-subspace restriction + β-Schur/rotation).
8220            // `inv_vv` is the DEFLATED selected inverse, so the t–t contraction
8221            // above contracts the RAW derivative `D` where the re-deflating
8222            // criterion uses the deflation-map derivative `DΦ[D]`. Subtract the
8223            // exact over-count `tr(inv_vv·(D − DΦ[D]))` via the Daleckii–Krein
8224            // helper, with `D_{ab} = kw_ab·⟨J_a, J_b⟩` the SAME t–t operator the
8225            // trace contracts. The t–β/β–β blocks are not deflated, so only the
8226            // t–t contraction is corrected.
8227            let dirs = cache
8228                .deflated_row_directions
8229                .get(row)
8230                .map(Vec::as_slice)
8231                .unwrap_or(&[]);
8232            if !dirs.is_empty() {
8233                let mut d_mat = Array2::<f64>::zeros((q, q));
8234                for a in 0..q {
8235                    for b in 0..q {
8236                        let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8237                        if h_ab == 0.0 {
8238                            continue;
8239                        }
8240                        d_mat[[a, b]] = (kfac(var_atom[a]) + kfac(var_atom[b])) * h_ab;
8241                    }
8242                }
8243                let spectrum = cache
8244                    .deflation_row_spectra
8245                    .get(row)
8246                    .and_then(Option::as_ref);
8247                trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
8248            }
8249            // t–β and β–t blocks: appear symmetrically, contract once with the
8250            // factor 2 (H, H⁻¹ symmetric; `(H⁻¹)_βt = (H⁻¹)_tβᵀ`).
8251            for a in 0..q {
8252                for (beta_pos, channel) in border.iter().enumerate() {
8253                    let h_ab = sae_dot(&jets.first[a], &jets.beta[beta_pos]);
8254                    if h_ab == 0.0 {
8255                        continue;
8256                    }
8257                    let kw = kfac(var_atom[a]) + kfac(border_atom[beta_pos]);
8258                    trace += 2.0 * kw * inv_vbeta[[a, channel.index]] * h_ab;
8259                }
8260            }
8261            // β–β block: Σ_{β,β'} (k_β + k_β')·(H⁻¹)_{β'β}·⟨J_β, J_β'⟩.
8262            for (beta_i, channel_i) in border.iter().enumerate() {
8263                for (beta_j, channel_j) in border.iter().enumerate() {
8264                    let h_ab = sae_dot(&jets.beta[beta_i], &jets.beta[beta_j]);
8265                    if h_ab == 0.0 {
8266                        continue;
8267                    }
8268                    let kw = kfac(border_atom[beta_i]) + kfac(border_atom[beta_j]);
8269                    trace += kw * beta_inv[[channel_i.index, channel_j.index]] * h_ab;
8270                }
8271            }
8272        }
8273        Ok(0.5 * inv_alpha1 * trace)
8274    }
8275
8276    pub(crate) fn add_learnable_ibp_forward_alpha_data_rhs(
8277        &self,
8278        rho: &SaeManifoldRho,
8279        target: ArrayView2<'_, f64>,
8280        cache: &ArrowFactorCache,
8281        t: &mut Array1<f64>,
8282        beta: &mut Array1<f64>,
8283    ) -> Result<(), String> {
8284        let AssignmentMode::IBPMap {
8285            temperature,
8286            learnable_alpha: true,
8287            ..
8288        } = self.assignment.mode
8289        else {
8290            return Ok(());
8291        };
8292        let alpha = self
8293            .assignment
8294            .resolved_ibp_alpha(rho)
8295            .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8296        let k_atoms = self.k_atoms();
8297        let p = self.output_dim();
8298        let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8299        let mut dprior = Array1::<f64>::zeros(k_atoms);
8300        for k in 0..k_atoms {
8301            // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8302            // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8303            // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8304            dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8305        }
8306        let inv_tau = 1.0 / temperature;
8307        let row_loss_w = self.row_loss_weights.as_deref();
8308        let whitens = self
8309            .row_metric
8310            .as_ref()
8311            .is_some_and(|metric| metric.whitens_likelihood());
8312        let border = self.border_channels_for_cache(cache)?;
8313        let mut decoded_rows = vec![vec![0.0_f64; p]; k_atoms];
8314        let mut decoded_deriv = vec![0.0_f64; p];
8315        let mut fitted = Array1::<f64>::zeros(p);
8316        let mut f_rho = Array1::<f64>::zeros(p);
8317        let mut residual = Array1::<f64>::zeros(p);
8318        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8319        let mut assignments = vec![0.0_f64; k_atoms];
8320        for row in 0..self.n_obs() {
8321            self.assignment
8322                .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8323            fitted.fill(0.0);
8324            f_rho.fill(0.0);
8325            for k in 0..k_atoms {
8326                self.atoms[k].fill_decoded_row(row, &mut decoded_rows[k]);
8327                // Ungated (#1026 background-tier) atoms have a force-fixed unit
8328                // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8329                // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8330                // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8331                // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8332                // but `a_k` still varies with α through `π_k`, so it must NOT be
8333                // skipped.)
8334                let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8335                    0.0
8336                } else {
8337                    (assignments[k] / prior[k]) * dprior[k]
8338                };
8339                for out_col in 0..p {
8340                    fitted[out_col] += assignments[k] * decoded_rows[k][out_col];
8341                    f_rho[out_col] += da_rho * decoded_rows[k][out_col];
8342                }
8343            }
8344            for out_col in 0..p {
8345                residual[out_col] = fitted[out_col] - target[[row, out_col]];
8346            }
8347            let residual_metric = match self.row_metric.as_ref() {
8348                Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8349                _ => residual.to_vec(),
8350            };
8351            let f_metric = match self.row_metric.as_ref() {
8352                Some(metric) if whitens => metric.apply_metric_row(row, f_rho.view()),
8353                _ => f_rho.to_vec(),
8354            };
8355            let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8356            let row_vars = self.row_vars_for_cache_row(row, cache)?;
8357            let row_base = cache.row_offsets[row];
8358            for (pos, var) in row_vars.iter().enumerate() {
8359                let mut contribution = 0.0_f64;
8360                match *var {
8361                    SaeLocalRowVar::Logit { atom } => {
8362                        let sigma = assignments[atom] / prior[atom];
8363                        let sigma_jac = sigma * (1.0 - sigma) * inv_tau;
8364                        let da_dl = sigma_jac * prior[atom];
8365                        let d_da_rho_dl = sigma_jac * dprior[atom];
8366                        for out_col in 0..p {
8367                            contribution += da_dl * decoded_rows[atom][out_col] * f_metric[out_col];
8368                            contribution += d_da_rho_dl
8369                                * decoded_rows[atom][out_col]
8370                                * residual_metric[out_col];
8371                        }
8372                    }
8373                    SaeLocalRowVar::Coord { atom, axis } => {
8374                        let sigma = assignments[atom] / prior[atom];
8375                        let da_rho = sigma * dprior[atom];
8376                        self.atoms[atom].fill_decoded_derivative_row(row, axis, &mut decoded_deriv);
8377                        for out_col in 0..p {
8378                            contribution +=
8379                                assignments[atom] * decoded_deriv[out_col] * f_metric[out_col];
8380                            contribution +=
8381                                da_rho * decoded_deriv[out_col] * residual_metric[out_col];
8382                        }
8383                    }
8384                }
8385                t[row_base + pos] += row_weight * contribution;
8386            }
8387            for channel in &border {
8388                let phi = self.atoms[channel.atom].basis_values[[row, channel.basis_col]];
8389                let sigma = assignments[channel.atom] / prior[channel.atom];
8390                let da_rho = sigma * dprior[channel.atom];
8391                let mut contribution = 0.0_f64;
8392                for out_col in 0..p {
8393                    let output = channel.output[out_col];
8394                    contribution += assignments[channel.atom] * phi * output * f_metric[out_col];
8395                    contribution += da_rho * phi * output * residual_metric[out_col];
8396                }
8397                beta[channel.index] += row_weight * contribution;
8398            }
8399        }
8400        Ok(())
8401    }
8402
8403    pub(crate) fn border_channels_for_cache(
8404        &self,
8405        cache: &ArrowFactorCache,
8406    ) -> Result<Vec<SaeBorderChannel>, String> {
8407        let p = self.output_dim();
8408        let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8409        let offsets = if frames_active {
8410            self.factored_beta_offsets()
8411        } else {
8412            self.beta_offsets()
8413        };
8414        let mut channels = Vec::with_capacity(cache.k);
8415        for (atom_idx, atom) in self.atoms.iter().enumerate() {
8416            let m = atom.basis_size();
8417            let frame = if frames_active {
8418                self.frame_output_matrix(atom_idx)
8419            } else {
8420                Array2::<f64>::eye(p)
8421            };
8422            let r = frame.ncols();
8423            for basis_col in 0..m {
8424                for channel in 0..r {
8425                    let mut output = vec![0.0_f64; p];
8426                    for out_col in 0..p {
8427                        output[out_col] = frame[[out_col, channel]];
8428                    }
8429                    channels.push(SaeBorderChannel {
8430                        atom: atom_idx,
8431                        basis_col,
8432                        index: offsets[atom_idx] + basis_col * r + channel,
8433                        output,
8434                    });
8435                }
8436            }
8437        }
8438        if channels.len() != cache.k {
8439            return Err(format!(
8440                "border channel layout has {} entries but cache border has {}",
8441                channels.len(),
8442                cache.k
8443            ));
8444        }
8445        Ok(channels)
8446    }
8447
8448    pub(crate) fn row_vars_for_cache_row(
8449        &self,
8450        row: usize,
8451        cache: &ArrowFactorCache,
8452    ) -> Result<Vec<SaeLocalRowVar>, String> {
8453        let q_row = cache.row_dims[row];
8454        let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
8455        match self.last_row_layout {
8456            Some(ref layout) => {
8457                for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8458                    vars[pos] = Some(SaeLocalRowVar::Logit { atom });
8459                    let start = layout.coord_starts[row][pos];
8460                    let d = self.assignment.coords[atom].latent_dim();
8461                    for axis in 0..d {
8462                        vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8463                    }
8464                }
8465            }
8466            None => {
8467                let assignment_dim = self.assignment.assignment_coord_dim();
8468                let coord_offsets = self.assignment.coord_offsets();
8469                for atom in 0..assignment_dim {
8470                    vars[atom] = Some(SaeLocalRowVar::Logit { atom });
8471                }
8472                for atom in 0..self.k_atoms() {
8473                    let start = coord_offsets[atom];
8474                    let d = self.assignment.coords[atom].latent_dim();
8475                    for axis in 0..d {
8476                        vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8477                    }
8478                }
8479            }
8480        }
8481        vars.into_iter()
8482            .enumerate()
8483            .map(|(idx, v)| {
8484                v.ok_or_else(|| {
8485                    format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
8486                })
8487            })
8488            .collect()
8489    }
8490
8491    pub(crate) fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
8492        let mut out = Vec::with_capacity(self.k_atoms());
8493        for (atom_idx, atom) in self.atoms.iter().enumerate() {
8494            let coords = self.assignment.coords[atom_idx].as_matrix();
8495            let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
8496                second.second_jet(coords.view())?
8497            } else {
8498                let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
8499                    format!(
8500                        "logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
8501                        atom.name
8502                    )
8503                })?;
8504                evaluator
8505                    .second_jet_dyn(coords.view())
8506                    .ok_or_else(|| {
8507                        format!(
8508                            "logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
8509                            atom.name
8510                        )
8511                    })??
8512            };
8513            let expected = (
8514                atom.n_obs(),
8515                atom.basis_size(),
8516                atom.latent_dim,
8517                atom.latent_dim,
8518            );
8519            if jet.dim() != expected {
8520                return Err(format!(
8521                    "logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
8522                    atom.name,
8523                    jet.dim(),
8524                    expected
8525                ));
8526            }
8527            out.push(jet);
8528        }
8529        Ok(out)
8530    }
8531
8532    // [#780 line-count gate] The per-row jet / reconstruction-channel cluster
8533    // (`reconstruction_row_program_for_logdet`, the const-generic
8534    // reconstruction / β-border channel fills and their dynamic dispatchers,
8535    // `row_jets_for_logdet`, `row_jets_for_logdet_batch4`, `batch4_assemble`,
8536    // and `refill_jet_window`) lives in the sibling
8537    // `construction_row_jet_logdet_channels.rs` file, inlined via `include!`
8538    // below at module scope as a second `impl SaeManifoldTerm` block. Splitting
8539    // it out keeps this tracked file under the 10k limit; `include!` preserves
8540    // the identical module scope and private-field access.
8541
8542    pub(crate) fn assignment_prior_hdiag_derivative_entry(
8543        &self,
8544        rho: &SaeManifoldRho,
8545        row: usize,
8546        diag_atom: usize,
8547        wrt: SaeLocalRowVar,
8548        ibp_channels: Option<&IbpHessianDiagThirdChannels>,
8549    ) -> f64 {
8550        let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
8551            return 0.0;
8552        };
8553        match self.assignment.mode {
8554            AssignmentMode::Softmax { .. } => {
8555                // #1038: the softmax entropy Hessian is now stored DENSE in
8556                // `block.htt` and its full θ-derivative `∂H_{k,j}/∂z_w` (diagonal
8557                // AND off-diagonal) is added inline in `logdet_theta_adjoint` from
8558                // the shared `row_dense_hessian_logit_derivative`. Returning the
8559                // diagonal contribution here too would double-count, so this
8560                // primitive is silent for softmax — the dense path is the single
8561                // source for value, logdet, and adjoint.
8562                0.0
8563            }
8564            AssignmentMode::ThresholdGate {
8565                temperature,
8566                threshold,
8567            } => {
8568                if diag_atom != wrt_atom {
8569                    return 0.0;
8570                }
8571                let logit = self.assignment.logits[[row, diag_atom]];
8572                if !crate::assignment::jumprelu_in_optimization_band(
8573                    logit,
8574                    threshold,
8575                    temperature,
8576                ) {
8577                    return 0.0;
8578                }
8579                let inv_tau = 1.0 / temperature;
8580                let activation =
8581                    gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
8582                let slope = activation * (1.0 - activation);
8583                // #1415: P(ℓ)=λσ((ℓ−θ)/τ); P''(ℓ)=(λ/τ²)s(1−2a) so the third
8584                // derivative is P'''(ℓ)=(λ/τ³)·s·(1−6a+6a²), because
8585                // d/dℓ[s(1−2a)] = (1/τ)s[(1−2a)²−2s] = (1/τ)s(1−6a+6a²).
8586                rho.lambda_sparse()
8587                    * slope
8588                    * (1.0 - 6.0 * activation + 6.0 * activation * activation)
8589                    * inv_tau
8590                    * inv_tau
8591                    * inv_tau
8592            }
8593            AssignmentMode::IBPMap { .. } => {
8594                // The assembled `htt` diagonal consumes
8595                // `IBPAssignmentPenalty::hessian_diag`, whose logit derivative
8596                // splits into a row-local direct-`z` channel and a global
8597                // empirical-`M_k` channel (π_k couples every row in column k).
8598                // This same-row primitive returns only the LOCAL direct-`z`
8599                // channel — and only on the matching logit (`diag_atom == w`),
8600                // since H_ik depends on no other row's z explicitly. The global
8601                // M_k channel is accumulated column-wise in
8602                // `logdet_theta_adjoint` (it needs the per-row selected-inverse
8603                // diagonals), so adding it here would double-count.
8604                if diag_atom != wrt_atom {
8605                    return 0.0;
8606                }
8607                match ibp_channels {
8608                    Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
8609                    None => 0.0,
8610                }
8611            }
8612        }
8613    }
8614
8615    pub(crate) fn ard_majorized_hessian_derivative(
8616        &self,
8617        rho: &SaeManifoldRho,
8618        row: usize,
8619        atom: usize,
8620        axis: usize,
8621    ) -> f64 {
8622        if rho.log_ard[atom].is_empty() {
8623            return 0.0;
8624        }
8625        let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8626        let periods = self.assignment.coords[atom].effective_axis_periods();
8627        let t = self.assignment.coords[atom].row(row)[axis];
8628        let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
8629        if prior.hess <= 0.0 {
8630            return 0.0;
8631        }
8632        match periods[axis] {
8633            None => 0.0,
8634            Some(period) => {
8635                let kappa = std::f64::consts::TAU / period;
8636                -alpha * kappa * (kappa * t).sin()
8637            }
8638        }
8639    }
8640
8641    pub fn outer_rho_gradient_ift_rhs(
8642        &self,
8643        rho: &SaeManifoldRho,
8644        target: ArrayView2<'_, f64>,
8645        j: usize,
8646        cache: &ArrowFactorCache,
8647    ) -> Result<SaeArrowVector, String> {
8648        let n_params = rho.to_flat().len();
8649        if j >= n_params {
8650            return Err(format!(
8651                "outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
8652            ));
8653        }
8654        let mut t = Array1::<f64>::zeros(cache.delta_t_len());
8655        let mut beta = Array1::<f64>::zeros(cache.k);
8656        if j == 0 {
8657            let assignment_grad =
8658                assignment_prior_log_strength_target_mixed(&self.assignment, rho)?;
8659            let k_atoms = self.k_atoms();
8660            let assignment_dim = self.assignment.assignment_coord_dim();
8661            for row in 0..self.n_obs() {
8662                let base = cache.row_offsets[row];
8663                let assignment_base = row * k_atoms;
8664                match self.last_row_layout {
8665                    Some(ref layout) => {
8666                        for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8667                            t[base + pos] = assignment_grad[assignment_base + atom];
8668                        }
8669                    }
8670                    None => {
8671                        for free_idx in 0..assignment_dim {
8672                            t[base + free_idx] = assignment_grad[assignment_base + free_idx];
8673                        }
8674                    }
8675                }
8676            }
8677            self.add_learnable_ibp_forward_alpha_data_rhs(rho, target, cache, &mut t, &mut beta)?;
8678        } else if (1..=rho.log_lambda_smooth.len()).contains(&j) {
8679            // #1556: coordinate `j ∈ 1..=K` is the per-atom smoothness strength
8680            // `log λ_smooth[j-1]`. `∂(penalty)/∂log λ_k = λ_k·S_k C_k` touches ONLY
8681            // atom `k = j-1`'s decoder block; every other atom's RHS is zero.
8682            let target_atom = j - 1;
8683            let lambda = rho.lambda_smooth_for(target_atom);
8684            let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8685            let offsets = if frames_active {
8686                self.factored_beta_offsets()
8687            } else {
8688                self.beta_offsets()
8689            };
8690            let atom = &self.atoms[target_atom];
8691            let m = atom.basis_size();
8692            let coeffs = if frames_active {
8693                match &atom.decoder_frame {
8694                    Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
8695                    None => atom.decoder_coefficients.clone(),
8696                }
8697            } else {
8698                atom.decoder_coefficients.clone()
8699            };
8700            let r = coeffs.ncols();
8701            let off = offsets[target_atom];
8702            for mu in 0..m {
8703                for channel in 0..r {
8704                    let mut acc = 0.0_f64;
8705                    for nu in 0..m {
8706                        let s_sym =
8707                            0.5 * (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
8708                        acc += s_sym * coeffs[[nu, channel]];
8709                    }
8710                    beta[off + mu * r + channel] = lambda * acc;
8711                }
8712            }
8713        } else {
8714            let mut cursor = 1 + rho.log_lambda_smooth.len();
8715            for atom in 0..rho.log_ard.len() {
8716                for axis in 0..rho.log_ard[atom].len() {
8717                    if cursor == j {
8718                        let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8719                        let periods = self.assignment.coords[atom].effective_axis_periods();
8720                        for row in 0..self.n_obs() {
8721                            let row_t = self.assignment.coords[atom].row(row);
8722                            let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
8723                            let Some(pos) = sae_coord_penalty_offset(
8724                                self.last_row_layout.as_ref(),
8725                                self.assignment.coord_offsets()[atom] + axis,
8726                                row,
8727                                atom,
8728                            ) else {
8729                                continue;
8730                            };
8731                            t[cache.row_offsets[row] + pos] = prior.grad;
8732                        }
8733                        return Ok(SaeArrowVector { t, beta });
8734                    }
8735                    cursor += 1;
8736                }
8737            }
8738        }
8739        Ok(SaeArrowVector { t, beta })
8740    }
8741
8742    pub(crate) fn logdet_theta_adjoint(
8743        &self,
8744        rho: &SaeManifoldRho,
8745        cache: &ArrowFactorCache,
8746        solver: &DeflatedArrowSolver<'_>,
8747    ) -> Result<SaeArrowVector, String> {
8748        // Γ_a = tr(H⁻¹ ∂H/∂θ_a) over the inner variables θ (#1006). `H` here is
8749        // the SAME object the evidence factor builds — Gauss-Newton data
8750        // curvature plus the prior majorizers / `hessian_diag` diagonals the
8751        // Newton/Schur Cholesky factorizes — so each block's θ-derivative channel
8752        // is differentiated on the criterion's own branch (no value/gradient
8753        // desync). The IBP-MAP assignment prior is the one block whose
8754        // `hessian_diag` couples every row in a column through the plug-in
8755        // empirical mass `M_k = Σ_i z_ik`; its logit derivative therefore has a
8756        // row-local channel (handled inline via
8757        // `assignment_prior_hdiag_derivative_entry`) and a cross-row channel
8758        // (accumulated column-wise after the row loop, below).
8759        let n = self.n_obs();
8760        let total_t = cache.delta_t_len();
8761        let mut gamma_t = Array1::<f64>::zeros(total_t);
8762        let mut gamma_beta = Array1::<f64>::zeros(cache.k);
8763        let second_jets = self.atom_second_jets()?;
8764        let border = self.border_channels_for_cache(cache)?;
8765        // #932 FRONT C: plain-arrow `(H⁻¹)_ββ = S⁻¹` formed once from the cached
8766        // Schur factor; gauge / #1038 cross-row Woodbury fall back to the per-β
8767        // `solve` loop where the row-local Takahashi blocks are not valid.
8768        let fast_selected = solver.plain_selected_inverse_available();
8769        let beta_inv = if cache.k == 0 {
8770            Array2::<f64>::zeros((0, 0))
8771        } else if fast_selected {
8772            solver
8773                .beta_inv()
8774                .map_err(|err| format!("logdet_theta_adjoint: beta selected inverse: {err}"))?
8775        } else {
8776            let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8777            let rhs_t = Array1::<f64>::zeros(total_t);
8778            let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8779            for col in 0..cache.k {
8780                rhs_beta[col] = 1.0;
8781                let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8782                    format!("logdet_theta_adjoint: beta selected inverse solve: {err}")
8783                })?;
8784                rhs_beta[col] = 0.0;
8785                for row in 0..cache.k {
8786                    beta_inv[[row, col]] = solved.beta[row];
8787                }
8788            }
8789            beta_inv
8790        };
8791        // IBP `hessian_diag` logit third-derivative channels (#1006). The full
8792        // IBP Hessian also has per-column cross-row rank-one terms
8793        // `H_(i,k),(j,k) = d_k·J_ik·J_jk`; these ARE carried in `H` via the #1038
8794        // Woodbury source (`IbpCrossRowSource`, construction.rs:4710-4752), the
8795        // ρ-trace differentiates them (#1416,
8796        // `assignment_log_strength_hessian_trace`), AND this θ-adjoint now
8797        // differentiates them exactly too: the empirical-`M_k` channel below
8798        // contracts the shared-mass coupling of the DIAGONAL curvature, and the
8799        // cross-row Woodbury pass (further below, using `cross_row_dd` and
8800        // `logit_curvature`) contracts the `∂/∂ℓ_w (d_k·J_ik·J_jk)` rank-one
8801        // derivative — so value, logdet, ρ-trace, and θ-adjoint all differentiate
8802        // the one operator `H = H₀ + Σ_k d_k u_k u_kᵀ`.
8803        let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
8804        let k_atoms = self.k_atoms();
8805        // #1038 softmax entropy: the dense per-row entropy Hessian written into
8806        // `block.htt` has off-diagonal logit terms whose θ-derivative the adjoint
8807        // must contract too (not just the diagonal). Build the SAME penalty +
8808        // `scale = λ/τ²` the assembly uses so value/logdet/adjoint differentiate
8809        // one operator. `None` for non-softmax modes (their diagonal/cross-row
8810        // channels are handled by `assignment_prior_hdiag_derivative_entry` and
8811        // the IBP column pass).
8812        let softmax_dense_adjoint: Option<(
8813            gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
8814            f64,
8815        )> = match self.assignment.mode {
8816            AssignmentMode::Softmax {
8817                temperature,
8818                sparsity,
8819            } if k_atoms > 1 => {
8820                let inv_tau = 1.0 / temperature;
8821                let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
8822                Some((
8823                    gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
8824                        k_atoms,
8825                        temperature,
8826                    ),
8827                    scale,
8828                ))
8829            }
8830            _ => None,
8831        };
8832        // Per active logit position: (row i, column k, global t-index,
8833        // (H⁻¹)_ik,ik) — the inputs to the IBP cross-row empirical-`M_k` channel.
8834        let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
8835
8836        // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8837        let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8838        // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8839        // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8840        // rows fall back to the scalar per-row path (bit-identical either way).
8841        let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8842            std::collections::VecDeque::new();
8843        let mut jet_window_next = 0usize;
8844        // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback.
8845        let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
8846        let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
8847        for row in 0..n {
8848            let q = cache.row_dims[row];
8849            let base = cache.row_offsets[row];
8850            let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8851            self.assignment
8852                .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8853            if jet_window.is_empty() {
8854                jet_window_next = self.refill_jet_window(
8855                    rho,
8856                    jet_window_next,
8857                    cache,
8858                    &second_jets,
8859                    &border,
8860                    &mut jet_window,
8861                )?;
8862            }
8863            let jets = jet_window.pop_front().expect("jet window must be non-empty");
8864
8865            // #932 FRONT C: row-local Takahashi on the plain arrow; per-row
8866            // full-system `solve` loop under gauge / cross-row Woodbury.
8867            let (inv_vv, inv_vbeta) = if fast_selected {
8868                solver
8869                    .selected_inverse_row_blocks(row, &beta_inv)
8870                    .map_err(|err| {
8871                        format!("logdet_theta_adjoint: selected inverse: {err}")
8872                    })?
8873            } else {
8874                let mut inv_vv = Array2::<f64>::zeros((q, q));
8875                let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8876                for col in 0..q {
8877                    rhs_t_scratch[base + col] = 1.0;
8878                    let solved = solver
8879                        .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
8880                        .map_err(|err| {
8881                            format!("logdet_theta_adjoint: selected inverse solve: {err}")
8882                        })?;
8883                    rhs_t_scratch[base + col] = 0.0;
8884                    for r in 0..q {
8885                        inv_vv[[r, col]] = solved.t[base + r];
8886                    }
8887                    for b in 0..cache.k {
8888                        inv_vbeta[[col, b]] = solved.beta[b];
8889                    }
8890                }
8891                (inv_vv, inv_vbeta)
8892            };
8893
8894            // Record each active logit's column, global t-index, and
8895            // selected-inverse diagonal (H⁻¹)_ik,ik for the IBP cross-row pass.
8896            if ibp_channels.is_some() {
8897                for (pos, var) in jets.vars.iter().enumerate() {
8898                    if let SaeLocalRowVar::Logit { atom } = *var {
8899                        ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
8900                    }
8901                }
8902            }
8903
8904            // #1419: when `w` is a logit and the assignment is softmax, the per-row
8905            // Gershgorin majorizer `D = diag(Σ_j|H_kj|)` is what the assembly wrote
8906            // into `htt` (the genuine Loewner majorizer that replaces the indefinite
8907            // exact entropy Hessian). Its full θ-derivative `∂D_{k,k}/∂z_w` (diagonal;
8908            // `∂D_kk/∂z_w = Σ_j sign(H_kj)·∂H_kj/∂z_w`) is the SAME operator the
8909            // assembly and logdet now differentiate, so value and adjoint stay on ONE
8910            // exact branch. Compute it once per logit `w` and add it at every logit
8911            // pair `(a,b)` below. The diagonal softmax case is therefore handled here,
8912            // NOT in `assignment_prior_hdiag_derivative_entry` (which returns 0 for
8913            // softmax to avoid double-counting).
8914            // #1410: the softmax majorizer θ-derivative `∂D_kk/∂z_w` is DIAGONAL
8915            // (`D` is diagonal), and the compact adjoint reads it only for this
8916            // row's `≤ top_k` active atoms. Compute the needed diagonal entry
8917            // directly from the softmax row `a` (= `assignments`, in hand) via
8918            // `active_softmax_majorizer_logit_derivative_entry`, instead of the old
8919            // per-(row, logit) full `K×K` `row_psd_majorizer_logit_derivative`
8920            // allocation. `m = Σ_j a_j l_j` is shared across all `(w, k)` pairs of
8921            // the row, so compute it once. `inv_tau` carries the softmax `∂a/∂z`
8922            // convention.
8923            let softmax_adjoint_row: Option<(&[f64], f64, f64, f64)> =
8924                match (softmax_dense_adjoint.as_ref(), self.assignment.mode) {
8925                    (Some((_penalty, scale)), AssignmentMode::Softmax { temperature, .. }) => {
8926                        let a = assignments
8927                            .as_slice()
8928                            .expect("softmax assignments row must be contiguous");
8929                        let m = softmax_majorizer_log_mean(a);
8930                        Some((a, m, *scale, 1.0 / temperature))
8931                    }
8932                    _ => None,
8933                };
8934            // Per-row UNIT-stiffness deflated directions: the selected inverse
8935            // `inv_vv` is the DEFLATED inverse (it assigns `1/λ̃ = 1` to each
8936            // `vᵢ`), so every `inv_vv`-weighted t–t contraction of `∂H/∂θ_w`
8937            // below spuriously contracts the RAW derivative where the re-deflating
8938            // criterion uses the deflation-map derivative `DΦ`. The kept-subspace Γ
8939            // subtracts `tr(inv_vv·(D − DΦ[D]))` over the t–t block via the same
8940            // Daleckii–Krein helper the ρ-traces use (the t–β / β–β blocks are not
8941            // deflated). `θ` enters only the per-row block (no cross-row Woodbury
8942            // self-downdate on the θ path), so the raw t–t derivative `D` is used
8943            // directly.
8944            let defl_dirs = cache
8945                .deflated_row_directions
8946                .get(row)
8947                .map(Vec::as_slice)
8948                .unwrap_or(&[]);
8949            let defl_spectrum = cache
8950                .deflation_row_spectra
8951                .get(row)
8952                .and_then(Option::as_ref);
8953            for w in 0..q {
8954                let mut gamma = 0.0_f64;
8955                // The active logit `w` differentiates against; `None` unless this
8956                // slot is a softmax logit on the softmax path.
8957                let softmax_d_dw: Option<(&[f64], f64, f64, f64, usize)> =
8958                    match (softmax_adjoint_row, jets.vars[w]) {
8959                        (Some((a, m, scale, inv_tau)), SaeLocalRowVar::Logit { atom: atom_w }) => {
8960                            Some((a, m, scale, inv_tau, atom_w))
8961                        }
8962                        _ => None,
8963                    };
8964                let mut dh_mat = Array2::<f64>::zeros((q, q));
8965                for a in 0..q {
8966                    for b in 0..q {
8967                        let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
8968                            + sae_dot(&jets.first[a], &jets.second[b][w]);
8969                        // `∂D/∂z_w` is diagonal, so it contributes only when the two
8970                        // logit slots are the SAME atom (`atom_a == atom_b`).
8971                        if let (
8972                            Some((a_soft, m, scale, inv_tau, _atom_w)),
8973                            SaeLocalRowVar::Logit { atom: atom_a },
8974                            SaeLocalRowVar::Logit { atom: atom_b },
8975                        ) = (softmax_d_dw, jets.vars[a], jets.vars[b])
8976                        {
8977                            if atom_a == atom_b {
8978                                dh += active_softmax_majorizer_logit_derivative_entry(
8979                                    a_soft, atom_a, _atom_w, m, scale, inv_tau,
8980                                );
8981                            }
8982                        }
8983                        if a == b {
8984                            dh += match jets.vars[a] {
8985                                SaeLocalRowVar::Logit { atom } => self
8986                                    .assignment_prior_hdiag_derivative_entry(
8987                                        rho,
8988                                        row,
8989                                        atom,
8990                                        jets.vars[w],
8991                                        ibp_channels.as_ref(),
8992                                    ),
8993                                SaeLocalRowVar::Coord { atom, axis } if a == w => {
8994                                    self.ard_majorized_hessian_derivative(rho, row, atom, axis)
8995                                }
8996                                _ => 0.0,
8997                            };
8998                        }
8999                        dh_mat[[a, b]] = dh;
9000                        gamma += inv_vv[[b, a]] * dh;
9001                    }
9002                }
9003                if !defl_dirs.is_empty() {
9004                    gamma -= Self::deflation_block_correction(
9005                        &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9006                    );
9007                }
9008                for a in 0..q {
9009                    for (beta_pos, channel) in border.iter().enumerate() {
9010                        let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
9011                            + sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
9012                        gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9013                    }
9014                }
9015                for (beta_i, channel_i) in border.iter().enumerate() {
9016                    for (beta_j, channel_j) in border.iter().enumerate() {
9017                        let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
9018                            + sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
9019                        gamma += beta_inv[[channel_i.index, channel_j.index]] * dh;
9020                    }
9021                }
9022                gamma_t[base + w] = gamma;
9023            }
9024
9025            for (w_beta_pos, w_channel) in border.iter().enumerate() {
9026                let mut gamma = 0.0_f64;
9027                let mut dh_mat = Array2::<f64>::zeros((q, q));
9028                for a in 0..q {
9029                    for b in 0..q {
9030                        let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
9031                            + sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
9032                        dh_mat[[a, b]] = dh;
9033                        gamma += inv_vv[[b, a]] * dh;
9034                    }
9035                }
9036                if !defl_dirs.is_empty() {
9037                    gamma -= Self::deflation_block_correction(
9038                        &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9039                    );
9040                }
9041                for a in 0..q {
9042                    for (beta_pos, channel) in border.iter().enumerate() {
9043                        let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
9044                        gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9045                    }
9046                }
9047                gamma_beta[w_channel.index] += gamma;
9048            }
9049        }
9050
9051        // IBP cross-row empirical-`M_k` channel of Γ (#1006). The assembled
9052        // diagonal H_ik consumes `hessian_diag`, whose dependence on the column
9053        // mass M_k = Σ_i z_ik couples every row in a column. Differentiating
9054        // tr(H⁻¹ ∂H/∂ℓ_wk) on that shared branch:
9055        //   Γ_wk += [ Σ_i (H⁻¹)_ik,ik · ∂_M H_ik ] · J_wk = C_k · J_wk,
9056        // where ∂_M H_ik = `m_channel[i*K+k]` and J_wk = `z_jac[w*K+k]`. The
9057        // row-local direct-`z` channel was already added inline above, so this
9058        // pass adds only the cross-row remainder (it spans `w ≠ i` and the
9059        // self-row M_k self-coupling, which the row-local primitive deliberately
9060        // omits to avoid double-counting).
9061        if let Some(channels) = ibp_channels.as_ref() {
9062            let mut col_coeff = vec![0.0_f64; k_atoms];
9063            for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
9064                col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
9065            }
9066            for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
9067                gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
9068            }
9069
9070            // #1416 / #1641: the EXACT cross-row Woodbury derivative of Γ. The
9071            // assembled `H` carries the per-column rank-one block
9072            // `W_k = d_k·u_k u_kᵀ` with `u_k` the J-weighted column indicator
9073            // (`u_k[slot(i,k)] = J_ik`) and `d_k = w·s'_k` (`cross_row_d[k]`). Both
9074            // `d_k` (through `M_k`) and the `u_k` entries (through `ℓ_ik`) depend on
9075            // the logits, so
9076            //   ∂W_k/∂ℓ_wk = dd_k·J_wk·u_k u_kᵀ
9077            //               + d_k·c_wk·(e_w u_kᵀ + u_k e_wᵀ),
9078            // where `dd_k = ∂d_k/∂M_k = w·s''_k` (`cross_row_dd[k]`),
9079            // `c_wk = ∂J_wk/∂ℓ_wk` (`logit_curvature`), and `e_w` is the unit
9080            // vector at row `w`'s logit-`k` slot.
9081            //
9082            // The θ-adjoint contracts the FULL trace `Γ_wk = tr(H⁻¹ ∂H/∂ℓ_wk)`
9083            // (NOT the `½ tr` the ρ-trace uses — `fixed_state_logdet` differentiates
9084            // the full `log|H|`, and the per-row blocks above contract `inv_vv·dh`
9085            // with no ½). Critically, the `i=j` self curvature `w·s'_k·J_ik²` of the
9086            // rank-one block lives on the assembled `htt` DIAGONAL `H_ik`, so its
9087            // derivative is ALREADY differentiated by the row-local
9088            // `local_logit_third` channel (direct-z, `i=w`) and the `m_channel`
9089            // column pass (via `M_k`) above. This Woodbury pass must therefore add
9090            // ONLY the off-diagonal `i≠j` remainder — otherwise the self term is
9091            // double-counted (the #1641 defect: the pre-fix pass summed the full
9092            // `u_k u_kᵀ` including `i=j`, AND carried the ρ-trace ½, AND dropped the
9093            // factor 2 on the symmetric `e_w u_kᵀ + u_k e_wᵀ` term). Excluding `i=j`
9094            // is also why this pass needs no deflation correction: it contracts only
9095            // DISTINCT rows, off any single-row `vᵢ`'s support (matching the
9096            // #1416 ρ-trace cross-row pass).
9097            //
9098            // Contracting `tr(H⁻¹ ∂W_k/∂ℓ_wk)` over `i≠j` only:
9099            //   Γ_wk += dd_k·J_wk·( u_kᵀ H⁻¹ u_k − Σ_i P_ii·J_ik² )       (term A)
9100            //         + 2·d_k·c_wk·( (H⁻¹ u_k)_{slot(w,k)} − P_ww·J_wk )  (term B),
9101            // where `P_ii = (H⁻¹)_{slot(i,k),slot(i,k)}` is the selected-inverse
9102            // diagonal recorded in `ibp_logit_sites`. The subtracted self pieces are
9103            // exactly the `i=j` terms the diagonal channels own. Both `u_kᵀ H⁻¹ u_k`
9104            // and `(H⁻¹ u_k)` come from ONE solve per column, `x_k = H⁻¹ u_k` — so
9105            // the adjoint differentiates the SAME `H = H₀ + Σ_k W_k` the
9106            // value/logdet use, closing the one-operator contract on the rank-one
9107            // block too.
9108            //
9109            // Group the column sites once (the layout is mode-agnostic: dense or
9110            // compact, `ibp_logit_sites` already carries each active logit's
9111            // global t-index AND its selected-inverse diagonal `G_ii`), then per
9112            // column build `u_k`, solve, and distribute the OFF-DIAGONAL remainder.
9113            //
9114            // #1416 FIX: the diagonal (`i = w`) parts of term A and term B are
9115            // ALREADY supplied — `diag(term A) = dd_k·J_w·Σ_i G_ii·J_i²` by the
9116            // `m_channel` column pass above (whose `m_channel = w·(s''·J² + s'·c)`
9117            // carries the `s''·J²` self piece), and `diag(term B) = 2·d_k·c_w·G_ww·J_w`
9118            // by the inline `local_logit_third` self channel (whose
9119            // `s'·2J·∂_z J` piece is exactly that). So this pass must add ONLY the
9120            // cross-row off-diagonal remainder; double-counting the diagonal here
9121            // (the pre-fix `0.5·dd·J·uᵀGu + d·c·x_w` form, which is neither the
9122            // full nor the off-diagonal value) desynced the θ-adjoint from the FD
9123            // of `log|H|`. The exact `tr(H⁻¹ ∂W_k/∂ℓ_wk)` is
9124            //   Γ_wk += dd_k·J_wk·(uᵀ G u − Σ_i G_ii·J_ik²)   (term A, off-diagonal)
9125            //         + 2·d_k·c_wk·((G u)_w − G_ww·J_wk)        (term B, off-diagonal),
9126            // with `uᵀGu = Σ_i J_ik·(Gu)_i`, `(Gu) = x_k = H⁻¹ u_k` from one solve,
9127            // and `G_ii` the per-site selected-inverse diagonal.
9128            let total_t = cache.delta_t_len();
9129            let mut col_sites: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); k_atoms];
9130            for &(row, atom, t_index, inv_diag) in &ibp_logit_sites {
9131                col_sites[atom].push((row, t_index, inv_diag));
9132            }
9133            // Hoisted RHS scratch: fill only this column's active slots, solve,
9134            // then clear exactly those slots — no per-column total_t zeroing.
9135            let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
9136            let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
9137            for atom in 0..k_atoms {
9138                let d_k = channels.cross_row_d[atom];
9139                let dd_k = channels.cross_row_dd[atom];
9140                if col_sites[atom].is_empty() || (d_k == 0.0 && dd_k == 0.0) {
9141                    continue;
9142                }
9143                // u_k as a full t-RHS: J at each active logit-k slot.
9144                for &(row, t_index, _g) in &col_sites[atom] {
9145                    rhs_t_scratch[t_index] = channels.z_jac[row * k_atoms + atom];
9146                }
9147                let x_k = solver
9148                    .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
9149                    .map_err(|err| {
9150                        format!("logdet_theta_adjoint: IBP cross-row Woodbury solve: {err}")
9151                    })?;
9152                // Clear this column's active slots for the next atom's RHS.
9153                for &(_row, t_index, _g) in &col_sites[atom] {
9154                    rhs_t_scratch[t_index] = 0.0;
9155                }
9156                // (JᵀH⁻¹J)_k = u_kᵀ x_k, and the diagonal `Σ_i G_ii·J_ik²` that the
9157                // `m_channel` pass already counted (subtract it from term A so this
9158                // pass holds only the off-diagonal `i ≠ j` remainder).
9159                let mut jt_hinv_j = 0.0_f64;
9160                let mut diag_jt_g_j = 0.0_f64;
9161                for &(row, t_index, g_ii) in &col_sites[atom] {
9162                    let j = channels.z_jac[row * k_atoms + atom];
9163                    jt_hinv_j += j * x_k.t[t_index];
9164                    diag_jt_g_j += g_ii * j * j;
9165                }
9166                let off_diag_a = jt_hinv_j - diag_jt_g_j;
9167                for &(row, t_index, g_ii) in &col_sites[atom] {
9168                    let j_wk = channels.z_jac[row * k_atoms + atom];
9169                    let c_wk = channels.logit_curvature[row * k_atoms + atom];
9170                    // term A (off-diagonal) + term B (off-diagonal); the inline /
9171                    // `m_channel` passes already added the diagonal parts.
9172                    let off_diag_b = x_k.t[t_index] - g_ii * j_wk;
9173                    gamma_t[t_index] += dd_k * j_wk * off_diag_a + 2.0 * d_k * c_wk * off_diag_b;
9174                }
9175            }
9176        }
9177
9178        Ok(SaeArrowVector {
9179            t: gamma_t,
9180            beta: gamma_beta,
9181        })
9182    }
9183
9184
9185    /// Public analytic outer-ρ gradient at a converged inner state, constructing
9186    /// the deflated arrow solver from the supplied cache. Use this seam from
9187    /// integration tests and external consumers that have a converged
9188    /// `(loss, cache)` from [`Self::reml_criterion_with_cache`] but no access to
9189    /// the crate-private `DeflatedArrowSolver`.
9190    pub fn analytic_outer_rho_gradient_at_converged(
9191        &self,
9192        target: ArrayView2<'_, f64>,
9193        rho: &SaeManifoldRho,
9194        loss: &SaeManifoldLoss,
9195        cache: &ArrowFactorCache,
9196    ) -> Result<SaeOuterRhoGradientComponents, String> {
9197        let solver = self.outer_gradient_arrow_solver(cache, &rho.lambda_smooth_vec())?;
9198        self.analytic_outer_rho_gradient_components(target, rho, loss, cache, &solver)
9199            .map_err(|e| e.to_string())
9200    }
9201
9202    /// Compose the SAE LAML criterion as a sum of atoms (#931 SAE pilot).
9203    ///
9204    /// This is the single seam that establishes value↔gradient coherence for
9205    /// the SAE objective: it runs the inner solve once via
9206    /// [`Self::reml_criterion_with_cache`], reads the value decomposition
9207    /// (`loss.total() + extra_penalty_energy`, `log|H|`, `occam`) and the
9208    /// matching gradient channels (`SaeOuterRhoGradientComponents`) from the
9209    /// SAME converged cache, and hands them to [`SaeCriterion::assemble`]. The
9210    /// returned criterion's [`SaeCriterion::value`] and
9211    /// [`SaeCriterion::gradient`] are then projections of one factorization —
9212    /// the outer optimizer can no longer evaluate a value path and a gradient
9213    /// path that disagree (the #752/#748/#901 desync class). The
9214    /// implicit-stationarity envelope correction (#1006's Γ term) is its own
9215    /// named atom, so the channel the desync class keeps dropping is visible
9216    /// rather than a silent zero.
9217    pub fn criterion_as_atoms(
9218        &mut self,
9219        target: ArrayView2<'_, f64>,
9220        rho: &SaeManifoldRho,
9221        registry: Option<&AnalyticPenaltyRegistry>,
9222        inner_max_iter: usize,
9223        learning_rate: f64,
9224        ridge_ext_coord: f64,
9225        ridge_beta: f64,
9226    ) -> Result<SaeCriterion, String> {
9227        let (_v, loss, cache) = self.reml_criterion_with_cache(
9228            target,
9229            rho,
9230            registry,
9231            inner_max_iter,
9232            learning_rate,
9233            ridge_ext_coord,
9234            ridge_beta,
9235        )?;
9236        let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
9237            "criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
9238        })?;
9239        let occam = self.reml_occam_term(rho)?;
9240        let extra_penalty_energy = match registry {
9241            Some(reg) => self
9242                .reml_extra_penalty_value_total(reg)
9243                .map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
9244            None => 0.0,
9245        };
9246        let data_fit_priors_value = loss.total() + extra_penalty_energy;
9247
9248        let solver = self.outer_gradient_arrow_solver(&cache, &rho.lambda_smooth_vec())?;
9249        let components =
9250            self.analytic_outer_rho_gradient_components(target, rho, &loss, &cache, &solver)?;
9251        Ok(SaeCriterion::assemble(
9252            data_fit_priors_value,
9253            log_det,
9254            occam,
9255            components.explicit,
9256            components.logdet_trace,
9257            components.occam,
9258            components.third_order_correction,
9259        ))
9260    }
9261
9262    // [#780 line-count gate] reconstruction_dispersion + assemble_shape_uncertainty
9263    // + complete_born_atom_shape_bands + shape_uncertainty_without_decoder_covariance
9264    // (the contiguous trailing methods of this impl block) were split into the
9265    // sibling construction_reconstruction.rs (declared in mod.rs); callers reach
9266    // them bare via use super::*.
9267}
9268
9269// [#780 line-count gate] Per-row jet / reconstruction-channel assembly for the
9270// streaming-exact arrow log-det lives in a sibling file as a second
9271// `impl SaeManifoldTerm` block, inlined here so it keeps the SAME module scope
9272// and private-field access. Keeps this tracked file under the 10k limit.
9273include!("construction_row_jet_logdet_channels.rs");
9274
9275// [#780 line-count gate] Massive-K decoder-smoothness effective-dof Hutchinson
9276// estimator (associated constants + the matrix-free per-atom trace) lives in a
9277// sibling file as another `impl SaeManifoldTerm` block, inlined here so it keeps
9278// the SAME module scope and private-field access. The two gated exact/estimator
9279// entry points above dispatch into it at `K >= MIN_ATOMS`.
9280include!("construction_smoothness_dof.rs");
9281
9282// [#780 line-count gate] `term_from_padded_blocks_with_mode` (the padded-FFI
9283// term builder) was split into the sibling `construction_padded_blocks.rs`
9284// module (declared and re-exported from `mod.rs`), keeping this tracked file
9285// under the 10k limit. Callers still reach it bare through `use super::*`.
9286
9287// [#780 line-count gate] `refresh_isometry_caches_from_atom` and
9288// `refresh_isometry_caches_from_term` were split into the sibling
9289// `construction_cache_refresh.rs` module (declared and re-exported from
9290// `mod.rs`), keeping this tracked file under the 10k limit. Callers still reach
9291// both functions bare through `use super::*`.
9292
9293// [#780 line-count gate] The `#[cfg(test)]` modules below the production code
9294// are mechanically split into a sibling `*_tests` file and inlined via
9295// `include!` (the sanctioned cohesive-module decomposition — see build.rs
9296// file_stem_is_exempt_test_module). Keeps this tracked file under the 10k limit.
9297include!("construction_tests.rs");