gam_sae/manifold/construction.rs
1use super::*;
2use gam_math::jet_scalar::JetScalar;
3
4// [#780] Softmax-entropy Gershgorin majorizer leaf helpers live in a sibling
5// cohesive module, inlined here so they share this module scope.
6include!("softmax_entropy_majorizer.rs");
7
8// [#780] The exact stationarity-Jacobian correction and exact-Hessian solve
9// methods live in a sibling file, inlined here so they share this `impl
10// SaeManifoldTerm` / module scope while keeping this file under the line-count
11// gate.
12include!("construction_exact_hessian.rs");
13
14// [#780] The outer-gradient error taxonomy (`OuterGradientError`), the
15// `ForcedRowLayout` override alias, the `COTRAIN_*` co-training weight
16// constants, and the `AmortizedEncoderConsistency` report were extracted
17// verbatim into the sibling `construction_aux_types` module to keep this file
18// under the per-file line-count gate. They re-enter this module's scope via the
19// parent's glob re-export (`use super::*;` above).
20
21impl SaeManifoldTerm {
22 #[must_use = "build error must be handled"]
23 pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
24 if atoms.is_empty() {
25 return Err("SaeManifoldTerm::new: at least one atom required".into());
26 }
27 let n = atoms[0].n_obs();
28 let p = atoms[0].output_dim();
29 if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
30 return Err(format!(
31 "SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
32 assignment.n_obs(),
33 assignment.k_atoms(),
34 atoms.len()
35 ));
36 }
37 for (k, atom) in atoms.iter().enumerate() {
38 if atom.n_obs() != n {
39 return Err(format!(
40 "SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
41 atom.n_obs()
42 ));
43 }
44 if atom.output_dim() != p {
45 return Err(format!(
46 "SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
47 atom.output_dim()
48 ));
49 }
50 if atom.latent_dim != assignment.coords[k].latent_dim() {
51 return Err(format!(
52 "SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
53 atom.latent_dim,
54 assignment.coords[k].latent_dim()
55 ));
56 }
57 }
58 Ok(Self {
59 atoms,
60 assignment,
61 temperature_schedule: None,
62 last_row_layout: None,
63 row_metric: None,
64 collapse_events: Vec::new(),
65 row_loss_weights: None,
66 last_frames_active: false,
67 assembly_chunk_override: None,
68 fixed_decoder_assembly: false,
69 softmax_active_cap: None,
70 border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
71 certificate_dispersion: None,
72 curvature_walk_report: None,
73 expected_evidence_gauge_deflated_directions: None,
74 evidence_gauge_deflation_reanchors: 0,
75 evidence_gauge_deflation_last_delta_sign: 0,
76 dictionary_cocollapse_reseeds: 0,
77 best_cocollapse_incumbent: None,
78 decoder_repulsion_gate: None,
79 barrier_coactivation_gate: None,
80 hybrid_split_report: None,
81 atom_inner_fits: None,
82 oos_linear_images: None,
83 separation_barrier_strength_override: None,
84 })
85 }
86
87 /// #1777 — apply the PER-FIT configuration overrides (the FFI-facing
88 /// [`SaeFitConfig`]) as the source of truth for this term's fit, isolating it
89 /// from the deprecated process-global barrier/α atomics.
90 ///
91 /// Distributes the config to its two authorities: the barrier strength override
92 /// onto the term (read by `separation_barrier_strength`), and the IBP-α
93 /// override onto the assignment (read by
94 /// [`SaeAssignment::resolved_ibp_alpha`]). Any `None` field leaves that axis on
95 /// its historical fallback (process-global override, then the
96 /// data-derived/mode default), so an all-`None` config is a strict no-op. Call
97 /// this after building the term (before the fit) so concurrent fits carrying
98 /// distinct configs stay isolated without any global writes.
99 pub fn set_fit_config(&mut self, config: SaeFitConfig) {
100 self.separation_barrier_strength_override = config.separation_barrier_strength_override;
101 self.assignment
102 .set_ibp_alpha_override(config.ibp_alpha_override);
103 }
104
105 /// #1777 — the per-fit configuration currently in force on this term,
106 /// reconstructed from its two authorities (the term's barrier override and the
107 /// assignment's α override). Round-trips with [`Self::set_fit_config`].
108 #[must_use]
109 pub fn fit_config(&self) -> SaeFitConfig {
110 SaeFitConfig {
111 separation_barrier_strength_override: self.separation_barrier_strength_override,
112 ibp_alpha_override: self.assignment.ibp_alpha_override,
113 }
114 }
115
116 /// #1408/#1409 — install the optional hard per-row active-atom cap for
117 /// Softmax mode (threaded from the fit/encode `top_k`). A `Some(k)` with
118 /// `1 <= k < K` makes the Softmax assignment optimize on the COMPACT
119 /// top-`k` row layout (see [`Self::softmax_active_cap`]); `Some(k) >= K`
120 /// and `None` are both no-ops (full support). Non-softmax modes ignore it.
121 pub fn set_softmax_active_cap(&mut self, top_k: Option<usize>) {
122 self.softmax_active_cap = match top_k {
123 Some(k) if k >= 1 && k < self.k_atoms() => Some(k),
124 _ => None,
125 };
126 }
127
128 /// Install the fitted reconstruction dispersion used by
129 /// [`dictionary_incoherence_report`]. This is a pure diagnostic scalar and
130 /// does not feed any loss, criterion, penalty, or optimizer state.
131 pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
132 if !dispersion.is_finite() || dispersion <= 0.0 {
133 return Err(format!(
134 "SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
135 ));
136 }
137 self.certificate_dispersion = Some(dispersion);
138 Ok(())
139 }
140
141 /// Harvest the per-atom inner-decoder-smooth byproducts (#1097 / #1103) the
142 /// residual-gauge certificate's post-PIRLS atom inference reports consume.
143 ///
144 /// This is the post-fit harness seam: it needs the reconstruction target `Z`
145 /// (`target`) and the fitted dispersion `φ` (`dispersion`), both available
146 /// only after the joint fit converges and the engine has discarded `Z` from
147 /// the objective. For each atom `k` it captures the Gaussian-identity
148 /// penalized smooth of the atom's leading decoder output channel `j`
149 /// (largest column 2-norm of `B_k`) against its partial residual
150 /// `e_{i} = z_i − fitted_i + a_{ik} g_k(t_i)` on channel `j`, holding all
151 /// other atoms and the assignment fixed at the fitted optimum — exactly the
152 /// fixed snapshot ([`crate::identifiability::AtomInnerFit`]) the Riesz
153 /// debiasing and split-LRT smooth-structure e-value read.
154 ///
155 /// A pure read of the fitted state: it mutates only the diagnostic
156 /// `atom_inner_fits` field, never a loss / criterion / penalty / optimizer
157 /// state. Atoms with no active rows or a degenerate (rank-deficient,
158 /// non-SPD) inner Hessian get a `None` slot — the genuine prerequisite (an
159 /// SPD penalized inner Hessian on a non-empty active set) is absent there.
160 pub fn set_atom_inner_fits(
161 &mut self,
162 target: ArrayView2<'_, f64>,
163 rho: &SaeManifoldRho,
164 dispersion: f64,
165 ) -> Result<(), String> {
166 if !dispersion.is_finite() || dispersion <= 0.0 {
167 return Err(format!(
168 "SaeManifoldTerm::set_atom_inner_fits: dispersion must be finite and positive, got {dispersion}"
169 ));
170 }
171 let n = self.n_obs();
172 let p = self.output_dim();
173 let k_atoms = self.k_atoms();
174 if target.dim() != (n, p) {
175 return Err(format!(
176 "SaeManifoldTerm::set_atom_inner_fits: target {:?} != ({n}, {p})",
177 target.dim()
178 ));
179 }
180
181 // #1026 — `atom_inner_fits` is a pure diagnostic; skip its dense (N×K×P)
182 // tensor (~256 GiB at K=32768,P=32) past a cell ceiling — all-None slots,
183 // never OOM. The fit is unaffected; only this audit field is absent.
184 if n.saturating_mul(k_atoms).saturating_mul(p) > 64_000_000 {
185 self.atom_inner_fits = Some((0..k_atoms).map(|_| None).collect());
186 return Ok(());
187 }
188
189 // Settled per-row assignments and per-(row, atom) decoded outputs, so the
190 // per-atom partial residual is `e_k = (z − fitted) + a_k decoded_k`.
191 let mut assignments = Vec::with_capacity(n);
192 for row in 0..n {
193 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
194 }
195 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
196 let mut dbuf = vec![0.0_f64; p];
197 for row in 0..n {
198 for atom_idx in 0..k_atoms {
199 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
200 for c in 0..p {
201 decoded[[row, atom_idx, c]] = dbuf[c];
202 }
203 }
204 }
205 let mut fitted = Array2::<f64>::zeros((n, p));
206 for row in 0..n {
207 for atom_idx in 0..k_atoms {
208 let a = assignments[row][atom_idx];
209 if a == 0.0 {
210 continue;
211 }
212 for c in 0..p {
213 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
214 }
215 }
216 }
217
218 let mut inner_fits: Vec<Option<crate::identifiability::AtomInnerFit>> =
219 Vec::with_capacity(k_atoms);
220 for atom_idx in 0..k_atoms {
221 inner_fits.push(self.build_atom_inner_fit(
222 atom_idx,
223 target,
224 &assignments,
225 decoded.view(),
226 fitted.view(),
227 dispersion,
228 )?);
229 }
230 self.atom_inner_fits = Some(inner_fits);
231 Ok(())
232 }
233
234 /// Build one atom's fixed inner-smooth snapshot for the post-PIRLS atom
235 /// inference reports, or `None` when the atom has no active rows or the
236 /// penalized inner Hessian is not SPD. Returns `Err` only on a structural
237 /// inconsistency (shape mismatch), never on a benign degenerate atom.
238 pub(crate) fn build_atom_inner_fit(
239 &self,
240 atom_idx: usize,
241 target: ArrayView2<'_, f64>,
242 assignments: &[Array1<f64>],
243 decoded: ArrayView3<'_, f64>,
244 fitted: ArrayView2<'_, f64>,
245 dispersion: f64,
246 ) -> Result<Option<crate::identifiability::AtomInnerFit>, String> {
247 let atom = &self.atoms[atom_idx];
248 let n = atom.n_obs();
249 let m = atom.basis_size();
250 let p = atom.output_dim();
251 if m == 0 || p == 0 {
252 return Ok(None);
253 }
254
255 // Leading decoder output channel j = argmax_j ‖B_k[:, j]‖, the channel
256 // that carries the atom's signal.
257 let mut j_lead = 0usize;
258 let mut best_norm = -1.0_f64;
259 for col in 0..p {
260 let mut norm = 0.0_f64;
261 for r in 0..m {
262 let v = atom.decoder_coefficients[[r, col]];
263 norm += v * v;
264 }
265 if norm > best_norm {
266 best_norm = norm;
267 j_lead = col;
268 }
269 }
270 let beta = atom.decoder_coefficients.column(j_lead).to_owned();
271
272 // Active rows: a_{ik} > 0.
273 let active: Vec<usize> = (0..n)
274 .filter(|&row| assignments[row][atom_idx] > 0.0)
275 .collect();
276 let n_active = active.len();
277 // The penalized smooth needs at least as many active rows as it has
278 // basis columns to give a non-degenerate data Gram; below that the inner
279 // fit's SPD prerequisite is genuinely unmet.
280 if n_active == 0 {
281 return Ok(None);
282 }
283
284 let mut design = Array2::<f64>::zeros((n_active, m));
285 let mut derivative_design = Array2::<f64>::zeros((n_active, m));
286 let mut row_scores = Array2::<f64>::zeros((n_active, m));
287 let mut weights = Array1::<f64>::zeros(n_active);
288 for (slot, &row) in active.iter().enumerate() {
289 let a_ik = assignments[row][atom_idx];
290 let w_i = a_ik * a_ik;
291 weights[slot] = w_i;
292 for col in 0..m {
293 design[[slot, col]] = atom.basis_values[[row, col]];
294 // Leading latent axis (axis 0) is the atom's primary coordinate;
295 // it is the one the average-derivative functional integrates.
296 derivative_design[[slot, col]] = atom.basis_jacobian[[row, col, 0]];
297 }
298 // Partial residual on channel j, then the inner-smooth working
299 // response z_i = e_i / a_ik so that w_i (z_i − Φᵀβ) = a_ik r_i.
300 let e_i = target[[row, j_lead]] - fitted[[row, j_lead]]
301 + a_ik * decoded[[row, atom_idx, j_lead]];
302 let mu_hat = design.row(slot).dot(&beta);
303 let z_i = e_i / a_ik;
304 let res_i = z_i - mu_hat;
305 // Gaussian-identity score s_i = −w_i res_i Φ_i / φ.
306 let scale = -w_i * res_i / dispersion;
307 for col in 0..m {
308 row_scores[[slot, col]] = scale * design[[slot, col]];
309 }
310 }
311
312 // Penalized inner Hessian H = ΦᵀWΦ + S̃_k.
313 let mut xtwx = Array2::<f64>::zeros((m, m));
314 for slot in 0..n_active {
315 let w_i = weights[slot];
316 for a in 0..m {
317 let xa = design[[slot, a]];
318 if xa == 0.0 {
319 continue;
320 }
321 for b in 0..m {
322 xtwx[[a, b]] += w_i * xa * design[[slot, b]];
323 }
324 }
325 }
326 let penalty = atom.smooth_penalty.clone();
327 if penalty.dim() != (m, m) {
328 return Err(format!(
329 "build_atom_inner_fit: atom {atom_idx} smooth penalty {:?} != ({m}, {m})",
330 penalty.dim()
331 ));
332 }
333 let penalized_hessian = &xtwx + &penalty;
334
335 // SPD prerequisite: the inner penalized Hessian must factor, else the
336 // atom's inner-smooth fit is degenerate and no report is producible.
337 if penalized_hessian.cholesky(Side::Lower).is_err() {
338 return Ok(None);
339 }
340
341 // Peak (largest fitted |g_k| on channel j) and mode (largest assignment
342 // mass) design rows, over the active set.
343 let mut peak_slot = 0usize;
344 let mut peak_val = -1.0_f64;
345 let mut mode_slot = 0usize;
346 let mut mode_mass = -1.0_f64;
347 for (slot, &row) in active.iter().enumerate() {
348 let g_val = design.row(slot).dot(&beta).abs();
349 if g_val > peak_val {
350 peak_val = g_val;
351 peak_slot = slot;
352 }
353 let mass = assignments[row][atom_idx];
354 if mass > mode_mass {
355 mode_mass = mass;
356 mode_slot = slot;
357 }
358 }
359 let peak_design_row = design.row(peak_slot).to_owned();
360 let mode_design_row = design.row(mode_slot).to_owned();
361
362 Ok(Some(crate::identifiability::AtomInnerFit {
363 design,
364 derivative_design,
365 beta,
366 penalty,
367 penalized_hessian,
368 row_scores,
369 weights,
370 dispersion,
371 peak_design_row,
372 mode_design_row,
373 }))
374 }
375
376 /// Profile the Gaussian reconstruction dispersion at the current seed
377 /// state. This is the scale used to make SAE penalty seeds dimensionless
378 /// before the outer rho search starts.
379 pub fn seed_reconstruction_dispersion(
380 &self,
381 target: ArrayView2<'_, f64>,
382 ) -> Result<f64, String> {
383 let fitted = self.try_fitted()?;
384 if fitted.dim() != target.dim() {
385 return Err(format!(
386 "SaeManifoldTerm::seed_reconstruction_dispersion: fitted {:?} != target {:?}",
387 fitted.dim(),
388 target.dim()
389 ));
390 }
391 let n_scalar = (target.nrows() * target.ncols()).max(1) as f64;
392 let mut rss = 0.0_f64;
393 for row in 0..target.nrows() {
394 for col in 0..target.ncols() {
395 let r = target[[row, col]] - fitted[[row, col]];
396 rss += r * r;
397 }
398 }
399 if !rss.is_finite() || rss < 0.0 {
400 return Err(format!(
401 "SaeManifoldTerm::seed_reconstruction_dispersion: non-finite seed RSS {rss}"
402 ));
403 }
404 Ok((rss / n_scalar).max(SAE_SEED_DISPERSION_FLOOR))
405 }
406
407 /// Install per-row design honesty weights (#991) — the `1/π` inclusion
408 /// corrections of a designed corpus subsample (see the field docs on
409 /// `row_loss_weights` for exactly where they enter the objective).
410 ///
411 /// Weights must be finite and strictly positive, one per term row. They
412 /// are self-normalized to mean `1.0` here (only the *relative* design
413 /// correction matters at the fitted sample size; the absolute `n/budget`
414 /// scale would silently inflate the dispersion estimate against the
415 /// sample-sized dof). Weights that are identically equal after
416 /// normalization (an exact full pass, or any uniform design) are stored
417 /// as `None`, so the unweighted path stays bit-for-bit identical rather
418 /// than "multiplied by 1.0".
419 pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
420 if weights.len() != self.n_obs() {
421 return Err(format!(
422 "SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
423 weights.len(),
424 self.n_obs()
425 ));
426 }
427 if weights.is_empty() {
428 self.row_loss_weights = None;
429 return Ok(());
430 }
431 if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
432 return Err(
433 "SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
434 positive"
435 .to_string(),
436 );
437 }
438 let first = weights[0];
439 if weights.iter().all(|w| *w == first) {
440 // Uniform design (full pass, or flat measure): the normalized
441 // weight is exactly 1 everywhere — take the unweighted path.
442 self.row_loss_weights = None;
443 return Ok(());
444 }
445 let mean = weights.iter().sum::<f64>() / weights.len() as f64;
446 self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
447 Ok(())
448 }
449
450 /// The installed (mean-1 normalized) design honesty weights, `None` on the
451 /// exact unweighted path.
452 pub fn row_loss_weights(&self) -> Option<&[f64]> {
453 self.row_loss_weights.as_deref()
454 }
455
456 /// Drop any installed per-row reconstruction weights, returning the term to
457 /// the exact unweighted (full-pass) path. Used by the #997 structure-search
458 /// wiring to clear the internal estimation/evaluation mask off the adopted
459 /// term before the payload reconstruction is read over all rows.
460 pub fn clear_row_loss_weights(&mut self) {
461 self.row_loss_weights = None;
462 }
463
464 /// Huber-style OUTLIER-ROBUST per-row weights from the target activation
465 /// norms — the missing default *policy* for the existing
466 /// [`set_row_loss_weights`](Self::set_row_loss_weights) mechanism.
467 ///
468 /// The SAE fits unweighted least squares, which weights each token by its
469 /// squared residual ∝ `‖z_i‖²`. On real LLM residual streams the per-token
470 /// norm distribution is heavy-tailed (e.g. an OLMo mixed-layer slice has
471 /// `p99/median ≈ 4.7`), so a small **coherent** cluster of high-norm tokens —
472 /// typically special / attention-sink tokens, not semantic content —
473 /// dominates the objective (measured: the top 5% of tokens carry ~31% of the
474 /// total `‖z‖²` budget) and pulls dictionary atoms toward their direction.
475 /// Mean-centering does NOT address this (it is per-feature, not per-token).
476 ///
477 /// This returns Huber weights `w_i = min(1, δ·m / ‖z_i‖)` where `m` is the
478 /// MEDIAN token norm: tokens at or below `δ·m` keep full weight, higher-norm
479 /// tokens are downweighted so their objective share grows only LINEARLY (not
480 /// quadratically) with norm. `δ` is the robustness knob (`δ=1` thresholds at
481 /// the median; larger `δ` only touches the extreme tail). The result is
482 /// mean-normalized (overall objective scale preserved). OPT-IN: the caller
483 /// installs it via `set_row_loss_weights` — the default fit is unchanged.
484 pub fn robust_norm_row_weights(
485 target: ArrayView2<'_, f64>,
486 delta: f64,
487 ) -> Result<Vec<f64>, String> {
488 if !(delta.is_finite() && delta > 0.0) {
489 return Err(format!(
490 "robust_norm_row_weights: delta must be finite and positive; got {delta}"
491 ));
492 }
493 let n = target.nrows();
494 if n == 0 {
495 return Ok(Vec::new());
496 }
497 let norms: Vec<f64> = (0..n)
498 .map(|i| {
499 let r = target.row(i);
500 r.dot(&r).sqrt()
501 })
502 .collect();
503 let mut sorted = norms.clone();
504 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
505 // Median token norm (lower-median for even n; floored off zero so an
506 // all-zero/degenerate slice yields uniform weights instead of NaN).
507 let median = sorted[n / 2].max(f64::MIN_POSITIVE);
508 let thresh = delta * median;
509 let raw: Vec<f64> = norms
510 .iter()
511 .map(|&nm| if nm <= thresh { 1.0 } else { thresh / nm })
512 .collect();
513 let mean = raw.iter().sum::<f64>() / n as f64;
514 if !(mean.is_finite() && mean > 0.0) {
515 return Err("robust_norm_row_weights: degenerate weight normalizer".to_string());
516 }
517 Ok(raw.into_iter().map(|w| w / mean).collect())
518 }
519
520 /// Install the single per-row [`RowMetric`](gam_problem::RowMetric)
521 /// that both the reconstruction likelihood and the isometry gauge read.
522 /// Installing per-row output-Fisher factors here flips the provenance to
523 /// `OutputFisher` *and* is the only way the gauge acquires a non-identity
524 /// weight, so the two inner products cannot diverge. Passing a Euclidean
525 /// metric (or never calling this) keeps the bit-identical isotropic path.
526 ///
527 /// The metric's row count and output dimension must match the term.
528 pub fn set_row_metric(
529 &mut self,
530 metric: gam_problem::RowMetric,
531 ) -> Result<(), String> {
532 if metric.n_rows() != self.n_obs() {
533 return Err(format!(
534 "SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
535 metric.n_rows(),
536 self.n_obs()
537 ));
538 }
539 if metric.p_out() != self.output_dim() {
540 return Err(format!(
541 "SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
542 metric.p_out(),
543 self.output_dim()
544 ));
545 }
546 self.row_metric = Some(metric);
547 Ok(())
548 }
549
550 /// The installed per-row metric, if any. `None` ⇒ Euclidean / isotropic.
551 /// Consumed by the gauge wiring (to build the matching `WeightField`) and by
552 /// Object 4 (to read the [`MetricProvenance`](gam_problem::MetricProvenance)).
553 pub fn row_metric(&self) -> Option<&gam_problem::RowMetric> {
554 self.row_metric.as_ref()
555 }
556
557 /// The per-row inner product the additive diagnostics read through: the
558 /// installed [`RowMetric`](gam_problem::RowMetric) when one
559 /// was set (output-Fisher harvest present), otherwise a freshly-built
560 /// Euclidean metric of the term's own `(n_obs, output_dim)` shape. Either way
561 /// a metric always exists, so the diagnostics are never gated by a flag — the
562 /// Euclidean fallback is the bit-identical isotropic path.
563 pub(crate) fn diagnostic_metric(
564 &self,
565 ) -> Result<gam_problem::RowMetric, String> {
566 match self.row_metric() {
567 Some(metric) => Ok(metric.clone()),
568 None => {
569 gam_problem::RowMetric::euclidean(self.n_obs(), self.output_dim())
570 }
571 }
572 }
573
574 /// Build the additive post-fit diagnostic report for this fitted term: the
575 /// two-score per-atom [`AtomTwoLensReport`](crate::inference::atom_lens::AtomTwoLensReport)
576 /// (presence / behavioral coupling / discrepancy) and the residual-gauge
577 /// [`ResidualGaugeReport`](crate::identifiability::ResidualGaugeReport)
578 /// certificate.
579 ///
580 /// Both reports are read through the same single metric
581 /// ([`Self::diagnostic_metric`]): under a Euclidean / no-harvest provenance
582 /// the lens coupling is `None` and the gauge is certified under Euclidean
583 /// provenance — never an error, never gated by a flag (magic-by-default,
584 /// mirroring the metric selection itself).
585 ///
586 /// `per_atom_ard_variances`, when supplied, is one ARD variance vector per
587 /// atom (length = `latent_dim_k`), threaded into the certificate's
588 /// equal-ARD-rotation detection. `None` (or a per-atom `None`) ⇒ no ARD prior
589 /// on that atom. `isometry_pin_active` records whether an isometry gauge
590 /// penalty was installed on the fit: `false` escalates the certificate to the
591 /// `diffeomorphism-unpinned` verdict (the honest "no metric pin" statement),
592 /// exactly as the certificate's own escalation flag specifies.
593 ///
594 /// Pure read: it never mutates the term, never touches a loss / criterion /
595 /// penalty / optimizer state.
596 pub fn fit_diagnostics_report(
597 &self,
598 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
599 isometry_pin_active: bool,
600 reconstruction_dispersion: Option<f64>,
601 assignments_override: Option<ArrayView2<'_, f64>>,
602 ) -> Result<SaeManifoldFitDiagnostics, String> {
603 if let Some(view) = assignments_override {
604 let n = self.n_obs();
605 let k = self.k_atoms();
606 if view.dim() != (n, k) {
607 return Err(format!(
608 "fit_diagnostics_report: assignments_override shape {:?} must be ({n}, {k})",
609 view.dim()
610 ));
611 }
612 }
613 let metric = self.diagnostic_metric()?;
614 let atom_two_lens =
615 crate::inference::atom_lens::atom_two_lens(self, &metric, assignments_override);
616
617 let (certificate_model, streamed_curvature) =
618 self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
619 // #998: within-atom gauge families are certified on their EXACT orbits
620 // in the model's own (decoder, coordinate) parameter space — compensated
621 // symmetries are data-nulls by construction there, no lowering-error
622 // calibration involved. This now holds whether or not an isometry pin is
623 // active:
624 // * pin INACTIVE ⇒ the orbit verdict is the data residual alone (no
625 // penalty operator);
626 // * pin ACTIVE ⇒ the orbit verdict adds the isometry pin's orbit-space
627 // curvature through an [`OrbitPenaltyOperator`] lowered from the
628 // atom's second jet `Φ''` (the pullback-metric change along the orbit
629 // differentiates `J = Φ'B` through `t`). A model-class symmetry that
630 // preserves the metric stays a certified freedom; a non-isometric
631 // orbit (a basis not closed under the action) is genuinely pinned.
632 // The relative-curvature fraction `cost/stiffness²` is invariant to the
633 // pin strength μ (both faces scale with μ), so the operator is built at a
634 // canonical unit weight. An atom whose basis exposes no analytic second
635 // jet supplies no operator and falls back to the data residual — never an
636 // error. Magic-by-default either way: the choice is derived from the fit,
637 // never a flag.
638 let views = self.atom_parameter_views();
639 let ops: Vec<Option<crate::identifiability::OrbitPenaltyOperator>> =
640 if isometry_pin_active {
641 views
642 .iter()
643 .map(|view| {
644 view.as_ref().and_then(|v| {
645 crate::identifiability::isometry_orbit_penalty_operator(
646 v, 1.0,
647 )
648 })
649 })
650 .collect()
651 } else {
652 (0..self.k_atoms()).map(|_| None).collect()
653 };
654 let residual_gauge = if isometry_pin_active {
655 // The pin-active path consumes the per-row Jacobian curvature
656 // directly (the certificate_model retains it under a pin), so route
657 // through the non-streamed exact entry point.
658 crate::identifiability::residual_gauge_exact(
659 &certificate_model,
660 &views,
661 &ops,
662 )?
663 } else {
664 let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
665 "fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
666 .to_string()
667 })?;
668 crate::identifiability::residual_gauge_exact_from_curvature_gram(
669 &certificate_model,
670 &views,
671 &ops,
672 curvature_gram,
673 root_rows,
674 )?
675 };
676
677 // #1097 / #1103: per-atom Riesz-debiased functionals and the any-n-valid
678 // split-LRT smooth-structure e-value (non-constant vs constant inner
679 // decoder), read straight off the certificate model — which carries
680 // each atom's `inner_fit` snapshot when the caller harvested it via
681 // [`Self::set_atom_inner_fits`] before this report. Atoms without a
682 // harvested inner fit degrade their inference fields to `None` inside
683 // `atom_inference_reports`, so this is always populated (one entry per
684 // atom) and never gated by a flag.
685 let atom_inference =
686 crate::identifiability::atom_inference_reports(&certificate_model);
687
688 Ok(SaeManifoldFitDiagnostics {
689 atom_two_lens,
690 residual_gauge,
691 incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
692 Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
693 self, dispersion,
694 )?),
695 None => None,
696 },
697 atom_inference,
698 })
699 }
700
701 /// Build the trust-diagnostics producer for the Python `diagnostics` block.
702 ///
703 /// `assignments` is supplied by the payload assembly site so top-k projection,
704 /// when requested, is reflected in coverage/frequency and in the tangent
705 /// spectra. The active threshold is shared with the atom lens so all
706 /// assignment-support diagnostics agree on what "active" means.
707 pub fn trust_diagnostics_report(
708 &self,
709 assignments: ArrayView2<'_, f64>,
710 ) -> Result<SaeTrustDiagnostics, String> {
711 let n = self.n_obs();
712 let k_atoms = self.k_atoms();
713 if assignments.dim() != (n, k_atoms) {
714 return Err(format!(
715 "trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
716 assignments.dim()
717 ));
718 }
719 if !assignments.iter().all(|v| v.is_finite()) {
720 return Err("trust_diagnostics_report: assignments must be finite".to_string());
721 }
722 let metric = self.diagnostic_metric()?;
723 let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
724 let mut atoms = Vec::with_capacity(k_atoms);
725 let mut atom_trust = Vec::with_capacity(k_atoms);
726 for (atom_idx, atom) in self.atoms.iter().enumerate() {
727 let mut active_token_count = 0usize;
728 let mut activation_sum = 0.0_f64;
729 for row in 0..n {
730 let mass = assignments[[row, atom_idx]];
731 activation_sum += mass;
732 if mass > active_threshold {
733 active_token_count += 1;
734 }
735 }
736 let coverage = if n > 0 {
737 active_token_count as f64 / n as f64
738 } else {
739 0.0
740 };
741 let activation_frequency = if n > 0 {
742 activation_sum / n as f64
743 } else {
744 0.0
745 };
746 let (sigma_min_tangent, sigma_max_tangent) = self
747 .atom_tangent_spectrum_from_assignments(
748 atom_idx,
749 assignments,
750 &metric,
751 active_threshold,
752 )?;
753 let tangent_condition_score = if sigma_max_tangent > 0.0 {
754 (sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
755 } else {
756 0.0
757 };
758 let trust_score = tangent_condition_score;
759 atom_trust.push(trust_score);
760 atoms.push(SaeAtomTrustDiagnostics {
761 trust_score,
762 sigma_min_tangent,
763 sigma_max_tangent,
764 tangent_condition_score,
765 coverage,
766 activation_frequency,
767 untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
768 active_token_count,
769 });
770 }
771 Ok(SaeTrustDiagnostics { atom_trust, atoms })
772 }
773
774 pub(crate) fn atom_tangent_spectrum_from_assignments(
775 &self,
776 atom_idx: usize,
777 assignments: ArrayView2<'_, f64>,
778 metric: &gam_problem::RowMetric,
779 active_threshold: f64,
780 ) -> Result<(f64, f64), String> {
781 let atom = &self.atoms[atom_idx];
782 let d = atom.latent_dim;
783 let p = self.output_dim();
784 if d == 0 || p == 0 {
785 return Ok((0.0, 0.0));
786 }
787 let mut gram = Array2::<f64>::zeros((d, d));
788 let mut active_mass_sum = 0.0_f64;
789 let mut jac_row = vec![0.0_f64; p * d];
790 for row in 0..self.n_obs() {
791 let mass = assignments[[row, atom_idx]];
792 if !(mass > active_threshold) {
793 continue;
794 }
795 active_mass_sum += mass;
796 for axis in 0..d {
797 let start = axis;
798 let mut tangent = vec![0.0_f64; p];
799 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
800 for out in 0..p {
801 jac_row[out * d + start] = tangent[out];
802 }
803 }
804 let row_pullback = metric.pullback(row, &jac_row, d);
805 for axis_a in 0..d {
806 for axis_b in 0..=axis_a {
807 gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
808 }
809 }
810 jac_row.fill(0.0);
811 }
812 if !(active_mass_sum > 0.0) {
813 return Ok((0.0, 0.0));
814 }
815 let inv_mass = 1.0 / active_mass_sum;
816 for axis_a in 0..d {
817 for axis_b in 0..=axis_a {
818 let value = gram[[axis_a, axis_b]] * inv_mass;
819 gram[[axis_a, axis_b]] = value;
820 gram[[axis_b, axis_a]] = value;
821 }
822 }
823 let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
824 format!(
825 "trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
826 )
827 })?;
828 let mut sigma_min = f64::INFINITY;
829 let mut sigma_max = 0.0_f64;
830 for value in evals.iter().copied() {
831 let clamped = value.max(0.0);
832 let sigma = clamped.sqrt();
833 sigma_min = sigma_min.min(sigma);
834 sigma_max = sigma_max.max(sigma);
835 }
836 if sigma_min.is_finite() {
837 Ok((sigma_min, sigma_max))
838 } else {
839 Ok((0.0, 0.0))
840 }
841 }
842
843 /// Per-atom exact parameter-space views for the #998 certificate path:
844 /// the basis values / first-derivative jet, decoder coefficients, latent
845 /// coordinates, and assignment mass each atom was actually fitted with.
846 /// Sphere atoms get `None` (their chart's group action is nonlinear, so
847 /// the exact-orbit realisation does not apply and they stay on the frame
848 /// path), as does any atom whose coordinate chart width disagrees with its
849 /// latent dimension (a structurally inconsistent atom must not masquerade
850 /// as exactly certified).
851 pub(crate) fn atom_parameter_views(
852 &self,
853 ) -> Vec<Option<crate::identifiability::AtomParameterView>> {
854 let assignments = self.assignment.assignments();
855 let n = self.n_obs();
856 self.atoms
857 .iter()
858 .enumerate()
859 .map(|(k, atom)| {
860 if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
861 return None;
862 }
863 let coords = self.assignment.coords[k].as_matrix().to_owned();
864 if coords.nrows() != n || coords.ncols() != atom.latent_dim {
865 return None;
866 }
867 let mut activations = Array1::<f64>::zeros(n);
868 for row in 0..n {
869 activations[row] = assignments[[row, k]];
870 }
871 // Second jet Φ'' (#998): supplied when the atom's evaluator
872 // exposes an analytic Hessian, so a pin-active fit can lower its
873 // orbit-space isometry penalty operator (the metric-change of the
874 // pullback gram differentiates Φ' through t). Absent ⇒ the orbit
875 // verdict stays on the data residual / no-pin path, never an
876 // error.
877 let basis_second_jet = atom
878 .basis_evaluator
879 .as_ref()
880 .and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
881 .and_then(|res| res.ok());
882 Some(crate::identifiability::AtomParameterView {
883 basis_values: atom.basis_values.clone(),
884 basis_jacobian: atom.basis_jacobian.clone(),
885 decoder: atom.decoder_coefficients.clone(),
886 coords,
887 activations,
888 basis_second_jet,
889 })
890 })
891 .collect()
892 }
893
894 /// Lower this fitted term into the self-contained
895 /// [`FittedSaeManifold`](crate::identifiability::FittedSaeManifold) the
896 /// residual-gauge certificate consumes.
897 ///
898 /// The certificate's parameter space is the per-atom decoder **frame** — the
899 /// `(output_dim, latent_dim)` image of the atom's latent axes in output space.
900 /// We realise it as the active-mass-weighted mean decoder tangent
901 /// `frame_k[:, a] = (Σ_n a_{nk} · ∂g_k/∂t_a(n)) / Σ_n a_{nk}` over the atom's
902 /// active rows (the centroid decoder Jacobian columns the certificate docs
903 /// name). The per-row pinning Jacobian block `J_n ∈ ℝ^{p × param_dim}` is the
904 /// assignment-weighted per-row decoder tangent placed at each atom's frame
905 /// slot: column `(k, i, a)` of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i]` — exactly
906 /// the directions the reconstruction data gives cost to, in the same metric
907 /// the fit used (whitened by the certificate through `RowMetric`).
908 ///
909 /// The flattened frame layout matches the certificate's
910 /// `vec(frame_0) ⊕ vec(frame_1) ⊕ …`, row-major within each frame
911 /// (`frame_k[i, a]` at offset `atom_offset(k) + i·latent_dim_k + a`).
912 pub(crate) fn to_residual_gauge_model(
913 &self,
914 metric: gam_problem::RowMetric,
915 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
916 isometry_pin_active: bool,
917 ) -> Result<
918 (
919 crate::identifiability::FittedSaeManifold,
920 Option<(Array2<f64>, usize)>,
921 ),
922 String,
923 > {
924 use crate::identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
925
926 let n = self.n_obs();
927 let p = self.output_dim();
928 let k = self.k_atoms();
929 let assignments = self.assignment.assignments();
930
931 // Per-atom frame `(p, d)` = active-mass-weighted mean decoder tangent,
932 // and the flattened-frame column offset bookkeeping for the joint
933 // parameter vector (`vec(frame_0) ⊕ …`, row-major within each frame).
934 let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
935 let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
936 let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
937 let mut cursor = 0usize;
938 for (atom_idx, atom) in self.atoms.iter().enumerate() {
939 let d = atom.latent_dim;
940 let topology = match (&atom.basis_kind, d) {
941 (SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
942 AtomTopology::Circle
943 }
944 (SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
945 AtomTopology::Torus { latent_dim: d }
946 }
947 (SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
948 // `Cylinder` (`S¹ × ℝ`) has exactly one continuous gauge: the
949 // rotation (shift) of the periodic axis. The unbounded line axis
950 // carries no rotational gauge, and its translation is already
951 // pinned by the design's constant column — so the identifiability
952 // gauge is that of a single circle. Fixing it as `Torus` would
953 // over-impose a second (nonexistent) circle shift; fixing it as
954 // `EuclideanPatch { 2 }` would over-impose a frame rotation
955 // mixing the periodic and linear axes. `Circle` fixes the one
956 // real continuous gauge and leaves the linear axis ungauged.
957 (SaeAtomBasisKind::Cylinder, _) => AtomTopology::Circle,
958 (
959 SaeAtomBasisKind::Linear
960 | SaeAtomBasisKind::Duchon
961 | SaeAtomBasisKind::EuclideanPatch
962 | SaeAtomBasisKind::Poincare
963 | SaeAtomBasisKind::Precomputed(_),
964 _,
965 ) => AtomTopology::EuclideanPatch { latent_dim: d },
966 };
967
968 let mut frame = Array2::<f64>::zeros((p, d));
969 let mut active_mass = 0.0_f64;
970 let mut tangent = vec![0.0_f64; p];
971 for row in 0..n {
972 let a_nk = assignments[[row, atom_idx]];
973 if !(a_nk > 0.0) {
974 continue;
975 }
976 active_mass += a_nk;
977 for axis in 0..d {
978 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
979 for i in 0..p {
980 frame[[i, axis]] += a_nk * tangent[i];
981 }
982 }
983 }
984 if active_mass > 0.0 {
985 let inv = 1.0 / active_mass;
986 frame.mapv_inplace(|v| v * inv);
987 }
988
989 // #995 lowering-error scale: mass-weighted relative dispersion of
990 // the per-row tangents around the mean frame just built,
991 // Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖².
992 // 0 ⇒ the frame represents every active row exactly (flat
993 // decoder); → 1 ⇒ the tangent field disperses so strongly (e.g. a
994 // full circle, whose tangents average out) that the mean-frame
995 // compression cannot distinguish gauge motion from curvature. The
996 // certificate calibrates its per-generator verdict tolerance to
997 // this scale so it never claims a pin it cannot resolve.
998 let mut disp_num = 0.0_f64;
999 let mut disp_den = 0.0_f64;
1000 for row in 0..n {
1001 let a_nk = assignments[[row, atom_idx]];
1002 if !(a_nk > 0.0) {
1003 continue;
1004 }
1005 for axis in 0..d {
1006 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1007 for i in 0..p {
1008 let dev = tangent[i] - frame[[i, axis]];
1009 disp_num += a_nk * dev * dev;
1010 disp_den += a_nk * tangent[i] * tangent[i];
1011 }
1012 }
1013 }
1014 let lowering_error = if disp_den > 0.0 {
1015 (disp_num / disp_den).clamp(0.0, 1.0)
1016 } else {
1017 0.0
1018 };
1019
1020 let ard_variances = per_atom_ard_variances
1021 .and_then(|all| all.get(atom_idx))
1022 .and_then(|opt| opt.clone())
1023 .filter(|v| v.len() == d);
1024
1025 fitted_atoms.push(FittedAtom {
1026 name: atom.name.clone(),
1027 topology,
1028 frame,
1029 ard_variances,
1030 lowering_error,
1031 // #1019: post-fit chart canonicalization (arc length for
1032 // d = 1, isometry-flow for d = 2 torus, flat-reference
1033 // isometry-flow for d = 2 free/patch, round-sphere
1034 // conformal-boost flow for d = 2 sphere atoms) pins the chart;
1035 // the certificate downgrades this atom's chart freedom to the
1036 // finite isometry group with PinnedByCanonicalization
1037 // provenance.
1038 chart_canonicalized: atom.chart_canonicalized
1039 && (d == 1
1040 || (d == 2
1041 && matches!(
1042 atom.basis_kind,
1043 SaeAtomBasisKind::Torus
1044 | SaeAtomBasisKind::Linear
1045 | SaeAtomBasisKind::Duchon
1046 | SaeAtomBasisKind::EuclideanPatch
1047 | SaeAtomBasisKind::Sphere
1048 ))),
1049 // #1097 / #1103: the per-atom inner-decoder-smooth snapshot,
1050 // attached when the post-fit harness has run
1051 // [`Self::set_atom_inner_fits`] (it needs the reconstruction
1052 // target Z, dropped from the objective at fit end). `None` on a
1053 // bare certificate-only model, or for a degenerate atom whose
1054 // inner Hessian was not SPD.
1055 inner_fit: self
1056 .atom_inner_fits
1057 .as_ref()
1058 .and_then(|fits| fits.get(atom_idx))
1059 .and_then(|slot| slot.clone()),
1060 });
1061 atom_offsets.push(cursor);
1062 atom_axis_dim.push(d);
1063 cursor += p * d;
1064 }
1065 let param_dim = cursor;
1066
1067 // Per-row pinning Jacobian `J_n ∈ ℝ^{p × param_dim}` flattened row-major
1068 // (`J_n[i, c] = jacobian_rows[n][i · param_dim + c]`). Column `(k, i', a)`
1069 // of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i']` placed at the atom-k frame slot
1070 // and read out on output coordinate `i = i'` (a frame perturbation of
1071 // output `i'` moves only the row's output coordinate `i'`).
1072 //
1073 // The pinned certificate still consumes the legacy row-block contract.
1074 // The unpinned exact path consumes only `RᵀR`, so stream each transient
1075 // row Jacobian through the metric whitening and discard it immediately.
1076 let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
1077 let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1078 let mut tangent = vec![0.0_f64; p];
1079 for row in 0..n {
1080 let mut j_flat = vec![0.0_f64; p * param_dim];
1081 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1082 let a_nk = assignments[[row, atom_idx]];
1083 if !(a_nk > 0.0) {
1084 continue;
1085 }
1086 let d = atom_axis_dim[atom_idx];
1087 let base = atom_offsets[atom_idx];
1088 for axis in 0..d {
1089 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1090 for i in 0..p {
1091 // Frame coordinate `(k, i, axis)` sits at column
1092 // `base + i·d + axis`; it sources output coordinate `i`.
1093 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1094 }
1095 }
1096 }
1097 jacobian_rows.push(j_flat);
1098 }
1099 (jacobian_rows, None)
1100 } else {
1101 let streamed = self.residual_gauge_streamed_data_curvature(
1102 &metric,
1103 &atom_offsets,
1104 &atom_axis_dim,
1105 param_dim,
1106 )?;
1107 (Vec::new(), Some(streamed))
1108 };
1109
1110 // Isometry-penalty curvature root over the frame parameter space. When
1111 // the isometry gauge pin is active it gives curvature along every fitted
1112 // frame direction (it resists deviation of the decoder image from its
1113 // arc-length parameterization), so its row space is the span of the
1114 // per-atom frame columns: one root row per `(k, axis)` carrying that
1115 // atom's frame column at the atom's frame slot. Empty (`0 × param_dim`)
1116 // when the pin is inactive — exactly the certificate's escalation
1117 // condition to `diffeomorphism-unpinned`.
1118 let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
1119 let mut root_rows: Vec<Array1<f64>> = Vec::new();
1120 for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
1121 let d = atom_axis_dim[atom_idx];
1122 let base = atom_offsets[atom_idx];
1123 for axis in 0..d {
1124 let mut r = Array1::<f64>::zeros(param_dim);
1125 let mut any = false;
1126 for i in 0..p {
1127 let v = fitted.frame[[i, axis]];
1128 if v != 0.0 {
1129 any = true;
1130 }
1131 r[base + i * d + axis] = v;
1132 }
1133 if any {
1134 root_rows.push(r);
1135 }
1136 }
1137 }
1138 let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
1139 for (ri, r) in root_rows.iter().enumerate() {
1140 root.row_mut(ri).assign(r);
1141 }
1142 root
1143 } else {
1144 Array2::<f64>::zeros((0, param_dim))
1145 };
1146
1147 Ok((
1148 FittedSaeManifold {
1149 atoms: fitted_atoms,
1150 jacobian_rows,
1151 isometry_penalty_root,
1152 metric,
1153 },
1154 streamed_curvature,
1155 ))
1156 }
1157
1158 pub(crate) fn residual_gauge_streamed_data_curvature(
1159 &self,
1160 metric: &gam_problem::RowMetric,
1161 atom_offsets: &[usize],
1162 atom_axis_dim: &[usize],
1163 param_dim: usize,
1164 ) -> Result<(Array2<f64>, usize), String> {
1165 let n = self.n_obs();
1166 let p = self.output_dim();
1167 if metric.p_out() != p {
1168 return Err(format!(
1169 "residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
1170 metric.p_out()
1171 ));
1172 }
1173 let rank = metric.metric_rank();
1174 let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
1175 if param_dim == 0 || n == 0 || rank == 0 {
1176 return Ok((gram, n * rank));
1177 }
1178
1179 let assignments = self.assignment.assignments();
1180 let mut tangent = vec![0.0_f64; p];
1181 let mut j_flat = vec![0.0_f64; p * param_dim];
1182 let mut root_row = Array1::<f64>::zeros(param_dim);
1183 for row in 0..n {
1184 j_flat.fill(0.0);
1185 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1186 let a_nk = assignments[[row, atom_idx]];
1187 if !(a_nk > 0.0) {
1188 continue;
1189 }
1190 let d = atom_axis_dim[atom_idx];
1191 let base = atom_offsets[atom_idx];
1192 for axis in 0..d {
1193 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1194 for i in 0..p {
1195 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1196 }
1197 }
1198 }
1199
1200 if metric.drives_gauge() {
1201 for r in 0..rank {
1202 root_row.fill(0.0);
1203 for c in 0..param_dim {
1204 let mut acc = 0.0_f64;
1205 for i in 0..p {
1206 acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
1207 }
1208 root_row[c] = acc;
1209 }
1210 let row_slice = root_row.as_slice().ok_or_else(|| {
1211 "residual_gauge_streamed_data_curvature: non-contiguous root row"
1212 .to_string()
1213 })?;
1214 Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
1215 }
1216 } else {
1217 for i in 0..p {
1218 let start = i * param_dim;
1219 let end = start + param_dim;
1220 Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
1221 }
1222 }
1223 }
1224
1225 for a in 0..param_dim {
1226 for b in 0..a {
1227 gram[[b, a]] = gram[[a, b]];
1228 }
1229 }
1230 Ok((gram, n * rank))
1231 }
1232
1233 pub(crate) fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
1234 for a in 0..row.len() {
1235 let va = row[a];
1236 if va == 0.0 {
1237 continue;
1238 }
1239 for b in 0..=a {
1240 let vb = row[b];
1241 if vb != 0.0 {
1242 gram[[a, b]] += va * vb;
1243 }
1244 }
1245 }
1246 }
1247
1248 pub fn set_temperature_schedule(
1249 &mut self,
1250 sched: GumbelTemperatureSchedule,
1251 ) -> Result<(), String> {
1252 sched.validate()?;
1253 self.assignment
1254 .mode
1255 .set_temperature(sched.current_tau(sched.iter_count))?;
1256 self.temperature_schedule = Some(sched);
1257 Ok(())
1258 }
1259
1260 pub(crate) fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
1261 let Some(schedule) = self.temperature_schedule.as_mut() else {
1262 return Ok(None);
1263 };
1264 schedule.validate()?;
1265 let tau = schedule.step();
1266 self.assignment.mode.set_temperature(tau)?;
1267 Ok(Some(tau))
1268 }
1269
1270 pub fn n_obs(&self) -> usize {
1271 self.assignment.n_obs()
1272 }
1273
1274 pub fn k_atoms(&self) -> usize {
1275 self.atoms.len()
1276 }
1277
1278 /// Auto-derived in-core vs streaming plan for SAE Arrow-Schur work.
1279 ///
1280 /// This is intentionally not user-configurable: the route follows the
1281 /// retained full-batch working-set estimate and the currently selected GPU
1282 /// memory budget when CUDA is usable, otherwise a conservative host budget.
1283 pub fn streaming_plan(&self) -> SaeStreamingPlan {
1284 let n_obs = self.n_obs();
1285 let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
1286 let d_max = self
1287 .atoms
1288 .iter()
1289 .map(|atom| atom.latent_dim)
1290 .max()
1291 .unwrap_or(0);
1292 let border_dim = if self.any_frame_active() {
1293 self.factored_border_dim()
1294 } else {
1295 self.beta_dim()
1296 };
1297 sae_streaming_plan_for_shape(n_obs, total_basis, self.k_atoms(), d_max, border_dim)
1298 }
1299
1300 /// Construction-time validation: every Psi-tier analytic penalty in the
1301 /// registry must be dispatchable into the SAE arrow-Schur row layout.
1302 ///
1303 /// Two invariants are enforced upfront so the dispatch loop in
1304 /// `add_sae_analytic_penalty_contributions` is total (no runtime
1305 /// "unsupported penalty" fallthrough, no per-call K-gating):
1306 ///
1307 /// 1. Every Psi-tier penalty is either in [`sae_penalty_is_row_block_supported`],
1308 /// or `NuclearNorm` (which is redirected to the per-atom decoder (β) block
1309 /// rather than the coord "t" row block). Assignment sparsity penalties
1310 /// (`IBPAssignment`, `SoftmaxAssignmentSparsity`) are refused because the SAE
1311 /// term already owns them through its built-in assignment path
1312 /// (`loss.assignment_sparsity`). Penalty kinds with cross-row structure
1313 /// (`TotalVariation`, `Monotonicity`, `BlockSparsity`,
1314 /// `IvaeRidgeMeanGauge`, `Orthogonality`, `NestedPrefix`,
1315 /// `SheafConsistency`) cannot be expressed in the SAE row-block layout
1316 /// and are refused here.
1317 ///
1318 /// 2. If any Psi-tier row-block penalty is present, every atom shares
1319 /// the same coord latent dim. The current registry model carries one
1320 /// `latent_dim` per descriptor (the "t" latent block declares one
1321 /// `d` value); per-atom dispatch with heterogeneous `d_k` would
1322 /// require per-atom registry entries or per-kind in-place
1323 /// reshaping. Mixed-d row-block fits are rejected with an actionable
1324 /// error pointing at the configuration mismatch.
1325 ///
1326 /// The K=1 case trivially satisfies (2). Beta-tier and rho-tier
1327 /// penalties are not constrained here.
1328 pub(crate) fn validate_analytic_penalty_registry(
1329 &self,
1330 registry: &AnalyticPenaltyRegistry,
1331 ) -> Result<(), String> {
1332 let mut row_block_penalty_present = false;
1333 for penalty in ®istry.penalties {
1334 if penalty.tier() != PenaltyTier::Psi {
1335 continue;
1336 }
1337 if matches!(
1338 penalty,
1339 AnalyticPenaltyKind::IBPAssignment(_)
1340 | AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
1341 ) {
1342 return Err(format!(
1343 "SAE-manifold term refuses analytic penalty {:?}: assignment sparsity \
1344 is owned by the built-in SAE assignment path (loss.assignment_sparsity). \
1345 Registering it would double-count the objective and gradient",
1346 penalty.name()
1347 ));
1348 }
1349 // NuclearNorm is redirected to the per-atom decoder (β) block in
1350 // `add_sae_beta_penalty` (it penalizes each atom's decoder matrix
1351 // singular spectrum, i.e. its embedding rank), so it bypasses the
1352 // coord "t" row-block requirement below.
1353 if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
1354 continue;
1355 }
1356 if !sae_penalty_is_row_block_supported(penalty) {
1357 return Err(format!(
1358 "SAE-manifold term refuses analytic penalty {:?}: this kind \
1359 has cross-row structure and cannot be expressed in the \
1360 arrow-Schur row layout. Use only row-block-supported \
1361 coord penalties (ARD, BlockOrthogonality, \
1362 Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
1363 ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
1364 coord latent block, or move the penalty to a non-SAE \
1365 term",
1366 penalty.name()
1367 ));
1368 }
1369 row_block_penalty_present = true;
1370 }
1371 if row_block_penalty_present {
1372 let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
1373 if let Some(first) = dims.next() {
1374 if let Some(mismatch) = dims.find(|d| *d != first) {
1375 return Err(format!(
1376 "SAE-manifold term refuses row-block analytic penalty: \
1377 atoms have heterogeneous coord latent dims (saw {first} \
1378 and {mismatch}). Row-block penalties (ARD, \
1379 BlockOrthogonality, ...) target the unified \"t\" \
1380 latent block whose declared `d` matches one shape; \
1381 per-atom dispatch with mixed `d_k` would silently \
1382 truncate or expand axes. Configure all atoms with the \
1383 same `atom_dim`, or split the row-block penalty into \
1384 per-atom descriptors keyed to per-atom latent blocks"
1385 ));
1386 }
1387 }
1388 }
1389 Ok(())
1390 }
1391
1392 pub fn output_dim(&self) -> usize {
1393 self.atoms[0].output_dim()
1394 }
1395
1396 pub fn beta_dim(&self) -> usize {
1397 let p = self.output_dim();
1398 self.atoms.iter().map(|a| a.basis_size() * p).sum()
1399 }
1400
1401 pub(crate) fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
1402 let mut workspace =
1403 std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
1404 if workspace.dim() != (border_dim, border_dim) {
1405 workspace = Array2::<f64>::zeros((border_dim, border_dim));
1406 } else {
1407 workspace.fill(0.0);
1408 }
1409 workspace
1410 }
1411
1412 pub(crate) fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
1413 let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
1414 self.border_hbb_workspace = workspace;
1415 }
1416
1417 /// Factored arrow-Schur border dimension `Σ_k M_k · r_k` (issue #972): the
1418 /// number of decoder coordinates the border actually carries once the
1419 /// low-rank Grassmann frames are profiled out. Atoms with no active frame
1420 /// contribute their full `M_k · p` (`r_k == p`), so on the all-full-`B` path
1421 /// this equals [`Self::beta_dim`]. The border Cholesky / evidence log-det
1422 /// scale with THIS count, not `beta_dim`.
1423 pub fn factored_border_dim(&self) -> usize {
1424 self.atoms.iter().map(|a| a.border_coeff_count()).sum()
1425 }
1426
1427 /// Total profiled-out Grassmann manifold dimension `Σ_k r_k·(p − r_k)` across
1428 /// all active frames (issue #972). This is the count of decoder-frame degrees
1429 /// of freedom estimated OUTSIDE the border by closed-form polar steps, and it
1430 /// must enter the Laplace evidence dimension accounting (evidence honesty):
1431 /// the profiled frame is a MAP point on `∏_k Gr(r_k, p)`, contributing this
1432 /// many free dimensions to the model. `0` when every atom is on the full-`B`
1433 /// path. Threaded into [`Self::reml_occam_term`].
1434 pub fn grassmann_evidence_dimension(&self) -> usize {
1435 self.atoms
1436 .iter()
1437 .map(|a| a.frame_manifold_dimension())
1438 .sum()
1439 }
1440
1441 /// True iff any atom has an active low-rank Grassmann frame (issue #972).
1442 pub fn frames_active(&self) -> bool {
1443 self.atoms.iter().any(|a| a.decoder_frame.is_some())
1444 }
1445
1446 /// Alias of [`Self::frames_active`] (issue #972 / #977 T1): the predicate the
1447 /// assembly / step-lift branch on to decide whether the β-tier is built in
1448 /// the factored coordinate layout. Named to read as the question
1449 /// "is the factored path engaged?" at its call sites.
1450 pub fn any_frame_active(&self) -> bool {
1451 self.frames_active()
1452 }
1453
1454 /// Per-atom column offsets of the *factored* border (issue #972 / #977 T1):
1455 /// the running prefix sum of `M_k · r_k`, one entry per atom (the same
1456 /// convention as [`Self::beta_offsets`]). This is the start of each atom's
1457 /// `C_k` block in the reduced border vector; on the all-full-`B` path it
1458 /// equals `beta_offsets`. Distinct from [`Self::factored_border_offsets`]
1459 /// only in name (both compute the identical prefix sum) — this method is the
1460 /// one the frame transform reads, mirroring `beta_offsets` at the call site.
1461 pub fn factored_beta_offsets(&self) -> Vec<usize> {
1462 self.factored_border_offsets()
1463 }
1464
1465 /// Frame output matrix `U_k ∈ St(p, r_k)` for atom `k` (issue #972 / #977 T1).
1466 /// Returns the active frame `U_k` (`p × r_k`) when atom `k` is framed, else
1467 /// the identity `I_p` (the `r_k == p`, `U_k == I_p` full-`B` special case) so
1468 /// the projection / lift code is uniform across a mixed dictionary.
1469 pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
1470 let atom = &self.atoms[atom_idx];
1471 match &atom.decoder_frame {
1472 Some(frame) => frame.frame().to_owned(),
1473 None => Array2::<f64>::eye(atom.output_dim()),
1474 }
1475 }
1476
1477 /// Per-pair frame factor `W_{ij} = U_iᵀ U_j` (`r_i × r_j`) used as the output
1478 /// factor of the factored data β-Hessian block `G_{ij} ⊗ W_{ij}` (issue #972
1479 /// / #977 T1). When both atoms are framed this is the dense principal-angle
1480 /// cosine matrix between the two frames; for `i == j` with an orthonormal
1481 /// frame it is exactly `I_{r_i}`; for any un-framed atom the corresponding
1482 /// `U` is `I_p`, so a same-atom un-framed pair gives `I_p` (the clean full-`B`
1483 /// `G ⊗ I_p` collapse) and a framed/un-framed cross pair gives the rectangular
1484 /// `U_iᵀ` / `U_j` overlap.
1485 pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
1486 let ui = self.frame_output_matrix(atom_i);
1487 let uj = self.frame_output_matrix(atom_j);
1488 // `U_iᵀ U_j`: `(r_i × p) · (p × r_j)`. `fast_atb` forms `U_iᵀ U_j` directly.
1489 fast_atb(&ui, &uj)
1490 }
1491
1492 /// Per-atom column offsets of the *factored* border (issue #972): the
1493 /// running prefix sum of `M_k · r_k`. The analogue of [`Self::beta_offsets`]
1494 /// for the reduced coordinate layout — atom `k`'s `C_k` occupies
1495 /// `[factored_border_offsets()[k] .. + M_k·r_k)`. On the full-`B` path this
1496 /// equals `beta_offsets`.
1497 pub fn factored_border_offsets(&self) -> Vec<usize> {
1498 let mut out = Vec::with_capacity(self.k_atoms());
1499 let mut cursor = 0usize;
1500 for atom in &self.atoms {
1501 out.push(cursor);
1502 cursor += atom.border_coeff_count();
1503 }
1504 out
1505 }
1506
1507 /// Assemble the factored border coordinate vector `C = [vec(C_1); …; vec(C_K)]`
1508 /// in row-major `C_k[m, j] → C[off_k + m·r_k + j]` layout (issue #972).
1509 ///
1510 /// This is the reduced state the arrow-Schur border carries when frames are
1511 /// active: its length is [`Self::factored_border_dim`] (`Σ M_k·r_k`), the
1512 /// border-size invariant verified by [`grassmann_assert_border_dim_invariant`].
1513 /// Atoms
1514 /// without an active frame contribute their full `vec(B_k)` (their `r_k == p`
1515 /// coordinates are the decoder itself), so on the all-full-`B` path this
1516 /// reproduces [`Self::flatten_beta`].
1517 pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
1518 let offsets = self.factored_border_offsets();
1519 let mut out = Array1::<f64>::zeros(self.factored_border_dim());
1520 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1521 let off = offsets[atom_idx];
1522 let r = atom.border_frame_rank();
1523 let m = atom.basis_size();
1524 let coords = match atom.factored_coordinates()? {
1525 Some(c) => c,
1526 // Full-`B` path: the decoder itself is the coordinate matrix.
1527 None => atom.decoder_coefficients.clone(),
1528 };
1529 for basis_col in 0..m {
1530 for j in 0..r {
1531 out[off + basis_col * r + j] = coords[[basis_col, j]];
1532 }
1533 }
1534 }
1535 Ok(out)
1536 }
1537
1538 /// Scatter a factored border coordinate vector `C` (length
1539 /// [`Self::factored_border_dim`]) back into the per-atom decoders, refreshing
1540 /// each `decoder_coefficients = C_k · U_kᵀ` so the full-`B` consumers stay
1541 /// consistent after a factored border solve (issue #972). The inverse of
1542 /// [`Self::flatten_factored_border`].
1543 pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
1544 let expected = self.factored_border_dim();
1545 if border.len() != expected {
1546 return Err(format!(
1547 "SaeManifoldTerm::scatter_factored_border: border length {} must equal \
1548 factored border dim {expected}",
1549 border.len()
1550 ));
1551 }
1552 let offsets = self.factored_border_offsets();
1553 for atom_idx in 0..self.atoms.len() {
1554 let off = offsets[atom_idx];
1555 let (r, m, has_frame) = {
1556 let atom = &self.atoms[atom_idx];
1557 (
1558 atom.border_frame_rank(),
1559 atom.basis_size(),
1560 atom.decoder_frame.is_some(),
1561 )
1562 };
1563 let mut coords = Array2::<f64>::zeros((m, r));
1564 for basis_col in 0..m {
1565 for j in 0..r {
1566 coords[[basis_col, j]] = border[off + basis_col * r + j];
1567 }
1568 }
1569 if has_frame {
1570 self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
1571 } else {
1572 // Full-`B` path: the coordinates ARE the decoder.
1573 self.atoms[atom_idx].decoder_coefficients = coords;
1574 }
1575 }
1576 Ok(())
1577 }
1578
1579 /// Auto-derive and install low-rank Grassmann decoder frames across all
1580 /// atoms (issue #972) — magic-by-default, no flag. Each atom independently
1581 /// activates its frame iff the factorization materially shrinks its border
1582 /// (see [`SaeManifoldAtom::maybe_activate_decoder_frame`]). Returns the
1583 /// number of atoms that activated a frame. Idempotent: re-running re-derives
1584 /// each frame from the current decoder.
1585 ///
1586 /// The decision keys on the *frontier* regime the issue targets: at large
1587 /// ambient `p` the full border `Σ M_k · p` reaches `10^7`–`10^8` and the
1588 /// border Cholesky dies, while the decoder's effective column rank `r` stays
1589 /// `≪ p`. Small-`p` atoms (where `r` cannot beat the activation margin)
1590 /// keep the bit-for-bit full-`B` path, so the small-model evidence is
1591 /// unchanged (verified by `factored_evidence_matches_full_b_at_small_p`).
1592 pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
1593 let mut activated = 0usize;
1594 for atom in &mut self.atoms {
1595 let expected_rank = atom.decoder_frame_activation_rank()?;
1596 match (
1597 expected_rank,
1598 atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
1599 ) {
1600 (Some(expected), Some(current)) if expected == current => {
1601 continue;
1602 }
1603 (None, Some(_)) => {
1604 atom.deactivate_decoder_frame();
1605 continue;
1606 }
1607 (None, None) => {
1608 continue;
1609 }
1610 (Some(_), _) => {}
1611 }
1612 if atom.maybe_activate_decoder_frame()?.is_some() {
1613 activated += 1;
1614 }
1615 }
1616 Ok(activated)
1617 }
1618
1619 /// Reconcile decoder-frame activation before a fit entry point. The
1620 /// user-facing `auto_activate_decoder_frames` contract returns only newly
1621 /// installed frames; this helper enforces the stronger invariant the large-p
1622 /// solver needs: every atom whose current decoder satisfies the activation
1623 /// predicate has an active frame after the pass.
1624 pub(crate) fn ensure_decoder_frames_active_for_current_decoder(
1625 &mut self,
1626 ) -> Result<(), String> {
1627 self.auto_activate_decoder_frames()?;
1628 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1629 let expected_rank = atom.decoder_frame_activation_rank()?;
1630 if let Some(expected_rank) = expected_rank {
1631 match atom.decoder_frame.as_ref() {
1632 Some(frame) if frame.rank() == expected_rank => {}
1633 Some(frame) => {
1634 return Err(format!(
1635 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1636 atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
1637 frame.rank()
1638 ));
1639 }
1640 None => {
1641 return Err(format!(
1642 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1643 atom {atom_idx} has audited rank {expected_rank} but no active frame"
1644 ));
1645 }
1646 }
1647 } else if atom.decoder_frame.is_some() {
1648 return Err(format!(
1649 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1650 atom {atom_idx} kept a frame after the full-B predicate won"
1651 ));
1652 }
1653 }
1654 Ok(())
1655 }
1656
1657 /// Closed-form streaming POLAR refresh of every ACTIVE decoder frame from the
1658 /// current data evidence (issue #972 / #977 T1) — the U-block of the
1659 /// alternating block-coordinate ascent that complements the border's
1660 /// C-block Newton step.
1661 ///
1662 /// For each framed atom `k` we accumulate the `p × r_k` cross-moment
1663 /// `A_k = Σ_n a_{n,k} · e_{n,k} · ĉ_{n,k}ᵀ`,
1664 /// where `e_{n,k} = z_n − Σ_{k'≠k} a_{n,k'}·decoded_{k'}(n)` is the row's
1665 /// partial reconstruction residual (everything except atom `k`) and
1666 /// `ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^{r_k}` is atom `k`'s in-span decoded
1667 /// coordinate. The polar factor `U_new = polar(A_k)` is the closed-form MAP
1668 /// frame on `Gr(r_k, p)` given the C-coordinates held fixed — the same
1669 /// `O(p r²)` thin SVD the issue prescribes, run OUTSIDE the border. The frame
1670 /// is then re-installed and the decoder re-projected onto it so the
1671 /// authoritative `B_k = C_k U_newᵀ` and the `(C_k, U_new)` pair stay
1672 /// consistent (a no-op in span for a truly rank-`r` atom). Un-framed atoms
1673 /// are skipped. Returns the number of frames refreshed.
1674 pub(crate) fn refresh_active_frames_from_data(
1675 &mut self,
1676 target: ArrayView2<'_, f64>,
1677 rho: &SaeManifoldRho,
1678 ) -> Result<usize, String> {
1679 let n = self.n_obs();
1680 let p = self.output_dim();
1681 let k_atoms = self.k_atoms();
1682 if n == 0 {
1683 return Ok(0);
1684 }
1685 // Per-row assignments and per-(row, atom) decoded outputs, computed once.
1686 let mut assignments = Vec::with_capacity(n);
1687 for row in 0..n {
1688 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
1689 }
1690 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
1691 let mut dbuf = vec![0.0_f64; p];
1692 for row in 0..n {
1693 for atom_idx in 0..k_atoms {
1694 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
1695 for c in 0..p {
1696 decoded[[row, atom_idx, c]] = dbuf[c];
1697 }
1698 }
1699 }
1700 // Full fitted reconstruction `Σ_k a_k decoded_k`, so the per-atom partial
1701 // residual is `e_k = (z − fitted) + a_k decoded_k` (add atom k back in).
1702 let mut fitted = Array2::<f64>::zeros((n, p));
1703 for row in 0..n {
1704 for atom_idx in 0..k_atoms {
1705 let a = assignments[row][atom_idx];
1706 if a == 0.0 {
1707 continue;
1708 }
1709 for c in 0..p {
1710 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
1711 }
1712 }
1713 }
1714 let mut refreshed = 0usize;
1715 for atom_idx in 0..k_atoms {
1716 // Only atoms with an active frame are refreshed.
1717 let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
1718 continue;
1719 };
1720 let r = self.atoms[atom_idx].border_frame_rank();
1721 let m = self.atoms[atom_idx].basis_size();
1722 // Accumulate `A_k = Σ_n a_k · e_{n,k} · ĉ_{n,k}ᵀ` directly (p × r).
1723 let mut cross = GrassmannCrossMoment::new(p, r);
1724 // Build per-row p-target `a_k·e_k` and r-coord `a_k·ĉ` batched, then
1725 // accumulate as one outer-product sum. `accumulate` forms
1726 // `targetsᵀ·coords`, so scaling EITHER side by `a_k` once gives the
1727 // `a_k²` weight on the cross-moment that matches the C-block normal
1728 // equations (residual leg carries `a_k`, coordinate leg carries
1729 // `a_k`).
1730 let mut targets = Array2::<f64>::zeros((n, p));
1731 let mut rcoords = Array2::<f64>::zeros((n, r));
1732 for row in 0..n {
1733 let a = assignments[row][atom_idx];
1734 // Partial residual e_{n,k} = z_n − (fitted − a_k decoded_k).
1735 for c in 0..p {
1736 let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
1737 targets[[row, c]] = a * e;
1738 }
1739 // In-span coordinate ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^r.
1740 for j in 0..r {
1741 let mut acc = 0.0_f64;
1742 for basis_col in 0..m {
1743 acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
1744 * coords_c[[basis_col, j]];
1745 }
1746 rcoords[[row, j]] = a * acc;
1747 }
1748 }
1749 cross.accumulate(targets.view(), rcoords.view())?;
1750 // `polar(A_k)` is well-defined only when the moment is non-trivial;
1751 // a zero moment (e.g. a fully collapsed atom) leaves the frame as-is.
1752 if cross.moment().iter().all(|&v| v == 0.0) {
1753 continue;
1754 }
1755 self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
1756 refreshed += 1;
1757 }
1758 Ok(refreshed)
1759 }
1760
1761 pub fn beta_offsets(&self) -> Vec<usize> {
1762 let p = self.output_dim();
1763 let mut out = Vec::with_capacity(self.k_atoms());
1764 let mut cursor = 0usize;
1765 for atom in &self.atoms {
1766 out.push(cursor);
1767 cursor += atom.basis_size() * p;
1768 }
1769 out
1770 }
1771
1772 /// Per-atom β column ranges for the block-Jacobi Schur preconditioner.
1773 ///
1774 /// Returns one `Range<usize>` per atom, covering that atom's decoder
1775 /// coefficients in the flat β vector:
1776 /// `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
1777 ///
1778 /// Pass to [`ArrowSchurSystem::set_block_offsets`] so that
1779 /// [`gam_solve::arrow_schur::JacobiPreconditioner`] builds one dense
1780 /// Schur sub-block per atom instead of scalar-diagonal inversion.
1781 pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
1782 let p = self.output_dim();
1783 let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
1784 let mut cursor = 0usize;
1785 for atom in &self.atoms {
1786 let width = atom.basis_size() * p;
1787 ranges.push(cursor..cursor + width);
1788 cursor += width;
1789 }
1790 Arc::from(ranges.into_boxed_slice())
1791 }
1792
1793 /// Decide whether the sparse per-row active-set layout is engaged for a
1794 /// dense-weight assignment mode, and if so derive the per-row active-atom
1795 /// cap and magnitude cutoff.
1796 ///
1797 /// #1408: this plan is mode-agnostic. `assemble_arrow_schur` consults it
1798 /// directly for IBP-MAP, and for `AssignmentMode::Softmax` via
1799 /// [`Self::softmax_active_plan`], which tightens it with an explicit `top_k`
1800 /// (`softmax_active_cap`). Softmax therefore engages the compact active-set
1801 /// layout whenever `top_k` or the budget bounds the active set (the
1802 /// active-sub-block Gershgorin majorizer + coherent logdet/θ-adjoint are
1803 /// landed — see `SaeRowLayout`'s doc); it keeps the full `K`-atom layout only
1804 /// when neither lever engages. The decision is auto-derived from
1805 /// the problem size and the device/host working-set budget — never a CLI flag
1806 /// or kwarg. JumpReLU is not handled here (it always uses its structural gate
1807 /// via [`SaeRowLayout::from_jumprelu`]). The dense Gauss-Newton data Gram `G`
1808 /// is `(m_total × m_total)` f64; if its dense form fits the budget we keep
1809 /// the exact full-support solve (every atom active per row), so small-`K`
1810 /// problems are bit-for-bit unchanged. Above that, we cap each row to the
1811 /// `k_active` atoms that make the *sparse* Gram fit the same budget, with a
1812 /// relative magnitude cutoff that drops assignment mass contributing
1813 /// negligible `O(a²)` curvature.
1814 ///
1815 /// Returns `Some((k_active_cap, cutoff))` to engage sparsity, or `None` to
1816 /// keep the dense full-support layout.
1817 pub(crate) fn sparse_active_plan(&self) -> Option<(usize, f64)> {
1818 // The per-row Riemannian tangent projection for non-Euclidean atom
1819 // latents is now applied directly on the compact active-set rows (see
1820 // the `Some(layout)` arm in `assemble_arrow_schur`, via
1821 // `compact_row_ext_manifold_and_point`), which rebuilds each row's
1822 // product manifold in its compact column order and applies the SAME
1823 // gt/htt/htbeta + Kronecker-Jacobian projections the dense path uses. So
1824 // the sparse plan may engage on curved ext-coord manifolds (circle /
1825 // torus / sphere atoms) — the affordability lever for manifold-SAE at
1826 // large `K`, where the dense `K²` co-assignment Gram is the cost. (The
1827 // former `is_euclidean()`-only restriction punted every curved atom to
1828 // the dense layout; it is lifted.) The host/device in-core budget is the
1829 // single gate now; it is parameterised in `sparse_active_plan_for_budget`
1830 // so the engagement regression can pin a small budget without allocating
1831 // a multi-GB dense Gram.
1832 let budget = match crate::gpu::device_runtime::GpuRuntime::global() {
1833 // Allow up to one quarter of the AGGREGATE device budget for the dense
1834 // Gram, matching the streaming dispatcher's in-core fraction. The
1835 // per-atom-pair Gram blocks fan out across the whole device pool, so
1836 // the in-core fraction sums every ordinal's budget, not just the
1837 // primary's.
1838 Some(rt) => {
1839 let aggregate: usize = rt
1840 .device_ordinals()
1841 .iter()
1842 .map(|&ord| rt.memory_budget_for(ord))
1843 .sum();
1844 aggregate / 4
1845 }
1846 None => sae_host_in_core_budget_bytes().0,
1847 };
1848 self.sparse_active_plan_for_budget(budget)
1849 }
1850
1851 /// Budget-parameterised core of [`Self::sparse_active_plan`]. The dense data
1852 /// Gram footprint `(m_total · m_total) f64` is compared against `budget`; a
1853 /// term whose dense Gram exceeds the budget engages the compact active-set
1854 /// plan (returns `Some((k_active_cap, cutoff))`), regardless of whether any
1855 /// atom latent is curved. Pulled out so the curved-atom engagement
1856 /// regression can pin a small budget deterministically.
1857 pub(crate) fn sparse_active_plan_for_budget(&self, budget: usize) -> Option<(usize, f64)> {
1858 // Relative magnitude cutoff: assignment mass below this fraction of the
1859 // row's peak `|a_k|` enters the Gram only as `O(a²)` curvature and is
1860 // dropped. Chosen so dropped terms are ~1e-6 of the peak self-coupling.
1861 const RELATIVE_CUTOFF: f64 = 1.0e-3;
1862
1863 let k_atoms = self.k_atoms();
1864 if k_atoms <= 1 {
1865 return None;
1866 }
1867 let p = self.output_dim();
1868 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
1869 // Dense data Gram footprint: (m_total · m_total) f64.
1870 let dense_gram_bytes = m_total
1871 .saturating_mul(m_total)
1872 .saturating_mul(SAE_BYTES_PER_F64);
1873 if dense_gram_bytes <= budget {
1874 return None;
1875 }
1876
1877 // Sparse Gram footprint scales with the per-row active basis count
1878 // `k_active · m_atom`. Solve for the largest `k_active` whose sparse
1879 // Gram `(k_active · m_atom)²` still fits the budget.
1880 let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
1881 let max_active_basis = ((budget as f64 / SAE_BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
1882 let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
1883 // p does not enter the Gram dimension (it is carried by the `⊗ I_p`
1884 // structure), but a degenerate `p == 0` term has no decoder columns.
1885 if p == 0 {
1886 return None;
1887 }
1888 Some((k_active_cap, RELATIVE_CUTOFF))
1889 }
1890
1891 /// #1408/#1409 — per-row active-set plan for the Softmax assignment.
1892 ///
1893 /// Engages the compact top-`k` row layout when EITHER the user supplied a
1894 /// hard `top_k` cap ([`Self::softmax_active_cap`], `1 <= k < K`) OR the
1895 /// dense data Gram exceeds the in-core budget (the same memory lever the
1896 /// IBP path uses via [`Self::sparse_active_plan`]). The returned
1897 /// `k_active_cap` is the tighter of the two, so an explicit `top_k`
1898 /// genuinely bounds the optimization even below the memory threshold and a
1899 /// large-K budget breach still bounds it when no `top_k` is set. Returns
1900 /// `None` (keep the exact full-`K` dense softmax layout) when neither lever
1901 /// engages.
1902 ///
1903 /// The cutoff is the same relative magnitude floor as the budget plan
1904 /// (`1e-3` of the row peak); under an explicit `top_k` cap alone (no budget
1905 /// breach) it is `0.0` so exactly the top-`k` atoms are retained.
1906 pub(crate) fn softmax_active_plan(&self) -> Option<(usize, f64)> {
1907 if self.k_atoms() <= 1 {
1908 return None;
1909 }
1910 let budget_plan = self.sparse_active_plan();
1911 match (self.softmax_active_cap, budget_plan) {
1912 (Some(cap), Some((budget_cap, cutoff))) => Some((cap.min(budget_cap), cutoff)),
1913 // Explicit cap only: retain exactly the top-`cap` atoms (no extra
1914 // magnitude cutoff beyond the cap).
1915 (Some(cap), None) => Some((cap, 0.0)),
1916 (None, plan) => plan,
1917 }
1918 }
1919
1920 pub fn flatten_beta(&self) -> Array1<f64> {
1921 let p = self.output_dim();
1922 let offsets = self.beta_offsets();
1923 let mut out = Array1::<f64>::zeros(self.beta_dim());
1924 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1925 let m = atom.basis_size();
1926 let off = offsets[atom_idx];
1927 for basis_col in 0..m {
1928 for out_col in 0..p {
1929 out[off + basis_col * p + out_col] =
1930 atom.decoder_coefficients[[basis_col, out_col]];
1931 }
1932 }
1933 }
1934 out
1935 }
1936
1937 pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
1938 if beta.len() != self.beta_dim() {
1939 return Err(format!(
1940 "set_flat_beta: beta length {} != expected {}",
1941 beta.len(),
1942 self.beta_dim()
1943 ));
1944 }
1945 let p = self.output_dim();
1946 let offsets = self.beta_offsets();
1947 for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
1948 let m = atom.basis_size();
1949 let off = offsets[atom_idx];
1950 for basis_col in 0..m {
1951 for out_col in 0..p {
1952 atom.decoder_coefficients[[basis_col, out_col]] =
1953 beta[off + basis_col * p + out_col];
1954 }
1955 }
1956 }
1957 Ok(())
1958 }
1959
1960 pub fn refit_decoder_least_squares_at_current_state(
1961 &mut self,
1962 target: ArrayView2<'_, f64>,
1963 rho: Option<&SaeManifoldRho>,
1964 ) -> Result<(), String> {
1965 let n = self.n_obs();
1966 let p = self.output_dim();
1967 if target.dim() != (n, p) {
1968 return Err(format!(
1969 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: target shape {:?} != ({n}, {p})",
1970 target.dim()
1971 ));
1972 }
1973 let k_atoms = self.k_atoms();
1974 let offsets = self.beta_offsets();
1975 let m_total = self.beta_dim() / p;
1976 let mut design = Array2::<f64>::zeros((n, m_total));
1977 for row in 0..n {
1978 let assignments = match rho {
1979 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
1980 None => self.assignment.try_assignments_row(row)?,
1981 };
1982 for atom_idx in 0..k_atoms {
1983 let atom = &self.atoms[atom_idx];
1984 let weight = assignments[atom_idx];
1985 let m = atom.basis_size();
1986 let off = offsets[atom_idx] / p;
1987 for basis_col in 0..m {
1988 design[[row, off + basis_col]] = weight * atom.basis_values[[row, basis_col]];
1989 }
1990 }
1991 }
1992 let beta = solve_design_least_squares(design.view(), target)?;
1993 if beta.dim() != (m_total, p) {
1994 return Err(format!(
1995 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: beta shape {:?} != ({m_total}, {p})",
1996 beta.dim()
1997 ));
1998 }
1999 for atom_idx in 0..k_atoms {
2000 let m = self.atoms[atom_idx].basis_size();
2001 let off = offsets[atom_idx] / p;
2002 for basis_col in 0..m {
2003 for out_col in 0..p {
2004 self.atoms[atom_idx].decoder_coefficients[[basis_col, out_col]] =
2005 beta[[off + basis_col, out_col]];
2006 }
2007 }
2008 self.atoms[atom_idx].refresh_intrinsic_smooth_penalty();
2009 }
2010 Ok(())
2011 }
2012
2013 pub fn fitted(&self) -> Array2<f64> {
2014 self.try_fitted().expect("assignment logits must be finite")
2015 }
2016
2017 /// The #1026 hybrid-collapse substitution map: `atom_idx → &AtomLinearImage`
2018 /// for every `d = 1` slot whose post-fit verdict selected its straight
2019 /// (`Θ → 0`) sub-model. Empty when no report has been computed
2020 /// (`hybrid_split_report == None`, e.g. mid-fit) or no slot collapsed. The
2021 /// SINGLE source of the collapse policy — every reconstruction path (the
2022 /// rho-keyed `try_fitted_with_rho`, the explicit-assignment
2023 /// [`Self::reconstruct_from_assignments`] used by the top-k projection)
2024 /// reads it so train, OOS, and top-k reconstructions decode collapsed slots
2025 /// identically (#1228, #1233).
2026 pub(crate) fn hybrid_linear_image_map(
2027 &self,
2028 ) -> std::collections::HashMap<usize, &crate::hybrid_split::AtomLinearImage> {
2029 // A fitted term carries its collapse policy on the post-fit
2030 // `hybrid_split_report`; an OOS term carries the same trained images on
2031 // `oos_linear_images` (#1228). At most one is `Some` in practice, but
2032 // prefer the report when both are present.
2033 if let Some(report) = self.hybrid_split_report.as_ref() {
2034 return report
2035 .verdicts
2036 .iter()
2037 .filter_map(|v| v.linear_image.as_ref().map(|img| (img.atom_idx, img)))
2038 .collect();
2039 }
2040 if let Some(images) = self.oos_linear_images.as_ref() {
2041 return images.iter().map(|img| (img.atom_idx, img)).collect();
2042 }
2043 std::collections::HashMap::new()
2044 }
2045
2046 /// #1228 — attach the trained dictionary's hybrid-collapsed linear images to
2047 /// this (typically OOS) term so its reconstruction (`fitted` / the top-k
2048 /// assembler) decodes verdict-linear `d = 1` slots by the SAME straight
2049 /// sub-model the training reconstruction used, instead of the original
2050 /// curved decoder. Each image's `atom_idx` must index a real slot; an image
2051 /// whose channel count `p` disagrees with this term's output dim, or whose
2052 /// `atom_idx` is out of range, is rejected so a stale/mismatched payload
2053 /// cannot silently corrupt the reconstruction. Pass an empty slice (or never
2054 /// call this) for an all-curved OOS reconstruction.
2055 ///
2056 /// `pub` (not `pub(crate)`): this is part of the FFI surface — the gam-pyffi
2057 /// crate calls it from `latent_basis_and_sae_ffi.rs` to attach a trained
2058 /// dictionary's hybrid-linear images to an OOS reconstruction term (#1228).
2059 /// Downgrading it to `pub(crate)` breaks the gam-pyffi cdylib build with
2060 /// E0624 (the gam lib still compiles, so the lib build does not catch it).
2061 pub fn set_hybrid_linear_images(
2062 &mut self,
2063 images: Vec<crate::hybrid_split::AtomLinearImage>,
2064 ) -> Result<(), String> {
2065 let p = self.output_dim();
2066 let k_atoms = self.k_atoms();
2067 for img in &images {
2068 if img.atom_idx >= k_atoms {
2069 return Err(format!(
2070 "set_hybrid_linear_images: atom_idx {} out of range (k_atoms={k_atoms})",
2071 img.atom_idx
2072 ));
2073 }
2074 if img.b0.len() != p || img.b1.len() != p {
2075 return Err(format!(
2076 "set_hybrid_linear_images: atom {} linear image has p=({}, {}) != output_dim {p}",
2077 img.atom_idx,
2078 img.b0.len(),
2079 img.b1.len()
2080 ));
2081 }
2082 // #1777 — a collapse-rescued image's projection direction `v` must
2083 // have one entry per output channel so `coordinate_from_residual` can
2084 // project a held-out row's `p`-vector residual onto it.
2085 if let Some(v) = img.v.as_ref() {
2086 if v.len() != p {
2087 return Err(format!(
2088 "set_hybrid_linear_images: atom {} projection direction v has len {} != output_dim {p}",
2089 img.atom_idx,
2090 v.len()
2091 ));
2092 }
2093 }
2094 if self.atoms[img.atom_idx].latent_dim != 1 {
2095 return Err(format!(
2096 "set_hybrid_linear_images: atom {} is not d=1; only d=1 slots collapse to a straight image",
2097 img.atom_idx
2098 ));
2099 }
2100 }
2101 self.oos_linear_images = if images.is_empty() {
2102 None
2103 } else {
2104 Some(images)
2105 };
2106 Ok(())
2107 }
2108
2109 /// Assemble the reconstruction `Σ_k a[i,k]·g_k(t_{ik})` from an EXPLICIT
2110 /// per-row assignment matrix (e.g. a hard top-k projection of the fitted
2111 /// soft assignments), honouring the #1026 hybrid collapse when `collapse` is
2112 /// set: a verdict-linear `d = 1` slot decodes its straight sub-model image
2113 /// instead of its curved curve, exactly as the production `try_fitted` does.
2114 /// This is the shared assembler the FFI top-k path uses so the projected
2115 /// reconstruction composes with hybrid collapse (#1233) instead of
2116 /// re-deriving the curved image by hand and silently bypassing the verdict.
2117 /// The atom coordinates (`t`) and decoded curves are the term's own fitted
2118 /// ones; only the assignment masses come from `assignments`.
2119 pub fn reconstruct_from_assignments(
2120 &self,
2121 assignments: ArrayView2<'_, f64>,
2122 collapse: bool,
2123 ) -> Result<Array2<f64>, String> {
2124 let n = self.n_obs();
2125 let p = self.output_dim();
2126 let k_atoms = self.k_atoms();
2127 if assignments.dim() != (n, k_atoms) {
2128 return Err(format!(
2129 "SaeManifoldTerm::reconstruct_from_assignments: assignments {:?} != ({n}, {k_atoms})",
2130 assignments.dim()
2131 ));
2132 }
2133 let linear_images = if collapse {
2134 self.hybrid_linear_image_map()
2135 } else {
2136 std::collections::HashMap::new()
2137 };
2138 let mut out = Array2::<f64>::zeros((n, p));
2139 let mut g_buf = vec![0.0_f64; p];
2140 for row in 0..n {
2141 for atom_idx in 0..k_atoms {
2142 let a_k = assignments[[row, atom_idx]];
2143 if a_k == 0.0 {
2144 continue;
2145 }
2146 if let Some(image) = linear_images.get(&atom_idx) {
2147 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2148 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2149 } else {
2150 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2151 }
2152 let mut out_row = out.row_mut(row);
2153 for out_col in 0..p {
2154 out_row[out_col] += a_k * g_buf[out_col];
2155 }
2156 }
2157 }
2158 Ok(out)
2159 }
2160
2161 /// #1777 — TARGET-AWARE hybrid-collapsed reconstruction: identical to
2162 /// [`Self::try_fitted`] except that a #1026 COLLAPSE-RESCUED `d = 1` slot
2163 /// (whose linear image carries a projection direction `v`) recomputes each
2164 /// row's coordinate from THIS `target` as
2165 /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — its own leave-this-atom-out residual
2166 /// projected onto `v` — instead of reading the train-only cached
2167 /// `row_codes[i]` (or, worse, the atom's collapsed own coordinate `own_t`).
2168 ///
2169 /// This is the SAME math the train split used to build `row_codes`
2170 /// (`row_codes[i] = ⟨target_resid[i], v⟩`), so on the TRAIN rows/target it
2171 /// reproduces the train reconstruction bit-for-bit, and on a HELD-OUT
2172 /// rows/target it produces the correct out-of-sample coordinate — train and
2173 /// OOS are ONE model. Ordinary (non-rescued) straight images and curved slots
2174 /// are decoded exactly as in [`Self::try_fitted`]; they ignore `target`.
2175 ///
2176 /// `rho` selects the assignment-mass resolution (`Some` uses the ρ-keyed
2177 /// gates, `None` the persisted gates), mirroring [`Self::try_fitted_with_rho`].
2178 /// This is the reconstruction path an OOS predict should call once the trained
2179 /// hybrid-linear images are attached via [`Self::set_hybrid_linear_images`].
2180 pub fn try_fitted_target_aware(
2181 &self,
2182 target: ArrayView2<'_, f64>,
2183 rho: Option<&SaeManifoldRho>,
2184 ) -> Result<Array2<f64>, String> {
2185 let n = self.n_obs();
2186 let p = self.output_dim();
2187 let k_atoms = self.k_atoms();
2188 if target.dim() != (n, p) {
2189 return Err(format!(
2190 "SaeManifoldTerm::try_fitted_target_aware: target {:?} != ({n}, {p})",
2191 target.dim()
2192 ));
2193 }
2194 let linear_images = self.hybrid_linear_image_map();
2195 // The all-curved reconstruction `full = Σ_j a_j·γ_j`, the same quantity the
2196 // train split's `target_resid_for` subtracts. A rescued slot `k`'s
2197 // leave-this-atom-out residual is then `target − full + a_k·γ_k`.
2198 let full_curved = self.try_fitted_with_rho(rho, false)?;
2199 let mut out = Array2::<f64>::zeros((n, p));
2200 let mut g_buf = vec![0.0_f64; p];
2201 let mut decoded_buf = vec![0.0_f64; p];
2202 let mut resid_buf = vec![0.0_f64; p];
2203 for row in 0..n {
2204 let a = match rho {
2205 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2206 None => self.assignment.try_assignments_row(row)?,
2207 };
2208 for atom_idx in 0..k_atoms {
2209 let a_k = a[atom_idx];
2210 if let Some(image) = linear_images.get(&atom_idx) {
2211 if image.is_collapse_rescued() {
2212 // Recompute this row's coordinate from its own
2213 // leave-this-atom-out residual projected onto `v`.
2214 self.atoms[atom_idx].fill_decoded_row(row, &mut decoded_buf);
2215 for col in 0..p {
2216 resid_buf[col] =
2217 target[[row, col]] - full_curved[[row, col]] + a_k * decoded_buf[col];
2218 }
2219 // `coordinate_from_residual` returns `None` only on a
2220 // length mismatch (impossible here — validated at attach)
2221 // or a non-rescued image (excluded by the branch); fall
2222 // back to the train code/own-coord path if it ever does.
2223 let coord = image
2224 .coordinate_from_residual(&resid_buf)
2225 .unwrap_or_else(|| {
2226 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2227 image.coordinate_for_row(row, own_t)
2228 });
2229 image.fill_row(coord, &mut g_buf);
2230 } else {
2231 // Ordinary straight image: decode at the atom's own coord.
2232 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2233 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2234 }
2235 } else {
2236 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2237 }
2238 let mut out_row = out.row_mut(row);
2239 for out_col in 0..p {
2240 out_row[out_col] += a_k * g_buf[out_col];
2241 }
2242 }
2243 }
2244 Ok(out)
2245 }
2246
2247 pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
2248 // Production/user-facing reconstruction: honours the #1026 hybrid-split
2249 // verdict (verdict-linear `d = 1` slots decode their straight sub-model).
2250 self.try_fitted_with_rho(None, true)
2251 }
2252
2253 pub(crate) fn try_fitted_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
2254 // Internal/fitting reconstruction: the pure CURVED image (the joint fit
2255 // and the #1026 adjudication both require the uncollapsed curve).
2256 self.try_fitted_with_rho(Some(rho), false)
2257 }
2258
2259 pub(crate) fn try_fitted_with_rho(
2260 &self,
2261 rho: Option<&SaeManifoldRho>,
2262 collapse: bool,
2263 ) -> Result<Array2<f64>, String> {
2264 let n = self.n_obs();
2265 let p = self.output_dim();
2266 let k_atoms = self.k_atoms();
2267 let mut out = Array2::<f64>::zeros((n, p));
2268 // #1026 — the curved/linear hybrid-split verdict is LOAD-BEARING on the
2269 // production reconstruction, not just a side report. When
2270 // [`Self::compute_hybrid_split_report`] (run post-fit in
2271 // `canonicalize_charts_post_fit`) adjudicated a `d = 1` atom's evidence
2272 // in favour of its straight (Θ→0) sub-model, the model's output
2273 // reconstruction (`fitted()` / `try_fitted` → predict and the user-facing
2274 // output) decodes that slot with its fitted linear image instead of its
2275 // curved decoded curve. The linear images are coordinate-keyed and
2276 // rho-independent (exact weighted-LS lines realised inside the
2277 // adjudication — no re-fit, no #1051 outer continuation).
2278 //
2279 // The collapse engages only when the caller asks for it (`collapse`):
2280 // the production `try_fitted` path and the explicit
2281 // `hybrid_collapsed_reconstruction` entry point. The pure-curved
2282 // `try_fitted_for_rho` opts out — the joint fit's loss/assembly optimise
2283 // the curved decoder coefficients and must see the curved image, and the
2284 // #1026 adjudication itself compares the curved fit against its straight
2285 // sub-model — both require the uncollapsed curve. (During fitting the
2286 // report is `None` regardless; it is only computed post-fit.)
2287 let linear_images = if collapse {
2288 self.hybrid_linear_image_map()
2289 } else {
2290 std::collections::HashMap::new()
2291 };
2292 // Reuse a single scratch buffer across all (row, atom) pairs instead of
2293 // allocating a fresh `Array1<f64>` of length p per call.
2294 let mut g_buf = vec![0.0_f64; p];
2295 for row in 0..n {
2296 let a = match rho {
2297 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2298 None => self.assignment.try_assignments_row(row)?,
2299 };
2300 for atom_idx in 0..k_atoms {
2301 let a_k = a[atom_idx];
2302 if let Some(image) = linear_images.get(&atom_idx) {
2303 // Verdict-linear slot: substitute the straight sub-model image
2304 // at this row's fitted on-atom coordinate — or, for a #1026
2305 // collapse-rescued slot, at its fresh per-row code.
2306 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2307 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2308 } else {
2309 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2310 }
2311 let mut out_row = out.row_mut(row);
2312 for out_col in 0..p {
2313 out_row[out_col] += a_k * g_buf[out_col];
2314 }
2315 }
2316 }
2317 Ok(out)
2318 }
2319
2320 /// Per-atom **leave-one-atom-out (LOAO) explained-variance contribution**
2321 /// (#1026): for each atom `k`, the drop in reconstruction explained variance
2322 /// `ΔEV_k = EV(full) − EV(full ⊖ atom_k)` when that atom's contribution
2323 /// `a[i,k]·g_k(coord[i,k])` is removed from the assembled reconstruction and
2324 /// nothing else is refit. Because every atom adds linearly into the same
2325 /// fitted reconstruction (`fitted[i] = Σ_k a[i,k]·g_k`), zeroing one atom is
2326 /// the exact "this atom withheld" counterfactual, and the EV it was earning
2327 /// is `EV(full) − EV(without k)`. This is the per-atom held-out EV
2328 /// attribution the #1026 roadmap pairs with each atom's fitted turning `Θ`:
2329 /// a `Θ ≈ 0` atom earning a large `ΔEV` is a linear-tail direction; a
2330 /// high-`Θ` atom earning a large `ΔEV` is a genuine curved family carrying
2331 /// reconstruction it would otherwise shatter into `N(ε) ≈ Θ/(2√(2ε))` linear
2332 /// directions. Pure read-only diagnostic — never mutates any atom.
2333 ///
2334 /// Returns one `Option<f64>` per atom in atom order; `None` for an atom
2335 /// whose ⊖-reconstruction EV is undefined (degenerate target variance), and
2336 /// `None` for the whole vector if the full-reconstruction EV is undefined.
2337 /// #1026: the load-bearing curved-vs-linear hybrid-split verdict for the
2338 /// fitted dictionary, or `None` until [`Self::canonicalize_charts_post_fit`]
2339 /// has run (or when no `d = 1` atom is eligible). Surfaced in the Python model
2340 /// output so the user sees which atoms genuinely earn their curvature.
2341 pub fn hybrid_split_report(
2342 &self,
2343 ) -> Option<&crate::hybrid_split::SaeHybridSplitReport> {
2344 self.hybrid_split_report.as_ref()
2345 }
2346
2347 /// Build the #1026 curved-vs-linear hybrid-split report by adjudicating each
2348 /// eligible `d = 1` atom's fitted curved image against its straight (linear
2349 /// special-case) sub-model on the common rank-aware Laplace evidence scale.
2350 ///
2351 /// Both candidates are scored against the SAME data — the atom's
2352 /// leave-this-atom-out response residual `y_resp = target − (full − a_k·γ_k)`
2353 /// (#1202) — over its assigned rows: the curved candidate predicts its actual
2354 /// mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2355 /// mass-weighted straight line fit to `y_resp` (the collapsed linear lane —
2356 /// closed form, NOT the broken euclidean outer fit path of #1051). Linear is
2357 /// the curved family's nested `Θ = 0` sub-model on common data, so the
2358 /// per-slot evidence argmin is a genuine match-or-beat comparison. Eligible
2359 /// atoms are `d = 1` atoms with an installed evaluator at the full curvature
2360 /// dial (`homotopy_eta == 1.0`) whose live coordinate dim still matches the
2361 /// atom's latent dim. Returns `None` when no reconstruction `target` is
2362 /// supplied (there is no data to adjudicate against).
2363 pub fn compute_hybrid_split_report(
2364 &self,
2365 rho: &SaeManifoldRho,
2366 target: Option<ArrayView2<'_, f64>>,
2367 ) -> Result<Option<crate::hybrid_split::SaeHybridSplitReport>, String> {
2368 let n = self.n_obs();
2369 let p = self.output_dim();
2370 // Per-atom held-out `ΔEV_k` (leave-one-atom-out explained-variance drop),
2371 // paired with each atom's fitted turning Θ onto the verdict so the report
2372 // carries the #1026 `(Θ, ΔEV)` frontier point as structured data. Absent
2373 // when no reconstruction target is supplied.
2374 let loao_ev: Vec<Option<f64>> = match target {
2375 Some(t) => self.per_atom_loao_explained_variance(t, rho)?,
2376 None => vec![None; self.k_atoms()],
2377 };
2378 let delta_ev_for =
2379 |atom_idx: usize| -> Option<f64> { loao_ev.get(atom_idx).copied().flatten() };
2380 // The common-evidence comparison (#1202) scores both candidates against
2381 // the response data the atom is responsible for. That requires a target;
2382 // with none supplied there is nothing to adjudicate against, so no report.
2383 let Some(target) = target else {
2384 return Ok(None);
2385 };
2386 if target.dim() != (n, p) {
2387 return Err(format!(
2388 "SaeManifoldTerm::compute_hybrid_split_report: target {:?} != ({n}, {p})",
2389 target.dim()
2390 ));
2391 }
2392 // Per-row assignment masses (once), so each atom's weighted straight-line
2393 // fit uses the same row weighting the joint reconstruction loss does.
2394 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2395 for row in 0..n {
2396 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2397 }
2398 // The full assembled reconstruction `Σ_k a[i,k]·γ_k`, computed once. Each
2399 // atom's leave-this-atom-out response residual is `y_resp = target −
2400 // (full − a_k·γ_k)`, the data both that atom's candidates fit (#1202).
2401 let full = self.try_fitted_for_rho(rho)?;
2402 let eligible: Vec<usize> = (0..self.k_atoms())
2403 .filter(|&atom_idx| {
2404 let atom = &self.atoms[atom_idx];
2405 atom.latent_dim == 1
2406 && atom.basis_evaluator.is_some()
2407 && atom.homotopy_eta == 1.0
2408 && self.assignment.coords[atom_idx].latent_dim() == atom.latent_dim
2409 })
2410 .collect();
2411 // Per-atom fitted decoded image at every row (the curved candidate's
2412 // realized curve, which the linear candidate must approximate).
2413 let coords_for = |atom_idx: usize| -> Array1<f64> {
2414 self.assignment.coords[atom_idx]
2415 .as_matrix()
2416 .column(0)
2417 .to_owned()
2418 };
2419 let assign_for = |atom_idx: usize| -> Array1<f64> {
2420 Array1::from_iter((0..n).map(|row| weights[row][atom_idx]))
2421 };
2422 let decoded_for = |atom_idx: usize| -> Array2<f64> {
2423 let mut decoded = Array2::<f64>::zeros((n, p));
2424 let mut buf = vec![0.0_f64; p];
2425 for row in 0..n {
2426 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2427 for col in 0..p {
2428 decoded[[row, col]] = buf[col];
2429 }
2430 }
2431 decoded
2432 };
2433 // The atom's leave-this-atom-out response residual `y_resp = target −
2434 // (full − a_k·γ_k) = (target − full) + a_k·γ_k`. Both the curved and the
2435 // linear candidate are scored against this on common data (#1202).
2436 let target_resid_for = |atom_idx: usize| -> Array2<f64> {
2437 let mut resid = Array2::<f64>::zeros((n, p));
2438 let mut buf = vec![0.0_f64; p];
2439 for row in 0..n {
2440 let a_k = weights[row][atom_idx];
2441 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2442 for col in 0..p {
2443 resid[[row, col]] = target[[row, col]] - full[[row, col]] + a_k * buf[col];
2444 }
2445 }
2446 resid
2447 };
2448 let manifold_for = |atom_idx: usize| -> gam_terms::latent::LatentManifold {
2449 self.assignment.coords[atom_idx].manifold().clone()
2450 };
2451 // #1026 EV-preservation gate denominator: the full target's total
2452 // column-centered variance `SST_full` (the SAME `sst` the reconstruction
2453 // EV is measured against), so the gate vetoes any collapse that would drop
2454 // full-reconstruction EV by more than its tolerance.
2455 let total_centered_variance = {
2456 let mut tss = 0.0_f64;
2457 for col in 0..p {
2458 let mut mean = 0.0_f64;
2459 for row in 0..n {
2460 mean += target[[row, col]];
2461 }
2462 mean /= n as f64;
2463 for row in 0..n {
2464 let c = target[[row, col]] - mean;
2465 tss += c * c;
2466 }
2467 }
2468 tss
2469 };
2470 crate::hybrid_split::build_hybrid_split_report(
2471 &self.atoms,
2472 eligible.into_iter(),
2473 coords_for,
2474 assign_for,
2475 decoded_for,
2476 target_resid_for,
2477 manifold_for,
2478 delta_ev_for,
2479 total_centered_variance,
2480 )
2481 }
2482
2483 pub fn per_atom_loao_explained_variance(
2484 &self,
2485 target: ArrayView2<'_, f64>,
2486 rho: &SaeManifoldRho,
2487 ) -> Result<Vec<Option<f64>>, String> {
2488 let n = self.n_obs();
2489 let p = self.output_dim();
2490 let k_atoms = self.k_atoms();
2491 if target.dim() != (n, p) {
2492 return Err(format!(
2493 "SaeManifoldTerm::per_atom_loao_explained_variance: target {:?} != ({n}, {p})",
2494 target.dim()
2495 ));
2496 }
2497 let full = self.try_fitted_for_rho(rho)?;
2498 let Some(ev_full) = reconstruction_explained_variance(target, full.view()) else {
2499 return Ok(vec![None; k_atoms]);
2500 };
2501 // Cache each row's assignment weights once, then subtract a single
2502 // atom's decoded contribution per LOAO pass instead of reassembling the
2503 // whole dictionary k times.
2504 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2505 for row in 0..n {
2506 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2507 }
2508 let mut g_buf = vec![0.0_f64; p];
2509 let mut out = Vec::with_capacity(k_atoms);
2510 for atom_idx in 0..k_atoms {
2511 let mut without = full.clone();
2512 for row in 0..n {
2513 let a_k = weights[row][atom_idx];
2514 if a_k == 0.0 {
2515 continue;
2516 }
2517 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2518 let mut without_row = without.row_mut(row);
2519 for out_col in 0..p {
2520 without_row[out_col] -= a_k * g_buf[out_col];
2521 }
2522 }
2523 out.push(
2524 reconstruction_explained_variance(target, without.view())
2525 .map(|ev_without| ev_full - ev_without),
2526 );
2527 }
2528 Ok(out)
2529 }
2530
2531 /// #1026 — the LOAD-BEARING collapsed reconstruction: the assembled
2532 /// dictionary output `Σ_k a[i,k]·g_k(coord[i,k])` in which every slot whose
2533 /// hybrid-split verdict selected LINEAR has its curved decoded image replaced
2534 /// by its fitted straight sub-model `b₀ + (t − t̄)·b₁`. This is what makes the
2535 /// verdict *change the reconstruction* instead of merely logging a choice:
2536 /// the linear-collapsed atom no longer pays its `M·p` curved coefficients, it
2537 /// carries a `2·p` straight image whose decoded curve has zero turning.
2538 ///
2539 /// The straight images are the exact weighted-least-squares lines already
2540 /// realized inside [`Self::compute_hybrid_split_report`] (no re-fit, no outer
2541 /// continuation, sidestepping #1051). Returns the curved reconstruction
2542 /// unchanged when no verdict selected linear, or when the report has not been
2543 /// computed yet (`hybrid_split_report == None`).
2544 pub fn hybrid_collapsed_reconstruction(
2545 &self,
2546 rho: &SaeManifoldRho,
2547 ) -> Result<Array2<f64>, String> {
2548 // #1026 — the hybrid collapse is realised by the SINGLE reconstruction
2549 // path ([`Self::try_fitted_with_rho`]) with the collapse flag set: a
2550 // verdict-linear `d = 1` slot decodes its straight sub-model image
2551 // instead of its curved curve. This replaces the dedicated re-collapse
2552 // loop this method used to carry (a parallel layer). The production
2553 // `try_fitted` shares the identical routine at `rho = None`; this entry
2554 // point keeps the rho-keyed collapse for the #1026 EV-dominance reporting
2555 // (`hybrid_collapsed_explained_variance`) and the regression battery.
2556 self.try_fitted_with_rho(Some(rho), true)
2557 }
2558
2559 /// #1026 — the reconstruction explained variance of the hybrid-collapsed
2560 /// dictionary (every verdict-linear slot decoded by its straight sub-model)
2561 /// against `target`. The companion of [`Self::per_atom_loao_explained_variance`]
2562 /// for the dominance claim: because each linear-collapsed slot is the curved
2563 /// family's `Θ → 0` sub-model and is only kept when its evidence beats the
2564 /// curved candidate's parameter price, the collapsed dictionary match-or-beats
2565 /// the all-curved one on EV-per-parameter — the strict-generalization floor
2566 /// the #1026 hybrid argument rests on. `None` when EV is undefined (degenerate
2567 /// target variance).
2568 pub fn hybrid_collapsed_explained_variance(
2569 &self,
2570 target: ArrayView2<'_, f64>,
2571 rho: &SaeManifoldRho,
2572 ) -> Result<Option<f64>, String> {
2573 let n = self.n_obs();
2574 let p = self.output_dim();
2575 if target.dim() != (n, p) {
2576 return Err(format!(
2577 "SaeManifoldTerm::hybrid_collapsed_explained_variance: target {:?} != ({n}, {p})",
2578 target.dim()
2579 ));
2580 }
2581 let collapsed = self.hybrid_collapsed_reconstruction(rho)?;
2582 Ok(reconstruction_explained_variance(target, collapsed.view()))
2583 }
2584
2585 /// #1026 ladder item 2/3 — the AMORTIZED ENCODER, wired from the fitted
2586 /// dictionary. Builds the offline certified [`EncodeAtlas`] over this term's
2587 /// frozen atoms and encodes a target corpus `targets` (`n × p`) through the
2588 /// per-chart distilled Jacobian predictor, with the Kantorovich certificate
2589 /// gating each row and an exact-solve fallback for the rows the amortized
2590 /// predictor cannot certify. Returns one [`EncodeResult`] per atom (the
2591 /// per-atom encoded coordinates + per-row certificate mask), in dictionary
2592 /// order.
2593 ///
2594 /// This is the thread's "encoder + certificate-gated exact fallback"
2595 /// deployment made reachable from a fit: the distilled map approximates
2596 /// inference at one mat-vec/row, and any row whose amortized prediction fails
2597 /// `h ≤ ½` falls back to the certified IFT-warm-start Newton encode
2598 /// ([`EncodeAtlas::certified_encode_row`]); rows that still cannot be
2599 /// certified ride the [`EncodeResult::encode_uncertified_count`] flag for the
2600 /// upstream exact multi-start solve (honesty, never a silent wrong encode).
2601 ///
2602 /// Magic by default: the atlas's worst-case bounds are auto-derived from the
2603 /// fit — `amplitude_bound[k]` is the largest fitted assignment mass `a[i,k]`
2604 /// the encode can produce for atom `k` (the encode recovers `t` from
2605 /// `x ≈ z·γ_k(t)` at amplitude `z = a[i,k]`), and `target_norm_bound` is the
2606 /// largest target row norm — so no caller supplies a knob. Per-row amplitudes
2607 /// are the fitted assignment masses for the same target the dictionary was fit
2608 /// against; an external corpus reuses the per-row masses the assignment
2609 /// produces for it upstream (passed in `amplitudes`, one column per atom).
2610 pub fn amortized_encode_target(
2611 &self,
2612 targets: ArrayView2<'_, f64>,
2613 amplitudes: ArrayView2<'_, f64>,
2614 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2615 let p = self.output_dim();
2616 let k_atoms = self.k_atoms();
2617 let n = targets.nrows();
2618 if targets.ncols() != p {
2619 return Err(format!(
2620 "SaeManifoldTerm::amortized_encode_target: targets have {} cols but output_dim is {p}",
2621 targets.ncols()
2622 ));
2623 }
2624 if amplitudes.dim() != (n, k_atoms) {
2625 return Err(format!(
2626 "SaeManifoldTerm::amortized_encode_target: amplitudes {:?} must be (n={n}, K={k_atoms})",
2627 amplitudes.dim()
2628 ));
2629 }
2630
2631 // Magic-by-default offline bounds, auto-derived from the fit so no caller
2632 // supplies a knob. `target_norm_bound` is the largest target row L2 norm
2633 // (bounds `‖x‖` over the corpus); `amplitude_bound[k]` is the largest
2634 // fitted assignment mass for atom `k` (bounds `|z_k|`), with a strictly
2635 // positive floor so a near-inactive atom still certifies a finite radius.
2636 let mut target_norm_bound = 0.0_f64;
2637 for row in 0..n {
2638 let norm = targets.row(row).dot(&targets.row(row)).sqrt();
2639 if norm.is_finite() && norm > target_norm_bound {
2640 target_norm_bound = norm;
2641 }
2642 }
2643 let mut amplitude_bound = vec![0.0_f64; k_atoms];
2644 for atom_idx in 0..k_atoms {
2645 let mut bound = 0.0_f64;
2646 for row in 0..n {
2647 let z = amplitudes[[row, atom_idx]].abs();
2648 if z.is_finite() && z > bound {
2649 bound = z;
2650 }
2651 }
2652 // A strictly positive amplitude floor keeps the offline Lipschitz
2653 // scaling finite for atoms with no active row in this corpus (those
2654 // rows encode to the chart center via the certificate anyway).
2655 amplitude_bound[atom_idx] = bound.max(1.0);
2656 }
2657
2658 let atlas = crate::encode::EncodeAtlas::build(
2659 &self.atoms,
2660 &litude_bound,
2661 target_norm_bound,
2662 crate::encode::AtlasConfig::default(),
2663 )?;
2664
2665 // Per-atom amortized encode with a certificate-gated exact-solve fallback:
2666 // a row whose distilled prediction fails `h ≤ ½` is retried through the
2667 // certified IFT-warm-start Newton path; a row that still cannot be
2668 // certified stays flagged for the upstream multi-start solve.
2669 // (The atlas is rho-free; the per-row amplitudes already carry the
2670 // rho-resolved assignment masses the caller produced upstream.)
2671 let mut results = Vec::with_capacity(k_atoms);
2672 for atom_idx in 0..k_atoms {
2673 let atom = &self.atoms[atom_idx];
2674 let amp_col = amplitudes.column(atom_idx).to_owned();
2675 let amortized =
2676 atlas.amortized_encode_batch(atom, atom_idx, targets, amp_col.view())?;
2677 let mut coords = amortized.coords;
2678 let mut certified = amortized.certified;
2679 for row in 0..n {
2680 if certified[row] {
2681 continue;
2682 }
2683 let (t, cert) =
2684 atlas.certified_encode_row(atom, atom_idx, targets.row(row), amp_col[row])?;
2685 if cert.certified() {
2686 coords.row_mut(row).assign(&t);
2687 certified[row] = true;
2688 }
2689 }
2690 results.push(crate::encode::EncodeResult::from_rows(
2691 coords, certified,
2692 ));
2693 }
2694 Ok(results)
2695 }
2696
2697 /// #1026 — the fitted per-row assignment masses `a[i,k]` (the activation
2698 /// amplitudes `z_k` the amortized encode recovers `t` against), as an
2699 /// `n × K` matrix. These are exactly the masses
2700 /// [`Self::try_fitted_with_rho`] assembles the reconstruction from, so
2701 /// feeding them to [`Self::amortized_encode_target`] re-encodes the SAME
2702 /// inference the dictionary was fit against — the self-consistency the
2703 /// distilled encoder is supervised to approximate.
2704 pub fn fitted_assignment_amplitudes(
2705 &self,
2706 rho: &SaeManifoldRho,
2707 ) -> Result<Array2<f64>, String> {
2708 let n = self.n_obs();
2709 let k_atoms = self.k_atoms();
2710 let mut amplitudes = Array2::<f64>::zeros((n, k_atoms));
2711 for row in 0..n {
2712 let a = self.assignment.try_assignments_row_for_rho(row, rho)?;
2713 for atom_idx in 0..k_atoms {
2714 amplitudes[[row, atom_idx]] = a[atom_idx];
2715 }
2716 }
2717 Ok(amplitudes)
2718 }
2719
2720 /// #1026 — encode the dictionary's own fit-time target with the amortized
2721 /// encoder, deriving the per-row amplitudes from the fitted assignment so the
2722 /// caller supplies neither bounds nor amplitudes (magic by default). The
2723 /// end-to-end "fit → distilled encoder → certificate-gated encode" path.
2724 pub fn amortized_encode_fitted(
2725 &self,
2726 targets: ArrayView2<'_, f64>,
2727 rho: &SaeManifoldRho,
2728 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2729 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2730 self.amortized_encode_target(targets, amplitudes.view())
2731 }
2732
2733 /// #1154 — amortized-encoder consistency of the CURRENT dictionary against
2734 /// its own fit-time target. This is the co-training signal of the joint
2735 /// amortized-encoder + REML loop (Design A): the amortized (one-mat-vec)
2736 /// encode is built from the *current* fitted decoder, run on `targets`, and
2737 /// scored on two principled axes —
2738 ///
2739 /// * `recon_consistency` (the bilinear part of the co-training loss): the
2740 /// mean per-element squared gap between the **amortized** reconstruction
2741 /// `Σ_k z_k · Φ_k(t̂_k) B_k` (decode the amortized coords) and the
2742 /// **exact** fitted reconstruction `Σ_k z_k · Φ_k(t_k^*) B_k` the inner
2743 /// solve converged to. A dictionary whose encode map is well-approximated
2744 /// to first order by the per-chart IFT predictor scores near zero; a
2745 /// dictionary the amortized encoder *cannot* invert faithfully (sharp
2746 /// curvature, poorly-charted regions) scores high. Minimising this jointly
2747 /// with REML steers the fit toward dictionaries that admit a fast,
2748 /// faithful amortized encode — the architectural co-adaptation #1154 adds.
2749 /// * `uncertified_fraction`: the share of (row, atom) encodes whose
2750 /// Kantorovich certificate failed (`h > ½`), i.e. that fell back to the
2751 /// certified IFT-warm-start Newton. This is the encoder's *certifiable coverage*
2752 /// of the dictionary; co-training rewards dictionaries the cheap encode
2753 /// certifies, not just ones it happens to land.
2754 ///
2755 /// The certificate keeps every accepted amortized coord honest (uncertified
2756 /// rows already ride the exact fallback inside `amortized_encode_target`), so
2757 /// this metric never silently trusts a wrong encode — it MEASURES how much of
2758 /// the dictionary the cheap encoder can faithfully and certifiably invert.
2759 pub fn amortized_encoder_consistency(
2760 &self,
2761 targets: ArrayView2<'_, f64>,
2762 rho: &SaeManifoldRho,
2763 ) -> Result<AmortizedEncoderConsistency, String> {
2764 let n = self.n_obs();
2765 let p = self.output_dim();
2766 let k_atoms = self.k_atoms();
2767 if targets.dim() != (n, p) {
2768 return Err(format!(
2769 "SaeManifoldTerm::amortized_encoder_consistency: targets {:?} must be (n={n}, p={p})",
2770 targets.dim()
2771 ));
2772 }
2773 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2774 let encodes = self.amortized_encode_target(targets, amplitudes.view())?;
2775 // The EXACT fitted reconstruction the inner solve converged to (pure
2776 // curved image, rho-keyed) is the supervision target for the amortized
2777 // reconstruction. Both are n×p ambient, so the comparison is layout-free.
2778 let exact_recon = self.try_fitted_for_rho(rho)?;
2779
2780 // Build the amortized reconstruction Σ_k z_k · Φ_k(t̂_k) B_k by decoding
2781 // each atom's amortized coords through that atom's own basis evaluator.
2782 let mut amortized_recon = Array2::<f64>::zeros((n, p));
2783 let mut uncertified = 0usize;
2784 for atom_idx in 0..k_atoms {
2785 let atom = &self.atoms[atom_idx];
2786 let result = &encodes[atom_idx];
2787 // An atom with no basis evaluator cannot decode an amortized
2788 // reconstruction; every one of its rows is necessarily uncertified
2789 // (the encode flagged them all), so it contributes nothing to the
2790 // amortized recon and its full row-count to the uncertified tally.
2791 // Count it and skip the decode rather than erroring — the consistency
2792 // fold stays a bounded penalty, never a hard abort of the criterion.
2793 let Some(evaluator) = atom.basis_evaluator.as_ref() else {
2794 uncertified += n;
2795 continue;
2796 };
2797 uncertified += result.encode_uncertified_count;
2798 // Decode the amortized coords: Φ_k(t̂) is (n × M_k); B_k is (M_k × p).
2799 let (phi, _jac) = evaluator.evaluate(result.coords.view())?;
2800 let decoded = phi.dot(&atom.decoder_coefficients); // (n × p)
2801 for row in 0..n {
2802 let z = amplitudes[[row, atom_idx]];
2803 if z == 0.0 {
2804 continue;
2805 }
2806 for col in 0..p {
2807 amortized_recon[[row, col]] += z * decoded[[row, col]];
2808 }
2809 }
2810 }
2811
2812 let mut sse = 0.0_f64;
2813 for row in 0..n {
2814 for col in 0..p {
2815 let gap = amortized_recon[[row, col]] - exact_recon[[row, col]];
2816 sse += gap * gap;
2817 }
2818 }
2819 let denom = (n.max(1) * p.max(1)) as f64;
2820 let recon_consistency = sse / denom;
2821 let total_encodes = (n * k_atoms).max(1) as f64;
2822 let uncertified_fraction = uncertified as f64 / total_encodes;
2823
2824 Ok(AmortizedEncoderConsistency {
2825 recon_consistency,
2826 uncertified_fraction,
2827 n_uncertified: uncertified,
2828 n_encodes: n * k_atoms,
2829 })
2830 }
2831
2832 /// #1154 — the co-trained REML criterion: the exact REML criterion at `rho`
2833 /// PLUS the amortized-encoder consistency penalty, so the outer optimizer
2834 /// co-adapts the dictionary + smoothing parameters λ TOWARD a dictionary the
2835 /// fast amortized encoder can faithfully and certifiably invert.
2836 ///
2837 /// This is Design A of #1154. The inner solve still converges the `(t, β)`
2838 /// system to stationarity at the engine's current ρ (so the implicit-function
2839 /// REML λ-gradient `dβ̂/dλ = −(H+S_λ)⁻¹(dS_λ/dλ)β̂` stays EXACT — the encoder
2840 /// only warm-starts/co-adapts, it never replaces the stationary point). The
2841 /// added term
2842 ///
2843 /// ```text
2844 /// J_cotrain(ρ) = REML(ρ) + w · ‖x̂_amortized − x̂_exact‖²/(n·p)
2845 /// + w_cert · uncertified_fraction
2846 /// ```
2847 ///
2848 /// folds the post-fit amortized-encode quality into the ranked objective. The
2849 /// weights are auto-scaled to the REML criterion magnitude (magic by default:
2850 /// no caller knob) so the consistency term is a meaningful but non-dominant
2851 /// fraction of the objective regardless of problem scale.
2852 pub fn reml_criterion_cotrained(
2853 &mut self,
2854 target: ArrayView2<'_, f64>,
2855 rho: &SaeManifoldRho,
2856 registry: Option<&AnalyticPenaltyRegistry>,
2857 inner_max_iter: usize,
2858 learning_rate: f64,
2859 ridge_ext_coord: f64,
2860 ridge_beta: f64,
2861 ) -> Result<(f64, SaeManifoldLoss, AmortizedEncoderConsistency), String> {
2862 // #1154: always attempt the amortized warm-start first inside
2863 // `reml_criterion_cotrained` (the encode/warm path for the cotrained
2864 // objective). Good warm-starts from the running dictionary land the
2865 // inner solve closer to the stationary point used for the fold.
2866 // Advisory only (0 or err falls back to cold); telemetry recorded by
2867 // outer objective callers when present.
2868 self.warm_start_latents_from_amortized_encoder(target, rho)
2869 .unwrap_or(0);
2870 let (reml, loss) = self.reml_criterion_with_refine_policy(
2871 target,
2872 rho,
2873 registry,
2874 inner_max_iter,
2875 learning_rate,
2876 ridge_ext_coord,
2877 ridge_beta,
2878 true,
2879 )?;
2880 let consistency = self.amortized_encoder_consistency(target, rho)?;
2881 // Auto-scale the co-training weights to the REML magnitude so the
2882 // consistency penalty is a bounded, scale-free fraction of the objective
2883 // (magic by default: no caller knob). `reml_scale` floors at 1 so a
2884 // near-zero criterion still admits a meaningful consistency contribution.
2885 let cotrained = Self::fold_cotrain_consistency(reml, &consistency);
2886 Ok((cotrained, loss, consistency))
2887 }
2888
2889 /// #1154 — the single source of the co-training fold arithmetic: add the
2890 /// auto-scaled amortized-encoder consistency penalty to an already-computed
2891 /// REML criterion at the converged dictionary. Both the public
2892 /// [`Self::reml_criterion_cotrained`] entry point and the outer-loop value /
2893 /// gradient lanes (`SaeManifoldOuterObjective::fold_cotrain_consistency`)
2894 /// route through THIS function, so the folded objective cannot drift between
2895 /// the criterion and the cascade-ranked cost (the objective↔gradient desync
2896 /// bug class). The weights are auto-scaled to the REML magnitude (`max(|REML|,
2897 /// 1)`) so the penalty is a bounded, scale-free fraction of the objective
2898 /// regardless of problem scale; the fold carries no analytic gradient (under
2899 /// Design A the REML λ-gradient stays the exact implicit-function path).
2900 #[must_use]
2901 pub fn fold_cotrain_consistency(
2902 reml_cost: f64,
2903 consistency: &AmortizedEncoderConsistency,
2904 ) -> f64 {
2905 let reml_scale = reml_cost.abs().max(1.0);
2906 reml_cost
2907 + COTRAIN_RECON_WEIGHT * reml_scale * consistency.recon_consistency
2908 + COTRAIN_CERT_WEIGHT * reml_scale * consistency.uncertified_fraction
2909 }
2910
2911 /// #1154 item 2 — warm-start the inner latent coordinates from the amortized
2912 /// encoder (Design A). Builds the per-chart IFT-Jacobian atlas from the
2913 /// CURRENT dictionary, runs the one-mat-vec amortized encode of `target`
2914 /// against each atom at the rho-resolved assignment masses, and overwrites
2915 /// each atom's stored latent coords with the predicted `t̂` ON THE ROWS THE
2916 /// KANTOROVICH CERTIFICATE ACCEPTS. Uncertified rows are left at their
2917 /// current coords (the previous-iterate start), so the
2918 /// warm-start can only HELP — a row the cheap predictor cannot certify never
2919 /// corrupts the seed. The subsequent inner Newton refines from this seed to
2920 /// the SAME stationary point (the warm-start changes only the basin entry,
2921 /// not the root), so the REML λ-gradient stays exactly the implicit-function
2922 /// path and the criterion is unchanged at convergence — the amortized encoder
2923 /// only accelerates/co-adapts the inner solve, it never replaces the
2924 /// stationary point.
2925 ///
2926 /// Returns the number of (row, atom) coords actually warm-started (the
2927 /// certified-prediction count), for instrumentation / tests. A first-build
2928 /// dictionary with no usable charts simply warm-starts nothing and returns 0
2929 /// (the cold path is byte-for-byte unchanged).
2930 pub fn warm_start_latents_from_amortized_encoder(
2931 &mut self,
2932 target: ArrayView2<'_, f64>,
2933 rho: &SaeManifoldRho,
2934 ) -> Result<usize, String> {
2935 let n = self.n_obs();
2936 let k_atoms = self.k_atoms();
2937 if n == 0 || k_atoms == 0 {
2938 return Ok(0);
2939 }
2940 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2941 let encodes = self.amortized_encode_target(target, amplitudes.view())?;
2942 let mut warm_started = 0usize;
2943 for atom_idx in 0..k_atoms {
2944 let d = self.atoms[atom_idx].latent_dim;
2945 if d == 0 {
2946 continue;
2947 }
2948 let result = &encodes[atom_idx];
2949 // Start from the atom's CURRENT coords so uncertified rows are left
2950 // exactly as they were; overwrite only the certified predictions.
2951 let mut coords = self.assignment.coords[atom_idx].as_matrix();
2952 if coords.dim() != (n, d) {
2953 return Err(format!(
2954 "warm_start_latents_from_amortized_encoder: atom {atom_idx} coords {:?} != (n={n}, d={d})",
2955 coords.dim()
2956 ));
2957 }
2958 for row in 0..n {
2959 if !result.certified[row] {
2960 continue;
2961 }
2962 for axis in 0..d {
2963 coords[[row, axis]] = result.coords[[row, axis]];
2964 }
2965 warm_started += 1;
2966 }
2967 // `as_matrix` lays coords out row-major (`[[row, axis]]`), exactly the
2968 // `values[row*d + axis]` order `set_flat` expects, so a plain
2969 // row-major iterator reconstructs the flat vector.
2970 let flat = Array1::from_iter(coords.iter().copied());
2971 self.assignment.coords[atom_idx].set_flat(flat.view());
2972 }
2973 // The basis caches must follow the freshly-seeded coords so the next
2974 // inner solve evaluates Φ at the warm-started t̂, not the stale coords.
2975 self.refresh_basis_from_current_coords()?;
2976 Ok(warm_started)
2977 }
2978
2979 pub fn loss(
2980 &self,
2981 target: ArrayView2<'_, f64>,
2982 rho: &SaeManifoldRho,
2983 ) -> Result<SaeManifoldLoss, String> {
2984 self.loss_scaled(target, rho, 1.0)
2985 }
2986
2987 /// Penalized objective with a `penalty_scale` applied to the β-tier
2988 /// (decoder smoothness) penalty, mirroring
2989 /// [`Self::assemble_arrow_schur_scaled`]. The streaming line search sums
2990 /// per-chunk `loss_scaled(..., n_chunk / N)` so that the global smoothness
2991 /// penalty is counted exactly once across a pass while the per-row data,
2992 /// assignment-prior, and ARD terms sum naturally. `penalty_scale == 1.0`
2993 /// recovers the full-batch objective.
2994 pub fn loss_scaled(
2995 &self,
2996 target: ArrayView2<'_, f64>,
2997 rho: &SaeManifoldRho,
2998 penalty_scale: f64,
2999 ) -> Result<SaeManifoldLoss, String> {
3000 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3001 return Err(format!(
3002 "SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3003 ));
3004 }
3005 if target.dim() != (self.n_obs(), self.output_dim()) {
3006 return Err(format!(
3007 "SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
3008 self.n_obs(),
3009 self.output_dim(),
3010 target.dim()
3011 ));
3012 }
3013 // The likelihood whitens through the RowMetric **only** when the metric
3014 // is a genuinely estimated noise model (`metric.whitens_likelihood()`,
3015 // i.e. `WhitenedStructured` — the #974 residual-covariance seam). For
3016 // Euclidean (default `None`) and for the OutputFisher *gauge* metric the
3017 // reconstruction data-fit stays the isotropic `0.5 * Σ r²`: a gauge /
3018 // output-Fisher inner product must NOT silently replace the
3019 // reconstruction loss with a Fisher pullback (#980). It only drives the
3020 // gauge (see `analytic_penalties::corrected_isometry_penalty`). The
3021 // producer of `WhitenedStructured` is
3022 // `inference::residual_factor::StructuredResidualModel::row_metric`; the
3023 // SAME metric whitens the assembled gradient/Hessian in
3024 // `assemble_arrow_schur` (the single #974 seam), so this value and that
3025 // gradient cannot desync. Without a whitening metric this path is
3026 // bit-for-bit the historical isotropic data-fit.
3027 let whitens = self
3028 .row_metric
3029 .as_ref()
3030 .is_some_and(|metric| metric.whitens_likelihood());
3031 // #991 design honesty weights: the reconstruction channel of row `i`
3032 // is weighted by `w_i` (mean-1 HT inclusion correction). The assembly
3033 // applies the same `w_i` via a `√w_i` scaling of the row residual /
3034 // Jacobian / β load at its single seam, so this value and that
3035 // gradient/Hessian carry the identical per-row factor. `None` ⇒ the
3036 // historical unweighted sum, bit-for-bit.
3037 let row_loss_w = self.row_loss_weights.as_deref();
3038 let n = self.n_obs();
3039 let p = self.output_dim();
3040 let k_atoms = self.k_atoms();
3041 // #1017: the data-fit is the dominant per-line-search-trial cost (it
3042 // re-runs every Armijo halving × every inner Newton iteration × every
3043 // outer ρ evaluation). The old path materialised the whole `n × p`
3044 // fitted matrix (`try_fitted_for_rho`) and then walked it AGAIN to form
3045 // the residual sum — two sequential `n·p` passes plus an `n·p`
3046 // allocation per trial. Fuse the reconstruction and the residual reduce
3047 // into ONE row-parallel pass that never materialises the fitted matrix:
3048 // each row decodes its atoms into per-worker scratch, differences
3049 // against the target, and contributes its scalar `0.5·w·‖r‖²` to a
3050 // chunk-ordered fold (bit-identical run-to-run). Per-worker scratch
3051 // (`map_init`) keeps the only allocations one `g_buf`/`fitted_row` pair
3052 // per rayon thread rather than per row. Stay sequential inside a worker
3053 // (the topology race owns the outer pool) to avoid nested
3054 // oversubscription.
3055 let parallel = n >= SAE_LOSS_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
3056 let row_data_fit =
3057 |row: usize,
3058 g_buf: &mut [f64],
3059 fitted_row: &mut [f64],
3060 assign_buf: &mut [f64]|
3061 -> Result<f64, String> {
3062 // #1557 — fill the per-atom assignment row into reused per-worker
3063 // scratch via the `_into` twin instead of heap-allocating a fresh
3064 // `Array1` per row per loss eval. Bit-identical to the allocating
3065 // `try_assignments_row_for_rho` (same arithmetic, same order); this
3066 // loss reruns every Armijo halving × inner Newton iter × outer ρ
3067 // eval, so the per-row K-sized allocation was a hot-path churn.
3068 self.assignment
3069 .try_assignments_row_for_rho_into(row, rho, assign_buf)?;
3070 let a = &*assign_buf;
3071 for slot in fitted_row.iter_mut() {
3072 *slot = 0.0;
3073 }
3074 for atom_idx in 0..k_atoms {
3075 self.atoms[atom_idx].fill_decoded_row(row, g_buf);
3076 let a_k = a[atom_idx];
3077 for out_col in 0..p {
3078 fitted_row[out_col] += a_k * g_buf[out_col];
3079 }
3080 }
3081 for out_col in 0..p {
3082 fitted_row[out_col] = target[[row, out_col]] - fitted_row[out_col];
3083 }
3084 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3085 let mut acc = 0.0_f64;
3086 match self.row_metric.as_ref() {
3087 Some(metric) if whitens => {
3088 let resid = ArrayView1::from(&fitted_row[..p]);
3089 for w in metric.whiten_residual_row(row, resid) {
3090 acc += 0.5 * w_row * w * w;
3091 }
3092 }
3093 _ => {
3094 for &r in fitted_row[..p].iter() {
3095 acc += 0.5 * w_row * r * r;
3096 }
3097 }
3098 }
3099 Ok(acc)
3100 };
3101 let data_fit = if parallel {
3102 use rayon::prelude::*;
3103 const CHUNK: usize = 32;
3104 let partials: Vec<Result<f64, String>> = (0..n)
3105 .into_par_iter()
3106 .chunks(CHUNK)
3107 .map_init(
3108 || (vec![0.0_f64; p], vec![0.0_f64; p], vec![0.0_f64; k_atoms]),
3109 |(g_buf, fitted_row, assign_buf), idxs| {
3110 // #1557 — pin any faer GEMM reached from this row-parallel
3111 // data-fit chunk to `Par::Seq` (no nested Rayon re-fan); the
3112 // per-row reductions are tiny, so the result is bit-identical.
3113 with_nested_parallel(|| {
3114 let mut acc = 0.0_f64;
3115 for row in idxs {
3116 acc += row_data_fit(row, g_buf, fitted_row, assign_buf)?;
3117 }
3118 Ok(acc)
3119 })
3120 },
3121 )
3122 .collect();
3123 let mut total = 0.0_f64;
3124 for partial in partials {
3125 total += partial?;
3126 }
3127 total
3128 } else {
3129 let mut g_buf = vec![0.0_f64; p];
3130 let mut fitted_row = vec![0.0_f64; p];
3131 let mut assign_buf = vec![0.0_f64; k_atoms];
3132 let mut total = 0.0_f64;
3133 for row in 0..n {
3134 total += row_data_fit(row, &mut g_buf, &mut fitted_row, &mut assign_buf)?;
3135 }
3136 total
3137 };
3138 let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
3139 let smoothness = penalty_scale * self.decoder_smoothness_value(&rho.lambda_smooth_vec());
3140 let ard = self.ard_value(rho)?;
3141 Ok(SaeManifoldLoss {
3142 data_fit,
3143 assignment_sparsity,
3144 smoothness,
3145 ard,
3146 evidence_gauge_deflated_directions: 0,
3147 })
3148 }
3149
3150 /// Reconstruction data-fit `0.5·Σ_i w_i·‖whiten(Z_i − R_i)‖²` for an EXPLICIT
3151 /// reconstruction matrix `R` (e.g. the hard top-k–projected `fitted`), using
3152 /// the SAME per-row metric and design-honesty weights as [`Self::loss_scaled`]
3153 /// (the soft-assignment data-fit). The only difference is the residual source:
3154 /// `loss_scaled` decodes the soft assignments on the fly, this consumes a
3155 /// reconstruction the caller already assembled (so the projected loss and the
3156 /// returned projected `fitted` describe one and the same model). The penalty
3157 /// terms (`assignment_sparsity`/`smoothness`/`ard`) are decoder/ρ properties
3158 /// the top-k gate does not change, so the caller keeps them from the soft
3159 /// `loss_scaled` and only swaps this data-fit in — see #1232.
3160 pub fn data_fit_for_reconstruction(
3161 &self,
3162 target: ArrayView2<'_, f64>,
3163 reconstruction: ArrayView2<'_, f64>,
3164 ) -> Result<f64, String> {
3165 let n = self.n_obs();
3166 let p = self.output_dim();
3167 if target.dim() != (n, p) {
3168 return Err(format!(
3169 "SaeManifoldTerm::data_fit_for_reconstruction: Z must be ({n}, {p}); got {:?}",
3170 target.dim()
3171 ));
3172 }
3173 if reconstruction.dim() != (n, p) {
3174 return Err(format!(
3175 "SaeManifoldTerm::data_fit_for_reconstruction: reconstruction must be ({n}, {p}); got {:?}",
3176 reconstruction.dim()
3177 ));
3178 }
3179 let whitens = self
3180 .row_metric
3181 .as_ref()
3182 .is_some_and(|metric| metric.whitens_likelihood());
3183 let row_loss_w = self.row_loss_weights.as_deref();
3184 let mut resid = vec![0.0_f64; p];
3185 let mut total = 0.0_f64;
3186 for row in 0..n {
3187 for out_col in 0..p {
3188 resid[out_col] = target[[row, out_col]] - reconstruction[[row, out_col]];
3189 }
3190 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3191 match self.row_metric.as_ref() {
3192 Some(metric) if whitens => {
3193 let r = ArrayView1::from(&resid[..p]);
3194 for w in metric.whiten_residual_row(row, r) {
3195 total += 0.5 * w_row * w * w;
3196 }
3197 }
3198 _ => {
3199 for &r in resid[..p].iter() {
3200 total += 0.5 * w_row * r * r;
3201 }
3202 }
3203 }
3204 }
3205 Ok(total)
3206 }
3207
3208 pub fn analytic_penalty_value_total(
3209 &self,
3210 registry: &AnalyticPenaltyRegistry,
3211 penalty_scale: f64,
3212 ) -> Result<f64, ArrowSchurError> {
3213 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3214 return Err(ArrowSchurError::SchurFactorFailed {
3215 reason: format!(
3216 "SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
3217 and positive; got {penalty_scale}"
3218 ),
3219 });
3220 }
3221 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3222 let layout = registry.rho_layout();
3223 let beta = self.flatten_beta();
3224 let mut value = 0.0_f64;
3225 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
3226 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3227 // Skip the registry `ARDPenalty` here for the same reason it is
3228 // skipped in `add_sae_analytic_penalty_contributions`: the coordinate
3229 // ARD energy is already counted by `loss.ard` (the von-Mises
3230 // `ard_value`), and the registry penalty's legacy Gaussian `½λt²` is
3231 // period-discontinuous. Including it would double-count the energy and
3232 // make this line-search objective jump across the branch cut while the
3233 // assembled gradient (von-Mises only, after the assembly fix) stays
3234 // continuous — i.e. a near-zero step would change the objective by a
3235 // finite amount and Armijo would wrongly reject it.
3236 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
3237 continue;
3238 }
3239 match tier {
3240 PenaltyTier::Psi => {
3241 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
3242 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3243 value += penalty_scale
3244 * per_atom.value(beta.slice(s![start..end]), rho_local);
3245 }
3246 } else {
3247 if !sae_penalty_is_row_block_supported(penalty) {
3248 return Err(ArrowSchurError::SchurFactorFailed {
3249 reason: format!(
3250 "validate_analytic_penalty_registry should have refused \
3251 non-row-block Psi-tier penalty {:?} (registry layout name \
3252 {name:?})",
3253 penalty.name()
3254 ),
3255 });
3256 }
3257 for atom_idx in 0..self.k_atoms() {
3258 let coord = &self.assignment.coords[atom_idx];
3259 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3260 let corrected_kind =
3261 self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3262 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3263 } else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
3264 // Origin-anchored magnitude shrinkage (SCAD/MCP) is
3265 // restricted to the Euclidean axes; periodic axes have
3266 // no chart origin and would make this energy
3267 // period-discontinuous (issue #795). This must mirror
3268 // the gradient/curvature assembly in
3269 // `add_sae_coord_penalty` exactly.
3270 match sae_coord_penalty_euclidean_restriction(coord) {
3271 Some((_axes, compacted)) => {
3272 value += penalty.value(compacted.view(), rho_local);
3273 }
3274 None => {
3275 value += penalty.value(coord.as_flat().view(), rho_local);
3276 }
3277 }
3278 } else {
3279 value += penalty.value(coord.as_flat().view(), rho_local);
3280 }
3281 }
3282 }
3283 }
3284 PenaltyTier::Beta => {
3285 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
3286 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3287 value += penalty_scale * per_fit.value(beta.view(), rho_local);
3288 }
3289 } else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
3290 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3291 if start < end {
3292 value += penalty_scale * per_atom.value(beta.view(), rho_local);
3293 }
3294 }
3295 } else {
3296 value += penalty_scale * penalty.value(beta.view(), rho_local);
3297 }
3298 }
3299 PenaltyTier::Rho => {}
3300 }
3301 }
3302 Ok(value)
3303 }
3304
3305 /// Energy of the decoder-block analytic penalties that have no native
3306 /// `SaeManifoldLoss` counterpart, evaluated at the current decoder `β` and
3307 /// the converged SAE state. These act on the per-atom decoder coefficient
3308 /// matrices: cross-atom decoder incoherence (#671), mechanism
3309 /// (feature-group) sparsity, and nuclear-norm embedding rank (#672). Each
3310 /// is injected with its live per-atom shape / co-activation before its
3311 /// value is taken, mirroring the assemble path.
3312 ///
3313 /// This is deliberately narrower than [`Self::analytic_penalty_value_total`]:
3314 /// it excludes the Psi-tier coordinate / assignment penalties (ARD,
3315 /// Isometry, ScadMcp, BlockOrthogonality, IBP/softmax assignment sparsity).
3316 /// The SAE already carries its own ARD (`loss.ard`) and assignment sparsity
3317 /// (`loss.assignment_sparsity`) energy, so adding the registry ARD /
3318 /// assignment value on top would double-count, and the gauge-only
3319 /// coordinate penalties are not part of the penalized deviance the
3320 /// REML/Laplace criterion scores. The decoder-block penalties, by contrast,
3321 /// are real penalized-energy terms with no `loss.*` representative: the
3322 /// inner solve minimizes them (they enter `gb`/`hbb`) but they were absent
3323 /// from the criterion scalar `v`. This restores that consistency so the
3324 /// ρ-sweep ranks the same objective the inner solve descends — the #671
3325 /// incoherence lever in particular now shapes model selection, not just the
3326 /// Newton step.
3327 ///
3328 /// NOTE: the coordinate-block penalties with no native `loss.*` twin
3329 /// (`ScadMcp`, `BlockOrthogonality`) carry the same residual inconsistency
3330 /// (scored in the line search via `penalized_objective_total`, absent from
3331 /// the REML scalar). They are left out here because they share a registry
3332 /// dispatch with the always-on `Isometry` gauge, whose inclusion in the
3333 /// topology-comparison criterion is a separate design question (#673:
3334 /// topology evidence is gauge-conditional). Folding the coord-tier energy in
3335 /// is tracked apart from this #671 decoder fix.
3336 pub fn analytic_decoder_penalty_value_total(
3337 &self,
3338 registry: &AnalyticPenaltyRegistry,
3339 ) -> Result<f64, ArrowSchurError> {
3340 // Resolve each penalty's rho slice exactly as `analytic_penalty_value_total`
3341 // does (registry-local rho at zeros), so a learnable decoder-penalty weight
3342 // is honoured rather than indexing into an empty view.
3343 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3344 let layout = registry.rho_layout();
3345 let beta = self.flatten_beta();
3346 let mut value = 0.0_f64;
3347 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3348 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3349 match penalty {
3350 AnalyticPenaltyKind::DecoderIncoherence(base) => {
3351 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3352 value += per_fit.value(beta.view(), rho_local);
3353 }
3354 }
3355 AnalyticPenaltyKind::MechanismSparsity(base) => {
3356 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3357 if start < end {
3358 value += per_atom.value(beta.view(), rho_local);
3359 }
3360 }
3361 }
3362 AnalyticPenaltyKind::NuclearNorm(base) => {
3363 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3364 value += per_atom.value(beta.slice(s![start..end]), rho_local);
3365 }
3366 }
3367 _ => {}
3368 }
3369 }
3370 Ok(value)
3371 }
3372
3373 /// Energy of the COORDINATE-tier isometry penalty(ies) at the converged
3374 /// SAE state. This is the per-atom `½μ Σ_n ‖J_n^T W_n J_n / gbar − g_ref‖²`
3375 /// summed over atoms, evaluated through `corrected_isometry_penalty` so the
3376 /// live decoder/coordinate caches drive the value exactly as the assemble
3377 /// path does. It has no `SaeManifoldLoss` twin (the loss carries only
3378 /// data-fit / assignment / smoothness / ARD), so the Laplace/REML criterion
3379 /// must add it explicitly to score the same penalized objective the inner
3380 /// solve descends.
3381 pub fn isometry_penalty_value_total(
3382 &self,
3383 registry: &AnalyticPenaltyRegistry,
3384 ) -> Result<f64, ArrowSchurError> {
3385 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3386 let layout = registry.rho_layout();
3387 let mut value = 0.0_f64;
3388 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3389 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3390 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3391 for atom_idx in 0..self.k_atoms() {
3392 let coord = &self.assignment.coords[atom_idx];
3393 let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3394 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3395 }
3396 }
3397 }
3398 Ok(value)
3399 }
3400
3401 /// Whether assembling `registry` will scatter an isometry Gauss-Newton
3402 /// cross-block (`H_tβ`) into the per-row dense `htbeta` slabs.
3403 ///
3404 /// `add_sae_isometry_metric_gn_blocks` writes the coupled cross-block (and
3405 /// flips on `activate_dense_htbeta_supplement`) only when (a) the registry
3406 /// carries an `Isometry` penalty and (b) the atom's chart
3407 /// `preserves_isometry_cross_block_coherence` (flat charts — `Euclidean`,
3408 /// `Circle`, and flat products — keep the full `μ AᵀA` coupling; curved /
3409 /// boundary charts drop it to stay PSD). On the non-frames matrix-free path
3410 /// the data-fit cross-block is carried by the Kronecker row operator and the
3411 /// per-row `htbeta` slab is allocated at zero width (#1406/#1407 anti-leak),
3412 /// so this dense isometry supplement has nowhere to land unless the slab is
3413 /// widened to the full `beta_dim`. This predicate decides exactly that. The
3414 /// effective isometry weight `μ` is NOT consulted here: a near-zero `μ`
3415 /// short-circuits the per-row write, but the slab must still exist so the
3416 /// solver's `htbeta_dense_supplement` read is well-shaped.
3417 pub(crate) fn registry_writes_dense_isometry_cross_block(
3418 &self,
3419 registry: &AnalyticPenaltyRegistry,
3420 ) -> bool {
3421 registry
3422 .penalties
3423 .iter()
3424 .any(|p| matches!(p, AnalyticPenaltyKind::Isometry(_)))
3425 && self
3426 .assignment
3427 .coords
3428 .iter()
3429 .any(|coord| coord.manifold().preserves_isometry_cross_block_coherence())
3430 }
3431
3432 /// Extra analytic-penalty energy that has no native `SaeManifoldLoss`
3433 /// component but is part of the penalized objective ranked by the SAE
3434 /// Laplace/REML criterion.
3435 pub fn reml_extra_penalty_value_total(
3436 &self,
3437 registry: &AnalyticPenaltyRegistry,
3438 ) -> Result<f64, ArrowSchurError> {
3439 Ok(self.analytic_decoder_penalty_value_total(registry)?
3440 + self.isometry_penalty_value_total(registry)?)
3441 }
3442
3443 pub fn penalized_objective_total(
3444 &self,
3445 target: ArrayView2<'_, f64>,
3446 rho: &SaeManifoldRho,
3447 registry: Option<&AnalyticPenaltyRegistry>,
3448 penalty_scale: f64,
3449 ) -> Result<f64, String> {
3450 let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
3451 if let Some(analytic_registry) = registry {
3452 total += self
3453 .analytic_penalty_value_total(analytic_registry, penalty_scale)
3454 .map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
3455 }
3456 // #1026 — decoder-repulsion value, on the SAME frozen gate the assembly
3457 // used, so the line search sees the term the Newton step optimizes. 0
3458 // unless two atoms are near-collinear (the no-op case).
3459 total += self.decoder_repulsion_value(penalty_scale);
3460 // #1026/#1522 — interior-point collapse-prevention barriers, on the SAME
3461 // decoders the assembly's gradient/curvature used, so the line search sees
3462 // exactly the term the inner Newton step optimises (no value/grad desync).
3463 total += self.separation_barrier_value(penalty_scale);
3464 Ok(total)
3465 }
3466
3467 pub(crate) fn decoder_smoothness_value(&self, lambda_smooth: &[f64]) -> f64 {
3468 // Smoothness penalty value is `0.5·λ·Σ_oc B[:,oc]ᵀ S B[:,oc]`. Form the
3469 // `S·B` matrix product once per atom (O(M²·p)) and reduce against `B`
3470 // with a single O(M·p) Hadamard sum, instead of the previous
3471 // four-factor multiply-accumulate inside an `O(M²·p)` triple loop.
3472 // The quadratic form only sees the symmetric part of `S`, so reusing
3473 // the raw (un-symmetrised) `smooth_penalty` here is numerically
3474 // identical to the symmetrised assembly form.
3475 // Per-atom `S_k · B_k` products are independent across atoms, so they ride
3476 // the multi-GPU batched smoothness GEMM (uniform-shape groups tiled across
3477 // every device); `symmetrize = false` because the quadratic form only sees
3478 // the symmetric part of `S` regardless. Exact CPU fallback per atom.
3479 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3480 .atoms
3481 .iter()
3482 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3483 .collect();
3484 let sb_all = batched_smooth_sb(&sb_inputs, false);
3485 let mut acc = 0.0;
3486 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3487 acc += 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3488 }
3489 acc
3490 }
3491
3492 /// Per-atom decoder-smoothness values (#1556): entry `k` is
3493 /// `0.5·λ_smooth[k]·<B_k, S_k B_k>` (sum = [`Self::decoder_smoothness_value`]).
3494 /// This is the explicit `∂loss.smoothness/∂log λ_smooth[k]` gradient entry.
3495 pub(crate) fn decoder_smoothness_value_per_atom(&self, lambda_smooth: &[f64]) -> Vec<f64> {
3496 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3497 .atoms
3498 .iter()
3499 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3500 .collect();
3501 let sb_all = batched_smooth_sb(&sb_inputs, false);
3502 let mut per_atom = vec![0.0_f64; self.atoms.len()];
3503 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3504 per_atom[atom_idx] =
3505 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3506 }
3507 per_atom
3508 }
3509
3510 pub(crate) fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
3511 if rho.log_ard.len() != self.k_atoms() {
3512 return Err(format!(
3513 "ARD rho has {} atoms but term has {}",
3514 rho.log_ard.len(),
3515 self.k_atoms()
3516 ));
3517 }
3518 let n = self.n_obs();
3519 let mut acc = 0.0;
3520 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3521 let d = coord.latent_dim();
3522 if rho.log_ard[atom_idx].is_empty() {
3523 continue;
3524 }
3525 if rho.log_ard[atom_idx].len() != d {
3526 return Err(format!(
3527 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
3528 rho.log_ard[atom_idx].len()
3529 ));
3530 }
3531 // Per-axis periodicity selects the smooth von-Mises energy on
3532 // wrapped (Circle) axes and the Gaussian on Euclidean axes.
3533 let periods = coord.effective_axis_periods();
3534 for axis in 0..d {
3535 let log_alpha = rho.log_ard[atom_idx][axis];
3536 // Clamp the log-precision before exponentiating: a raw
3537 // `exp(log_ard)` overflows to `inf` for `log_ard ≳ 709`, and the
3538 // `inf` precision then poisons the ARD energy / curvature with
3539 // `inf · 0.0 = NaN` (#742, Issue 4).
3540 let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
3541 let period = periods[axis];
3542 let mut energy = 0.0;
3543 for row in 0..n {
3544 let v = coord.row(row)[axis];
3545 energy += ArdAxisPrior::eval(alpha, v, period).value;
3546 }
3547 // Negative-log prior for precision alpha. The data-dependent
3548 // energy is the (Gaussian or von-Mises) coordinate prior; the
3549 // accompanying normaliser is the precision log-partition.
3550 //
3551 // Euclidean axes keep the Gaussian normaliser `-0.5 n log α`.
3552 // Periodic (von-Mises) axes use the EXACT von-Mises precision
3553 // log-partition `n[-η + log I0(η)]`, η = α/κ², κ = 2π/P, rather
3554 // than the Gaussian surrogate: the von-Mises partition function
3555 // is `2π I0(η)` (up to the κ Jacobian), so the per-observation
3556 // normaliser is `-η + log I0(η)` and is exact across the cut.
3557 match period {
3558 None => {
3559 acc += energy - 0.5 * (n as f64) * log_alpha;
3560 }
3561 Some(p) => {
3562 let kappa = std::f64::consts::TAU / p;
3563 let eta = alpha / (kappa * kappa);
3564 // Overflow-free `log I0(η)`; `bessel_i0(η).ln()` would be
3565 // `+inf` for `η ≳ 709` (#1113).
3566 let log_i0 = bessel_i0_log_and_ratio(eta).0;
3567 acc += energy + (n as f64) * (-eta + log_i0);
3568 }
3569 }
3570 }
3571 }
3572 Ok(acc)
3573 }
3574
3575 /// Assemble the enlarged `(logits, t)` row-local Arrow-Schur system.
3576 ///
3577 /// Full-batch entry point: a single chunk covering all rows, with the
3578 /// β-tier penalties (decoder smoothness, ARD, analytic β penalties) carrying
3579 /// their full strength. The streaming driver calls
3580 /// [`Self::assemble_arrow_schur_scaled`] directly with a `penalty_scale`
3581 /// equal to the minibatch fraction `n_chunk / N`, so that the sum of the
3582 /// per-chunk β-tier contributions over a full pass reconstructs exactly the
3583 /// single global β penalty (the smoothness/ARD/β terms are functions of `B`
3584 /// and the global coordinates, not of the chunk's rows).
3585 pub fn assemble_arrow_schur(
3586 &mut self,
3587 target: ArrayView2<'_, f64>,
3588 rho: &SaeManifoldRho,
3589 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3590 ) -> Result<ArrowSchurSystem, String> {
3591 self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
3592 }
3593
3594 /// Assemble the row-local Arrow-Schur system with a `penalty_scale` applied
3595 /// to the β-tier (decoder smoothness, ARD prior, analytic β penalties).
3596 ///
3597 /// `penalty_scale == 1.0` recovers the full-batch assembly. The streaming
3598 /// driver passes the minibatch fraction `n_chunk / N` so that the β-tier
3599 /// reduced-Schur and gradient contributions of the chunks sum to exactly one
3600 /// global copy across a full pass (data-fit, assignment-prior, and per-row
3601 /// coord/logit analytic terms are *not* scaled — they are genuine per-row
3602 /// sums).
3603 pub fn assemble_arrow_schur_scaled(
3604 &mut self,
3605 target: ArrayView2<'_, f64>,
3606 rho: &SaeManifoldRho,
3607 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3608 penalty_scale: f64,
3609 ) -> Result<ArrowSchurSystem, String> {
3610 self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3611 target,
3612 rho,
3613 analytic_penalties,
3614 penalty_scale,
3615 SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
3616 )
3617 }
3618
3619 pub(crate) fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3620 &mut self,
3621 target: ArrayView2<'_, f64>,
3622 rho: &SaeManifoldRho,
3623 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3624 penalty_scale: f64,
3625 dense_beta_penalty_probe_max_dim: usize,
3626 ) -> Result<ArrowSchurSystem, String> {
3627 self.assemble_arrow_schur_inner(
3628 target,
3629 rho,
3630 analytic_penalties,
3631 penalty_scale,
3632 dense_beta_penalty_probe_max_dim,
3633 None,
3634 )
3635 }
3636
3637 /// Innermost assembly entry. `forced_layout` overrides the budget-derived
3638 /// active-set layout so a caller can pin the dense (`Forced(None)`) or a
3639 /// specific compact (`Forced(Some(layout))`) path — used by the
3640 /// compact-vs-dense Riemannian-geometry equality regression test to drive
3641 /// both layouts on identical data. `Computed` is the production path:
3642 /// the layout is derived from the assignment mode + `sparse_active_plan`.
3643 pub(crate) fn assemble_arrow_schur_inner(
3644 &mut self,
3645 target: ArrayView2<'_, f64>,
3646 rho: &SaeManifoldRho,
3647 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3648 penalty_scale: f64,
3649 dense_beta_penalty_probe_max_dim: usize,
3650 forced_layout: ForcedRowLayout,
3651 ) -> Result<ArrowSchurSystem, String> {
3652 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3653 return Err(format!(
3654 "SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3655 ));
3656 }
3657 if target.dim() != (self.n_obs(), self.output_dim()) {
3658 return Err(format!(
3659 "SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
3660 self.n_obs(),
3661 self.output_dim(),
3662 target.dim()
3663 ));
3664 }
3665 if rho.log_ard.len() != self.k_atoms() {
3666 return Err(format!(
3667 "SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
3668 rho.log_ard.len(),
3669 self.k_atoms()
3670 ));
3671 }
3672 // `lambda_smooth` is indexed per-atom in the smoothness gradient/curvature
3673 // assembly (`lambda_smooth[atom_idx]`); a too-short vector (e.g. a growth
3674 // move that grew `k_atoms()` without extending ρ — #1556) would panic deep
3675 // in the assembly loop with an opaque index-out-of-bounds. Validate it here
3676 // alongside `log_ard` so the contract violation surfaces as a clear Err.
3677 if rho.log_lambda_smooth.len() != self.k_atoms() {
3678 return Err(format!(
3679 "SaeManifoldTerm::assemble_arrow_schur: log_lambda_smooth length {} != K {}",
3680 rho.log_lambda_smooth.len(),
3681 self.k_atoms()
3682 ));
3683 }
3684 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3685 let ard_len = rho.log_ard[atom_idx].len();
3686 let d = coord.latent_dim();
3687 if ard_len != 0 && ard_len != d {
3688 return Err(format!(
3689 "SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
3690 has len {ard_len}; expected 0 (disabled) or atom dim {d}"
3691 ));
3692 }
3693 }
3694 // Reparameterize each atom's roughness Gram into arc length at the
3695 // current decoder/coordinates (issue #673). This is the single
3696 // chokepoint for both the inner Newton assembly and the undamped
3697 // evidence factorization, so freezing the pullback-metric weight here
3698 // (lagged-diffusivity) keeps the smoothness value, gradient, Kronecker
3699 // Hessian, and REML log-det mutually consistent within each assembly
3700 // and makes the converged penalty — hence the topology evidence —
3701 // gauge-invariant. Constant-speed (periodic) atoms are unaffected.
3702 for atom in &mut self.atoms {
3703 atom.refresh_intrinsic_smooth_penalty();
3704 }
3705 // #1026 — freeze the decoder-repulsion collinearity gate at the SAME
3706 // assembly chokepoint as the smoothness Gram, so the repulsion's
3707 // gradient/curvature (assembled below) and its value (read by the
3708 // line-search `penalized_objective_total`) share one frozen gate.
3709 self.refresh_decoder_repulsion_gate();
3710 // #1625 — freeze the SEPARATION barrier's normalized-coactivation `q_jk`
3711 // at the same chokepoint. The barrier weights its decoder-shape repulsion
3712 // by the routing coactivation, but its gradient treats that weight as a
3713 // constant; recomputing it from the trial logits in the line-search value
3714 // desyncs value vs gradient in the logit block and stalls the inner solve
3715 // (#1625). Freezing it here makes value/gradient/curvature consistent.
3716 self.refresh_barrier_coactivation_gate();
3717 let n = self.n_obs();
3718 let p = self.output_dim();
3719 let k_atoms = self.k_atoms();
3720 let assignment_dim = self.assignment.assignment_coord_dim();
3721 let q = self.assignment.row_block_dim();
3722 let beta_dim = self.beta_dim();
3723 let frame_projection = FrameProjection::new(self);
3724 let beta_offsets = frame_projection.beta_offsets.clone();
3725 let coord_offsets = self.assignment.coord_offsets();
3726 // β-tier decoder smoothness is a global (B-only) penalty; under a
3727 // minibatch pass it is scaled by the chunk fraction so the per-chunk
3728 // contributions sum to one global copy.
3729 // Per-atom decoder-smoothness strengths (#1556): atom k's penalty `S_k`
3730 // is scaled by `λ_smooth[k]·penalty_scale`. The minibatch `penalty_scale`
3731 // multiplies every atom uniformly.
3732 let lambda_smooth: Vec<f64> = rho
3733 .lambda_smooth_vec()
3734 .iter()
3735 .map(|&l| l * penalty_scale)
3736 .collect();
3737 let (assignment_grad, assignment_hdiag) =
3738 assignment_prior_grad_hdiag(&self.assignment, rho)?;
3739
3740 // #1038 softmax entropy: the exact per-row Hessian in logits is dense
3741 // (`H_kj = (λ/τ²) a_k[δ_kj(m−L_k−1)+a_j(L_k+L_j+1−2m)]`), not just the
3742 // `assignment_hdiag` diagonal. Build the shared penalty + `scale = λ/τ²`
3743 // once here so the dense row block written into `block.htt` below, the
3744 // criterion's `log|H|`, and the #1006 θ-adjoint all differentiate the
3745 // SAME operator. JumpReLU / IBP keep their (separately exact) diagonal /
3746 // cross-row channels and leave this `None`. The block is gauge-null in
3747 // isolation (`H·𝟙 = 0`); it is only ever summed onto the gauge-breaking
3748 // data-fit row block before the Cholesky factor, never factored alone.
3749 let softmax_dense: Option<(
3750 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
3751 f64,
3752 )> = match self.assignment.mode {
3753 AssignmentMode::Softmax {
3754 temperature,
3755 sparsity,
3756 } if k_atoms > 1 => {
3757 let inv_tau = 1.0 / temperature;
3758 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
3759 Some((
3760 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
3761 k_atoms,
3762 temperature,
3763 ),
3764 scale,
3765 ))
3766 }
3767 _ => None,
3768 };
3769
3770 // Decoder smoothness penalty: build one KroneckerPenaltyOp per atom
3771 // (structure = λ·S_k ⊗ I_p, offset = beta_offsets[k]) instead of
3772 // materialising the dense K×K block. The gradient is a dense K-vector
3773 // accumulated into `smooth_grad_gb` and written into sys.gb after sys
3774 // is constructed (#296).
3775 let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
3776 // #972 / #977 T1: retain each atom's symmetrised `λ S_k` (`M_k × M_k`) so
3777 // the frame transform can rebuild the smooth penalty in the factored
3778 // coordinate space as `λ S_k ⊗ I_{r_k}` (the `tr(C_kᵀ S_k C_k)` form,
3779 // using `U_kᵀU_k = I`). Unused — and not even read — on the full-`B`
3780 // path, so this is a zero-cost capture there.
3781 let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
3782 let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
3783 // #1117 — rank deficiency is handled at the basis layer: any
3784 // rank-deficient atom was reparametrized onto its data-supported subspace
3785 // at fit entry (`reduce_atoms_to_data_supported_rank`), so the β-tier here
3786 // always sees a full-rank design and needs no step-time data-null
3787 // deflation operator. The well-conditioned (full-rank) path is unchanged.
3788 // Per-atom smoothness-gradient GEMMs `½(S_k+S_kᵀ)·B_k` are independent
3789 // across atoms; batch them across ALL GPUs (uniform-shape tiles) and
3790 // scale by `lambda_smooth` below. `symmetrize = true` reproduces the
3791 // per-atom symmetrised `scaled_s/λ` used by the Kronecker op. Exact CPU
3792 // fallback per atom keeps the result bit-for-bit with the all-CPU path.
3793 let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3794 .atoms
3795 .iter()
3796 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3797 .collect();
3798 let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
3799 for (atom_idx, atom) in self.atoms.iter().enumerate() {
3800 let m = atom.basis_size();
3801 let off = beta_offsets[atom_idx];
3802 // Symmetrise and scale the smoothness penalty matrix.
3803 let mut scaled_s = Array2::<f64>::zeros((m, m));
3804 for i in 0..m {
3805 for j in 0..m {
3806 let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
3807 scaled_s[[i, j]] = lambda_smooth[atom_idx] * s_ij;
3808 }
3809 }
3810 // Gradient: g[beta_i] += (λ_k S_k B_k)[i, out_col]. The (m×m)·(m×p)
3811 // GEMM `½(S+Sᵀ)·B_k` was computed in the multi-GPU batch above; here
3812 // we only apply atom k's `lambda_smooth[atom_idx]`.
3813 let sb = &sym_sb_all[atom_idx] * lambda_smooth[atom_idx];
3814 for out_col in 0..p {
3815 for i in 0..m {
3816 let beta_i = off + i * p + out_col;
3817 smooth_grad_gb[beta_i] += sb[[i, out_col]];
3818 }
3819 }
3820 // IdentityRightKroneckerPenaltyOp: factor_a = λ·S_k (m×m), factor_b = I_p.
3821 smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
3822 factor_a: scaled_s.clone(),
3823 p,
3824 global_offset: off,
3825 k: beta_dim,
3826 }));
3827 // Retain `λ S_k` for the factored rebuild (no-op cost on full-`B`).
3828 smooth_scaled_s.push(scaled_s);
3829 }
3830
3831 // Per-row active-set layout. Engaged for two regimes:
3832 // * JumpReLU — structural gate plus the smooth prior's
3833 // machine-precision support: atoms with
3834 // `(logit - threshold)/tau > -36` enter the compact solve
3835 // ([`jumprelu_in_optimization_band`]). Strictly gated-off atoms
3836 // (logit ≤ threshold) carry zero assignment mass so their data-fit
3837 // reconstruction contribution and data-fit logit JVP are zero, but
3838 // supported atoms keep value-consistent prior gradient in the row block.
3839 // * IBP-MAP at large `K` — the dense `(m_total · p)²` data
3840 // Gram is infeasible, so each row is truncated to its
3841 // top-`k_active` atoms above a relative magnitude cutoff
3842 // ([`Self::sparse_active_plan`]). Small-`K` problems return `None`
3843 // and keep the exact full-support layout.
3844 // The compact row block is sized `q_active = |active| + Σ_{k∈active}
3845 // d_k` instead of the full `q`.
3846 let coord_dims: Vec<usize> = self
3847 .assignment
3848 .coords
3849 .iter()
3850 .map(|c| c.latent_dim())
3851 .collect();
3852 let row_layout: Option<SaeRowLayout> = match forced_layout {
3853 Some(layout) => layout,
3854 None => match self.assignment.mode {
3855 AssignmentMode::ThresholdGate {
3856 threshold,
3857 temperature,
3858 } => Some(SaeRowLayout::from_jumprelu(
3859 n,
3860 k_atoms,
3861 threshold,
3862 temperature,
3863 &self.assignment.logits,
3864 coord_dims.clone(),
3865 self.assignment.coord_offsets(),
3866 )),
3867 // #1408/#1409 — Softmax engages the COMPACT top-`k` row layout
3868 // inside the optimization (no longer a post-fit projection).
3869 // The active set is each row's top-`k_active_cap` softmax atoms
3870 // above the relative cutoff; the cap comes from the user's
3871 // `top_k` (`softmax_active_cap`) and/or the in-core memory budget
3872 // ([`Self::softmax_active_plan`]). The full-`K` softmax
3873 // normalization still forms `a` (the gate map); only the dropped
3874 // tail logits, carrying negligible `O(a)` reconstruction mass and
3875 // `O(a²)` curvature, leave the per-row block.
3876 //
3877 // Coherence (the load-bearing correctness invariant): the
3878 // assembly's softmax curvature branch writes the ACTIVE×ACTIVE
3879 // principal sub-block of the Gershgorin Loewner majorizer
3880 // `D = diag(Σ_j|H_kj|)` (#1419; PSD and `D ⪰ H_entropy`) on the
3881 // compact logit slots — NOT the indefinite `assignment_hdiag`
3882 // diagonal. The logdet ρ-trace
3883 // (`assignment_log_strength_hessian_trace`) iterates the row's
3884 // active logit slots and indexes that SAME majorizer by global
3885 // atom, and the θ-adjoint reads its derivative via `jets.vars`
3886 // (global-atom indexed), so value, log|H|, and Γ differentiate
3887 // ONE operator on the compact support. The FFI's after-the-fit
3888 // top-`k` projection is then a no-op at the optimum.
3889 AssignmentMode::Softmax { .. } => match self.softmax_active_plan() {
3890 Some((k_active_cap, relative_cutoff)) => {
3891 let mut assignments_all = Vec::with_capacity(n);
3892 for row in 0..n {
3893 assignments_all
3894 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3895 }
3896 Some(SaeRowLayout::from_dense_weights(
3897 &assignments_all,
3898 k_active_cap,
3899 relative_cutoff,
3900 coord_dims.clone(),
3901 self.assignment.coord_offsets(),
3902 ))
3903 }
3904 None => None,
3905 },
3906 AssignmentMode::IBPMap { .. } => {
3907 match self.sparse_active_plan() {
3908 Some((k_active_cap, relative_cutoff)) => {
3909 // Build per-row dense assignments once to derive the
3910 // active set; the row loop re-derives `assignments`
3911 // (cheap gate map at the same rho) and reuses these
3912 // active sets.
3913 let mut assignments_all = Vec::with_capacity(n);
3914 for row in 0..n {
3915 assignments_all
3916 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3917 }
3918 // #1414: pass the RELATIVE cutoff through;
3919 // `from_dense_weights` applies it per row against that
3920 // row's own peak `max_k |a_{n,k}|`, matching the
3921 // documented `sparse_active_plan` contract. A single
3922 // global threshold (relative_cutoff · whole-dataset
3923 // peak) wrongly drops every atom of a uniformly-small
3924 // row when another row peaks high.
3925 Some(SaeRowLayout::from_dense_weights(
3926 &assignments_all,
3927 k_active_cap,
3928 relative_cutoff,
3929 coord_dims.clone(),
3930 self.assignment.coord_offsets(),
3931 ))
3932 }
3933 None => None,
3934 }
3935 }
3936 },
3937 };
3938 // #974 likelihood-whitening seam. The single per-row decision: when the
3939 // installed `RowMetric` is a genuinely estimated noise model
3940 // (`whitens_likelihood()` — only `WhitenedStructured`), the
3941 // reconstruction data-fit, its t-block Gauss-Newton row block, AND the
3942 // β-tier data-fit gradient are all assembled through the SAME per-row
3943 // metric `M_n = U_n U_nᵀ = Σ_n^{-1}`. There is exactly ONE construction
3944 // site (the `whiten_rows` closure below), so the value the line-search
3945 // sums and the gradient/Hessian the Newton step solves cannot drift apart
3946 // (the objective↔gradient-desync cure). For Euclidean / OutputFisher /
3947 // no-metric the closure is the identity and every downstream loop is
3948 // byte-identical to the historical isotropic path.
3949 let whitens_likelihood = self
3950 .row_metric
3951 .as_ref()
3952 .is_some_and(|metric| metric.whitens_likelihood());
3953 // #972 / #977 T1: engage the FACTORED Grassmann-coordinate β-tier when
3954 // any atom has an active decoder frame. The closed-form factorization
3955 // `Φᵀ(G ⊗ I_p)Φ = G ⊗ (U_iᵀU_j)` is EXACT only for the isotropic
3956 // likelihood; under an active whitening metric (`whitens_likelihood()`,
3957 // only `WhitenedStructured`) the per-row output factor would be
3958 // `U_iᵀ M_n U_j` and does NOT factor out of the basis Gram, so we fall
3959 // back to the full-`B` path there (frames + whitening is out of scope —
3960 // see #974). The common Euclidean / OutputFisher / no-metric case factors
3961 // cleanly. When `frames_engaged` is false, EVERY β-tier object below is
3962 // assembled bit-for-bit as the historical full-`B` path.
3963 let frames_engaged = self.any_frame_active() && !whitens_likelihood;
3964 // #1407: fixed-decoder mode skips the entire β decoder tier (G/gb/htbeta
3965 // operator/hbb/β-penalties); only per-row htt/gt are produced.
3966 let fixed_decoder = self.fixed_decoder_assembly;
3967 let admission_plan = self
3968 .streaming_plan()
3969 .admitted_or_error(self.n_obs(), self.output_dim(), self.k_atoms())
3970 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
3971 // #1407: fixed-decoder builds NO dense β-Hessian (hbb) — force the
3972 // empty-hbb system constructor so no `beta_dim × beta_dim` workspace is
3973 // taken (the early return skips `reclaim_border_hbb_workspace`).
3974 let dense_beta_curvature = !fixed_decoder
3975 && admission_plan.direct_admitted
3976 && !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
3977 // #1406: the dense per-row cross-block slab `block.htbeta` is only WRITTEN
3978 // (line ~4243) and READ by the solver when `frames_engaged` (the factored
3979 // full-B path, which installs NO matrix-free row operator → the solver's
3980 // `sys_htbeta_apply_row` falls back to the dense slab). On the
3981 // `!frames_engaged` path the cross block is carried entirely by the
3982 // matrix-free Kronecker operator (`set_row_htbeta_operator`, ~line 4491);
3983 // `activate_dense_htbeta_supplement` is never called, so the solver never
3984 // touches `block.htbeta`. Allocating it at `beta_dim = K·M·p` there is the
3985 // ~6 TiB high-K leak (#1405/#1406): allocate ZERO columns instead. Frames
3986 // still use the (much smaller) factored border width.
3987 // #795/#1406/#1407: the non-frames matrix-free path normally holds a
3988 // ZERO-width per-row cross-block slab — the data-fit `H_tβ` is carried by
3989 // the Kronecker row operator (`set_row_htbeta_operator`), and allocating
3990 // the dense slab at `beta_dim = K·M·p` is the high-K memory leak. But an
3991 // ISOMETRY penalty on a coherence-preserving (flat) chart scatters an
3992 // ADDITIONAL Gauss-Newton cross-block into the dense per-row `htbeta`
3993 // slab and flips on `activate_dense_htbeta_supplement` — dropping it would
3994 // leave the Newton system block-diagonal and forfeit the strong `t↔B`
3995 // isometry coupling the circle fit needs to reach KKT stationarity (#795).
3996 // So on the non-frames path widen the slab to `beta_dim` exactly when that
3997 // dense supplement will be written, and keep zero width otherwise.
3998 let dense_isometry_cross_block = !fixed_decoder
3999 && analytic_penalties
4000 .map(|registry| self.registry_writes_dense_isometry_cross_block(registry))
4001 .unwrap_or(false);
4002 let row_htbeta_dim = if fixed_decoder {
4003 // Fixed-decoder mode skips the β tier entirely.
4004 0
4005 } else if frames_engaged {
4006 self.factored_border_dim()
4007 } else if dense_isometry_cross_block {
4008 // Matrix-free data-fit cross-block + dense isometry supplement: the
4009 // supplement is written/read in the full-`B` β coordinate system.
4010 beta_dim
4011 } else {
4012 // Matrix-free path with no dense cross-block supplement.
4013 0
4014 };
4015 // Build the Arrow-Schur system: heterogeneous row dims when a compact
4016 // layout is active, uniform `q` otherwise.
4017 let mut sys = if let Some(ref layout) = row_layout {
4018 let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
4019 if dense_beta_curvature {
4020 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4021 ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
4022 per_row_dims,
4023 beta_dim,
4024 hbb_workspace,
4025 row_htbeta_dim,
4026 )
4027 } else {
4028 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4029 ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
4030 per_row_dims,
4031 beta_dim,
4032 row_htbeta_dim,
4033 )
4034 }
4035 } else if dense_beta_curvature {
4036 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4037 ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
4038 n,
4039 q,
4040 beta_dim,
4041 hbb_workspace,
4042 row_htbeta_dim,
4043 )
4044 } else {
4045 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4046 ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
4047 };
4048 // Apply accumulated smoothness-penalty gradients into sys.gb.
4049 for (i, g) in smooth_grad_gb.iter().enumerate() {
4050 sys.gb[i] += g;
4051 }
4052 // `w_dim` is the whitened output dimension: `rank` of the metric factor
4053 // when whitening, else `p` (identity). `error_white` is the whitened
4054 // residual `U_nᵀ r_n ∈ ℝ^{w_dim}` whose squared norm is `r_nᵀ M_n r_n`,
4055 // shared by the value path, the t-block GN, and (lifted back to p-space)
4056 // the β-tier gradient.
4057 let w_dim = match self.row_metric.as_ref() {
4058 Some(metric) if whitens_likelihood => metric.metric_rank(),
4059 _ => p,
4060 };
4061 // Data-fit Gauss-Newton β-Hessian is block-diagonal across the `p`
4062 // output channels and identical in each: with the flat β layout
4063 // `β[μ·p + oc] = B[μ, oc]` (μ enumerating (atom, basis_col)) the GN
4064 // outer product `Jβᵀ Jβ` couples only equal `oc`, with the same
4065 // `(M_total × M_total)` block `G[μ, μ'] = Σ_rows (a_k φ_k[m])(a_{k'} φ_{k'}[m'])`
4066 // for every channel. So `H_data = G ⊗ I_p`. The `μ` index of an `a_phi`
4067 // entry whose global β base is `beta_base` is `beta_base / p` (every
4068 // `beta_offset` and the `basis_col·p` stride are multiples of `p`).
4069 //
4070 // `G` is only non-zero on `(atom_i, atom_j)` pairs that co-occur in
4071 // some row's active set, so we accumulate it as a sparse map of dense
4072 // per-atom-pair `(m_i × m_j)` blocks keyed by `(atom_i, atom_j)` rather
4073 // than as a dense `(m_total × m_total)` matrix. At `K = 100K` with
4074 // per-row active sets of size `k_active ≪ K`, only `O(N · k_active²)`
4075 // pairs are ever touched, so the data Gram (and every matvec /
4076 // diagonal pass over it via `SparseBlockKroneckerPenaltyOp`) tracks the
4077 // active atoms instead of `K²`. In the dense full-support layout the
4078 // map degenerates to every co-occurring pair, reproducing the dense
4079 // Gram exactly. A `BTreeMap` key order keeps the installed op's
4080 // fingerprint deterministic. The `μ`-space offset of atom `k` is
4081 // `beta_offsets[k] / p`.
4082 type SaeGBlocks = std::collections::BTreeMap<(usize, usize), Array2<f64>>;
4083 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
4084 let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
4085 // Stick-breaking prior for IBP-MAP depends only on (k_atoms, alpha_eff)
4086 // which are constant across rows for the current rho; precompute once.
4087 let ibp_prior_vec = match self.assignment.mode {
4088 AssignmentMode::IBPMap { .. } => {
4089 let alpha = self
4090 .assignment
4091 .resolved_ibp_alpha(rho)
4092 .ok_or_else(|| "IBP assignment alpha resolution failed".to_string())?;
4093 Some(ordered_geometric_shrinkage_prior(k_atoms, alpha).to_vec())
4094 }
4095 _ => None,
4096 };
4097 let ibp_prior_slice = ibp_prior_vec.as_deref();
4098 // #991 design honesty weights (mean-1 HT inclusion corrections); see
4099 // the seam comment at the per-row residual below.
4100 let row_loss_w = self.row_loss_weights.as_deref();
4101 // Dense full-support index `[0, k_atoms)`, used by the row loop when no
4102 // compact layout is engaged so the active-atom iteration is uniform.
4103 let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
4104 // Per-atom per-axis periodicity, hoisted out of the row loop. Selects
4105 // the smooth von-Mises coordinate prior on wrapped (Circle) axes and
4106 // the Gaussian prior on Euclidean axes; see `ArdAxisPrior`.
4107 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
4108 .assignment
4109 .coords
4110 .iter()
4111 .map(|coord| coord.effective_axis_periods())
4112 .collect();
4113 struct SaeAssemblyRow {
4114 pub(crate) row: usize,
4115 pub(crate) block: ArrowRowBlock,
4116 pub(crate) gb_delta: Vec<(usize, f64)>,
4117 pub(crate) g_blocks: SaeGBlocks,
4118 pub(crate) kron_a_phi: Option<Vec<(usize, f64)>>,
4119 pub(crate) kron_jac: Option<Vec<f64>>,
4120 }
4121
4122 // Per-row scratch reused across all rows a rayon worker processes
4123 // (#1017). The assembly closure is re-run every inner Newton iteration ×
4124 // every outer ρ evaluation; allocating these eight loop-invariant-sized
4125 // buffers (`k_atoms·p`, several `p`, one `q·max(w_dim,p)`) once per
4126 // worker via `map_init` — rather than once per (row × assembly) inside
4127 // the closure — removes the dominant small-allocation traffic the
4128 // eu-stack profile attributed to allocator/barrier spin at the SAE LLM
4129 // shape (p≈5120). Every buffer is fully filled (or `.fill(0.0)`'d) before
4130 // it is read each row, so reuse is bit-identical to the fresh-alloc path;
4131 // `gb_delta`/`g_blocks` are NOT scratch (they move into the returned
4132 // `SaeAssemblyRow`) and stay allocated per row.
4133 struct RowScratch {
4134 pub(crate) decoded: Array2<f64>,
4135 pub(crate) dg_buf: Vec<f64>,
4136 pub(crate) fitted: Array1<f64>,
4137 pub(crate) error: Array1<f64>,
4138 pub(crate) error_white: Vec<f64>,
4139 pub(crate) error_metric: Array1<f64>,
4140 pub(crate) jac_white: Vec<f64>,
4141 pub(crate) decoded_scratch: Vec<f64>,
4142 // #1557 — per-worker scratch for the row assignment vector (filled via
4143 // `_into`, not allocated per row); full `k_atoms`, global-atom indexed.
4144 pub(crate) assignments: Array1<f64>,
4145 }
4146 // #1410: size the per-worker scratch by the COMPACT row dimensions, not
4147 // full `K`/`q`. With a compact layout the assembly only ever touches each
4148 // row's active atoms (≤ `max_active`) and its compact tangent block
4149 // (≤ `max_q_row`); allocating `decoded` at `k_atoms·p` and `jac_white` at
4150 // `q·max(w_dim,p)` was the per-worker `O(K)` blow-up (≈11 GiB/worker at
4151 // K=100k, p=5120 — and `map_init` gives every Rayon worker its own copy).
4152 // Without a layout the dense path needs full `k_atoms`/`q`. `decoded` rows
4153 // are addressed by COMPACT SLOT in the compact branch below (the dense
4154 // branch keeps global-atom rows), so the row count is the max active set.
4155 //
4156 // #1410/#1408/#1409: SOFTMAX now ALSO takes the `Some(layout)` branch
4157 // whenever a `top_k` cap (`set_softmax_active_cap`) or an in-core memory
4158 // breach engages `softmax_active_plan` → `from_dense_weights`, so its
4159 // per-worker `decoded`/`jac_white` scratch is the COMPACT
4160 // `max_active`/`max_q_row` size too — no longer the full `(k_atoms·p)` /
4161 // `(q·max(w_dim,p))` blow-up. JumpReLU / IBP-MAP likewise pay only
4162 // `max_active`. The remaining `None` (full-`K`) branch is the UNCAPPED
4163 // softmax / no-budget-breach case, which genuinely assembles the dense
4164 // entropy block over all `K`; capping it (the compact contract) removes
4165 // the per-worker `O(K)` footprint entirely. (#1410: the residual per-row
4166 // `O(K)` softmax-majorizer scratch — a `row_logits` copy and the full-`K`
4167 // `d`/`H_entropy` blocks — is removed separately; see the active-only
4168 // `active_softmax_gershgorin_majorizer_entry` /
4169 // `softmax_dense_entropy_hessian_entry` helpers below.)
4170 let (decoded_rows, scratch_q) = match row_layout.as_ref() {
4171 Some(layout) => {
4172 let max_active = (0..n)
4173 .map(|r| layout.active_atoms[r].len())
4174 .max()
4175 .unwrap_or(0)
4176 .max(1);
4177 let max_q_row = (0..n)
4178 .map(|r| layout.row_q_active(r))
4179 .max()
4180 .unwrap_or(q)
4181 .max(1);
4182 (max_active, max_q_row)
4183 }
4184 None => (k_atoms, q),
4185 };
4186 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4187 // #1033 large-n: fold the per-row assembly results in row-ordered CHUNKS
4188 // rather than collecting all `n` `SaeAssemblyRow`s at once. The previous
4189 // path materialized the FULL `Vec<SaeAssemblyRow>` (every row's htt/gt
4190 // block + per-row `g_blocks` + `kron_a_phi`/`kron_jac`) AND the fold
4191 // destinations simultaneously — a ~2× transient peak over the resident
4192 // system during the fold, the assembly-side OOM cliff at large `n`. By
4193 // collecting one chunk, folding it into `sys.rows`/`g_blocks`/`kron_*`,
4194 // and dropping the chunk's `Vec` before the next chunk, the transient
4195 // intermediate is bounded to `O(chunk_size)` while the resident output is
4196 // unchanged. The fold stays STRICTLY row-ascending (chunk `[c0..c1)` then
4197 // `[c1..c2)`, rows in order within each chunk), so every `+=` into
4198 // `sys.gb`, the `g_blocks` BTreeMap, and the `kron_*` pushes lands in the
4199 // identical order as the single-pass fold — bit-for-bit the same system.
4200 // Chunk width is the admission plan's `chunk_size` (the same value
4201 // `streaming_plan` sizes for the matrix-free window), floored so a tiny
4202 // plan still makes forward progress.
4203 let assembly_chunk_rows = self
4204 .assembly_chunk_override
4205 .unwrap_or(admission_plan.chunk_size)
4206 .clamp(1, n.max(1));
4207 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4208 let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
4209 let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
4210 let mut chunk_start = 0usize;
4211 while chunk_start < n {
4212 let chunk_end = (chunk_start + assembly_chunk_rows).min(n);
4213 let mut fold_offset_in_chunk = 0usize;
4214 let row_results: Vec<SaeAssemblyRow> = (chunk_start..chunk_end)
4215 .into_par_iter()
4216 .map_init(
4217 || RowScratch {
4218 decoded: Array2::<f64>::zeros((decoded_rows, p)),
4219 dg_buf: vec![0.0_f64; p],
4220 fitted: Array1::<f64>::zeros(p),
4221 error: Array1::<f64>::zeros(p),
4222 error_white: vec![0.0_f64; w_dim],
4223 error_metric: Array1::<f64>::zeros(p),
4224 jac_white: vec![0.0_f64; scratch_q * w_dim.max(p)],
4225 decoded_scratch: vec![0.0_f64; p],
4226 assignments: Array1::<f64>::zeros(k_atoms),
4227 },
4228 |scratch, row| -> Result<SaeAssemblyRow, String> {
4229 // #1557 — mark this rayon row worker as a nested data-parallel
4230 // region so any faer GEMM reached transitively from the per-row
4231 // assembly (frame `Uᵀ` products, the per-row cross-block /
4232 // Schur-accumulation matmuls, the Riemannian projections) pins to
4233 // `Par::Seq` via `effective_global_parallelism` instead of
4234 // re-fanning the global Rayon pool against this outer fan-out
4235 // (the `spindle` barrier-spin). Serial vs parallel over these tiny
4236 // per-row blocks is a single small product, so the result is
4237 // bit-identical. The guard is held for the whole closure body
4238 // including its `?`/`return` paths.
4239 with_nested_parallel(|| {
4240 let RowScratch {
4241 decoded,
4242 dg_buf,
4243 fitted,
4244 error,
4245 error_white,
4246 error_metric,
4247 jac_white,
4248 decoded_scratch,
4249 assignments,
4250 } = scratch;
4251 let mut gb_delta: Vec<(usize, f64)> = Vec::new();
4252 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4253 // #1557 — fill per-worker scratch (bit-identical to alloc path).
4254 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
4255 self.assignment
4256 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
4257 // Reconstruction uses the row's active support: for the dense
4258 // full-support layout this is all atoms (exact); for a compact
4259 // layout the dropped atoms carry negligible `O(a)` reconstruction
4260 // mass and zero curvature, so excluding them keeps `fitted`,
4261 // `error`, and the logit-JVP cross term `(decoded[k] − fitted)`
4262 // mutually consistent with the curvature actually assembled.
4263 fitted.fill(0.0);
4264 let row_active_owned: Option<&[usize]> =
4265 row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
4266 match row_active_owned {
4267 Some(active) => {
4268 // #1410: `decoded` is a compact (max_active × p) buffer
4269 // here; index it by the active-set SLOT `j` (the same
4270 // index the compact tangent block / `coord_starts` use),
4271 // NOT the global `atom_idx`.
4272 for (j, &atom_idx) in active.iter().enumerate() {
4273 let a_k = assignments[atom_idx];
4274 self.atoms[atom_idx]
4275 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4276 for out_col in 0..p {
4277 decoded[[j, out_col]] = decoded_scratch[out_col];
4278 fitted[out_col] += a_k * decoded_scratch[out_col];
4279 }
4280 }
4281 }
4282 None => {
4283 for atom_idx in 0..k_atoms {
4284 let a_k = assignments[atom_idx];
4285 self.atoms[atom_idx]
4286 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4287 for out_col in 0..p {
4288 decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
4289 fitted[out_col] += a_k * decoded_scratch[out_col];
4290 }
4291 }
4292 }
4293 }
4294 for out_col in 0..p {
4295 error[out_col] = fitted[out_col] - target[[row, out_col]];
4296 }
4297 // #991 design-honesty seam: a per-row scalar weight `w_row` on the
4298 // reconstruction channel is exactly the metric `w_row · I_p`, so it
4299 // is realized as a `√w_row` scaling of the THREE row-local data
4300 // quantities at their construction sites — this residual, the
4301 // latent Jacobian (below), and the β basis load `a·φ` (below).
4302 // Every downstream data object then carries exactly one factor of
4303 // `w_row` (gt, htt, htbeta, the β Gram `G`, and the β gradient),
4304 // matching the `w_row`-weighted value `loss_scaled` sums; the
4305 // per-row latent priors (assignment / ARD, added to `gt`/`htt`
4306 // further down) are deliberately unweighted — see the
4307 // `row_loss_weights` field docs. `None` ⇒ `sqrt_row_w == 1.0` and
4308 // no multiply is applied (bit-identical unweighted path).
4309 let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
4310 if sqrt_row_w != 1.0 {
4311 for out_col in 0..p {
4312 error[out_col] *= sqrt_row_w;
4313 }
4314 }
4315 // #974 seam (step 1/2): whiten the per-row residual ONCE.
4316 // * not whitening ⇒ `error_white == error` (length p) and
4317 // `error_metric == error`; every downstream loop is the
4318 // historical isotropic path bit-for-bit.
4319 // * whitening ⇒ `error_white = U_nᵀ r_n ∈ ℝ^{w_dim}` (its squared
4320 // norm is `r_nᵀ M_n r_n`, the value the data-fit sums) and
4321 // `error_metric = U_n (U_nᵀ r_n) = M_n r_n ∈ ℝ^p` (the p-space
4322 // metric-applied residual the β-tier gradient contracts).
4323 match self.row_metric.as_ref() {
4324 Some(metric) if whitens_likelihood => {
4325 let wr = metric.whiten_residual_row(row, error.view());
4326 for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
4327 *slot = v;
4328 }
4329 let mr = metric.apply_metric_row(row, error.view());
4330 for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
4331 *slot = v;
4332 }
4333 }
4334 _ => {
4335 for out_col in 0..p {
4336 error_white[out_col] = error[out_col];
4337 error_metric[out_col] = error[out_col];
4338 }
4339 }
4340 }
4341
4342 // Determine whether this row uses the compact active-set layout.
4343 // * JumpReLU: gated atoms plus the smooth prior's
4344 // machine-precision support enter.
4345 // * IBP-MAP at large K: only the top-`k_active` atoms.
4346 // * Otherwise (small K): the dense uniform-q layout.
4347 let (q_row, mut local_jac_row) = if let Some(layout) = row_layout.as_ref() {
4348 let active = &layout.active_atoms[row];
4349 let starts = &layout.coord_starts[row];
4350 let q_active = layout.row_q_active(row);
4351 let mut jac_compact = Array2::<f64>::zeros((q_active, p));
4352 // Logit JVP rows for active atoms only, using the per-mode
4353 // assignment sensitivity `da_k/dl_k` contracted into the
4354 // decoded / fitted-corrected output direction.
4355 let logits_row = self.assignment.logits.row(row);
4356 for (j, &k) in active.iter().enumerate() {
4357 fill_active_atom_logit_jvp(
4358 ActiveAtomLogitJvp {
4359 mode: self.assignment.mode,
4360 k,
4361 logit_k: logits_row[k],
4362 a_k: assignments[k],
4363 // #1410: compact slot `j`, not global atom `k`.
4364 decoded_k: decoded.row(j),
4365 fitted: fitted.view(),
4366 ibp_prior: ibp_prior_slice,
4367 compact_index: j,
4368 // #1026/#1033: a FIXED logit (ungated, or every
4369 // atom under frozen routing) has a constant gate
4370 // ⇒ zero logit-JVP.
4371 ungated: self.assignment.logit_is_fixed(k),
4372 },
4373 &mut jac_compact,
4374 );
4375 }
4376 // Coordinate JVP rows for active atoms only.
4377 for (j, &k) in active.iter().enumerate() {
4378 let d = self.atoms[k].latent_dim;
4379 let a_k = assignments[k];
4380 let coord_start = starts[j];
4381 for axis in 0..d {
4382 self.atoms[k].fill_decoded_derivative_row(
4383 row,
4384 axis,
4385 dg_buf.as_mut_slice(),
4386 );
4387 for out_col in 0..p {
4388 jac_compact[[coord_start + axis, out_col]] =
4389 a_k * dg_buf[out_col];
4390 }
4391 }
4392 }
4393 (q_active, jac_compact)
4394 } else {
4395 // Fresh per-row Jacobian, structurally identical to the
4396 // JumpReLU branch: every (q × p) element is unconditionally
4397 // overwritten below (assignment-chart JVP rows + coordinate rows), so the
4398 // `Array2::zeros` allocation needs no separate `fill(0.0)` and
4399 // the populated buffer is returned by move without a clone.
4400 let mut jac_row = Array2::<f64>::zeros((q, p));
4401 fill_assignment_logit_jvp_rows(
4402 self.assignment.mode,
4403 self.assignment.logits.row(row),
4404 assignments.view(),
4405 decoded.view(),
4406 fitted.view(),
4407 ibp_prior_slice,
4408 // #1026/#1033: zero logit-JVP rows for FIXED-logit atoms
4409 // (ungated, and all atoms under frozen routing).
4410 &self.assignment.fixed_logit_mask(),
4411 &mut jac_row,
4412 );
4413 // Coordinate columns for all atoms.
4414 for atom_idx in 0..k_atoms {
4415 let d = self.atoms[atom_idx].latent_dim;
4416 let off = coord_offsets[atom_idx];
4417 let a_k = assignments[atom_idx];
4418 for axis in 0..d {
4419 self.atoms[atom_idx].fill_decoded_derivative_row(
4420 row,
4421 axis,
4422 dg_buf.as_mut_slice(),
4423 );
4424 for out_col in 0..p {
4425 jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
4426 }
4427 }
4428 }
4429 (q, jac_row)
4430 };
4431
4432 // #991 design-honesty seam, Jacobian leg: scale the row's latent
4433 // Jacobian by `√w_row` BEFORE the whitening / Kronecker capture so
4434 // htt (= J̃J̃ᵀ), the data part of gt (= J̃ẽ, the residual already
4435 // carries its own √w_row), and the htbeta cross block (J paired
4436 // with the √w_row-scaled β load below) each carry exactly one
4437 // factor of `w_row`. No-op on the unweighted path.
4438 if sqrt_row_w != 1.0 {
4439 for a in 0..q_row {
4440 for out_col in 0..p {
4441 local_jac_row[[a, out_col]] *= sqrt_row_w;
4442 }
4443 }
4444 }
4445
4446 // #974 seam (step 2/2): whiten the per-row Jacobian through the SAME
4447 // metric the residual was whitened by. `jac_white[a*w_dim + k]` holds
4448 // `J̃[a, k] = Σ_out U_n[out, k] · J_n[a, out]` so the t-block
4449 // Gauss-Newton row block is `htt = J̃ J̃ᵀ = J_n M_n J_nᵀ` and
4450 // `gt = J̃ ẽ = J_nᵀ M_n r_n`. When not whitening, `w_dim == p` and the
4451 // whitened jac equals the raw Jacobian, so htt/gt are byte-identical
4452 // to the historical isotropic assembly. Because the SAME `error_white`
4453 // feeds both the value-path data-fit (Σ½ ẽ²) and this gradient
4454 // (J̃ ẽ), the objective and its t-block gradient share one whitening
4455 // — they cannot desync.
4456 if whitens_likelihood {
4457 if let Some(metric) = self.row_metric.as_ref() {
4458 for a in 0..q_row {
4459 for k in 0..w_dim {
4460 let mut acc = 0.0;
4461 // U_n[out, k] read through the metric's factor layout.
4462 for out_col in 0..p {
4463 acc += metric.factor_entry(row, out_col, k)
4464 * local_jac_row[[a, out_col]];
4465 }
4466 jac_white[a * w_dim + k] = acc;
4467 }
4468 }
4469 }
4470 } else {
4471 for a in 0..q_row {
4472 for out_col in 0..p {
4473 jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
4474 }
4475 }
4476 }
4477
4478 // Build the per-row Arrow-Schur block at the row's active dim.
4479 let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
4480 for a in 0..q_row {
4481 let jac_a = &jac_white[a * w_dim..(a + 1) * w_dim];
4482 let g = jac_a
4483 .iter()
4484 .zip(error_white.iter())
4485 .map(|(&j, &e)| j * e)
4486 .sum::<f64>();
4487 block.gt[a] += g;
4488 for b in 0..q_row {
4489 let jac_b = &jac_white[b * w_dim..(b + 1) * w_dim];
4490 let h = jac_a
4491 .iter()
4492 .zip(jac_b.iter())
4493 .map(|(&ja, &jb)| ja * jb)
4494 .sum::<f64>();
4495 block.htt[[a, b]] += h;
4496 }
4497 }
4498
4499 // Assignment prior in logit space.
4500 // For compact layout: position `j` = active_atoms index.
4501 // For dense layout: position `atom_idx` directly.
4502 //
4503 // H-consistency note (#1006 audit / #1416 update). This
4504 // `assignment_hdiag` is the assignment channel's raw diagonal
4505 // curvature, added un-majorized. It is exact for JumpReLU and exact
4506 // within each IBP row/column diagonal, and stores ONLY the diagonal of
4507 // two full-Hessian structures — but those off-diagonal structures are
4508 // now carried elsewhere, not dropped:
4509 //
4510 // * softmax entropy has dense within-row Hessian
4511 // H_kj = (λ/τ²) a_k[δ_kj(m-L_k-1) + a_j(L_k+L_j+1-2m)];
4512 // this diagonal stores its Gershgorin Loewner majorizer (#1419).
4513 // * IBP empirical-π has cross-row rank-one terms per column
4514 // H_(i,k),(j,k) = w score_derivative_k z'_ik z'_jk for i != j.
4515 // This per-row diagonal stores only the diagonal/self-row part;
4516 // the FULL rank-one cross-row block `U D Uᵀ` is now INSTALLED as a
4517 // separate Woodbury source by `set_ibp_cross_row_source` (#1038),
4518 // so the assembled operator is `H_full = H₀' + U D Uᵀ` on the
4519 // NO-SELF base `H₀' = H₀ − Σ_k d_k diag(z'_ik²)` (self term
4520 // downdated, see `IbpCrossRowSource::self_term_downdate`). The
4521 // scalar `D`-coefficient `d_k = w·s'_k` is
4522 // `IbpHessianDiagThirdChannels::cross_row_d` (FD-verified against
4523 // ∂²value/∂ℓ_ik∂ℓ_jk in
4524 // `ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian`),
4525 // and `z_jac` carries `u_k`'s entries `z'_ik`.
4526 //
4527 // The criterion's log|H| and Γ adjoint differentiate this SAME
4528 // `H_full`: the ρ-trace adds the cross-row off-diagonal in
4529 // `assignment_log_strength_hessian_trace` (#1416, dense AND compact
4530 // layouts) and the θ-adjoint adds it in `logdet_theta_adjoint`
4531 // (#1416/#1641), so value and gradient stay on one operator.
4532 let assignment_base = row * k_atoms;
4533 if let Some(layout) = row_layout.as_ref() {
4534 let active = &layout.active_atoms[row];
4535 // #1408/#1409 softmax compact curvature: the entropy
4536 // Hessian diagonal in `assignment_hdiag` is INDEFINITE,
4537 // so on a compact softmax layout write the Gershgorin
4538 // Loewner majorizer `D_kk = Σ_j|H_kj|` (#1419) — the same
4539 // PSD operator the dense softmax branch writes — at each
4540 // active logit slot. `D` is diagonal, so its active
4541 // principal sub-block is `diag(D_kk : k ∈ active)`; each
4542 // `D_kk` is the FULL-`K` abs-row-sum, so it still
4543 // dominates the active principal sub-block of `H_entropy`
4544 // (a genuine majorizer on the retained support). The
4545 // gradient stays the EXACT entropy gradient (it sets the
4546 // fixed point), so majorizing only conditions the Newton
4547 // step. JumpReLU/IBP keep their (exact) diagonal.
4548 //
4549 // #1410: compute only the active `D_kk` directly from this
4550 // row's softmax assignments `a` (= `assignments`, already
4551 // in hand), via `active_softmax_gershgorin_majorizer_entry`.
4552 // The previous `psd_majorizer_abs_row_sums(&row_logits, ..)`
4553 // call allocated TWO length-`K` per-row scratch vectors (a
4554 // fresh `row_logits` copy and the full-`K` returned `d`)
4555 // only to read `d[k]` for the `≤ top_k` active `k` — an
4556 // `O(K)` per-row allocation on the path the compact
4557 // contract keeps `K`-free. The shared `m = Σ_j a_j l_j` is
4558 // the one irreducible `O(K)` pass, computed once per row.
4559 let assignments_slice = assignments
4560 .as_slice()
4561 .expect("softmax assignments row must be contiguous");
4562 let majorizer_log_mean: Option<f64> = softmax_dense
4563 .as_ref()
4564 .map(|_| softmax_majorizer_log_mean(assignments_slice));
4565 for (j, &k) in active.iter().enumerate() {
4566 block.gt[j] += assignment_grad[assignment_base + k];
4567 match (softmax_dense.as_ref(), majorizer_log_mean) {
4568 (Some((_penalty, scale)), Some(m)) => {
4569 block.htt[[j, j]] +=
4570 active_softmax_gershgorin_majorizer_entry(
4571 assignments_slice,
4572 k,
4573 m,
4574 *scale,
4575 );
4576 }
4577 _ => block.htt[[j, j]] += assignment_hdiag[assignment_base + k],
4578 }
4579 }
4580 } else {
4581 for free_idx in 0..assignment_dim {
4582 block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
4583 }
4584 if let Some((penalty, scale)) = softmax_dense.as_ref() {
4585 // #1419: write the genuine Gershgorin Loewner majorizer
4586 // `D = diag(Σ_j|H_kj|)` of the exact entropy Hessian onto the
4587 // row's logit block in place of the EXACT entropy Hessian. The
4588 // entropy Hessian is INDEFINITE (concave directions on
4589 // long-tailed rows), which drove the per-row evidence block
4590 // non-PD and forced the downstream Faddeev–Popov deflation to
4591 // flatten data-relevant logit directions (under-identifying the
4592 // atoms). `D` is a nonnegative diagonal, hence exactly PSD and
4593 // PD-preserving like the previous Fisher surrogate, so the block
4594 // stays PD and the deflation no longer fires on the entropy
4595 // block. Unlike the Fisher metric `G = scale·(diag(a) − a aᵀ)`,
4596 // which is PSD but NOT a majorizer (`G − H_entropy` can be
4597 // indefinite — K=2, a=(0.95,0.05): G₁₁=0.0475 < H₁₁=0.0784,
4598 // #1419), `D` actually satisfies `D ⪰ H_entropy` and `D ⪰ 0`,
4599 // so it is a true MM/Loewner curvature majorizer. Because the
4600 // entropy penalty is a FIXED prior whose stationary point is set
4601 // by its (unchanged) EXACT gradient, replacing its curvature
4602 // with the majorizer only conditions the Newton step and the
4603 // Laplace normalizer's curvature operator — it does NOT move the
4604 // optimum.
4605 //
4606 // Softmax uses the REDUCED K−1 free-logit chart (the last
4607 // reference logit is fixed at 0, `assignment_coord_dim() = K−1`).
4608 // Holding z_{K-1} fixed, the reduced curvature over the free
4609 // logits 0..K−1 is exactly the top-left (K−1)×(K−1) submatrix of
4610 // the full K×K majorizer (the fixed logit contributes no
4611 // row/column to the free curvature). The criterion's `log|H|`
4612 // and the #1006 θ-adjoint differentiate this SAME `D` (see the
4613 // `row_psd_majorizer_logit_derivative` site below), so value and
4614 // adjoint stay on one exact branch.
4615 let row_logits: Vec<f64> = (0..k_atoms)
4616 .map(|k| self.assignment.logits[[row, k]])
4617 .collect();
4618 let h_dense = penalty.row_psd_majorizer(&row_logits, *scale);
4619 for ki in 0..assignment_dim {
4620 for kj in 0..assignment_dim {
4621 block.htt[[ki, kj]] += h_dense[[ki, kj]];
4622 }
4623 }
4624 } else {
4625 for free_idx in 0..assignment_dim {
4626 block.htt[[free_idx, free_idx]] +=
4627 assignment_hdiag[assignment_base + free_idx];
4628 }
4629 }
4630 }
4631
4632 // ARD on each on-atom coordinate.
4633 // For compact layout: only active atoms; coord positions use compact starts.
4634 // For dense layout: all atoms; coord positions use coord_offsets.
4635 if let Some(layout) = row_layout.as_ref() {
4636 let active = &layout.active_atoms[row];
4637 let starts = &layout.coord_starts[row];
4638 for (j, &k) in active.iter().enumerate() {
4639 let coord = &self.assignment.coords[k];
4640 let d = coord.latent_dim();
4641 if rho.log_ard[k].is_empty() {
4642 continue;
4643 }
4644 if rho.log_ard[k].len() != d {
4645 return Err(format!(
4646 "ARD rho atom {k} has len {} but atom dim is {d}",
4647 rho.log_ard[k].len()
4648 ));
4649 }
4650 let row_t = coord.row(row);
4651 let periods = &ard_axis_periods[k];
4652 for axis in 0..d {
4653 // ARD on coords is a genuine per-row prior (each row
4654 // contributes the per-axis prior energy), so it is NOT
4655 // minibatch-scaled — the per-chunk row sums already
4656 // reconstruct the full coordinate prior across a pass.
4657 // The value (`ard_value`/`loss.ard`) and the gradient
4658 // both come from the SAME `ArdAxisPrior` energy, so they
4659 // stay FD-consistent on periodic axes. The exact
4660 // von-Mises curvature `V'' = α·cos(κt)` is INDEFINITE —
4661 // it goes negative for |t| past a quarter period — so
4662 // writing it raw into the Newton/Schur `htt` diagonal
4663 // makes that PSD curvature block indefinite and the Schur
4664 // Cholesky (used both for the Newton step and the exact
4665 // log-det) fails on a non-PD pivot. Accumulate the PSD
4666 // majorizer `max(V'', 0)` instead, exactly as
4667 // `add_sae_coord_penalty` does for the registry coord
4668 // penalties: the positive part keeps `htt` PSD so the
4669 // factorization succeeds, and majorizing the curvature of
4670 // a fixed prior only damps the Newton step — it does not
4671 // move the stationary point (the gradient, which sets the
4672 // fixed point, stays the exact `V'`).
4673 let alpha =
4674 SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
4675 let prior =
4676 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4677 block.gt[starts[j] + axis] += prior.grad;
4678 block.htt[[starts[j] + axis, starts[j] + axis]] +=
4679 prior.hess.max(0.0);
4680 }
4681 }
4682 } else {
4683 for atom_idx in 0..k_atoms {
4684 let coord = &self.assignment.coords[atom_idx];
4685 let d = coord.latent_dim();
4686 if rho.log_ard[atom_idx].is_empty() {
4687 continue;
4688 }
4689 if rho.log_ard[atom_idx].len() != d {
4690 return Err(format!(
4691 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
4692 rho.log_ard[atom_idx].len()
4693 ));
4694 }
4695 let off = coord_offsets[atom_idx];
4696 let row_t = coord.row(row);
4697 let periods = &ard_axis_periods[atom_idx];
4698 for axis in 0..d {
4699 // PSD-majorize the (possibly negative) von-Mises curvature
4700 // into the Newton/Schur `htt` block; see the compact-layout
4701 // branch above for why `max(V'', 0)` is required to keep
4702 // `htt` PD (the exact `V'' = α·cos κt` is indefinite past a
4703 // quarter period and breaks the Schur/log-det Cholesky).
4704 let alpha = SaeManifoldRho::stable_exp_strength(
4705 rho.log_ard[atom_idx][axis],
4706 );
4707 let prior =
4708 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4709 block.gt[off + axis] += prior.grad;
4710 block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
4711 }
4712 }
4713 }
4714
4715 // Beta gradient/Hessian — Kronecker form J_β = φᵀ ⊗ I_p.
4716 //
4717 // The per-row beta Jacobian is
4718 // J_β[out_col, beta_idx] = a_k · phi_k[basis_col] if out_col == out_col(beta_idx)
4719 // 0 otherwise
4720 // so the data-fit Gauss-Newton beta-Hessian factors as a rank-`p`
4721 // sum of outer products. We pre-compute the per-(atom, basis_col)
4722 // scalar `a_k · phi_k` once and reuse it across the `out_col`
4723 // and inner `(atom_j, basis_col2)` loops.
4724 //
4725 // Full-B rows keep the matrix-free Kronecker path below. Factored
4726 // rows write the `q_i × Σ M_k r_k` C-space cross slab directly by
4727 // folding each output-channel contribution through the atom frame,
4728 // so no `q_i × β_dim` slab is ever materialized.
4729 //
4730 // Only the row's active atoms contribute `a_phi` support and data
4731 // curvature: in a compact layout (JumpReLU gate or large-K
4732 // top-`k_active` truncation) the inactive atoms carry zero (gated)
4733 // or sub-cutoff assignment mass and are excluded — this is what
4734 // keeps both the htbeta support and the `G` accumulation
4735 // `O(k_active)` rather than `O(K)`. In the dense full-support
4736 // layout `row_active` spans all atoms.
4737 let row_active: &[usize] = match row_layout.as_ref() {
4738 Some(layout) => layout.active_atoms[row].as_slice(),
4739 None => &all_atoms_index,
4740 };
4741 // #1407: in fixed-decoder mode the β tier is not assembled at
4742 // all — leave gb_delta/g_blocks empty and kron None. htt/gt
4743 // (built above) are the only outputs the frozen-decoder step
4744 // consumes.
4745 let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
4746 // Per-active-atom weighted basis row `a_k · φ_k[·]`, retained so the
4747 // data Gram blocks can be accumulated as clean per-atom-pair outer
4748 // products `(a_k φ_k) (a_{k'} φ_{k'})ᵀ`.
4749 let mut weighted_phi: Vec<(usize, Vec<f64>)> =
4750 Vec::with_capacity(row_active.len());
4751 if !fixed_decoder {
4752 for &atom_idx in row_active {
4753 let atom = &self.atoms[atom_idx];
4754 let atom_beta_off = beta_offsets[atom_idx];
4755 let m = atom.basis_size();
4756 let a_k = assignments[atom_idx];
4757 let mut wphi = Vec::with_capacity(m);
4758 for basis_col in 0..m {
4759 let phi = atom.basis_values[[row, basis_col]];
4760 // #991 design-honesty seam, β leg: the `√w_row` here pairs
4761 // with the `√w_row` on the residual (β gradient =
4762 // `a·φ · M r` ⇒ w_row) and with itself (β Gram `G` and the
4763 // htbeta Kronecker capture ⇒ w_row). `1.0` when unweighted.
4764 let w = a_k * phi * sqrt_row_w;
4765 a_phi.push((atom_beta_off + basis_col * p, w));
4766 wphi.push(w);
4767 }
4768 weighted_phi.push((atom_idx, wphi));
4769 }
4770 // β data-fit gradient `gᵦ += J_βᵀ M_n r_n`. The β-Jacobian is
4771 // `J_β = φ_nᵀ ⊗ I_p`, so `J_βᵀ M_n r_n = φ_n ⊗ (M_n r_n)` —
4772 // contract the basis weight `a·φ` against the p-space metric-applied
4773 // residual `error_metric` (= `M_n r_n`), the SAME whitening the value
4774 // path and t-block share. When not whitening, `error_metric == error`
4775 // and this is byte-identical to the historical `J_βᵀ r`.
4776 for &(beta_base_i, j_beta_i) in a_phi.iter() {
4777 if j_beta_i == 0.0 {
4778 continue;
4779 }
4780 for out_col in 0..p {
4781 gb_delta.push((
4782 beta_base_i + out_col,
4783 j_beta_i * error_metric[out_col],
4784 ));
4785 // No dense hbb write — the sparse `G ⊗ I_p` op installed
4786 // after the loop carries the data-fit GN β-Hessian.
4787 }
4788 }
4789 if frames_engaged {
4790 for &atom_idx in row_active {
4791 let atom = &self.atoms[atom_idx];
4792 let m = atom.basis_size();
4793 let a_k = assignments[atom_idx];
4794 for basis_col in 0..m {
4795 let phi = atom.basis_values[[row, basis_col]];
4796 let w = a_k * phi * sqrt_row_w;
4797 if w == 0.0 {
4798 continue;
4799 }
4800 let c_base = frame_projection.border_offsets[atom_idx]
4801 + basis_col * frame_projection.ranks[atom_idx];
4802 for c in 0..q_row {
4803 let mut hrow = block.htbeta.row_mut(c);
4804 let hrow_slice = hrow
4805 .as_slice_mut()
4806 .expect("htbeta row is contiguous");
4807 for out_col in 0..p {
4808 let value = local_jac_row[[c, out_col]] * w;
4809 frame_projection.accumulate_output_project(
4810 atom_idx, c_base, out_col, value, hrow_slice,
4811 );
4812 }
4813 }
4814 }
4815 }
4816 }
4817 // Data-fit GN β-Hessian: accumulate the channel-independent block
4818 // `G[μ_i, μ_j] += (a_k φ_k)[μ_i] (a_{k'} φ_{k'})[μ_j]` into the
4819 // sparse per-atom-pair map (the `out_col` dimension is carried by
4820 // `I_p`). Only co-occurring `(atom_i, atom_j)` pairs are touched.
4821 for ai in 0..weighted_phi.len() {
4822 let (atom_i, ref wphi_i) = weighted_phi[ai];
4823 let m_i = wphi_i.len();
4824 for aj in 0..weighted_phi.len() {
4825 let (atom_j, ref wphi_j) = weighted_phi[aj];
4826 let m_j = wphi_j.len();
4827 let blk = g_blocks
4828 .entry((atom_i, atom_j))
4829 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4830 for li in 0..m_i {
4831 let wi = wphi_i[li];
4832 if wi == 0.0 {
4833 continue;
4834 }
4835 for lj in 0..m_j {
4836 blk[[li, lj]] += wi * wphi_j[lj];
4837 }
4838 }
4839 }
4840 }
4841 } // #1407 end `if !fixed_decoder` β-tier accumulation
4842 let (kron_a_phi, kron_jac) = if !frames_engaged && !fixed_decoder {
4843 // Flatten local_jac_row row-major into a plain Vec<f64> (q_row * p entries).
4844 let mut jac_flat = vec![0.0_f64; q_row * p];
4845 for c in 0..q_row {
4846 for j in 0..p {
4847 jac_flat[c * p + j] = local_jac_row[[c, j]];
4848 }
4849 }
4850 (Some(a_phi), Some(jac_flat))
4851 } else {
4852 (None, None)
4853 };
4854 Ok(SaeAssemblyRow {
4855 row,
4856 block,
4857 gb_delta,
4858 g_blocks,
4859 kron_a_phi,
4860 kron_jac,
4861 })
4862 }) // #1557 with_nested_parallel
4863 },
4864 )
4865 .collect::<Result<Vec<_>, String>>()?;
4866
4867 // Fold THIS chunk's rows (ascending) into the global accumulators.
4868 // The parallel collect preserves index order within the chunk and
4869 // chunks are visited in ascending `chunk_start` order, so the overall
4870 // fold order is `0,1,2,…,n-1` — identical to the former single-pass
4871 // fold. The `row == chunk_start + fold_offset_in_chunk` assert pins
4872 // that strict sequential arrival (the invariant the `kron_*`
4873 // row-aligned pushes depend on).
4874 for row_result in row_results.into_iter() {
4875 let row = row_result.row;
4876 assert_eq!(
4877 row,
4878 chunk_start + fold_offset_in_chunk,
4879 "parallel SAE row assembly returned rows out of order"
4880 );
4881 fold_offset_in_chunk += 1;
4882 for (idx, value) in row_result.gb_delta {
4883 sys.gb[idx] += value;
4884 }
4885 for ((atom_i, atom_j), data) in row_result.g_blocks {
4886 let m_i = data.nrows();
4887 let m_j = data.ncols();
4888 let blk = g_blocks
4889 .entry((atom_i, atom_j))
4890 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4891 for li in 0..m_i {
4892 for lj in 0..m_j {
4893 blk[[li, lj]] += data[[li, lj]];
4894 }
4895 }
4896 }
4897 if !frames_engaged && !fixed_decoder {
4898 // Rows arrive in ascending order across chunks, so pushing
4899 // here yields `kron_*[row]` aligned to the row index exactly
4900 // as the single-pass `push` did.
4901 kron_a_phi.push(
4902 row_result
4903 .kron_a_phi
4904 .expect("full-B SAE row assembly must return a_phi rows"),
4905 );
4906 kron_jac.push(
4907 row_result
4908 .kron_jac
4909 .expect("full-B SAE row assembly must return local Jacobian rows"),
4910 );
4911 }
4912 sys.rows[row] = row_result.block;
4913 }
4914 chunk_start = chunk_end;
4915 }
4916 // #1407: fixed-decoder early return. The per-row htt/gt are now fully
4917 // assembled (data GN + assignment/ARD prior). Apply only the htt/gt
4918 // Riemannian projection (the decoder/β tier is intentionally absent), then
4919 // return the block-diagonal system. `fixed_decoder_step_from_rows` reads
4920 // only `rows[*].htt`/`gt` + `row_offsets`, so no β-tier object is needed.
4921 if fixed_decoder {
4922 match row_layout.as_ref() {
4923 None => {
4924 // Dense uniform-q: project htt/gt (and the 0-width htbeta, a
4925 // no-op) through the ext-coord manifold.
4926 self.apply_sae_riemannian_geometry(&mut sys);
4927 }
4928 Some(layout) => {
4929 // Compact heterogeneous-q: project each row's htt/gt at its
4930 // own ext-coord point, mirroring the full path's compact
4931 // Riemannian block (htbeta is 0-width here, so skipped).
4932 if !self.ext_coord_manifold().is_euclidean() {
4933 for row_idx in 0..n {
4934 let (manifold_i, point_i) =
4935 self.compact_row_ext_manifold_and_point(row_idx, layout);
4936 let t_i = point_i.view();
4937 let gt_e = sys.rows[row_idx].gt.clone();
4938 let htt_e = sys.rows[row_idx].htt.clone();
4939 sys.rows[row_idx].gt =
4940 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
4941 sys.rows[row_idx].htt = manifold_i.riemannian_hessian_matrix(
4942 t_i,
4943 gt_e.view(),
4944 htt_e.view(),
4945 );
4946 }
4947 }
4948 }
4949 }
4950 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
4951 sys.set_row_gauge_deflation(deflation);
4952 }
4953 self.last_row_layout = row_layout;
4954 self.last_frames_active = frames_engaged;
4955 return Ok(sys);
4956 }
4957 // Apply Riemannian geometry to the per-row row blocks (htt, gt) and
4958 // also to the per-row Kronecker local Jacobians stored in kron_jac.
4959 // When the SAE ext-coord manifold is non-Euclidean (any atom latent
4960 // on sphere / circle / interval), the local Jacobian rows that map
4961 // into the t-block tangent space must be projected via the per-row
4962 // tangent projector P_i. This mirrors what
4963 // `apply_riemannian_latent_geometry` does to `row.htbeta`, applied
4964 // here to the (q × p) kron_jac so the Kronecker htbeta_matvec uses
4965 // the Riemannian-projected form.
4966 // Apply Riemannian geometry only for the dense uniform-q layout. Any
4967 // compact active-set layout (JumpReLU gate or large-K softmax/IBP
4968 // truncation) has heterogeneous q_i; the Riemannian projector path
4969 // requires a uniform latent dimension. The sparse plan only engages on
4970 // Euclidean ext-coord manifolds (see `sparse_active_plan`), so skipping
4971 // the projector here is correct — there is nothing to project.
4972 match row_layout.as_ref() {
4973 None => {
4974 let raw_gt_rows: Vec<Array1<f64>> =
4975 sys.rows.iter().map(|row| row.gt.clone()).collect();
4976 self.apply_sae_riemannian_geometry(&mut sys);
4977 let manifold = self.ext_coord_manifold();
4978 if !frames_engaged && !manifold.is_euclidean() {
4979 let ext = self.ext_coord_matrix();
4980 // Project the local Jacobian columns onto the tangent space at
4981 // each row's ext-coord point. Each column `j` of the row's
4982 // (q_row × p) Jacobian is an ambient-space vector of length
4983 // `q_row`; the manifold projector acts on one such column at a
4984 // time. Working directly on the row-major `jac_flat` storage via
4985 // a single reusable `col_buf` avoids the two dense (q × p) copies
4986 // (flatten→Array2, project, unflatten→Vec) that previously fired
4987 // per row. `t_buf` still holds the row's ext-coord vector.
4988 let mut t_buf = vec![0.0_f64; q];
4989 let mut col_buf = Array1::<f64>::zeros(q);
4990 for row_idx in 0..n {
4991 let ext_row = ext.row(row_idx);
4992 for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
4993 *slot = v;
4994 }
4995 let t_i = ArrayView1::from(t_buf.as_slice());
4996 let raw_gt = raw_gt_rows[row_idx].view();
4997 let jac_flat = &mut kron_jac[row_idx];
4998 let q_row = jac_flat.len() / p;
4999 for j in 0..p {
5000 for c in 0..q_row {
5001 col_buf[c] = jac_flat[c * p + j];
5002 }
5003 let projected_col = manifold.project_vector_to_gradient_tangent(
5004 t_i,
5005 raw_gt.slice(ndarray::s![..q_row]),
5006 col_buf.slice(ndarray::s![..q_row]),
5007 );
5008 for c in 0..q_row {
5009 jac_flat[c * p + j] = projected_col[c];
5010 }
5011 }
5012 }
5013 }
5014 }
5015 Some(layout) => {
5016 // Compact active-set layout (#1117 follow-up): the dense
5017 // `ext_coord_manifold()` is keyed to the uniform full-`q` block
5018 // ordering, so it cannot be applied to the heterogeneous compact
5019 // rows directly. Instead we rebuild, PER ROW, the product manifold
5020 // and ext-coord point in that row's compact column order (see
5021 // `compact_row_ext_manifold_and_point`) and apply the SAME three
5022 // per-row Riemannian operations the dense
5023 // `apply_riemannian_latent_geometry` applies — gradient tangent
5024 // projection of `gt`, the Riemannian Hessian correction of `htt`,
5025 // and the column tangent projection of `htbeta` — plus the
5026 // identical Kronecker `kron_jac` column projection. On the shared
5027 // active support this is byte-identical to slicing the dense
5028 // product manifold, so engaging the sparse plan on a non-Euclidean
5029 // ext manifold is now correct (the former
5030 // `is_euclidean()`-only guard in `sparse_active_plan` is lifted).
5031 //
5032 // Euclidean ext manifolds still skip all of this (every
5033 // per-row manifold is a product of Euclidean parts whose
5034 // projector is the identity); we early-out so those rows stay
5035 // byte-for-byte the historical compact path.
5036 if !self.ext_coord_manifold().is_euclidean() {
5037 for row_idx in 0..n {
5038 let (manifold_i, point_i) =
5039 self.compact_row_ext_manifold_and_point(row_idx, layout);
5040 let t_i = point_i.view();
5041 // gt / htt / htbeta on the compact ArrowRowBlock, exactly
5042 // as `apply_riemannian_latent_geometry` does for dense
5043 // uniform-q rows.
5044 let gt_e = sys.rows[row_idx].gt.clone();
5045 let htt_e = sys.rows[row_idx].htt.clone();
5046 sys.rows[row_idx].gt =
5047 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
5048 sys.rows[row_idx].htt =
5049 manifold_i.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
5050 // #1406: only the frames path holds a real dense `htbeta`
5051 // slab; the matrix-free path leaves it 0-width (the
5052 // cross-block geometry is applied to `kron_jac` below), so
5053 // projecting a zero-column matrix is a no-op we skip.
5054 if frames_engaged {
5055 let htbeta_e = sys.rows[row_idx].htbeta.clone();
5056 sys.rows[row_idx].htbeta = manifold_i
5057 .project_matrix_columns_to_gradient_tangent(
5058 t_i,
5059 gt_e.view(),
5060 htbeta_e.view(),
5061 );
5062 }
5063 // Kronecker local-Jacobian column projection (full-B path
5064 // only), using the SAME pre-projection gradient `gt_e` so
5065 // the cross-block geometry matches the dense branch.
5066 if !frames_engaged {
5067 let jac_flat = &mut kron_jac[row_idx];
5068 let q_row = jac_flat.len() / p;
5069 let mut col_buf = Array1::<f64>::zeros(q_row);
5070 for j in 0..p {
5071 for c in 0..q_row {
5072 col_buf[c] = jac_flat[c * p + j];
5073 }
5074 let projected_col = manifold_i.project_vector_to_gradient_tangent(
5075 t_i,
5076 gt_e.view(),
5077 col_buf.view(),
5078 );
5079 for c in 0..q_row {
5080 jac_flat[c * p + j] = projected_col[c];
5081 }
5082 }
5083 }
5084 }
5085 }
5086 }
5087 }
5088 // Build and install the full-B Kronecker htbeta_matvec.
5089 //
5090 // `SaeKroneckerRows` holds per-row `(a_phi, local_jac)` and implements
5091 // the cross-block operator without ever materialising the dense
5092 // `(q × K·p)` slab. The cross-block factorises as `H_tβ = L · J_β`,
5093 // where `J_β = φᵀ ⊗ I_p` projects a length-`K` β vector onto the
5094 // `p`-dimensional decoded output space (`apply_jbeta`) and `L_i` is
5095 // the per-row `(q_i × p)` assignment+coordinate Jacobian that lifts
5096 // that p-vector into the row's `q_i`-dim tangent block (`apply_l`).
5097 // Both factors are required: the contract of `set_row_htbeta_operator`
5098 // is `out.len() == d` (= `q_i`), so writing `apply_jbeta`'s p-vector
5099 // output directly into a length-`q_i` buffer overflows whenever
5100 // `p > q_i` (the common case once `p` reflects real feature width).
5101 // Symmetric for the transpose: `H_βt = J_βᵀ · Lᵀ`, so apply `Lᵀ`
5102 // first to map the q_i-vector back to p-space, then scatter through
5103 // the support.
5104 // #1017/#1026: the legacy full-B device PCG assumes `G ⊗ I_p`, while
5105 // framed systems carry `G_ij ⊗ W_ij` with rank-r atom blocks. Feeding a
5106 // framed system to that kernel would silently return the wrong Newton
5107 // step. Framed device PCG therefore needs the dedicated factored kernel.
5108 // #1033 large-n: the per-row support `kron_a_phi` and local Jacobians
5109 // `kron_jac` are consumed by BOTH the host matrix-free row operator
5110 // (`SaeKroneckerRows`) and the solver's `DeviceSaePcgData`. Previously
5111 // each took its own full `O(n·q·p)` / `O(n·k_active)` clone, so the
5112 // always-resident footprint of the CPU non-frames path carried TWO copies
5113 // of the dominant Jacobian slab. Promote each to a single `Arc<[…]>` once
5114 // and hand both consumers a refcount bump (`O(1)`) — the backing
5115 // allocation is shared, halving the resident per-row Jacobian memory.
5116 // Reads are identical (`&arc[row]`, `.len()`), so the assembled system and
5117 // every matvec are bit-for-bit unchanged.
5118 let device_rows = if frames_engaged {
5119 None
5120 } else {
5121 let a_phi_shared: Arc<[Vec<(usize, f64)>]> =
5122 Arc::from(std::mem::take(&mut kron_a_phi).into_boxed_slice());
5123 let jac_shared: Arc<[Vec<f64>]> =
5124 Arc::from(std::mem::take(&mut kron_jac).into_boxed_slice());
5125 Some((a_phi_shared, jac_shared))
5126 };
5127 if !frames_engaged {
5128 let (a_phi_shared, jac_shared) = device_rows
5129 .clone()
5130 .expect("non-frames path always populates device_rows");
5131 let kron = Arc::new(SaeKroneckerRows::new(p, a_phi_shared, jac_shared));
5132 let kron_t = Arc::clone(&kron);
5133 let p_dim = p;
5134 sys.set_row_htbeta_operator(
5135 move |row_idx, x, out| {
5136 // out = L_i · (J_β · x). Allocate a length-p scratch buffer
5137 // for the intermediate decoded-output vector; both factors
5138 // overwrite their output buffers (`apply_jbeta` zeroes
5139 // before accumulating, `apply_l` writes per-row), so no
5140 // pre-zeroing of `u_p`/`out` is needed.
5141 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5142 let mut u_p = vec![0.0_f64; p_dim];
5143 if let Some(xs) = x.as_slice() {
5144 kron.apply_jbeta(row_idx, xs, &mut u_p);
5145 } else {
5146 let x_vec: Vec<f64> = x.iter().copied().collect();
5147 kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
5148 }
5149 kron.apply_l(row_idx, &u_p, out_slice);
5150 },
5151 move |row_idx, v, out| {
5152 // out += J_βᵀ · (Lᵀ · v). `apply_l_t` accumulates into a
5153 // zero-initialised length-p buffer to produce the p-vector
5154 // `Lᵀ v`; `scatter_jbeta_t` then adds φ_i[s] · u_p[j] into
5155 // the length-K β accumulator at each active `(s, j)`.
5156 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5157 let mut u_p = vec![0.0_f64; p_dim];
5158 if let Some(vs) = v.as_slice() {
5159 kron_t.apply_l_t(row_idx, vs, &mut u_p);
5160 } else {
5161 let v_vec: Vec<f64> = v.iter().copied().collect();
5162 kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
5163 }
5164 kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
5165 },
5166 );
5167 }
5168 let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
5169 let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
5170 Some(&frame_projection)
5171 } else {
5172 None
5173 };
5174 if let Some(registry) = analytic_penalties {
5175 // Upfront validation: refuse penalty kinds the SAE row layout
5176 // cannot host, and refuse mixed-d row-block configurations.
5177 // This makes the dispatch loop below total — no runtime
5178 // "unsupported penalty" fallthrough, no K-gating.
5179 self.validate_analytic_penalty_registry(registry)
5180 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5181 beta_penalty_assembly = self
5182 .add_sae_analytic_penalty_contributions(
5183 &mut sys,
5184 registry,
5185 penalty_scale,
5186 row_layout.as_ref(),
5187 dense_beta_curvature,
5188 factored_row_projection,
5189 )
5190 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5191 }
5192 // #1026 — decoder repulsion (collinearity-gated, registry-independent):
5193 // accumulate into the full-`B` β-tier here, BEFORE the frame transform,
5194 // so a framed system carries it identically to the analytic β penalties.
5195 // No-op unless two atoms are near-collinear (the frozen gate is `None`).
5196 if self.add_sae_decoder_repulsion(&mut sys, penalty_scale, dense_beta_curvature) {
5197 beta_penalty_assembly.record_curvature(dense_beta_curvature);
5198 }
5199 // #1026/#1522 — interior-point collapse-prevention barriers. The amplitude
5200 // barrier supplies the OUTWARD radial force at the zero-decoder collapse
5201 // point (the principal failure state the threshold repulsion skips), and
5202 // the separation barrier supplies the alignment-divergent separating
5203 // curvature on normalized shapes weighted by coactivation. Both accumulate
5204 // into the full-`B` β-tier here, BEFORE the frame transform, so a framed
5205 // system carries them identically to the analytic β penalties.
5206 // #1610 — on the dense path the barrier's Levenberg majorizer scatters
5207 // onto `sys.hbb`; on the matrix-free / framed production path `sys.hbb` is
5208 // unused, so the barrier hands back a per-atom scalar ridge which we fold
5209 // into `smooth_scaled_s` (the single source for the CPU composite penalty
5210 // op AND the device smooth blocks), restoring the collapse-prevention
5211 // curvature the operator was silently dropping there.
5212 let mut sep_atom_curv = vec![0.0_f64; self.atoms.len()];
5213 if self.add_sae_separation_barrier(
5214 &mut sys,
5215 penalty_scale,
5216 dense_beta_curvature,
5217 &mut sep_atom_curv,
5218 ) {
5219 if dense_beta_curvature {
5220 beta_penalty_assembly.record_curvature(true);
5221 } else {
5222 // Fold the per-atom majorizer `lev_k·I_{M_k}` into the smooth
5223 // penalty factor `λ S_k`. With `⊗ I_p` (full-`B`) or `⊗ I_{r_k}`
5224 // (factored, `U_kᵀU_k = I`) this is exactly the `lev_k·I` block
5225 // diagonal the dense path writes — and it now flows through the
5226 // structured penalty op and the device smooth blocks. No
5227 // `deferred_factored` mark: the curvature is in the smooth op, not
5228 // a deferred dense block, so the device path stays engaged.
5229 for atom_idx in 0..self.atoms.len() {
5230 let c = sep_atom_curv[atom_idx];
5231 if c > 0.0 {
5232 let m = smooth_scaled_s[atom_idx].nrows();
5233 for i in 0..m {
5234 smooth_scaled_s[atom_idx][[i, i]] += c;
5235 }
5236 smooth_ops[atom_idx] = Arc::new(IdentityRightKroneckerPenaltyOp {
5237 factor_a: smooth_scaled_s[atom_idx].clone(),
5238 p,
5239 global_offset: beta_offsets[atom_idx],
5240 k: beta_dim,
5241 });
5242 }
5243 }
5244 }
5245 }
5246 if frames_engaged {
5247 // ── #972 / #977 T1 — FACTORED β-tier transform ──────────────────
5248 //
5249 // The entire β-tier above was assembled in the full-`B` (p-wide)
5250 // layout: `sys.gb` is `g_B` (length `beta_dim`), `sys.hbb` carries
5251 // any analytic Beta-tier penalty, and `g_blocks` is the
5252 // FRAME-INDEPENDENT basis Gram. We now rebuild the β-tier in the
5253 // factored coordinate space `C` (width `factored_border_dim`), the
5254 // full-`B` system sandwiched by `Φ = blkdiag(I_{M_k} ⊗ U_k)`:
5255 // * gradient `g_C = Φᵀ g_B` (per atom `(g_B U_k)`),
5256 // * data H `Φᵀ(G⊗I_p)Φ = G_{ij}⊗(U_iᵀU_j)`,
5257 // * smooth `λ S_k ⊗ I_{r_k}` (since `U_kᵀU_k = I`),
5258 // * analytic `Φᵀ hbb Φ` (dense, only if written).
5259 // Un-framed atoms ride the `r_k = p, U_k = I_p` identity special case.
5260 let off_c = &frame_projection.border_offsets;
5261 let ranks = &frame_projection.ranks;
5262 let basis_sizes = &frame_projection.basis_sizes;
5263 let border_dim = frame_projection.border_dim();
5264 let gb_c = frame_projection.project_border_vec(sys.gb.view());
5265
5266 // Data β-Hessian: `G_{ij} ⊗ W_{ij}` with `W_{ij} = U_iᵀU_j`. The
5267 // basis Gram `g_blocks` is unchanged; only the output factor is the
5268 // per-pair frame overlap (`I_{r_k}` within a framed atom, `I_p` for
5269 // un-framed).
5270 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
5271 for ((atom_i, atom_j), data) in g_blocks.into_iter() {
5272 if data.iter().all(|&v| v == 0.0) {
5273 continue;
5274 }
5275 // `W_{ij} = U_iᵀ U_j` from the precomputed per-atom frames.
5276 let w = self.frame_cross_factor(atom_i, atom_j);
5277 frame_blocks.push(FactoredFrameGBlock {
5278 atom_i,
5279 atom_j,
5280 g: data,
5281 w,
5282 });
5283 }
5284 // #1017/#1026 — snapshot the factored data-fit blocks for the
5285 // frames-engaged device PCG BEFORE `FactoredFrameKroneckerOp::new`
5286 // consumes them. Cheap clone (co-occurring blocks only).
5287 let device_frame_blocks = frame_blocks.clone();
5288 let data_op =
5289 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
5290
5291 // Smooth penalty in factored space: `λ S_k ⊗ I_{r_k}` at `off_C[k]`.
5292 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
5293 for k in 0..self.atoms.len() {
5294 let r = ranks[k];
5295 ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
5296 factor_a: smooth_scaled_s[k].clone(),
5297 p: r,
5298 global_offset: off_c[k],
5299 k: border_dim,
5300 }));
5301 }
5302 ops.push(Arc::new(data_op));
5303 // Analytic Beta-tier penalty: project the dense full-`B` `hbb` block
5304 // `Φᵀ hbb Φ` into the factored space. Only present when a Beta-tier
5305 // penalty actually wrote `hbb` (else `hbb` is all-zero and the dense
5306 // `(border_dim)²` op is skipped entirely, exactly as full-`B`).
5307 if beta_penalty_assembly.dense_written {
5308 let hbb_c =
5309 self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
5310 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5311 } else if beta_penalty_assembly.deferred_factored {
5312 // Registry Beta-tier curvature deferred to factored-space probing.
5313 // The registry may be absent when `deferred_factored` was set ONLY
5314 // by the frozen-gate decoder repulsion (which is
5315 // registry-independent), so start from a zero factored block in
5316 // that case instead of unwrapping.
5317 let mut hbb_c = match analytic_penalties {
5318 Some(registry) => self.build_factored_beta_penalty_curvature(
5319 registry,
5320 penalty_scale,
5321 &frame_projection,
5322 ),
5323 None => Array2::<f64>::zeros((
5324 frame_projection.border_dim(),
5325 frame_projection.border_dim(),
5326 )),
5327 };
5328 // #1610 — the frozen-gate decoder repulsion's PSD majorizer was
5329 // dropped on this matrix-free/framed path (only its gradient was
5330 // applied). Project it into the factored block via the same
5331 // `psd_majorizer_hvp` + frame-projection probe pattern the registry
5332 // DecoderIncoherence uses, so the collapse-prevention curvature
5333 // reaches the operator here too. No-op when no repulsion is active.
5334 self.add_factored_repulsion_curvature(
5335 &mut hbb_c,
5336 penalty_scale,
5337 &frame_projection,
5338 );
5339 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5340 }
5341
5342 // Re-point the system's β-tier to the factored width. The t-tier
5343 // (per-row `htt`, `gt`) is frame-independent and untouched; row
5344 // cross-block slabs were allocated and assembled directly in
5345 // factored coordinates, so analytic row supplements and data-fit
5346 // cross terms already share shape `(q_i × factored_border_dim)`.
5347 sys.k = border_dim;
5348 sys.gb = gb_c;
5349 self.reclaim_border_hbb_workspace(&mut sys);
5350 // Factored per-atom block ranges for the block-Jacobi Schur
5351 // preconditioner: `[off_C[k] .. off_C[k] + M_k·r_k]`.
5352 let mut block_ranges: Vec<std::ops::Range<usize>> =
5353 Vec::with_capacity(self.atoms.len());
5354 for k in 0..self.atoms.len() {
5355 let start = off_c[k];
5356 block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
5357 }
5358 sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
5359 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
5360 // #1017/#1026 — install the frames-engaged device SAE PCG data. Skipped
5361 // (CPU fallback) when a dense analytic Beta-tier penalty fired (the
5362 // device kernel does not model that extra dense term). Builder:
5363 // `crate::frames::build_framed_device_sae_data`.
5364 let has_dense_beta_penalty =
5365 beta_penalty_assembly.dense_written || beta_penalty_assembly.deferred_factored;
5366 if !has_dense_beta_penalty {
5367 let device = crate::frames::build_framed_device_sae_data(
5368 crate::frames::FramedDeviceArgs {
5369 p,
5370 border_dim,
5371 border_offsets: off_c.as_slice(),
5372 ranks: ranks.as_slice(),
5373 basis_sizes: basis_sizes.as_slice(),
5374 smooth_scaled_s: &smooth_scaled_s,
5375 frame_blocks: device_frame_blocks,
5376 rows: &sys.rows,
5377 },
5378 );
5379 sys.set_device_sae_pcg_data(device);
5380 }
5381 } else {
5382 let (device_a_phi, device_local_jac) =
5383 device_rows.expect("full-beta SAE PCG rows are cloned before row operator install");
5384 // Wire per-atom β block ranges so the Jacobi preconditioner builds one
5385 // dense Schur sub-block per atom (block-Jacobi) instead of scalar-diagonal
5386 // inversion. Each atom's decoder coefficients form a natural block:
5387 // `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
5388 sys.set_block_offsets(self.beta_block_offsets());
5389 // Install the composite BetaPenaltyOp (#296): smoothness contributions
5390 // via per-atom KroneckerPenaltyOp (avoid dense K×K materialisation), the
5391 // data-fit Gauss-Newton β-Hessian as the structured `G ⊗ I_p`
5392 // SparseBlockKroneckerPenaltyOp (block-sparse over co-occurring
5393 // `(atom, atom')` pairs, block-diagonal across the `p` output channels,
5394 // identical per channel), plus — only when a Beta-tier analytic penalty
5395 // was written — the dense `sys.hbb` residual contribution. When no beta
5396 // penalty fired, `sys.hbb` is all-zero and the dense `(K·p)²` operator
5397 // is skipped entirely. The sparse data op tracks only the active-atom
5398 // couplings, so its storage and matvec cost scale with `k_active`, not
5399 // `K`, at `K = 100K`.
5400 // Convert the per-atom-pair coupling map into `SparseGBlock`s keyed
5401 // by μ-space offsets. Empty blocks (no co-occurrence) are simply
5402 // absent from the map.
5403 let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
5404 .into_iter()
5405 .filter_map(|((atom_i, atom_j), data)| {
5406 if data.iter().all(|&v| v == 0.0) {
5407 None
5408 } else {
5409 Some(SparseGBlock {
5410 row_off: mu_offsets[atom_i],
5411 col_off: mu_offsets[atom_j],
5412 data,
5413 })
5414 }
5415 })
5416 .collect();
5417 let device_smooth_blocks = smooth_scaled_s
5418 .iter()
5419 .enumerate()
5420 .map(|(atom_idx, factor_a)| {
5421 // #1117 — rank deficiency is removed at the basis layer, so the
5422 // device PCG smooth block is just `λ S_k ⊗ I_p` (full-rank
5423 // design); no data-null deflation is folded in here.
5424 DeviceSaeSmoothBlock {
5425 global_offset: beta_offsets[atom_idx],
5426 factor_a: factor_a.clone(),
5427 }
5428 })
5429 .collect();
5430 sys.set_device_sae_pcg_data(DeviceSaePcgData {
5431 p,
5432 beta_dim,
5433 a_phi: device_a_phi,
5434 local_jac: device_local_jac,
5435 smooth_blocks: device_smooth_blocks,
5436 sparse_g_blocks: g_sparse_blocks.clone(),
5437 frame: None,
5438 });
5439 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
5440 ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
5441 p,
5442 dim_a: m_total,
5443 k: beta_dim,
5444 blocks: g_sparse_blocks,
5445 }));
5446 if beta_penalty_assembly.dense_written {
5447 ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
5448 }
5449 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
5450 self.reclaim_border_hbb_workspace(&mut sys);
5451 }
5452 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
5453 sys.set_row_gauge_deflation(deflation);
5454 }
5455 // #1038 IBP cross-row Woodbury source. The exact IBP Hessian has the
5456 // per-column rank-one cross-row block `H_(i,k),(j,k) = w·s'_k·z'_ik·z'_jk`
5457 // (for ALL `i,j`, including the `i=j` self term) that couples DISTINCT
5458 // latent rows through the shared empirical mass `M_k = Σ_i z_ik`. The
5459 // assembled row-block-diagonal `htt` already carries the `i=j` self term
5460 // `w·s'_k·z'_ik²` — it is the first summand of `assignment_hdiag`'s
5461 // `hessian_diag` value `w·(score_derivative·z_jac² + score·c_ik)` written
5462 // at the logit diagonal above. So the consumer (`solver::arrow_schur`,
5463 // #1038 `IbpCrossRowSource`/`CrossRowWoodbury`) DOWNDATES exactly
5464 // `Σ_k d_k·z'_ik²` (`self_term_downdate`) to recover the NO-SELF base
5465 // `H₀'`, then re-adds the FULL rank-one `U D Uᵀ` via the determinant
5466 // lemma — so value, the evidence log-determinant, and the θ/ρ-adjoint all
5467 // differentiate the SAME `H_full = H₀' + U D Uᵀ`.
5468 //
5469 // The source is built from the SAME `ibp_assignment_third_channels`
5470 // operator the #1006 θ-adjoint consumes:
5471 // * `d[k] = cross_row_d[k] = w·s'_k = w·score_derivative_k` (the column
5472 // `D`-coefficient — NOT sign-definite, hence the consumer's
5473 // indefinite-capacitance LU);
5474 // * `entries[(i,k)] = (global_t_index, k, z'_ik)` with `z'_ik =
5475 // z_jac[i·K + k]`. For the DENSE layout (`assignment_coord_dim() = K`,
5476 // `last_row_layout = None`) atom `k`'s logit slot is local position `k`
5477 // of row `i`'s block, so `global_t_index = sys.row_offsets[i] + k`. For
5478 // the COMPACT layout (#1420) only the row's active atoms are
5479 // coordinates and atom `k` lives at local position `pos` of
5480 // `active_atoms[row]`, so `global_t_index = sys.row_offsets[i] + pos`.
5481 // Both pin the `U`-column convention bit-for-bit to the consumer's
5482 // `ibp_logit_sites`/`row_vars_for_cache_row` slot mapping.
5483 if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
5484 let mut entries: Vec<(usize, usize, f64)> = Vec::with_capacity(n * k_atoms);
5485 for row in 0..n {
5486 let start = row * k_atoms;
5487 let g_base = sys.row_offsets[row];
5488 match row_layout.as_ref() {
5489 // #1420: compact layout — the local logit slot `pos` (not the
5490 // global atom index `k`) is the t-coordinate. Atom `k`'s logit
5491 // lives at local position `pos` of `active_atoms[row]`, so emit
5492 // `(g_base + pos, atom, z_jac[row·K + atom])` for the active set
5493 // only. Using `g_base + k` would attach atom `k`'s derivative to
5494 // the wrong slot (and run out of range for compact rows),
5495 // violating the `IbpCrossRowSource` contract.
5496 Some(layout) => {
5497 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
5498 let z_prime = channels.z_jac[start + atom];
5499 entries.push((g_base + pos, atom, z_prime));
5500 }
5501 }
5502 // Dense layout: atom `k`'s logit slot is local position `k`.
5503 None => {
5504 for k in 0..k_atoms {
5505 let z_prime = channels.z_jac[start + k];
5506 entries.push((g_base + k, k, z_prime));
5507 }
5508 }
5509 }
5510 }
5511 let source = IbpCrossRowSource {
5512 r: k_atoms,
5513 d: channels.cross_row_d.clone(),
5514 entries,
5515 };
5516 sys.set_ibp_cross_row_source(source);
5517 }
5518 // Store the active-set layout for `apply_newton_step`.
5519 self.last_row_layout = row_layout;
5520 // Record whether `delta_beta` from this system is a factored ΔC (needs a
5521 // frame lift) or a full-`B` ΔB. Read by `apply_newton_step_impl`.
5522 self.last_frames_active = frames_engaged;
5523 Ok(sys)
5524 }
5525
5526 /// Project a dense full-`B` Beta-tier penalty Hessian `hbb` (`beta_dim ×
5527 /// beta_dim`, the analytic `∂²P/∂B∂B` block) into the factored coordinate
5528 /// space `Φᵀ hbb Φ` (`border_dim × border_dim`) for the #972 / #977 T1
5529 /// frame transform. `Φ = blkdiag(I_{M_k} ⊗ U_k)` maps C-space → B-space, so
5530 /// the projected block contracts both index legs through the per-atom frames.
5531 ///
5532 /// The projection is done in two passes to stay `O(beta_dim · border_dim +
5533 /// border_dim²)` instead of forming the dense `Φ` explicitly: first
5534 /// `T = hbb · Φ` (right multiply, columns fold `U`), then `Φᵀ · T` (left
5535 /// multiply, rows fold `U`). Analytic Beta-tier penalties are rare and small,
5536 /// so this only fires when one is actually installed.
5537 pub(crate) fn project_dense_penalty_to_factored(
5538 &self,
5539 hbb: ArrayView2<'_, f64>,
5540 projection: &FrameProjection,
5541 ) -> Array2<f64> {
5542 projection.project_block(hbb)
5543 }
5544
5545 pub(crate) fn build_factored_beta_penalty_curvature(
5546 &self,
5547 registry: &AnalyticPenaltyRegistry,
5548 penalty_scale: f64,
5549 projection: &FrameProjection,
5550 ) -> Array2<f64> {
5551 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
5552 let layout = registry.rho_layout();
5553 let target_beta = self.flatten_beta();
5554 let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
5555 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
5556 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
5557 continue;
5558 }
5559 let rho_local = rho_global.slice(s![rho_slice.clone()]);
5560 match tier {
5561 PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
5562 self.add_factored_beta_penalty_curvature_for_penalty(
5563 &mut hbb_c,
5564 penalty,
5565 target_beta.view(),
5566 rho_local,
5567 penalty_scale,
5568 projection,
5569 );
5570 }
5571 PenaltyTier::Beta => {
5572 self.add_factored_beta_penalty_curvature_for_penalty(
5573 &mut hbb_c,
5574 penalty,
5575 target_beta.view(),
5576 rho_local,
5577 penalty_scale,
5578 projection,
5579 );
5580 }
5581 _ => {}
5582 }
5583 }
5584 hbb_c
5585 }
5586
5587 pub(crate) fn add_factored_beta_penalty_curvature_for_penalty(
5588 &self,
5589 hbb_c: &mut Array2<f64>,
5590 penalty: &AnalyticPenaltyKind,
5591 target_beta: ArrayView1<'_, f64>,
5592 rho_local: ArrayView1<'_, f64>,
5593 penalty_scale: f64,
5594 projection: &FrameProjection,
5595 ) {
5596 let p = self.output_dim();
5597 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
5598 let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
5599 return;
5600 };
5601 let beta_dim = self.beta_dim();
5602 let mut probe = Array1::<f64>::zeros(beta_dim);
5603 for k in 0..self.atoms.len() {
5604 for basis_col in 0..projection.basis_sizes[k] {
5605 for frame_col in 0..projection.ranks[k] {
5606 probe.fill(0.0);
5607 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5608 let col = projection.border_offsets[k]
5609 + basis_col * projection.ranks[k]
5610 + frame_col;
5611 let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5612 projection
5613 .project_border_vec(hv.view())
5614 .iter()
5615 .enumerate()
5616 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5617 }
5618 }
5619 }
5620 return;
5621 }
5622 if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
5623 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
5624 let atom_idx = projection
5625 .beta_offsets
5626 .iter()
5627 .position(|&offset| offset == start)
5628 .expect("live mechanism-sparsity offset must match an SAE atom");
5629 let block_len = end - start;
5630 let mut local_penalty = per_atom.clone();
5631 local_penalty.target = PsiSlice {
5632 range: 0..block_len,
5633 latent_dim: Some(projection.basis_sizes[atom_idx]),
5634 };
5635 let block = target_beta.slice(s![start..end]);
5636 let mut probe = Array1::<f64>::zeros(block_len);
5637 for basis_col in 0..projection.basis_sizes[atom_idx] {
5638 for frame_col in 0..projection.ranks[atom_idx] {
5639 probe.fill(0.0);
5640 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5641 let col = projection.border_offsets[atom_idx]
5642 + basis_col * projection.ranks[atom_idx]
5643 + frame_col;
5644 let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
5645 projection.project_local_atom_vec_into(
5646 atom_idx,
5647 hv.view(),
5648 hbb_c.column_mut(col),
5649 penalty_scale,
5650 );
5651 }
5652 }
5653 }
5654 return;
5655 }
5656 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
5657 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
5658 let atom_idx = projection
5659 .beta_offsets
5660 .iter()
5661 .position(|&offset| offset == start)
5662 .expect("live nuclear-norm offset must match an SAE atom");
5663 let block = target_beta.slice(s![start..end]);
5664 let block_len = end - start;
5665 let mut probe = Array1::<f64>::zeros(block_len);
5666 for basis_col in 0..projection.basis_sizes[atom_idx] {
5667 for frame_col in 0..projection.ranks[atom_idx] {
5668 probe.fill(0.0);
5669 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5670 let col = projection.border_offsets[atom_idx]
5671 + basis_col * projection.ranks[atom_idx]
5672 + frame_col;
5673 let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
5674 projection.project_local_atom_vec_into(
5675 atom_idx,
5676 hv.view(),
5677 hbb_c.column_mut(col),
5678 penalty_scale,
5679 );
5680 }
5681 }
5682 }
5683 return;
5684 }
5685 let beta_dim = self.beta_dim();
5686 let mut probe = Array1::<f64>::zeros(beta_dim);
5687 for k in 0..self.atoms.len() {
5688 for basis_col in 0..projection.basis_sizes[k] {
5689 for frame_col in 0..projection.ranks[k] {
5690 probe.fill(0.0);
5691 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5692 let col =
5693 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5694 let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5695 projection
5696 .project_border_vec(hv.view())
5697 .iter()
5698 .enumerate()
5699 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5700 }
5701 }
5702 }
5703 assert_eq!(p, self.output_dim());
5704 }
5705
5706 /// #1610 — project the frozen-gate decoder-repulsion PSD majorizer into the
5707 /// factored β block `hbb_c`. Mirrors the `DecoderIncoherence` arm of
5708 /// [`Self::add_factored_beta_penalty_curvature_for_penalty`] but sources the
5709 /// penalty from [`Self::live_decoder_repulsion_penalty`] (registry-independent,
5710 /// collinearity-gated), so the repulsion curvature reaches the operator on the
5711 /// matrix-free/framed path where the dense `sys.hbb` write is unused. No-op
5712 /// when no repulsion is active.
5713 pub(crate) fn add_factored_repulsion_curvature(
5714 &self,
5715 hbb_c: &mut Array2<f64>,
5716 penalty_scale: f64,
5717 projection: &FrameProjection,
5718 ) {
5719 let Some(per_fit) = self.live_decoder_repulsion_penalty() else {
5720 return;
5721 };
5722 let beta_dim = self.beta_dim();
5723 let target_beta = self.flatten_beta();
5724 // The repulsion penalty is non-learnable; its strength is already folded
5725 // into the frozen gate (see `live_decoder_repulsion_penalty`), so the rho
5726 // slice is empty/inert.
5727 let rho_local = Array1::<f64>::zeros(0);
5728 let mut probe = Array1::<f64>::zeros(beta_dim);
5729 for k in 0..self.atoms.len() {
5730 for basis_col in 0..projection.basis_sizes[k] {
5731 for frame_col in 0..projection.ranks[k] {
5732 probe.fill(0.0);
5733 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5734 let col =
5735 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5736 let hv =
5737 per_fit.psd_majorizer_hvp(target_beta.view(), rho_local.view(), probe.view());
5738 projection
5739 .project_border_vec(hv.view())
5740 .iter()
5741 .enumerate()
5742 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5743 }
5744 }
5745 }
5746 }
5747
5748 pub(crate) fn ext_coord_matrix(&self) -> Array2<f64> {
5749 let n = self.n_obs();
5750 let q = self.assignment.row_block_dim();
5751 let flat = self.assignment.flatten_ext_coords();
5752 let mut out = Array2::<f64>::zeros((n, q));
5753 for row in 0..n {
5754 for col in 0..q {
5755 out[[row, col]] = flat[row * q + col];
5756 }
5757 }
5758 out
5759 }
5760
5761 pub(crate) fn ext_coord_manifold(&self) -> LatentManifold {
5762 let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
5763 for _ in 0..self.assignment.assignment_coord_dim() {
5764 parts.push(LatentManifold::Euclidean);
5765 }
5766 let mut any_constrained = false;
5767 for coord in &self.assignment.coords {
5768 if coord.manifold().is_euclidean() {
5769 for _ in 0..coord.latent_dim() {
5770 parts.push(LatentManifold::Euclidean);
5771 }
5772 } else {
5773 any_constrained = true;
5774 parts.push(coord.manifold().clone());
5775 }
5776 }
5777 if any_constrained {
5778 LatentManifold::Product(parts)
5779 } else {
5780 LatentManifold::Euclidean
5781 }
5782 }
5783
5784 pub(crate) fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
5785 let manifold = self.ext_coord_manifold();
5786 if manifold.is_euclidean() {
5787 return;
5788 }
5789 let ext = self.ext_coord_matrix();
5790 let latent =
5791 LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
5792 sys.apply_riemannian_latent_geometry(&latent);
5793 }
5794
5795 /// Build the compact-layout ext-coord product manifold and point for one row.
5796 ///
5797 /// The dense `ext_coord_manifold()` is keyed to the full-`q` block ordering
5798 /// `[assignment parts (all Euclidean for IBP-MAP / JumpReLU), then per-atom
5799 /// coord blocks in atom order]`. A compact active-set row instead lays its
5800 /// `q_active` columns out as `[one Euclidean logit slot per active atom,
5801 /// then each active atom's coord block in `active` order]` (see
5802 /// [`SaeRowLayout::from_active_atoms`] / `coord_starts`). To reuse the exact
5803 /// per-row Riemannian projector on the compact block we rebuild a product
5804 /// manifold and the matching ext-coord point in that compact order: the
5805 /// `active.len()` logit slots are `Euclidean` (the assignment channel is
5806 /// always Euclidean for the modes that engage sparsity — `assignment_coord_dim
5807 /// == k_atoms`), and each active atom contributes its own coordinate
5808 /// manifold. On the shared active support this is byte-identical to slicing
5809 /// the dense full-`q` product manifold, so the compact projection matches the
5810 /// dense path exactly — it only drops the inactive atoms' (negligible-mass)
5811 /// coordinate blocks the compact layout already excludes from curvature.
5812 ///
5813 /// Returns `(manifold, t_compact)` where `t_compact` has length `q_active`.
5814 /// The logit-slot entries of `t_compact` are filled from the row logits (the
5815 /// Euclidean projector ignores the point, so any finite value is equivalent;
5816 /// using the true logits keeps the point well-defined and finite).
5817 pub(crate) fn compact_row_ext_manifold_and_point(
5818 &self,
5819 row: usize,
5820 layout: &SaeRowLayout,
5821 ) -> (LatentManifold, Array1<f64>) {
5822 let active = &layout.active_atoms[row];
5823 let q_active = layout.row_q_active(row);
5824 let mut parts: Vec<LatentManifold> = Vec::with_capacity(active.len() + active.len());
5825 let mut point = Array1::<f64>::zeros(q_active);
5826 // Logit slots: one Euclidean part per active atom, in `active` order.
5827 let logits_row = self.assignment.logits.row(row);
5828 for (j, &k) in active.iter().enumerate() {
5829 parts.push(LatentManifold::Euclidean);
5830 point[j] = logits_row[k];
5831 }
5832 // Coordinate blocks: each active atom's coordinate manifold + point, at
5833 // the compact coord start the layout assigned it.
5834 for (j, &k) in active.iter().enumerate() {
5835 let coord = &self.assignment.coords[k];
5836 let d = coord.latent_dim();
5837 let coord_start = layout.coord_starts[row][j];
5838 let manifold_k = coord.manifold();
5839 // A `d`-dim coordinate whose manifold is a product (e.g. a torus =
5840 // Circle×Circle) already carries its `d` parts; a scalar manifold is
5841 // one part. Either way the manifold's ambient width must equal `d`,
5842 // matching the `d` compact columns at `coord_start`.
5843 parts.push(manifold_k.clone());
5844 let coord_point = coord.row(row);
5845 for axis in 0..d {
5846 point[coord_start + axis] = coord_point[axis];
5847 }
5848 }
5849 (LatentManifold::Product(parts), point)
5850 }
5851
5852 /// Numerical rank of a symmetric matrix: the count of eigenvalues
5853 /// exceeding `tol · max_eig`, with `tol = 1e-9` (the conventional
5854 /// relative spectral cutoff used elsewhere in the codebase).
5855 ///
5856 /// Used to count the penalised dimension of each atom's `smooth_penalty`
5857 /// `S_k` so the REML criterion's `−½·p·rank(S)·log λ_smooth` Occam term
5858 /// uses the *effective* penalty rank rather than the ambient basis size
5859 /// (a thin-plate / B-spline penalty has a non-trivial null space).
5860 pub(crate) fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
5861 if s.nrows() != s.ncols() {
5862 return Err(format!(
5863 "SaeManifoldTerm::symmetric_rank: matrix must be square, got {}x{}",
5864 s.nrows(),
5865 s.ncols()
5866 ));
5867 }
5868 let m = s.ncols();
5869 if m == 0 {
5870 return Ok(0);
5871 }
5872 // Symmetrize defensively through the shared ndarray helper. The SAE
5873 // rank cutoff is intentionally local to the SAE evidence contract; only
5874 // the symmetric cleanup is shared with the other construction modules.
5875 let mut sym = s.clone();
5876 gam_linalg::matrix::symmetrize_in_place(&mut sym);
5877 let (evals, _evecs) = sym
5878 .eigh(Side::Lower)
5879 .map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
5880 let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
5881 if !(max_eig > 0.0) {
5882 return Ok(0);
5883 }
5884 let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
5885 Ok(evals.iter().filter(|&&v| v > tol).count())
5886 }
5887
5888 /// Penalised quasi-Laplace evidence score for the SAE term at a FIXED ρ.
5889 ///
5890 /// #1421: this is NOT a true normalized-prior REML/evidence objective. The
5891 /// assignment priors (softmax entropy, JumpReLU) have NO finite normalizer:
5892 /// for softmax the reference-logit chart sends `P(ℓ)→0` as a free logit →±∞
5893 /// so `∫ e^{−λP} dℓ = ∞`, and JumpReLU's bounded penalty `0<P<λ` keeps
5894 /// `e^{−λP}` bounded below over an unbounded domain, also divergent. There is
5895 /// therefore no ρ-independent assignment-prior normalizer that can be dropped
5896 /// as a constant. The smoothing-penalty `−½log|λS|_+` term IS a genuine
5897 /// (proper-Gaussian) REML normalizer and is kept exactly; the rest is a
5898 /// penalized quasi-Laplace score (Laplace curvature term `½log|H|` around the
5899 /// inner optimum), which the engine minimizes over ρ.
5900 ///
5901 /// Runs the inner `(t, β)` arrow-Schur Newton solve to convergence at the
5902 /// supplied ρ (with NO in-loop ARD update — ρ is owned by the engine),
5903 /// then forms the Laplace/REML cost
5904 ///
5905 /// ```text
5906 /// V(ρ) = ℓ_pen(t̂, β̂; ρ) + ½ log|H(t̂, β̂; ρ)|
5907 /// − ½ · p · (Σ_k rank S_k) · log λ_smooth
5908 /// ```
5909 ///
5910 /// where `ℓ_pen = loss.total()` is the penalised objective at the inner
5911 /// optimum and `½ log|H|` is the Laplace normaliser. `H` is the joint
5912 /// `(t, β)` Hessian assembled by the arrow-Schur system; its `H_tt` block
5913 /// carries `α = exp(log_ard)` on its diagonal, so as α grows `½ log|H|`
5914 /// rises while the `−½·n·log α` already inside `loss.ard` falls — their
5915 /// balance IS the effective-dof term that the deleted `α = n/‖t‖²` rule
5916 /// dropped, which is why the criterion needs no clamp to stay finite on a
5917 /// collapsing axis.
5918 ///
5919 /// The final `−½·p·rank(S)·log λ_smooth` term is the smoothing-penalty
5920 /// normaliser `−½ log|λ S|_+` restricted to its ρ-dependent part: `S_k` is
5921 /// shared across all `p` decoder output channels (the `⊗ I_p` Kronecker
5922 /// structure), so `log|λ S|_+ = p·rank(S)·log λ + p·log|S|_+`, and the
5923 /// `½ p·log|S|_+` piece is ρ-independent. The ρ-independent additive
5924 /// constants that ARE dropped here (they shift `V` by a constant and do not
5925 /// affect the ρ-argmin) are the `2π` Laplace constant and the base
5926 /// `½ p·log|S|_+` penalty logdet. #1421: NO assignment-prior normalizer is
5927 /// dropped, because none exists (softmax/JumpReLU priors are improper — see
5928 /// the doc on this function): the quasi-Laplace score simply omits a
5929 /// normalizer that is not a finite constant.
5930 ///
5931 /// Returns `(V, loss)` so the engine can both rank ρ and surface the inner
5932 /// loss breakdown.
5933 pub fn reml_criterion(
5934 &mut self,
5935 target: ArrayView2<'_, f64>,
5936 rho: &SaeManifoldRho,
5937 registry: Option<&AnalyticPenaltyRegistry>,
5938 inner_max_iter: usize,
5939 learning_rate: f64,
5940 ridge_ext_coord: f64,
5941 ridge_beta: f64,
5942 ) -> Result<(f64, SaeManifoldLoss), String> {
5943 self.reml_criterion_with_refine_policy(
5944 target,
5945 rho,
5946 registry,
5947 inner_max_iter,
5948 learning_rate,
5949 ridge_ext_coord,
5950 ridge_beta,
5951 true,
5952 )
5953 }
5954
5955 pub(crate) fn reml_criterion_with_refine_policy(
5956 &mut self,
5957 target: ArrayView2<'_, f64>,
5958 rho: &SaeManifoldRho,
5959 registry: Option<&AnalyticPenaltyRegistry>,
5960 inner_max_iter: usize,
5961 learning_rate: f64,
5962 ridge_ext_coord: f64,
5963 ridge_beta: f64,
5964 refine_progress_extension: bool,
5965 ) -> Result<(f64, SaeManifoldLoss), String> {
5966 let plan = self.streaming_plan().admitted_or_error(
5967 self.n_obs(),
5968 self.output_dim(),
5969 self.k_atoms(),
5970 )?;
5971 if plan.streaming {
5972 // #1225: streaming and dense MUST optimize the SAME mathematical
5973 // objective — the full REML criterion `loss.total() + extra_penalty +
5974 // ½ log|H| − Occam`. The streaming branch previously returned only
5975 // `loss.total() + extra_penalty_energy`, dropping the Laplace
5976 // normalizer `½ log|H|` and the Occam term, so large shapes (exactly
5977 // where streaming is needed) were ranked by penalized loss rather than
5978 // REML — and dense vs streaming disagreed on the objective. Route
5979 // through the streaming exact-logdet path, which assembles the same
5980 // chunk-by-chunk-bit-identical `½ log|H|_stream` and the same
5981 // `−Occam`/extra-penalty terms as the dense `reml_criterion_with_cache`
5982 // (different memory strategy, same objective).
5983 self.reml_criterion_streaming_exact(
5984 target,
5985 rho,
5986 registry,
5987 inner_max_iter,
5988 learning_rate,
5989 ridge_ext_coord,
5990 ridge_beta,
5991 )
5992 } else {
5993 let (v, loss, _cache) = self.reml_criterion_with_cache_refine_policy(
5994 target,
5995 rho,
5996 registry,
5997 inner_max_iter,
5998 learning_rate,
5999 ridge_ext_coord,
6000 ridge_beta,
6001 refine_progress_extension,
6002 )?;
6003 Ok((v, loss))
6004 }
6005 }
6006
6007 /// As [`Self::reml_criterion`], but also returns the converged undamped
6008 /// `ArrowFactorCache` so callers (the EFS fixed-point step) can read the
6009 /// selected-inverse traces `(H⁻¹)_tt` / `(H⁻¹)_ββ` without re-factoring.
6010 /// The cache is the single shared O(K³) Direct factor; both the
6011 /// log-determinant criterion and the Fellner-Schall ρ-step consume it.
6012 pub fn reml_criterion_with_cache(
6013 &mut self,
6014 target: ArrayView2<'_, f64>,
6015 rho: &SaeManifoldRho,
6016 registry: Option<&AnalyticPenaltyRegistry>,
6017 inner_max_iter: usize,
6018 learning_rate: f64,
6019 ridge_ext_coord: f64,
6020 ridge_beta: f64,
6021 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6022 self.reml_criterion_with_cache_refine_policy(
6023 target,
6024 rho,
6025 registry,
6026 inner_max_iter,
6027 learning_rate,
6028 ridge_ext_coord,
6029 ridge_beta,
6030 true,
6031 )
6032 }
6033
6034 pub(crate) fn reml_criterion_with_cache_refine_policy(
6035 &mut self,
6036 target: ArrayView2<'_, f64>,
6037 rho: &SaeManifoldRho,
6038 registry: Option<&AnalyticPenaltyRegistry>,
6039 inner_max_iter: usize,
6040 learning_rate: f64,
6041 ridge_ext_coord: f64,
6042 ridge_beta: f64,
6043 refine_progress_extension: bool,
6044 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6045 let admission_plan = self.streaming_plan().admitted_or_error(
6046 self.n_obs(),
6047 self.output_dim(),
6048 self.k_atoms(),
6049 )?;
6050 if !admission_plan.direct_logdet_admitted() {
6051 return Err(format!(
6052 "SaeManifoldTerm::reml_criterion_with_cache: predicted working set {} bytes exceeds budget {} bytes for dense evidence cache; shape n={},p={},K={}; cost-only streaming route is required",
6053 admission_plan.estimated_direct_peak_bytes,
6054 admission_plan.in_core_budget_bytes,
6055 self.n_obs(),
6056 self.output_dim(),
6057 self.k_atoms()
6058 ));
6059 }
6060 // 1. Run the inner (t, β) Newton solve to convergence at FIXED ρ.
6061 // `run_joint_fit_arrow_schur` no longer touches ρ.
6062 let mut rho_fixed = rho.clone();
6063 let mut loss = self.run_joint_fit_arrow_schur(
6064 target,
6065 &mut rho_fixed,
6066 registry,
6067 inner_max_iter,
6068 learning_rate,
6069 ridge_ext_coord,
6070 ridge_beta,
6071 )?;
6072
6073 // 2. Drive the inner (t, β) solve to the KKT/step-converged optimum and
6074 // take one final UNDAMPED factor there to obtain the joint Hessian
6075 // log-determinant. We force ridge = 0 and the dense `Direct` Schur
6076 // mode so `arrow_log_det_from_cache` returns the exact
6077 // `log|H| = Σ_i log|H_tt^(i)| + log|Schur_β|` (it rejects damped
6078 // factors and InexactPCG caches, which have no dense Schur factor).
6079 // This is the same evidence convention the main GAM REML path uses.
6080 // The shared `converge_inner_for_undamped_logdet` driver guarantees
6081 // the per-row `H_tt^(i)` blocks are PD at the converged optimum so
6082 // the undamped (`ridge = 0`) factorization succeeds — the streaming
6083 // log-det path reuses the identical driver so both rank the same
6084 // converged Laplace optimum and stay bit-identical.
6085 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
6086 let cache = self.converge_inner_for_undamped_logdet(
6087 target,
6088 rho,
6089 &mut rho_fixed,
6090 registry,
6091 inner_max_iter,
6092 learning_rate,
6093 ridge_ext_coord,
6094 ridge_beta,
6095 &mut loss,
6096 &options,
6097 refine_progress_extension,
6098 )?;
6099 self.record_evidence_gauge_deflation_count(cache.gauge_deflated_directions)?;
6100 loss.evidence_gauge_deflated_directions = cache.gauge_deflated_directions;
6101 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
6102 // Distinguish a GENUINE infeasibility — a probed ρ where the joint
6103 // Hessian is not PD so the Laplace evidence log-det is undefined —
6104 // from a real factorization defect. The cross-row IBP Woodbury
6105 // capacitance `C = I_R + D·Uᵀ H₀'⁻¹ U` can have det ≤ 0 at a ρ the
6106 // outer optimizer line-searches into (the indefinite basin adjacent
6107 // to the PD region); there the log-det legitimately does not exist.
6108 // That refusal must be RECOVERABLE (the outer BFGS should get +∞ and
6109 // steer back into the PD region), exactly like the "non-PD per-row
6110 // H_tt block" refusal — not a fatal `RemlOptimizationFailed` that
6111 // aborts the whole fit. See `is_recoverable_value_probe_refusal`.
6112 // (The old message claimed "no dense Schur factor", which is false
6113 // here — the Schur factor is present; the Woodbury correction is the
6114 // non-finite term.)
6115 if cache.cross_row_woodbury.is_some()
6116 && !cache.cross_row_woodbury_log_det().is_finite()
6117 {
6118 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
6119 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
6120 .to_string()
6121 } else {
6122 "SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None \
6123 (undamped joint Hessian log-det unavailable for the Laplace normaliser)"
6124 .to_string()
6125 }
6126 })?;
6127
6128 // 3. Smoothing-penalty Occam term `−½·Σ_k r_k·rank(S_k)·log λ_smooth`
6129 // plus the profiled-frame evidence-dimension correction
6130 // `+½·Σ_k r_k·(p−r_k)·log λ_smooth` (issue #972). On the full-`B` path
6131 // (`r_k == p`, no frames) this is exactly the historical
6132 // `½·p·(Σ rank S_k)·log λ_smooth`, so the small-model criterion is
6133 // unchanged. The single seam is `reml_occam_term`, shared with the
6134 // streaming path so both rank the identical Laplace dimension count.
6135 let occam = self.reml_occam_term(rho)?;
6136
6137 // Decoder-block analytic-penalty energy (#671/#672). The inner solve
6138 // descended this energy (it enters `gb`/`hbb`) but it had no native
6139 // `loss.*` representative, so the Laplace criterion `v` was scoring a
6140 // different objective than the one minimized. Add the converged
6141 // decoder-penalty value so the ρ-sweep ranks the same penalized
6142 // deviance. Excludes the Psi-tier ARD/assignment penalties already
6143 // accounted for in `loss.total()` (see
6144 // `analytic_decoder_penalty_value_total`).
6145 // Extra analytic-penalty energy (#671/#737). Decoder-block penalties and
6146 // coordinate-tier isometry enter the inner solve but have no `loss.*`
6147 // representative, so the Laplace criterion must add them explicitly to
6148 // rank the same penalized deviance the Newton solve descends.
6149 let extra_penalty_energy = match registry {
6150 Some(reg) => self
6151 .reml_extra_penalty_value_total(reg)
6152 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
6153 None => 0.0,
6154 };
6155
6156 let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
6157 Ok((v, loss, cache))
6158 }
6159
6160 /// The #1037 quotient-dimension invariant: a Laplace normalizer `½log|H|` is
6161 /// only comparable across ρ at a COMMON quotient (gauge-deflation) dimension.
6162 /// The first observation pins the expected count; a later match is a no-op.
6163 ///
6164 /// A later observation that DIFFERS is, under the K>1 fit, a LEGITIMATE
6165 /// quotient-dimension event — an atom born, reseeded (the #976 collapse
6166 /// guards), or rank-reduced moves the number of gauge-flat rows. Because a
6167 /// deflated direction is lifted to unit stiffness and contributes the
6168 /// ρ-independent `log 1 = 0` to the evidence, re-anchoring the comparison to
6169 /// the new dimension is exactly evidence-preserving and keeps every future
6170 /// cross-ρ comparison consistent — the principled response, not an abort.
6171 ///
6172 /// The genuine pathology the guard still catches is a count that NEVER
6173 /// STABILIZES: re-anchors are bounded by the per-atom structural-event budget
6174 /// (`k·(reseed_budget+1)+1`), and a runaway quotient dimension past that
6175 /// bound refuses loudly. This supersedes the prior strict-constant guard and
6176 /// its ±1 flicker band (#1117) at root — the band was masking exactly the
6177 /// legitimate K>1 dimension changes this re-anchoring now handles.
6178 pub(crate) fn record_evidence_gauge_deflation_count(
6179 &mut self,
6180 count: usize,
6181 ) -> Result<(), String> {
6182 match self.expected_evidence_gauge_deflated_directions {
6183 Some(expected) if expected == count => Ok(()),
6184 Some(expected) => {
6185 // A change in the gauge-deflation count between two evidence
6186 // factorizations is a legitimate quotient-dimension event under
6187 // the K>1 fit: an atom can be born, reseeded (the #976 collapse
6188 // guards), or rank-reduced across the ρ-walk, and each such event
6189 // moves the number of gauge-flat rows. The #1037 invariant is
6190 // NOT "the count never changes" — it is "two Laplace normalizers
6191 // are only comparable at a COMMON quotient dimension". The
6192 // principled response to a legitimate change is therefore to
6193 // RE-ANCHOR the comparison to the new dimension (so every future
6194 // cross-ρ comparison within the optimization is consistent), not
6195 // to abort the fit. This is exactly evidence-preserving: each
6196 // gauge-deflated direction is lifted to unit stiffness and
6197 // contributes the ρ-independent `log 1 = 0` to `½log|H|`, so the
6198 // converged criterion value is identical whether a given row is
6199 // counted as deflated or not — only the BOOKKEEPING dimension
6200 // must agree across a comparison, and re-anchoring restores that.
6201 //
6202 // The genuine pathology the guard must still catch is a count
6203 // that NEVER STABILIZES — an OSCILLATING quotient dimension that
6204 // re-anchors without converging, signalling a truly ill-posed
6205 // evidence surface. But the deflation count is NOT a discrete
6206 // dictionary-level event count: it is the per-ROW-summed number of
6207 // near-null evidence directions across all N rows (#1217). On real
6208 // K≥2 activations it is an O(N) quantity that drifts SMOOTHLY and
6209 // monotonically as the conditioning improves over the ρ-walk
6210 // (e.g. 171→156→…→113 as smoothing increases) — a benign,
6211 // evidence-neutral change (each deflated direction contributes the
6212 // ρ-independent `log 1 = 0` to `½log|H|`, so re-anchoring never
6213 // moves the criterion value). Charging such a monotone drift
6214 // against a `k`-sized "structural event" budget was wrong: it
6215 // counts threshold crossings of a continuous per-row quantity, not
6216 // atom births/reseeds, so the budget tripped on a perfectly healthy
6217 // converging K=2 fit (#1217 regression from the #1189/#1190
6218 // basin-escape fixes, which shifted which rows sit near the
6219 // deflation floor).
6220 //
6221 // The principled discriminator is DIRECTION REVERSALS: a count
6222 // that drifts one way and settles is benign; a count that bounces
6223 // up and down without settling is the oscillating-quotient
6224 // pathology. We therefore charge the re-anchor budget ONLY on a
6225 // reversal of the change direction, and size the budget by the
6226 // number of distinct dictionary structural events (births/reseeds)
6227 // that can each legitimately flip the drift direction. A monotone
6228 // drift of any length re-anchors freely (it is consistently
6229 // re-anchored and evidence-neutral); a genuinely oscillating count
6230 // exhausts the reversal budget and refuses loudly.
6231 let delta_sign: i8 = if count > expected { 1 } else { -1 };
6232 let is_reversal = self.evidence_gauge_deflation_last_delta_sign != 0
6233 && delta_sign != self.evidence_gauge_deflation_last_delta_sign;
6234 self.evidence_gauge_deflation_last_delta_sign = delta_sign;
6235 // A reversal alone is NOT the pathology — a BOUNDED flicker of a
6236 // few rows crossing the near-null deflation floor reverses
6237 // direction every step yet is the discretization jitter of a
6238 // continuous evidence spectrum, fully evidence-neutral (each
6239 // deflated direction contributes `log 1 = 0` either way). The
6240 // genuine "quotient dimension not stabilizing" pathology is a
6241 // WIDE-amplitude oscillation: a substantial FRACTION of the
6242 // dimension flipping back and forth. The count is an O(N) per-row
6243 // sum, so the discriminator must be the reversal AMPLITUDE
6244 // relative to the dimension level, not the bare reversal. Charge
6245 // the reversal budget only when a reversal's step exceeds a
6246 // relative jitter band; a converged-but-flickering fit (e.g.
6247 // 150<->147 on N=200, ~2% of the level) re-anchors freely while a
6248 // true runaway (e.g. 9<->2, ~80% of the level) still trips every
6249 // reversal and exhausts the budget. This was the second #795 root
6250 // cause: the single-planted-circle fit's per-row count flickers
6251 // 150<->147 near the deflation floor, so the bare-reversal guard
6252 // refused the simplest possible fit — with the isometry gauge ON
6253 // *or* OFF — long before the gauge magnitude mattered.
6254 let amplitude = expected.abs_diff(count);
6255 let level = expected.max(count);
6256 let jitter_band = (level / 4).max(2);
6257 if is_reversal && amplitude > jitter_band {
6258 self.evidence_gauge_deflation_reanchors += 1;
6259 }
6260 let reversal_budget = self
6261 .k_atoms()
6262 .saturating_mul(
6263 SAE_ATOM_COLLAPSE_RESEED_BUDGET
6264 + SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET
6265 + 1,
6266 )
6267 .saturating_add(1);
6268 if self.evidence_gauge_deflation_reanchors > reversal_budget {
6269 return Err(format!(
6270 "SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count \
6271 oscillated (reversed direction {} times, last {expected}->{count}) within \
6272 one optimization, exceeding the {reversal_budget}-reversal budget for {} \
6273 atoms; the quotient dimension is not stabilizing, refusing to compare \
6274 Laplace normalizers",
6275 self.evidence_gauge_deflation_reanchors,
6276 self.k_atoms()
6277 ));
6278 }
6279 log::debug!(
6280 "SaeManifoldTerm::reml_criterion: per-row evidence deflation count changed \
6281 {expected}->{count} (a benign per-row conditioning drift across the ρ-walk; \
6282 reversal {}/{reversal_budget}); re-anchoring the Laplace normalizer comparison \
6283 to the new dimension",
6284 self.evidence_gauge_deflation_reanchors
6285 );
6286 self.expected_evidence_gauge_deflated_directions = Some(count);
6287 Ok(())
6288 }
6289 None => {
6290 self.expected_evidence_gauge_deflated_directions = Some(count);
6291 Ok(())
6292 }
6293 }
6294 }
6295
6296 pub(crate) fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
6297 matches!(
6298 err,
6299 ArrowSchurError::PerRowFactorFailed { reason, .. }
6300 if reason.contains("H_tt is non-PD at base ridge")
6301 && reason.contains("evidence mode preserves the genuine Cholesky")
6302 )
6303 }
6304
6305 /// Drive the inner `(t, β)` Newton solve to the KKT/step-converged optimum
6306 /// and return the final UNDAMPED (`ridge = 0`) joint-Hessian factor cache.
6307 ///
6308 /// The Laplace normaliser `½log|H|` is only the correct REML criterion at
6309 /// the inner optimum `(t̂, β̂)`, so the criterion must refine the inner state
6310 /// until either the KKT gradient or the undamped Newton step meets tolerance
6311 /// before factoring. Crucially, **at the converged optimum the per-row
6312 /// `H_tt^(i)` blocks are PD**, so the undamped (`ridge = 0`) factorization
6313 /// succeeds; an off-optimum iterate (e.g. the initial seed, or a state
6314 /// stopped after only `inner_max_iter` steps) can have an indefinite /
6315 /// rank-deficient per-row block (`p_out = 1` → rank-1 `JᵀJ`, softmax
6316 /// assignment-sparsity negative logit curvature) that surfaces
6317 /// `PerRowFactorFailed` from the undamped `factor_one_row`. Both the dense
6318 /// (`reml_criterion_with_cache`) and the streaming
6319 /// (`reml_criterion_streaming_exact`) evidence paths route through this same
6320 /// driver, so they converge to the identical inner state and their
6321 /// `ridge = 0` log-determinants stay bit-identical (#847).
6322 pub(crate) fn converge_inner_for_undamped_logdet(
6323 &mut self,
6324 target: ArrayView2<'_, f64>,
6325 rho: &SaeManifoldRho,
6326 rho_fixed: &mut SaeManifoldRho,
6327 registry: Option<&AnalyticPenaltyRegistry>,
6328 inner_max_iter: usize,
6329 learning_rate: f64,
6330 ridge_ext_coord: f64,
6331 ridge_beta: f64,
6332 loss: &mut SaeManifoldLoss,
6333 options: &ArrowSolveOptions,
6334 refine_progress_extension: bool,
6335 ) -> Result<ArrowFactorCache, String> {
6336 // `inner_max_iter == 0` is a genuine FREEZE of the inner `(t, β)` state
6337 // — a verbatim warm-start reuse, not a convergence request (gam#577/#579,
6338 // #850). The convergence/refinement loop below MUST NOT run even one
6339 // Newton step in that case (the old `inner_max_iter.max(1)` floor moved
6340 // β off the seed), so we factor exactly once at the frozen iterate and
6341 // return that undamped cache without invoking the stationarity gate.
6342 // The caller has already run `run_joint_fit_arrow_schur(..., 0, ...)`,
6343 // which under the `max_iter == 0` freeze (gam#577/#579, #850) runs ONLY
6344 // the β-neutral basis refresh and returns the loss without touching β —
6345 // it skips the rank-reduction, frame activation, re-seed guards, and the
6346 // #1026 decoder-LSQ polish that would otherwise refit β off the seed — so
6347 // `self` is at the warm-start β here.
6348 if inner_max_iter == 0 {
6349 let sys = self
6350 .assemble_arrow_schur(target, rho, registry)
6351 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6352 let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
6353 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6354 // The frozen-state Newton step (factored.0, factored.1) is discarded
6355 // — only the undamped factor cache (factored.2) is consumed for the
6356 // log-det / selected-inverse traces; β stays at the warm-start seed.
6357 return Ok(factored.2);
6358 }
6359 let mut total_inner_iter = inner_max_iter;
6360 let accepted_base_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
6361 let value_probe_base_refine_iter = inner_max_iter.max(1).saturating_mul(4).max(16);
6362 let base_refine_iter = if refine_progress_extension {
6363 accepted_base_refine_iter
6364 } else {
6365 value_probe_base_refine_iter
6366 };
6367 let progress_refine_iter = if refine_progress_extension {
6368 inner_max_iter.max(1).saturating_mul(64).max(256)
6369 } else {
6370 base_refine_iter
6371 };
6372 let mut previous_refine_grad_norm: Option<f64> = None;
6373 let mut saw_refine_progress = false;
6374 // #1051 — objective-stagnation convergence. On an ill-conditioned
6375 // penalised bilinear fit (the euclidean / Duchon decoder × latent
6376 // coordinate system on a trivial shape), the inner Newton crawls: each
6377 // refine round lowers the penalised objective by a shrinking amount while
6378 // the KKT gradient and the undamped step stay above their relative
6379 // tolerances (the near-singular Schur amplifies the step in the
6380 // weakly-identified decoder direction). The grad-OR-step gate then never
6381 // fires and the solve is rejected as "did not converge" — the 1e12
6382 // sentinel. A Newton/LM iterate whose objective has stopped decreasing
6383 // to within `√εmach` of its scale IS the numerical inner optimum; ranking
6384 // the Laplace criterion there is correct. We accept that fixed point
6385 // instead of grinding the budget.
6386 let entry_loss_total = loss.total();
6387 let mut previous_loss_total = entry_loss_total;
6388 let mut refine_rounds: usize = 0;
6389 // Consecutive stall rounds: counts how many successive refine rounds
6390 // ended in a stall AND a failed undamped factor. Once this reaches
6391 // `SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS` the iterate is at
6392 // its numerical fixed point and cannot be improved further; returning
6393 // `Err` here is the same "did not converge" signal that
6394 // `is_recoverable_value_probe_refusal` already handles, so the outer
6395 // BFGS treats it as an INFINITY probe and tries a different ρ instead
6396 // of looping forever burning the extended progress budget. Without
6397 // this counter the stagnation handler fell through when the undamped
6398 // factor failed and the loop kept extending via `saw_refine_progress`
6399 // from earlier rounds, accumulating minutes of wasted work (#1094).
6400 let mut consecutive_stall_factor_fail: usize = 0;
6401 loop {
6402 let sys = self
6403 .assemble_arrow_schur(target, rho, registry)
6404 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6405 // Evidence-only factorization: the Newton step (Δt, Δβ) is discarded
6406 // and only the factor cache is consumed — the exact undamped log-det
6407 // and the selected-inverse traces. As ρ sweeps to extremes (e.g. a
6408 // wide ARD-α sweep), H_tt is genuinely PD but can be ill-conditioned;
6409 // the standard Direct guard rejects that to protect Newton-step
6410 // accuracy, but the log-det is exact from diag(L) regardless of the
6411 // condition number and the traces only need the (PD) factor. So
6412 // tolerate the ill-conditioning rejection here (a genuine non-PD pivot
6413 // still errors). The cache stays undamped at ridge=0, so
6414 // `arrow_log_det_from_cache` remains exact.
6415 // The exact KKT stationarity residual is the joint gradient
6416 // ‖g‖ = √(Σ_i ‖g_t^(i)‖² + ‖g_β‖²), read straight off the assembled
6417 // system. Unlike the Newton step Δ = H⁻¹g, the gradient is
6418 // factorisation-independent: it is NOT amplified by an inverse, so a
6419 // genuinely stationary but ill-conditioned fit (tiny g, possibly large
6420 // Δ in a flat direction) is correctly recognised as converged. The
6421 // `with_ill_conditioning_tolerated` Direct factor below documents that
6422 // its Δ may be inaccurate in exactly those flat directions, so using Δ
6423 // alone as the convergence gate would falsely reject healthy fits.
6424 let grad_norm_sq: f64 = sys
6425 .rows
6426 .iter()
6427 .map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
6428 .sum::<f64>()
6429 + sys.gb.iter().map(|&v| v * v).sum::<f64>();
6430 let grad_norm = grad_norm_sq.sqrt();
6431 // Quotient KKT-gradient (#1117): the raw joint gradient retains a
6432 // persistent small component in the chart-gauge orbit and the
6433 // rank-deficient decoder β-null even at a stationary fit, so the raw
6434 // grad gate never clears on a rank-deficient circle and the inner
6435 // refine loop crawls until the (large) progress budget dies — the
6436 // 2-min stall. Measure the gradient on the SAME identified quotient
6437 // the step gate already uses: a fit whose only remaining gradient
6438 // lives in those flat directions is stationary on the quotient, so
6439 // ranking the Laplace criterion there is correct. The dense per-row
6440 // g_t is laid into the `n·q` coordinate layout the gauge basis spans;
6441 // non-dense/heterogeneous systems fall back to the raw norm.
6442 let quotient_grad_norm = {
6443 let n = self.n_obs();
6444 let q = self.assignment.row_block_dim();
6445 let dense_len = n.saturating_mul(q);
6446 let mut grad_ext_coord = Array1::<f64>::zeros(dense_len);
6447 let mut dense_layout_ok = sys.rows.len() == n;
6448 if dense_layout_ok {
6449 for (row_idx, row) in sys.rows.iter().enumerate() {
6450 let base = sys.row_offsets[row_idx];
6451 let di = sys.row_dims[row_idx];
6452 if base + di > dense_len || row.gt.len() < di {
6453 dense_layout_ok = false;
6454 break;
6455 }
6456 for axis in 0..di {
6457 grad_ext_coord[base + axis] = row.gt[axis];
6458 }
6459 }
6460 }
6461 if dense_layout_ok {
6462 self.quotient_gradient_norm_sq(
6463 grad_ext_coord.view(),
6464 sys.gb.view(),
6465 grad_norm_sq,
6466 &rho_fixed.lambda_smooth_vec(),
6467 )
6468 .map(|v| v.sqrt())
6469 .unwrap_or(grad_norm)
6470 } else {
6471 grad_norm
6472 }
6473 };
6474 let iterate_scale = self.inner_iterate_scale();
6475 // Relative parameter-step tolerance for Δ (well-conditioned charts)
6476 // and a scaled KKT-gradient tolerance. Convergence is accepted on
6477 // EITHER a small KKT gradient OR a small undamped Newton step: SAE
6478 // manifold fits contain gauge-like coordinate/decoder directions (the
6479 // circle's rotation gauge, decoder column-space rotations) where the
6480 // shared-block Hessian is near-singular, so the undamped step can stay
6481 // large in that flat direction even at a genuine stationary point; the
6482 // gradient, which is not amplified by the inverse, recognises it. With
6483 // the isometry Gauss-Newton block now a coherent PSD pullback (no
6484 // indefinite Schur pivot), the inner solve reaches true stationarity,
6485 // so the gradient tolerance is a standard relative KKT residual rather
6486 // than the 0.1.154-regression band-aid (3e-3) that masked the
6487 // non-convergence the indefinite curvature caused.
6488 let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
6489 let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
6490 if !grad_norm_sq.is_finite() {
6491 return Err(format!(
6492 "SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
6493 at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
6494 factorisation is degenerate at this ρ"
6495 ));
6496 }
6497 let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
6498 match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
6499 Ok(factored) => factored,
6500 Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
6501 if grad_norm <= grad_tolerance || quotient_grad_norm <= grad_tolerance {
6502 // K>1: the softmax/IBP logit–coordinate Gauss-Newton
6503 // cross-terms (H_zt = J_z^T J_t, assembled row-locally from
6504 // the assignment JVP × basis JVP) can make a per-row H_tt
6505 // indefinite at the TRUE KKT stationary point — when two
6506 // atoms' decoders specialise in opposite directions the
6507 // Schur complement of the logit block goes negative even
6508 // though the priors and the full-joint GN term are PSD.
6509 //
6510 // The undamped evidence factor already conditions that
6511 // block the PRINCIPLED way: `factor_spectral_deflated_
6512 // evidence_row` discovers the negative/flat eigen-direction
6513 // and stiffens it to UNIT curvature (eigenvalue → +1), so it
6514 // contributes a ρ-INDEPENDENT log 1 = 0 to the evidence —
6515 // the same quotient pseudo-determinant convention the gauge
6516 // (#1037) and data-null (#1117) deflations use. Reaching
6517 // THIS arm at stationarity therefore means even the spectral
6518 // deflation declined (a non-finite block or a failed
6519 // eigendecomposition): the state is genuinely broken, so we
6520 // surface the hard refusal and let the outer BFGS treat this
6521 // ρ as an INFINITY probe (`is_recoverable_value_probe_
6522 // refusal`). We must NOT ridge-damp here: a `+ridge·I`
6523 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6524 // bias into the VALUE that the analytic ρ-gradient (built
6525 // for the undamped Laplace log-det) never sees, desyncing
6526 // the outer line-search — the multi-atom non-convergence
6527 // this fix (#1117) removes.
6528 return Err(format!(
6529 "SaeManifoldTerm::reml_criterion: stationary undamped \
6530 evidence factorization has a non-PD per-row H_tt block \
6531 that spectral unit-stiffness deflation could not \
6532 condition (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); \
6533 {err}"
6534 ));
6535 }
6536 let refine_limit = Self::refine_iteration_limit(
6537 total_inner_iter,
6538 base_refine_iter,
6539 progress_refine_iter,
6540 previous_refine_grad_norm,
6541 grad_norm,
6542 saw_refine_progress,
6543 );
6544 if total_inner_iter >= refine_limit {
6545 // #1117/#1118 — pre-stationarity genuinely-indefinite
6546 // non-gauge H_tt under K>1 IBP/softmax row-sharing. The
6547 // logit × coordinate Gauss-Newton cross term H_zt = J_zᵀJ_t
6548 // can drive a shared row's H_tt Schur complement NEGATIVE off
6549 // the gauge orbit; the LM-escalated refinement above cannot
6550 // always cross the indefinite basin into the PD region within
6551 // the descent-extended budget.
6552 //
6553 // The undamped (ridge=0) evidence factor already conditions
6554 // that block the PRINCIPLED way: `factor_spectral_deflated_
6555 // evidence_row` discovers the negative/flat eigen-direction
6556 // and stiffens it to UNIT curvature (eigenvalue → +1), a
6557 // ρ-INDEPENDENT log 1 = 0 evidence contribution — so the
6558 // `Ok(factored)` arm above accepts the indefinite block and
6559 // returns a finite, monotone-comparable value to the outer
6560 // BFGS WITHOUT a ρ-dependent bias. Reaching THIS arm means
6561 // even that spectral deflation declined (a non-finite block
6562 // or a failed eigendecomposition): the iterate is genuinely
6563 // broken, so we surface the hard refusal and let the outer
6564 // BFGS treat this ρ as an INFINITY probe.
6565 //
6566 // We must NOT ridge-damp here: a `+ridge·I` evidence
6567 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6568 // bias into the VALUE that the analytic ρ-gradient (built
6569 // for the undamped Laplace log-det) never sees, desyncing
6570 // the outer line-search — the multi-atom non-convergence this
6571 // fix removes. K=1 (and any already-PD or spectral-deflatable
6572 // K>1 row) never reaches this branch.
6573 return Err(format!(
6574 "SaeManifoldTerm::reml_criterion: undamped evidence \
6575 factorization hit a non-PD per-row H_tt block before KKT \
6576 stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
6577 and the refinement budget was exhausted after \
6578 {total_inner_iter} inner iterations; {err}"
6579 ));
6580 }
6581 let remaining = refine_limit - total_inner_iter;
6582 let refine_iter = inner_max_iter.max(1).min(remaining);
6583 saw_refine_progress |=
6584 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6585 previous_refine_grad_norm = Some(grad_norm);
6586 *loss = self.run_joint_fit_arrow_schur(
6587 target,
6588 rho_fixed,
6589 registry,
6590 refine_iter,
6591 learning_rate,
6592 ridge_ext_coord,
6593 ridge_beta,
6594 )?;
6595 total_inner_iter += refine_iter;
6596 continue;
6597 }
6598 Err(err) => {
6599 return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
6600 }
6601 };
6602 // The Laplace normaliser ½log|H| is only the correct REML criterion at
6603 // the inner optimum (t̂, β̂). Convergence is judged by EITHER a small
6604 // gradient (KKT stationarity) OR a small undamped Newton step; the
6605 // solve is only rejected as non-converged when BOTH are large, i.e.
6606 // the iterate is neither stationary nor about to move negligibly. That
6607 // disjunction is what keeps an ill-conditioned-but-stationary fit
6608 // (small g, large Δ) from being rejected while still refusing to rank
6609 // an off-optimum Laplace criterion that is genuinely mid-flight.
6610 let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
6611 + delta_beta.iter().map(|&v| v * v).sum::<f64>();
6612 if !step_norm_sq.is_finite() {
6613 return Err(format!(
6614 "SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
6615 the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
6616 Hessian factorisation is degenerate at this ρ"
6617 ));
6618 }
6619 let step_norm = step_norm_sq.sqrt();
6620 let quotient_step_norm_sq = self.quotient_newton_step_norm_sq(
6621 delta_t.view(),
6622 delta_beta.view(),
6623 step_norm_sq,
6624 &rho_fixed.lambda_smooth_vec(),
6625 )?;
6626 let quotient_step_norm = quotient_step_norm_sq.sqrt();
6627 // Converge on ANY of: the raw KKT gradient (well-conditioned fit),
6628 // the QUOTIENT KKT gradient (#1117 — rank-deficient fit whose only
6629 // residual gradient is gauge/null flat-direction crawl), or the
6630 // quotient Newton step. The quotient-gradient disjunct is what lets
6631 // a rank-deficient K=1 circle terminate in budget instead of crawling
6632 // the weakly-identified valley until the refine budget dies.
6633 if grad_norm <= grad_tolerance
6634 || quotient_grad_norm <= grad_tolerance
6635 || quotient_step_norm <= step_tolerance
6636 {
6637 return Ok(cache);
6638 }
6639 let refine_limit = Self::refine_iteration_limit(
6640 total_inner_iter,
6641 base_refine_iter,
6642 progress_refine_iter,
6643 previous_refine_grad_norm,
6644 grad_norm,
6645 saw_refine_progress,
6646 );
6647 if total_inner_iter >= refine_limit {
6648 // Inner solve did not converge in reml_criterion; the returned
6649 // Err below carries the full non-convergence diagnostic
6650 // (gradient / quotient-step norms and tolerances) to the caller.
6651 return Err(format!(
6652 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6653 neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
6654 the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
6655 (raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
6656 tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
6657 off-optimum Laplace criterion."
6658 ));
6659 }
6660 let remaining = refine_limit - total_inner_iter;
6661 let refine_iter = inner_max_iter.max(1).min(remaining);
6662 saw_refine_progress |=
6663 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6664 previous_refine_grad_norm = Some(grad_norm);
6665 *loss = self.run_joint_fit_arrow_schur(
6666 target,
6667 rho_fixed,
6668 registry,
6669 refine_iter,
6670 learning_rate,
6671 ridge_ext_coord,
6672 ridge_beta,
6673 )?;
6674 total_inner_iter += refine_iter;
6675 refine_rounds += 1;
6676 // #1051 — objective-stagnation fixed point. A whole refine round that
6677 // failed to lower the penalised objective by a meaningful FRACTION of
6678 // the total since-entry reduction means the Newton/LM iterate is at
6679 // its numerical optimum: the remaining KKT residual lives in the
6680 // weakly-identified decoder / gauge directions the near-singular Schur
6681 // cannot resolve. Ranking the Laplace criterion at this fixed point is
6682 // correct (the only further motion is cosmetic flat-valley crawl), so
6683 // accept the current cache instead of refining until the budget dies.
6684 // Requires a few completed refine rounds (so the fraction baseline is
6685 // meaningful) but is NOT gated behind the full refine budget — the
6686 // whole point is to terminate the crawl long before that.
6687 let new_loss_total = loss.total();
6688 // Two stagnation signals, both required: (1) the latest refine round
6689 // contributed a negligible FRACTION of the total objective reduction
6690 // achieved since entry — the fit has captured essentially all the
6691 // achievable improvement and is now crawling cosmetically along the
6692 // weakly-identified valley; (2) the absolute relative decrease is
6693 // itself tiny. The fraction test is scale- and rate-free (it fires
6694 // whether the crawl decays fast or slow), so it recognises the
6695 // over-smoothed / rank-deficient fixed point the bare relative floor
6696 // misses, while still never firing on a fit that is materially
6697 // improving round over round.
6698 let total_improvement = (entry_loss_total - new_loss_total).max(0.0);
6699 let round_improvement = (previous_loss_total - new_loss_total).max(0.0);
6700 let objective_scale = previous_loss_total.abs().max(new_loss_total.abs()) + 1.0;
6701 let relative_decrease = round_improvement / objective_scale;
6702 let captured_fraction = if total_improvement > 0.0 {
6703 round_improvement / total_improvement
6704 } else {
6705 0.0
6706 };
6707 let stalled = new_loss_total.is_finite()
6708 && relative_decrease.is_finite()
6709 && (relative_decrease < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL
6710 || captured_fraction < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION);
6711 previous_loss_total = new_loss_total;
6712 if stalled && refine_rounds >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6713 let stationary_sys = self
6714 .assemble_arrow_schur(target, rho_fixed, registry)
6715 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6716 if let Ok((_dt, _db, stationary_cache)) =
6717 solve_arrow_newton_step_with_options(&stationary_sys, 0.0, 0.0, options)
6718 {
6719 return Ok(stationary_cache);
6720 }
6721 // Stagnated AND the undamped factor still fails: this is the
6722 // numerical fixed point of the inner solve under rank-deficient
6723 // or ill-conditioned geometry (e.g. multi-atom euclidean with
6724 // near-zero initial latent coords, #1094). The iterate cannot
6725 // be improved further at this ρ. Treat it as "inner solve did
6726 // not converge" — the same signal `is_recoverable_value_probe_refusal`
6727 // already handles, causing the outer BFGS to return INFINITY for
6728 // this ρ probe and try a different one. Without this early
6729 // return the stagnation handler fell through and the loop kept
6730 // burning the extended `progress_refine_iter` budget indefinitely.
6731 consecutive_stall_factor_fail += 1;
6732 if consecutive_stall_factor_fail >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6733 return Err(format!(
6734 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6735 objective stalled for {consecutive_stall_factor_fail} consecutive refine \
6736 rounds (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) and the undamped \
6737 evidence factorization failed at each stall point — the iterate is at the \
6738 numerical fixed point under rank-deficient geometry (#{consecutive_stall_factor_fail} \
6739 stall-factor-fail rounds; refusing to rank an off-optimum Laplace criterion)"
6740 ));
6741 }
6742 } else {
6743 consecutive_stall_factor_fail = 0;
6744 }
6745 }
6746 }
6747
6748 pub(crate) fn refine_iteration_limit(
6749 total_inner_iter: usize,
6750 base_refine_iter: usize,
6751 progress_refine_iter: usize,
6752 previous_grad_norm: Option<f64>,
6753 grad_norm: f64,
6754 saw_refine_progress: bool,
6755 ) -> usize {
6756 // Flat affine-gauge valleys can keep crawling productively after the
6757 // historical base budget. Extend only when the measured KKT residual has
6758 // shown a real finite round-to-round drop; true stalls end at the base
6759 // work budget (#968/#1029). Value-order probes pass the base budget as
6760 // their progress budget, so this branch cannot make probes expensive.
6761 if total_inner_iter < base_refine_iter {
6762 return base_refine_iter;
6763 }
6764 let making_progress =
6765 saw_refine_progress || Self::refine_round_made_progress(previous_grad_norm, grad_norm);
6766 if making_progress && grad_norm.is_finite() {
6767 progress_refine_iter
6768 } else {
6769 base_refine_iter
6770 }
6771 }
6772
6773 pub(crate) fn refine_round_made_progress(
6774 previous_grad_norm: Option<f64>,
6775 grad_norm: f64,
6776 ) -> bool {
6777 previous_grad_norm
6778 .is_some_and(|prev| prev.is_finite() && grad_norm.is_finite() && grad_norm < prev)
6779 }
6780
6781 pub(crate) fn outer_gradient_arrow_solver<'a>(
6782 &'a self,
6783 cache: &'a ArrowFactorCache,
6784 penalized_gram_scale: &[f64],
6785 ) -> Result<DeflatedArrowSolver<'a>, OuterGradientError> {
6786 let Err(conditioning_err) = Self::outer_gradient_conditioning_error(cache) else {
6787 return Ok(DeflatedArrowSolver::plain(cache));
6788 };
6789 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
6790 return Err(conditioning_err);
6791 };
6792 if !(max_pivot.is_finite() && max_pivot > 0.0) {
6793 return Err(conditioning_err);
6794 }
6795
6796 // The conditioning gate has already flagged a near-singular joint Hessian
6797 // (`conditioning_err`). Below we attempt to attribute that flatness to the
6798 // closed-form gauge orbit (chart step gauges) plus the penalty-aware
6799 // decoder-null directions and deflate it. When NO such deflatable
6800 // direction can be recovered, the flat subspace is genuinely
6801 // non-identifiable -- a degenerate direction OUTSIDE the gauge orbit -- a
6802 // diagnosis distinct from the raw pivot-ratio conditioning trip. Both
6803 // classes are #1273 FD-eligible, but surfacing the gauge-degenerate case
6804 // as its own [`OuterGradientError::NonIdentifiable`] keeps the diagnostic
6805 // distinction the FD-eligibility contract is built around.
6806 let non_identifiable_err = OuterGradientError::NonIdentifiable {
6807 reason: format!(
6808 "near-singular joint Hessian with no deflatable gauge/decoder-null \
6809 direction (max pivot {max_pivot:.3e})"
6810 ),
6811 };
6812
6813 let full_len = cache.delta_t_len() + cache.k;
6814 let mut raw_gauges = Vec::new();
6815 for gauge in self
6816 .dense_step_gauge_vectors()
6817 .map_err(OuterGradientError::internal)?
6818 {
6819 if gauge.len() != full_len {
6820 continue;
6821 }
6822 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6823 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6824 continue;
6825 }
6826 raw_gauges.push(gauge);
6827 }
6828 // #1051/#1273: admit the penalty-aware decoder-β null directions as
6829 // additional deflation candidates. A rank-deficient decoder design
6830 // (e.g. a euclidean-1D line in a p=2 ambient: decoder column rank 1 of
6831 // 3) puts a genuine near-null direction of the joint Hessian in the β
6832 // block, OUTSIDE the closed-form chart gauge orbit. #1273: probing the
6833 // RAW unit-β basis `e_j` produced an INCOMPLETE candidate set — the
6834 // true flat direction is the penalised null of `G_k + λ_smooth·S_k`,
6835 // not an axis-aligned coordinate, so the outer gate rejected trial ρ
6836 // with a pivot ratio (5.3e-16 < 1e-12) that the inner gate (which
6837 // already uses `decoder_beta_null_directions(λ_smooth)`) accepts. Use
6838 // the SAME penalty-aware null directions here, evaluated at the smooth
6839 // scale the Schur factor used, so the outer and inner gates agree.
6840 // These full (n·q + beta_dim)-length vectors drop into the same
6841 // Gram-Schmidt + Rayleigh + Faddeev-Popov path below; the Rayleigh
6842 // floor still keeps only genuinely flat (sub-floor) directions, so a
6843 // well-conditioned decoder is unaffected.
6844 for dir in self
6845 .decoder_beta_null_directions(penalized_gram_scale)
6846 .map_err(OuterGradientError::internal)?
6847 {
6848 if dir.len() == full_len {
6849 raw_gauges.push(dir);
6850 }
6851 }
6852 // #1051/#1273: also admit the decoder COLUMN-SPAN null (an unrealised
6853 // ambient output channel of a rank-deficient decoder), which the
6854 // channel-free basis-null above structurally cannot represent. The
6855 // rank-1-decoder-line geometry (e.g. a 1-D euclidean line in p=2
6856 // ambient: decoder column rank 1 of 2) puts the joint Hessian's
6857 // sub-floor pivot entirely in one output channel; without this
6858 // candidate the outer gate had nothing to deflate it with and rejected
6859 // the trial ρ. The Rayleigh floor below still prunes any candidate that
6860 // is not genuinely flat against the cached Hessian.
6861 for dir in self
6862 .decoder_channel_null_directions()
6863 .map_err(OuterGradientError::internal)?
6864 {
6865 if dir.len() == full_len {
6866 raw_gauges.push(dir);
6867 }
6868 }
6869 if raw_gauges.is_empty() {
6870 return Err(non_identifiable_err);
6871 }
6872
6873 let mut gauge_span: Vec<Array1<f64>> = Vec::new();
6874 for mut gauge in raw_gauges {
6875 for basis in &gauge_span {
6876 let coeff = gauge.dot(basis);
6877 for i in 0..gauge.len() {
6878 gauge[i] -= coeff * basis[i];
6879 }
6880 }
6881 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6882 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6883 continue;
6884 }
6885 let inv_norm = norm_sq.sqrt().recip();
6886 for value in gauge.iter_mut() {
6887 *value *= inv_norm;
6888 }
6889 gauge_span.push(gauge);
6890 }
6891 if gauge_span.is_empty() {
6892 return Err(non_identifiable_err);
6893 }
6894
6895 let span_rank = gauge_span.len();
6896 let mut h_span = Array2::<f64>::zeros((span_rank, span_rank));
6897 for col in 0..span_rank {
6898 let h_gauge = match apply_cached_arrow_hessian(
6899 cache,
6900 gauge_span[col].slice(s![..cache.delta_t_len()]),
6901 gauge_span[col].slice(s![cache.delta_t_len()..]),
6902 ) {
6903 Ok(value) => value,
6904 // #1451: a shape/dimension mismatch or non-finite intermediate
6905 // from the Hessian apply is an internal-invariant defect and MUST
6906 // propagate; only a genuine numeric failure on a finite,
6907 // correctly-shaped input keeps the FD-eligible conditioning class.
6908 Err(err) => {
6909 return Err(OuterGradientError::classify_arrow_solver_error(
6910 &err,
6911 conditioning_err.clone(),
6912 ));
6913 }
6914 };
6915 let h_flat = flatten_arrow_parts(h_gauge.t.view(), h_gauge.beta.view());
6916 for row in 0..span_rank {
6917 h_span[[row, col]] = gauge_span[row].dot(&h_flat);
6918 }
6919 }
6920 for row in 0..span_rank {
6921 for col in 0..row {
6922 let sym = 0.5 * (h_span[[row, col]] + h_span[[col, row]]);
6923 h_span[[row, col]] = sym;
6924 h_span[[col, row]] = sym;
6925 }
6926 }
6927 // #1451: a non-finite entry in the projected gauge Hessian is an
6928 // internal-invariant defect (a NaN/Inf intermediate leaked into the
6929 // span), not a conditioning failure — it MUST propagate rather than be
6930 // masked behind an FD descent. Guard finiteness BEFORE the eigh so only a
6931 // genuine decomposition failure on a finite, correctly-shaped matrix keeps
6932 // the FD-eligible conditioning class.
6933 if !h_span.iter().all(|v| v.is_finite()) {
6934 return Err(OuterGradientError::internal(format!(
6935 "outer_gradient_arrow_solver: non-finite entry in projected gauge \
6936 Hessian (h_span is {span_rank}x{span_rank})"
6937 )));
6938 }
6939 let (evals, evecs) = h_span
6940 .eigh(Side::Lower)
6941 .map_err(|_| conditioning_err.clone())?;
6942 let strict_gauge_floor = SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR * max_pivot;
6943 let mut orthonormal: Vec<Array1<f64>> = Vec::new();
6944 for eig_idx in 0..evals.len() {
6945 let rayleigh = evals[eig_idx];
6946 if !(rayleigh.is_finite() && rayleigh <= strict_gauge_floor) {
6947 continue;
6948 }
6949 let mut direction = Array1::<f64>::zeros(full_len);
6950 for basis_idx in 0..span_rank {
6951 let coeff = evecs[[basis_idx, eig_idx]];
6952 for row in 0..full_len {
6953 direction[row] += coeff * gauge_span[basis_idx][row];
6954 }
6955 }
6956 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6957 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6958 continue;
6959 }
6960 let inv_norm = norm_sq.sqrt().recip();
6961 for value in direction.iter_mut() {
6962 *value *= inv_norm;
6963 }
6964 orthonormal.push(direction);
6965 }
6966 if orthonormal.is_empty() {
6967 // #1273/#1440: the conditioning gate has ALREADY certified a
6968 // near-singular joint Hessian (`conditioning_err`), so a genuine flat
6969 // direction exists inside the assembled gauge/decoder-null span even
6970 // when no projected-Hessian eigenvector cleared the strict or the
6971 // `fallback_gauge_floor` Rayleigh band. Rather than declining
6972 // (which historically routed the outer step to a finite-difference
6973 // descent direction — the FD instrument #1440 removes), deflate the
6974 // SMALLEST-Rayleigh eigenvector of the projected gauge Hessian
6975 // UNCONDITIONALLY. That eigenvector is the least-curvature member of
6976 // the validated gauge span (a Faddeev-Popov gauge candidate), so the
6977 // Tikhonov stiffness `max_pivot` in `from_orthonormal_gauges` bounds
6978 // its contribution at the Hessian scale and the components orthogonal
6979 // to it are byte-for-byte the plain analytic inverse solve. This keeps
6980 // the descent direction fully ANALYTIC (a projected/damped gradient),
6981 // never a differenced value path.
6982 let mut best_idx = None;
6983 let mut best_rayleigh = f64::INFINITY;
6984 for eig_idx in 0..evals.len() {
6985 let rayleigh = evals[eig_idx];
6986 if rayleigh.is_finite() && rayleigh < best_rayleigh {
6987 best_idx = Some(eig_idx);
6988 best_rayleigh = rayleigh;
6989 }
6990 }
6991 if let Some(eig_idx) = best_idx {
6992 let mut direction = Array1::<f64>::zeros(full_len);
6993 for basis_idx in 0..span_rank {
6994 let coeff = evecs[[basis_idx, eig_idx]];
6995 for row in 0..full_len {
6996 direction[row] += coeff * gauge_span[basis_idx][row];
6997 }
6998 }
6999 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
7000 if norm_sq.is_finite() && norm_sq > 1.0e-24 {
7001 let inv_norm = norm_sq.sqrt().recip();
7002 for value in direction.iter_mut() {
7003 *value *= inv_norm;
7004 }
7005 orthonormal.push(direction);
7006 }
7007 }
7008 }
7009 if orthonormal.is_empty() {
7010 return Err(non_identifiable_err);
7011 }
7012
7013 // Quotient-geometry gauge fixing: add stiffness only along the closed-form
7014 // gauge orbit (Faddeev-Popov style). Components orthogonal to that orbit
7015 // are identical to the original inverse solve, while gauge components are
7016 // bounded at the Hessian scale `max_pivot`.
7017 // #1451: a shape/length mismatch or non-finite stiffness/intermediate in
7018 // the deflated-solver assembly is an internal-invariant defect and MUST
7019 // propagate; only a genuine near-singular gauge Woodbury/back-solve keeps
7020 // the FD-eligible conditioning class.
7021 DeflatedArrowSolver::from_orthonormal_gauges(cache, orthonormal, max_pivot)
7022 .map_err(|err| OuterGradientError::classify_arrow_solver_error(&err, conditioning_err))
7023 }
7024
7025 pub(crate) fn outer_gradient_conditioning_error(
7026 cache: &ArrowFactorCache,
7027 ) -> Result<(), OuterGradientError> {
7028 let pivot = arrow_factor_min_pivot(cache);
7029 let Some(min_pivot) = pivot.min_pivot else {
7030 return Err(OuterGradientError::IllConditioned {
7031 reason: "joint Hessian numerically singular (no cached Cholesky pivots)"
7032 .to_string(),
7033 });
7034 };
7035 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
7036 return Err(OuterGradientError::IllConditioned {
7037 reason: "joint Hessian numerically singular (no cached Cholesky pivot scale)"
7038 .to_string(),
7039 });
7040 };
7041 let ratio = min_pivot / max_pivot;
7042 if min_pivot.is_finite()
7043 && max_pivot.is_finite()
7044 && max_pivot > 0.0
7045 && ratio.is_finite()
7046 && ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
7047 {
7048 return Ok(());
7049 }
7050 Err(OuterGradientError::IllConditioned {
7051 reason: format!(
7052 "joint Hessian numerically singular (min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, max pivot {max_pivot:.3e})",
7053 floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
7054 ),
7055 })
7056 }
7057
7058 /// Smoothing-penalty Occam normalizer `−½ Σ_k r_k·rank(S_k)·log λ_smooth`
7059 /// PLUS the profiled-frame evidence-dimension term `½ Σ_k r_k·(p−r_k)·log
7060 /// λ_smooth` (issue #972).
7061 ///
7062 /// On the full-`B` path every atom's frame rank `r_k == p`, so the first
7063 /// piece reduces to the historical `½ p·(Σ rank S_k)·log λ_smooth` and the
7064 /// Grassmann term is zero — bit-for-bit unchanged. When a frame is active the
7065 /// decoder coordinates `C_k` carry the `⊗ I_{r_k}` Kronecker structure (the
7066 /// smoothing penalty `S_k` now acts on `r_k` channels, not `p`), so the
7067 /// penalty-logdet normalizer uses `r_k·rank(S_k)`; and the `r_k·(p−r_k)`
7068 /// frame degrees of freedom profiled OUT of the border are counted explicitly
7069 /// in the Laplace dimension accounting (evidence honesty) so the criterion
7070 /// cannot buy a free evidence boost by hiding decoder freedom in the frame.
7071 pub(crate) fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
7072 // #1556: λ_smooth is per-atom, so the Occam penalty normalizer and the
7073 // profiled-frame evidence-dimension term are both per-atom sums, each
7074 // atom `k` weighted by its own `log λ_smooth[k]`. With a uniform
7075 // (broadcast) vector this is bit-for-bit the historical global form.
7076 let mut acc = 0.0_f64;
7077 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7078 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7079 // Penalized decoder dimension: `r_k` coordinate channels carry the
7080 // `S_k` roughness penalty (full-`B` path ⇒ `r_k == p`).
7081 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7082 // Profiled Grassmann dimensions enter the Laplace evidence dimension
7083 // count with the OPPOSITE sign of the penalty Occam term (they are
7084 // free, unpenalized-by-`S` profiled directions), so `−occam` adds
7085 // `+½ r(p−r) log λ_k` to the criterion `V` — the honesty correction.
7086 let frame_dim = atom.frame_manifold_dimension();
7087 let log_lambda = rho.log_lambda_smooth[atom_idx];
7088 acc += 0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)) * log_lambda;
7089 }
7090 // `V = … − occam`, so the net occam SUBTRACTS the penalty normalizer and
7091 // ADDS the frame-dimension count after the caller's `− occam`.
7092 Ok(acc)
7093 }
7094
7095 /// Per-atom derivative `∂(occam)/∂log λ_smooth[k]` (#1556): atom `k`'s entry
7096 /// is `½·(r_k·rank(S_k) − frame_dim_k)`, matching the per-atom Occam term in
7097 /// [`Self::reml_occam_term`]. Returns one entry per atom in atom order.
7098 pub(crate) fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<Vec<f64>, String> {
7099 let mut out = Vec::with_capacity(self.atoms.len());
7100 for atom in &self.atoms {
7101 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7102 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7103 let frame_dim = atom.frame_manifold_dimension();
7104 out.push(0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)));
7105 }
7106 Ok(out)
7107 }
7108
7109 pub fn reml_criterion_streaming_exact(
7110 &mut self,
7111 target: ArrayView2<'_, f64>,
7112 rho: &SaeManifoldRho,
7113 registry: Option<&AnalyticPenaltyRegistry>,
7114 inner_max_iter: usize,
7115 learning_rate: f64,
7116 ridge_ext_coord: f64,
7117 ridge_beta: f64,
7118 ) -> Result<(f64, SaeManifoldLoss), String> {
7119 let mut rho_fixed = rho.clone();
7120 let mut loss = self.run_joint_fit_arrow_schur(
7121 target,
7122 &mut rho_fixed,
7123 registry,
7124 inner_max_iter,
7125 learning_rate,
7126 ridge_ext_coord,
7127 ridge_beta,
7128 )?;
7129 // Drive the inner (t, β) state to the SAME KKT/step-converged optimum the
7130 // dense `reml_criterion_with_cache` reaches before factoring. At that
7131 // optimum the per-row `H_tt^(i)` blocks are PD, so the undamped
7132 // (`ridge_t = 0`) streaming factorization in `streaming_exact_arrow_log_det`
7133 // succeeds — without this, a state stopped after only `inner_max_iter`
7134 // steps can leave a rank-deficient / indefinite row block (`p_out = 1` →
7135 // rank-1 `JᵀJ`, softmax negative-logit curvature) that surfaces
7136 // `PerRowFactorFailed` at base ridge 0. Sharing the driver also keeps the
7137 // streaming and dense log-determinants bit-identical (#847).
7138 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7139 // The dense factor cache from convergence is surplus here — the streaming
7140 // path recomputes the (bit-identical) log-determinant chunk-by-chunk in
7141 // `streaming_exact_arrow_log_det` to bound peak memory — so it is dropped.
7142 let converged_cache = self.converge_inner_for_undamped_logdet(
7143 target,
7144 rho,
7145 &mut rho_fixed,
7146 registry,
7147 inner_max_iter,
7148 learning_rate,
7149 ridge_ext_coord,
7150 ridge_beta,
7151 &mut loss,
7152 &options,
7153 true,
7154 )?;
7155 drop(converged_cache);
7156 let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
7157 let occam = self.reml_occam_term(rho)?;
7158 // Extra analytic-penalty energy (#671/#737), matching the full-batch
7159 // `reml_criterion_with_cache` path so streaming and dense criteria rank
7160 // the identical penalized objective.
7161 let extra_penalty_energy = match registry {
7162 Some(reg) => self
7163 .reml_extra_penalty_value_total(reg)
7164 .map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
7165 None => 0.0,
7166 };
7167 Ok((
7168 loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
7169 loss,
7170 ))
7171 }
7172
7173 pub fn streaming_exact_arrow_log_det(
7174 &mut self,
7175 target: ArrayView2<'_, f64>,
7176 rho: &SaeManifoldRho,
7177 registry: Option<&AnalyticPenaltyRegistry>,
7178 ) -> Result<f64, String> {
7179 if target.dim() != (self.n_obs(), self.output_dim()) {
7180 return Err(format!(
7181 "SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
7182 self.n_obs(),
7183 self.output_dim(),
7184 target.dim()
7185 ));
7186 }
7187 let plan = self.streaming_plan().admitted_or_error(
7188 self.n_obs(),
7189 self.output_dim(),
7190 self.k_atoms(),
7191 )?;
7192 if plan.estimated_dense_schur_bytes > plan.in_core_budget_bytes {
7193 // #988 memory-matrix-free evidence route. The dense k×k reduced Schur
7194 // (≈8 GB at the K=32k manifold border) does NOT fit the in-core
7195 // budget, so estimate log|S| via Stochastic Lanczos Quadrature on the
7196 // matrix-free `schur_matvec` apply (`gam_solve::arrow_schur::
7197 // matrix_free_arrow_evidence_log_det`) instead of assembling +
7198 // Cholesky-factoring the dense Schur. Peak memory is the per-row block
7199 // storage the inner PCG already holds, not the extra O(k²) dense S.
7200 //
7201 // Valid for the NON-IBP (softmax / JumpReLU) evidence, whose exact
7202 // log-det is `log_det_tt + log_det_schur` with NO cross-row Woodbury
7203 // correction. The IBP cross-row term additionally needs
7204 // `log det(I_R + D Uᵀ H₀'⁻¹ U)`, which has no matrix-free route yet, so
7205 // it keeps refusing (loudly, pointing at the dense resident path).
7206 if ibp_assignment_third_channels(&self.assignment, rho)?.is_some() {
7207 return Err(format!(
7208 "SaeManifoldTerm::streaming_exact_arrow_log_det: predicted dense reduced Schur \
7209 {} bytes exceeds budget {} bytes and the exact cross-row IBP Woodbury evidence \
7210 has no matrix-free log-det route yet; route IBP-active large-K fits through the \
7211 dense resident ArrowFactorCache::arrow_log_det",
7212 plan.estimated_dense_schur_bytes, plan.in_core_budget_bytes
7213 ));
7214 }
7215 let n_total = self.n_obs();
7216 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7217 // Assemble the WHOLE system once (a single "chunk" over all rows) so the
7218 // matrix-free reduced-Schur apply `v ↦ S·v` can iterate every row; the
7219 // per-row block storage is exactly what the inner solve already holds.
7220 let full_logits = self.assignment.logits.slice(s![0..n_total, ..]).to_owned();
7221 let full_coords: Vec<Array2<f64>> = self
7222 .assignment
7223 .coords
7224 .iter()
7225 .map(|coord| coord.as_matrix().slice(s![0..n_total, ..]).to_owned())
7226 .collect();
7227 let mut full_chunk = self.materialize_chunk(full_logits, full_coords)?;
7228 if let Some(w) = self.row_loss_weights.as_deref() {
7229 full_chunk.row_loss_weights = Some(w[0..n_total].to_vec());
7230 }
7231 // Full penalty (`penalty_scale = 1.0`): one chunk carries the whole
7232 // objective, matching the summed per-chunk `(end-start)/n_total` scale.
7233 let sys = full_chunk
7234 .assemble_arrow_schur_scaled(target, rho, registry, 1.0)
7235 .map_err(|err| {
7236 format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7237 })?;
7238 let (log_det_tt, slq) = matrix_free_arrow_evidence_log_det(
7239 &sys,
7240 0.0,
7241 0.0,
7242 &options,
7243 SCHUR_SLQ_LOGDET_PROBES,
7244 SCHUR_SLQ_LOGDET_LANCZOS_STEPS,
7245 SCHUR_SLQ_LOGDET_SEED,
7246 )
7247 .map_err(|err| {
7248 format!(
7249 "SaeManifoldTerm::streaming_exact_arrow_log_det: matrix-free evidence log-det: {err:?}"
7250 )
7251 })?;
7252 if !slq.estimate.is_finite() {
7253 return Err(format!(
7254 "SaeManifoldTerm::streaming_exact_arrow_log_det: matrix-free SLQ reduced-Schur \
7255 log|S| non-finite ({})",
7256 slq.estimate
7257 ));
7258 }
7259 return Ok(log_det_tt + slq.estimate);
7260 }
7261 let n_total = self.n_obs();
7262 let chunk_size = plan.chunk_size.min(n_total.max(1));
7263 // #972 / #977 T1: the reduced β-Schur is over the FACTORED border when
7264 // frames are active (each chunk inherits the frames via
7265 // `materialize_chunk`, so every `chunk_schur` is `border_dim²`), matching
7266 // the dense path's factored log-det. Full-`B` ⇒ `border_dim == beta_dim`.
7267 let border_dim = if self.frames_active() {
7268 self.factored_border_dim()
7269 } else {
7270 self.beta_dim()
7271 };
7272 let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
7273 let mut log_det_tt = 0.0_f64;
7274 // #1038 cross-row IBP Woodbury accumulators. `M = Uᵀ H₀'⁻¹ U` is
7275 // chunk-additive in `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` and `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ`
7276 // (`A = H₀'` block-diagonal, `U` row-supported), closed against the
7277 // GLOBAL reduced Schur `S = schur_acc` after the loop. `None` for every
7278 // non-IBP (softmax / JumpReLU) term, where the streaming log-det is
7279 // exactly the bare `log_det_tt + log_det_schur` as before.
7280 let mut wood_m0: Option<Array2<f64>> = None;
7281 let mut wood_w: Option<Array2<f64>> = None;
7282 let mut wood_d: Option<Array1<f64>> = None;
7283 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7284 let mut start = 0usize;
7285 while start < n_total {
7286 let end = (start + chunk_size).min(n_total);
7287 let penalty_scale = (end - start) as f64 / n_total as f64;
7288 let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
7289 let chunk_coords: Vec<Array2<f64>> = self
7290 .assignment
7291 .coords
7292 .iter()
7293 .map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
7294 .collect();
7295 let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
7296 // #1117 — rank deficiency is removed at the basis layer at fit entry
7297 // (`reduce_atoms_to_data_supported_rank`), so each chunk inherits the
7298 // already-reduced full-rank atoms via `materialize_chunk`; there are
7299 // no global deflation projectors to propagate.
7300 // #991: chunk terms inherit the row's design honesty weight slice
7301 // (global mean-1 normalization preserved — NOT re-normalized per
7302 // chunk — so the per-chunk sums reconstruct the global weighted
7303 // objective exactly).
7304 if let Some(w) = self.row_loss_weights.as_deref() {
7305 chunk.row_loss_weights = Some(w[start..end].to_vec());
7306 }
7307 let z_chunk = target.slice(s![start..end, ..]);
7308 let sys = chunk
7309 .assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
7310 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7311 let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
7312 let (chunk_log_det_tt, chunk_schur, chunk_wood) = streaming
7313 .reduced_schur_log_det_tt_woodbury(0.0, 0.0, &options)
7314 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7315 log_det_tt += chunk_log_det_tt;
7316 for row in 0..border_dim {
7317 for col in 0..border_dim {
7318 schur_acc[[row, col]] += chunk_schur[[row, col]];
7319 }
7320 }
7321 if chunk_wood.is_some() && chunk_size < n_total {
7322 // The cross-row IBP empirical mass `M_k = Σ_i z_ik` couples ALL
7323 // rows, so the per-row `H₀'` diagonal (`score_derivative_k(M_k)`)
7324 // and the column coefficient `d_k = w·s'_k(M_k)` are only exact
7325 // when every row is assembled together — a SINGLE chunk. Under a
7326 // genuine multi-chunk pass each chunk would see a partial mass and
7327 // the Woodbury (and the bare per-row log-det) would be inexact, so
7328 // refuse loudly and route to the dense resident path rather than
7329 // return a silently-wrong evidence. The streaming log-det only
7330 // runs when the dense reduced Schur fits budget, so the single-
7331 // chunk regime is the common case; this guards the rest.
7332 return Err(
7333 "SaeManifoldTerm::streaming_exact_arrow_log_det: exact cross-row IBP \
7334 Woodbury evidence requires a single-chunk pass (the empirical mass \
7335 M_k = Σ_i z_ik couples all rows); this shape needs >1 chunk. Route \
7336 IBP-active large-n fits through the dense resident \
7337 ArrowFactorCache::arrow_log_det."
7338 .to_string(),
7339 );
7340 }
7341 if let Some(cw) = chunk_wood {
7342 wood_m0 = Some(match wood_m0.take() {
7343 Some(mut acc) => {
7344 acc += &cw.m0;
7345 acc
7346 }
7347 None => cw.m0,
7348 });
7349 wood_w = Some(match wood_w.take() {
7350 Some(mut acc) => {
7351 acc += &cw.w;
7352 acc
7353 }
7354 None => cw.w,
7355 });
7356 // `D = diag(d_k)` is per-atom; identical across chunks for a
7357 // single-chunk evidence pass (the regime the streaming log-det
7358 // runs in — the dense reduced Schur must fit budget here), where
7359 // it equals the global mass-derived `cross_row_d`.
7360 wood_d = Some(cw.d);
7361 }
7362 start = end;
7363 }
7364 let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
7365 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7366 let mut total = log_det_tt + log_det_schur;
7367 // #1038/#1225: close the exact cross-row IBP Woodbury correction
7368 // `log det(I_R + D Uᵀ H₀'⁻¹ U)` so the streaming evidence equals the
7369 // dense `arrow_log_det_from_cache` (which adds the SAME term). Without
7370 // it the streaming criterion would silently drop the entire cross-row
7371 // coupling and disagree with the dense path by exactly `log|C|`.
7372 if let (Some(m0), Some(w), Some(d)) = (wood_m0, wood_w, wood_d) {
7373 let correction = streaming_cross_row_woodbury_log_det(&schur_acc, &m0, &w, &d)
7374 .map_err(|err| {
7375 format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7376 })?
7377 .ok_or_else(|| {
7378 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
7379 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
7380 .to_string()
7381 })?;
7382 total += correction;
7383 }
7384 Ok(total)
7385 }
7386
7387 /// Per-atom decoder-smoothness penalty quadratic form (#1556): entry `k` is
7388 /// the λ-free `<B_k, ½(S_k+S_kᵀ)·B_k> = Σ_oc B_k[:,oc]ᵀ S_k B_k[:,oc]`, the
7389 /// per-atom denominator of atom `k`'s λ_smooth Fellner-Schall update. The sum
7390 /// over atoms is `βᵀ(⊕_k S_k ⊗ I_p)β`, the un-scaled total penalty energy.
7391 /// `S_k` is symmetrised defensively (as the assembler does); the per-atom
7392 /// `½(S+Sᵀ)·B_k` GEMMs ride the multi-GPU batched smoothness GEMM with an
7393 /// exact per-atom CPU fallback.
7394 pub(crate) fn decoder_smoothness_quadratic_form_per_atom(&self) -> Vec<f64> {
7395 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
7396 .atoms
7397 .iter()
7398 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
7399 .collect();
7400 let sb_all = batched_smooth_sb(&sb_inputs, true);
7401 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7402 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
7403 per_atom[atom_idx] = (&atom.decoder_coefficients * sb).sum();
7404 }
7405 per_atom
7406 }
7407
7408 /// Per-atom effective penalized dof of the decoder smoothness penalty
7409 /// (#1556): entry `k` is `tr(S_β⁻¹ · M_k)` with `M_k = (λ_smooth[k]·S_k) ⊗ I`
7410 /// and `S_β⁻¹ = (H⁻¹)_ββ` the Schur-complement inverse, each atom scaled by
7411 /// its OWN `lambda_smooth[atom_idx]`. Built on
7412 /// [`ArrowFactorCache::schur_inverse_apply`]: column `(k,μ,oc)` of `M_k` is
7413 /// `λ_k·S_k[:,μ] ⊗ e_oc` (sparse), so we apply `S_β⁻¹` to that K-vector and
7414 /// read back `result[col]`. The total edf is the sum of the returned vector
7415 /// (a uniform/broadcast λ reproduces the historical global trace).
7416 ///
7417 /// At `K ≥ SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS` this delegates to the
7418 /// matrix-free Hutchinson estimator (the exact `K·M·p`-solve trace is
7419 /// infeasible at that scale); below it the exact column solve is used
7420 /// unchanged.
7421 pub(crate) fn decoder_smoothness_effective_dof_per_atom(
7422 &self,
7423 cache: &ArrowFactorCache,
7424 lambda_smooth: &[f64],
7425 ) -> Result<Vec<f64>, ArrowSchurError> {
7426 let p = self.output_dim();
7427 let frames_active = self.frames_active();
7428 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7429 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7430 (
7431 self.factored_beta_offsets(),
7432 Box::new(move |k: usize| ranks[k]),
7433 )
7434 } else {
7435 (self.beta_offsets(), Box::new(move |_k: usize| p))
7436 };
7437 let k = cache.k;
7438 if self.atoms.len() >= Self::SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS {
7439 // Massive-K: `Σ_k M_k·r_k` exact solves is infeasible — estimate every
7440 // atom's trace matrix-free with one `S_β⁻¹` solve per Hutchinson probe.
7441 return self
7442 .decoder_smoothness_effective_dof_per_atom_hutchinson(
7443 k,
7444 &offsets,
7445 out_dim.as_ref(),
7446 lambda_smooth,
7447 Self::SMOOTHNESS_DOF_HUTCHINSON_PROBES,
7448 Self::SMOOTHNESS_DOF_HUTCHINSON_SEED,
7449 |rhs| {
7450 cache
7451 .schur_inverse_apply(rhs)
7452 .map_err(|e| format!("schur_inverse_apply: {e:?}"))
7453 },
7454 )
7455 .map_err(|reason| ArrowSchurError::SchurFactorFailed { reason });
7456 }
7457 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7458 let mut m_col = Array1::<f64>::zeros(k);
7459 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7460 let s = &atom.smooth_penalty;
7461 let m = atom.basis_size();
7462 let off = offsets[atom_idx];
7463 let r = out_dim(atom_idx);
7464 let lambda = lambda_smooth[atom_idx];
7465 let mut trace = 0.0_f64;
7466 for mu in 0..m {
7467 for oc in 0..r {
7468 let col = off + mu * r + oc;
7469 m_col.fill(0.0);
7470 for nu in 0..m {
7471 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7472 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7473 }
7474 let z = cache.schur_inverse_apply(m_col.view())?;
7475 trace += z[col];
7476 }
7477 }
7478 per_atom[atom_idx] = trace;
7479 }
7480 Ok(per_atom)
7481 }
7482
7483 /// Per-atom effective penalized dof via the deflated solver (#1556): entry
7484 /// `k` is `tr((H⁻¹)_ββ · M_k)` for `M_k = (λ_smooth[k]·S_k) ⊗ I`, each atom
7485 /// scaled by its OWN `lambda_smooth[atom_idx]`. The total is the sum.
7486 pub(crate) fn decoder_smoothness_effective_dof_with_solver_per_atom(
7487 &self,
7488 cache: &ArrowFactorCache,
7489 solver: &DeflatedArrowSolver<'_>,
7490 lambda_smooth: &[f64],
7491 ) -> Result<Vec<f64>, String> {
7492 let p = self.output_dim();
7493 // #972 / #977 T1: the cache's β block is the FACTORED border when frames
7494 // are active (`cache.k == factored_border_dim`), so the smoothness edf
7495 // trace `tr((H⁻¹)_ββ · M)` is taken over the same factored layout, with
7496 // `M = ⊕_k (λ_k S_k) ⊗ I_{r_k}` at the factored offsets (the `U_kᵀU_k = I`
7497 // collapse means the per-coordinate-channel penalty is `λ_k S_k`, exactly
7498 // as in the full-`B` `⊗ I_p` case but with `r_k` channels). On the
7499 // full-`B` path `frames_active` is false: `out_dim_k = p`, the offsets
7500 // are `beta_offsets`, and this is bit-for-bit the historical trace.
7501 let frames_active = self.frames_active();
7502 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7503 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7504 (
7505 self.factored_beta_offsets(),
7506 Box::new(move |k: usize| ranks[k]),
7507 )
7508 } else {
7509 (self.beta_offsets(), Box::new(move |_k: usize| p))
7510 };
7511 let k = cache.k;
7512 // The t-RHS is identically zero for every β-only smoothness solve; build
7513 // it once instead of re-zeroing a delta_t_len()-sized buffer per column.
7514 let zero_t = Array1::<f64>::zeros(cache.delta_t_len());
7515 if self.atoms.len() >= Self::SMOOTHNESS_DOF_HUTCHINSON_MIN_ATOMS {
7516 // Massive-K matrix-free path: one deflated `(H⁻¹)_ββ` solve per
7517 // Hutchinson probe estimates ALL per-atom traces, replacing the
7518 // `Σ_k M_k·r_k` deflated solves that form the `O(K³·M·p)` wall.
7519 return self.decoder_smoothness_effective_dof_per_atom_hutchinson(
7520 k,
7521 &offsets,
7522 out_dim.as_ref(),
7523 lambda_smooth,
7524 Self::SMOOTHNESS_DOF_HUTCHINSON_PROBES,
7525 Self::SMOOTHNESS_DOF_HUTCHINSON_SEED,
7526 |rhs| Ok(solver.solve(zero_t.view(), rhs)?.beta),
7527 );
7528 }
7529 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7530 let mut m_col = Array1::<f64>::zeros(k);
7531 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7532 let s = &atom.smooth_penalty;
7533 let m = atom.basis_size();
7534 let off = offsets[atom_idx];
7535 let r = out_dim(atom_idx);
7536 let lambda = lambda_smooth[atom_idx];
7537 let mut trace = 0.0_f64;
7538 for mu in 0..m {
7539 for oc in 0..r {
7540 let col = off + mu * r + oc;
7541 // M[:,col] = λ_k · S_k[:,mu] ⊗ e_oc (nonzero at off+ν·r+oc).
7542 m_col.fill(0.0);
7543 for nu in 0..m {
7544 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7545 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7546 }
7547 let z = solver.solve(zero_t.view(), m_col.view())?.beta;
7548 trace += z[col];
7549 }
7550 }
7551 per_atom[atom_idx] = trace;
7552 }
7553 Ok(per_atom)
7554 }
7555
7556 pub(crate) fn assignment_log_strength_hessian_trace(
7557 &self,
7558 rho: &SaeManifoldRho,
7559 cache: &ArrowFactorCache,
7560 solver: &DeflatedArrowSolver<'_>,
7561 ) -> Result<f64, String> {
7562 let k_atoms = self.k_atoms();
7563 // #1038 softmax: `H` carries the DENSE entropy block, and since the
7564 // entropy curvature scales linearly with `λ_sparse = exp(ρ)`,
7565 // `∂H/∂ρ = H_entropy` (the full dense per-row block, not just its
7566 // diagonal). The trace `½ tr(H⁻¹ ∂H/∂ρ)` must therefore contract the
7567 // dense `∂H/∂ρ` against the per-row selected-inverse BLOCK, mirroring the
7568 // dense `log|H|` and θ-adjoint — a diagonal-only contraction would
7569 // desync the ρ-gradient from the criterion. The assembled majorizer
7570 // `D = diag(Σ_j|H_kj|)` is itself DIAGONAL (#1419), so the contraction
7571 // reduces to `½ Σ_slot (H⁻¹)_{slot,slot}·D_atom`. On the dense `None`
7572 // layout the logit slot equals the atom position; on the compact
7573 // softmax top-`k` layout (#1408/#1409) the slots are the row's active
7574 // atoms — the SAME `D_atom` (full-`K` abs-row-sum) the assembly wrote.
7575 if let AssignmentMode::Softmax {
7576 temperature,
7577 sparsity,
7578 } = self.assignment.mode
7579 {
7580 if k_atoms <= 1 {
7581 return Ok(0.0);
7582 }
7583 let inv_tau = 1.0 / temperature;
7584 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
7585 let penalty = gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
7586 k_atoms,
7587 temperature,
7588 );
7589 // Softmax uses the reduced K−1 free-logit chart on the dense layout
7590 // (last reference logit fixed); the compact layout carries one slot
7591 // per active atom. The diagonal selected inverse gives each slot's
7592 // (H⁻¹)_{slot,slot}.
7593 let assignment_dim = self.assignment.assignment_coord_dim();
7594 // Kept-subspace inverse diagonal: the deflated inverse assigns
7595 // `1/λ̃ = 1` to each per-row UNIT-stiffness direction `vᵢ`, so a raw
7596 // diagonal `D` contraction would spuriously add `½ Σ_i vᵢᵀ D vᵢ` (a
7597 // ρ-independent direction must add 0). `latent_inverse_diagonal_kept`
7598 // removes that per-row deflated diagonal centrally.
7599 let inv_diag = solver
7600 .latent_inverse_diagonal_kept()
7601 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7602 let mut trace = 0.0_f64;
7603 for row in 0..self.n_obs() {
7604 let row_base = cache.row_offsets[row];
7605 // ∂(scale·D)/∂ρ = scale·D (linear in λ_sparse = eᵖ) — the SAME
7606 // operator the assembly and θ-adjoint differentiate.
7607 match self.last_row_layout {
7608 Some(ref layout) => {
7609 // #1410: the compact adjoint reads `D_kk` only for this
7610 // row's `≤ top_k` active atoms, so compute those entries
7611 // directly from the softmax row `a` via the active-only
7612 // Gershgorin helper — no full-`K` `row_logits` copy and no
7613 // full-`K` `d` vector. `a` itself is the irreducible `O(K)`
7614 // softmax normalisation, computed once per row and shared
7615 // across the row's active slots.
7616 let a = crate::assignment::softmax_row(
7617 self.assignment.logits.row(row),
7618 temperature,
7619 );
7620 let a = a.as_slice().expect("softmax row must be contiguous");
7621 let m = softmax_majorizer_log_mean(a);
7622 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7623 let d_atom =
7624 active_softmax_gershgorin_majorizer_entry(a, atom, m, scale);
7625 trace += inv_diag[row_base + pos] * d_atom;
7626 }
7627 }
7628 None => {
7629 // Dense layout genuinely contracts every free logit slot's
7630 // `D_kk`, so the full-`K` `d` is intrinsic here; keep the
7631 // single-source dense majorizer call.
7632 let row_logits: Vec<f64> = (0..k_atoms)
7633 .map(|k| self.assignment.logits[[row, k]])
7634 .collect();
7635 let d = penalty.psd_majorizer_abs_row_sums(&row_logits, scale);
7636 let q = cache.row_dims[row];
7637 let logit_dim = assignment_dim.min(q);
7638 for atom in 0..logit_dim {
7639 trace += inv_diag[row_base + atom] * d[atom];
7640 }
7641 }
7642 }
7643 }
7644 return Ok(0.5 * trace);
7645 }
7646 let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
7647 if hdiag.is_empty() {
7648 return Ok(0.0);
7649 }
7650 // RAW selected-inverse diagonal: the per-row diagonal contraction uses the
7651 // DEFLATED inverse; the full kept-subspace + β-Schur/rotation deflation
7652 // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per row afterwards
7653 // (`deflation_block_correction`), exactly as the data trace does. The
7654 // cross-row off-diagonal pass below contracts only DISTINCT rows `i ≠ j`,
7655 // off any single-row `vᵢ`'s support, so it needs no deflation correction.
7656 let inv_diag = solver
7657 .latent_inverse_diagonal()
7658 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7659 let assignment_dim = self.assignment.assignment_coord_dim();
7660 let total_t = cache.delta_t_len();
7661 // #932 FRONT C: row-local Takahashi selected inverse on the plain arrow
7662 // for the per-row deflation correction below (the diagonal trace already
7663 // uses the cheap `latent_inverse_diagonal`); gauge / cross-row Woodbury
7664 // fall back to the per-row full-system `solve` loop.
7665 let fast_selected = solver.plain_selected_inverse_available();
7666 let selected_beta_inv = if fast_selected && cache.k > 0 {
7667 solver
7668 .beta_inv()
7669 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?
7670 } else {
7671 Array2::<f64>::zeros((0, 0))
7672 };
7673 // #1416 cross-row IBP source: the per-row block that the deflation
7674 // factorizes is the NO-SELF base `H₀'` — the rank-one self curvature
7675 // `d_k·J_ik²` is DOWNDATED from each logit diagonal and re-applied through
7676 // the Woodbury carrier. The full-`H` diagonal contraction below still uses
7677 // the full `hdiag` (which carries that self term), but the per-row
7678 // DEFLATION correction must use `(∂H₀'/∂ρ)_tt`, i.e. `hdiag` MINUS the
7679 // downdated self term — otherwise the Daleckii–Krein correction
7680 // mis-attributes the (un-deflated) Woodbury self curvature's derivative to
7681 // the deflated subspace. For non-IBP modes there is no Woodbury source and
7682 // the self term is `0` (the deflated block IS the full block).
7683 // #1416 (compact-layout completion): the IBP cross-row Woodbury source is
7684 // installed for BOTH the dense and the compact (#1420 top-`k`) layouts (see
7685 // `set_ibp_cross_row_source`, which emits `(g_base + pos, atom, z'_ik)` for
7686 // the active set under a compact layout), so the deflated base `H₀'` is the
7687 // no-self block in BOTH layouts. The self-curvature downdate below must
7688 // therefore run regardless of layout — gating it to the dense path (the
7689 // pre-fix bug) left the compact deflation correction differentiating the
7690 // un-downdated full block. For non-IBP modes `ibp_assignment_third_channels`
7691 // returns `None`, there is no Woodbury source, and `self_curv` is
7692 // identically 0 (the deflated block IS the full block).
7693 let cross_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
7694 let learnable_alpha = matches!(
7695 self.assignment.mode,
7696 AssignmentMode::IBPMap {
7697 learnable_alpha: true,
7698 ..
7699 }
7700 );
7701 let self_curv = |row: usize, atom: usize| -> f64 {
7702 let Some(ch) = cross_channels.as_ref() else {
7703 return 0.0;
7704 };
7705 let d_k = if learnable_alpha {
7706 ch.cross_row_d_logalpha[atom]
7707 } else {
7708 ch.cross_row_d[atom]
7709 };
7710 let j = ch.z_jac[row * k_atoms + atom];
7711 d_k * j * j
7712 };
7713 let mut trace = 0.0_f64;
7714 // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback:
7715 // single-entry set/clear instead of a per-column total_t-sized zeroing.
7716 let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
7717 let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
7718 for row in 0..self.n_obs() {
7719 let row_base = cache.row_offsets[row];
7720 let assignment_base = row * k_atoms;
7721 let q = cache.row_dims[row];
7722 // Per-row diagonal `(∂H₀'/∂ρ)_tt` for the deflation correction: the
7723 // assignment prior curves only the logit/assignment slots (coordinate
7724 // slots are 0 — ARD handles those), MINUS the downdated cross-row self
7725 // curvature. The full-`H` trace contraction keeps the full `hdiag`.
7726 let mut d_diag = Array1::<f64>::zeros(q);
7727 match self.last_row_layout {
7728 Some(ref layout) => {
7729 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7730 let d_slot = hdiag[assignment_base + atom];
7731 trace += inv_diag[row_base + pos] * d_slot;
7732 if pos < q {
7733 d_diag[pos] = d_slot - self_curv(row, atom);
7734 }
7735 }
7736 }
7737 None => {
7738 for free_idx in 0..assignment_dim {
7739 let d_slot = hdiag[assignment_base + free_idx];
7740 trace += inv_diag[row_base + free_idx] * d_slot;
7741 if free_idx < q {
7742 d_diag[free_idx] = d_slot - self_curv(row, free_idx);
7743 }
7744 }
7745 }
7746 }
7747 let dirs = cache
7748 .deflated_row_directions
7749 .get(row)
7750 .map(Vec::as_slice)
7751 .unwrap_or(&[]);
7752 if !dirs.is_empty() {
7753 let inv_vv = if fast_selected {
7754 let (inv_vv, _inv_vbeta) = solver
7755 .selected_inverse_row_blocks(row, &selected_beta_inv)
7756 .map_err(|err| {
7757 format!("assignment_log_strength_hessian_trace: selected inverse: {err}")
7758 })?;
7759 inv_vv
7760 } else {
7761 let mut inv_vv = Array2::<f64>::zeros((q, q));
7762 for col in 0..q {
7763 rhs_t_scratch[row_base + col] = 1.0;
7764 let solved = solver
7765 .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
7766 .map_err(|err| {
7767 format!(
7768 "assignment_log_strength_hessian_trace: selected inverse: {err}"
7769 )
7770 })?;
7771 rhs_t_scratch[row_base + col] = 0.0;
7772 for r in 0..q {
7773 inv_vv[[r, col]] = solved.t[row_base + r];
7774 }
7775 }
7776 inv_vv
7777 };
7778 let mut d_mat = Array2::<f64>::zeros((q, q));
7779 for s in 0..q {
7780 d_mat[[s, s]] = d_diag[s];
7781 }
7782 let spectrum = cache
7783 .deflation_row_spectra
7784 .get(row)
7785 .and_then(Option::as_ref);
7786 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
7787 }
7788 }
7789 // #1416: the IBP prior Hessian is `H_p = d·J Jᵀ + diag(s, c)`, where the
7790 // rank-one `d·J Jᵀ` couples EVERY row pair `(i, j)` in a column `k`
7791 // through the shared empirical mass `M_k`. The assembled `H` carries the
7792 // full `H_full = H₀' + U D Uᵀ` (Woodbury, `set_ibp_cross_row_source`), and
7793 // for fixed alpha the entire IBP prior scales with `λ = eᵖ`, so
7794 // `∂H_p/∂ρ = H_p`. The diagonal loop above already captures the `i = j`
7795 // self terms (the `d·J_ik²` summand lives in `hdiag`); this pass adds the
7796 // omitted off-diagonal `½·d_k·Σ_{i≠j}(H⁻¹)_{ik,jk}·J_ik·J_jk`. Only IBP
7797 // has the cross-row rank-one source; for other diagonal modes
7798 // `ibp_assignment_third_channels` returns `None` and the trace stays the
7799 // pure diagonal contraction.
7800 //
7801 // #1416 (compact completion): this pass is LAYOUT-AGNOSTIC. Under the dense
7802 // layout atom `k`'s logit slot is local position `k`
7803 // (`row_offsets[i] + k`); under the compact (#1420 top-`k`) layout only the
7804 // row's active atoms carry coordinates and atom `k` lives at local position
7805 // `pos` of `active_atoms[row]` (`row_offsets[i] + pos`). The Woodbury source
7806 // and the θ-adjoint already use this active-slot mapping, so gating the
7807 // cross-row pass to the dense layout (the pre-fix bug) dropped the
7808 // off-diagonal term from `∂log|H|/∂ρ` whenever the budget/`top_k` engaged
7809 // the compact layout. We build per-column active sites `(row, t_index)` once
7810 // — exactly the θ-adjoint `col_sites` construction — then contract the
7811 // off-diagonal `i ≠ j` remainder with one solve per active site.
7812 if let Some(channels) = cross_channels.as_ref() {
7813 let n = self.n_obs();
7814 let total_t = cache.delta_t_len();
7815 // This trace is ½ ∂log|H|/∂ρ. For FIXED-α IBP the whole prior
7816 // scales with λ=eᵖ so ∂H_p/∂ρ = H_p and the rank-one coefficient
7817 // is the VALUE `cross_row_d[k] = w·s'_k`. For LEARNABLE-α this trace
7818 // is ½ ∂log|H|/∂logα, and the rank-one block's logα-derivative is
7819 // `∂d_k/∂logα = w·∂s'_k/∂logα` (`cross_row_d_logalpha[k]`) — the same
7820 // α-derivative the DIAGONAL channel (`hessian_diag_log_alpha_derivative`)
7821 // already uses. Using the value `s'_k` here (the pre-fix bug) made the
7822 // off-diagonal inconsistent with the diagonal and the α-gradient wrong.
7823 // (`learnable_alpha` is the same flag the self-curvature downdate uses.)
7824 // Per-column active sites `(row, global t-index)`. Layout-agnostic.
7825 let mut col_sites: Vec<Vec<(usize, usize)>> = vec![Vec::new(); k_atoms];
7826 match self.last_row_layout {
7827 Some(ref layout) => {
7828 for row in 0..n {
7829 let base = cache.row_offsets[row];
7830 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7831 col_sites[atom].push((row, base + pos));
7832 }
7833 }
7834 }
7835 None => {
7836 for row in 0..n {
7837 let base = cache.row_offsets[row];
7838 for k in 0..k_atoms {
7839 col_sites[k].push((row, base + k));
7840 }
7841 }
7842 }
7843 }
7844 let mut cross = 0.0_f64;
7845 // Hoisted RHS scratch: each active site sets exactly one t-slot, so
7846 // set-then-clear that single entry rather than allocating and zeroing
7847 // a total_t-sized vector per (column, site).
7848 let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
7849 let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
7850 for k in 0..k_atoms {
7851 let d_k = if learnable_alpha {
7852 channels.cross_row_d_logalpha[k]
7853 } else {
7854 channels.cross_row_d[k]
7855 };
7856 if d_k == 0.0 || col_sites[k].len() < 2 {
7857 continue;
7858 }
7859 for &(i, t_i) in &col_sites[k] {
7860 let j_ik = channels.z_jac[i * k_atoms + k];
7861 if j_ik == 0.0 {
7862 continue;
7863 }
7864 // (H⁻¹) column at row `i`'s active logit-`k` slot.
7865 rhs_t_scratch[t_i] = 1.0;
7866 let solved = solver
7867 .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
7868 .map_err(|err| {
7869 format!("assignment_log_strength_hessian_trace: {err}")
7870 })?;
7871 rhs_t_scratch[t_i] = 0.0;
7872 for &(j, t_j) in &col_sites[k] {
7873 if j == i {
7874 continue;
7875 }
7876 let j_jk = channels.z_jac[j * k_atoms + k];
7877 if j_jk == 0.0 {
7878 continue;
7879 }
7880 cross += d_k * solved.t[t_j] * j_ik * j_jk;
7881 }
7882 }
7883 }
7884 trace += cross;
7885 }
7886 Ok(0.5 * trace)
7887 }
7888
7889 pub(crate) fn learnable_ibp_forward_alpha_data_derivative(
7890 &self,
7891 rho: &SaeManifoldRho,
7892 target: ArrayView2<'_, f64>,
7893 ) -> Result<f64, String> {
7894 let AssignmentMode::IBPMap {
7895 temperature: _,
7896 learnable_alpha: true,
7897 ..
7898 } = self.assignment.mode
7899 else {
7900 return Ok(0.0);
7901 };
7902 let alpha = self
7903 .assignment
7904 .resolved_ibp_alpha(rho)
7905 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
7906 let k_atoms = self.k_atoms();
7907 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
7908 let mut dprior = Array1::<f64>::zeros(k_atoms);
7909 for k in 0..k_atoms {
7910 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
7911 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
7912 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
7913 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
7914 }
7915 let n = self.n_obs();
7916 let p = self.output_dim();
7917 let row_loss_w = self.row_loss_weights.as_deref();
7918 let whitens = self
7919 .row_metric
7920 .as_ref()
7921 .is_some_and(|metric| metric.whitens_likelihood());
7922 let mut decoded = vec![0.0_f64; p];
7923 let mut fitted = Array1::<f64>::zeros(p);
7924 let mut f_rho = Array1::<f64>::zeros(p);
7925 let mut residual = Array1::<f64>::zeros(p);
7926 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
7927 let mut assignments = vec![0.0_f64; k_atoms];
7928 let mut total = 0.0_f64;
7929 for row in 0..n {
7930 self.assignment
7931 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
7932 fitted.fill(0.0);
7933 f_rho.fill(0.0);
7934 for k in 0..k_atoms {
7935 self.atoms[k].fill_decoded_row(row, &mut decoded);
7936 // Ungated (#1026 background-tier) atoms have a force-fixed unit
7937 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
7938 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
7939 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
7940 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
7941 // but `a_k` still varies with α through `π_k`, so it must NOT be
7942 // skipped.)
7943 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
7944 0.0
7945 } else {
7946 (assignments[k] / prior[k]) * dprior[k]
7947 };
7948 for out_col in 0..p {
7949 fitted[out_col] += assignments[k] * decoded[out_col];
7950 f_rho[out_col] += da_rho * decoded[out_col];
7951 }
7952 }
7953 for out_col in 0..p {
7954 residual[out_col] = fitted[out_col] - target[[row, out_col]];
7955 }
7956 let residual_metric = match self.row_metric.as_ref() {
7957 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
7958 _ => residual.to_vec(),
7959 };
7960 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
7961 let mut row_dot = 0.0_f64;
7962 for out_col in 0..p {
7963 row_dot += residual_metric[out_col] * f_rho[out_col];
7964 }
7965 total += row_weight * row_dot;
7966 }
7967 Ok(total)
7968 }
7969
7970 /// Per-row spectral-deflation correction `tr((H⁻¹)_tt · (D − DΦ[D]))` for one
7971 /// evidence ρ-component, to be SUBTRACTED from the raw-derivative trace
7972 /// `tr((H⁻¹)_tt · D)` the trace otherwise accumulates.
7973 ///
7974 /// The criterion VALUE re-deflates each per-row `H_tt` at every ρ, so the
7975 /// correct evidence gradient contracts `(H⁻¹)_tt` against the deflation-map
7976 /// derivative `DΦ[D]`, not the raw `D = (∂H_raw/∂ρ)_tt`. By Daleckii–Krein,
7977 /// in the row's RAW eigenbasis `U`,
7978 /// `DΦ[D] = U (F ∘ (Uᵀ D U)) Uᵀ`, `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)`
7979 /// (raw `λ` in the denominator, conditioned `λ̃` in the numerator; the
7980 /// diagonal / degenerate entry is `f'(λₘ) = 1` for an unclamped kept
7981 /// direction and `0` otherwise). Hence `D − DΦ[D] = U ((1−F) ∘ (Uᵀ D U)) Uᵀ`,
7982 /// whose kept×kept block is `0`, deflated×deflated block is the full `M`, and
7983 /// kept(m)×deflated(i) block carries the ROTATION coefficient
7984 /// `(1−λᵢ)/(λₘ−λᵢ)`. Contracting against the FULL deflated selected-inverse
7985 /// t-block `inv_vv` (which carries the β-Schur back-substitution) captures
7986 /// both the within-row kept-subspace term and the deferred β-Schur/rotation
7987 /// coupling in one pass, matching the re-deflating fixed-state FD oracle.
7988 ///
7989 /// `spectrum = Some` (spectral deflation): exact Daleckii–Krein. `None` with a
7990 /// non-empty `dirs` (gauge-only deflation, ρ-independent structural null):
7991 /// fall back to the within-row kept-subspace term `Σᵢ vᵢᵀ D vᵢ`.
7992 /// `inv_vv` is assumed symmetric (selected inverse of a symmetric PD system).
7993 // #1610 — `pub(crate)` so the ARD/latent-block helpers moved into
7994 // `construction_ard.rs` (pure code move to stay under the 10k-line ban gate)
7995 // can still call this from the sibling module.
7996 pub(crate) fn deflation_block_correction(
7997 inv_vv: &Array2<f64>,
7998 d_mat: &Array2<f64>,
7999 dirs: &[Array1<f64>],
8000 spectrum: Option<&RowDeflationSpectrum>,
8001 ) -> f64 {
8002 let q = inv_vv.nrows();
8003 let Some(spec) = spectrum else {
8004 // Gauge-only deflation: ρ-independent structural null → within-row term.
8005 let mut acc = 0.0_f64;
8006 for v in dirs {
8007 for a in 0..q {
8008 let va = if a < v.len() { v[a] } else { 0.0 };
8009 if va == 0.0 {
8010 continue;
8011 }
8012 for b in 0..q {
8013 let vb = if b < v.len() { v[b] } else { 0.0 };
8014 acc += va * vb * d_mat[[a, b]];
8015 }
8016 }
8017 }
8018 return acc;
8019 };
8020 let u = &spec.evecs;
8021 if u.nrows() != q || u.ncols() != q {
8022 return 0.0;
8023 }
8024 let raw = &spec.raw_evals;
8025 let cond = &spec.cond_evals;
8026 // M = Uᵀ D U, W = Uᵀ inv_vv U (both q×q, symmetric).
8027 let m = u.t().dot(d_mat).dot(u);
8028 let w = u.t().dot(inv_vv).dot(u);
8029 // correction = Σ_{m,l} W[m,l]·M[m,l]·(1 − F[m,l]).
8030 let mut acc = 0.0_f64;
8031 let eps = 1.0e-12;
8032 for a in 0..q {
8033 for b in 0..q {
8034 let denom = raw[a] - raw[b];
8035 let f1 = if denom.abs() > eps {
8036 (cond[a] - cond[b]) / denom
8037 } else if cond[a] == raw[a] {
8038 1.0
8039 } else {
8040 0.0
8041 };
8042 acc += w[[a, b]] * m[[a, b]] * (1.0 - f1);
8043 }
8044 }
8045 acc
8046 }
8047
8048 /// #1417: exact `½ tr(H⁻¹ ∂H_data/∂logα)` for LEARNABLE IBP alpha.
8049 ///
8050 /// The forward assignment is `a_ik = σ(ℓ_ik/τ)·π_k(α)` with the #614
8051 /// consistent stick-breaking mean `π_k(α) = (α/(α+1))^(k+1)`, so
8052 /// `∂logπ_k/∂logα = (k+1)/(α+1)`. EVERY data-Jacobian column for atom `k` —
8053 /// the logit-JVP row (carries one `π_k`), the coordinate rows (carry one
8054 /// `a_k`), and the β-leg (`a_k·φ`) — carries exactly ONE `a_k`/`π_k` factor
8055 /// (`σ(ℓ/τ)` is α-independent). Hence each Jacobian column scales as
8056 /// `∂J_·k/∂logα = ((k+1)/(α+1))·J_·k`, and the data Hessian block for the
8057 /// atom pair `(k_a, k_b)` scales as
8058 /// ∂H_data[a,b]/∂logα = (((k_a+1) + (k_b+1))/(α+1))·H_data[a,b].
8059 /// Therefore the exact data-block contribution to the α-logdet trace is
8060 /// ½ tr(H⁻¹ ∂H_data/∂logα)
8061 /// = ½/(α+1) · Σ_{a,b} ((k_a+1) + (k_b+1))·(H⁻¹)_{ba}·H_data[a,b],
8062 /// over the full joint `(t, β)` index set. `H_data[a,b]` is the data-fit
8063 /// Gauss-Newton block built from the SAME `row_jets_for_logdet` first-jets the
8064 /// θ-adjoint uses (`H_tt = ⟨J_a,J_b⟩`, `H_tβ = ⟨J_a,J_β⟩`, `H_ββ = ⟨J_β,J_β'⟩`),
8065 /// and `(H⁻¹)` is contracted through the same per-row selected-inverse blocks.
8066 /// This closes the learnable-α gradient: combined with the prior-Hessian
8067 /// trace (`assignment_log_strength_hessian_trace`) the full
8068 /// `½ tr(H⁻¹ ∂H/∂logα)` is now assembled. For FIXED alpha (and non-IBP modes)
8069 /// this is identically zero.
8070 pub(crate) fn learnable_ibp_data_logdet_alpha_trace(
8071 &self,
8072 rho: &SaeManifoldRho,
8073 cache: &ArrowFactorCache,
8074 solver: &DeflatedArrowSolver<'_>,
8075 ) -> Result<f64, String> {
8076 let AssignmentMode::IBPMap {
8077 learnable_alpha: true,
8078 ..
8079 } = self.assignment.mode
8080 else {
8081 return Ok(0.0);
8082 };
8083 let alpha = self
8084 .assignment
8085 .resolved_ibp_alpha(rho)
8086 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8087 let inv_alpha1 = 1.0 / (alpha + 1.0);
8088 let n = self.n_obs();
8089 let total_t = cache.delta_t_len();
8090 let second_jets = self.atom_second_jets()?;
8091 let border = self.border_channels_for_cache(cache)?;
8092
8093 // β-tier selected inverse `(H⁻¹)_ββ` (shared across rows). #932 FRONT C:
8094 // on the plain bordered arrow this is the cached dense `S⁻¹` formed once
8095 // (no `K` full-system solves); when a gauge / #1038 cross-row Woodbury is
8096 // active the row-local Takahashi blocks are NOT valid, so we fall back to
8097 // the per-β-coordinate `solve` loop (bit-identical, just O(n) per call).
8098 let fast_selected = solver.plain_selected_inverse_available();
8099 let beta_inv = if cache.k == 0 {
8100 Array2::<f64>::zeros((0, 0))
8101 } else if fast_selected {
8102 solver.beta_inv().map_err(|err| {
8103 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8104 })?
8105 } else {
8106 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8107 let rhs_t = Array1::<f64>::zeros(total_t);
8108 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8109 for col in 0..cache.k {
8110 rhs_beta[col] = 1.0;
8111 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8112 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8113 })?;
8114 rhs_beta[col] = 0.0;
8115 for r in 0..cache.k {
8116 beta_inv[[r, col]] = solved.beta[r];
8117 }
8118 }
8119 beta_inv
8120 };
8121 // Atom index of each β border channel (the `k_b` weight for the β leg).
8122 let border_atom: Vec<usize> = border.iter().map(|c| c.atom).collect();
8123
8124 let mut trace = 0.0_f64;
8125 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8126 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8127 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8128 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8129 // rows fall back to the scalar per-row path (bit-identical either way).
8130 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8131 std::collections::VecDeque::new();
8132 let mut jet_window_next = 0usize;
8133 // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback.
8134 let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
8135 let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
8136 for row in 0..n {
8137 let q = cache.row_dims[row];
8138 let base = cache.row_offsets[row];
8139 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8140 self.assignment
8141 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8142 if jet_window.is_empty() {
8143 jet_window_next = self.refill_jet_window(
8144 rho,
8145 jet_window_next,
8146 cache,
8147 &second_jets,
8148 &border,
8149 &mut jet_window,
8150 )?;
8151 }
8152 let jets = jet_window.pop_front().expect("jet window must be non-empty");
8153 // Atom index (k-weight) of each local t-var.
8154 let var_atom: Vec<usize> = jets
8155 .vars
8156 .iter()
8157 .map(|v| match *v {
8158 SaeLocalRowVar::Logit { atom } => atom,
8159 SaeLocalRowVar::Coord { atom, .. } => atom,
8160 })
8161 .collect();
8162
8163 // Per-row selected inverse blocks `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ`.
8164 // #932 FRONT C: row-local Takahashi (O(q·(q+K))) on the plain arrow;
8165 // per-row full-system `solve` loop (O(n·q)) under gauge / cross-row
8166 // Woodbury where the row-local blocks are not valid.
8167 let (inv_vv, inv_vbeta) = if fast_selected {
8168 solver
8169 .selected_inverse_row_blocks(row, &beta_inv)
8170 .map_err(|err| {
8171 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8172 })?
8173 } else {
8174 let mut inv_vv = Array2::<f64>::zeros((q, q));
8175 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8176 for col in 0..q {
8177 rhs_t_scratch[base + col] = 1.0;
8178 let solved = solver
8179 .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
8180 .map_err(|err| {
8181 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8182 })?;
8183 rhs_t_scratch[base + col] = 0.0;
8184 for r in 0..q {
8185 inv_vv[[r, col]] = solved.t[base + r];
8186 }
8187 for b in 0..cache.k {
8188 inv_vbeta[[col, b]] = solved.beta[b];
8189 }
8190 }
8191 (inv_vv, inv_vbeta)
8192 };
8193
8194 // #1026 — UNGATED (background-tier) atoms have a force-fixed unit gate,
8195 // so their mass `a_k ≡ 1` is α-INDEPENDENT: every data-Jacobian column
8196 // for an ungated atom carries `a_k = 1`, NOT `π_k(α)`, so its α-exponent
8197 // is `e_k = 0`, not `k+1`. Gated atoms keep `e_k = k+1`. (The prior trace
8198 // handles ungated separately by zeroing the fixed-logit `z_jac`.)
8199 let kfac = |atom: usize| -> f64 {
8200 if self.assignment.ungated.get(atom).copied().unwrap_or(false) {
8201 0.0
8202 } else {
8203 (atom + 1) as f64
8204 }
8205 };
8206 // t–t block: Σ_{a,b} (e_a + e_b)·(H⁻¹)_{ba}·⟨J_a, J_b⟩, where the
8207 // per-atom log-prior exponent is e_k = k+1 for the #614 consistent
8208 // stick-breaking mean π_k = (α/(α+1))^(k+1) (dlogπ_k/dlogα = (k+1)·inv_alpha1).
8209 for a in 0..q {
8210 for b in 0..q {
8211 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8212 if h_ab == 0.0 {
8213 continue;
8214 }
8215 let kw = kfac(var_atom[a]) + kfac(var_atom[b]);
8216 trace += kw * inv_vv[[b, a]] * h_ab;
8217 }
8218 }
8219 // Deflation correction (kept-subspace restriction + β-Schur/rotation).
8220 // `inv_vv` is the DEFLATED selected inverse, so the t–t contraction
8221 // above contracts the RAW derivative `D` where the re-deflating
8222 // criterion uses the deflation-map derivative `DΦ[D]`. Subtract the
8223 // exact over-count `tr(inv_vv·(D − DΦ[D]))` via the Daleckii–Krein
8224 // helper, with `D_{ab} = kw_ab·⟨J_a, J_b⟩` the SAME t–t operator the
8225 // trace contracts. The t–β/β–β blocks are not deflated, so only the
8226 // t–t contraction is corrected.
8227 let dirs = cache
8228 .deflated_row_directions
8229 .get(row)
8230 .map(Vec::as_slice)
8231 .unwrap_or(&[]);
8232 if !dirs.is_empty() {
8233 let mut d_mat = Array2::<f64>::zeros((q, q));
8234 for a in 0..q {
8235 for b in 0..q {
8236 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8237 if h_ab == 0.0 {
8238 continue;
8239 }
8240 d_mat[[a, b]] = (kfac(var_atom[a]) + kfac(var_atom[b])) * h_ab;
8241 }
8242 }
8243 let spectrum = cache
8244 .deflation_row_spectra
8245 .get(row)
8246 .and_then(Option::as_ref);
8247 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
8248 }
8249 // t–β and β–t blocks: appear symmetrically, contract once with the
8250 // factor 2 (H, H⁻¹ symmetric; `(H⁻¹)_βt = (H⁻¹)_tβᵀ`).
8251 for a in 0..q {
8252 for (beta_pos, channel) in border.iter().enumerate() {
8253 let h_ab = sae_dot(&jets.first[a], &jets.beta[beta_pos]);
8254 if h_ab == 0.0 {
8255 continue;
8256 }
8257 let kw = kfac(var_atom[a]) + kfac(border_atom[beta_pos]);
8258 trace += 2.0 * kw * inv_vbeta[[a, channel.index]] * h_ab;
8259 }
8260 }
8261 // β–β block: Σ_{β,β'} (k_β + k_β')·(H⁻¹)_{β'β}·⟨J_β, J_β'⟩.
8262 for (beta_i, channel_i) in border.iter().enumerate() {
8263 for (beta_j, channel_j) in border.iter().enumerate() {
8264 let h_ab = sae_dot(&jets.beta[beta_i], &jets.beta[beta_j]);
8265 if h_ab == 0.0 {
8266 continue;
8267 }
8268 let kw = kfac(border_atom[beta_i]) + kfac(border_atom[beta_j]);
8269 trace += kw * beta_inv[[channel_i.index, channel_j.index]] * h_ab;
8270 }
8271 }
8272 }
8273 Ok(0.5 * inv_alpha1 * trace)
8274 }
8275
8276 pub(crate) fn add_learnable_ibp_forward_alpha_data_rhs(
8277 &self,
8278 rho: &SaeManifoldRho,
8279 target: ArrayView2<'_, f64>,
8280 cache: &ArrowFactorCache,
8281 t: &mut Array1<f64>,
8282 beta: &mut Array1<f64>,
8283 ) -> Result<(), String> {
8284 let AssignmentMode::IBPMap {
8285 temperature,
8286 learnable_alpha: true,
8287 ..
8288 } = self.assignment.mode
8289 else {
8290 return Ok(());
8291 };
8292 let alpha = self
8293 .assignment
8294 .resolved_ibp_alpha(rho)
8295 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8296 let k_atoms = self.k_atoms();
8297 let p = self.output_dim();
8298 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8299 let mut dprior = Array1::<f64>::zeros(k_atoms);
8300 for k in 0..k_atoms {
8301 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8302 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8303 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8304 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8305 }
8306 let inv_tau = 1.0 / temperature;
8307 let row_loss_w = self.row_loss_weights.as_deref();
8308 let whitens = self
8309 .row_metric
8310 .as_ref()
8311 .is_some_and(|metric| metric.whitens_likelihood());
8312 let border = self.border_channels_for_cache(cache)?;
8313 let mut decoded_rows = vec![vec![0.0_f64; p]; k_atoms];
8314 let mut decoded_deriv = vec![0.0_f64; p];
8315 let mut fitted = Array1::<f64>::zeros(p);
8316 let mut f_rho = Array1::<f64>::zeros(p);
8317 let mut residual = Array1::<f64>::zeros(p);
8318 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8319 let mut assignments = vec![0.0_f64; k_atoms];
8320 for row in 0..self.n_obs() {
8321 self.assignment
8322 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8323 fitted.fill(0.0);
8324 f_rho.fill(0.0);
8325 for k in 0..k_atoms {
8326 self.atoms[k].fill_decoded_row(row, &mut decoded_rows[k]);
8327 // Ungated (#1026 background-tier) atoms have a force-fixed unit
8328 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8329 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8330 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8331 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8332 // but `a_k` still varies with α through `π_k`, so it must NOT be
8333 // skipped.)
8334 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8335 0.0
8336 } else {
8337 (assignments[k] / prior[k]) * dprior[k]
8338 };
8339 for out_col in 0..p {
8340 fitted[out_col] += assignments[k] * decoded_rows[k][out_col];
8341 f_rho[out_col] += da_rho * decoded_rows[k][out_col];
8342 }
8343 }
8344 for out_col in 0..p {
8345 residual[out_col] = fitted[out_col] - target[[row, out_col]];
8346 }
8347 let residual_metric = match self.row_metric.as_ref() {
8348 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8349 _ => residual.to_vec(),
8350 };
8351 let f_metric = match self.row_metric.as_ref() {
8352 Some(metric) if whitens => metric.apply_metric_row(row, f_rho.view()),
8353 _ => f_rho.to_vec(),
8354 };
8355 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8356 let row_vars = self.row_vars_for_cache_row(row, cache)?;
8357 let row_base = cache.row_offsets[row];
8358 for (pos, var) in row_vars.iter().enumerate() {
8359 let mut contribution = 0.0_f64;
8360 match *var {
8361 SaeLocalRowVar::Logit { atom } => {
8362 let sigma = assignments[atom] / prior[atom];
8363 let sigma_jac = sigma * (1.0 - sigma) * inv_tau;
8364 let da_dl = sigma_jac * prior[atom];
8365 let d_da_rho_dl = sigma_jac * dprior[atom];
8366 for out_col in 0..p {
8367 contribution += da_dl * decoded_rows[atom][out_col] * f_metric[out_col];
8368 contribution += d_da_rho_dl
8369 * decoded_rows[atom][out_col]
8370 * residual_metric[out_col];
8371 }
8372 }
8373 SaeLocalRowVar::Coord { atom, axis } => {
8374 let sigma = assignments[atom] / prior[atom];
8375 let da_rho = sigma * dprior[atom];
8376 self.atoms[atom].fill_decoded_derivative_row(row, axis, &mut decoded_deriv);
8377 for out_col in 0..p {
8378 contribution +=
8379 assignments[atom] * decoded_deriv[out_col] * f_metric[out_col];
8380 contribution +=
8381 da_rho * decoded_deriv[out_col] * residual_metric[out_col];
8382 }
8383 }
8384 }
8385 t[row_base + pos] += row_weight * contribution;
8386 }
8387 for channel in &border {
8388 let phi = self.atoms[channel.atom].basis_values[[row, channel.basis_col]];
8389 let sigma = assignments[channel.atom] / prior[channel.atom];
8390 let da_rho = sigma * dprior[channel.atom];
8391 let mut contribution = 0.0_f64;
8392 for out_col in 0..p {
8393 let output = channel.output[out_col];
8394 contribution += assignments[channel.atom] * phi * output * f_metric[out_col];
8395 contribution += da_rho * phi * output * residual_metric[out_col];
8396 }
8397 beta[channel.index] += row_weight * contribution;
8398 }
8399 }
8400 Ok(())
8401 }
8402
8403 pub(crate) fn border_channels_for_cache(
8404 &self,
8405 cache: &ArrowFactorCache,
8406 ) -> Result<Vec<SaeBorderChannel>, String> {
8407 let p = self.output_dim();
8408 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8409 let offsets = if frames_active {
8410 self.factored_beta_offsets()
8411 } else {
8412 self.beta_offsets()
8413 };
8414 let mut channels = Vec::with_capacity(cache.k);
8415 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8416 let m = atom.basis_size();
8417 let frame = if frames_active {
8418 self.frame_output_matrix(atom_idx)
8419 } else {
8420 Array2::<f64>::eye(p)
8421 };
8422 let r = frame.ncols();
8423 for basis_col in 0..m {
8424 for channel in 0..r {
8425 let mut output = vec![0.0_f64; p];
8426 for out_col in 0..p {
8427 output[out_col] = frame[[out_col, channel]];
8428 }
8429 channels.push(SaeBorderChannel {
8430 atom: atom_idx,
8431 basis_col,
8432 index: offsets[atom_idx] + basis_col * r + channel,
8433 output,
8434 });
8435 }
8436 }
8437 }
8438 if channels.len() != cache.k {
8439 return Err(format!(
8440 "border channel layout has {} entries but cache border has {}",
8441 channels.len(),
8442 cache.k
8443 ));
8444 }
8445 Ok(channels)
8446 }
8447
8448 pub(crate) fn row_vars_for_cache_row(
8449 &self,
8450 row: usize,
8451 cache: &ArrowFactorCache,
8452 ) -> Result<Vec<SaeLocalRowVar>, String> {
8453 let q_row = cache.row_dims[row];
8454 let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
8455 match self.last_row_layout {
8456 Some(ref layout) => {
8457 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8458 vars[pos] = Some(SaeLocalRowVar::Logit { atom });
8459 let start = layout.coord_starts[row][pos];
8460 let d = self.assignment.coords[atom].latent_dim();
8461 for axis in 0..d {
8462 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8463 }
8464 }
8465 }
8466 None => {
8467 let assignment_dim = self.assignment.assignment_coord_dim();
8468 let coord_offsets = self.assignment.coord_offsets();
8469 for atom in 0..assignment_dim {
8470 vars[atom] = Some(SaeLocalRowVar::Logit { atom });
8471 }
8472 for atom in 0..self.k_atoms() {
8473 let start = coord_offsets[atom];
8474 let d = self.assignment.coords[atom].latent_dim();
8475 for axis in 0..d {
8476 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8477 }
8478 }
8479 }
8480 }
8481 vars.into_iter()
8482 .enumerate()
8483 .map(|(idx, v)| {
8484 v.ok_or_else(|| {
8485 format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
8486 })
8487 })
8488 .collect()
8489 }
8490
8491 pub(crate) fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
8492 let mut out = Vec::with_capacity(self.k_atoms());
8493 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8494 let coords = self.assignment.coords[atom_idx].as_matrix();
8495 let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
8496 second.second_jet(coords.view())?
8497 } else {
8498 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
8499 format!(
8500 "logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
8501 atom.name
8502 )
8503 })?;
8504 evaluator
8505 .second_jet_dyn(coords.view())
8506 .ok_or_else(|| {
8507 format!(
8508 "logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
8509 atom.name
8510 )
8511 })??
8512 };
8513 let expected = (
8514 atom.n_obs(),
8515 atom.basis_size(),
8516 atom.latent_dim,
8517 atom.latent_dim,
8518 );
8519 if jet.dim() != expected {
8520 return Err(format!(
8521 "logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
8522 atom.name,
8523 jet.dim(),
8524 expected
8525 ));
8526 }
8527 out.push(jet);
8528 }
8529 Ok(out)
8530 }
8531
8532 // [#780 line-count gate] The per-row jet / reconstruction-channel cluster
8533 // (`reconstruction_row_program_for_logdet`, the const-generic
8534 // reconstruction / β-border channel fills and their dynamic dispatchers,
8535 // `row_jets_for_logdet`, `row_jets_for_logdet_batch4`, `batch4_assemble`,
8536 // and `refill_jet_window`) lives in the sibling
8537 // `construction_row_jet_logdet_channels.rs` file, inlined via `include!`
8538 // below at module scope as a second `impl SaeManifoldTerm` block. Splitting
8539 // it out keeps this tracked file under the 10k limit; `include!` preserves
8540 // the identical module scope and private-field access.
8541
8542 pub(crate) fn assignment_prior_hdiag_derivative_entry(
8543 &self,
8544 rho: &SaeManifoldRho,
8545 row: usize,
8546 diag_atom: usize,
8547 wrt: SaeLocalRowVar,
8548 ibp_channels: Option<&IbpHessianDiagThirdChannels>,
8549 ) -> f64 {
8550 let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
8551 return 0.0;
8552 };
8553 match self.assignment.mode {
8554 AssignmentMode::Softmax { .. } => {
8555 // #1038: the softmax entropy Hessian is now stored DENSE in
8556 // `block.htt` and its full θ-derivative `∂H_{k,j}/∂z_w` (diagonal
8557 // AND off-diagonal) is added inline in `logdet_theta_adjoint` from
8558 // the shared `row_dense_hessian_logit_derivative`. Returning the
8559 // diagonal contribution here too would double-count, so this
8560 // primitive is silent for softmax — the dense path is the single
8561 // source for value, logdet, and adjoint.
8562 0.0
8563 }
8564 AssignmentMode::ThresholdGate {
8565 temperature,
8566 threshold,
8567 } => {
8568 if diag_atom != wrt_atom {
8569 return 0.0;
8570 }
8571 let logit = self.assignment.logits[[row, diag_atom]];
8572 if !crate::assignment::jumprelu_in_optimization_band(
8573 logit,
8574 threshold,
8575 temperature,
8576 ) {
8577 return 0.0;
8578 }
8579 let inv_tau = 1.0 / temperature;
8580 let activation =
8581 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
8582 let slope = activation * (1.0 - activation);
8583 // #1415: P(ℓ)=λσ((ℓ−θ)/τ); P''(ℓ)=(λ/τ²)s(1−2a) so the third
8584 // derivative is P'''(ℓ)=(λ/τ³)·s·(1−6a+6a²), because
8585 // d/dℓ[s(1−2a)] = (1/τ)s[(1−2a)²−2s] = (1/τ)s(1−6a+6a²).
8586 rho.lambda_sparse()
8587 * slope
8588 * (1.0 - 6.0 * activation + 6.0 * activation * activation)
8589 * inv_tau
8590 * inv_tau
8591 * inv_tau
8592 }
8593 AssignmentMode::IBPMap { .. } => {
8594 // The assembled `htt` diagonal consumes
8595 // `IBPAssignmentPenalty::hessian_diag`, whose logit derivative
8596 // splits into a row-local direct-`z` channel and a global
8597 // empirical-`M_k` channel (π_k couples every row in column k).
8598 // This same-row primitive returns only the LOCAL direct-`z`
8599 // channel — and only on the matching logit (`diag_atom == w`),
8600 // since H_ik depends on no other row's z explicitly. The global
8601 // M_k channel is accumulated column-wise in
8602 // `logdet_theta_adjoint` (it needs the per-row selected-inverse
8603 // diagonals), so adding it here would double-count.
8604 if diag_atom != wrt_atom {
8605 return 0.0;
8606 }
8607 match ibp_channels {
8608 Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
8609 None => 0.0,
8610 }
8611 }
8612 }
8613 }
8614
8615 pub(crate) fn ard_majorized_hessian_derivative(
8616 &self,
8617 rho: &SaeManifoldRho,
8618 row: usize,
8619 atom: usize,
8620 axis: usize,
8621 ) -> f64 {
8622 if rho.log_ard[atom].is_empty() {
8623 return 0.0;
8624 }
8625 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8626 let periods = self.assignment.coords[atom].effective_axis_periods();
8627 let t = self.assignment.coords[atom].row(row)[axis];
8628 let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
8629 if prior.hess <= 0.0 {
8630 return 0.0;
8631 }
8632 match periods[axis] {
8633 None => 0.0,
8634 Some(period) => {
8635 let kappa = std::f64::consts::TAU / period;
8636 -alpha * kappa * (kappa * t).sin()
8637 }
8638 }
8639 }
8640
8641 pub fn outer_rho_gradient_ift_rhs(
8642 &self,
8643 rho: &SaeManifoldRho,
8644 target: ArrayView2<'_, f64>,
8645 j: usize,
8646 cache: &ArrowFactorCache,
8647 ) -> Result<SaeArrowVector, String> {
8648 let n_params = rho.to_flat().len();
8649 if j >= n_params {
8650 return Err(format!(
8651 "outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
8652 ));
8653 }
8654 let mut t = Array1::<f64>::zeros(cache.delta_t_len());
8655 let mut beta = Array1::<f64>::zeros(cache.k);
8656 if j == 0 {
8657 let assignment_grad =
8658 assignment_prior_log_strength_target_mixed(&self.assignment, rho)?;
8659 let k_atoms = self.k_atoms();
8660 let assignment_dim = self.assignment.assignment_coord_dim();
8661 for row in 0..self.n_obs() {
8662 let base = cache.row_offsets[row];
8663 let assignment_base = row * k_atoms;
8664 match self.last_row_layout {
8665 Some(ref layout) => {
8666 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8667 t[base + pos] = assignment_grad[assignment_base + atom];
8668 }
8669 }
8670 None => {
8671 for free_idx in 0..assignment_dim {
8672 t[base + free_idx] = assignment_grad[assignment_base + free_idx];
8673 }
8674 }
8675 }
8676 }
8677 self.add_learnable_ibp_forward_alpha_data_rhs(rho, target, cache, &mut t, &mut beta)?;
8678 } else if (1..=rho.log_lambda_smooth.len()).contains(&j) {
8679 // #1556: coordinate `j ∈ 1..=K` is the per-atom smoothness strength
8680 // `log λ_smooth[j-1]`. `∂(penalty)/∂log λ_k = λ_k·S_k C_k` touches ONLY
8681 // atom `k = j-1`'s decoder block; every other atom's RHS is zero.
8682 let target_atom = j - 1;
8683 let lambda = rho.lambda_smooth_for(target_atom);
8684 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8685 let offsets = if frames_active {
8686 self.factored_beta_offsets()
8687 } else {
8688 self.beta_offsets()
8689 };
8690 let atom = &self.atoms[target_atom];
8691 let m = atom.basis_size();
8692 let coeffs = if frames_active {
8693 match &atom.decoder_frame {
8694 Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
8695 None => atom.decoder_coefficients.clone(),
8696 }
8697 } else {
8698 atom.decoder_coefficients.clone()
8699 };
8700 let r = coeffs.ncols();
8701 let off = offsets[target_atom];
8702 for mu in 0..m {
8703 for channel in 0..r {
8704 let mut acc = 0.0_f64;
8705 for nu in 0..m {
8706 let s_sym =
8707 0.5 * (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
8708 acc += s_sym * coeffs[[nu, channel]];
8709 }
8710 beta[off + mu * r + channel] = lambda * acc;
8711 }
8712 }
8713 } else {
8714 let mut cursor = 1 + rho.log_lambda_smooth.len();
8715 for atom in 0..rho.log_ard.len() {
8716 for axis in 0..rho.log_ard[atom].len() {
8717 if cursor == j {
8718 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8719 let periods = self.assignment.coords[atom].effective_axis_periods();
8720 for row in 0..self.n_obs() {
8721 let row_t = self.assignment.coords[atom].row(row);
8722 let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
8723 let Some(pos) = sae_coord_penalty_offset(
8724 self.last_row_layout.as_ref(),
8725 self.assignment.coord_offsets()[atom] + axis,
8726 row,
8727 atom,
8728 ) else {
8729 continue;
8730 };
8731 t[cache.row_offsets[row] + pos] = prior.grad;
8732 }
8733 return Ok(SaeArrowVector { t, beta });
8734 }
8735 cursor += 1;
8736 }
8737 }
8738 }
8739 Ok(SaeArrowVector { t, beta })
8740 }
8741
8742 pub(crate) fn logdet_theta_adjoint(
8743 &self,
8744 rho: &SaeManifoldRho,
8745 cache: &ArrowFactorCache,
8746 solver: &DeflatedArrowSolver<'_>,
8747 ) -> Result<SaeArrowVector, String> {
8748 // Γ_a = tr(H⁻¹ ∂H/∂θ_a) over the inner variables θ (#1006). `H` here is
8749 // the SAME object the evidence factor builds — Gauss-Newton data
8750 // curvature plus the prior majorizers / `hessian_diag` diagonals the
8751 // Newton/Schur Cholesky factorizes — so each block's θ-derivative channel
8752 // is differentiated on the criterion's own branch (no value/gradient
8753 // desync). The IBP-MAP assignment prior is the one block whose
8754 // `hessian_diag` couples every row in a column through the plug-in
8755 // empirical mass `M_k = Σ_i z_ik`; its logit derivative therefore has a
8756 // row-local channel (handled inline via
8757 // `assignment_prior_hdiag_derivative_entry`) and a cross-row channel
8758 // (accumulated column-wise after the row loop, below).
8759 let n = self.n_obs();
8760 let total_t = cache.delta_t_len();
8761 let mut gamma_t = Array1::<f64>::zeros(total_t);
8762 let mut gamma_beta = Array1::<f64>::zeros(cache.k);
8763 let second_jets = self.atom_second_jets()?;
8764 let border = self.border_channels_for_cache(cache)?;
8765 // #932 FRONT C: plain-arrow `(H⁻¹)_ββ = S⁻¹` formed once from the cached
8766 // Schur factor; gauge / #1038 cross-row Woodbury fall back to the per-β
8767 // `solve` loop where the row-local Takahashi blocks are not valid.
8768 let fast_selected = solver.plain_selected_inverse_available();
8769 let beta_inv = if cache.k == 0 {
8770 Array2::<f64>::zeros((0, 0))
8771 } else if fast_selected {
8772 solver
8773 .beta_inv()
8774 .map_err(|err| format!("logdet_theta_adjoint: beta selected inverse: {err}"))?
8775 } else {
8776 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8777 let rhs_t = Array1::<f64>::zeros(total_t);
8778 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8779 for col in 0..cache.k {
8780 rhs_beta[col] = 1.0;
8781 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8782 format!("logdet_theta_adjoint: beta selected inverse solve: {err}")
8783 })?;
8784 rhs_beta[col] = 0.0;
8785 for row in 0..cache.k {
8786 beta_inv[[row, col]] = solved.beta[row];
8787 }
8788 }
8789 beta_inv
8790 };
8791 // IBP `hessian_diag` logit third-derivative channels (#1006). The full
8792 // IBP Hessian also has per-column cross-row rank-one terms
8793 // `H_(i,k),(j,k) = d_k·J_ik·J_jk`; these ARE carried in `H` via the #1038
8794 // Woodbury source (`IbpCrossRowSource`, construction.rs:4710-4752), the
8795 // ρ-trace differentiates them (#1416,
8796 // `assignment_log_strength_hessian_trace`), AND this θ-adjoint now
8797 // differentiates them exactly too: the empirical-`M_k` channel below
8798 // contracts the shared-mass coupling of the DIAGONAL curvature, and the
8799 // cross-row Woodbury pass (further below, using `cross_row_dd` and
8800 // `logit_curvature`) contracts the `∂/∂ℓ_w (d_k·J_ik·J_jk)` rank-one
8801 // derivative — so value, logdet, ρ-trace, and θ-adjoint all differentiate
8802 // the one operator `H = H₀ + Σ_k d_k u_k u_kᵀ`.
8803 let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
8804 let k_atoms = self.k_atoms();
8805 // #1038 softmax entropy: the dense per-row entropy Hessian written into
8806 // `block.htt` has off-diagonal logit terms whose θ-derivative the adjoint
8807 // must contract too (not just the diagonal). Build the SAME penalty +
8808 // `scale = λ/τ²` the assembly uses so value/logdet/adjoint differentiate
8809 // one operator. `None` for non-softmax modes (their diagonal/cross-row
8810 // channels are handled by `assignment_prior_hdiag_derivative_entry` and
8811 // the IBP column pass).
8812 let softmax_dense_adjoint: Option<(
8813 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
8814 f64,
8815 )> = match self.assignment.mode {
8816 AssignmentMode::Softmax {
8817 temperature,
8818 sparsity,
8819 } if k_atoms > 1 => {
8820 let inv_tau = 1.0 / temperature;
8821 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
8822 Some((
8823 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
8824 k_atoms,
8825 temperature,
8826 ),
8827 scale,
8828 ))
8829 }
8830 _ => None,
8831 };
8832 // Per active logit position: (row i, column k, global t-index,
8833 // (H⁻¹)_ik,ik) — the inputs to the IBP cross-row empirical-`M_k` channel.
8834 let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
8835
8836 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8837 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8838 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8839 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8840 // rows fall back to the scalar per-row path (bit-identical either way).
8841 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8842 std::collections::VecDeque::new();
8843 let mut jet_window_next = 0usize;
8844 // Hoisted RHS scratch for the gauge/Woodbury per-row solve fallback.
8845 let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
8846 let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
8847 for row in 0..n {
8848 let q = cache.row_dims[row];
8849 let base = cache.row_offsets[row];
8850 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8851 self.assignment
8852 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8853 if jet_window.is_empty() {
8854 jet_window_next = self.refill_jet_window(
8855 rho,
8856 jet_window_next,
8857 cache,
8858 &second_jets,
8859 &border,
8860 &mut jet_window,
8861 )?;
8862 }
8863 let jets = jet_window.pop_front().expect("jet window must be non-empty");
8864
8865 // #932 FRONT C: row-local Takahashi on the plain arrow; per-row
8866 // full-system `solve` loop under gauge / cross-row Woodbury.
8867 let (inv_vv, inv_vbeta) = if fast_selected {
8868 solver
8869 .selected_inverse_row_blocks(row, &beta_inv)
8870 .map_err(|err| {
8871 format!("logdet_theta_adjoint: selected inverse: {err}")
8872 })?
8873 } else {
8874 let mut inv_vv = Array2::<f64>::zeros((q, q));
8875 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8876 for col in 0..q {
8877 rhs_t_scratch[base + col] = 1.0;
8878 let solved = solver
8879 .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
8880 .map_err(|err| {
8881 format!("logdet_theta_adjoint: selected inverse solve: {err}")
8882 })?;
8883 rhs_t_scratch[base + col] = 0.0;
8884 for r in 0..q {
8885 inv_vv[[r, col]] = solved.t[base + r];
8886 }
8887 for b in 0..cache.k {
8888 inv_vbeta[[col, b]] = solved.beta[b];
8889 }
8890 }
8891 (inv_vv, inv_vbeta)
8892 };
8893
8894 // Record each active logit's column, global t-index, and
8895 // selected-inverse diagonal (H⁻¹)_ik,ik for the IBP cross-row pass.
8896 if ibp_channels.is_some() {
8897 for (pos, var) in jets.vars.iter().enumerate() {
8898 if let SaeLocalRowVar::Logit { atom } = *var {
8899 ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
8900 }
8901 }
8902 }
8903
8904 // #1419: when `w` is a logit and the assignment is softmax, the per-row
8905 // Gershgorin majorizer `D = diag(Σ_j|H_kj|)` is what the assembly wrote
8906 // into `htt` (the genuine Loewner majorizer that replaces the indefinite
8907 // exact entropy Hessian). Its full θ-derivative `∂D_{k,k}/∂z_w` (diagonal;
8908 // `∂D_kk/∂z_w = Σ_j sign(H_kj)·∂H_kj/∂z_w`) is the SAME operator the
8909 // assembly and logdet now differentiate, so value and adjoint stay on ONE
8910 // exact branch. Compute it once per logit `w` and add it at every logit
8911 // pair `(a,b)` below. The diagonal softmax case is therefore handled here,
8912 // NOT in `assignment_prior_hdiag_derivative_entry` (which returns 0 for
8913 // softmax to avoid double-counting).
8914 // #1410: the softmax majorizer θ-derivative `∂D_kk/∂z_w` is DIAGONAL
8915 // (`D` is diagonal), and the compact adjoint reads it only for this
8916 // row's `≤ top_k` active atoms. Compute the needed diagonal entry
8917 // directly from the softmax row `a` (= `assignments`, in hand) via
8918 // `active_softmax_majorizer_logit_derivative_entry`, instead of the old
8919 // per-(row, logit) full `K×K` `row_psd_majorizer_logit_derivative`
8920 // allocation. `m = Σ_j a_j l_j` is shared across all `(w, k)` pairs of
8921 // the row, so compute it once. `inv_tau` carries the softmax `∂a/∂z`
8922 // convention.
8923 let softmax_adjoint_row: Option<(&[f64], f64, f64, f64)> =
8924 match (softmax_dense_adjoint.as_ref(), self.assignment.mode) {
8925 (Some((_penalty, scale)), AssignmentMode::Softmax { temperature, .. }) => {
8926 let a = assignments
8927 .as_slice()
8928 .expect("softmax assignments row must be contiguous");
8929 let m = softmax_majorizer_log_mean(a);
8930 Some((a, m, *scale, 1.0 / temperature))
8931 }
8932 _ => None,
8933 };
8934 // Per-row UNIT-stiffness deflated directions: the selected inverse
8935 // `inv_vv` is the DEFLATED inverse (it assigns `1/λ̃ = 1` to each
8936 // `vᵢ`), so every `inv_vv`-weighted t–t contraction of `∂H/∂θ_w`
8937 // below spuriously contracts the RAW derivative where the re-deflating
8938 // criterion uses the deflation-map derivative `DΦ`. The kept-subspace Γ
8939 // subtracts `tr(inv_vv·(D − DΦ[D]))` over the t–t block via the same
8940 // Daleckii–Krein helper the ρ-traces use (the t–β / β–β blocks are not
8941 // deflated). `θ` enters only the per-row block (no cross-row Woodbury
8942 // self-downdate on the θ path), so the raw t–t derivative `D` is used
8943 // directly.
8944 let defl_dirs = cache
8945 .deflated_row_directions
8946 .get(row)
8947 .map(Vec::as_slice)
8948 .unwrap_or(&[]);
8949 let defl_spectrum = cache
8950 .deflation_row_spectra
8951 .get(row)
8952 .and_then(Option::as_ref);
8953 for w in 0..q {
8954 let mut gamma = 0.0_f64;
8955 // The active logit `w` differentiates against; `None` unless this
8956 // slot is a softmax logit on the softmax path.
8957 let softmax_d_dw: Option<(&[f64], f64, f64, f64, usize)> =
8958 match (softmax_adjoint_row, jets.vars[w]) {
8959 (Some((a, m, scale, inv_tau)), SaeLocalRowVar::Logit { atom: atom_w }) => {
8960 Some((a, m, scale, inv_tau, atom_w))
8961 }
8962 _ => None,
8963 };
8964 let mut dh_mat = Array2::<f64>::zeros((q, q));
8965 for a in 0..q {
8966 for b in 0..q {
8967 let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
8968 + sae_dot(&jets.first[a], &jets.second[b][w]);
8969 // `∂D/∂z_w` is diagonal, so it contributes only when the two
8970 // logit slots are the SAME atom (`atom_a == atom_b`).
8971 if let (
8972 Some((a_soft, m, scale, inv_tau, _atom_w)),
8973 SaeLocalRowVar::Logit { atom: atom_a },
8974 SaeLocalRowVar::Logit { atom: atom_b },
8975 ) = (softmax_d_dw, jets.vars[a], jets.vars[b])
8976 {
8977 if atom_a == atom_b {
8978 dh += active_softmax_majorizer_logit_derivative_entry(
8979 a_soft, atom_a, _atom_w, m, scale, inv_tau,
8980 );
8981 }
8982 }
8983 if a == b {
8984 dh += match jets.vars[a] {
8985 SaeLocalRowVar::Logit { atom } => self
8986 .assignment_prior_hdiag_derivative_entry(
8987 rho,
8988 row,
8989 atom,
8990 jets.vars[w],
8991 ibp_channels.as_ref(),
8992 ),
8993 SaeLocalRowVar::Coord { atom, axis } if a == w => {
8994 self.ard_majorized_hessian_derivative(rho, row, atom, axis)
8995 }
8996 _ => 0.0,
8997 };
8998 }
8999 dh_mat[[a, b]] = dh;
9000 gamma += inv_vv[[b, a]] * dh;
9001 }
9002 }
9003 if !defl_dirs.is_empty() {
9004 gamma -= Self::deflation_block_correction(
9005 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9006 );
9007 }
9008 for a in 0..q {
9009 for (beta_pos, channel) in border.iter().enumerate() {
9010 let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
9011 + sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
9012 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9013 }
9014 }
9015 for (beta_i, channel_i) in border.iter().enumerate() {
9016 for (beta_j, channel_j) in border.iter().enumerate() {
9017 let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
9018 + sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
9019 gamma += beta_inv[[channel_i.index, channel_j.index]] * dh;
9020 }
9021 }
9022 gamma_t[base + w] = gamma;
9023 }
9024
9025 for (w_beta_pos, w_channel) in border.iter().enumerate() {
9026 let mut gamma = 0.0_f64;
9027 let mut dh_mat = Array2::<f64>::zeros((q, q));
9028 for a in 0..q {
9029 for b in 0..q {
9030 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
9031 + sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
9032 dh_mat[[a, b]] = dh;
9033 gamma += inv_vv[[b, a]] * dh;
9034 }
9035 }
9036 if !defl_dirs.is_empty() {
9037 gamma -= Self::deflation_block_correction(
9038 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9039 );
9040 }
9041 for a in 0..q {
9042 for (beta_pos, channel) in border.iter().enumerate() {
9043 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
9044 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9045 }
9046 }
9047 gamma_beta[w_channel.index] += gamma;
9048 }
9049 }
9050
9051 // IBP cross-row empirical-`M_k` channel of Γ (#1006). The assembled
9052 // diagonal H_ik consumes `hessian_diag`, whose dependence on the column
9053 // mass M_k = Σ_i z_ik couples every row in a column. Differentiating
9054 // tr(H⁻¹ ∂H/∂ℓ_wk) on that shared branch:
9055 // Γ_wk += [ Σ_i (H⁻¹)_ik,ik · ∂_M H_ik ] · J_wk = C_k · J_wk,
9056 // where ∂_M H_ik = `m_channel[i*K+k]` and J_wk = `z_jac[w*K+k]`. The
9057 // row-local direct-`z` channel was already added inline above, so this
9058 // pass adds only the cross-row remainder (it spans `w ≠ i` and the
9059 // self-row M_k self-coupling, which the row-local primitive deliberately
9060 // omits to avoid double-counting).
9061 if let Some(channels) = ibp_channels.as_ref() {
9062 let mut col_coeff = vec![0.0_f64; k_atoms];
9063 for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
9064 col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
9065 }
9066 for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
9067 gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
9068 }
9069
9070 // #1416 / #1641: the EXACT cross-row Woodbury derivative of Γ. The
9071 // assembled `H` carries the per-column rank-one block
9072 // `W_k = d_k·u_k u_kᵀ` with `u_k` the J-weighted column indicator
9073 // (`u_k[slot(i,k)] = J_ik`) and `d_k = w·s'_k` (`cross_row_d[k]`). Both
9074 // `d_k` (through `M_k`) and the `u_k` entries (through `ℓ_ik`) depend on
9075 // the logits, so
9076 // ∂W_k/∂ℓ_wk = dd_k·J_wk·u_k u_kᵀ
9077 // + d_k·c_wk·(e_w u_kᵀ + u_k e_wᵀ),
9078 // where `dd_k = ∂d_k/∂M_k = w·s''_k` (`cross_row_dd[k]`),
9079 // `c_wk = ∂J_wk/∂ℓ_wk` (`logit_curvature`), and `e_w` is the unit
9080 // vector at row `w`'s logit-`k` slot.
9081 //
9082 // The θ-adjoint contracts the FULL trace `Γ_wk = tr(H⁻¹ ∂H/∂ℓ_wk)`
9083 // (NOT the `½ tr` the ρ-trace uses — `fixed_state_logdet` differentiates
9084 // the full `log|H|`, and the per-row blocks above contract `inv_vv·dh`
9085 // with no ½). Critically, the `i=j` self curvature `w·s'_k·J_ik²` of the
9086 // rank-one block lives on the assembled `htt` DIAGONAL `H_ik`, so its
9087 // derivative is ALREADY differentiated by the row-local
9088 // `local_logit_third` channel (direct-z, `i=w`) and the `m_channel`
9089 // column pass (via `M_k`) above. This Woodbury pass must therefore add
9090 // ONLY the off-diagonal `i≠j` remainder — otherwise the self term is
9091 // double-counted (the #1641 defect: the pre-fix pass summed the full
9092 // `u_k u_kᵀ` including `i=j`, AND carried the ρ-trace ½, AND dropped the
9093 // factor 2 on the symmetric `e_w u_kᵀ + u_k e_wᵀ` term). Excluding `i=j`
9094 // is also why this pass needs no deflation correction: it contracts only
9095 // DISTINCT rows, off any single-row `vᵢ`'s support (matching the
9096 // #1416 ρ-trace cross-row pass).
9097 //
9098 // Contracting `tr(H⁻¹ ∂W_k/∂ℓ_wk)` over `i≠j` only:
9099 // Γ_wk += dd_k·J_wk·( u_kᵀ H⁻¹ u_k − Σ_i P_ii·J_ik² ) (term A)
9100 // + 2·d_k·c_wk·( (H⁻¹ u_k)_{slot(w,k)} − P_ww·J_wk ) (term B),
9101 // where `P_ii = (H⁻¹)_{slot(i,k),slot(i,k)}` is the selected-inverse
9102 // diagonal recorded in `ibp_logit_sites`. The subtracted self pieces are
9103 // exactly the `i=j` terms the diagonal channels own. Both `u_kᵀ H⁻¹ u_k`
9104 // and `(H⁻¹ u_k)` come from ONE solve per column, `x_k = H⁻¹ u_k` — so
9105 // the adjoint differentiates the SAME `H = H₀ + Σ_k W_k` the
9106 // value/logdet use, closing the one-operator contract on the rank-one
9107 // block too.
9108 //
9109 // Group the column sites once (the layout is mode-agnostic: dense or
9110 // compact, `ibp_logit_sites` already carries each active logit's
9111 // global t-index AND its selected-inverse diagonal `G_ii`), then per
9112 // column build `u_k`, solve, and distribute the OFF-DIAGONAL remainder.
9113 //
9114 // #1416 FIX: the diagonal (`i = w`) parts of term A and term B are
9115 // ALREADY supplied — `diag(term A) = dd_k·J_w·Σ_i G_ii·J_i²` by the
9116 // `m_channel` column pass above (whose `m_channel = w·(s''·J² + s'·c)`
9117 // carries the `s''·J²` self piece), and `diag(term B) = 2·d_k·c_w·G_ww·J_w`
9118 // by the inline `local_logit_third` self channel (whose
9119 // `s'·2J·∂_z J` piece is exactly that). So this pass must add ONLY the
9120 // cross-row off-diagonal remainder; double-counting the diagonal here
9121 // (the pre-fix `0.5·dd·J·uᵀGu + d·c·x_w` form, which is neither the
9122 // full nor the off-diagonal value) desynced the θ-adjoint from the FD
9123 // of `log|H|`. The exact `tr(H⁻¹ ∂W_k/∂ℓ_wk)` is
9124 // Γ_wk += dd_k·J_wk·(uᵀ G u − Σ_i G_ii·J_ik²) (term A, off-diagonal)
9125 // + 2·d_k·c_wk·((G u)_w − G_ww·J_wk) (term B, off-diagonal),
9126 // with `uᵀGu = Σ_i J_ik·(Gu)_i`, `(Gu) = x_k = H⁻¹ u_k` from one solve,
9127 // and `G_ii` the per-site selected-inverse diagonal.
9128 let total_t = cache.delta_t_len();
9129 let mut col_sites: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); k_atoms];
9130 for &(row, atom, t_index, inv_diag) in &ibp_logit_sites {
9131 col_sites[atom].push((row, t_index, inv_diag));
9132 }
9133 // Hoisted RHS scratch: fill only this column's active slots, solve,
9134 // then clear exactly those slots — no per-column total_t zeroing.
9135 let mut rhs_t_scratch = Array1::<f64>::zeros(total_t);
9136 let rhs_beta_zero = Array1::<f64>::zeros(cache.k);
9137 for atom in 0..k_atoms {
9138 let d_k = channels.cross_row_d[atom];
9139 let dd_k = channels.cross_row_dd[atom];
9140 if col_sites[atom].is_empty() || (d_k == 0.0 && dd_k == 0.0) {
9141 continue;
9142 }
9143 // u_k as a full t-RHS: J at each active logit-k slot.
9144 for &(row, t_index, _g) in &col_sites[atom] {
9145 rhs_t_scratch[t_index] = channels.z_jac[row * k_atoms + atom];
9146 }
9147 let x_k = solver
9148 .solve(rhs_t_scratch.view(), rhs_beta_zero.view())
9149 .map_err(|err| {
9150 format!("logdet_theta_adjoint: IBP cross-row Woodbury solve: {err}")
9151 })?;
9152 // Clear this column's active slots for the next atom's RHS.
9153 for &(_row, t_index, _g) in &col_sites[atom] {
9154 rhs_t_scratch[t_index] = 0.0;
9155 }
9156 // (JᵀH⁻¹J)_k = u_kᵀ x_k, and the diagonal `Σ_i G_ii·J_ik²` that the
9157 // `m_channel` pass already counted (subtract it from term A so this
9158 // pass holds only the off-diagonal `i ≠ j` remainder).
9159 let mut jt_hinv_j = 0.0_f64;
9160 let mut diag_jt_g_j = 0.0_f64;
9161 for &(row, t_index, g_ii) in &col_sites[atom] {
9162 let j = channels.z_jac[row * k_atoms + atom];
9163 jt_hinv_j += j * x_k.t[t_index];
9164 diag_jt_g_j += g_ii * j * j;
9165 }
9166 let off_diag_a = jt_hinv_j - diag_jt_g_j;
9167 for &(row, t_index, g_ii) in &col_sites[atom] {
9168 let j_wk = channels.z_jac[row * k_atoms + atom];
9169 let c_wk = channels.logit_curvature[row * k_atoms + atom];
9170 // term A (off-diagonal) + term B (off-diagonal); the inline /
9171 // `m_channel` passes already added the diagonal parts.
9172 let off_diag_b = x_k.t[t_index] - g_ii * j_wk;
9173 gamma_t[t_index] += dd_k * j_wk * off_diag_a + 2.0 * d_k * c_wk * off_diag_b;
9174 }
9175 }
9176 }
9177
9178 Ok(SaeArrowVector {
9179 t: gamma_t,
9180 beta: gamma_beta,
9181 })
9182 }
9183
9184
9185 /// Public analytic outer-ρ gradient at a converged inner state, constructing
9186 /// the deflated arrow solver from the supplied cache. Use this seam from
9187 /// integration tests and external consumers that have a converged
9188 /// `(loss, cache)` from [`Self::reml_criterion_with_cache`] but no access to
9189 /// the crate-private `DeflatedArrowSolver`.
9190 pub fn analytic_outer_rho_gradient_at_converged(
9191 &self,
9192 target: ArrayView2<'_, f64>,
9193 rho: &SaeManifoldRho,
9194 loss: &SaeManifoldLoss,
9195 cache: &ArrowFactorCache,
9196 ) -> Result<SaeOuterRhoGradientComponents, String> {
9197 let solver = self.outer_gradient_arrow_solver(cache, &rho.lambda_smooth_vec())?;
9198 self.analytic_outer_rho_gradient_components(target, rho, loss, cache, &solver)
9199 .map_err(|e| e.to_string())
9200 }
9201
9202 /// Compose the SAE LAML criterion as a sum of atoms (#931 SAE pilot).
9203 ///
9204 /// This is the single seam that establishes value↔gradient coherence for
9205 /// the SAE objective: it runs the inner solve once via
9206 /// [`Self::reml_criterion_with_cache`], reads the value decomposition
9207 /// (`loss.total() + extra_penalty_energy`, `log|H|`, `occam`) and the
9208 /// matching gradient channels (`SaeOuterRhoGradientComponents`) from the
9209 /// SAME converged cache, and hands them to [`SaeCriterion::assemble`]. The
9210 /// returned criterion's [`SaeCriterion::value`] and
9211 /// [`SaeCriterion::gradient`] are then projections of one factorization —
9212 /// the outer optimizer can no longer evaluate a value path and a gradient
9213 /// path that disagree (the #752/#748/#901 desync class). The
9214 /// implicit-stationarity envelope correction (#1006's Γ term) is its own
9215 /// named atom, so the channel the desync class keeps dropping is visible
9216 /// rather than a silent zero.
9217 pub fn criterion_as_atoms(
9218 &mut self,
9219 target: ArrayView2<'_, f64>,
9220 rho: &SaeManifoldRho,
9221 registry: Option<&AnalyticPenaltyRegistry>,
9222 inner_max_iter: usize,
9223 learning_rate: f64,
9224 ridge_ext_coord: f64,
9225 ridge_beta: f64,
9226 ) -> Result<SaeCriterion, String> {
9227 let (_v, loss, cache) = self.reml_criterion_with_cache(
9228 target,
9229 rho,
9230 registry,
9231 inner_max_iter,
9232 learning_rate,
9233 ridge_ext_coord,
9234 ridge_beta,
9235 )?;
9236 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
9237 "criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
9238 })?;
9239 let occam = self.reml_occam_term(rho)?;
9240 let extra_penalty_energy = match registry {
9241 Some(reg) => self
9242 .reml_extra_penalty_value_total(reg)
9243 .map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
9244 None => 0.0,
9245 };
9246 let data_fit_priors_value = loss.total() + extra_penalty_energy;
9247
9248 let solver = self.outer_gradient_arrow_solver(&cache, &rho.lambda_smooth_vec())?;
9249 let components =
9250 self.analytic_outer_rho_gradient_components(target, rho, &loss, &cache, &solver)?;
9251 Ok(SaeCriterion::assemble(
9252 data_fit_priors_value,
9253 log_det,
9254 occam,
9255 components.explicit,
9256 components.logdet_trace,
9257 components.occam,
9258 components.third_order_correction,
9259 ))
9260 }
9261
9262 // [#780 line-count gate] reconstruction_dispersion + assemble_shape_uncertainty
9263 // + complete_born_atom_shape_bands + shape_uncertainty_without_decoder_covariance
9264 // (the contiguous trailing methods of this impl block) were split into the
9265 // sibling construction_reconstruction.rs (declared in mod.rs); callers reach
9266 // them bare via use super::*.
9267}
9268
9269// [#780 line-count gate] Per-row jet / reconstruction-channel assembly for the
9270// streaming-exact arrow log-det lives in a sibling file as a second
9271// `impl SaeManifoldTerm` block, inlined here so it keeps the SAME module scope
9272// and private-field access. Keeps this tracked file under the 10k limit.
9273include!("construction_row_jet_logdet_channels.rs");
9274
9275// [#780 line-count gate] Massive-K decoder-smoothness effective-dof Hutchinson
9276// estimator (associated constants + the matrix-free per-atom trace) lives in a
9277// sibling file as another `impl SaeManifoldTerm` block, inlined here so it keeps
9278// the SAME module scope and private-field access. The two gated exact/estimator
9279// entry points above dispatch into it at `K >= MIN_ATOMS`.
9280include!("construction_smoothness_dof.rs");
9281
9282// [#780 line-count gate] `term_from_padded_blocks_with_mode` (the padded-FFI
9283// term builder) was split into the sibling `construction_padded_blocks.rs`
9284// module (declared and re-exported from `mod.rs`), keeping this tracked file
9285// under the 10k limit. Callers still reach it bare through `use super::*`.
9286
9287// [#780 line-count gate] `refresh_isometry_caches_from_atom` and
9288// `refresh_isometry_caches_from_term` were split into the sibling
9289// `construction_cache_refresh.rs` module (declared and re-exported from
9290// `mod.rs`), keeping this tracked file under the 10k limit. Callers still reach
9291// both functions bare through `use super::*`.
9292
9293// [#780 line-count gate] The `#[cfg(test)]` modules below the production code
9294// are mechanically split into a sibling `*_tests` file and inlined via
9295// `include!` (the sanctioned cohesive-module decomposition — see build.rs
9296// file_stem_is_exempt_test_module). Keeps this tracked file under the 10k limit.
9297include!("construction_tests.rs");