gam_sae/manifold/construction.rs
1use super::*;
2use gam_math::jet_scalar::JetScalar;
3
4// [#780] Softmax-entropy Gershgorin majorizer leaf helpers live in a sibling
5// cohesive module, inlined here so they share this module scope.
6include!("softmax_entropy_majorizer.rs");
7
8/// Typed error from the SAE outer-gradient analytic assembly path (#1436).
9///
10/// The `eval()` analytic fallback (#1273/#1440: the plain undeflated analytic
11/// outer gradient, NOT a finite difference) must fire ONLY for the genuine
12/// conditioning/identifiability failure modes it was designed for — a
13/// near-singular-but-valid joint Hessian or a gauge-degenerate direction.
14/// Shape/indexing bugs, non-finite intermediates, and violated internal
15/// invariants are [`OuterGradientError::InternalInvariant`] and MUST propagate
16/// as hard errors so regressions surface instead of being silently masked by a
17/// degraded descent direction.
18#[derive(Clone, Debug)]
19pub(crate) enum OuterGradientError {
20 /// Expected: near-singular or ill-conditioned joint Hessian at a feasible ρ
21 /// (the genuine #1273 flat-valley case). Eligible for the FD fallback.
22 IllConditioned { reason: String },
23 /// Expected: a non-identifiable / gauge-degenerate direction at this ρ.
24 /// Eligible for the FD fallback.
25 NonIdentifiable { reason: String },
26 /// Unexpected: shape/dimension mismatch, non-finite intermediate, or a
27 /// violated internal invariant. MUST propagate — never fall back to FD.
28 InternalInvariant { reason: String },
29}
30
31impl OuterGradientError {
32 /// Whether this error class is recoverable by the #1273/#1440 analytic
33 /// plain-solver fallback (i.e. it represents a legitimate
34 /// conditioning/identifiability failure, not a programming/invariant defect).
35 pub(crate) fn is_conditioning_recoverable(&self) -> bool {
36 matches!(
37 self,
38 Self::IllConditioned { .. } | Self::NonIdentifiable { .. }
39 )
40 }
41
42 /// Construct an [`OuterGradientError::InternalInvariant`] from any error
43 /// displayable — the default classification for unexpected assembly failures
44 /// (shape mismatches, non-finite intermediates, violated invariants).
45 pub(crate) fn internal<E: std::fmt::Display>(err: E) -> Self {
46 Self::InternalInvariant {
47 reason: err.to_string(),
48 }
49 }
50
51 /// #1451 — classify a `String` error surfaced by the deflation linear-algebra
52 /// path (`apply_cached_arrow_hessian`, `DeflatedArrowSolver::from_orthonormal_gauges`)
53 /// into the correct [`OuterGradientError`] class.
54 ///
55 /// A genuine rank-deficiency / near-singularity failure (a back-solve or
56 /// Cholesky/Woodbury factor that tripped on a finite, correctly-shaped input)
57 /// is a legitimate #1273 conditioning failure and keeps `conditioning_err`
58 /// (`IllConditioned`), so it stays recoverable by the analytic fallback. A
59 /// shape/dimension mismatch or a non-finite intermediate is an
60 /// internal-invariant defect and MUST propagate ([`Self::internal`]) instead
61 /// of being masked as a plausible-but-wrong descent direction — exactly the
62 /// #1436 contract.
63 ///
64 /// The two solver helpers return `String` (not a typed error), so the
65 /// distinction is drawn from the stable markers those helpers emit for their
66 /// shape/non-finite guards (`vector shapes`, `gauge length`, `must be finite`,
67 /// `non-finite`). Everything else — including the `cholesky`/back-solve
68 /// near-singular failures — is treated as a genuine conditioning trip.
69 pub(crate) fn classify_arrow_solver_error(message: &str, conditioning_err: Self) -> Self {
70 let lower = message.to_ascii_lowercase();
71 let is_internal = lower.contains("vector shapes")
72 || lower.contains("gauge length")
73 || lower.contains("solution length")
74 || lower.contains("!= cache")
75 || lower.contains("must be finite")
76 || lower.contains("non-finite")
77 || lower.contains("not finite")
78 || lower.contains("nan")
79 || lower.contains("inf");
80 if is_internal {
81 Self::internal(message)
82 } else {
83 conditioning_err
84 }
85 }
86
87 /// The exact gate the gradient lane (`SaeManifoldOuterObjective::eval`) uses
88 /// to decide whether to descend with the #1273/#1440 analytic plain-solver
89 /// fallback instead of propagating the error as a hard failure.
90 ///
91 /// The fallback is admissible ONLY when BOTH hold:
92 /// * the REML cost at this rho is finite (a genuinely feasible point -- the
93 /// plain analytic solver supplies a descent direction for a value the
94 /// analytic path already produced), and
95 /// * the error is a legitimate conditioning/identifiability failure
96 /// ([`Self::is_conditioning_recoverable`]) -- the genuine #1273 flat-valley
97 /// case.
98 ///
99 /// A non-finite cost or an [`OuterGradientError::InternalInvariant`] must
100 /// propagate: masking a shape/indexing bug, a non-finite intermediate, or a
101 /// violated invariant behind a plausible-but-wrong step is exactly the
102 /// regression #1436 closes. Centralising the decision here (rather than
103 /// inlining the boolean at the call site) makes the `cost x error-class`
104 /// contract a single, directly unit-testable predicate.
105 pub(crate) fn admits_plain_solver_fallback(&self, cost: f64) -> bool {
106 cost.is_finite() && self.is_conditioning_recoverable()
107 }
108}
109
110impl std::fmt::Display for OuterGradientError {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 Self::IllConditioned { reason } => write!(f, "ill-conditioned: {reason}"),
114 Self::NonIdentifiable { reason } => write!(f, "non-identifiable: {reason}"),
115 Self::InternalInvariant { reason } => write!(f, "internal invariant: {reason}"),
116 }
117 }
118}
119
120impl From<OuterGradientError> for String {
121 fn from(e: OuterGradientError) -> String {
122 e.to_string()
123 }
124}
125
126/// Active-set layout override for [`SaeManifoldTerm::assemble_arrow_schur_inner`].
127///
128/// `None` is the production path: the layout is derived from the assignment mode
129/// and `sparse_active_plan`. `Some(layout_opt)` pins a specific layout — dense
130/// (`Some(None)`) or a chosen compact `SaeRowLayout` (`Some(Some(..))`) — so the
131/// compact-vs-dense Riemannian-geometry equality regression can drive both code
132/// paths on identical data without depending on the host/device memory budget
133/// that gates the compact path in production.
134pub(crate) type ForcedRowLayout = Option<Option<SaeRowLayout>>;
135
136/// #1154 — base co-training weight for the amortized-encoder reconstruction
137/// consistency penalty, as a fraction of the REML criterion magnitude. The
138/// effective weight is `COTRAIN_RECON_WEIGHT · max(|REML|, 1)`, so the penalty
139/// is a bounded, scale-free share of the objective and needs no caller knob.
140pub(crate) const COTRAIN_RECON_WEIGHT: f64 = 0.1;
141
142/// #1154 — base co-training weight for the encoder's certifiable-coverage
143/// penalty (the fraction of (row, atom) encodes the Kantorovich certificate
144/// rejected). Scaled like [`COTRAIN_RECON_WEIGHT`].
145pub(crate) const COTRAIN_CERT_WEIGHT: f64 = 0.05;
146
147/// #1154 — amortized-encoder consistency of a fitted dictionary against its own
148/// fit-time target. The co-training signal of the joint amortized-encoder +
149/// REML loop: how faithfully (and how certifiably) the cheap one-mat-vec
150/// encoder inverts the dictionary the inner solve converged to.
151#[derive(Debug, Clone, Copy)]
152pub struct AmortizedEncoderConsistency {
153 /// Mean per-element squared gap between the amortized reconstruction and the
154 /// exact fitted reconstruction (`‖x̂_amortized − x̂_exact‖² / (n·p)`). Zero ⇒
155 /// the IFT predictor reproduces the encode map exactly to first order.
156 pub recon_consistency: f64,
157 /// Fraction of (row, atom) amortized encodes whose Kantorovich certificate
158 /// failed (`h > ½`) and fell back to the certified Newton encode.
159 pub uncertified_fraction: f64,
160 /// Count of uncertified (row, atom) encodes (numerator of the fraction).
161 pub n_uncertified: usize,
162 /// Total (row, atom) encodes scored (`n · K`).
163 pub n_encodes: usize,
164}
165
166impl SaeManifoldTerm {
167 #[must_use = "build error must be handled"]
168 pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
169 if atoms.is_empty() {
170 return Err("SaeManifoldTerm::new: at least one atom required".into());
171 }
172 let n = atoms[0].n_obs();
173 let p = atoms[0].output_dim();
174 if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
175 return Err(format!(
176 "SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
177 assignment.n_obs(),
178 assignment.k_atoms(),
179 atoms.len()
180 ));
181 }
182 for (k, atom) in atoms.iter().enumerate() {
183 if atom.n_obs() != n {
184 return Err(format!(
185 "SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
186 atom.n_obs()
187 ));
188 }
189 if atom.output_dim() != p {
190 return Err(format!(
191 "SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
192 atom.output_dim()
193 ));
194 }
195 if atom.latent_dim != assignment.coords[k].latent_dim() {
196 return Err(format!(
197 "SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
198 atom.latent_dim,
199 assignment.coords[k].latent_dim()
200 ));
201 }
202 }
203 Ok(Self {
204 atoms,
205 assignment,
206 temperature_schedule: None,
207 last_row_layout: None,
208 row_metric: None,
209 collapse_events: Vec::new(),
210 row_loss_weights: None,
211 last_frames_active: false,
212 assembly_chunk_override: None,
213 fixed_decoder_assembly: false,
214 softmax_active_cap: None,
215 border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
216 certificate_dispersion: None,
217 curvature_walk_report: None,
218 expected_evidence_gauge_deflated_directions: None,
219 evidence_gauge_deflation_reanchors: 0,
220 evidence_gauge_deflation_last_delta_sign: 0,
221 dictionary_cocollapse_reseeds: 0,
222 best_cocollapse_incumbent: None,
223 decoder_repulsion_gate: None,
224 barrier_coactivation_gate: None,
225 hybrid_split_report: None,
226 atom_inner_fits: None,
227 oos_linear_images: None,
228 })
229 }
230
231 /// #1408/#1409 — install the optional hard per-row active-atom cap for
232 /// Softmax mode (threaded from the fit/encode `top_k`). A `Some(k)` with
233 /// `1 <= k < K` makes the Softmax assignment optimize on the COMPACT
234 /// top-`k` row layout (see [`Self::softmax_active_cap`]); `Some(k) >= K`
235 /// and `None` are both no-ops (full support). Non-softmax modes ignore it.
236 pub fn set_softmax_active_cap(&mut self, top_k: Option<usize>) {
237 self.softmax_active_cap = match top_k {
238 Some(k) if k >= 1 && k < self.k_atoms() => Some(k),
239 _ => None,
240 };
241 }
242
243 /// Install the fitted reconstruction dispersion used by
244 /// [`dictionary_incoherence_report`]. This is a pure diagnostic scalar and
245 /// does not feed any loss, criterion, penalty, or optimizer state.
246 pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
247 if !dispersion.is_finite() || dispersion <= 0.0 {
248 return Err(format!(
249 "SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
250 ));
251 }
252 self.certificate_dispersion = Some(dispersion);
253 Ok(())
254 }
255
256 /// Harvest the per-atom inner-decoder-smooth byproducts (#1097 / #1103) the
257 /// residual-gauge certificate's post-PIRLS atom inference reports consume.
258 ///
259 /// This is the post-fit harness seam: it needs the reconstruction target `Z`
260 /// (`target`) and the fitted dispersion `φ` (`dispersion`), both available
261 /// only after the joint fit converges and the engine has discarded `Z` from
262 /// the objective. For each atom `k` it captures the Gaussian-identity
263 /// penalized smooth of the atom's leading decoder output channel `j`
264 /// (largest column 2-norm of `B_k`) against its partial residual
265 /// `e_{i} = z_i − fitted_i + a_{ik} g_k(t_i)` on channel `j`, holding all
266 /// other atoms and the assignment fixed at the fitted optimum — exactly the
267 /// fixed snapshot ([`crate::identifiability::AtomInnerFit`]) the Riesz
268 /// debiasing and split-LRT smooth-structure e-value read.
269 ///
270 /// A pure read of the fitted state: it mutates only the diagnostic
271 /// `atom_inner_fits` field, never a loss / criterion / penalty / optimizer
272 /// state. Atoms with no active rows or a degenerate (rank-deficient,
273 /// non-SPD) inner Hessian get a `None` slot — the genuine prerequisite (an
274 /// SPD penalized inner Hessian on a non-empty active set) is absent there.
275 pub fn set_atom_inner_fits(
276 &mut self,
277 target: ArrayView2<'_, f64>,
278 rho: &SaeManifoldRho,
279 dispersion: f64,
280 ) -> Result<(), String> {
281 if !dispersion.is_finite() || dispersion <= 0.0 {
282 return Err(format!(
283 "SaeManifoldTerm::set_atom_inner_fits: dispersion must be finite and positive, got {dispersion}"
284 ));
285 }
286 let n = self.n_obs();
287 let p = self.output_dim();
288 let k_atoms = self.k_atoms();
289 if target.dim() != (n, p) {
290 return Err(format!(
291 "SaeManifoldTerm::set_atom_inner_fits: target {:?} != ({n}, {p})",
292 target.dim()
293 ));
294 }
295
296 // #1026 — `atom_inner_fits` is a pure diagnostic; skip its dense (N×K×P)
297 // tensor (~256 GiB at K=32768,P=32) past a cell ceiling — all-None slots,
298 // never OOM. The fit is unaffected; only this audit field is absent.
299 if n.saturating_mul(k_atoms).saturating_mul(p) > 64_000_000 {
300 self.atom_inner_fits = Some((0..k_atoms).map(|_| None).collect());
301 return Ok(());
302 }
303
304 // Settled per-row assignments and per-(row, atom) decoded outputs, so the
305 // per-atom partial residual is `e_k = (z − fitted) + a_k decoded_k`.
306 let mut assignments = Vec::with_capacity(n);
307 for row in 0..n {
308 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
309 }
310 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
311 let mut dbuf = vec![0.0_f64; p];
312 for row in 0..n {
313 for atom_idx in 0..k_atoms {
314 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
315 for c in 0..p {
316 decoded[[row, atom_idx, c]] = dbuf[c];
317 }
318 }
319 }
320 let mut fitted = Array2::<f64>::zeros((n, p));
321 for row in 0..n {
322 for atom_idx in 0..k_atoms {
323 let a = assignments[row][atom_idx];
324 if a == 0.0 {
325 continue;
326 }
327 for c in 0..p {
328 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
329 }
330 }
331 }
332
333 let mut inner_fits: Vec<Option<crate::identifiability::AtomInnerFit>> =
334 Vec::with_capacity(k_atoms);
335 for atom_idx in 0..k_atoms {
336 inner_fits.push(self.build_atom_inner_fit(
337 atom_idx,
338 target,
339 &assignments,
340 decoded.view(),
341 fitted.view(),
342 dispersion,
343 )?);
344 }
345 self.atom_inner_fits = Some(inner_fits);
346 Ok(())
347 }
348
349 /// Build one atom's fixed inner-smooth snapshot for the post-PIRLS atom
350 /// inference reports, or `None` when the atom has no active rows or the
351 /// penalized inner Hessian is not SPD. Returns `Err` only on a structural
352 /// inconsistency (shape mismatch), never on a benign degenerate atom.
353 pub(crate) fn build_atom_inner_fit(
354 &self,
355 atom_idx: usize,
356 target: ArrayView2<'_, f64>,
357 assignments: &[Array1<f64>],
358 decoded: ArrayView3<'_, f64>,
359 fitted: ArrayView2<'_, f64>,
360 dispersion: f64,
361 ) -> Result<Option<crate::identifiability::AtomInnerFit>, String> {
362 let atom = &self.atoms[atom_idx];
363 let n = atom.n_obs();
364 let m = atom.basis_size();
365 let p = atom.output_dim();
366 if m == 0 || p == 0 {
367 return Ok(None);
368 }
369
370 // Leading decoder output channel j = argmax_j ‖B_k[:, j]‖, the channel
371 // that carries the atom's signal.
372 let mut j_lead = 0usize;
373 let mut best_norm = -1.0_f64;
374 for col in 0..p {
375 let mut norm = 0.0_f64;
376 for r in 0..m {
377 let v = atom.decoder_coefficients[[r, col]];
378 norm += v * v;
379 }
380 if norm > best_norm {
381 best_norm = norm;
382 j_lead = col;
383 }
384 }
385 let beta = atom.decoder_coefficients.column(j_lead).to_owned();
386
387 // Active rows: a_{ik} > 0.
388 let active: Vec<usize> = (0..n)
389 .filter(|&row| assignments[row][atom_idx] > 0.0)
390 .collect();
391 let n_active = active.len();
392 // The penalized smooth needs at least as many active rows as it has
393 // basis columns to give a non-degenerate data Gram; below that the inner
394 // fit's SPD prerequisite is genuinely unmet.
395 if n_active == 0 {
396 return Ok(None);
397 }
398
399 let mut design = Array2::<f64>::zeros((n_active, m));
400 let mut derivative_design = Array2::<f64>::zeros((n_active, m));
401 let mut row_scores = Array2::<f64>::zeros((n_active, m));
402 let mut weights = Array1::<f64>::zeros(n_active);
403 for (slot, &row) in active.iter().enumerate() {
404 let a_ik = assignments[row][atom_idx];
405 let w_i = a_ik * a_ik;
406 weights[slot] = w_i;
407 for col in 0..m {
408 design[[slot, col]] = atom.basis_values[[row, col]];
409 // Leading latent axis (axis 0) is the atom's primary coordinate;
410 // it is the one the average-derivative functional integrates.
411 derivative_design[[slot, col]] = atom.basis_jacobian[[row, col, 0]];
412 }
413 // Partial residual on channel j, then the inner-smooth working
414 // response z_i = e_i / a_ik so that w_i (z_i − Φᵀβ) = a_ik r_i.
415 let e_i = target[[row, j_lead]] - fitted[[row, j_lead]]
416 + a_ik * decoded[[row, atom_idx, j_lead]];
417 let mu_hat = design.row(slot).dot(&beta);
418 let z_i = e_i / a_ik;
419 let res_i = z_i - mu_hat;
420 // Gaussian-identity score s_i = −w_i res_i Φ_i / φ.
421 let scale = -w_i * res_i / dispersion;
422 for col in 0..m {
423 row_scores[[slot, col]] = scale * design[[slot, col]];
424 }
425 }
426
427 // Penalized inner Hessian H = ΦᵀWΦ + S̃_k.
428 let mut xtwx = Array2::<f64>::zeros((m, m));
429 for slot in 0..n_active {
430 let w_i = weights[slot];
431 for a in 0..m {
432 let xa = design[[slot, a]];
433 if xa == 0.0 {
434 continue;
435 }
436 for b in 0..m {
437 xtwx[[a, b]] += w_i * xa * design[[slot, b]];
438 }
439 }
440 }
441 let penalty = atom.smooth_penalty.clone();
442 if penalty.dim() != (m, m) {
443 return Err(format!(
444 "build_atom_inner_fit: atom {atom_idx} smooth penalty {:?} != ({m}, {m})",
445 penalty.dim()
446 ));
447 }
448 let penalized_hessian = &xtwx + &penalty;
449
450 // SPD prerequisite: the inner penalized Hessian must factor, else the
451 // atom's inner-smooth fit is degenerate and no report is producible.
452 if penalized_hessian.cholesky(Side::Lower).is_err() {
453 return Ok(None);
454 }
455
456 // Peak (largest fitted |g_k| on channel j) and mode (largest assignment
457 // mass) design rows, over the active set.
458 let mut peak_slot = 0usize;
459 let mut peak_val = -1.0_f64;
460 let mut mode_slot = 0usize;
461 let mut mode_mass = -1.0_f64;
462 for (slot, &row) in active.iter().enumerate() {
463 let g_val = design.row(slot).dot(&beta).abs();
464 if g_val > peak_val {
465 peak_val = g_val;
466 peak_slot = slot;
467 }
468 let mass = assignments[row][atom_idx];
469 if mass > mode_mass {
470 mode_mass = mass;
471 mode_slot = slot;
472 }
473 }
474 let peak_design_row = design.row(peak_slot).to_owned();
475 let mode_design_row = design.row(mode_slot).to_owned();
476
477 Ok(Some(crate::identifiability::AtomInnerFit {
478 design,
479 derivative_design,
480 beta,
481 penalty,
482 penalized_hessian,
483 row_scores,
484 weights,
485 dispersion,
486 peak_design_row,
487 mode_design_row,
488 }))
489 }
490
491 /// Profile the Gaussian reconstruction dispersion at the current seed
492 /// state. This is the scale used to make SAE penalty seeds dimensionless
493 /// before the outer rho search starts.
494 pub fn seed_reconstruction_dispersion(
495 &self,
496 target: ArrayView2<'_, f64>,
497 ) -> Result<f64, String> {
498 let fitted = self.try_fitted()?;
499 if fitted.dim() != target.dim() {
500 return Err(format!(
501 "SaeManifoldTerm::seed_reconstruction_dispersion: fitted {:?} != target {:?}",
502 fitted.dim(),
503 target.dim()
504 ));
505 }
506 let n_scalar = (target.nrows() * target.ncols()).max(1) as f64;
507 let mut rss = 0.0_f64;
508 for row in 0..target.nrows() {
509 for col in 0..target.ncols() {
510 let r = target[[row, col]] - fitted[[row, col]];
511 rss += r * r;
512 }
513 }
514 if !rss.is_finite() || rss < 0.0 {
515 return Err(format!(
516 "SaeManifoldTerm::seed_reconstruction_dispersion: non-finite seed RSS {rss}"
517 ));
518 }
519 Ok((rss / n_scalar).max(SAE_SEED_DISPERSION_FLOOR))
520 }
521
522 /// Install per-row design honesty weights (#991) — the `1/π` inclusion
523 /// corrections of a designed corpus subsample (see the field docs on
524 /// `row_loss_weights` for exactly where they enter the objective).
525 ///
526 /// Weights must be finite and strictly positive, one per term row. They
527 /// are self-normalized to mean `1.0` here (only the *relative* design
528 /// correction matters at the fitted sample size; the absolute `n/budget`
529 /// scale would silently inflate the dispersion estimate against the
530 /// sample-sized dof). Weights that are identically equal after
531 /// normalization (an exact full pass, or any uniform design) are stored
532 /// as `None`, so the unweighted path stays bit-for-bit identical rather
533 /// than "multiplied by 1.0".
534 pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
535 if weights.len() != self.n_obs() {
536 return Err(format!(
537 "SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
538 weights.len(),
539 self.n_obs()
540 ));
541 }
542 if weights.is_empty() {
543 self.row_loss_weights = None;
544 return Ok(());
545 }
546 if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
547 return Err(
548 "SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
549 positive"
550 .to_string(),
551 );
552 }
553 let first = weights[0];
554 if weights.iter().all(|w| *w == first) {
555 // Uniform design (full pass, or flat measure): the normalized
556 // weight is exactly 1 everywhere — take the unweighted path.
557 self.row_loss_weights = None;
558 return Ok(());
559 }
560 let mean = weights.iter().sum::<f64>() / weights.len() as f64;
561 self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
562 Ok(())
563 }
564
565 /// The installed (mean-1 normalized) design honesty weights, `None` on the
566 /// exact unweighted path.
567 pub fn row_loss_weights(&self) -> Option<&[f64]> {
568 self.row_loss_weights.as_deref()
569 }
570
571 /// Drop any installed per-row reconstruction weights, returning the term to
572 /// the exact unweighted (full-pass) path. Used by the #997 structure-search
573 /// wiring to clear the internal estimation/evaluation mask off the adopted
574 /// term before the payload reconstruction is read over all rows.
575 pub fn clear_row_loss_weights(&mut self) {
576 self.row_loss_weights = None;
577 }
578
579 /// Huber-style OUTLIER-ROBUST per-row weights from the target activation
580 /// norms — the missing default *policy* for the existing
581 /// [`set_row_loss_weights`](Self::set_row_loss_weights) mechanism.
582 ///
583 /// The SAE fits unweighted least squares, which weights each token by its
584 /// squared residual ∝ `‖z_i‖²`. On real LLM residual streams the per-token
585 /// norm distribution is heavy-tailed (e.g. an OLMo mixed-layer slice has
586 /// `p99/median ≈ 4.7`), so a small **coherent** cluster of high-norm tokens —
587 /// typically special / attention-sink tokens, not semantic content —
588 /// dominates the objective (measured: the top 5% of tokens carry ~31% of the
589 /// total `‖z‖²` budget) and pulls dictionary atoms toward their direction.
590 /// Mean-centering does NOT address this (it is per-feature, not per-token).
591 ///
592 /// This returns Huber weights `w_i = min(1, δ·m / ‖z_i‖)` where `m` is the
593 /// MEDIAN token norm: tokens at or below `δ·m` keep full weight, higher-norm
594 /// tokens are downweighted so their objective share grows only LINEARLY (not
595 /// quadratically) with norm. `δ` is the robustness knob (`δ=1` thresholds at
596 /// the median; larger `δ` only touches the extreme tail). The result is
597 /// mean-normalized (overall objective scale preserved). OPT-IN: the caller
598 /// installs it via `set_row_loss_weights` — the default fit is unchanged.
599 pub fn robust_norm_row_weights(
600 target: ArrayView2<'_, f64>,
601 delta: f64,
602 ) -> Result<Vec<f64>, String> {
603 if !(delta.is_finite() && delta > 0.0) {
604 return Err(format!(
605 "robust_norm_row_weights: delta must be finite and positive; got {delta}"
606 ));
607 }
608 let n = target.nrows();
609 if n == 0 {
610 return Ok(Vec::new());
611 }
612 let norms: Vec<f64> = (0..n)
613 .map(|i| {
614 let r = target.row(i);
615 r.dot(&r).sqrt()
616 })
617 .collect();
618 let mut sorted = norms.clone();
619 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
620 // Median token norm (lower-median for even n; floored off zero so an
621 // all-zero/degenerate slice yields uniform weights instead of NaN).
622 let median = sorted[n / 2].max(f64::MIN_POSITIVE);
623 let thresh = delta * median;
624 let raw: Vec<f64> = norms
625 .iter()
626 .map(|&nm| if nm <= thresh { 1.0 } else { thresh / nm })
627 .collect();
628 let mean = raw.iter().sum::<f64>() / n as f64;
629 if !(mean.is_finite() && mean > 0.0) {
630 return Err("robust_norm_row_weights: degenerate weight normalizer".to_string());
631 }
632 Ok(raw.into_iter().map(|w| w / mean).collect())
633 }
634
635 /// Install the single per-row [`RowMetric`](gam_problem::RowMetric)
636 /// that both the reconstruction likelihood and the isometry gauge read.
637 /// Installing per-row output-Fisher factors here flips the provenance to
638 /// `OutputFisher` *and* is the only way the gauge acquires a non-identity
639 /// weight, so the two inner products cannot diverge. Passing a Euclidean
640 /// metric (or never calling this) keeps the bit-identical isotropic path.
641 ///
642 /// The metric's row count and output dimension must match the term.
643 pub fn set_row_metric(
644 &mut self,
645 metric: gam_problem::RowMetric,
646 ) -> Result<(), String> {
647 if metric.n_rows() != self.n_obs() {
648 return Err(format!(
649 "SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
650 metric.n_rows(),
651 self.n_obs()
652 ));
653 }
654 if metric.p_out() != self.output_dim() {
655 return Err(format!(
656 "SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
657 metric.p_out(),
658 self.output_dim()
659 ));
660 }
661 self.row_metric = Some(metric);
662 Ok(())
663 }
664
665 /// The installed per-row metric, if any. `None` ⇒ Euclidean / isotropic.
666 /// Consumed by the gauge wiring (to build the matching `WeightField`) and by
667 /// Object 4 (to read the [`MetricProvenance`](gam_problem::MetricProvenance)).
668 pub fn row_metric(&self) -> Option<&gam_problem::RowMetric> {
669 self.row_metric.as_ref()
670 }
671
672 /// The per-row inner product the additive diagnostics read through: the
673 /// installed [`RowMetric`](gam_problem::RowMetric) when one
674 /// was set (output-Fisher harvest present), otherwise a freshly-built
675 /// Euclidean metric of the term's own `(n_obs, output_dim)` shape. Either way
676 /// a metric always exists, so the diagnostics are never gated by a flag — the
677 /// Euclidean fallback is the bit-identical isotropic path.
678 pub(crate) fn diagnostic_metric(
679 &self,
680 ) -> Result<gam_problem::RowMetric, String> {
681 match self.row_metric() {
682 Some(metric) => Ok(metric.clone()),
683 None => {
684 gam_problem::RowMetric::euclidean(self.n_obs(), self.output_dim())
685 }
686 }
687 }
688
689 /// Build the additive post-fit diagnostic report for this fitted term: the
690 /// two-score per-atom [`AtomTwoLensReport`](crate::inference::atom_lens::AtomTwoLensReport)
691 /// (presence / behavioral coupling / discrepancy) and the residual-gauge
692 /// [`ResidualGaugeReport`](crate::identifiability::ResidualGaugeReport)
693 /// certificate.
694 ///
695 /// Both reports are read through the same single metric
696 /// ([`Self::diagnostic_metric`]): under a Euclidean / no-harvest provenance
697 /// the lens coupling is `None` and the gauge is certified under Euclidean
698 /// provenance — never an error, never gated by a flag (magic-by-default,
699 /// mirroring the metric selection itself).
700 ///
701 /// `per_atom_ard_variances`, when supplied, is one ARD variance vector per
702 /// atom (length = `latent_dim_k`), threaded into the certificate's
703 /// equal-ARD-rotation detection. `None` (or a per-atom `None`) ⇒ no ARD prior
704 /// on that atom. `isometry_pin_active` records whether an isometry gauge
705 /// penalty was installed on the fit: `false` escalates the certificate to the
706 /// `diffeomorphism-unpinned` verdict (the honest "no metric pin" statement),
707 /// exactly as the certificate's own escalation flag specifies.
708 ///
709 /// Pure read: it never mutates the term, never touches a loss / criterion /
710 /// penalty / optimizer state.
711 pub fn fit_diagnostics_report(
712 &self,
713 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
714 isometry_pin_active: bool,
715 reconstruction_dispersion: Option<f64>,
716 assignments_override: Option<ArrayView2<'_, f64>>,
717 ) -> Result<SaeManifoldFitDiagnostics, String> {
718 if let Some(view) = assignments_override {
719 let n = self.n_obs();
720 let k = self.k_atoms();
721 if view.dim() != (n, k) {
722 return Err(format!(
723 "fit_diagnostics_report: assignments_override shape {:?} must be ({n}, {k})",
724 view.dim()
725 ));
726 }
727 }
728 let metric = self.diagnostic_metric()?;
729 let atom_two_lens =
730 crate::inference::atom_lens::atom_two_lens(self, &metric, assignments_override);
731
732 let (certificate_model, streamed_curvature) =
733 self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
734 // #998: within-atom gauge families are certified on their EXACT orbits
735 // in the model's own (decoder, coordinate) parameter space — compensated
736 // symmetries are data-nulls by construction there, no lowering-error
737 // calibration involved. This now holds whether or not an isometry pin is
738 // active:
739 // * pin INACTIVE ⇒ the orbit verdict is the data residual alone (no
740 // penalty operator);
741 // * pin ACTIVE ⇒ the orbit verdict adds the isometry pin's orbit-space
742 // curvature through an [`OrbitPenaltyOperator`] lowered from the
743 // atom's second jet `Φ''` (the pullback-metric change along the orbit
744 // differentiates `J = Φ'B` through `t`). A model-class symmetry that
745 // preserves the metric stays a certified freedom; a non-isometric
746 // orbit (a basis not closed under the action) is genuinely pinned.
747 // The relative-curvature fraction `cost/stiffness²` is invariant to the
748 // pin strength μ (both faces scale with μ), so the operator is built at a
749 // canonical unit weight. An atom whose basis exposes no analytic second
750 // jet supplies no operator and falls back to the data residual — never an
751 // error. Magic-by-default either way: the choice is derived from the fit,
752 // never a flag.
753 let views = self.atom_parameter_views();
754 let ops: Vec<Option<crate::identifiability::OrbitPenaltyOperator>> =
755 if isometry_pin_active {
756 views
757 .iter()
758 .map(|view| {
759 view.as_ref().and_then(|v| {
760 crate::identifiability::isometry_orbit_penalty_operator(
761 v, 1.0,
762 )
763 })
764 })
765 .collect()
766 } else {
767 (0..self.k_atoms()).map(|_| None).collect()
768 };
769 let residual_gauge = if isometry_pin_active {
770 // The pin-active path consumes the per-row Jacobian curvature
771 // directly (the certificate_model retains it under a pin), so route
772 // through the non-streamed exact entry point.
773 crate::identifiability::residual_gauge_exact(
774 &certificate_model,
775 &views,
776 &ops,
777 )?
778 } else {
779 let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
780 "fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
781 .to_string()
782 })?;
783 crate::identifiability::residual_gauge_exact_from_curvature_gram(
784 &certificate_model,
785 &views,
786 &ops,
787 curvature_gram,
788 root_rows,
789 )?
790 };
791
792 // #1097 / #1103: per-atom Riesz-debiased functionals and the any-n-valid
793 // split-LRT smooth-structure e-value (non-constant vs constant inner
794 // decoder), read straight off the certificate model — which carries
795 // each atom's `inner_fit` snapshot when the caller harvested it via
796 // [`Self::set_atom_inner_fits`] before this report. Atoms without a
797 // harvested inner fit degrade their inference fields to `None` inside
798 // `atom_inference_reports`, so this is always populated (one entry per
799 // atom) and never gated by a flag.
800 let atom_inference =
801 crate::identifiability::atom_inference_reports(&certificate_model);
802
803 Ok(SaeManifoldFitDiagnostics {
804 atom_two_lens,
805 residual_gauge,
806 incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
807 Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
808 self, dispersion,
809 )?),
810 None => None,
811 },
812 atom_inference,
813 })
814 }
815
816 /// Build the trust-diagnostics producer for the Python `diagnostics` block.
817 ///
818 /// `assignments` is supplied by the payload assembly site so top-k projection,
819 /// when requested, is reflected in coverage/frequency and in the tangent
820 /// spectra. The active threshold is shared with the atom lens so all
821 /// assignment-support diagnostics agree on what "active" means.
822 pub fn trust_diagnostics_report(
823 &self,
824 assignments: ArrayView2<'_, f64>,
825 ) -> Result<SaeTrustDiagnostics, String> {
826 let n = self.n_obs();
827 let k_atoms = self.k_atoms();
828 if assignments.dim() != (n, k_atoms) {
829 return Err(format!(
830 "trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
831 assignments.dim()
832 ));
833 }
834 if !assignments.iter().all(|v| v.is_finite()) {
835 return Err("trust_diagnostics_report: assignments must be finite".to_string());
836 }
837 let metric = self.diagnostic_metric()?;
838 let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
839 let mut atoms = Vec::with_capacity(k_atoms);
840 let mut atom_trust = Vec::with_capacity(k_atoms);
841 for (atom_idx, atom) in self.atoms.iter().enumerate() {
842 let mut active_token_count = 0usize;
843 let mut activation_sum = 0.0_f64;
844 for row in 0..n {
845 let mass = assignments[[row, atom_idx]];
846 activation_sum += mass;
847 if mass > active_threshold {
848 active_token_count += 1;
849 }
850 }
851 let coverage = if n > 0 {
852 active_token_count as f64 / n as f64
853 } else {
854 0.0
855 };
856 let activation_frequency = if n > 0 {
857 activation_sum / n as f64
858 } else {
859 0.0
860 };
861 let (sigma_min_tangent, sigma_max_tangent) = self
862 .atom_tangent_spectrum_from_assignments(
863 atom_idx,
864 assignments,
865 &metric,
866 active_threshold,
867 )?;
868 let tangent_condition_score = if sigma_max_tangent > 0.0 {
869 (sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
870 } else {
871 0.0
872 };
873 let trust_score = tangent_condition_score;
874 atom_trust.push(trust_score);
875 atoms.push(SaeAtomTrustDiagnostics {
876 trust_score,
877 sigma_min_tangent,
878 sigma_max_tangent,
879 tangent_condition_score,
880 coverage,
881 activation_frequency,
882 untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
883 active_token_count,
884 });
885 }
886 Ok(SaeTrustDiagnostics { atom_trust, atoms })
887 }
888
889 pub(crate) fn atom_tangent_spectrum_from_assignments(
890 &self,
891 atom_idx: usize,
892 assignments: ArrayView2<'_, f64>,
893 metric: &gam_problem::RowMetric,
894 active_threshold: f64,
895 ) -> Result<(f64, f64), String> {
896 let atom = &self.atoms[atom_idx];
897 let d = atom.latent_dim;
898 let p = self.output_dim();
899 if d == 0 || p == 0 {
900 return Ok((0.0, 0.0));
901 }
902 let mut gram = Array2::<f64>::zeros((d, d));
903 let mut active_mass_sum = 0.0_f64;
904 let mut jac_row = vec![0.0_f64; p * d];
905 for row in 0..self.n_obs() {
906 let mass = assignments[[row, atom_idx]];
907 if !(mass > active_threshold) {
908 continue;
909 }
910 active_mass_sum += mass;
911 for axis in 0..d {
912 let start = axis;
913 let mut tangent = vec![0.0_f64; p];
914 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
915 for out in 0..p {
916 jac_row[out * d + start] = tangent[out];
917 }
918 }
919 let row_pullback = metric.pullback(row, &jac_row, d);
920 for axis_a in 0..d {
921 for axis_b in 0..=axis_a {
922 gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
923 }
924 }
925 jac_row.fill(0.0);
926 }
927 if !(active_mass_sum > 0.0) {
928 return Ok((0.0, 0.0));
929 }
930 let inv_mass = 1.0 / active_mass_sum;
931 for axis_a in 0..d {
932 for axis_b in 0..=axis_a {
933 let value = gram[[axis_a, axis_b]] * inv_mass;
934 gram[[axis_a, axis_b]] = value;
935 gram[[axis_b, axis_a]] = value;
936 }
937 }
938 let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
939 format!(
940 "trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
941 )
942 })?;
943 let mut sigma_min = f64::INFINITY;
944 let mut sigma_max = 0.0_f64;
945 for value in evals.iter().copied() {
946 let clamped = value.max(0.0);
947 let sigma = clamped.sqrt();
948 sigma_min = sigma_min.min(sigma);
949 sigma_max = sigma_max.max(sigma);
950 }
951 if sigma_min.is_finite() {
952 Ok((sigma_min, sigma_max))
953 } else {
954 Ok((0.0, 0.0))
955 }
956 }
957
958 /// Per-atom exact parameter-space views for the #998 certificate path:
959 /// the basis values / first-derivative jet, decoder coefficients, latent
960 /// coordinates, and assignment mass each atom was actually fitted with.
961 /// Sphere atoms get `None` (their chart's group action is nonlinear, so
962 /// the exact-orbit realisation does not apply and they stay on the frame
963 /// path), as does any atom whose coordinate chart width disagrees with its
964 /// latent dimension (a structurally inconsistent atom must not masquerade
965 /// as exactly certified).
966 pub(crate) fn atom_parameter_views(
967 &self,
968 ) -> Vec<Option<crate::identifiability::AtomParameterView>> {
969 let assignments = self.assignment.assignments();
970 let n = self.n_obs();
971 self.atoms
972 .iter()
973 .enumerate()
974 .map(|(k, atom)| {
975 if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
976 return None;
977 }
978 let coords = self.assignment.coords[k].as_matrix().to_owned();
979 if coords.nrows() != n || coords.ncols() != atom.latent_dim {
980 return None;
981 }
982 let mut activations = Array1::<f64>::zeros(n);
983 for row in 0..n {
984 activations[row] = assignments[[row, k]];
985 }
986 // Second jet Φ'' (#998): supplied when the atom's evaluator
987 // exposes an analytic Hessian, so a pin-active fit can lower its
988 // orbit-space isometry penalty operator (the metric-change of the
989 // pullback gram differentiates Φ' through t). Absent ⇒ the orbit
990 // verdict stays on the data residual / no-pin path, never an
991 // error.
992 let basis_second_jet = atom
993 .basis_evaluator
994 .as_ref()
995 .and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
996 .and_then(|res| res.ok());
997 Some(crate::identifiability::AtomParameterView {
998 basis_values: atom.basis_values.clone(),
999 basis_jacobian: atom.basis_jacobian.clone(),
1000 decoder: atom.decoder_coefficients.clone(),
1001 coords,
1002 activations,
1003 basis_second_jet,
1004 })
1005 })
1006 .collect()
1007 }
1008
1009 /// Lower this fitted term into the self-contained
1010 /// [`FittedSaeManifold`](crate::identifiability::FittedSaeManifold) the
1011 /// residual-gauge certificate consumes.
1012 ///
1013 /// The certificate's parameter space is the per-atom decoder **frame** — the
1014 /// `(output_dim, latent_dim)` image of the atom's latent axes in output space.
1015 /// We realise it as the active-mass-weighted mean decoder tangent
1016 /// `frame_k[:, a] = (Σ_n a_{nk} · ∂g_k/∂t_a(n)) / Σ_n a_{nk}` over the atom's
1017 /// active rows (the centroid decoder Jacobian columns the certificate docs
1018 /// name). The per-row pinning Jacobian block `J_n ∈ ℝ^{p × param_dim}` is the
1019 /// assignment-weighted per-row decoder tangent placed at each atom's frame
1020 /// slot: column `(k, i, a)` of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i]` — exactly
1021 /// the directions the reconstruction data gives cost to, in the same metric
1022 /// the fit used (whitened by the certificate through `RowMetric`).
1023 ///
1024 /// The flattened frame layout matches the certificate's
1025 /// `vec(frame_0) ⊕ vec(frame_1) ⊕ …`, row-major within each frame
1026 /// (`frame_k[i, a]` at offset `atom_offset(k) + i·latent_dim_k + a`).
1027 pub(crate) fn to_residual_gauge_model(
1028 &self,
1029 metric: gam_problem::RowMetric,
1030 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
1031 isometry_pin_active: bool,
1032 ) -> Result<
1033 (
1034 crate::identifiability::FittedSaeManifold,
1035 Option<(Array2<f64>, usize)>,
1036 ),
1037 String,
1038 > {
1039 use crate::identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
1040
1041 let n = self.n_obs();
1042 let p = self.output_dim();
1043 let k = self.k_atoms();
1044 let assignments = self.assignment.assignments();
1045
1046 // Per-atom frame `(p, d)` = active-mass-weighted mean decoder tangent,
1047 // and the flattened-frame column offset bookkeeping for the joint
1048 // parameter vector (`vec(frame_0) ⊕ …`, row-major within each frame).
1049 let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
1050 let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
1051 let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
1052 let mut cursor = 0usize;
1053 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1054 let d = atom.latent_dim;
1055 let topology = match (&atom.basis_kind, d) {
1056 (SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
1057 AtomTopology::Circle
1058 }
1059 (SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
1060 AtomTopology::Torus { latent_dim: d }
1061 }
1062 (SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
1063 // `Cylinder` (`S¹ × ℝ`) has exactly one continuous gauge: the
1064 // rotation (shift) of the periodic axis. The unbounded line axis
1065 // carries no rotational gauge, and its translation is already
1066 // pinned by the design's constant column — so the identifiability
1067 // gauge is that of a single circle. Fixing it as `Torus` would
1068 // over-impose a second (nonexistent) circle shift; fixing it as
1069 // `EuclideanPatch { 2 }` would over-impose a frame rotation
1070 // mixing the periodic and linear axes. `Circle` fixes the one
1071 // real continuous gauge and leaves the linear axis ungauged.
1072 (SaeAtomBasisKind::Cylinder, _) => AtomTopology::Circle,
1073 (
1074 SaeAtomBasisKind::Linear
1075 | SaeAtomBasisKind::Duchon
1076 | SaeAtomBasisKind::EuclideanPatch
1077 | SaeAtomBasisKind::Poincare
1078 | SaeAtomBasisKind::Precomputed(_),
1079 _,
1080 ) => AtomTopology::EuclideanPatch { latent_dim: d },
1081 };
1082
1083 let mut frame = Array2::<f64>::zeros((p, d));
1084 let mut active_mass = 0.0_f64;
1085 let mut tangent = vec![0.0_f64; p];
1086 for row in 0..n {
1087 let a_nk = assignments[[row, atom_idx]];
1088 if !(a_nk > 0.0) {
1089 continue;
1090 }
1091 active_mass += a_nk;
1092 for axis in 0..d {
1093 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1094 for i in 0..p {
1095 frame[[i, axis]] += a_nk * tangent[i];
1096 }
1097 }
1098 }
1099 if active_mass > 0.0 {
1100 let inv = 1.0 / active_mass;
1101 frame.mapv_inplace(|v| v * inv);
1102 }
1103
1104 // #995 lowering-error scale: mass-weighted relative dispersion of
1105 // the per-row tangents around the mean frame just built,
1106 // Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖².
1107 // 0 ⇒ the frame represents every active row exactly (flat
1108 // decoder); → 1 ⇒ the tangent field disperses so strongly (e.g. a
1109 // full circle, whose tangents average out) that the mean-frame
1110 // compression cannot distinguish gauge motion from curvature. The
1111 // certificate calibrates its per-generator verdict tolerance to
1112 // this scale so it never claims a pin it cannot resolve.
1113 let mut disp_num = 0.0_f64;
1114 let mut disp_den = 0.0_f64;
1115 for row in 0..n {
1116 let a_nk = assignments[[row, atom_idx]];
1117 if !(a_nk > 0.0) {
1118 continue;
1119 }
1120 for axis in 0..d {
1121 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1122 for i in 0..p {
1123 let dev = tangent[i] - frame[[i, axis]];
1124 disp_num += a_nk * dev * dev;
1125 disp_den += a_nk * tangent[i] * tangent[i];
1126 }
1127 }
1128 }
1129 let lowering_error = if disp_den > 0.0 {
1130 (disp_num / disp_den).clamp(0.0, 1.0)
1131 } else {
1132 0.0
1133 };
1134
1135 let ard_variances = per_atom_ard_variances
1136 .and_then(|all| all.get(atom_idx))
1137 .and_then(|opt| opt.clone())
1138 .filter(|v| v.len() == d);
1139
1140 fitted_atoms.push(FittedAtom {
1141 name: atom.name.clone(),
1142 topology,
1143 frame,
1144 ard_variances,
1145 lowering_error,
1146 // #1019: post-fit chart canonicalization (arc length for
1147 // d = 1, isometry-flow for d = 2 torus, flat-reference
1148 // isometry-flow for d = 2 free/patch, round-sphere
1149 // conformal-boost flow for d = 2 sphere atoms) pins the chart;
1150 // the certificate downgrades this atom's chart freedom to the
1151 // finite isometry group with PinnedByCanonicalization
1152 // provenance.
1153 chart_canonicalized: atom.chart_canonicalized
1154 && (d == 1
1155 || (d == 2
1156 && matches!(
1157 atom.basis_kind,
1158 SaeAtomBasisKind::Torus
1159 | SaeAtomBasisKind::Linear
1160 | SaeAtomBasisKind::Duchon
1161 | SaeAtomBasisKind::EuclideanPatch
1162 | SaeAtomBasisKind::Sphere
1163 ))),
1164 // #1097 / #1103: the per-atom inner-decoder-smooth snapshot,
1165 // attached when the post-fit harness has run
1166 // [`Self::set_atom_inner_fits`] (it needs the reconstruction
1167 // target Z, dropped from the objective at fit end). `None` on a
1168 // bare certificate-only model, or for a degenerate atom whose
1169 // inner Hessian was not SPD.
1170 inner_fit: self
1171 .atom_inner_fits
1172 .as_ref()
1173 .and_then(|fits| fits.get(atom_idx))
1174 .and_then(|slot| slot.clone()),
1175 });
1176 atom_offsets.push(cursor);
1177 atom_axis_dim.push(d);
1178 cursor += p * d;
1179 }
1180 let param_dim = cursor;
1181
1182 // Per-row pinning Jacobian `J_n ∈ ℝ^{p × param_dim}` flattened row-major
1183 // (`J_n[i, c] = jacobian_rows[n][i · param_dim + c]`). Column `(k, i', a)`
1184 // of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i']` placed at the atom-k frame slot
1185 // and read out on output coordinate `i = i'` (a frame perturbation of
1186 // output `i'` moves only the row's output coordinate `i'`).
1187 //
1188 // The pinned certificate still consumes the legacy row-block contract.
1189 // The unpinned exact path consumes only `RᵀR`, so stream each transient
1190 // row Jacobian through the metric whitening and discard it immediately.
1191 let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
1192 let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1193 let mut tangent = vec![0.0_f64; p];
1194 for row in 0..n {
1195 let mut j_flat = vec![0.0_f64; p * param_dim];
1196 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1197 let a_nk = assignments[[row, atom_idx]];
1198 if !(a_nk > 0.0) {
1199 continue;
1200 }
1201 let d = atom_axis_dim[atom_idx];
1202 let base = atom_offsets[atom_idx];
1203 for axis in 0..d {
1204 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1205 for i in 0..p {
1206 // Frame coordinate `(k, i, axis)` sits at column
1207 // `base + i·d + axis`; it sources output coordinate `i`.
1208 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1209 }
1210 }
1211 }
1212 jacobian_rows.push(j_flat);
1213 }
1214 (jacobian_rows, None)
1215 } else {
1216 let streamed = self.residual_gauge_streamed_data_curvature(
1217 &metric,
1218 &atom_offsets,
1219 &atom_axis_dim,
1220 param_dim,
1221 )?;
1222 (Vec::new(), Some(streamed))
1223 };
1224
1225 // Isometry-penalty curvature root over the frame parameter space. When
1226 // the isometry gauge pin is active it gives curvature along every fitted
1227 // frame direction (it resists deviation of the decoder image from its
1228 // arc-length parameterization), so its row space is the span of the
1229 // per-atom frame columns: one root row per `(k, axis)` carrying that
1230 // atom's frame column at the atom's frame slot. Empty (`0 × param_dim`)
1231 // when the pin is inactive — exactly the certificate's escalation
1232 // condition to `diffeomorphism-unpinned`.
1233 let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
1234 let mut root_rows: Vec<Array1<f64>> = Vec::new();
1235 for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
1236 let d = atom_axis_dim[atom_idx];
1237 let base = atom_offsets[atom_idx];
1238 for axis in 0..d {
1239 let mut r = Array1::<f64>::zeros(param_dim);
1240 let mut any = false;
1241 for i in 0..p {
1242 let v = fitted.frame[[i, axis]];
1243 if v != 0.0 {
1244 any = true;
1245 }
1246 r[base + i * d + axis] = v;
1247 }
1248 if any {
1249 root_rows.push(r);
1250 }
1251 }
1252 }
1253 let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
1254 for (ri, r) in root_rows.iter().enumerate() {
1255 root.row_mut(ri).assign(r);
1256 }
1257 root
1258 } else {
1259 Array2::<f64>::zeros((0, param_dim))
1260 };
1261
1262 Ok((
1263 FittedSaeManifold {
1264 atoms: fitted_atoms,
1265 jacobian_rows,
1266 isometry_penalty_root,
1267 metric,
1268 },
1269 streamed_curvature,
1270 ))
1271 }
1272
1273 pub(crate) fn residual_gauge_streamed_data_curvature(
1274 &self,
1275 metric: &gam_problem::RowMetric,
1276 atom_offsets: &[usize],
1277 atom_axis_dim: &[usize],
1278 param_dim: usize,
1279 ) -> Result<(Array2<f64>, usize), String> {
1280 let n = self.n_obs();
1281 let p = self.output_dim();
1282 if metric.p_out() != p {
1283 return Err(format!(
1284 "residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
1285 metric.p_out()
1286 ));
1287 }
1288 let rank = metric.metric_rank();
1289 let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
1290 if param_dim == 0 || n == 0 || rank == 0 {
1291 return Ok((gram, n * rank));
1292 }
1293
1294 let assignments = self.assignment.assignments();
1295 let mut tangent = vec![0.0_f64; p];
1296 let mut j_flat = vec![0.0_f64; p * param_dim];
1297 let mut root_row = Array1::<f64>::zeros(param_dim);
1298 for row in 0..n {
1299 j_flat.fill(0.0);
1300 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1301 let a_nk = assignments[[row, atom_idx]];
1302 if !(a_nk > 0.0) {
1303 continue;
1304 }
1305 let d = atom_axis_dim[atom_idx];
1306 let base = atom_offsets[atom_idx];
1307 for axis in 0..d {
1308 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1309 for i in 0..p {
1310 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1311 }
1312 }
1313 }
1314
1315 if metric.drives_gauge() {
1316 for r in 0..rank {
1317 root_row.fill(0.0);
1318 for c in 0..param_dim {
1319 let mut acc = 0.0_f64;
1320 for i in 0..p {
1321 acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
1322 }
1323 root_row[c] = acc;
1324 }
1325 let row_slice = root_row.as_slice().ok_or_else(|| {
1326 "residual_gauge_streamed_data_curvature: non-contiguous root row"
1327 .to_string()
1328 })?;
1329 Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
1330 }
1331 } else {
1332 for i in 0..p {
1333 let start = i * param_dim;
1334 let end = start + param_dim;
1335 Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
1336 }
1337 }
1338 }
1339
1340 for a in 0..param_dim {
1341 for b in 0..a {
1342 gram[[b, a]] = gram[[a, b]];
1343 }
1344 }
1345 Ok((gram, n * rank))
1346 }
1347
1348 pub(crate) fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
1349 for a in 0..row.len() {
1350 let va = row[a];
1351 if va == 0.0 {
1352 continue;
1353 }
1354 for b in 0..=a {
1355 let vb = row[b];
1356 if vb != 0.0 {
1357 gram[[a, b]] += va * vb;
1358 }
1359 }
1360 }
1361 }
1362
1363 pub fn set_temperature_schedule(
1364 &mut self,
1365 sched: GumbelTemperatureSchedule,
1366 ) -> Result<(), String> {
1367 sched.validate()?;
1368 self.assignment
1369 .mode
1370 .set_temperature(sched.current_tau(sched.iter_count))?;
1371 self.temperature_schedule = Some(sched);
1372 Ok(())
1373 }
1374
1375 pub(crate) fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
1376 let Some(schedule) = self.temperature_schedule.as_mut() else {
1377 return Ok(None);
1378 };
1379 schedule.validate()?;
1380 let tau = schedule.step();
1381 self.assignment.mode.set_temperature(tau)?;
1382 Ok(Some(tau))
1383 }
1384
1385 pub fn n_obs(&self) -> usize {
1386 self.assignment.n_obs()
1387 }
1388
1389 pub fn k_atoms(&self) -> usize {
1390 self.atoms.len()
1391 }
1392
1393 /// Auto-derived in-core vs streaming plan for SAE Arrow-Schur work.
1394 ///
1395 /// This is intentionally not user-configurable: the route follows the
1396 /// retained full-batch working-set estimate and the currently selected GPU
1397 /// memory budget when CUDA is usable, otherwise a conservative host budget.
1398 pub fn streaming_plan(&self) -> SaeStreamingPlan {
1399 let n_obs = self.n_obs();
1400 let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
1401 let d_max = self
1402 .atoms
1403 .iter()
1404 .map(|atom| atom.latent_dim)
1405 .max()
1406 .unwrap_or(0);
1407 let border_dim = if self.any_frame_active() {
1408 self.factored_border_dim()
1409 } else {
1410 self.beta_dim()
1411 };
1412 sae_streaming_plan_for_shape(n_obs, total_basis, self.k_atoms(), d_max, border_dim)
1413 }
1414
1415 /// Construction-time validation: every Psi-tier analytic penalty in the
1416 /// registry must be dispatchable into the SAE arrow-Schur row layout.
1417 ///
1418 /// Two invariants are enforced upfront so the dispatch loop in
1419 /// `add_sae_analytic_penalty_contributions` is total (no runtime
1420 /// "unsupported penalty" fallthrough, no per-call K-gating):
1421 ///
1422 /// 1. Every Psi-tier penalty is either in [`sae_penalty_is_row_block_supported`],
1423 /// or `NuclearNorm` (which is redirected to the per-atom decoder (β) block
1424 /// rather than the coord "t" row block). Assignment sparsity penalties
1425 /// (`IBPAssignment`, `SoftmaxAssignmentSparsity`) are refused because the SAE
1426 /// term already owns them through its built-in assignment path
1427 /// (`loss.assignment_sparsity`). Penalty kinds with cross-row structure
1428 /// (`TotalVariation`, `Monotonicity`, `BlockSparsity`,
1429 /// `IvaeRidgeMeanGauge`, `Orthogonality`, `NestedPrefix`,
1430 /// `SheafConsistency`) cannot be expressed in the SAE row-block layout
1431 /// and are refused here.
1432 ///
1433 /// 2. If any Psi-tier row-block penalty is present, every atom shares
1434 /// the same coord latent dim. The current registry model carries one
1435 /// `latent_dim` per descriptor (the "t" latent block declares one
1436 /// `d` value); per-atom dispatch with heterogeneous `d_k` would
1437 /// require per-atom registry entries or per-kind in-place
1438 /// reshaping. Mixed-d row-block fits are rejected with an actionable
1439 /// error pointing at the configuration mismatch.
1440 ///
1441 /// The K=1 case trivially satisfies (2). Beta-tier and rho-tier
1442 /// penalties are not constrained here.
1443 pub(crate) fn validate_analytic_penalty_registry(
1444 &self,
1445 registry: &AnalyticPenaltyRegistry,
1446 ) -> Result<(), String> {
1447 let mut row_block_penalty_present = false;
1448 for penalty in ®istry.penalties {
1449 if penalty.tier() != PenaltyTier::Psi {
1450 continue;
1451 }
1452 if matches!(
1453 penalty,
1454 AnalyticPenaltyKind::IBPAssignment(_)
1455 | AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
1456 ) {
1457 return Err(format!(
1458 "SAE-manifold term refuses analytic penalty {:?}: assignment sparsity \
1459 is owned by the built-in SAE assignment path (loss.assignment_sparsity). \
1460 Registering it would double-count the objective and gradient",
1461 penalty.name()
1462 ));
1463 }
1464 // NuclearNorm is redirected to the per-atom decoder (β) block in
1465 // `add_sae_beta_penalty` (it penalizes each atom's decoder matrix
1466 // singular spectrum, i.e. its embedding rank), so it bypasses the
1467 // coord "t" row-block requirement below.
1468 if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
1469 continue;
1470 }
1471 if !sae_penalty_is_row_block_supported(penalty) {
1472 return Err(format!(
1473 "SAE-manifold term refuses analytic penalty {:?}: this kind \
1474 has cross-row structure and cannot be expressed in the \
1475 arrow-Schur row layout. Use only row-block-supported \
1476 coord penalties (ARD, BlockOrthogonality, \
1477 Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
1478 ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
1479 coord latent block, or move the penalty to a non-SAE \
1480 term",
1481 penalty.name()
1482 ));
1483 }
1484 row_block_penalty_present = true;
1485 }
1486 if row_block_penalty_present {
1487 let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
1488 if let Some(first) = dims.next() {
1489 if let Some(mismatch) = dims.find(|d| *d != first) {
1490 return Err(format!(
1491 "SAE-manifold term refuses row-block analytic penalty: \
1492 atoms have heterogeneous coord latent dims (saw {first} \
1493 and {mismatch}). Row-block penalties (ARD, \
1494 BlockOrthogonality, ...) target the unified \"t\" \
1495 latent block whose declared `d` matches one shape; \
1496 per-atom dispatch with mixed `d_k` would silently \
1497 truncate or expand axes. Configure all atoms with the \
1498 same `atom_dim`, or split the row-block penalty into \
1499 per-atom descriptors keyed to per-atom latent blocks"
1500 ));
1501 }
1502 }
1503 }
1504 Ok(())
1505 }
1506
1507 pub fn output_dim(&self) -> usize {
1508 self.atoms[0].output_dim()
1509 }
1510
1511 pub fn beta_dim(&self) -> usize {
1512 let p = self.output_dim();
1513 self.atoms.iter().map(|a| a.basis_size() * p).sum()
1514 }
1515
1516 pub(crate) fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
1517 let mut workspace =
1518 std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
1519 if workspace.dim() != (border_dim, border_dim) {
1520 workspace = Array2::<f64>::zeros((border_dim, border_dim));
1521 } else {
1522 workspace.fill(0.0);
1523 }
1524 workspace
1525 }
1526
1527 pub(crate) fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
1528 let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
1529 self.border_hbb_workspace = workspace;
1530 }
1531
1532 /// Factored arrow-Schur border dimension `Σ_k M_k · r_k` (issue #972): the
1533 /// number of decoder coordinates the border actually carries once the
1534 /// low-rank Grassmann frames are profiled out. Atoms with no active frame
1535 /// contribute their full `M_k · p` (`r_k == p`), so on the all-full-`B` path
1536 /// this equals [`Self::beta_dim`]. The border Cholesky / evidence log-det
1537 /// scale with THIS count, not `beta_dim`.
1538 pub fn factored_border_dim(&self) -> usize {
1539 self.atoms.iter().map(|a| a.border_coeff_count()).sum()
1540 }
1541
1542 /// Total profiled-out Grassmann manifold dimension `Σ_k r_k·(p − r_k)` across
1543 /// all active frames (issue #972). This is the count of decoder-frame degrees
1544 /// of freedom estimated OUTSIDE the border by closed-form polar steps, and it
1545 /// must enter the Laplace evidence dimension accounting (evidence honesty):
1546 /// the profiled frame is a MAP point on `∏_k Gr(r_k, p)`, contributing this
1547 /// many free dimensions to the model. `0` when every atom is on the full-`B`
1548 /// path. Threaded into [`Self::reml_occam_term`].
1549 pub fn grassmann_evidence_dimension(&self) -> usize {
1550 self.atoms
1551 .iter()
1552 .map(|a| a.frame_manifold_dimension())
1553 .sum()
1554 }
1555
1556 /// True iff any atom has an active low-rank Grassmann frame (issue #972).
1557 pub fn frames_active(&self) -> bool {
1558 self.atoms.iter().any(|a| a.decoder_frame.is_some())
1559 }
1560
1561 /// Alias of [`Self::frames_active`] (issue #972 / #977 T1): the predicate the
1562 /// assembly / step-lift branch on to decide whether the β-tier is built in
1563 /// the factored coordinate layout. Named to read as the question
1564 /// "is the factored path engaged?" at its call sites.
1565 pub fn any_frame_active(&self) -> bool {
1566 self.frames_active()
1567 }
1568
1569 /// Per-atom column offsets of the *factored* border (issue #972 / #977 T1):
1570 /// the running prefix sum of `M_k · r_k`, one entry per atom (the same
1571 /// convention as [`Self::beta_offsets`]). This is the start of each atom's
1572 /// `C_k` block in the reduced border vector; on the all-full-`B` path it
1573 /// equals `beta_offsets`. Distinct from [`Self::factored_border_offsets`]
1574 /// only in name (both compute the identical prefix sum) — this method is the
1575 /// one the frame transform reads, mirroring `beta_offsets` at the call site.
1576 pub fn factored_beta_offsets(&self) -> Vec<usize> {
1577 self.factored_border_offsets()
1578 }
1579
1580 /// Frame output matrix `U_k ∈ St(p, r_k)` for atom `k` (issue #972 / #977 T1).
1581 /// Returns the active frame `U_k` (`p × r_k`) when atom `k` is framed, else
1582 /// the identity `I_p` (the `r_k == p`, `U_k == I_p` full-`B` special case) so
1583 /// the projection / lift code is uniform across a mixed dictionary.
1584 pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
1585 let atom = &self.atoms[atom_idx];
1586 match &atom.decoder_frame {
1587 Some(frame) => frame.frame().to_owned(),
1588 None => Array2::<f64>::eye(atom.output_dim()),
1589 }
1590 }
1591
1592 /// Per-pair frame factor `W_{ij} = U_iᵀ U_j` (`r_i × r_j`) used as the output
1593 /// factor of the factored data β-Hessian block `G_{ij} ⊗ W_{ij}` (issue #972
1594 /// / #977 T1). When both atoms are framed this is the dense principal-angle
1595 /// cosine matrix between the two frames; for `i == j` with an orthonormal
1596 /// frame it is exactly `I_{r_i}`; for any un-framed atom the corresponding
1597 /// `U` is `I_p`, so a same-atom un-framed pair gives `I_p` (the clean full-`B`
1598 /// `G ⊗ I_p` collapse) and a framed/un-framed cross pair gives the rectangular
1599 /// `U_iᵀ` / `U_j` overlap.
1600 pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
1601 let ui = self.frame_output_matrix(atom_i);
1602 let uj = self.frame_output_matrix(atom_j);
1603 // `U_iᵀ U_j`: `(r_i × p) · (p × r_j)`. `fast_atb` forms `U_iᵀ U_j` directly.
1604 fast_atb(&ui, &uj)
1605 }
1606
1607 /// Per-atom column offsets of the *factored* border (issue #972): the
1608 /// running prefix sum of `M_k · r_k`. The analogue of [`Self::beta_offsets`]
1609 /// for the reduced coordinate layout — atom `k`'s `C_k` occupies
1610 /// `[factored_border_offsets()[k] .. + M_k·r_k)`. On the full-`B` path this
1611 /// equals `beta_offsets`.
1612 pub fn factored_border_offsets(&self) -> Vec<usize> {
1613 let mut out = Vec::with_capacity(self.k_atoms());
1614 let mut cursor = 0usize;
1615 for atom in &self.atoms {
1616 out.push(cursor);
1617 cursor += atom.border_coeff_count();
1618 }
1619 out
1620 }
1621
1622 /// Assemble the factored border coordinate vector `C = [vec(C_1); …; vec(C_K)]`
1623 /// in row-major `C_k[m, j] → C[off_k + m·r_k + j]` layout (issue #972).
1624 ///
1625 /// This is the reduced state the arrow-Schur border carries when frames are
1626 /// active: its length is [`Self::factored_border_dim`] (`Σ M_k·r_k`), the
1627 /// border-size invariant verified by [`grassmann_assert_border_dim_invariant`].
1628 /// Atoms
1629 /// without an active frame contribute their full `vec(B_k)` (their `r_k == p`
1630 /// coordinates are the decoder itself), so on the all-full-`B` path this
1631 /// reproduces [`Self::flatten_beta`].
1632 pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
1633 let offsets = self.factored_border_offsets();
1634 let mut out = Array1::<f64>::zeros(self.factored_border_dim());
1635 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1636 let off = offsets[atom_idx];
1637 let r = atom.border_frame_rank();
1638 let m = atom.basis_size();
1639 let coords = match atom.factored_coordinates()? {
1640 Some(c) => c,
1641 // Full-`B` path: the decoder itself is the coordinate matrix.
1642 None => atom.decoder_coefficients.clone(),
1643 };
1644 for basis_col in 0..m {
1645 for j in 0..r {
1646 out[off + basis_col * r + j] = coords[[basis_col, j]];
1647 }
1648 }
1649 }
1650 Ok(out)
1651 }
1652
1653 /// Scatter a factored border coordinate vector `C` (length
1654 /// [`Self::factored_border_dim`]) back into the per-atom decoders, refreshing
1655 /// each `decoder_coefficients = C_k · U_kᵀ` so the full-`B` consumers stay
1656 /// consistent after a factored border solve (issue #972). The inverse of
1657 /// [`Self::flatten_factored_border`].
1658 pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
1659 let expected = self.factored_border_dim();
1660 if border.len() != expected {
1661 return Err(format!(
1662 "SaeManifoldTerm::scatter_factored_border: border length {} must equal \
1663 factored border dim {expected}",
1664 border.len()
1665 ));
1666 }
1667 let offsets = self.factored_border_offsets();
1668 for atom_idx in 0..self.atoms.len() {
1669 let off = offsets[atom_idx];
1670 let (r, m, has_frame) = {
1671 let atom = &self.atoms[atom_idx];
1672 (
1673 atom.border_frame_rank(),
1674 atom.basis_size(),
1675 atom.decoder_frame.is_some(),
1676 )
1677 };
1678 let mut coords = Array2::<f64>::zeros((m, r));
1679 for basis_col in 0..m {
1680 for j in 0..r {
1681 coords[[basis_col, j]] = border[off + basis_col * r + j];
1682 }
1683 }
1684 if has_frame {
1685 self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
1686 } else {
1687 // Full-`B` path: the coordinates ARE the decoder.
1688 self.atoms[atom_idx].decoder_coefficients = coords;
1689 }
1690 }
1691 Ok(())
1692 }
1693
1694 /// Auto-derive and install low-rank Grassmann decoder frames across all
1695 /// atoms (issue #972) — magic-by-default, no flag. Each atom independently
1696 /// activates its frame iff the factorization materially shrinks its border
1697 /// (see [`SaeManifoldAtom::maybe_activate_decoder_frame`]). Returns the
1698 /// number of atoms that activated a frame. Idempotent: re-running re-derives
1699 /// each frame from the current decoder.
1700 ///
1701 /// The decision keys on the *frontier* regime the issue targets: at large
1702 /// ambient `p` the full border `Σ M_k · p` reaches `10^7`–`10^8` and the
1703 /// border Cholesky dies, while the decoder's effective column rank `r` stays
1704 /// `≪ p`. Small-`p` atoms (where `r` cannot beat the activation margin)
1705 /// keep the bit-for-bit full-`B` path, so the small-model evidence is
1706 /// unchanged (verified by `factored_evidence_matches_full_b_at_small_p`).
1707 pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
1708 let mut activated = 0usize;
1709 for atom in &mut self.atoms {
1710 let expected_rank = atom.decoder_frame_activation_rank()?;
1711 match (
1712 expected_rank,
1713 atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
1714 ) {
1715 (Some(expected), Some(current)) if expected == current => {
1716 continue;
1717 }
1718 (None, Some(_)) => {
1719 atom.deactivate_decoder_frame();
1720 continue;
1721 }
1722 (None, None) => {
1723 continue;
1724 }
1725 (Some(_), _) => {}
1726 }
1727 if atom.maybe_activate_decoder_frame()?.is_some() {
1728 activated += 1;
1729 }
1730 }
1731 Ok(activated)
1732 }
1733
1734 /// Reconcile decoder-frame activation before a fit entry point. The
1735 /// user-facing `auto_activate_decoder_frames` contract returns only newly
1736 /// installed frames; this helper enforces the stronger invariant the large-p
1737 /// solver needs: every atom whose current decoder satisfies the activation
1738 /// predicate has an active frame after the pass.
1739 pub(crate) fn ensure_decoder_frames_active_for_current_decoder(
1740 &mut self,
1741 ) -> Result<(), String> {
1742 self.auto_activate_decoder_frames()?;
1743 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1744 let expected_rank = atom.decoder_frame_activation_rank()?;
1745 if let Some(expected_rank) = expected_rank {
1746 match atom.decoder_frame.as_ref() {
1747 Some(frame) if frame.rank() == expected_rank => {}
1748 Some(frame) => {
1749 return Err(format!(
1750 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1751 atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
1752 frame.rank()
1753 ));
1754 }
1755 None => {
1756 return Err(format!(
1757 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1758 atom {atom_idx} has audited rank {expected_rank} but no active frame"
1759 ));
1760 }
1761 }
1762 } else if atom.decoder_frame.is_some() {
1763 return Err(format!(
1764 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1765 atom {atom_idx} kept a frame after the full-B predicate won"
1766 ));
1767 }
1768 }
1769 Ok(())
1770 }
1771
1772 /// Closed-form streaming POLAR refresh of every ACTIVE decoder frame from the
1773 /// current data evidence (issue #972 / #977 T1) — the U-block of the
1774 /// alternating block-coordinate ascent that complements the border's
1775 /// C-block Newton step.
1776 ///
1777 /// For each framed atom `k` we accumulate the `p × r_k` cross-moment
1778 /// `A_k = Σ_n a_{n,k} · e_{n,k} · ĉ_{n,k}ᵀ`,
1779 /// where `e_{n,k} = z_n − Σ_{k'≠k} a_{n,k'}·decoded_{k'}(n)` is the row's
1780 /// partial reconstruction residual (everything except atom `k`) and
1781 /// `ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^{r_k}` is atom `k`'s in-span decoded
1782 /// coordinate. The polar factor `U_new = polar(A_k)` is the closed-form MAP
1783 /// frame on `Gr(r_k, p)` given the C-coordinates held fixed — the same
1784 /// `O(p r²)` thin SVD the issue prescribes, run OUTSIDE the border. The frame
1785 /// is then re-installed and the decoder re-projected onto it so the
1786 /// authoritative `B_k = C_k U_newᵀ` and the `(C_k, U_new)` pair stay
1787 /// consistent (a no-op in span for a truly rank-`r` atom). Un-framed atoms
1788 /// are skipped. Returns the number of frames refreshed.
1789 pub(crate) fn refresh_active_frames_from_data(
1790 &mut self,
1791 target: ArrayView2<'_, f64>,
1792 rho: &SaeManifoldRho,
1793 ) -> Result<usize, String> {
1794 let n = self.n_obs();
1795 let p = self.output_dim();
1796 let k_atoms = self.k_atoms();
1797 if n == 0 {
1798 return Ok(0);
1799 }
1800 // Per-row assignments and per-(row, atom) decoded outputs, computed once.
1801 let mut assignments = Vec::with_capacity(n);
1802 for row in 0..n {
1803 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
1804 }
1805 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
1806 let mut dbuf = vec![0.0_f64; p];
1807 for row in 0..n {
1808 for atom_idx in 0..k_atoms {
1809 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
1810 for c in 0..p {
1811 decoded[[row, atom_idx, c]] = dbuf[c];
1812 }
1813 }
1814 }
1815 // Full fitted reconstruction `Σ_k a_k decoded_k`, so the per-atom partial
1816 // residual is `e_k = (z − fitted) + a_k decoded_k` (add atom k back in).
1817 let mut fitted = Array2::<f64>::zeros((n, p));
1818 for row in 0..n {
1819 for atom_idx in 0..k_atoms {
1820 let a = assignments[row][atom_idx];
1821 if a == 0.0 {
1822 continue;
1823 }
1824 for c in 0..p {
1825 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
1826 }
1827 }
1828 }
1829 let mut refreshed = 0usize;
1830 for atom_idx in 0..k_atoms {
1831 // Only atoms with an active frame are refreshed.
1832 let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
1833 continue;
1834 };
1835 let r = self.atoms[atom_idx].border_frame_rank();
1836 let m = self.atoms[atom_idx].basis_size();
1837 // Accumulate `A_k = Σ_n a_k · e_{n,k} · ĉ_{n,k}ᵀ` directly (p × r).
1838 let mut cross = GrassmannCrossMoment::new(p, r);
1839 // Build per-row p-target `a_k·e_k` and r-coord `a_k·ĉ` batched, then
1840 // accumulate as one outer-product sum. `accumulate` forms
1841 // `targetsᵀ·coords`, so scaling EITHER side by `a_k` once gives the
1842 // `a_k²` weight on the cross-moment that matches the C-block normal
1843 // equations (residual leg carries `a_k`, coordinate leg carries
1844 // `a_k`).
1845 let mut targets = Array2::<f64>::zeros((n, p));
1846 let mut rcoords = Array2::<f64>::zeros((n, r));
1847 for row in 0..n {
1848 let a = assignments[row][atom_idx];
1849 // Partial residual e_{n,k} = z_n − (fitted − a_k decoded_k).
1850 for c in 0..p {
1851 let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
1852 targets[[row, c]] = a * e;
1853 }
1854 // In-span coordinate ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^r.
1855 for j in 0..r {
1856 let mut acc = 0.0_f64;
1857 for basis_col in 0..m {
1858 acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
1859 * coords_c[[basis_col, j]];
1860 }
1861 rcoords[[row, j]] = a * acc;
1862 }
1863 }
1864 cross.accumulate(targets.view(), rcoords.view())?;
1865 // `polar(A_k)` is well-defined only when the moment is non-trivial;
1866 // a zero moment (e.g. a fully collapsed atom) leaves the frame as-is.
1867 if cross.moment().iter().all(|&v| v == 0.0) {
1868 continue;
1869 }
1870 self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
1871 refreshed += 1;
1872 }
1873 Ok(refreshed)
1874 }
1875
1876 pub fn beta_offsets(&self) -> Vec<usize> {
1877 let p = self.output_dim();
1878 let mut out = Vec::with_capacity(self.k_atoms());
1879 let mut cursor = 0usize;
1880 for atom in &self.atoms {
1881 out.push(cursor);
1882 cursor += atom.basis_size() * p;
1883 }
1884 out
1885 }
1886
1887 /// Per-atom β column ranges for the block-Jacobi Schur preconditioner.
1888 ///
1889 /// Returns one `Range<usize>` per atom, covering that atom's decoder
1890 /// coefficients in the flat β vector:
1891 /// `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
1892 ///
1893 /// Pass to [`ArrowSchurSystem::set_block_offsets`] so that
1894 /// [`gam_solve::arrow_schur::JacobiPreconditioner`] builds one dense
1895 /// Schur sub-block per atom instead of scalar-diagonal inversion.
1896 pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
1897 let p = self.output_dim();
1898 let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
1899 let mut cursor = 0usize;
1900 for atom in &self.atoms {
1901 let width = atom.basis_size() * p;
1902 ranges.push(cursor..cursor + width);
1903 cursor += width;
1904 }
1905 Arc::from(ranges.into_boxed_slice())
1906 }
1907
1908 /// Decide whether the sparse per-row active-set layout is engaged for a
1909 /// dense-weight assignment mode, and if so derive the per-row active-atom
1910 /// cap and magnitude cutoff.
1911 ///
1912 /// #1408: this plan is mode-agnostic. `assemble_arrow_schur` consults it
1913 /// directly for IBP-MAP, and for `AssignmentMode::Softmax` via
1914 /// [`Self::softmax_active_plan`], which tightens it with an explicit `top_k`
1915 /// (`softmax_active_cap`). Softmax therefore engages the compact active-set
1916 /// layout whenever `top_k` or the budget bounds the active set (the
1917 /// active-sub-block Gershgorin majorizer + coherent logdet/θ-adjoint are
1918 /// landed — see `SaeRowLayout`'s doc); it keeps the full `K`-atom layout only
1919 /// when neither lever engages. The decision is auto-derived from
1920 /// the problem size and the device/host working-set budget — never a CLI flag
1921 /// or kwarg. JumpReLU is not handled here (it always uses its structural gate
1922 /// via [`SaeRowLayout::from_jumprelu`]). The dense Gauss-Newton data Gram `G`
1923 /// is `(m_total × m_total)` f64; if its dense form fits the budget we keep
1924 /// the exact full-support solve (every atom active per row), so small-`K`
1925 /// problems are bit-for-bit unchanged. Above that, we cap each row to the
1926 /// `k_active` atoms that make the *sparse* Gram fit the same budget, with a
1927 /// relative magnitude cutoff that drops assignment mass contributing
1928 /// negligible `O(a²)` curvature.
1929 ///
1930 /// Returns `Some((k_active_cap, cutoff))` to engage sparsity, or `None` to
1931 /// keep the dense full-support layout.
1932 pub(crate) fn sparse_active_plan(&self) -> Option<(usize, f64)> {
1933 // The per-row Riemannian tangent projection for non-Euclidean atom
1934 // latents is now applied directly on the compact active-set rows (see
1935 // the `Some(layout)` arm in `assemble_arrow_schur`, via
1936 // `compact_row_ext_manifold_and_point`), which rebuilds each row's
1937 // product manifold in its compact column order and applies the SAME
1938 // gt/htt/htbeta + Kronecker-Jacobian projections the dense path uses. So
1939 // the sparse plan may engage on curved ext-coord manifolds (circle /
1940 // torus / sphere atoms) — the affordability lever for manifold-SAE at
1941 // large `K`, where the dense `K²` co-assignment Gram is the cost. (The
1942 // former `is_euclidean()`-only restriction punted every curved atom to
1943 // the dense layout; it is lifted.) The host/device in-core budget is the
1944 // single gate now; it is parameterised in `sparse_active_plan_for_budget`
1945 // so the engagement regression can pin a small budget without allocating
1946 // a multi-GB dense Gram.
1947 let budget = match crate::gpu::device_runtime::GpuRuntime::global() {
1948 // Allow up to one quarter of the AGGREGATE device budget for the dense
1949 // Gram, matching the streaming dispatcher's in-core fraction. The
1950 // per-atom-pair Gram blocks fan out across the whole device pool, so
1951 // the in-core fraction sums every ordinal's budget, not just the
1952 // primary's.
1953 Some(rt) => {
1954 let aggregate: usize = rt
1955 .device_ordinals()
1956 .iter()
1957 .map(|&ord| rt.memory_budget_for(ord))
1958 .sum();
1959 aggregate / 4
1960 }
1961 None => sae_host_in_core_budget_bytes().0,
1962 };
1963 self.sparse_active_plan_for_budget(budget)
1964 }
1965
1966 /// Budget-parameterised core of [`Self::sparse_active_plan`]. The dense data
1967 /// Gram footprint `(m_total · m_total) f64` is compared against `budget`; a
1968 /// term whose dense Gram exceeds the budget engages the compact active-set
1969 /// plan (returns `Some((k_active_cap, cutoff))`), regardless of whether any
1970 /// atom latent is curved. Pulled out so the curved-atom engagement
1971 /// regression can pin a small budget deterministically.
1972 pub(crate) fn sparse_active_plan_for_budget(&self, budget: usize) -> Option<(usize, f64)> {
1973 // Relative magnitude cutoff: assignment mass below this fraction of the
1974 // row's peak `|a_k|` enters the Gram only as `O(a²)` curvature and is
1975 // dropped. Chosen so dropped terms are ~1e-6 of the peak self-coupling.
1976 const RELATIVE_CUTOFF: f64 = 1.0e-3;
1977
1978 let k_atoms = self.k_atoms();
1979 if k_atoms <= 1 {
1980 return None;
1981 }
1982 let p = self.output_dim();
1983 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
1984 // Dense data Gram footprint: (m_total · m_total) f64.
1985 let dense_gram_bytes = m_total
1986 .saturating_mul(m_total)
1987 .saturating_mul(SAE_BYTES_PER_F64);
1988 if dense_gram_bytes <= budget {
1989 return None;
1990 }
1991
1992 // Sparse Gram footprint scales with the per-row active basis count
1993 // `k_active · m_atom`. Solve for the largest `k_active` whose sparse
1994 // Gram `(k_active · m_atom)²` still fits the budget.
1995 let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
1996 let max_active_basis = ((budget as f64 / SAE_BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
1997 let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
1998 // p does not enter the Gram dimension (it is carried by the `⊗ I_p`
1999 // structure), but a degenerate `p == 0` term has no decoder columns.
2000 if p == 0 {
2001 return None;
2002 }
2003 Some((k_active_cap, RELATIVE_CUTOFF))
2004 }
2005
2006 /// #1408/#1409 — per-row active-set plan for the Softmax assignment.
2007 ///
2008 /// Engages the compact top-`k` row layout when EITHER the user supplied a
2009 /// hard `top_k` cap ([`Self::softmax_active_cap`], `1 <= k < K`) OR the
2010 /// dense data Gram exceeds the in-core budget (the same memory lever the
2011 /// IBP path uses via [`Self::sparse_active_plan`]). The returned
2012 /// `k_active_cap` is the tighter of the two, so an explicit `top_k`
2013 /// genuinely bounds the optimization even below the memory threshold and a
2014 /// large-K budget breach still bounds it when no `top_k` is set. Returns
2015 /// `None` (keep the exact full-`K` dense softmax layout) when neither lever
2016 /// engages.
2017 ///
2018 /// The cutoff is the same relative magnitude floor as the budget plan
2019 /// (`1e-3` of the row peak); under an explicit `top_k` cap alone (no budget
2020 /// breach) it is `0.0` so exactly the top-`k` atoms are retained.
2021 pub(crate) fn softmax_active_plan(&self) -> Option<(usize, f64)> {
2022 if self.k_atoms() <= 1 {
2023 return None;
2024 }
2025 let budget_plan = self.sparse_active_plan();
2026 match (self.softmax_active_cap, budget_plan) {
2027 (Some(cap), Some((budget_cap, cutoff))) => Some((cap.min(budget_cap), cutoff)),
2028 // Explicit cap only: retain exactly the top-`cap` atoms (no extra
2029 // magnitude cutoff beyond the cap).
2030 (Some(cap), None) => Some((cap, 0.0)),
2031 (None, plan) => plan,
2032 }
2033 }
2034
2035 pub fn flatten_beta(&self) -> Array1<f64> {
2036 let p = self.output_dim();
2037 let offsets = self.beta_offsets();
2038 let mut out = Array1::<f64>::zeros(self.beta_dim());
2039 for (atom_idx, atom) in self.atoms.iter().enumerate() {
2040 let m = atom.basis_size();
2041 let off = offsets[atom_idx];
2042 for basis_col in 0..m {
2043 for out_col in 0..p {
2044 out[off + basis_col * p + out_col] =
2045 atom.decoder_coefficients[[basis_col, out_col]];
2046 }
2047 }
2048 }
2049 out
2050 }
2051
2052 pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
2053 if beta.len() != self.beta_dim() {
2054 return Err(format!(
2055 "set_flat_beta: beta length {} != expected {}",
2056 beta.len(),
2057 self.beta_dim()
2058 ));
2059 }
2060 let p = self.output_dim();
2061 let offsets = self.beta_offsets();
2062 for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
2063 let m = atom.basis_size();
2064 let off = offsets[atom_idx];
2065 for basis_col in 0..m {
2066 for out_col in 0..p {
2067 atom.decoder_coefficients[[basis_col, out_col]] =
2068 beta[off + basis_col * p + out_col];
2069 }
2070 }
2071 }
2072 Ok(())
2073 }
2074
2075 pub fn refit_decoder_least_squares_at_current_state(
2076 &mut self,
2077 target: ArrayView2<'_, f64>,
2078 rho: Option<&SaeManifoldRho>,
2079 ) -> Result<(), String> {
2080 let n = self.n_obs();
2081 let p = self.output_dim();
2082 if target.dim() != (n, p) {
2083 return Err(format!(
2084 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: target shape {:?} != ({n}, {p})",
2085 target.dim()
2086 ));
2087 }
2088 let k_atoms = self.k_atoms();
2089 let offsets = self.beta_offsets();
2090 let m_total = self.beta_dim() / p;
2091 let mut design = Array2::<f64>::zeros((n, m_total));
2092 for row in 0..n {
2093 let assignments = match rho {
2094 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2095 None => self.assignment.try_assignments_row(row)?,
2096 };
2097 for atom_idx in 0..k_atoms {
2098 let atom = &self.atoms[atom_idx];
2099 let weight = assignments[atom_idx];
2100 let m = atom.basis_size();
2101 let off = offsets[atom_idx] / p;
2102 for basis_col in 0..m {
2103 design[[row, off + basis_col]] = weight * atom.basis_values[[row, basis_col]];
2104 }
2105 }
2106 }
2107 let beta = solve_design_least_squares(design.view(), target)?;
2108 if beta.dim() != (m_total, p) {
2109 return Err(format!(
2110 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: beta shape {:?} != ({m_total}, {p})",
2111 beta.dim()
2112 ));
2113 }
2114 for atom_idx in 0..k_atoms {
2115 let m = self.atoms[atom_idx].basis_size();
2116 let off = offsets[atom_idx] / p;
2117 for basis_col in 0..m {
2118 for out_col in 0..p {
2119 self.atoms[atom_idx].decoder_coefficients[[basis_col, out_col]] =
2120 beta[[off + basis_col, out_col]];
2121 }
2122 }
2123 self.atoms[atom_idx].refresh_intrinsic_smooth_penalty();
2124 }
2125 Ok(())
2126 }
2127
2128 pub fn fitted(&self) -> Array2<f64> {
2129 self.try_fitted().expect("assignment logits must be finite")
2130 }
2131
2132 /// The #1026 hybrid-collapse substitution map: `atom_idx → &AtomLinearImage`
2133 /// for every `d = 1` slot whose post-fit verdict selected its straight
2134 /// (`Θ → 0`) sub-model. Empty when no report has been computed
2135 /// (`hybrid_split_report == None`, e.g. mid-fit) or no slot collapsed. The
2136 /// SINGLE source of the collapse policy — every reconstruction path (the
2137 /// rho-keyed `try_fitted_with_rho`, the explicit-assignment
2138 /// [`Self::reconstruct_from_assignments`] used by the top-k projection)
2139 /// reads it so train, OOS, and top-k reconstructions decode collapsed slots
2140 /// identically (#1228, #1233).
2141 pub(crate) fn hybrid_linear_image_map(
2142 &self,
2143 ) -> std::collections::HashMap<usize, &crate::hybrid_split::AtomLinearImage> {
2144 // A fitted term carries its collapse policy on the post-fit
2145 // `hybrid_split_report`; an OOS term carries the same trained images on
2146 // `oos_linear_images` (#1228). At most one is `Some` in practice, but
2147 // prefer the report when both are present.
2148 if let Some(report) = self.hybrid_split_report.as_ref() {
2149 return report
2150 .verdicts
2151 .iter()
2152 .filter_map(|v| v.linear_image.as_ref().map(|img| (img.atom_idx, img)))
2153 .collect();
2154 }
2155 if let Some(images) = self.oos_linear_images.as_ref() {
2156 return images.iter().map(|img| (img.atom_idx, img)).collect();
2157 }
2158 std::collections::HashMap::new()
2159 }
2160
2161 /// #1228 — attach the trained dictionary's hybrid-collapsed linear images to
2162 /// this (typically OOS) term so its reconstruction (`fitted` / the top-k
2163 /// assembler) decodes verdict-linear `d = 1` slots by the SAME straight
2164 /// sub-model the training reconstruction used, instead of the original
2165 /// curved decoder. Each image's `atom_idx` must index a real slot; an image
2166 /// whose channel count `p` disagrees with this term's output dim, or whose
2167 /// `atom_idx` is out of range, is rejected so a stale/mismatched payload
2168 /// cannot silently corrupt the reconstruction. Pass an empty slice (or never
2169 /// call this) for an all-curved OOS reconstruction.
2170 ///
2171 /// `pub` (not `pub(crate)`): this is part of the FFI surface — the gam-pyffi
2172 /// crate calls it from `latent_basis_and_sae_ffi.rs` to attach a trained
2173 /// dictionary's hybrid-linear images to an OOS reconstruction term (#1228).
2174 /// Downgrading it to `pub(crate)` breaks the gam-pyffi cdylib build with
2175 /// E0624 (the gam lib still compiles, so the lib build does not catch it).
2176 pub fn set_hybrid_linear_images(
2177 &mut self,
2178 images: Vec<crate::hybrid_split::AtomLinearImage>,
2179 ) -> Result<(), String> {
2180 let p = self.output_dim();
2181 let k_atoms = self.k_atoms();
2182 for img in &images {
2183 if img.atom_idx >= k_atoms {
2184 return Err(format!(
2185 "set_hybrid_linear_images: atom_idx {} out of range (k_atoms={k_atoms})",
2186 img.atom_idx
2187 ));
2188 }
2189 if img.b0.len() != p || img.b1.len() != p {
2190 return Err(format!(
2191 "set_hybrid_linear_images: atom {} linear image has p=({}, {}) != output_dim {p}",
2192 img.atom_idx,
2193 img.b0.len(),
2194 img.b1.len()
2195 ));
2196 }
2197 if self.atoms[img.atom_idx].latent_dim != 1 {
2198 return Err(format!(
2199 "set_hybrid_linear_images: atom {} is not d=1; only d=1 slots collapse to a straight image",
2200 img.atom_idx
2201 ));
2202 }
2203 }
2204 self.oos_linear_images = if images.is_empty() {
2205 None
2206 } else {
2207 Some(images)
2208 };
2209 Ok(())
2210 }
2211
2212 /// Assemble the reconstruction `Σ_k a[i,k]·g_k(t_{ik})` from an EXPLICIT
2213 /// per-row assignment matrix (e.g. a hard top-k projection of the fitted
2214 /// soft assignments), honouring the #1026 hybrid collapse when `collapse` is
2215 /// set: a verdict-linear `d = 1` slot decodes its straight sub-model image
2216 /// instead of its curved curve, exactly as the production `try_fitted` does.
2217 /// This is the shared assembler the FFI top-k path uses so the projected
2218 /// reconstruction composes with hybrid collapse (#1233) instead of
2219 /// re-deriving the curved image by hand and silently bypassing the verdict.
2220 /// The atom coordinates (`t`) and decoded curves are the term's own fitted
2221 /// ones; only the assignment masses come from `assignments`.
2222 pub fn reconstruct_from_assignments(
2223 &self,
2224 assignments: ArrayView2<'_, f64>,
2225 collapse: bool,
2226 ) -> Result<Array2<f64>, String> {
2227 let n = self.n_obs();
2228 let p = self.output_dim();
2229 let k_atoms = self.k_atoms();
2230 if assignments.dim() != (n, k_atoms) {
2231 return Err(format!(
2232 "SaeManifoldTerm::reconstruct_from_assignments: assignments {:?} != ({n}, {k_atoms})",
2233 assignments.dim()
2234 ));
2235 }
2236 let linear_images = if collapse {
2237 self.hybrid_linear_image_map()
2238 } else {
2239 std::collections::HashMap::new()
2240 };
2241 let mut out = Array2::<f64>::zeros((n, p));
2242 let mut g_buf = vec![0.0_f64; p];
2243 for row in 0..n {
2244 for atom_idx in 0..k_atoms {
2245 let a_k = assignments[[row, atom_idx]];
2246 if a_k == 0.0 {
2247 continue;
2248 }
2249 if let Some(image) = linear_images.get(&atom_idx) {
2250 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2251 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2252 } else {
2253 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2254 }
2255 let mut out_row = out.row_mut(row);
2256 for out_col in 0..p {
2257 out_row[out_col] += a_k * g_buf[out_col];
2258 }
2259 }
2260 }
2261 Ok(out)
2262 }
2263
2264 pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
2265 // Production/user-facing reconstruction: honours the #1026 hybrid-split
2266 // verdict (verdict-linear `d = 1` slots decode their straight sub-model).
2267 self.try_fitted_with_rho(None, true)
2268 }
2269
2270 pub(crate) fn try_fitted_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
2271 // Internal/fitting reconstruction: the pure CURVED image (the joint fit
2272 // and the #1026 adjudication both require the uncollapsed curve).
2273 self.try_fitted_with_rho(Some(rho), false)
2274 }
2275
2276 pub(crate) fn try_fitted_with_rho(
2277 &self,
2278 rho: Option<&SaeManifoldRho>,
2279 collapse: bool,
2280 ) -> Result<Array2<f64>, String> {
2281 let n = self.n_obs();
2282 let p = self.output_dim();
2283 let k_atoms = self.k_atoms();
2284 let mut out = Array2::<f64>::zeros((n, p));
2285 // #1026 — the curved/linear hybrid-split verdict is LOAD-BEARING on the
2286 // production reconstruction, not just a side report. When
2287 // [`Self::compute_hybrid_split_report`] (run post-fit in
2288 // `canonicalize_charts_post_fit`) adjudicated a `d = 1` atom's evidence
2289 // in favour of its straight (Θ→0) sub-model, the model's output
2290 // reconstruction (`fitted()` / `try_fitted` → predict and the user-facing
2291 // output) decodes that slot with its fitted linear image instead of its
2292 // curved decoded curve. The linear images are coordinate-keyed and
2293 // rho-independent (exact weighted-LS lines realised inside the
2294 // adjudication — no re-fit, no #1051 outer continuation).
2295 //
2296 // The collapse engages only when the caller asks for it (`collapse`):
2297 // the production `try_fitted` path and the explicit
2298 // `hybrid_collapsed_reconstruction` entry point. The pure-curved
2299 // `try_fitted_for_rho` opts out — the joint fit's loss/assembly optimise
2300 // the curved decoder coefficients and must see the curved image, and the
2301 // #1026 adjudication itself compares the curved fit against its straight
2302 // sub-model — both require the uncollapsed curve. (During fitting the
2303 // report is `None` regardless; it is only computed post-fit.)
2304 let linear_images = if collapse {
2305 self.hybrid_linear_image_map()
2306 } else {
2307 std::collections::HashMap::new()
2308 };
2309 // Reuse a single scratch buffer across all (row, atom) pairs instead of
2310 // allocating a fresh `Array1<f64>` of length p per call.
2311 let mut g_buf = vec![0.0_f64; p];
2312 for row in 0..n {
2313 let a = match rho {
2314 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2315 None => self.assignment.try_assignments_row(row)?,
2316 };
2317 for atom_idx in 0..k_atoms {
2318 let a_k = a[atom_idx];
2319 if let Some(image) = linear_images.get(&atom_idx) {
2320 // Verdict-linear slot: substitute the straight sub-model image
2321 // at this row's fitted on-atom coordinate — or, for a #1026
2322 // collapse-rescued slot, at its fresh per-row code.
2323 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2324 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2325 } else {
2326 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2327 }
2328 let mut out_row = out.row_mut(row);
2329 for out_col in 0..p {
2330 out_row[out_col] += a_k * g_buf[out_col];
2331 }
2332 }
2333 }
2334 Ok(out)
2335 }
2336
2337 /// Per-atom **leave-one-atom-out (LOAO) explained-variance contribution**
2338 /// (#1026): for each atom `k`, the drop in reconstruction explained variance
2339 /// `ΔEV_k = EV(full) − EV(full ⊖ atom_k)` when that atom's contribution
2340 /// `a[i,k]·g_k(coord[i,k])` is removed from the assembled reconstruction and
2341 /// nothing else is refit. Because every atom adds linearly into the same
2342 /// fitted reconstruction (`fitted[i] = Σ_k a[i,k]·g_k`), zeroing one atom is
2343 /// the exact "this atom withheld" counterfactual, and the EV it was earning
2344 /// is `EV(full) − EV(without k)`. This is the per-atom held-out EV
2345 /// attribution the #1026 roadmap pairs with each atom's fitted turning `Θ`:
2346 /// a `Θ ≈ 0` atom earning a large `ΔEV` is a linear-tail direction; a
2347 /// high-`Θ` atom earning a large `ΔEV` is a genuine curved family carrying
2348 /// reconstruction it would otherwise shatter into `N(ε) ≈ Θ/(2√(2ε))` linear
2349 /// directions. Pure read-only diagnostic — never mutates any atom.
2350 ///
2351 /// Returns one `Option<f64>` per atom in atom order; `None` for an atom
2352 /// whose ⊖-reconstruction EV is undefined (degenerate target variance), and
2353 /// `None` for the whole vector if the full-reconstruction EV is undefined.
2354 /// #1026: the load-bearing curved-vs-linear hybrid-split verdict for the
2355 /// fitted dictionary, or `None` until [`Self::canonicalize_charts_post_fit`]
2356 /// has run (or when no `d = 1` atom is eligible). Surfaced in the Python model
2357 /// output so the user sees which atoms genuinely earn their curvature.
2358 pub fn hybrid_split_report(
2359 &self,
2360 ) -> Option<&crate::hybrid_split::SaeHybridSplitReport> {
2361 self.hybrid_split_report.as_ref()
2362 }
2363
2364 /// Build the #1026 curved-vs-linear hybrid-split report by adjudicating each
2365 /// eligible `d = 1` atom's fitted curved image against its straight (linear
2366 /// special-case) sub-model on the common rank-aware Laplace evidence scale.
2367 ///
2368 /// Both candidates are scored against the SAME data — the atom's
2369 /// leave-this-atom-out response residual `y_resp = target − (full − a_k·γ_k)`
2370 /// (#1202) — over its assigned rows: the curved candidate predicts its actual
2371 /// mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2372 /// mass-weighted straight line fit to `y_resp` (the collapsed linear lane —
2373 /// closed form, NOT the broken euclidean outer fit path of #1051). Linear is
2374 /// the curved family's nested `Θ = 0` sub-model on common data, so the
2375 /// per-slot evidence argmin is a genuine match-or-beat comparison. Eligible
2376 /// atoms are `d = 1` atoms with an installed evaluator at the full curvature
2377 /// dial (`homotopy_eta == 1.0`) whose live coordinate dim still matches the
2378 /// atom's latent dim. Returns `None` when no reconstruction `target` is
2379 /// supplied (there is no data to adjudicate against).
2380 pub fn compute_hybrid_split_report(
2381 &self,
2382 rho: &SaeManifoldRho,
2383 target: Option<ArrayView2<'_, f64>>,
2384 ) -> Result<Option<crate::hybrid_split::SaeHybridSplitReport>, String> {
2385 let n = self.n_obs();
2386 let p = self.output_dim();
2387 // Per-atom held-out `ΔEV_k` (leave-one-atom-out explained-variance drop),
2388 // paired with each atom's fitted turning Θ onto the verdict so the report
2389 // carries the #1026 `(Θ, ΔEV)` frontier point as structured data. Absent
2390 // when no reconstruction target is supplied.
2391 let loao_ev: Vec<Option<f64>> = match target {
2392 Some(t) => self.per_atom_loao_explained_variance(t, rho)?,
2393 None => vec![None; self.k_atoms()],
2394 };
2395 let delta_ev_for =
2396 |atom_idx: usize| -> Option<f64> { loao_ev.get(atom_idx).copied().flatten() };
2397 // The common-evidence comparison (#1202) scores both candidates against
2398 // the response data the atom is responsible for. That requires a target;
2399 // with none supplied there is nothing to adjudicate against, so no report.
2400 let Some(target) = target else {
2401 return Ok(None);
2402 };
2403 if target.dim() != (n, p) {
2404 return Err(format!(
2405 "SaeManifoldTerm::compute_hybrid_split_report: target {:?} != ({n}, {p})",
2406 target.dim()
2407 ));
2408 }
2409 // Per-row assignment masses (once), so each atom's weighted straight-line
2410 // fit uses the same row weighting the joint reconstruction loss does.
2411 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2412 for row in 0..n {
2413 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2414 }
2415 // The full assembled reconstruction `Σ_k a[i,k]·γ_k`, computed once. Each
2416 // atom's leave-this-atom-out response residual is `y_resp = target −
2417 // (full − a_k·γ_k)`, the data both that atom's candidates fit (#1202).
2418 let full = self.try_fitted_for_rho(rho)?;
2419 let eligible: Vec<usize> = (0..self.k_atoms())
2420 .filter(|&atom_idx| {
2421 let atom = &self.atoms[atom_idx];
2422 atom.latent_dim == 1
2423 && atom.basis_evaluator.is_some()
2424 && atom.homotopy_eta == 1.0
2425 && self.assignment.coords[atom_idx].latent_dim() == atom.latent_dim
2426 })
2427 .collect();
2428 // Per-atom fitted decoded image at every row (the curved candidate's
2429 // realized curve, which the linear candidate must approximate).
2430 let coords_for = |atom_idx: usize| -> Array1<f64> {
2431 self.assignment.coords[atom_idx]
2432 .as_matrix()
2433 .column(0)
2434 .to_owned()
2435 };
2436 let assign_for = |atom_idx: usize| -> Array1<f64> {
2437 Array1::from_iter((0..n).map(|row| weights[row][atom_idx]))
2438 };
2439 let decoded_for = |atom_idx: usize| -> Array2<f64> {
2440 let mut decoded = Array2::<f64>::zeros((n, p));
2441 let mut buf = vec![0.0_f64; p];
2442 for row in 0..n {
2443 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2444 for col in 0..p {
2445 decoded[[row, col]] = buf[col];
2446 }
2447 }
2448 decoded
2449 };
2450 // The atom's leave-this-atom-out response residual `y_resp = target −
2451 // (full − a_k·γ_k) = (target − full) + a_k·γ_k`. Both the curved and the
2452 // linear candidate are scored against this on common data (#1202).
2453 let target_resid_for = |atom_idx: usize| -> Array2<f64> {
2454 let mut resid = Array2::<f64>::zeros((n, p));
2455 let mut buf = vec![0.0_f64; p];
2456 for row in 0..n {
2457 let a_k = weights[row][atom_idx];
2458 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2459 for col in 0..p {
2460 resid[[row, col]] = target[[row, col]] - full[[row, col]] + a_k * buf[col];
2461 }
2462 }
2463 resid
2464 };
2465 let manifold_for = |atom_idx: usize| -> gam_terms::latent::LatentManifold {
2466 self.assignment.coords[atom_idx].manifold().clone()
2467 };
2468 // #1026 EV-preservation gate denominator: the full target's total
2469 // column-centered variance `SST_full` (the SAME `sst` the reconstruction
2470 // EV is measured against), so the gate vetoes any collapse that would drop
2471 // full-reconstruction EV by more than its tolerance.
2472 let total_centered_variance = {
2473 let mut tss = 0.0_f64;
2474 for col in 0..p {
2475 let mut mean = 0.0_f64;
2476 for row in 0..n {
2477 mean += target[[row, col]];
2478 }
2479 mean /= n as f64;
2480 for row in 0..n {
2481 let c = target[[row, col]] - mean;
2482 tss += c * c;
2483 }
2484 }
2485 tss
2486 };
2487 crate::hybrid_split::build_hybrid_split_report(
2488 &self.atoms,
2489 eligible.into_iter(),
2490 coords_for,
2491 assign_for,
2492 decoded_for,
2493 target_resid_for,
2494 manifold_for,
2495 delta_ev_for,
2496 total_centered_variance,
2497 )
2498 }
2499
2500 pub fn per_atom_loao_explained_variance(
2501 &self,
2502 target: ArrayView2<'_, f64>,
2503 rho: &SaeManifoldRho,
2504 ) -> Result<Vec<Option<f64>>, String> {
2505 let n = self.n_obs();
2506 let p = self.output_dim();
2507 let k_atoms = self.k_atoms();
2508 if target.dim() != (n, p) {
2509 return Err(format!(
2510 "SaeManifoldTerm::per_atom_loao_explained_variance: target {:?} != ({n}, {p})",
2511 target.dim()
2512 ));
2513 }
2514 let full = self.try_fitted_for_rho(rho)?;
2515 let Some(ev_full) = reconstruction_explained_variance(target, full.view()) else {
2516 return Ok(vec![None; k_atoms]);
2517 };
2518 // Cache each row's assignment weights once, then subtract a single
2519 // atom's decoded contribution per LOAO pass instead of reassembling the
2520 // whole dictionary k times.
2521 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2522 for row in 0..n {
2523 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2524 }
2525 let mut g_buf = vec![0.0_f64; p];
2526 let mut out = Vec::with_capacity(k_atoms);
2527 for atom_idx in 0..k_atoms {
2528 let mut without = full.clone();
2529 for row in 0..n {
2530 let a_k = weights[row][atom_idx];
2531 if a_k == 0.0 {
2532 continue;
2533 }
2534 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2535 let mut without_row = without.row_mut(row);
2536 for out_col in 0..p {
2537 without_row[out_col] -= a_k * g_buf[out_col];
2538 }
2539 }
2540 out.push(
2541 reconstruction_explained_variance(target, without.view())
2542 .map(|ev_without| ev_full - ev_without),
2543 );
2544 }
2545 Ok(out)
2546 }
2547
2548 /// #1026 — the LOAD-BEARING collapsed reconstruction: the assembled
2549 /// dictionary output `Σ_k a[i,k]·g_k(coord[i,k])` in which every slot whose
2550 /// hybrid-split verdict selected LINEAR has its curved decoded image replaced
2551 /// by its fitted straight sub-model `b₀ + (t − t̄)·b₁`. This is what makes the
2552 /// verdict *change the reconstruction* instead of merely logging a choice:
2553 /// the linear-collapsed atom no longer pays its `M·p` curved coefficients, it
2554 /// carries a `2·p` straight image whose decoded curve has zero turning.
2555 ///
2556 /// The straight images are the exact weighted-least-squares lines already
2557 /// realized inside [`Self::compute_hybrid_split_report`] (no re-fit, no outer
2558 /// continuation, sidestepping #1051). Returns the curved reconstruction
2559 /// unchanged when no verdict selected linear, or when the report has not been
2560 /// computed yet (`hybrid_split_report == None`).
2561 pub fn hybrid_collapsed_reconstruction(
2562 &self,
2563 rho: &SaeManifoldRho,
2564 ) -> Result<Array2<f64>, String> {
2565 // #1026 — the hybrid collapse is realised by the SINGLE reconstruction
2566 // path ([`Self::try_fitted_with_rho`]) with the collapse flag set: a
2567 // verdict-linear `d = 1` slot decodes its straight sub-model image
2568 // instead of its curved curve. This replaces the dedicated re-collapse
2569 // loop this method used to carry (a parallel layer). The production
2570 // `try_fitted` shares the identical routine at `rho = None`; this entry
2571 // point keeps the rho-keyed collapse for the #1026 EV-dominance reporting
2572 // (`hybrid_collapsed_explained_variance`) and the regression battery.
2573 self.try_fitted_with_rho(Some(rho), true)
2574 }
2575
2576 /// #1026 — the reconstruction explained variance of the hybrid-collapsed
2577 /// dictionary (every verdict-linear slot decoded by its straight sub-model)
2578 /// against `target`. The companion of [`Self::per_atom_loao_explained_variance`]
2579 /// for the dominance claim: because each linear-collapsed slot is the curved
2580 /// family's `Θ → 0` sub-model and is only kept when its evidence beats the
2581 /// curved candidate's parameter price, the collapsed dictionary match-or-beats
2582 /// the all-curved one on EV-per-parameter — the strict-generalization floor
2583 /// the #1026 hybrid argument rests on. `None` when EV is undefined (degenerate
2584 /// target variance).
2585 pub fn hybrid_collapsed_explained_variance(
2586 &self,
2587 target: ArrayView2<'_, f64>,
2588 rho: &SaeManifoldRho,
2589 ) -> Result<Option<f64>, String> {
2590 let n = self.n_obs();
2591 let p = self.output_dim();
2592 if target.dim() != (n, p) {
2593 return Err(format!(
2594 "SaeManifoldTerm::hybrid_collapsed_explained_variance: target {:?} != ({n}, {p})",
2595 target.dim()
2596 ));
2597 }
2598 let collapsed = self.hybrid_collapsed_reconstruction(rho)?;
2599 Ok(reconstruction_explained_variance(target, collapsed.view()))
2600 }
2601
2602 /// #1026 ladder item 2/3 — the AMORTIZED ENCODER, wired from the fitted
2603 /// dictionary. Builds the offline certified [`EncodeAtlas`] over this term's
2604 /// frozen atoms and encodes a target corpus `targets` (`n × p`) through the
2605 /// per-chart distilled Jacobian predictor, with the Kantorovich certificate
2606 /// gating each row and an exact-solve fallback for the rows the amortized
2607 /// predictor cannot certify. Returns one [`EncodeResult`] per atom (the
2608 /// per-atom encoded coordinates + per-row certificate mask), in dictionary
2609 /// order.
2610 ///
2611 /// This is the thread's "encoder + certificate-gated exact fallback"
2612 /// deployment made reachable from a fit: the distilled map approximates
2613 /// inference at one mat-vec/row, and any row whose amortized prediction fails
2614 /// `h ≤ ½` falls back to the certified IFT-warm-start Newton encode
2615 /// ([`EncodeAtlas::certified_encode_row`]); rows that still cannot be
2616 /// certified ride the [`EncodeResult::encode_uncertified_count`] flag for the
2617 /// upstream exact multi-start solve (honesty, never a silent wrong encode).
2618 ///
2619 /// Magic by default: the atlas's worst-case bounds are auto-derived from the
2620 /// fit — `amplitude_bound[k]` is the largest fitted assignment mass `a[i,k]`
2621 /// the encode can produce for atom `k` (the encode recovers `t` from
2622 /// `x ≈ z·γ_k(t)` at amplitude `z = a[i,k]`), and `target_norm_bound` is the
2623 /// largest target row norm — so no caller supplies a knob. Per-row amplitudes
2624 /// are the fitted assignment masses for the same target the dictionary was fit
2625 /// against; an external corpus reuses the per-row masses the assignment
2626 /// produces for it upstream (passed in `amplitudes`, one column per atom).
2627 pub fn amortized_encode_target(
2628 &self,
2629 targets: ArrayView2<'_, f64>,
2630 amplitudes: ArrayView2<'_, f64>,
2631 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2632 let p = self.output_dim();
2633 let k_atoms = self.k_atoms();
2634 let n = targets.nrows();
2635 if targets.ncols() != p {
2636 return Err(format!(
2637 "SaeManifoldTerm::amortized_encode_target: targets have {} cols but output_dim is {p}",
2638 targets.ncols()
2639 ));
2640 }
2641 if amplitudes.dim() != (n, k_atoms) {
2642 return Err(format!(
2643 "SaeManifoldTerm::amortized_encode_target: amplitudes {:?} must be (n={n}, K={k_atoms})",
2644 amplitudes.dim()
2645 ));
2646 }
2647
2648 // Magic-by-default offline bounds, auto-derived from the fit so no caller
2649 // supplies a knob. `target_norm_bound` is the largest target row L2 norm
2650 // (bounds `‖x‖` over the corpus); `amplitude_bound[k]` is the largest
2651 // fitted assignment mass for atom `k` (bounds `|z_k|`), with a strictly
2652 // positive floor so a near-inactive atom still certifies a finite radius.
2653 let mut target_norm_bound = 0.0_f64;
2654 for row in 0..n {
2655 let norm = targets.row(row).dot(&targets.row(row)).sqrt();
2656 if norm.is_finite() && norm > target_norm_bound {
2657 target_norm_bound = norm;
2658 }
2659 }
2660 let mut amplitude_bound = vec![0.0_f64; k_atoms];
2661 for atom_idx in 0..k_atoms {
2662 let mut bound = 0.0_f64;
2663 for row in 0..n {
2664 let z = amplitudes[[row, atom_idx]].abs();
2665 if z.is_finite() && z > bound {
2666 bound = z;
2667 }
2668 }
2669 // A strictly positive amplitude floor keeps the offline Lipschitz
2670 // scaling finite for atoms with no active row in this corpus (those
2671 // rows encode to the chart center via the certificate anyway).
2672 amplitude_bound[atom_idx] = bound.max(1.0);
2673 }
2674
2675 let atlas = crate::encode::EncodeAtlas::build(
2676 &self.atoms,
2677 &litude_bound,
2678 target_norm_bound,
2679 crate::encode::AtlasConfig::default(),
2680 )?;
2681
2682 // Per-atom amortized encode with a certificate-gated exact-solve fallback:
2683 // a row whose distilled prediction fails `h ≤ ½` is retried through the
2684 // certified IFT-warm-start Newton path; a row that still cannot be
2685 // certified stays flagged for the upstream multi-start solve.
2686 // (The atlas is rho-free; the per-row amplitudes already carry the
2687 // rho-resolved assignment masses the caller produced upstream.)
2688 let mut results = Vec::with_capacity(k_atoms);
2689 for atom_idx in 0..k_atoms {
2690 let atom = &self.atoms[atom_idx];
2691 let amp_col = amplitudes.column(atom_idx).to_owned();
2692 let amortized =
2693 atlas.amortized_encode_batch(atom, atom_idx, targets, amp_col.view())?;
2694 let mut coords = amortized.coords;
2695 let mut certified = amortized.certified;
2696 for row in 0..n {
2697 if certified[row] {
2698 continue;
2699 }
2700 let (t, cert) =
2701 atlas.certified_encode_row(atom, atom_idx, targets.row(row), amp_col[row])?;
2702 if cert.certified() {
2703 coords.row_mut(row).assign(&t);
2704 certified[row] = true;
2705 }
2706 }
2707 results.push(crate::encode::EncodeResult::from_rows(
2708 coords, certified,
2709 ));
2710 }
2711 Ok(results)
2712 }
2713
2714 /// #1026 — the fitted per-row assignment masses `a[i,k]` (the activation
2715 /// amplitudes `z_k` the amortized encode recovers `t` against), as an
2716 /// `n × K` matrix. These are exactly the masses
2717 /// [`Self::try_fitted_with_rho`] assembles the reconstruction from, so
2718 /// feeding them to [`Self::amortized_encode_target`] re-encodes the SAME
2719 /// inference the dictionary was fit against — the self-consistency the
2720 /// distilled encoder is supervised to approximate.
2721 pub fn fitted_assignment_amplitudes(
2722 &self,
2723 rho: &SaeManifoldRho,
2724 ) -> Result<Array2<f64>, String> {
2725 let n = self.n_obs();
2726 let k_atoms = self.k_atoms();
2727 let mut amplitudes = Array2::<f64>::zeros((n, k_atoms));
2728 for row in 0..n {
2729 let a = self.assignment.try_assignments_row_for_rho(row, rho)?;
2730 for atom_idx in 0..k_atoms {
2731 amplitudes[[row, atom_idx]] = a[atom_idx];
2732 }
2733 }
2734 Ok(amplitudes)
2735 }
2736
2737 /// #1026 — encode the dictionary's own fit-time target with the amortized
2738 /// encoder, deriving the per-row amplitudes from the fitted assignment so the
2739 /// caller supplies neither bounds nor amplitudes (magic by default). The
2740 /// end-to-end "fit → distilled encoder → certificate-gated encode" path.
2741 pub fn amortized_encode_fitted(
2742 &self,
2743 targets: ArrayView2<'_, f64>,
2744 rho: &SaeManifoldRho,
2745 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2746 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2747 self.amortized_encode_target(targets, amplitudes.view())
2748 }
2749
2750 /// #1154 — amortized-encoder consistency of the CURRENT dictionary against
2751 /// its own fit-time target. This is the co-training signal of the joint
2752 /// amortized-encoder + REML loop (Design A): the amortized (one-mat-vec)
2753 /// encode is built from the *current* fitted decoder, run on `targets`, and
2754 /// scored on two principled axes —
2755 ///
2756 /// * `recon_consistency` (the bilinear part of the co-training loss): the
2757 /// mean per-element squared gap between the **amortized** reconstruction
2758 /// `Σ_k z_k · Φ_k(t̂_k) B_k` (decode the amortized coords) and the
2759 /// **exact** fitted reconstruction `Σ_k z_k · Φ_k(t_k^*) B_k` the inner
2760 /// solve converged to. A dictionary whose encode map is well-approximated
2761 /// to first order by the per-chart IFT predictor scores near zero; a
2762 /// dictionary the amortized encoder *cannot* invert faithfully (sharp
2763 /// curvature, poorly-charted regions) scores high. Minimising this jointly
2764 /// with REML steers the fit toward dictionaries that admit a fast,
2765 /// faithful amortized encode — the architectural co-adaptation #1154 adds.
2766 /// * `uncertified_fraction`: the share of (row, atom) encodes whose
2767 /// Kantorovich certificate failed (`h > ½`), i.e. that fell back to the
2768 /// certified IFT-warm-start Newton. This is the encoder's *certifiable coverage*
2769 /// of the dictionary; co-training rewards dictionaries the cheap encode
2770 /// certifies, not just ones it happens to land.
2771 ///
2772 /// The certificate keeps every accepted amortized coord honest (uncertified
2773 /// rows already ride the exact fallback inside `amortized_encode_target`), so
2774 /// this metric never silently trusts a wrong encode — it MEASURES how much of
2775 /// the dictionary the cheap encoder can faithfully and certifiably invert.
2776 pub fn amortized_encoder_consistency(
2777 &self,
2778 targets: ArrayView2<'_, f64>,
2779 rho: &SaeManifoldRho,
2780 ) -> Result<AmortizedEncoderConsistency, String> {
2781 let n = self.n_obs();
2782 let p = self.output_dim();
2783 let k_atoms = self.k_atoms();
2784 if targets.dim() != (n, p) {
2785 return Err(format!(
2786 "SaeManifoldTerm::amortized_encoder_consistency: targets {:?} must be (n={n}, p={p})",
2787 targets.dim()
2788 ));
2789 }
2790 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2791 let encodes = self.amortized_encode_target(targets, amplitudes.view())?;
2792 // The EXACT fitted reconstruction the inner solve converged to (pure
2793 // curved image, rho-keyed) is the supervision target for the amortized
2794 // reconstruction. Both are n×p ambient, so the comparison is layout-free.
2795 let exact_recon = self.try_fitted_for_rho(rho)?;
2796
2797 // Build the amortized reconstruction Σ_k z_k · Φ_k(t̂_k) B_k by decoding
2798 // each atom's amortized coords through that atom's own basis evaluator.
2799 let mut amortized_recon = Array2::<f64>::zeros((n, p));
2800 let mut uncertified = 0usize;
2801 for atom_idx in 0..k_atoms {
2802 let atom = &self.atoms[atom_idx];
2803 let result = &encodes[atom_idx];
2804 // An atom with no basis evaluator cannot decode an amortized
2805 // reconstruction; every one of its rows is necessarily uncertified
2806 // (the encode flagged them all), so it contributes nothing to the
2807 // amortized recon and its full row-count to the uncertified tally.
2808 // Count it and skip the decode rather than erroring — the consistency
2809 // fold stays a bounded penalty, never a hard abort of the criterion.
2810 let Some(evaluator) = atom.basis_evaluator.as_ref() else {
2811 uncertified += n;
2812 continue;
2813 };
2814 uncertified += result.encode_uncertified_count;
2815 // Decode the amortized coords: Φ_k(t̂) is (n × M_k); B_k is (M_k × p).
2816 let (phi, _jac) = evaluator.evaluate(result.coords.view())?;
2817 let decoded = phi.dot(&atom.decoder_coefficients); // (n × p)
2818 for row in 0..n {
2819 let z = amplitudes[[row, atom_idx]];
2820 if z == 0.0 {
2821 continue;
2822 }
2823 for col in 0..p {
2824 amortized_recon[[row, col]] += z * decoded[[row, col]];
2825 }
2826 }
2827 }
2828
2829 let mut sse = 0.0_f64;
2830 for row in 0..n {
2831 for col in 0..p {
2832 let gap = amortized_recon[[row, col]] - exact_recon[[row, col]];
2833 sse += gap * gap;
2834 }
2835 }
2836 let denom = (n.max(1) * p.max(1)) as f64;
2837 let recon_consistency = sse / denom;
2838 let total_encodes = (n * k_atoms).max(1) as f64;
2839 let uncertified_fraction = uncertified as f64 / total_encodes;
2840
2841 Ok(AmortizedEncoderConsistency {
2842 recon_consistency,
2843 uncertified_fraction,
2844 n_uncertified: uncertified,
2845 n_encodes: n * k_atoms,
2846 })
2847 }
2848
2849 /// #1154 — the co-trained REML criterion: the exact REML criterion at `rho`
2850 /// PLUS the amortized-encoder consistency penalty, so the outer optimizer
2851 /// co-adapts the dictionary + smoothing parameters λ TOWARD a dictionary the
2852 /// fast amortized encoder can faithfully and certifiably invert.
2853 ///
2854 /// This is Design A of #1154. The inner solve still converges the `(t, β)`
2855 /// system to stationarity at the engine's current ρ (so the implicit-function
2856 /// REML λ-gradient `dβ̂/dλ = −(H+S_λ)⁻¹(dS_λ/dλ)β̂` stays EXACT — the encoder
2857 /// only warm-starts/co-adapts, it never replaces the stationary point). The
2858 /// added term
2859 ///
2860 /// ```text
2861 /// J_cotrain(ρ) = REML(ρ) + w · ‖x̂_amortized − x̂_exact‖²/(n·p)
2862 /// + w_cert · uncertified_fraction
2863 /// ```
2864 ///
2865 /// folds the post-fit amortized-encode quality into the ranked objective. The
2866 /// weights are auto-scaled to the REML criterion magnitude (magic by default:
2867 /// no caller knob) so the consistency term is a meaningful but non-dominant
2868 /// fraction of the objective regardless of problem scale.
2869 pub fn reml_criterion_cotrained(
2870 &mut self,
2871 target: ArrayView2<'_, f64>,
2872 rho: &SaeManifoldRho,
2873 registry: Option<&AnalyticPenaltyRegistry>,
2874 inner_max_iter: usize,
2875 learning_rate: f64,
2876 ridge_ext_coord: f64,
2877 ridge_beta: f64,
2878 ) -> Result<(f64, SaeManifoldLoss, AmortizedEncoderConsistency), String> {
2879 // #1154: always attempt the amortized warm-start first inside
2880 // `reml_criterion_cotrained` (the encode/warm path for the cotrained
2881 // objective). Good warm-starts from the running dictionary land the
2882 // inner solve closer to the stationary point used for the fold.
2883 // Advisory only (0 or err falls back to cold); telemetry recorded by
2884 // outer objective callers when present.
2885 self.warm_start_latents_from_amortized_encoder(target, rho)
2886 .unwrap_or(0);
2887 let (reml, loss) = self.reml_criterion_with_refine_policy(
2888 target,
2889 rho,
2890 registry,
2891 inner_max_iter,
2892 learning_rate,
2893 ridge_ext_coord,
2894 ridge_beta,
2895 true,
2896 )?;
2897 let consistency = self.amortized_encoder_consistency(target, rho)?;
2898 // Auto-scale the co-training weights to the REML magnitude so the
2899 // consistency penalty is a bounded, scale-free fraction of the objective
2900 // (magic by default: no caller knob). `reml_scale` floors at 1 so a
2901 // near-zero criterion still admits a meaningful consistency contribution.
2902 let cotrained = Self::fold_cotrain_consistency(reml, &consistency);
2903 Ok((cotrained, loss, consistency))
2904 }
2905
2906 /// #1154 — the single source of the co-training fold arithmetic: add the
2907 /// auto-scaled amortized-encoder consistency penalty to an already-computed
2908 /// REML criterion at the converged dictionary. Both the public
2909 /// [`Self::reml_criterion_cotrained`] entry point and the outer-loop value /
2910 /// gradient lanes (`SaeManifoldOuterObjective::fold_cotrain_consistency`)
2911 /// route through THIS function, so the folded objective cannot drift between
2912 /// the criterion and the cascade-ranked cost (the objective↔gradient desync
2913 /// bug class). The weights are auto-scaled to the REML magnitude (`max(|REML|,
2914 /// 1)`) so the penalty is a bounded, scale-free fraction of the objective
2915 /// regardless of problem scale; the fold carries no analytic gradient (under
2916 /// Design A the REML λ-gradient stays the exact implicit-function path).
2917 #[must_use]
2918 pub fn fold_cotrain_consistency(
2919 reml_cost: f64,
2920 consistency: &AmortizedEncoderConsistency,
2921 ) -> f64 {
2922 let reml_scale = reml_cost.abs().max(1.0);
2923 reml_cost
2924 + COTRAIN_RECON_WEIGHT * reml_scale * consistency.recon_consistency
2925 + COTRAIN_CERT_WEIGHT * reml_scale * consistency.uncertified_fraction
2926 }
2927
2928 /// #1154 item 2 — warm-start the inner latent coordinates from the amortized
2929 /// encoder (Design A). Builds the per-chart IFT-Jacobian atlas from the
2930 /// CURRENT dictionary, runs the one-mat-vec amortized encode of `target`
2931 /// against each atom at the rho-resolved assignment masses, and overwrites
2932 /// each atom's stored latent coords with the predicted `t̂` ON THE ROWS THE
2933 /// KANTOROVICH CERTIFICATE ACCEPTS. Uncertified rows are left at their
2934 /// current coords (the previous-iterate start), so the
2935 /// warm-start can only HELP — a row the cheap predictor cannot certify never
2936 /// corrupts the seed. The subsequent inner Newton refines from this seed to
2937 /// the SAME stationary point (the warm-start changes only the basin entry,
2938 /// not the root), so the REML λ-gradient stays exactly the implicit-function
2939 /// path and the criterion is unchanged at convergence — the amortized encoder
2940 /// only accelerates/co-adapts the inner solve, it never replaces the
2941 /// stationary point.
2942 ///
2943 /// Returns the number of (row, atom) coords actually warm-started (the
2944 /// certified-prediction count), for instrumentation / tests. A first-build
2945 /// dictionary with no usable charts simply warm-starts nothing and returns 0
2946 /// (the cold path is byte-for-byte unchanged).
2947 pub fn warm_start_latents_from_amortized_encoder(
2948 &mut self,
2949 target: ArrayView2<'_, f64>,
2950 rho: &SaeManifoldRho,
2951 ) -> Result<usize, String> {
2952 let n = self.n_obs();
2953 let k_atoms = self.k_atoms();
2954 if n == 0 || k_atoms == 0 {
2955 return Ok(0);
2956 }
2957 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2958 let encodes = self.amortized_encode_target(target, amplitudes.view())?;
2959 let mut warm_started = 0usize;
2960 for atom_idx in 0..k_atoms {
2961 let d = self.atoms[atom_idx].latent_dim;
2962 if d == 0 {
2963 continue;
2964 }
2965 let result = &encodes[atom_idx];
2966 // Start from the atom's CURRENT coords so uncertified rows are left
2967 // exactly as they were; overwrite only the certified predictions.
2968 let mut coords = self.assignment.coords[atom_idx].as_matrix();
2969 if coords.dim() != (n, d) {
2970 return Err(format!(
2971 "warm_start_latents_from_amortized_encoder: atom {atom_idx} coords {:?} != (n={n}, d={d})",
2972 coords.dim()
2973 ));
2974 }
2975 for row in 0..n {
2976 if !result.certified[row] {
2977 continue;
2978 }
2979 for axis in 0..d {
2980 coords[[row, axis]] = result.coords[[row, axis]];
2981 }
2982 warm_started += 1;
2983 }
2984 // `as_matrix` lays coords out row-major (`[[row, axis]]`), exactly the
2985 // `values[row*d + axis]` order `set_flat` expects, so a plain
2986 // row-major iterator reconstructs the flat vector.
2987 let flat = Array1::from_iter(coords.iter().copied());
2988 self.assignment.coords[atom_idx].set_flat(flat.view());
2989 }
2990 // The basis caches must follow the freshly-seeded coords so the next
2991 // inner solve evaluates Φ at the warm-started t̂, not the stale coords.
2992 self.refresh_basis_from_current_coords()?;
2993 Ok(warm_started)
2994 }
2995
2996 pub fn loss(
2997 &self,
2998 target: ArrayView2<'_, f64>,
2999 rho: &SaeManifoldRho,
3000 ) -> Result<SaeManifoldLoss, String> {
3001 self.loss_scaled(target, rho, 1.0)
3002 }
3003
3004 /// Penalized objective with a `penalty_scale` applied to the β-tier
3005 /// (decoder smoothness) penalty, mirroring
3006 /// [`Self::assemble_arrow_schur_scaled`]. The streaming line search sums
3007 /// per-chunk `loss_scaled(..., n_chunk / N)` so that the global smoothness
3008 /// penalty is counted exactly once across a pass while the per-row data,
3009 /// assignment-prior, and ARD terms sum naturally. `penalty_scale == 1.0`
3010 /// recovers the full-batch objective.
3011 pub fn loss_scaled(
3012 &self,
3013 target: ArrayView2<'_, f64>,
3014 rho: &SaeManifoldRho,
3015 penalty_scale: f64,
3016 ) -> Result<SaeManifoldLoss, String> {
3017 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3018 return Err(format!(
3019 "SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3020 ));
3021 }
3022 if target.dim() != (self.n_obs(), self.output_dim()) {
3023 return Err(format!(
3024 "SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
3025 self.n_obs(),
3026 self.output_dim(),
3027 target.dim()
3028 ));
3029 }
3030 // The likelihood whitens through the RowMetric **only** when the metric
3031 // is a genuinely estimated noise model (`metric.whitens_likelihood()`,
3032 // i.e. `WhitenedStructured` — the #974 residual-covariance seam). For
3033 // Euclidean (default `None`) and for the OutputFisher *gauge* metric the
3034 // reconstruction data-fit stays the isotropic `0.5 * Σ r²`: a gauge /
3035 // output-Fisher inner product must NOT silently replace the
3036 // reconstruction loss with a Fisher pullback (#980). It only drives the
3037 // gauge (see `analytic_penalties::corrected_isometry_penalty`). The
3038 // producer of `WhitenedStructured` is
3039 // `inference::residual_factor::StructuredResidualModel::row_metric`; the
3040 // SAME metric whitens the assembled gradient/Hessian in
3041 // `assemble_arrow_schur` (the single #974 seam), so this value and that
3042 // gradient cannot desync. Without a whitening metric this path is
3043 // bit-for-bit the historical isotropic data-fit.
3044 let whitens = self
3045 .row_metric
3046 .as_ref()
3047 .is_some_and(|metric| metric.whitens_likelihood());
3048 // #991 design honesty weights: the reconstruction channel of row `i`
3049 // is weighted by `w_i` (mean-1 HT inclusion correction). The assembly
3050 // applies the same `w_i` via a `√w_i` scaling of the row residual /
3051 // Jacobian / β load at its single seam, so this value and that
3052 // gradient/Hessian carry the identical per-row factor. `None` ⇒ the
3053 // historical unweighted sum, bit-for-bit.
3054 let row_loss_w = self.row_loss_weights.as_deref();
3055 let n = self.n_obs();
3056 let p = self.output_dim();
3057 let k_atoms = self.k_atoms();
3058 // #1017: the data-fit is the dominant per-line-search-trial cost (it
3059 // re-runs every Armijo halving × every inner Newton iteration × every
3060 // outer ρ evaluation). The old path materialised the whole `n × p`
3061 // fitted matrix (`try_fitted_for_rho`) and then walked it AGAIN to form
3062 // the residual sum — two sequential `n·p` passes plus an `n·p`
3063 // allocation per trial. Fuse the reconstruction and the residual reduce
3064 // into ONE row-parallel pass that never materialises the fitted matrix:
3065 // each row decodes its atoms into per-worker scratch, differences
3066 // against the target, and contributes its scalar `0.5·w·‖r‖²` to a
3067 // chunk-ordered fold (bit-identical run-to-run). Per-worker scratch
3068 // (`map_init`) keeps the only allocations one `g_buf`/`fitted_row` pair
3069 // per rayon thread rather than per row. Stay sequential inside a worker
3070 // (the topology race owns the outer pool) to avoid nested
3071 // oversubscription.
3072 let parallel = n >= SAE_LOSS_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
3073 let row_data_fit =
3074 |row: usize,
3075 g_buf: &mut [f64],
3076 fitted_row: &mut [f64],
3077 assign_buf: &mut [f64]|
3078 -> Result<f64, String> {
3079 // #1557 — fill the per-atom assignment row into reused per-worker
3080 // scratch via the `_into` twin instead of heap-allocating a fresh
3081 // `Array1` per row per loss eval. Bit-identical to the allocating
3082 // `try_assignments_row_for_rho` (same arithmetic, same order); this
3083 // loss reruns every Armijo halving × inner Newton iter × outer ρ
3084 // eval, so the per-row K-sized allocation was a hot-path churn.
3085 self.assignment
3086 .try_assignments_row_for_rho_into(row, rho, assign_buf)?;
3087 let a = &*assign_buf;
3088 for slot in fitted_row.iter_mut() {
3089 *slot = 0.0;
3090 }
3091 for atom_idx in 0..k_atoms {
3092 self.atoms[atom_idx].fill_decoded_row(row, g_buf);
3093 let a_k = a[atom_idx];
3094 for out_col in 0..p {
3095 fitted_row[out_col] += a_k * g_buf[out_col];
3096 }
3097 }
3098 for out_col in 0..p {
3099 fitted_row[out_col] = target[[row, out_col]] - fitted_row[out_col];
3100 }
3101 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3102 let mut acc = 0.0_f64;
3103 match self.row_metric.as_ref() {
3104 Some(metric) if whitens => {
3105 let resid = ArrayView1::from(&fitted_row[..p]);
3106 for w in metric.whiten_residual_row(row, resid) {
3107 acc += 0.5 * w_row * w * w;
3108 }
3109 }
3110 _ => {
3111 for &r in fitted_row[..p].iter() {
3112 acc += 0.5 * w_row * r * r;
3113 }
3114 }
3115 }
3116 Ok(acc)
3117 };
3118 let data_fit = if parallel {
3119 use rayon::prelude::*;
3120 const CHUNK: usize = 32;
3121 let partials: Vec<Result<f64, String>> = (0..n)
3122 .into_par_iter()
3123 .chunks(CHUNK)
3124 .map_init(
3125 || (vec![0.0_f64; p], vec![0.0_f64; p], vec![0.0_f64; k_atoms]),
3126 |(g_buf, fitted_row, assign_buf), idxs| {
3127 // #1557 — pin any faer GEMM reached from this row-parallel
3128 // data-fit chunk to `Par::Seq` (no nested Rayon re-fan); the
3129 // per-row reductions are tiny, so the result is bit-identical.
3130 with_nested_parallel(|| {
3131 let mut acc = 0.0_f64;
3132 for row in idxs {
3133 acc += row_data_fit(row, g_buf, fitted_row, assign_buf)?;
3134 }
3135 Ok(acc)
3136 })
3137 },
3138 )
3139 .collect();
3140 let mut total = 0.0_f64;
3141 for partial in partials {
3142 total += partial?;
3143 }
3144 total
3145 } else {
3146 let mut g_buf = vec![0.0_f64; p];
3147 let mut fitted_row = vec![0.0_f64; p];
3148 let mut assign_buf = vec![0.0_f64; k_atoms];
3149 let mut total = 0.0_f64;
3150 for row in 0..n {
3151 total += row_data_fit(row, &mut g_buf, &mut fitted_row, &mut assign_buf)?;
3152 }
3153 total
3154 };
3155 let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
3156 let smoothness = penalty_scale * self.decoder_smoothness_value(&rho.lambda_smooth_vec());
3157 let ard = self.ard_value(rho)?;
3158 Ok(SaeManifoldLoss {
3159 data_fit,
3160 assignment_sparsity,
3161 smoothness,
3162 ard,
3163 evidence_gauge_deflated_directions: 0,
3164 })
3165 }
3166
3167 /// Reconstruction data-fit `0.5·Σ_i w_i·‖whiten(Z_i − R_i)‖²` for an EXPLICIT
3168 /// reconstruction matrix `R` (e.g. the hard top-k–projected `fitted`), using
3169 /// the SAME per-row metric and design-honesty weights as [`Self::loss_scaled`]
3170 /// (the soft-assignment data-fit). The only difference is the residual source:
3171 /// `loss_scaled` decodes the soft assignments on the fly, this consumes a
3172 /// reconstruction the caller already assembled (so the projected loss and the
3173 /// returned projected `fitted` describe one and the same model). The penalty
3174 /// terms (`assignment_sparsity`/`smoothness`/`ard`) are decoder/ρ properties
3175 /// the top-k gate does not change, so the caller keeps them from the soft
3176 /// `loss_scaled` and only swaps this data-fit in — see #1232.
3177 pub fn data_fit_for_reconstruction(
3178 &self,
3179 target: ArrayView2<'_, f64>,
3180 reconstruction: ArrayView2<'_, f64>,
3181 ) -> Result<f64, String> {
3182 let n = self.n_obs();
3183 let p = self.output_dim();
3184 if target.dim() != (n, p) {
3185 return Err(format!(
3186 "SaeManifoldTerm::data_fit_for_reconstruction: Z must be ({n}, {p}); got {:?}",
3187 target.dim()
3188 ));
3189 }
3190 if reconstruction.dim() != (n, p) {
3191 return Err(format!(
3192 "SaeManifoldTerm::data_fit_for_reconstruction: reconstruction must be ({n}, {p}); got {:?}",
3193 reconstruction.dim()
3194 ));
3195 }
3196 let whitens = self
3197 .row_metric
3198 .as_ref()
3199 .is_some_and(|metric| metric.whitens_likelihood());
3200 let row_loss_w = self.row_loss_weights.as_deref();
3201 let mut resid = vec![0.0_f64; p];
3202 let mut total = 0.0_f64;
3203 for row in 0..n {
3204 for out_col in 0..p {
3205 resid[out_col] = target[[row, out_col]] - reconstruction[[row, out_col]];
3206 }
3207 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3208 match self.row_metric.as_ref() {
3209 Some(metric) if whitens => {
3210 let r = ArrayView1::from(&resid[..p]);
3211 for w in metric.whiten_residual_row(row, r) {
3212 total += 0.5 * w_row * w * w;
3213 }
3214 }
3215 _ => {
3216 for &r in resid[..p].iter() {
3217 total += 0.5 * w_row * r * r;
3218 }
3219 }
3220 }
3221 }
3222 Ok(total)
3223 }
3224
3225 pub fn analytic_penalty_value_total(
3226 &self,
3227 registry: &AnalyticPenaltyRegistry,
3228 penalty_scale: f64,
3229 ) -> Result<f64, ArrowSchurError> {
3230 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3231 return Err(ArrowSchurError::SchurFactorFailed {
3232 reason: format!(
3233 "SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
3234 and positive; got {penalty_scale}"
3235 ),
3236 });
3237 }
3238 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3239 let layout = registry.rho_layout();
3240 let beta = self.flatten_beta();
3241 let mut value = 0.0_f64;
3242 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
3243 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3244 // Skip the registry `ARDPenalty` here for the same reason it is
3245 // skipped in `add_sae_analytic_penalty_contributions`: the coordinate
3246 // ARD energy is already counted by `loss.ard` (the von-Mises
3247 // `ard_value`), and the registry penalty's legacy Gaussian `½λt²` is
3248 // period-discontinuous. Including it would double-count the energy and
3249 // make this line-search objective jump across the branch cut while the
3250 // assembled gradient (von-Mises only, after the assembly fix) stays
3251 // continuous — i.e. a near-zero step would change the objective by a
3252 // finite amount and Armijo would wrongly reject it.
3253 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
3254 continue;
3255 }
3256 match tier {
3257 PenaltyTier::Psi => {
3258 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
3259 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3260 value += penalty_scale
3261 * per_atom.value(beta.slice(s![start..end]), rho_local);
3262 }
3263 } else {
3264 if !sae_penalty_is_row_block_supported(penalty) {
3265 return Err(ArrowSchurError::SchurFactorFailed {
3266 reason: format!(
3267 "validate_analytic_penalty_registry should have refused \
3268 non-row-block Psi-tier penalty {:?} (registry layout name \
3269 {name:?})",
3270 penalty.name()
3271 ),
3272 });
3273 }
3274 for atom_idx in 0..self.k_atoms() {
3275 let coord = &self.assignment.coords[atom_idx];
3276 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3277 let corrected_kind =
3278 self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3279 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3280 } else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
3281 // Origin-anchored magnitude shrinkage (SCAD/MCP) is
3282 // restricted to the Euclidean axes; periodic axes have
3283 // no chart origin and would make this energy
3284 // period-discontinuous (issue #795). This must mirror
3285 // the gradient/curvature assembly in
3286 // `add_sae_coord_penalty` exactly.
3287 match sae_coord_penalty_euclidean_restriction(coord) {
3288 Some((_axes, compacted)) => {
3289 value += penalty.value(compacted.view(), rho_local);
3290 }
3291 None => {
3292 value += penalty.value(coord.as_flat().view(), rho_local);
3293 }
3294 }
3295 } else {
3296 value += penalty.value(coord.as_flat().view(), rho_local);
3297 }
3298 }
3299 }
3300 }
3301 PenaltyTier::Beta => {
3302 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
3303 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3304 value += penalty_scale * per_fit.value(beta.view(), rho_local);
3305 }
3306 } else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
3307 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3308 if start < end {
3309 value += penalty_scale * per_atom.value(beta.view(), rho_local);
3310 }
3311 }
3312 } else {
3313 value += penalty_scale * penalty.value(beta.view(), rho_local);
3314 }
3315 }
3316 PenaltyTier::Rho => {}
3317 }
3318 }
3319 Ok(value)
3320 }
3321
3322 /// Energy of the decoder-block analytic penalties that have no native
3323 /// `SaeManifoldLoss` counterpart, evaluated at the current decoder `β` and
3324 /// the converged SAE state. These act on the per-atom decoder coefficient
3325 /// matrices: cross-atom decoder incoherence (#671), mechanism
3326 /// (feature-group) sparsity, and nuclear-norm embedding rank (#672). Each
3327 /// is injected with its live per-atom shape / co-activation before its
3328 /// value is taken, mirroring the assemble path.
3329 ///
3330 /// This is deliberately narrower than [`Self::analytic_penalty_value_total`]:
3331 /// it excludes the Psi-tier coordinate / assignment penalties (ARD,
3332 /// Isometry, ScadMcp, BlockOrthogonality, IBP/softmax assignment sparsity).
3333 /// The SAE already carries its own ARD (`loss.ard`) and assignment sparsity
3334 /// (`loss.assignment_sparsity`) energy, so adding the registry ARD /
3335 /// assignment value on top would double-count, and the gauge-only
3336 /// coordinate penalties are not part of the penalized deviance the
3337 /// REML/Laplace criterion scores. The decoder-block penalties, by contrast,
3338 /// are real penalized-energy terms with no `loss.*` representative: the
3339 /// inner solve minimizes them (they enter `gb`/`hbb`) but they were absent
3340 /// from the criterion scalar `v`. This restores that consistency so the
3341 /// ρ-sweep ranks the same objective the inner solve descends — the #671
3342 /// incoherence lever in particular now shapes model selection, not just the
3343 /// Newton step.
3344 ///
3345 /// NOTE: the coordinate-block penalties with no native `loss.*` twin
3346 /// (`ScadMcp`, `BlockOrthogonality`) carry the same residual inconsistency
3347 /// (scored in the line search via `penalized_objective_total`, absent from
3348 /// the REML scalar). They are left out here because they share a registry
3349 /// dispatch with the always-on `Isometry` gauge, whose inclusion in the
3350 /// topology-comparison criterion is a separate design question (#673:
3351 /// topology evidence is gauge-conditional). Folding the coord-tier energy in
3352 /// is tracked apart from this #671 decoder fix.
3353 pub fn analytic_decoder_penalty_value_total(
3354 &self,
3355 registry: &AnalyticPenaltyRegistry,
3356 ) -> Result<f64, ArrowSchurError> {
3357 // Resolve each penalty's rho slice exactly as `analytic_penalty_value_total`
3358 // does (registry-local rho at zeros), so a learnable decoder-penalty weight
3359 // is honoured rather than indexing into an empty view.
3360 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3361 let layout = registry.rho_layout();
3362 let beta = self.flatten_beta();
3363 let mut value = 0.0_f64;
3364 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3365 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3366 match penalty {
3367 AnalyticPenaltyKind::DecoderIncoherence(base) => {
3368 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3369 value += per_fit.value(beta.view(), rho_local);
3370 }
3371 }
3372 AnalyticPenaltyKind::MechanismSparsity(base) => {
3373 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3374 if start < end {
3375 value += per_atom.value(beta.view(), rho_local);
3376 }
3377 }
3378 }
3379 AnalyticPenaltyKind::NuclearNorm(base) => {
3380 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3381 value += per_atom.value(beta.slice(s![start..end]), rho_local);
3382 }
3383 }
3384 _ => {}
3385 }
3386 }
3387 Ok(value)
3388 }
3389
3390 /// Energy of the COORDINATE-tier isometry penalty(ies) at the converged
3391 /// SAE state. This is the per-atom `½μ Σ_n ‖J_n^T W_n J_n / gbar − g_ref‖²`
3392 /// summed over atoms, evaluated through `corrected_isometry_penalty` so the
3393 /// live decoder/coordinate caches drive the value exactly as the assemble
3394 /// path does. It has no `SaeManifoldLoss` twin (the loss carries only
3395 /// data-fit / assignment / smoothness / ARD), so the Laplace/REML criterion
3396 /// must add it explicitly to score the same penalized objective the inner
3397 /// solve descends.
3398 pub fn isometry_penalty_value_total(
3399 &self,
3400 registry: &AnalyticPenaltyRegistry,
3401 ) -> Result<f64, ArrowSchurError> {
3402 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3403 let layout = registry.rho_layout();
3404 let mut value = 0.0_f64;
3405 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3406 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3407 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3408 for atom_idx in 0..self.k_atoms() {
3409 let coord = &self.assignment.coords[atom_idx];
3410 let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3411 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3412 }
3413 }
3414 }
3415 Ok(value)
3416 }
3417
3418 /// Whether assembling `registry` will scatter an isometry Gauss-Newton
3419 /// cross-block (`H_tβ`) into the per-row dense `htbeta` slabs.
3420 ///
3421 /// `add_sae_isometry_metric_gn_blocks` writes the coupled cross-block (and
3422 /// flips on `activate_dense_htbeta_supplement`) only when (a) the registry
3423 /// carries an `Isometry` penalty and (b) the atom's chart
3424 /// `preserves_isometry_cross_block_coherence` (flat charts — `Euclidean`,
3425 /// `Circle`, and flat products — keep the full `μ AᵀA` coupling; curved /
3426 /// boundary charts drop it to stay PSD). On the non-frames matrix-free path
3427 /// the data-fit cross-block is carried by the Kronecker row operator and the
3428 /// per-row `htbeta` slab is allocated at zero width (#1406/#1407 anti-leak),
3429 /// so this dense isometry supplement has nowhere to land unless the slab is
3430 /// widened to the full `beta_dim`. This predicate decides exactly that. The
3431 /// effective isometry weight `μ` is NOT consulted here: a near-zero `μ`
3432 /// short-circuits the per-row write, but the slab must still exist so the
3433 /// solver's `htbeta_dense_supplement` read is well-shaped.
3434 pub(crate) fn registry_writes_dense_isometry_cross_block(
3435 &self,
3436 registry: &AnalyticPenaltyRegistry,
3437 ) -> bool {
3438 registry
3439 .penalties
3440 .iter()
3441 .any(|p| matches!(p, AnalyticPenaltyKind::Isometry(_)))
3442 && self
3443 .assignment
3444 .coords
3445 .iter()
3446 .any(|coord| coord.manifold().preserves_isometry_cross_block_coherence())
3447 }
3448
3449 /// Extra analytic-penalty energy that has no native `SaeManifoldLoss`
3450 /// component but is part of the penalized objective ranked by the SAE
3451 /// Laplace/REML criterion.
3452 pub fn reml_extra_penalty_value_total(
3453 &self,
3454 registry: &AnalyticPenaltyRegistry,
3455 ) -> Result<f64, ArrowSchurError> {
3456 Ok(self.analytic_decoder_penalty_value_total(registry)?
3457 + self.isometry_penalty_value_total(registry)?)
3458 }
3459
3460 pub fn penalized_objective_total(
3461 &self,
3462 target: ArrayView2<'_, f64>,
3463 rho: &SaeManifoldRho,
3464 registry: Option<&AnalyticPenaltyRegistry>,
3465 penalty_scale: f64,
3466 ) -> Result<f64, String> {
3467 let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
3468 if let Some(analytic_registry) = registry {
3469 total += self
3470 .analytic_penalty_value_total(analytic_registry, penalty_scale)
3471 .map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
3472 }
3473 // #1026 — decoder-repulsion value, on the SAME frozen gate the assembly
3474 // used, so the line search sees the term the Newton step optimizes. 0
3475 // unless two atoms are near-collinear (the no-op case).
3476 total += self.decoder_repulsion_value(penalty_scale);
3477 // #1026/#1522 — interior-point collapse-prevention barriers, on the SAME
3478 // decoders the assembly's gradient/curvature used, so the line search sees
3479 // exactly the term the inner Newton step optimises (no value/grad desync).
3480 total += self.separation_barrier_value(penalty_scale);
3481 Ok(total)
3482 }
3483
3484 pub(crate) fn decoder_smoothness_value(&self, lambda_smooth: &[f64]) -> f64 {
3485 // Smoothness penalty value is `0.5·λ·Σ_oc B[:,oc]ᵀ S B[:,oc]`. Form the
3486 // `S·B` matrix product once per atom (O(M²·p)) and reduce against `B`
3487 // with a single O(M·p) Hadamard sum, instead of the previous
3488 // four-factor multiply-accumulate inside an `O(M²·p)` triple loop.
3489 // The quadratic form only sees the symmetric part of `S`, so reusing
3490 // the raw (un-symmetrised) `smooth_penalty` here is numerically
3491 // identical to the symmetrised assembly form.
3492 // Per-atom `S_k · B_k` products are independent across atoms, so they ride
3493 // the multi-GPU batched smoothness GEMM (uniform-shape groups tiled across
3494 // every device); `symmetrize = false` because the quadratic form only sees
3495 // the symmetric part of `S` regardless. Exact CPU fallback per atom.
3496 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3497 .atoms
3498 .iter()
3499 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3500 .collect();
3501 let sb_all = batched_smooth_sb(&sb_inputs, false);
3502 let mut acc = 0.0;
3503 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3504 acc += 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3505 }
3506 acc
3507 }
3508
3509 /// Per-atom decoder-smoothness values (#1556): entry `k` is
3510 /// `0.5·λ_smooth[k]·<B_k, S_k B_k>` (sum = [`Self::decoder_smoothness_value`]).
3511 /// This is the explicit `∂loss.smoothness/∂log λ_smooth[k]` gradient entry.
3512 pub(crate) fn decoder_smoothness_value_per_atom(&self, lambda_smooth: &[f64]) -> Vec<f64> {
3513 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3514 .atoms
3515 .iter()
3516 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3517 .collect();
3518 let sb_all = batched_smooth_sb(&sb_inputs, false);
3519 let mut per_atom = vec![0.0_f64; self.atoms.len()];
3520 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3521 per_atom[atom_idx] =
3522 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3523 }
3524 per_atom
3525 }
3526
3527 pub(crate) fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
3528 if rho.log_ard.len() != self.k_atoms() {
3529 return Err(format!(
3530 "ARD rho has {} atoms but term has {}",
3531 rho.log_ard.len(),
3532 self.k_atoms()
3533 ));
3534 }
3535 let n = self.n_obs();
3536 let mut acc = 0.0;
3537 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3538 let d = coord.latent_dim();
3539 if rho.log_ard[atom_idx].is_empty() {
3540 continue;
3541 }
3542 if rho.log_ard[atom_idx].len() != d {
3543 return Err(format!(
3544 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
3545 rho.log_ard[atom_idx].len()
3546 ));
3547 }
3548 // Per-axis periodicity selects the smooth von-Mises energy on
3549 // wrapped (Circle) axes and the Gaussian on Euclidean axes.
3550 let periods = coord.effective_axis_periods();
3551 for axis in 0..d {
3552 let log_alpha = rho.log_ard[atom_idx][axis];
3553 // Clamp the log-precision before exponentiating: a raw
3554 // `exp(log_ard)` overflows to `inf` for `log_ard ≳ 709`, and the
3555 // `inf` precision then poisons the ARD energy / curvature with
3556 // `inf · 0.0 = NaN` (#742, Issue 4).
3557 let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
3558 let period = periods[axis];
3559 let mut energy = 0.0;
3560 for row in 0..n {
3561 let v = coord.row(row)[axis];
3562 energy += ArdAxisPrior::eval(alpha, v, period).value;
3563 }
3564 // Negative-log prior for precision alpha. The data-dependent
3565 // energy is the (Gaussian or von-Mises) coordinate prior; the
3566 // accompanying normaliser is the precision log-partition.
3567 //
3568 // Euclidean axes keep the Gaussian normaliser `-0.5 n log α`.
3569 // Periodic (von-Mises) axes use the EXACT von-Mises precision
3570 // log-partition `n[-η + log I0(η)]`, η = α/κ², κ = 2π/P, rather
3571 // than the Gaussian surrogate: the von-Mises partition function
3572 // is `2π I0(η)` (up to the κ Jacobian), so the per-observation
3573 // normaliser is `-η + log I0(η)` and is exact across the cut.
3574 match period {
3575 None => {
3576 acc += energy - 0.5 * (n as f64) * log_alpha;
3577 }
3578 Some(p) => {
3579 let kappa = std::f64::consts::TAU / p;
3580 let eta = alpha / (kappa * kappa);
3581 // Overflow-free `log I0(η)`; `bessel_i0(η).ln()` would be
3582 // `+inf` for `η ≳ 709` (#1113).
3583 let log_i0 = bessel_i0_log_and_ratio(eta).0;
3584 acc += energy + (n as f64) * (-eta + log_i0);
3585 }
3586 }
3587 }
3588 }
3589 Ok(acc)
3590 }
3591
3592 /// Assemble the enlarged `(logits, t)` row-local Arrow-Schur system.
3593 ///
3594 /// Full-batch entry point: a single chunk covering all rows, with the
3595 /// β-tier penalties (decoder smoothness, ARD, analytic β penalties) carrying
3596 /// their full strength. The streaming driver calls
3597 /// [`Self::assemble_arrow_schur_scaled`] directly with a `penalty_scale`
3598 /// equal to the minibatch fraction `n_chunk / N`, so that the sum of the
3599 /// per-chunk β-tier contributions over a full pass reconstructs exactly the
3600 /// single global β penalty (the smoothness/ARD/β terms are functions of `B`
3601 /// and the global coordinates, not of the chunk's rows).
3602 pub fn assemble_arrow_schur(
3603 &mut self,
3604 target: ArrayView2<'_, f64>,
3605 rho: &SaeManifoldRho,
3606 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3607 ) -> Result<ArrowSchurSystem, String> {
3608 self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
3609 }
3610
3611 /// Assemble the row-local Arrow-Schur system with a `penalty_scale` applied
3612 /// to the β-tier (decoder smoothness, ARD prior, analytic β penalties).
3613 ///
3614 /// `penalty_scale == 1.0` recovers the full-batch assembly. The streaming
3615 /// driver passes the minibatch fraction `n_chunk / N` so that the β-tier
3616 /// reduced-Schur and gradient contributions of the chunks sum to exactly one
3617 /// global copy across a full pass (data-fit, assignment-prior, and per-row
3618 /// coord/logit analytic terms are *not* scaled — they are genuine per-row
3619 /// sums).
3620 pub fn assemble_arrow_schur_scaled(
3621 &mut self,
3622 target: ArrayView2<'_, f64>,
3623 rho: &SaeManifoldRho,
3624 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3625 penalty_scale: f64,
3626 ) -> Result<ArrowSchurSystem, String> {
3627 self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3628 target,
3629 rho,
3630 analytic_penalties,
3631 penalty_scale,
3632 SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
3633 )
3634 }
3635
3636 pub(crate) fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3637 &mut self,
3638 target: ArrayView2<'_, f64>,
3639 rho: &SaeManifoldRho,
3640 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3641 penalty_scale: f64,
3642 dense_beta_penalty_probe_max_dim: usize,
3643 ) -> Result<ArrowSchurSystem, String> {
3644 self.assemble_arrow_schur_inner(
3645 target,
3646 rho,
3647 analytic_penalties,
3648 penalty_scale,
3649 dense_beta_penalty_probe_max_dim,
3650 None,
3651 )
3652 }
3653
3654 /// Innermost assembly entry. `forced_layout` overrides the budget-derived
3655 /// active-set layout so a caller can pin the dense (`Forced(None)`) or a
3656 /// specific compact (`Forced(Some(layout))`) path — used by the
3657 /// compact-vs-dense Riemannian-geometry equality regression test to drive
3658 /// both layouts on identical data. `Computed` is the production path:
3659 /// the layout is derived from the assignment mode + `sparse_active_plan`.
3660 pub(crate) fn assemble_arrow_schur_inner(
3661 &mut self,
3662 target: ArrayView2<'_, f64>,
3663 rho: &SaeManifoldRho,
3664 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3665 penalty_scale: f64,
3666 dense_beta_penalty_probe_max_dim: usize,
3667 forced_layout: ForcedRowLayout,
3668 ) -> Result<ArrowSchurSystem, String> {
3669 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3670 return Err(format!(
3671 "SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3672 ));
3673 }
3674 if target.dim() != (self.n_obs(), self.output_dim()) {
3675 return Err(format!(
3676 "SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
3677 self.n_obs(),
3678 self.output_dim(),
3679 target.dim()
3680 ));
3681 }
3682 if rho.log_ard.len() != self.k_atoms() {
3683 return Err(format!(
3684 "SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
3685 rho.log_ard.len(),
3686 self.k_atoms()
3687 ));
3688 }
3689 // `lambda_smooth` is indexed per-atom in the smoothness gradient/curvature
3690 // assembly (`lambda_smooth[atom_idx]`); a too-short vector (e.g. a growth
3691 // move that grew `k_atoms()` without extending ρ — #1556) would panic deep
3692 // in the assembly loop with an opaque index-out-of-bounds. Validate it here
3693 // alongside `log_ard` so the contract violation surfaces as a clear Err.
3694 if rho.log_lambda_smooth.len() != self.k_atoms() {
3695 return Err(format!(
3696 "SaeManifoldTerm::assemble_arrow_schur: log_lambda_smooth length {} != K {}",
3697 rho.log_lambda_smooth.len(),
3698 self.k_atoms()
3699 ));
3700 }
3701 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3702 let ard_len = rho.log_ard[atom_idx].len();
3703 let d = coord.latent_dim();
3704 if ard_len != 0 && ard_len != d {
3705 return Err(format!(
3706 "SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
3707 has len {ard_len}; expected 0 (disabled) or atom dim {d}"
3708 ));
3709 }
3710 }
3711 // Reparameterize each atom's roughness Gram into arc length at the
3712 // current decoder/coordinates (issue #673). This is the single
3713 // chokepoint for both the inner Newton assembly and the undamped
3714 // evidence factorization, so freezing the pullback-metric weight here
3715 // (lagged-diffusivity) keeps the smoothness value, gradient, Kronecker
3716 // Hessian, and REML log-det mutually consistent within each assembly
3717 // and makes the converged penalty — hence the topology evidence —
3718 // gauge-invariant. Constant-speed (periodic) atoms are unaffected.
3719 for atom in &mut self.atoms {
3720 atom.refresh_intrinsic_smooth_penalty();
3721 }
3722 // #1026 — freeze the decoder-repulsion collinearity gate at the SAME
3723 // assembly chokepoint as the smoothness Gram, so the repulsion's
3724 // gradient/curvature (assembled below) and its value (read by the
3725 // line-search `penalized_objective_total`) share one frozen gate.
3726 self.refresh_decoder_repulsion_gate();
3727 // #1625 — freeze the SEPARATION barrier's normalized-coactivation `q_jk`
3728 // at the same chokepoint. The barrier weights its decoder-shape repulsion
3729 // by the routing coactivation, but its gradient treats that weight as a
3730 // constant; recomputing it from the trial logits in the line-search value
3731 // desyncs value vs gradient in the logit block and stalls the inner solve
3732 // (#1625). Freezing it here makes value/gradient/curvature consistent.
3733 self.refresh_barrier_coactivation_gate();
3734 let n = self.n_obs();
3735 let p = self.output_dim();
3736 let k_atoms = self.k_atoms();
3737 let assignment_dim = self.assignment.assignment_coord_dim();
3738 let q = self.assignment.row_block_dim();
3739 let beta_dim = self.beta_dim();
3740 let frame_projection = FrameProjection::new(self);
3741 let beta_offsets = frame_projection.beta_offsets.clone();
3742 let coord_offsets = self.assignment.coord_offsets();
3743 // β-tier decoder smoothness is a global (B-only) penalty; under a
3744 // minibatch pass it is scaled by the chunk fraction so the per-chunk
3745 // contributions sum to one global copy.
3746 // Per-atom decoder-smoothness strengths (#1556): atom k's penalty `S_k`
3747 // is scaled by `λ_smooth[k]·penalty_scale`. The minibatch `penalty_scale`
3748 // multiplies every atom uniformly.
3749 let lambda_smooth: Vec<f64> = rho
3750 .lambda_smooth_vec()
3751 .iter()
3752 .map(|&l| l * penalty_scale)
3753 .collect();
3754 let (assignment_grad, assignment_hdiag) =
3755 assignment_prior_grad_hdiag(&self.assignment, rho)?;
3756
3757 // #1038 softmax entropy: the exact per-row Hessian in logits is dense
3758 // (`H_kj = (λ/τ²) a_k[δ_kj(m−L_k−1)+a_j(L_k+L_j+1−2m)]`), not just the
3759 // `assignment_hdiag` diagonal. Build the shared penalty + `scale = λ/τ²`
3760 // once here so the dense row block written into `block.htt` below, the
3761 // criterion's `log|H|`, and the #1006 θ-adjoint all differentiate the
3762 // SAME operator. JumpReLU / IBP keep their (separately exact) diagonal /
3763 // cross-row channels and leave this `None`. The block is gauge-null in
3764 // isolation (`H·𝟙 = 0`); it is only ever summed onto the gauge-breaking
3765 // data-fit row block before the Cholesky factor, never factored alone.
3766 let softmax_dense: Option<(
3767 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
3768 f64,
3769 )> = match self.assignment.mode {
3770 AssignmentMode::Softmax {
3771 temperature,
3772 sparsity,
3773 } if k_atoms > 1 => {
3774 let inv_tau = 1.0 / temperature;
3775 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
3776 Some((
3777 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
3778 k_atoms,
3779 temperature,
3780 ),
3781 scale,
3782 ))
3783 }
3784 _ => None,
3785 };
3786
3787 // Decoder smoothness penalty: build one KroneckerPenaltyOp per atom
3788 // (structure = λ·S_k ⊗ I_p, offset = beta_offsets[k]) instead of
3789 // materialising the dense K×K block. The gradient is a dense K-vector
3790 // accumulated into `smooth_grad_gb` and written into sys.gb after sys
3791 // is constructed (#296).
3792 let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
3793 // #972 / #977 T1: retain each atom's symmetrised `λ S_k` (`M_k × M_k`) so
3794 // the frame transform can rebuild the smooth penalty in the factored
3795 // coordinate space as `λ S_k ⊗ I_{r_k}` (the `tr(C_kᵀ S_k C_k)` form,
3796 // using `U_kᵀU_k = I`). Unused — and not even read — on the full-`B`
3797 // path, so this is a zero-cost capture there.
3798 let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
3799 let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
3800 // #1117 — rank deficiency is handled at the basis layer: any
3801 // rank-deficient atom was reparametrized onto its data-supported subspace
3802 // at fit entry (`reduce_atoms_to_data_supported_rank`), so the β-tier here
3803 // always sees a full-rank design and needs no step-time data-null
3804 // deflation operator. The well-conditioned (full-rank) path is unchanged.
3805 // Per-atom smoothness-gradient GEMMs `½(S_k+S_kᵀ)·B_k` are independent
3806 // across atoms; batch them across ALL GPUs (uniform-shape tiles) and
3807 // scale by `lambda_smooth` below. `symmetrize = true` reproduces the
3808 // per-atom symmetrised `scaled_s/λ` used by the Kronecker op. Exact CPU
3809 // fallback per atom keeps the result bit-for-bit with the all-CPU path.
3810 let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3811 .atoms
3812 .iter()
3813 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3814 .collect();
3815 let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
3816 for (atom_idx, atom) in self.atoms.iter().enumerate() {
3817 let m = atom.basis_size();
3818 let off = beta_offsets[atom_idx];
3819 // Symmetrise and scale the smoothness penalty matrix.
3820 let mut scaled_s = Array2::<f64>::zeros((m, m));
3821 for i in 0..m {
3822 for j in 0..m {
3823 let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
3824 scaled_s[[i, j]] = lambda_smooth[atom_idx] * s_ij;
3825 }
3826 }
3827 // Gradient: g[beta_i] += (λ_k S_k B_k)[i, out_col]. The (m×m)·(m×p)
3828 // GEMM `½(S+Sᵀ)·B_k` was computed in the multi-GPU batch above; here
3829 // we only apply atom k's `lambda_smooth[atom_idx]`.
3830 let sb = &sym_sb_all[atom_idx] * lambda_smooth[atom_idx];
3831 for out_col in 0..p {
3832 for i in 0..m {
3833 let beta_i = off + i * p + out_col;
3834 smooth_grad_gb[beta_i] += sb[[i, out_col]];
3835 }
3836 }
3837 // IdentityRightKroneckerPenaltyOp: factor_a = λ·S_k (m×m), factor_b = I_p.
3838 smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
3839 factor_a: scaled_s.clone(),
3840 p,
3841 global_offset: off,
3842 k: beta_dim,
3843 }));
3844 // Retain `λ S_k` for the factored rebuild (no-op cost on full-`B`).
3845 smooth_scaled_s.push(scaled_s);
3846 }
3847
3848 // Per-row active-set layout. Engaged for two regimes:
3849 // * JumpReLU — structural gate plus the smooth prior's
3850 // machine-precision support: atoms with
3851 // `(logit - threshold)/tau > -36` enter the compact solve
3852 // ([`jumprelu_in_optimization_band`]). Strictly gated-off atoms
3853 // (logit ≤ threshold) carry zero assignment mass so their data-fit
3854 // reconstruction contribution and data-fit logit JVP are zero, but
3855 // supported atoms keep value-consistent prior gradient in the row block.
3856 // * IBP-MAP at large `K` — the dense `(m_total · p)²` data
3857 // Gram is infeasible, so each row is truncated to its
3858 // top-`k_active` atoms above a relative magnitude cutoff
3859 // ([`Self::sparse_active_plan`]). Small-`K` problems return `None`
3860 // and keep the exact full-support layout.
3861 // The compact row block is sized `q_active = |active| + Σ_{k∈active}
3862 // d_k` instead of the full `q`.
3863 let coord_dims: Vec<usize> = self
3864 .assignment
3865 .coords
3866 .iter()
3867 .map(|c| c.latent_dim())
3868 .collect();
3869 let row_layout: Option<SaeRowLayout> = match forced_layout {
3870 Some(layout) => layout,
3871 None => match self.assignment.mode {
3872 AssignmentMode::JumpReLU {
3873 threshold,
3874 temperature,
3875 } => Some(SaeRowLayout::from_jumprelu(
3876 n,
3877 k_atoms,
3878 threshold,
3879 temperature,
3880 &self.assignment.logits,
3881 coord_dims.clone(),
3882 self.assignment.coord_offsets(),
3883 )),
3884 // #1408/#1409 — Softmax engages the COMPACT top-`k` row layout
3885 // inside the optimization (no longer a post-fit projection).
3886 // The active set is each row's top-`k_active_cap` softmax atoms
3887 // above the relative cutoff; the cap comes from the user's
3888 // `top_k` (`softmax_active_cap`) and/or the in-core memory budget
3889 // ([`Self::softmax_active_plan`]). The full-`K` softmax
3890 // normalization still forms `a` (the gate map); only the dropped
3891 // tail logits, carrying negligible `O(a)` reconstruction mass and
3892 // `O(a²)` curvature, leave the per-row block.
3893 //
3894 // Coherence (the load-bearing correctness invariant): the
3895 // assembly's softmax curvature branch writes the ACTIVE×ACTIVE
3896 // principal sub-block of the Gershgorin Loewner majorizer
3897 // `D = diag(Σ_j|H_kj|)` (#1419; PSD and `D ⪰ H_entropy`) on the
3898 // compact logit slots — NOT the indefinite `assignment_hdiag`
3899 // diagonal. The logdet ρ-trace
3900 // (`assignment_log_strength_hessian_trace`) iterates the row's
3901 // active logit slots and indexes that SAME majorizer by global
3902 // atom, and the θ-adjoint reads its derivative via `jets.vars`
3903 // (global-atom indexed), so value, log|H|, and Γ differentiate
3904 // ONE operator on the compact support. The FFI's after-the-fit
3905 // top-`k` projection is then a no-op at the optimum.
3906 AssignmentMode::Softmax { .. } => match self.softmax_active_plan() {
3907 Some((k_active_cap, relative_cutoff)) => {
3908 let mut assignments_all = Vec::with_capacity(n);
3909 for row in 0..n {
3910 assignments_all
3911 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3912 }
3913 Some(SaeRowLayout::from_dense_weights(
3914 &assignments_all,
3915 k_active_cap,
3916 relative_cutoff,
3917 coord_dims.clone(),
3918 self.assignment.coord_offsets(),
3919 ))
3920 }
3921 None => None,
3922 },
3923 AssignmentMode::IBPMap { .. } => {
3924 match self.sparse_active_plan() {
3925 Some((k_active_cap, relative_cutoff)) => {
3926 // Build per-row dense assignments once to derive the
3927 // active set; the row loop re-derives `assignments`
3928 // (cheap gate map at the same rho) and reuses these
3929 // active sets.
3930 let mut assignments_all = Vec::with_capacity(n);
3931 for row in 0..n {
3932 assignments_all
3933 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3934 }
3935 // #1414: pass the RELATIVE cutoff through;
3936 // `from_dense_weights` applies it per row against that
3937 // row's own peak `max_k |a_{n,k}|`, matching the
3938 // documented `sparse_active_plan` contract. A single
3939 // global threshold (relative_cutoff · whole-dataset
3940 // peak) wrongly drops every atom of a uniformly-small
3941 // row when another row peaks high.
3942 Some(SaeRowLayout::from_dense_weights(
3943 &assignments_all,
3944 k_active_cap,
3945 relative_cutoff,
3946 coord_dims.clone(),
3947 self.assignment.coord_offsets(),
3948 ))
3949 }
3950 None => None,
3951 }
3952 }
3953 },
3954 };
3955 // #974 likelihood-whitening seam. The single per-row decision: when the
3956 // installed `RowMetric` is a genuinely estimated noise model
3957 // (`whitens_likelihood()` — only `WhitenedStructured`), the
3958 // reconstruction data-fit, its t-block Gauss-Newton row block, AND the
3959 // β-tier data-fit gradient are all assembled through the SAME per-row
3960 // metric `M_n = U_n U_nᵀ = Σ_n^{-1}`. There is exactly ONE construction
3961 // site (the `whiten_rows` closure below), so the value the line-search
3962 // sums and the gradient/Hessian the Newton step solves cannot drift apart
3963 // (the objective↔gradient-desync cure). For Euclidean / OutputFisher /
3964 // no-metric the closure is the identity and every downstream loop is
3965 // byte-identical to the historical isotropic path.
3966 let whitens_likelihood = self
3967 .row_metric
3968 .as_ref()
3969 .is_some_and(|metric| metric.whitens_likelihood());
3970 // #972 / #977 T1: engage the FACTORED Grassmann-coordinate β-tier when
3971 // any atom has an active decoder frame. The closed-form factorization
3972 // `Φᵀ(G ⊗ I_p)Φ = G ⊗ (U_iᵀU_j)` is EXACT only for the isotropic
3973 // likelihood; under an active whitening metric (`whitens_likelihood()`,
3974 // only `WhitenedStructured`) the per-row output factor would be
3975 // `U_iᵀ M_n U_j` and does NOT factor out of the basis Gram, so we fall
3976 // back to the full-`B` path there (frames + whitening is out of scope —
3977 // see #974). The common Euclidean / OutputFisher / no-metric case factors
3978 // cleanly. When `frames_engaged` is false, EVERY β-tier object below is
3979 // assembled bit-for-bit as the historical full-`B` path.
3980 let frames_engaged = self.any_frame_active() && !whitens_likelihood;
3981 // #1407: fixed-decoder mode skips the entire β decoder tier (G/gb/htbeta
3982 // operator/hbb/β-penalties); only per-row htt/gt are produced.
3983 let fixed_decoder = self.fixed_decoder_assembly;
3984 let admission_plan = self
3985 .streaming_plan()
3986 .admitted_or_error(self.n_obs(), self.output_dim(), self.k_atoms())
3987 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
3988 // #1407: fixed-decoder builds NO dense β-Hessian (hbb) — force the
3989 // empty-hbb system constructor so no `beta_dim × beta_dim` workspace is
3990 // taken (the early return skips `reclaim_border_hbb_workspace`).
3991 let dense_beta_curvature = !fixed_decoder
3992 && admission_plan.direct_admitted
3993 && !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
3994 // #1406: the dense per-row cross-block slab `block.htbeta` is only WRITTEN
3995 // (line ~4243) and READ by the solver when `frames_engaged` (the factored
3996 // full-B path, which installs NO matrix-free row operator → the solver's
3997 // `sys_htbeta_apply_row` falls back to the dense slab). On the
3998 // `!frames_engaged` path the cross block is carried entirely by the
3999 // matrix-free Kronecker operator (`set_row_htbeta_operator`, ~line 4491);
4000 // `activate_dense_htbeta_supplement` is never called, so the solver never
4001 // touches `block.htbeta`. Allocating it at `beta_dim = K·M·p` there is the
4002 // ~6 TiB high-K leak (#1405/#1406): allocate ZERO columns instead. Frames
4003 // still use the (much smaller) factored border width.
4004 // #795/#1406/#1407: the non-frames matrix-free path normally holds a
4005 // ZERO-width per-row cross-block slab — the data-fit `H_tβ` is carried by
4006 // the Kronecker row operator (`set_row_htbeta_operator`), and allocating
4007 // the dense slab at `beta_dim = K·M·p` is the high-K memory leak. But an
4008 // ISOMETRY penalty on a coherence-preserving (flat) chart scatters an
4009 // ADDITIONAL Gauss-Newton cross-block into the dense per-row `htbeta`
4010 // slab and flips on `activate_dense_htbeta_supplement` — dropping it would
4011 // leave the Newton system block-diagonal and forfeit the strong `t↔B`
4012 // isometry coupling the circle fit needs to reach KKT stationarity (#795).
4013 // So on the non-frames path widen the slab to `beta_dim` exactly when that
4014 // dense supplement will be written, and keep zero width otherwise.
4015 let dense_isometry_cross_block = !fixed_decoder
4016 && analytic_penalties
4017 .map(|registry| self.registry_writes_dense_isometry_cross_block(registry))
4018 .unwrap_or(false);
4019 let row_htbeta_dim = if fixed_decoder {
4020 // Fixed-decoder mode skips the β tier entirely.
4021 0
4022 } else if frames_engaged {
4023 self.factored_border_dim()
4024 } else if dense_isometry_cross_block {
4025 // Matrix-free data-fit cross-block + dense isometry supplement: the
4026 // supplement is written/read in the full-`B` β coordinate system.
4027 beta_dim
4028 } else {
4029 // Matrix-free path with no dense cross-block supplement.
4030 0
4031 };
4032 // Build the Arrow-Schur system: heterogeneous row dims when a compact
4033 // layout is active, uniform `q` otherwise.
4034 let mut sys = if let Some(ref layout) = row_layout {
4035 let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
4036 if dense_beta_curvature {
4037 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4038 ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
4039 per_row_dims,
4040 beta_dim,
4041 hbb_workspace,
4042 row_htbeta_dim,
4043 )
4044 } else {
4045 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4046 ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
4047 per_row_dims,
4048 beta_dim,
4049 row_htbeta_dim,
4050 )
4051 }
4052 } else if dense_beta_curvature {
4053 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4054 ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
4055 n,
4056 q,
4057 beta_dim,
4058 hbb_workspace,
4059 row_htbeta_dim,
4060 )
4061 } else {
4062 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4063 ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
4064 };
4065 // Apply accumulated smoothness-penalty gradients into sys.gb.
4066 for (i, g) in smooth_grad_gb.iter().enumerate() {
4067 sys.gb[i] += g;
4068 }
4069 // `w_dim` is the whitened output dimension: `rank` of the metric factor
4070 // when whitening, else `p` (identity). `error_white` is the whitened
4071 // residual `U_nᵀ r_n ∈ ℝ^{w_dim}` whose squared norm is `r_nᵀ M_n r_n`,
4072 // shared by the value path, the t-block GN, and (lifted back to p-space)
4073 // the β-tier gradient.
4074 let w_dim = match self.row_metric.as_ref() {
4075 Some(metric) if whitens_likelihood => metric.metric_rank(),
4076 _ => p,
4077 };
4078 // Data-fit Gauss-Newton β-Hessian is block-diagonal across the `p`
4079 // output channels and identical in each: with the flat β layout
4080 // `β[μ·p + oc] = B[μ, oc]` (μ enumerating (atom, basis_col)) the GN
4081 // outer product `Jβᵀ Jβ` couples only equal `oc`, with the same
4082 // `(M_total × M_total)` block `G[μ, μ'] = Σ_rows (a_k φ_k[m])(a_{k'} φ_{k'}[m'])`
4083 // for every channel. So `H_data = G ⊗ I_p`. The `μ` index of an `a_phi`
4084 // entry whose global β base is `beta_base` is `beta_base / p` (every
4085 // `beta_offset` and the `basis_col·p` stride are multiples of `p`).
4086 //
4087 // `G` is only non-zero on `(atom_i, atom_j)` pairs that co-occur in
4088 // some row's active set, so we accumulate it as a sparse map of dense
4089 // per-atom-pair `(m_i × m_j)` blocks keyed by `(atom_i, atom_j)` rather
4090 // than as a dense `(m_total × m_total)` matrix. At `K = 100K` with
4091 // per-row active sets of size `k_active ≪ K`, only `O(N · k_active²)`
4092 // pairs are ever touched, so the data Gram (and every matvec /
4093 // diagonal pass over it via `SparseBlockKroneckerPenaltyOp`) tracks the
4094 // active atoms instead of `K²`. In the dense full-support layout the
4095 // map degenerates to every co-occurring pair, reproducing the dense
4096 // Gram exactly. A `BTreeMap` key order keeps the installed op's
4097 // fingerprint deterministic. The `μ`-space offset of atom `k` is
4098 // `beta_offsets[k] / p`.
4099 type SaeGBlocks = std::collections::BTreeMap<(usize, usize), Array2<f64>>;
4100 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
4101 let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
4102 // Stick-breaking prior for IBP-MAP depends only on (k_atoms, alpha_eff)
4103 // which are constant across rows for the current rho; precompute once.
4104 let ibp_prior_vec = match self.assignment.mode {
4105 AssignmentMode::IBPMap { .. } => {
4106 let alpha = self
4107 .assignment
4108 .mode
4109 .resolved_ibp_alpha(rho)
4110 .ok_or_else(|| "IBP assignment alpha resolution failed".to_string())?;
4111 Some(ordered_geometric_shrinkage_prior(k_atoms, alpha).to_vec())
4112 }
4113 _ => None,
4114 };
4115 let ibp_prior_slice = ibp_prior_vec.as_deref();
4116 // #991 design honesty weights (mean-1 HT inclusion corrections); see
4117 // the seam comment at the per-row residual below.
4118 let row_loss_w = self.row_loss_weights.as_deref();
4119 // Dense full-support index `[0, k_atoms)`, used by the row loop when no
4120 // compact layout is engaged so the active-atom iteration is uniform.
4121 let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
4122 // Per-atom per-axis periodicity, hoisted out of the row loop. Selects
4123 // the smooth von-Mises coordinate prior on wrapped (Circle) axes and
4124 // the Gaussian prior on Euclidean axes; see `ArdAxisPrior`.
4125 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
4126 .assignment
4127 .coords
4128 .iter()
4129 .map(|coord| coord.effective_axis_periods())
4130 .collect();
4131 struct SaeAssemblyRow {
4132 pub(crate) row: usize,
4133 pub(crate) block: ArrowRowBlock,
4134 pub(crate) gb_delta: Vec<(usize, f64)>,
4135 pub(crate) g_blocks: SaeGBlocks,
4136 pub(crate) kron_a_phi: Option<Vec<(usize, f64)>>,
4137 pub(crate) kron_jac: Option<Vec<f64>>,
4138 }
4139
4140 // Per-row scratch reused across all rows a rayon worker processes
4141 // (#1017). The assembly closure is re-run every inner Newton iteration ×
4142 // every outer ρ evaluation; allocating these eight loop-invariant-sized
4143 // buffers (`k_atoms·p`, several `p`, one `q·max(w_dim,p)`) once per
4144 // worker via `map_init` — rather than once per (row × assembly) inside
4145 // the closure — removes the dominant small-allocation traffic the
4146 // eu-stack profile attributed to allocator/barrier spin at the SAE LLM
4147 // shape (p≈5120). Every buffer is fully filled (or `.fill(0.0)`'d) before
4148 // it is read each row, so reuse is bit-identical to the fresh-alloc path;
4149 // `gb_delta`/`g_blocks` are NOT scratch (they move into the returned
4150 // `SaeAssemblyRow`) and stay allocated per row.
4151 struct RowScratch {
4152 pub(crate) decoded: Array2<f64>,
4153 pub(crate) dg_buf: Vec<f64>,
4154 pub(crate) fitted: Array1<f64>,
4155 pub(crate) error: Array1<f64>,
4156 pub(crate) error_white: Vec<f64>,
4157 pub(crate) error_metric: Array1<f64>,
4158 pub(crate) jac_white: Vec<f64>,
4159 pub(crate) decoded_scratch: Vec<f64>,
4160 // #1557 — per-worker scratch for the row assignment vector (filled via
4161 // `_into`, not allocated per row); full `k_atoms`, global-atom indexed.
4162 pub(crate) assignments: Array1<f64>,
4163 }
4164 // #1410: size the per-worker scratch by the COMPACT row dimensions, not
4165 // full `K`/`q`. With a compact layout the assembly only ever touches each
4166 // row's active atoms (≤ `max_active`) and its compact tangent block
4167 // (≤ `max_q_row`); allocating `decoded` at `k_atoms·p` and `jac_white` at
4168 // `q·max(w_dim,p)` was the per-worker `O(K)` blow-up (≈11 GiB/worker at
4169 // K=100k, p=5120 — and `map_init` gives every Rayon worker its own copy).
4170 // Without a layout the dense path needs full `k_atoms`/`q`. `decoded` rows
4171 // are addressed by COMPACT SLOT in the compact branch below (the dense
4172 // branch keeps global-atom rows), so the row count is the max active set.
4173 //
4174 // #1410/#1408/#1409: SOFTMAX now ALSO takes the `Some(layout)` branch
4175 // whenever a `top_k` cap (`set_softmax_active_cap`) or an in-core memory
4176 // breach engages `softmax_active_plan` → `from_dense_weights`, so its
4177 // per-worker `decoded`/`jac_white` scratch is the COMPACT
4178 // `max_active`/`max_q_row` size too — no longer the full `(k_atoms·p)` /
4179 // `(q·max(w_dim,p))` blow-up. JumpReLU / IBP-MAP likewise pay only
4180 // `max_active`. The remaining `None` (full-`K`) branch is the UNCAPPED
4181 // softmax / no-budget-breach case, which genuinely assembles the dense
4182 // entropy block over all `K`; capping it (the compact contract) removes
4183 // the per-worker `O(K)` footprint entirely. (#1410: the residual per-row
4184 // `O(K)` softmax-majorizer scratch — a `row_logits` copy and the full-`K`
4185 // `d`/`H_entropy` blocks — is removed separately; see the active-only
4186 // `active_softmax_gershgorin_majorizer_entry` /
4187 // `softmax_dense_entropy_hessian_entry` helpers below.)
4188 let (decoded_rows, scratch_q) = match row_layout.as_ref() {
4189 Some(layout) => {
4190 let max_active = (0..n)
4191 .map(|r| layout.active_atoms[r].len())
4192 .max()
4193 .unwrap_or(0)
4194 .max(1);
4195 let max_q_row = (0..n)
4196 .map(|r| layout.row_q_active(r))
4197 .max()
4198 .unwrap_or(q)
4199 .max(1);
4200 (max_active, max_q_row)
4201 }
4202 None => (k_atoms, q),
4203 };
4204 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4205 // #1033 large-n: fold the per-row assembly results in row-ordered CHUNKS
4206 // rather than collecting all `n` `SaeAssemblyRow`s at once. The previous
4207 // path materialized the FULL `Vec<SaeAssemblyRow>` (every row's htt/gt
4208 // block + per-row `g_blocks` + `kron_a_phi`/`kron_jac`) AND the fold
4209 // destinations simultaneously — a ~2× transient peak over the resident
4210 // system during the fold, the assembly-side OOM cliff at large `n`. By
4211 // collecting one chunk, folding it into `sys.rows`/`g_blocks`/`kron_*`,
4212 // and dropping the chunk's `Vec` before the next chunk, the transient
4213 // intermediate is bounded to `O(chunk_size)` while the resident output is
4214 // unchanged. The fold stays STRICTLY row-ascending (chunk `[c0..c1)` then
4215 // `[c1..c2)`, rows in order within each chunk), so every `+=` into
4216 // `sys.gb`, the `g_blocks` BTreeMap, and the `kron_*` pushes lands in the
4217 // identical order as the single-pass fold — bit-for-bit the same system.
4218 // Chunk width is the admission plan's `chunk_size` (the same value
4219 // `streaming_plan` sizes for the matrix-free window), floored so a tiny
4220 // plan still makes forward progress.
4221 let assembly_chunk_rows = self
4222 .assembly_chunk_override
4223 .unwrap_or(admission_plan.chunk_size)
4224 .clamp(1, n.max(1));
4225 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4226 let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
4227 let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
4228 let mut chunk_start = 0usize;
4229 while chunk_start < n {
4230 let chunk_end = (chunk_start + assembly_chunk_rows).min(n);
4231 let mut fold_offset_in_chunk = 0usize;
4232 let row_results: Vec<SaeAssemblyRow> = (chunk_start..chunk_end)
4233 .into_par_iter()
4234 .map_init(
4235 || RowScratch {
4236 decoded: Array2::<f64>::zeros((decoded_rows, p)),
4237 dg_buf: vec![0.0_f64; p],
4238 fitted: Array1::<f64>::zeros(p),
4239 error: Array1::<f64>::zeros(p),
4240 error_white: vec![0.0_f64; w_dim],
4241 error_metric: Array1::<f64>::zeros(p),
4242 jac_white: vec![0.0_f64; scratch_q * w_dim.max(p)],
4243 decoded_scratch: vec![0.0_f64; p],
4244 assignments: Array1::<f64>::zeros(k_atoms),
4245 },
4246 |scratch, row| -> Result<SaeAssemblyRow, String> {
4247 // #1557 — mark this rayon row worker as a nested data-parallel
4248 // region so any faer GEMM reached transitively from the per-row
4249 // assembly (frame `Uᵀ` products, the per-row cross-block /
4250 // Schur-accumulation matmuls, the Riemannian projections) pins to
4251 // `Par::Seq` via `effective_global_parallelism` instead of
4252 // re-fanning the global Rayon pool against this outer fan-out
4253 // (the `spindle` barrier-spin). Serial vs parallel over these tiny
4254 // per-row blocks is a single small product, so the result is
4255 // bit-identical. The guard is held for the whole closure body
4256 // including its `?`/`return` paths.
4257 with_nested_parallel(|| {
4258 let RowScratch {
4259 decoded,
4260 dg_buf,
4261 fitted,
4262 error,
4263 error_white,
4264 error_metric,
4265 jac_white,
4266 decoded_scratch,
4267 assignments,
4268 } = scratch;
4269 let mut gb_delta: Vec<(usize, f64)> = Vec::new();
4270 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4271 // #1557 — fill per-worker scratch (bit-identical to alloc path).
4272 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
4273 self.assignment
4274 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
4275 // Reconstruction uses the row's active support: for the dense
4276 // full-support layout this is all atoms (exact); for a compact
4277 // layout the dropped atoms carry negligible `O(a)` reconstruction
4278 // mass and zero curvature, so excluding them keeps `fitted`,
4279 // `error`, and the logit-JVP cross term `(decoded[k] − fitted)`
4280 // mutually consistent with the curvature actually assembled.
4281 fitted.fill(0.0);
4282 let row_active_owned: Option<&[usize]> =
4283 row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
4284 match row_active_owned {
4285 Some(active) => {
4286 // #1410: `decoded` is a compact (max_active × p) buffer
4287 // here; index it by the active-set SLOT `j` (the same
4288 // index the compact tangent block / `coord_starts` use),
4289 // NOT the global `atom_idx`.
4290 for (j, &atom_idx) in active.iter().enumerate() {
4291 let a_k = assignments[atom_idx];
4292 self.atoms[atom_idx]
4293 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4294 for out_col in 0..p {
4295 decoded[[j, out_col]] = decoded_scratch[out_col];
4296 fitted[out_col] += a_k * decoded_scratch[out_col];
4297 }
4298 }
4299 }
4300 None => {
4301 for atom_idx in 0..k_atoms {
4302 let a_k = assignments[atom_idx];
4303 self.atoms[atom_idx]
4304 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4305 for out_col in 0..p {
4306 decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
4307 fitted[out_col] += a_k * decoded_scratch[out_col];
4308 }
4309 }
4310 }
4311 }
4312 for out_col in 0..p {
4313 error[out_col] = fitted[out_col] - target[[row, out_col]];
4314 }
4315 // #991 design-honesty seam: a per-row scalar weight `w_row` on the
4316 // reconstruction channel is exactly the metric `w_row · I_p`, so it
4317 // is realized as a `√w_row` scaling of the THREE row-local data
4318 // quantities at their construction sites — this residual, the
4319 // latent Jacobian (below), and the β basis load `a·φ` (below).
4320 // Every downstream data object then carries exactly one factor of
4321 // `w_row` (gt, htt, htbeta, the β Gram `G`, and the β gradient),
4322 // matching the `w_row`-weighted value `loss_scaled` sums; the
4323 // per-row latent priors (assignment / ARD, added to `gt`/`htt`
4324 // further down) are deliberately unweighted — see the
4325 // `row_loss_weights` field docs. `None` ⇒ `sqrt_row_w == 1.0` and
4326 // no multiply is applied (bit-identical unweighted path).
4327 let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
4328 if sqrt_row_w != 1.0 {
4329 for out_col in 0..p {
4330 error[out_col] *= sqrt_row_w;
4331 }
4332 }
4333 // #974 seam (step 1/2): whiten the per-row residual ONCE.
4334 // * not whitening ⇒ `error_white == error` (length p) and
4335 // `error_metric == error`; every downstream loop is the
4336 // historical isotropic path bit-for-bit.
4337 // * whitening ⇒ `error_white = U_nᵀ r_n ∈ ℝ^{w_dim}` (its squared
4338 // norm is `r_nᵀ M_n r_n`, the value the data-fit sums) and
4339 // `error_metric = U_n (U_nᵀ r_n) = M_n r_n ∈ ℝ^p` (the p-space
4340 // metric-applied residual the β-tier gradient contracts).
4341 match self.row_metric.as_ref() {
4342 Some(metric) if whitens_likelihood => {
4343 let wr = metric.whiten_residual_row(row, error.view());
4344 for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
4345 *slot = v;
4346 }
4347 let mr = metric.apply_metric_row(row, error.view());
4348 for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
4349 *slot = v;
4350 }
4351 }
4352 _ => {
4353 for out_col in 0..p {
4354 error_white[out_col] = error[out_col];
4355 error_metric[out_col] = error[out_col];
4356 }
4357 }
4358 }
4359
4360 // Determine whether this row uses the compact active-set layout.
4361 // * JumpReLU: gated atoms plus the smooth prior's
4362 // machine-precision support enter.
4363 // * IBP-MAP at large K: only the top-`k_active` atoms.
4364 // * Otherwise (small K): the dense uniform-q layout.
4365 let (q_row, mut local_jac_row) = if let Some(layout) = row_layout.as_ref() {
4366 let active = &layout.active_atoms[row];
4367 let starts = &layout.coord_starts[row];
4368 let q_active = layout.row_q_active(row);
4369 let mut jac_compact = Array2::<f64>::zeros((q_active, p));
4370 // Logit JVP rows for active atoms only, using the per-mode
4371 // assignment sensitivity `da_k/dl_k` contracted into the
4372 // decoded / fitted-corrected output direction.
4373 let logits_row = self.assignment.logits.row(row);
4374 for (j, &k) in active.iter().enumerate() {
4375 fill_active_atom_logit_jvp(
4376 ActiveAtomLogitJvp {
4377 mode: self.assignment.mode,
4378 k,
4379 logit_k: logits_row[k],
4380 a_k: assignments[k],
4381 // #1410: compact slot `j`, not global atom `k`.
4382 decoded_k: decoded.row(j),
4383 fitted: fitted.view(),
4384 ibp_prior: ibp_prior_slice,
4385 compact_index: j,
4386 // #1026/#1033: a FIXED logit (ungated, or every
4387 // atom under frozen routing) has a constant gate
4388 // ⇒ zero logit-JVP.
4389 ungated: self.assignment.logit_is_fixed(k),
4390 },
4391 &mut jac_compact,
4392 );
4393 }
4394 // Coordinate JVP rows for active atoms only.
4395 for (j, &k) in active.iter().enumerate() {
4396 let d = self.atoms[k].latent_dim;
4397 let a_k = assignments[k];
4398 let coord_start = starts[j];
4399 for axis in 0..d {
4400 self.atoms[k].fill_decoded_derivative_row(
4401 row,
4402 axis,
4403 dg_buf.as_mut_slice(),
4404 );
4405 for out_col in 0..p {
4406 jac_compact[[coord_start + axis, out_col]] =
4407 a_k * dg_buf[out_col];
4408 }
4409 }
4410 }
4411 (q_active, jac_compact)
4412 } else {
4413 // Fresh per-row Jacobian, structurally identical to the
4414 // JumpReLU branch: every (q × p) element is unconditionally
4415 // overwritten below (assignment-chart JVP rows + coordinate rows), so the
4416 // `Array2::zeros` allocation needs no separate `fill(0.0)` and
4417 // the populated buffer is returned by move without a clone.
4418 let mut jac_row = Array2::<f64>::zeros((q, p));
4419 fill_assignment_logit_jvp_rows(
4420 self.assignment.mode,
4421 self.assignment.logits.row(row),
4422 assignments.view(),
4423 decoded.view(),
4424 fitted.view(),
4425 ibp_prior_slice,
4426 // #1026/#1033: zero logit-JVP rows for FIXED-logit atoms
4427 // (ungated, and all atoms under frozen routing).
4428 &self.assignment.fixed_logit_mask(),
4429 &mut jac_row,
4430 );
4431 // Coordinate columns for all atoms.
4432 for atom_idx in 0..k_atoms {
4433 let d = self.atoms[atom_idx].latent_dim;
4434 let off = coord_offsets[atom_idx];
4435 let a_k = assignments[atom_idx];
4436 for axis in 0..d {
4437 self.atoms[atom_idx].fill_decoded_derivative_row(
4438 row,
4439 axis,
4440 dg_buf.as_mut_slice(),
4441 );
4442 for out_col in 0..p {
4443 jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
4444 }
4445 }
4446 }
4447 (q, jac_row)
4448 };
4449
4450 // #991 design-honesty seam, Jacobian leg: scale the row's latent
4451 // Jacobian by `√w_row` BEFORE the whitening / Kronecker capture so
4452 // htt (= J̃J̃ᵀ), the data part of gt (= J̃ẽ, the residual already
4453 // carries its own √w_row), and the htbeta cross block (J paired
4454 // with the √w_row-scaled β load below) each carry exactly one
4455 // factor of `w_row`. No-op on the unweighted path.
4456 if sqrt_row_w != 1.0 {
4457 for a in 0..q_row {
4458 for out_col in 0..p {
4459 local_jac_row[[a, out_col]] *= sqrt_row_w;
4460 }
4461 }
4462 }
4463
4464 // #974 seam (step 2/2): whiten the per-row Jacobian through the SAME
4465 // metric the residual was whitened by. `jac_white[a*w_dim + k]` holds
4466 // `J̃[a, k] = Σ_out U_n[out, k] · J_n[a, out]` so the t-block
4467 // Gauss-Newton row block is `htt = J̃ J̃ᵀ = J_n M_n J_nᵀ` and
4468 // `gt = J̃ ẽ = J_nᵀ M_n r_n`. When not whitening, `w_dim == p` and the
4469 // whitened jac equals the raw Jacobian, so htt/gt are byte-identical
4470 // to the historical isotropic assembly. Because the SAME `error_white`
4471 // feeds both the value-path data-fit (Σ½ ẽ²) and this gradient
4472 // (J̃ ẽ), the objective and its t-block gradient share one whitening
4473 // — they cannot desync.
4474 if whitens_likelihood {
4475 if let Some(metric) = self.row_metric.as_ref() {
4476 for a in 0..q_row {
4477 for k in 0..w_dim {
4478 let mut acc = 0.0;
4479 // U_n[out, k] read through the metric's factor layout.
4480 for out_col in 0..p {
4481 acc += metric.factor_entry(row, out_col, k)
4482 * local_jac_row[[a, out_col]];
4483 }
4484 jac_white[a * w_dim + k] = acc;
4485 }
4486 }
4487 }
4488 } else {
4489 for a in 0..q_row {
4490 for out_col in 0..p {
4491 jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
4492 }
4493 }
4494 }
4495
4496 // Build the per-row Arrow-Schur block at the row's active dim.
4497 let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
4498 for a in 0..q_row {
4499 let jac_a = &jac_white[a * w_dim..(a + 1) * w_dim];
4500 let g = jac_a
4501 .iter()
4502 .zip(error_white.iter())
4503 .map(|(&j, &e)| j * e)
4504 .sum::<f64>();
4505 block.gt[a] += g;
4506 for b in 0..q_row {
4507 let jac_b = &jac_white[b * w_dim..(b + 1) * w_dim];
4508 let h = jac_a
4509 .iter()
4510 .zip(jac_b.iter())
4511 .map(|(&ja, &jb)| ja * jb)
4512 .sum::<f64>();
4513 block.htt[[a, b]] += h;
4514 }
4515 }
4516
4517 // Assignment prior in logit space.
4518 // For compact layout: position `j` = active_atoms index.
4519 // For dense layout: position `atom_idx` directly.
4520 //
4521 // H-consistency note (#1006 audit / #1416 update). This
4522 // `assignment_hdiag` is the assignment channel's raw diagonal
4523 // curvature, added un-majorized. It is exact for JumpReLU and exact
4524 // within each IBP row/column diagonal, and stores ONLY the diagonal of
4525 // two full-Hessian structures — but those off-diagonal structures are
4526 // now carried elsewhere, not dropped:
4527 //
4528 // * softmax entropy has dense within-row Hessian
4529 // H_kj = (λ/τ²) a_k[δ_kj(m-L_k-1) + a_j(L_k+L_j+1-2m)];
4530 // this diagonal stores its Gershgorin Loewner majorizer (#1419).
4531 // * IBP empirical-π has cross-row rank-one terms per column
4532 // H_(i,k),(j,k) = w score_derivative_k z'_ik z'_jk for i != j.
4533 // This per-row diagonal stores only the diagonal/self-row part;
4534 // the FULL rank-one cross-row block `U D Uᵀ` is now INSTALLED as a
4535 // separate Woodbury source by `set_ibp_cross_row_source` (#1038),
4536 // so the assembled operator is `H_full = H₀' + U D Uᵀ` on the
4537 // NO-SELF base `H₀' = H₀ − Σ_k d_k diag(z'_ik²)` (self term
4538 // downdated, see `IbpCrossRowSource::self_term_downdate`). The
4539 // scalar `D`-coefficient `d_k = w·s'_k` is
4540 // `IbpHessianDiagThirdChannels::cross_row_d` (FD-verified against
4541 // ∂²value/∂ℓ_ik∂ℓ_jk in
4542 // `ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian`),
4543 // and `z_jac` carries `u_k`'s entries `z'_ik`.
4544 //
4545 // The criterion's log|H| and Γ adjoint differentiate this SAME
4546 // `H_full`: the ρ-trace adds the cross-row off-diagonal in
4547 // `assignment_log_strength_hessian_trace` (#1416, dense AND compact
4548 // layouts) and the θ-adjoint adds it in `logdet_theta_adjoint`
4549 // (#1416/#1641), so value and gradient stay on one operator.
4550 let assignment_base = row * k_atoms;
4551 if let Some(layout) = row_layout.as_ref() {
4552 let active = &layout.active_atoms[row];
4553 // #1408/#1409 softmax compact curvature: the entropy
4554 // Hessian diagonal in `assignment_hdiag` is INDEFINITE,
4555 // so on a compact softmax layout write the Gershgorin
4556 // Loewner majorizer `D_kk = Σ_j|H_kj|` (#1419) — the same
4557 // PSD operator the dense softmax branch writes — at each
4558 // active logit slot. `D` is diagonal, so its active
4559 // principal sub-block is `diag(D_kk : k ∈ active)`; each
4560 // `D_kk` is the FULL-`K` abs-row-sum, so it still
4561 // dominates the active principal sub-block of `H_entropy`
4562 // (a genuine majorizer on the retained support). The
4563 // gradient stays the EXACT entropy gradient (it sets the
4564 // fixed point), so majorizing only conditions the Newton
4565 // step. JumpReLU/IBP keep their (exact) diagonal.
4566 //
4567 // #1410: compute only the active `D_kk` directly from this
4568 // row's softmax assignments `a` (= `assignments`, already
4569 // in hand), via `active_softmax_gershgorin_majorizer_entry`.
4570 // The previous `psd_majorizer_abs_row_sums(&row_logits, ..)`
4571 // call allocated TWO length-`K` per-row scratch vectors (a
4572 // fresh `row_logits` copy and the full-`K` returned `d`)
4573 // only to read `d[k]` for the `≤ top_k` active `k` — an
4574 // `O(K)` per-row allocation on the path the compact
4575 // contract keeps `K`-free. The shared `m = Σ_j a_j l_j` is
4576 // the one irreducible `O(K)` pass, computed once per row.
4577 let assignments_slice = assignments
4578 .as_slice()
4579 .expect("softmax assignments row must be contiguous");
4580 let majorizer_log_mean: Option<f64> = softmax_dense
4581 .as_ref()
4582 .map(|_| softmax_majorizer_log_mean(assignments_slice));
4583 for (j, &k) in active.iter().enumerate() {
4584 block.gt[j] += assignment_grad[assignment_base + k];
4585 match (softmax_dense.as_ref(), majorizer_log_mean) {
4586 (Some((_penalty, scale)), Some(m)) => {
4587 block.htt[[j, j]] +=
4588 active_softmax_gershgorin_majorizer_entry(
4589 assignments_slice,
4590 k,
4591 m,
4592 *scale,
4593 );
4594 }
4595 _ => block.htt[[j, j]] += assignment_hdiag[assignment_base + k],
4596 }
4597 }
4598 } else {
4599 for free_idx in 0..assignment_dim {
4600 block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
4601 }
4602 if let Some((penalty, scale)) = softmax_dense.as_ref() {
4603 // #1419: write the genuine Gershgorin Loewner majorizer
4604 // `D = diag(Σ_j|H_kj|)` of the exact entropy Hessian onto the
4605 // row's logit block in place of the EXACT entropy Hessian. The
4606 // entropy Hessian is INDEFINITE (concave directions on
4607 // long-tailed rows), which drove the per-row evidence block
4608 // non-PD and forced the downstream Faddeev–Popov deflation to
4609 // flatten data-relevant logit directions (under-identifying the
4610 // atoms). `D` is a nonnegative diagonal, hence exactly PSD and
4611 // PD-preserving like the previous Fisher surrogate, so the block
4612 // stays PD and the deflation no longer fires on the entropy
4613 // block. Unlike the Fisher metric `G = scale·(diag(a) − a aᵀ)`,
4614 // which is PSD but NOT a majorizer (`G − H_entropy` can be
4615 // indefinite — K=2, a=(0.95,0.05): G₁₁=0.0475 < H₁₁=0.0784,
4616 // #1419), `D` actually satisfies `D ⪰ H_entropy` and `D ⪰ 0`,
4617 // so it is a true MM/Loewner curvature majorizer. Because the
4618 // entropy penalty is a FIXED prior whose stationary point is set
4619 // by its (unchanged) EXACT gradient, replacing its curvature
4620 // with the majorizer only conditions the Newton step and the
4621 // Laplace normalizer's curvature operator — it does NOT move the
4622 // optimum.
4623 //
4624 // Softmax uses the REDUCED K−1 free-logit chart (the last
4625 // reference logit is fixed at 0, `assignment_coord_dim() = K−1`).
4626 // Holding z_{K-1} fixed, the reduced curvature over the free
4627 // logits 0..K−1 is exactly the top-left (K−1)×(K−1) submatrix of
4628 // the full K×K majorizer (the fixed logit contributes no
4629 // row/column to the free curvature). The criterion's `log|H|`
4630 // and the #1006 θ-adjoint differentiate this SAME `D` (see the
4631 // `row_psd_majorizer_logit_derivative` site below), so value and
4632 // adjoint stay on one exact branch.
4633 let row_logits: Vec<f64> = (0..k_atoms)
4634 .map(|k| self.assignment.logits[[row, k]])
4635 .collect();
4636 let h_dense = penalty.row_psd_majorizer(&row_logits, *scale);
4637 for ki in 0..assignment_dim {
4638 for kj in 0..assignment_dim {
4639 block.htt[[ki, kj]] += h_dense[[ki, kj]];
4640 }
4641 }
4642 } else {
4643 for free_idx in 0..assignment_dim {
4644 block.htt[[free_idx, free_idx]] +=
4645 assignment_hdiag[assignment_base + free_idx];
4646 }
4647 }
4648 }
4649
4650 // ARD on each on-atom coordinate.
4651 // For compact layout: only active atoms; coord positions use compact starts.
4652 // For dense layout: all atoms; coord positions use coord_offsets.
4653 if let Some(layout) = row_layout.as_ref() {
4654 let active = &layout.active_atoms[row];
4655 let starts = &layout.coord_starts[row];
4656 for (j, &k) in active.iter().enumerate() {
4657 let coord = &self.assignment.coords[k];
4658 let d = coord.latent_dim();
4659 if rho.log_ard[k].is_empty() {
4660 continue;
4661 }
4662 if rho.log_ard[k].len() != d {
4663 return Err(format!(
4664 "ARD rho atom {k} has len {} but atom dim is {d}",
4665 rho.log_ard[k].len()
4666 ));
4667 }
4668 let row_t = coord.row(row);
4669 let periods = &ard_axis_periods[k];
4670 for axis in 0..d {
4671 // ARD on coords is a genuine per-row prior (each row
4672 // contributes the per-axis prior energy), so it is NOT
4673 // minibatch-scaled — the per-chunk row sums already
4674 // reconstruct the full coordinate prior across a pass.
4675 // The value (`ard_value`/`loss.ard`) and the gradient
4676 // both come from the SAME `ArdAxisPrior` energy, so they
4677 // stay FD-consistent on periodic axes. The exact
4678 // von-Mises curvature `V'' = α·cos(κt)` is INDEFINITE —
4679 // it goes negative for |t| past a quarter period — so
4680 // writing it raw into the Newton/Schur `htt` diagonal
4681 // makes that PSD curvature block indefinite and the Schur
4682 // Cholesky (used both for the Newton step and the exact
4683 // log-det) fails on a non-PD pivot. Accumulate the PSD
4684 // majorizer `max(V'', 0)` instead, exactly as
4685 // `add_sae_coord_penalty` does for the registry coord
4686 // penalties: the positive part keeps `htt` PSD so the
4687 // factorization succeeds, and majorizing the curvature of
4688 // a fixed prior only damps the Newton step — it does not
4689 // move the stationary point (the gradient, which sets the
4690 // fixed point, stays the exact `V'`).
4691 let alpha =
4692 SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
4693 let prior =
4694 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4695 block.gt[starts[j] + axis] += prior.grad;
4696 block.htt[[starts[j] + axis, starts[j] + axis]] +=
4697 prior.hess.max(0.0);
4698 }
4699 }
4700 } else {
4701 for atom_idx in 0..k_atoms {
4702 let coord = &self.assignment.coords[atom_idx];
4703 let d = coord.latent_dim();
4704 if rho.log_ard[atom_idx].is_empty() {
4705 continue;
4706 }
4707 if rho.log_ard[atom_idx].len() != d {
4708 return Err(format!(
4709 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
4710 rho.log_ard[atom_idx].len()
4711 ));
4712 }
4713 let off = coord_offsets[atom_idx];
4714 let row_t = coord.row(row);
4715 let periods = &ard_axis_periods[atom_idx];
4716 for axis in 0..d {
4717 // PSD-majorize the (possibly negative) von-Mises curvature
4718 // into the Newton/Schur `htt` block; see the compact-layout
4719 // branch above for why `max(V'', 0)` is required to keep
4720 // `htt` PD (the exact `V'' = α·cos κt` is indefinite past a
4721 // quarter period and breaks the Schur/log-det Cholesky).
4722 let alpha = SaeManifoldRho::stable_exp_strength(
4723 rho.log_ard[atom_idx][axis],
4724 );
4725 let prior =
4726 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4727 block.gt[off + axis] += prior.grad;
4728 block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
4729 }
4730 }
4731 }
4732
4733 // Beta gradient/Hessian — Kronecker form J_β = φᵀ ⊗ I_p.
4734 //
4735 // The per-row beta Jacobian is
4736 // J_β[out_col, beta_idx] = a_k · phi_k[basis_col] if out_col == out_col(beta_idx)
4737 // 0 otherwise
4738 // so the data-fit Gauss-Newton beta-Hessian factors as a rank-`p`
4739 // sum of outer products. We pre-compute the per-(atom, basis_col)
4740 // scalar `a_k · phi_k` once and reuse it across the `out_col`
4741 // and inner `(atom_j, basis_col2)` loops.
4742 //
4743 // Full-B rows keep the matrix-free Kronecker path below. Factored
4744 // rows write the `q_i × Σ M_k r_k` C-space cross slab directly by
4745 // folding each output-channel contribution through the atom frame,
4746 // so no `q_i × β_dim` slab is ever materialized.
4747 //
4748 // Only the row's active atoms contribute `a_phi` support and data
4749 // curvature: in a compact layout (JumpReLU gate or large-K
4750 // top-`k_active` truncation) the inactive atoms carry zero (gated)
4751 // or sub-cutoff assignment mass and are excluded — this is what
4752 // keeps both the htbeta support and the `G` accumulation
4753 // `O(k_active)` rather than `O(K)`. In the dense full-support
4754 // layout `row_active` spans all atoms.
4755 let row_active: &[usize] = match row_layout.as_ref() {
4756 Some(layout) => layout.active_atoms[row].as_slice(),
4757 None => &all_atoms_index,
4758 };
4759 // #1407: in fixed-decoder mode the β tier is not assembled at
4760 // all — leave gb_delta/g_blocks empty and kron None. htt/gt
4761 // (built above) are the only outputs the frozen-decoder step
4762 // consumes.
4763 let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
4764 // Per-active-atom weighted basis row `a_k · φ_k[·]`, retained so the
4765 // data Gram blocks can be accumulated as clean per-atom-pair outer
4766 // products `(a_k φ_k) (a_{k'} φ_{k'})ᵀ`.
4767 let mut weighted_phi: Vec<(usize, Vec<f64>)> =
4768 Vec::with_capacity(row_active.len());
4769 if !fixed_decoder {
4770 for &atom_idx in row_active {
4771 let atom = &self.atoms[atom_idx];
4772 let atom_beta_off = beta_offsets[atom_idx];
4773 let m = atom.basis_size();
4774 let a_k = assignments[atom_idx];
4775 let mut wphi = Vec::with_capacity(m);
4776 for basis_col in 0..m {
4777 let phi = atom.basis_values[[row, basis_col]];
4778 // #991 design-honesty seam, β leg: the `√w_row` here pairs
4779 // with the `√w_row` on the residual (β gradient =
4780 // `a·φ · M r` ⇒ w_row) and with itself (β Gram `G` and the
4781 // htbeta Kronecker capture ⇒ w_row). `1.0` when unweighted.
4782 let w = a_k * phi * sqrt_row_w;
4783 a_phi.push((atom_beta_off + basis_col * p, w));
4784 wphi.push(w);
4785 }
4786 weighted_phi.push((atom_idx, wphi));
4787 }
4788 // β data-fit gradient `gᵦ += J_βᵀ M_n r_n`. The β-Jacobian is
4789 // `J_β = φ_nᵀ ⊗ I_p`, so `J_βᵀ M_n r_n = φ_n ⊗ (M_n r_n)` —
4790 // contract the basis weight `a·φ` against the p-space metric-applied
4791 // residual `error_metric` (= `M_n r_n`), the SAME whitening the value
4792 // path and t-block share. When not whitening, `error_metric == error`
4793 // and this is byte-identical to the historical `J_βᵀ r`.
4794 for &(beta_base_i, j_beta_i) in a_phi.iter() {
4795 if j_beta_i == 0.0 {
4796 continue;
4797 }
4798 for out_col in 0..p {
4799 gb_delta.push((
4800 beta_base_i + out_col,
4801 j_beta_i * error_metric[out_col],
4802 ));
4803 // No dense hbb write — the sparse `G ⊗ I_p` op installed
4804 // after the loop carries the data-fit GN β-Hessian.
4805 }
4806 }
4807 if frames_engaged {
4808 for &atom_idx in row_active {
4809 let atom = &self.atoms[atom_idx];
4810 let m = atom.basis_size();
4811 let a_k = assignments[atom_idx];
4812 for basis_col in 0..m {
4813 let phi = atom.basis_values[[row, basis_col]];
4814 let w = a_k * phi * sqrt_row_w;
4815 if w == 0.0 {
4816 continue;
4817 }
4818 let c_base = frame_projection.border_offsets[atom_idx]
4819 + basis_col * frame_projection.ranks[atom_idx];
4820 for c in 0..q_row {
4821 let mut hrow = block.htbeta.row_mut(c);
4822 let hrow_slice = hrow
4823 .as_slice_mut()
4824 .expect("htbeta row is contiguous");
4825 for out_col in 0..p {
4826 let value = local_jac_row[[c, out_col]] * w;
4827 frame_projection.accumulate_output_project(
4828 atom_idx, c_base, out_col, value, hrow_slice,
4829 );
4830 }
4831 }
4832 }
4833 }
4834 }
4835 // Data-fit GN β-Hessian: accumulate the channel-independent block
4836 // `G[μ_i, μ_j] += (a_k φ_k)[μ_i] (a_{k'} φ_{k'})[μ_j]` into the
4837 // sparse per-atom-pair map (the `out_col` dimension is carried by
4838 // `I_p`). Only co-occurring `(atom_i, atom_j)` pairs are touched.
4839 for ai in 0..weighted_phi.len() {
4840 let (atom_i, ref wphi_i) = weighted_phi[ai];
4841 let m_i = wphi_i.len();
4842 for aj in 0..weighted_phi.len() {
4843 let (atom_j, ref wphi_j) = weighted_phi[aj];
4844 let m_j = wphi_j.len();
4845 let blk = g_blocks
4846 .entry((atom_i, atom_j))
4847 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4848 for li in 0..m_i {
4849 let wi = wphi_i[li];
4850 if wi == 0.0 {
4851 continue;
4852 }
4853 for lj in 0..m_j {
4854 blk[[li, lj]] += wi * wphi_j[lj];
4855 }
4856 }
4857 }
4858 }
4859 } // #1407 end `if !fixed_decoder` β-tier accumulation
4860 let (kron_a_phi, kron_jac) = if !frames_engaged && !fixed_decoder {
4861 // Flatten local_jac_row row-major into a plain Vec<f64> (q_row * p entries).
4862 let mut jac_flat = vec![0.0_f64; q_row * p];
4863 for c in 0..q_row {
4864 for j in 0..p {
4865 jac_flat[c * p + j] = local_jac_row[[c, j]];
4866 }
4867 }
4868 (Some(a_phi), Some(jac_flat))
4869 } else {
4870 (None, None)
4871 };
4872 Ok(SaeAssemblyRow {
4873 row,
4874 block,
4875 gb_delta,
4876 g_blocks,
4877 kron_a_phi,
4878 kron_jac,
4879 })
4880 }) // #1557 with_nested_parallel
4881 },
4882 )
4883 .collect::<Result<Vec<_>, String>>()?;
4884
4885 // Fold THIS chunk's rows (ascending) into the global accumulators.
4886 // The parallel collect preserves index order within the chunk and
4887 // chunks are visited in ascending `chunk_start` order, so the overall
4888 // fold order is `0,1,2,…,n-1` — identical to the former single-pass
4889 // fold. The `row == chunk_start + fold_offset_in_chunk` assert pins
4890 // that strict sequential arrival (the invariant the `kron_*`
4891 // row-aligned pushes depend on).
4892 for row_result in row_results.into_iter() {
4893 let row = row_result.row;
4894 assert_eq!(
4895 row,
4896 chunk_start + fold_offset_in_chunk,
4897 "parallel SAE row assembly returned rows out of order"
4898 );
4899 fold_offset_in_chunk += 1;
4900 for (idx, value) in row_result.gb_delta {
4901 sys.gb[idx] += value;
4902 }
4903 for ((atom_i, atom_j), data) in row_result.g_blocks {
4904 let m_i = data.nrows();
4905 let m_j = data.ncols();
4906 let blk = g_blocks
4907 .entry((atom_i, atom_j))
4908 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4909 for li in 0..m_i {
4910 for lj in 0..m_j {
4911 blk[[li, lj]] += data[[li, lj]];
4912 }
4913 }
4914 }
4915 if !frames_engaged && !fixed_decoder {
4916 // Rows arrive in ascending order across chunks, so pushing
4917 // here yields `kron_*[row]` aligned to the row index exactly
4918 // as the single-pass `push` did.
4919 kron_a_phi.push(
4920 row_result
4921 .kron_a_phi
4922 .expect("full-B SAE row assembly must return a_phi rows"),
4923 );
4924 kron_jac.push(
4925 row_result
4926 .kron_jac
4927 .expect("full-B SAE row assembly must return local Jacobian rows"),
4928 );
4929 }
4930 sys.rows[row] = row_result.block;
4931 }
4932 chunk_start = chunk_end;
4933 }
4934 // #1407: fixed-decoder early return. The per-row htt/gt are now fully
4935 // assembled (data GN + assignment/ARD prior). Apply only the htt/gt
4936 // Riemannian projection (the decoder/β tier is intentionally absent), then
4937 // return the block-diagonal system. `fixed_decoder_step_from_rows` reads
4938 // only `rows[*].htt`/`gt` + `row_offsets`, so no β-tier object is needed.
4939 if fixed_decoder {
4940 match row_layout.as_ref() {
4941 None => {
4942 // Dense uniform-q: project htt/gt (and the 0-width htbeta, a
4943 // no-op) through the ext-coord manifold.
4944 self.apply_sae_riemannian_geometry(&mut sys);
4945 }
4946 Some(layout) => {
4947 // Compact heterogeneous-q: project each row's htt/gt at its
4948 // own ext-coord point, mirroring the full path's compact
4949 // Riemannian block (htbeta is 0-width here, so skipped).
4950 if !self.ext_coord_manifold().is_euclidean() {
4951 for row_idx in 0..n {
4952 let (manifold_i, point_i) =
4953 self.compact_row_ext_manifold_and_point(row_idx, layout);
4954 let t_i = point_i.view();
4955 let gt_e = sys.rows[row_idx].gt.clone();
4956 let htt_e = sys.rows[row_idx].htt.clone();
4957 sys.rows[row_idx].gt =
4958 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
4959 sys.rows[row_idx].htt = manifold_i.riemannian_hessian_matrix(
4960 t_i,
4961 gt_e.view(),
4962 htt_e.view(),
4963 );
4964 }
4965 }
4966 }
4967 }
4968 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
4969 sys.set_row_gauge_deflation(deflation);
4970 }
4971 self.last_row_layout = row_layout;
4972 self.last_frames_active = frames_engaged;
4973 return Ok(sys);
4974 }
4975 // Apply Riemannian geometry to the per-row row blocks (htt, gt) and
4976 // also to the per-row Kronecker local Jacobians stored in kron_jac.
4977 // When the SAE ext-coord manifold is non-Euclidean (any atom latent
4978 // on sphere / circle / interval), the local Jacobian rows that map
4979 // into the t-block tangent space must be projected via the per-row
4980 // tangent projector P_i. This mirrors what
4981 // `apply_riemannian_latent_geometry` does to `row.htbeta`, applied
4982 // here to the (q × p) kron_jac so the Kronecker htbeta_matvec uses
4983 // the Riemannian-projected form.
4984 // Apply Riemannian geometry only for the dense uniform-q layout. Any
4985 // compact active-set layout (JumpReLU gate or large-K softmax/IBP
4986 // truncation) has heterogeneous q_i; the Riemannian projector path
4987 // requires a uniform latent dimension. The sparse plan only engages on
4988 // Euclidean ext-coord manifolds (see `sparse_active_plan`), so skipping
4989 // the projector here is correct — there is nothing to project.
4990 match row_layout.as_ref() {
4991 None => {
4992 let raw_gt_rows: Vec<Array1<f64>> =
4993 sys.rows.iter().map(|row| row.gt.clone()).collect();
4994 self.apply_sae_riemannian_geometry(&mut sys);
4995 let manifold = self.ext_coord_manifold();
4996 if !frames_engaged && !manifold.is_euclidean() {
4997 let ext = self.ext_coord_matrix();
4998 // Project the local Jacobian columns onto the tangent space at
4999 // each row's ext-coord point. Each column `j` of the row's
5000 // (q_row × p) Jacobian is an ambient-space vector of length
5001 // `q_row`; the manifold projector acts on one such column at a
5002 // time. Working directly on the row-major `jac_flat` storage via
5003 // a single reusable `col_buf` avoids the two dense (q × p) copies
5004 // (flatten→Array2, project, unflatten→Vec) that previously fired
5005 // per row. `t_buf` still holds the row's ext-coord vector.
5006 let mut t_buf = vec![0.0_f64; q];
5007 let mut col_buf = Array1::<f64>::zeros(q);
5008 for row_idx in 0..n {
5009 let ext_row = ext.row(row_idx);
5010 for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
5011 *slot = v;
5012 }
5013 let t_i = ArrayView1::from(t_buf.as_slice());
5014 let raw_gt = raw_gt_rows[row_idx].view();
5015 let jac_flat = &mut kron_jac[row_idx];
5016 let q_row = jac_flat.len() / p;
5017 for j in 0..p {
5018 for c in 0..q_row {
5019 col_buf[c] = jac_flat[c * p + j];
5020 }
5021 let projected_col = manifold.project_vector_to_gradient_tangent(
5022 t_i,
5023 raw_gt.slice(ndarray::s![..q_row]),
5024 col_buf.slice(ndarray::s![..q_row]),
5025 );
5026 for c in 0..q_row {
5027 jac_flat[c * p + j] = projected_col[c];
5028 }
5029 }
5030 }
5031 }
5032 }
5033 Some(layout) => {
5034 // Compact active-set layout (#1117 follow-up): the dense
5035 // `ext_coord_manifold()` is keyed to the uniform full-`q` block
5036 // ordering, so it cannot be applied to the heterogeneous compact
5037 // rows directly. Instead we rebuild, PER ROW, the product manifold
5038 // and ext-coord point in that row's compact column order (see
5039 // `compact_row_ext_manifold_and_point`) and apply the SAME three
5040 // per-row Riemannian operations the dense
5041 // `apply_riemannian_latent_geometry` applies — gradient tangent
5042 // projection of `gt`, the Riemannian Hessian correction of `htt`,
5043 // and the column tangent projection of `htbeta` — plus the
5044 // identical Kronecker `kron_jac` column projection. On the shared
5045 // active support this is byte-identical to slicing the dense
5046 // product manifold, so engaging the sparse plan on a non-Euclidean
5047 // ext manifold is now correct (the former
5048 // `is_euclidean()`-only guard in `sparse_active_plan` is lifted).
5049 //
5050 // Euclidean ext manifolds still skip all of this (every
5051 // per-row manifold is a product of Euclidean parts whose
5052 // projector is the identity); we early-out so those rows stay
5053 // byte-for-byte the historical compact path.
5054 if !self.ext_coord_manifold().is_euclidean() {
5055 for row_idx in 0..n {
5056 let (manifold_i, point_i) =
5057 self.compact_row_ext_manifold_and_point(row_idx, layout);
5058 let t_i = point_i.view();
5059 // gt / htt / htbeta on the compact ArrowRowBlock, exactly
5060 // as `apply_riemannian_latent_geometry` does for dense
5061 // uniform-q rows.
5062 let gt_e = sys.rows[row_idx].gt.clone();
5063 let htt_e = sys.rows[row_idx].htt.clone();
5064 sys.rows[row_idx].gt =
5065 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
5066 sys.rows[row_idx].htt =
5067 manifold_i.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
5068 // #1406: only the frames path holds a real dense `htbeta`
5069 // slab; the matrix-free path leaves it 0-width (the
5070 // cross-block geometry is applied to `kron_jac` below), so
5071 // projecting a zero-column matrix is a no-op we skip.
5072 if frames_engaged {
5073 let htbeta_e = sys.rows[row_idx].htbeta.clone();
5074 sys.rows[row_idx].htbeta = manifold_i
5075 .project_matrix_columns_to_gradient_tangent(
5076 t_i,
5077 gt_e.view(),
5078 htbeta_e.view(),
5079 );
5080 }
5081 // Kronecker local-Jacobian column projection (full-B path
5082 // only), using the SAME pre-projection gradient `gt_e` so
5083 // the cross-block geometry matches the dense branch.
5084 if !frames_engaged {
5085 let jac_flat = &mut kron_jac[row_idx];
5086 let q_row = jac_flat.len() / p;
5087 let mut col_buf = Array1::<f64>::zeros(q_row);
5088 for j in 0..p {
5089 for c in 0..q_row {
5090 col_buf[c] = jac_flat[c * p + j];
5091 }
5092 let projected_col = manifold_i.project_vector_to_gradient_tangent(
5093 t_i,
5094 gt_e.view(),
5095 col_buf.view(),
5096 );
5097 for c in 0..q_row {
5098 jac_flat[c * p + j] = projected_col[c];
5099 }
5100 }
5101 }
5102 }
5103 }
5104 }
5105 }
5106 // Build and install the full-B Kronecker htbeta_matvec.
5107 //
5108 // `SaeKroneckerRows` holds per-row `(a_phi, local_jac)` and implements
5109 // the cross-block operator without ever materialising the dense
5110 // `(q × K·p)` slab. The cross-block factorises as `H_tβ = L · J_β`,
5111 // where `J_β = φᵀ ⊗ I_p` projects a length-`K` β vector onto the
5112 // `p`-dimensional decoded output space (`apply_jbeta`) and `L_i` is
5113 // the per-row `(q_i × p)` assignment+coordinate Jacobian that lifts
5114 // that p-vector into the row's `q_i`-dim tangent block (`apply_l`).
5115 // Both factors are required: the contract of `set_row_htbeta_operator`
5116 // is `out.len() == d` (= `q_i`), so writing `apply_jbeta`'s p-vector
5117 // output directly into a length-`q_i` buffer overflows whenever
5118 // `p > q_i` (the common case once `p` reflects real feature width).
5119 // Symmetric for the transpose: `H_βt = J_βᵀ · Lᵀ`, so apply `Lᵀ`
5120 // first to map the q_i-vector back to p-space, then scatter through
5121 // the support.
5122 // #1017/#1026: the legacy full-B device PCG assumes `G ⊗ I_p`, while
5123 // framed systems carry `G_ij ⊗ W_ij` with rank-r atom blocks. Feeding a
5124 // framed system to that kernel would silently return the wrong Newton
5125 // step. Framed device PCG therefore needs the dedicated factored kernel.
5126 // #1033 large-n: the per-row support `kron_a_phi` and local Jacobians
5127 // `kron_jac` are consumed by BOTH the host matrix-free row operator
5128 // (`SaeKroneckerRows`) and the solver's `DeviceSaePcgData`. Previously
5129 // each took its own full `O(n·q·p)` / `O(n·k_active)` clone, so the
5130 // always-resident footprint of the CPU non-frames path carried TWO copies
5131 // of the dominant Jacobian slab. Promote each to a single `Arc<[…]>` once
5132 // and hand both consumers a refcount bump (`O(1)`) — the backing
5133 // allocation is shared, halving the resident per-row Jacobian memory.
5134 // Reads are identical (`&arc[row]`, `.len()`), so the assembled system and
5135 // every matvec are bit-for-bit unchanged.
5136 let device_rows = if frames_engaged {
5137 None
5138 } else {
5139 let a_phi_shared: Arc<[Vec<(usize, f64)>]> =
5140 Arc::from(std::mem::take(&mut kron_a_phi).into_boxed_slice());
5141 let jac_shared: Arc<[Vec<f64>]> =
5142 Arc::from(std::mem::take(&mut kron_jac).into_boxed_slice());
5143 Some((a_phi_shared, jac_shared))
5144 };
5145 if !frames_engaged {
5146 let (a_phi_shared, jac_shared) = device_rows
5147 .clone()
5148 .expect("non-frames path always populates device_rows");
5149 let kron = Arc::new(SaeKroneckerRows::new(p, a_phi_shared, jac_shared));
5150 let kron_t = Arc::clone(&kron);
5151 let p_dim = p;
5152 sys.set_row_htbeta_operator(
5153 move |row_idx, x, out| {
5154 // out = L_i · (J_β · x). Allocate a length-p scratch buffer
5155 // for the intermediate decoded-output vector; both factors
5156 // overwrite their output buffers (`apply_jbeta` zeroes
5157 // before accumulating, `apply_l` writes per-row), so no
5158 // pre-zeroing of `u_p`/`out` is needed.
5159 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5160 let mut u_p = vec![0.0_f64; p_dim];
5161 if let Some(xs) = x.as_slice() {
5162 kron.apply_jbeta(row_idx, xs, &mut u_p);
5163 } else {
5164 let x_vec: Vec<f64> = x.iter().copied().collect();
5165 kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
5166 }
5167 kron.apply_l(row_idx, &u_p, out_slice);
5168 },
5169 move |row_idx, v, out| {
5170 // out += J_βᵀ · (Lᵀ · v). `apply_l_t` accumulates into a
5171 // zero-initialised length-p buffer to produce the p-vector
5172 // `Lᵀ v`; `scatter_jbeta_t` then adds φ_i[s] · u_p[j] into
5173 // the length-K β accumulator at each active `(s, j)`.
5174 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5175 let mut u_p = vec![0.0_f64; p_dim];
5176 if let Some(vs) = v.as_slice() {
5177 kron_t.apply_l_t(row_idx, vs, &mut u_p);
5178 } else {
5179 let v_vec: Vec<f64> = v.iter().copied().collect();
5180 kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
5181 }
5182 kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
5183 },
5184 );
5185 }
5186 let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
5187 let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
5188 Some(&frame_projection)
5189 } else {
5190 None
5191 };
5192 if let Some(registry) = analytic_penalties {
5193 // Upfront validation: refuse penalty kinds the SAE row layout
5194 // cannot host, and refuse mixed-d row-block configurations.
5195 // This makes the dispatch loop below total — no runtime
5196 // "unsupported penalty" fallthrough, no K-gating.
5197 self.validate_analytic_penalty_registry(registry)
5198 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5199 beta_penalty_assembly = self
5200 .add_sae_analytic_penalty_contributions(
5201 &mut sys,
5202 registry,
5203 penalty_scale,
5204 row_layout.as_ref(),
5205 dense_beta_curvature,
5206 factored_row_projection,
5207 )
5208 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5209 }
5210 // #1026 — decoder repulsion (collinearity-gated, registry-independent):
5211 // accumulate into the full-`B` β-tier here, BEFORE the frame transform,
5212 // so a framed system carries it identically to the analytic β penalties.
5213 // No-op unless two atoms are near-collinear (the frozen gate is `None`).
5214 if self.add_sae_decoder_repulsion(&mut sys, penalty_scale, dense_beta_curvature) {
5215 beta_penalty_assembly.record_curvature(dense_beta_curvature);
5216 }
5217 // #1026/#1522 — interior-point collapse-prevention barriers. The amplitude
5218 // barrier supplies the OUTWARD radial force at the zero-decoder collapse
5219 // point (the principal failure state the threshold repulsion skips), and
5220 // the separation barrier supplies the alignment-divergent separating
5221 // curvature on normalized shapes weighted by coactivation. Both accumulate
5222 // into the full-`B` β-tier here, BEFORE the frame transform, so a framed
5223 // system carries them identically to the analytic β penalties.
5224 // #1610 — on the dense path the barrier's Levenberg majorizer scatters
5225 // onto `sys.hbb`; on the matrix-free / framed production path `sys.hbb` is
5226 // unused, so the barrier hands back a per-atom scalar ridge which we fold
5227 // into `smooth_scaled_s` (the single source for the CPU composite penalty
5228 // op AND the device smooth blocks), restoring the collapse-prevention
5229 // curvature the operator was silently dropping there.
5230 let mut sep_atom_curv = vec![0.0_f64; self.atoms.len()];
5231 if self.add_sae_separation_barrier(
5232 &mut sys,
5233 penalty_scale,
5234 dense_beta_curvature,
5235 &mut sep_atom_curv,
5236 ) {
5237 if dense_beta_curvature {
5238 beta_penalty_assembly.record_curvature(true);
5239 } else {
5240 // Fold the per-atom majorizer `lev_k·I_{M_k}` into the smooth
5241 // penalty factor `λ S_k`. With `⊗ I_p` (full-`B`) or `⊗ I_{r_k}`
5242 // (factored, `U_kᵀU_k = I`) this is exactly the `lev_k·I` block
5243 // diagonal the dense path writes — and it now flows through the
5244 // structured penalty op and the device smooth blocks. No
5245 // `deferred_factored` mark: the curvature is in the smooth op, not
5246 // a deferred dense block, so the device path stays engaged.
5247 for atom_idx in 0..self.atoms.len() {
5248 let c = sep_atom_curv[atom_idx];
5249 if c > 0.0 {
5250 let m = smooth_scaled_s[atom_idx].nrows();
5251 for i in 0..m {
5252 smooth_scaled_s[atom_idx][[i, i]] += c;
5253 }
5254 smooth_ops[atom_idx] = Arc::new(IdentityRightKroneckerPenaltyOp {
5255 factor_a: smooth_scaled_s[atom_idx].clone(),
5256 p,
5257 global_offset: beta_offsets[atom_idx],
5258 k: beta_dim,
5259 });
5260 }
5261 }
5262 }
5263 }
5264 if frames_engaged {
5265 // ── #972 / #977 T1 — FACTORED β-tier transform ──────────────────
5266 //
5267 // The entire β-tier above was assembled in the full-`B` (p-wide)
5268 // layout: `sys.gb` is `g_B` (length `beta_dim`), `sys.hbb` carries
5269 // any analytic Beta-tier penalty, and `g_blocks` is the
5270 // FRAME-INDEPENDENT basis Gram. We now rebuild the β-tier in the
5271 // factored coordinate space `C` (width `factored_border_dim`), the
5272 // full-`B` system sandwiched by `Φ = blkdiag(I_{M_k} ⊗ U_k)`:
5273 // * gradient `g_C = Φᵀ g_B` (per atom `(g_B U_k)`),
5274 // * data H `Φᵀ(G⊗I_p)Φ = G_{ij}⊗(U_iᵀU_j)`,
5275 // * smooth `λ S_k ⊗ I_{r_k}` (since `U_kᵀU_k = I`),
5276 // * analytic `Φᵀ hbb Φ` (dense, only if written).
5277 // Un-framed atoms ride the `r_k = p, U_k = I_p` identity special case.
5278 let off_c = &frame_projection.border_offsets;
5279 let ranks = &frame_projection.ranks;
5280 let basis_sizes = &frame_projection.basis_sizes;
5281 let border_dim = frame_projection.border_dim();
5282 let gb_c = frame_projection.project_border_vec(sys.gb.view());
5283
5284 // Data β-Hessian: `G_{ij} ⊗ W_{ij}` with `W_{ij} = U_iᵀU_j`. The
5285 // basis Gram `g_blocks` is unchanged; only the output factor is the
5286 // per-pair frame overlap (`I_{r_k}` within a framed atom, `I_p` for
5287 // un-framed).
5288 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
5289 for ((atom_i, atom_j), data) in g_blocks.into_iter() {
5290 if data.iter().all(|&v| v == 0.0) {
5291 continue;
5292 }
5293 // `W_{ij} = U_iᵀ U_j` from the precomputed per-atom frames.
5294 let w = self.frame_cross_factor(atom_i, atom_j);
5295 frame_blocks.push(FactoredFrameGBlock {
5296 atom_i,
5297 atom_j,
5298 g: data,
5299 w,
5300 });
5301 }
5302 // #1017/#1026 — snapshot the factored data-fit blocks for the
5303 // frames-engaged device PCG BEFORE `FactoredFrameKroneckerOp::new`
5304 // consumes them. Cheap clone (co-occurring blocks only).
5305 let device_frame_blocks = frame_blocks.clone();
5306 let data_op =
5307 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
5308
5309 // Smooth penalty in factored space: `λ S_k ⊗ I_{r_k}` at `off_C[k]`.
5310 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
5311 for k in 0..self.atoms.len() {
5312 let r = ranks[k];
5313 ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
5314 factor_a: smooth_scaled_s[k].clone(),
5315 p: r,
5316 global_offset: off_c[k],
5317 k: border_dim,
5318 }));
5319 }
5320 ops.push(Arc::new(data_op));
5321 // Analytic Beta-tier penalty: project the dense full-`B` `hbb` block
5322 // `Φᵀ hbb Φ` into the factored space. Only present when a Beta-tier
5323 // penalty actually wrote `hbb` (else `hbb` is all-zero and the dense
5324 // `(border_dim)²` op is skipped entirely, exactly as full-`B`).
5325 if beta_penalty_assembly.dense_written {
5326 let hbb_c =
5327 self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
5328 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5329 } else if beta_penalty_assembly.deferred_factored {
5330 // Registry Beta-tier curvature deferred to factored-space probing.
5331 // The registry may be absent when `deferred_factored` was set ONLY
5332 // by the frozen-gate decoder repulsion (which is
5333 // registry-independent), so start from a zero factored block in
5334 // that case instead of unwrapping.
5335 let mut hbb_c = match analytic_penalties {
5336 Some(registry) => self.build_factored_beta_penalty_curvature(
5337 registry,
5338 penalty_scale,
5339 &frame_projection,
5340 ),
5341 None => Array2::<f64>::zeros((
5342 frame_projection.border_dim(),
5343 frame_projection.border_dim(),
5344 )),
5345 };
5346 // #1610 — the frozen-gate decoder repulsion's PSD majorizer was
5347 // dropped on this matrix-free/framed path (only its gradient was
5348 // applied). Project it into the factored block via the same
5349 // `psd_majorizer_hvp` + frame-projection probe pattern the registry
5350 // DecoderIncoherence uses, so the collapse-prevention curvature
5351 // reaches the operator here too. No-op when no repulsion is active.
5352 self.add_factored_repulsion_curvature(
5353 &mut hbb_c,
5354 penalty_scale,
5355 &frame_projection,
5356 );
5357 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5358 }
5359
5360 // Re-point the system's β-tier to the factored width. The t-tier
5361 // (per-row `htt`, `gt`) is frame-independent and untouched; row
5362 // cross-block slabs were allocated and assembled directly in
5363 // factored coordinates, so analytic row supplements and data-fit
5364 // cross terms already share shape `(q_i × factored_border_dim)`.
5365 sys.k = border_dim;
5366 sys.gb = gb_c;
5367 self.reclaim_border_hbb_workspace(&mut sys);
5368 // Factored per-atom block ranges for the block-Jacobi Schur
5369 // preconditioner: `[off_C[k] .. off_C[k] + M_k·r_k]`.
5370 let mut block_ranges: Vec<std::ops::Range<usize>> =
5371 Vec::with_capacity(self.atoms.len());
5372 for k in 0..self.atoms.len() {
5373 let start = off_c[k];
5374 block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
5375 }
5376 sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
5377 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
5378 // #1017/#1026 — install the frames-engaged device SAE PCG data. Skipped
5379 // (CPU fallback) when a dense analytic Beta-tier penalty fired (the
5380 // device kernel does not model that extra dense term). Builder:
5381 // `crate::frames::build_framed_device_sae_data`.
5382 let has_dense_beta_penalty =
5383 beta_penalty_assembly.dense_written || beta_penalty_assembly.deferred_factored;
5384 if !has_dense_beta_penalty {
5385 let device = crate::frames::build_framed_device_sae_data(
5386 crate::frames::FramedDeviceArgs {
5387 p,
5388 border_dim,
5389 border_offsets: off_c.as_slice(),
5390 ranks: ranks.as_slice(),
5391 basis_sizes: basis_sizes.as_slice(),
5392 smooth_scaled_s: &smooth_scaled_s,
5393 frame_blocks: device_frame_blocks,
5394 rows: &sys.rows,
5395 },
5396 );
5397 sys.set_device_sae_pcg_data(device);
5398 }
5399 } else {
5400 let (device_a_phi, device_local_jac) =
5401 device_rows.expect("full-beta SAE PCG rows are cloned before row operator install");
5402 // Wire per-atom β block ranges so the Jacobi preconditioner builds one
5403 // dense Schur sub-block per atom (block-Jacobi) instead of scalar-diagonal
5404 // inversion. Each atom's decoder coefficients form a natural block:
5405 // `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
5406 sys.set_block_offsets(self.beta_block_offsets());
5407 // Install the composite BetaPenaltyOp (#296): smoothness contributions
5408 // via per-atom KroneckerPenaltyOp (avoid dense K×K materialisation), the
5409 // data-fit Gauss-Newton β-Hessian as the structured `G ⊗ I_p`
5410 // SparseBlockKroneckerPenaltyOp (block-sparse over co-occurring
5411 // `(atom, atom')` pairs, block-diagonal across the `p` output channels,
5412 // identical per channel), plus — only when a Beta-tier analytic penalty
5413 // was written — the dense `sys.hbb` residual contribution. When no beta
5414 // penalty fired, `sys.hbb` is all-zero and the dense `(K·p)²` operator
5415 // is skipped entirely. The sparse data op tracks only the active-atom
5416 // couplings, so its storage and matvec cost scale with `k_active`, not
5417 // `K`, at `K = 100K`.
5418 // Convert the per-atom-pair coupling map into `SparseGBlock`s keyed
5419 // by μ-space offsets. Empty blocks (no co-occurrence) are simply
5420 // absent from the map.
5421 let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
5422 .into_iter()
5423 .filter_map(|((atom_i, atom_j), data)| {
5424 if data.iter().all(|&v| v == 0.0) {
5425 None
5426 } else {
5427 Some(SparseGBlock {
5428 row_off: mu_offsets[atom_i],
5429 col_off: mu_offsets[atom_j],
5430 data,
5431 })
5432 }
5433 })
5434 .collect();
5435 let device_smooth_blocks = smooth_scaled_s
5436 .iter()
5437 .enumerate()
5438 .map(|(atom_idx, factor_a)| {
5439 // #1117 — rank deficiency is removed at the basis layer, so the
5440 // device PCG smooth block is just `λ S_k ⊗ I_p` (full-rank
5441 // design); no data-null deflation is folded in here.
5442 DeviceSaeSmoothBlock {
5443 global_offset: beta_offsets[atom_idx],
5444 factor_a: factor_a.clone(),
5445 }
5446 })
5447 .collect();
5448 sys.set_device_sae_pcg_data(DeviceSaePcgData {
5449 p,
5450 beta_dim,
5451 a_phi: device_a_phi,
5452 local_jac: device_local_jac,
5453 smooth_blocks: device_smooth_blocks,
5454 sparse_g_blocks: g_sparse_blocks.clone(),
5455 frame: None,
5456 });
5457 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
5458 ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
5459 p,
5460 dim_a: m_total,
5461 k: beta_dim,
5462 blocks: g_sparse_blocks,
5463 }));
5464 if beta_penalty_assembly.dense_written {
5465 ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
5466 }
5467 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
5468 self.reclaim_border_hbb_workspace(&mut sys);
5469 }
5470 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
5471 sys.set_row_gauge_deflation(deflation);
5472 }
5473 // #1038 IBP cross-row Woodbury source. The exact IBP Hessian has the
5474 // per-column rank-one cross-row block `H_(i,k),(j,k) = w·s'_k·z'_ik·z'_jk`
5475 // (for ALL `i,j`, including the `i=j` self term) that couples DISTINCT
5476 // latent rows through the shared empirical mass `M_k = Σ_i z_ik`. The
5477 // assembled row-block-diagonal `htt` already carries the `i=j` self term
5478 // `w·s'_k·z'_ik²` — it is the first summand of `assignment_hdiag`'s
5479 // `hessian_diag` value `w·(score_derivative·z_jac² + score·c_ik)` written
5480 // at the logit diagonal above. So the consumer (`solver::arrow_schur`,
5481 // #1038 `IbpCrossRowSource`/`CrossRowWoodbury`) DOWNDATES exactly
5482 // `Σ_k d_k·z'_ik²` (`self_term_downdate`) to recover the NO-SELF base
5483 // `H₀'`, then re-adds the FULL rank-one `U D Uᵀ` via the determinant
5484 // lemma — so value, the evidence log-determinant, and the θ/ρ-adjoint all
5485 // differentiate the SAME `H_full = H₀' + U D Uᵀ`.
5486 //
5487 // The source is built from the SAME `ibp_assignment_third_channels`
5488 // operator the #1006 θ-adjoint consumes:
5489 // * `d[k] = cross_row_d[k] = w·s'_k = w·score_derivative_k` (the column
5490 // `D`-coefficient — NOT sign-definite, hence the consumer's
5491 // indefinite-capacitance LU);
5492 // * `entries[(i,k)] = (global_t_index, k, z'_ik)` with `z'_ik =
5493 // z_jac[i·K + k]`. For the DENSE layout (`assignment_coord_dim() = K`,
5494 // `last_row_layout = None`) atom `k`'s logit slot is local position `k`
5495 // of row `i`'s block, so `global_t_index = sys.row_offsets[i] + k`. For
5496 // the COMPACT layout (#1420) only the row's active atoms are
5497 // coordinates and atom `k` lives at local position `pos` of
5498 // `active_atoms[row]`, so `global_t_index = sys.row_offsets[i] + pos`.
5499 // Both pin the `U`-column convention bit-for-bit to the consumer's
5500 // `ibp_logit_sites`/`row_vars_for_cache_row` slot mapping.
5501 if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
5502 let mut entries: Vec<(usize, usize, f64)> = Vec::with_capacity(n * k_atoms);
5503 for row in 0..n {
5504 let start = row * k_atoms;
5505 let g_base = sys.row_offsets[row];
5506 match row_layout.as_ref() {
5507 // #1420: compact layout — the local logit slot `pos` (not the
5508 // global atom index `k`) is the t-coordinate. Atom `k`'s logit
5509 // lives at local position `pos` of `active_atoms[row]`, so emit
5510 // `(g_base + pos, atom, z_jac[row·K + atom])` for the active set
5511 // only. Using `g_base + k` would attach atom `k`'s derivative to
5512 // the wrong slot (and run out of range for compact rows),
5513 // violating the `IbpCrossRowSource` contract.
5514 Some(layout) => {
5515 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
5516 let z_prime = channels.z_jac[start + atom];
5517 entries.push((g_base + pos, atom, z_prime));
5518 }
5519 }
5520 // Dense layout: atom `k`'s logit slot is local position `k`.
5521 None => {
5522 for k in 0..k_atoms {
5523 let z_prime = channels.z_jac[start + k];
5524 entries.push((g_base + k, k, z_prime));
5525 }
5526 }
5527 }
5528 }
5529 let source = IbpCrossRowSource {
5530 r: k_atoms,
5531 d: channels.cross_row_d.clone(),
5532 entries,
5533 };
5534 sys.set_ibp_cross_row_source(source);
5535 }
5536 // Store the active-set layout for `apply_newton_step`.
5537 self.last_row_layout = row_layout;
5538 // Record whether `delta_beta` from this system is a factored ΔC (needs a
5539 // frame lift) or a full-`B` ΔB. Read by `apply_newton_step_impl`.
5540 self.last_frames_active = frames_engaged;
5541 Ok(sys)
5542 }
5543
5544 /// Project a dense full-`B` Beta-tier penalty Hessian `hbb` (`beta_dim ×
5545 /// beta_dim`, the analytic `∂²P/∂B∂B` block) into the factored coordinate
5546 /// space `Φᵀ hbb Φ` (`border_dim × border_dim`) for the #972 / #977 T1
5547 /// frame transform. `Φ = blkdiag(I_{M_k} ⊗ U_k)` maps C-space → B-space, so
5548 /// the projected block contracts both index legs through the per-atom frames.
5549 ///
5550 /// The projection is done in two passes to stay `O(beta_dim · border_dim +
5551 /// border_dim²)` instead of forming the dense `Φ` explicitly: first
5552 /// `T = hbb · Φ` (right multiply, columns fold `U`), then `Φᵀ · T` (left
5553 /// multiply, rows fold `U`). Analytic Beta-tier penalties are rare and small,
5554 /// so this only fires when one is actually installed.
5555 pub(crate) fn project_dense_penalty_to_factored(
5556 &self,
5557 hbb: ArrayView2<'_, f64>,
5558 projection: &FrameProjection,
5559 ) -> Array2<f64> {
5560 projection.project_block(hbb)
5561 }
5562
5563 pub(crate) fn build_factored_beta_penalty_curvature(
5564 &self,
5565 registry: &AnalyticPenaltyRegistry,
5566 penalty_scale: f64,
5567 projection: &FrameProjection,
5568 ) -> Array2<f64> {
5569 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
5570 let layout = registry.rho_layout();
5571 let target_beta = self.flatten_beta();
5572 let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
5573 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
5574 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
5575 continue;
5576 }
5577 let rho_local = rho_global.slice(s![rho_slice.clone()]);
5578 match tier {
5579 PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
5580 self.add_factored_beta_penalty_curvature_for_penalty(
5581 &mut hbb_c,
5582 penalty,
5583 target_beta.view(),
5584 rho_local,
5585 penalty_scale,
5586 projection,
5587 );
5588 }
5589 PenaltyTier::Beta => {
5590 self.add_factored_beta_penalty_curvature_for_penalty(
5591 &mut hbb_c,
5592 penalty,
5593 target_beta.view(),
5594 rho_local,
5595 penalty_scale,
5596 projection,
5597 );
5598 }
5599 _ => {}
5600 }
5601 }
5602 hbb_c
5603 }
5604
5605 pub(crate) fn add_factored_beta_penalty_curvature_for_penalty(
5606 &self,
5607 hbb_c: &mut Array2<f64>,
5608 penalty: &AnalyticPenaltyKind,
5609 target_beta: ArrayView1<'_, f64>,
5610 rho_local: ArrayView1<'_, f64>,
5611 penalty_scale: f64,
5612 projection: &FrameProjection,
5613 ) {
5614 let p = self.output_dim();
5615 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
5616 let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
5617 return;
5618 };
5619 let beta_dim = self.beta_dim();
5620 let mut probe = Array1::<f64>::zeros(beta_dim);
5621 for k in 0..self.atoms.len() {
5622 for basis_col in 0..projection.basis_sizes[k] {
5623 for frame_col in 0..projection.ranks[k] {
5624 probe.fill(0.0);
5625 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5626 let col = projection.border_offsets[k]
5627 + basis_col * projection.ranks[k]
5628 + frame_col;
5629 let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5630 projection
5631 .project_border_vec(hv.view())
5632 .iter()
5633 .enumerate()
5634 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5635 }
5636 }
5637 }
5638 return;
5639 }
5640 if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
5641 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
5642 let atom_idx = projection
5643 .beta_offsets
5644 .iter()
5645 .position(|&offset| offset == start)
5646 .expect("live mechanism-sparsity offset must match an SAE atom");
5647 let block_len = end - start;
5648 let mut local_penalty = per_atom.clone();
5649 local_penalty.target = PsiSlice {
5650 range: 0..block_len,
5651 latent_dim: Some(projection.basis_sizes[atom_idx]),
5652 };
5653 let block = target_beta.slice(s![start..end]);
5654 let mut probe = Array1::<f64>::zeros(block_len);
5655 for basis_col in 0..projection.basis_sizes[atom_idx] {
5656 for frame_col in 0..projection.ranks[atom_idx] {
5657 probe.fill(0.0);
5658 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5659 let col = projection.border_offsets[atom_idx]
5660 + basis_col * projection.ranks[atom_idx]
5661 + frame_col;
5662 let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
5663 projection.project_local_atom_vec_into(
5664 atom_idx,
5665 hv.view(),
5666 hbb_c.column_mut(col),
5667 penalty_scale,
5668 );
5669 }
5670 }
5671 }
5672 return;
5673 }
5674 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
5675 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
5676 let atom_idx = projection
5677 .beta_offsets
5678 .iter()
5679 .position(|&offset| offset == start)
5680 .expect("live nuclear-norm offset must match an SAE atom");
5681 let block = target_beta.slice(s![start..end]);
5682 let block_len = end - start;
5683 let mut probe = Array1::<f64>::zeros(block_len);
5684 for basis_col in 0..projection.basis_sizes[atom_idx] {
5685 for frame_col in 0..projection.ranks[atom_idx] {
5686 probe.fill(0.0);
5687 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5688 let col = projection.border_offsets[atom_idx]
5689 + basis_col * projection.ranks[atom_idx]
5690 + frame_col;
5691 let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
5692 projection.project_local_atom_vec_into(
5693 atom_idx,
5694 hv.view(),
5695 hbb_c.column_mut(col),
5696 penalty_scale,
5697 );
5698 }
5699 }
5700 }
5701 return;
5702 }
5703 let beta_dim = self.beta_dim();
5704 let mut probe = Array1::<f64>::zeros(beta_dim);
5705 for k in 0..self.atoms.len() {
5706 for basis_col in 0..projection.basis_sizes[k] {
5707 for frame_col in 0..projection.ranks[k] {
5708 probe.fill(0.0);
5709 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5710 let col =
5711 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5712 let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5713 projection
5714 .project_border_vec(hv.view())
5715 .iter()
5716 .enumerate()
5717 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5718 }
5719 }
5720 }
5721 assert_eq!(p, self.output_dim());
5722 }
5723
5724 /// #1610 — project the frozen-gate decoder-repulsion PSD majorizer into the
5725 /// factored β block `hbb_c`. Mirrors the `DecoderIncoherence` arm of
5726 /// [`Self::add_factored_beta_penalty_curvature_for_penalty`] but sources the
5727 /// penalty from [`Self::live_decoder_repulsion_penalty`] (registry-independent,
5728 /// collinearity-gated), so the repulsion curvature reaches the operator on the
5729 /// matrix-free/framed path where the dense `sys.hbb` write is unused. No-op
5730 /// when no repulsion is active.
5731 pub(crate) fn add_factored_repulsion_curvature(
5732 &self,
5733 hbb_c: &mut Array2<f64>,
5734 penalty_scale: f64,
5735 projection: &FrameProjection,
5736 ) {
5737 let Some(per_fit) = self.live_decoder_repulsion_penalty() else {
5738 return;
5739 };
5740 let beta_dim = self.beta_dim();
5741 let target_beta = self.flatten_beta();
5742 // The repulsion penalty is non-learnable; its strength is already folded
5743 // into the frozen gate (see `live_decoder_repulsion_penalty`), so the rho
5744 // slice is empty/inert.
5745 let rho_local = Array1::<f64>::zeros(0);
5746 let mut probe = Array1::<f64>::zeros(beta_dim);
5747 for k in 0..self.atoms.len() {
5748 for basis_col in 0..projection.basis_sizes[k] {
5749 for frame_col in 0..projection.ranks[k] {
5750 probe.fill(0.0);
5751 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5752 let col =
5753 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5754 let hv =
5755 per_fit.psd_majorizer_hvp(target_beta.view(), rho_local.view(), probe.view());
5756 projection
5757 .project_border_vec(hv.view())
5758 .iter()
5759 .enumerate()
5760 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5761 }
5762 }
5763 }
5764 }
5765
5766 pub(crate) fn ext_coord_matrix(&self) -> Array2<f64> {
5767 let n = self.n_obs();
5768 let q = self.assignment.row_block_dim();
5769 let flat = self.assignment.flatten_ext_coords();
5770 let mut out = Array2::<f64>::zeros((n, q));
5771 for row in 0..n {
5772 for col in 0..q {
5773 out[[row, col]] = flat[row * q + col];
5774 }
5775 }
5776 out
5777 }
5778
5779 pub(crate) fn ext_coord_manifold(&self) -> LatentManifold {
5780 let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
5781 for _ in 0..self.assignment.assignment_coord_dim() {
5782 parts.push(LatentManifold::Euclidean);
5783 }
5784 let mut any_constrained = false;
5785 for coord in &self.assignment.coords {
5786 if coord.manifold().is_euclidean() {
5787 for _ in 0..coord.latent_dim() {
5788 parts.push(LatentManifold::Euclidean);
5789 }
5790 } else {
5791 any_constrained = true;
5792 parts.push(coord.manifold().clone());
5793 }
5794 }
5795 if any_constrained {
5796 LatentManifold::Product(parts)
5797 } else {
5798 LatentManifold::Euclidean
5799 }
5800 }
5801
5802 pub(crate) fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
5803 let manifold = self.ext_coord_manifold();
5804 if manifold.is_euclidean() {
5805 return;
5806 }
5807 let ext = self.ext_coord_matrix();
5808 let latent =
5809 LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
5810 sys.apply_riemannian_latent_geometry(&latent);
5811 }
5812
5813 /// Build the compact-layout ext-coord product manifold and point for one row.
5814 ///
5815 /// The dense `ext_coord_manifold()` is keyed to the full-`q` block ordering
5816 /// `[assignment parts (all Euclidean for IBP-MAP / JumpReLU), then per-atom
5817 /// coord blocks in atom order]`. A compact active-set row instead lays its
5818 /// `q_active` columns out as `[one Euclidean logit slot per active atom,
5819 /// then each active atom's coord block in `active` order]` (see
5820 /// [`SaeRowLayout::from_active_atoms`] / `coord_starts`). To reuse the exact
5821 /// per-row Riemannian projector on the compact block we rebuild a product
5822 /// manifold and the matching ext-coord point in that compact order: the
5823 /// `active.len()` logit slots are `Euclidean` (the assignment channel is
5824 /// always Euclidean for the modes that engage sparsity — `assignment_coord_dim
5825 /// == k_atoms`), and each active atom contributes its own coordinate
5826 /// manifold. On the shared active support this is byte-identical to slicing
5827 /// the dense full-`q` product manifold, so the compact projection matches the
5828 /// dense path exactly — it only drops the inactive atoms' (negligible-mass)
5829 /// coordinate blocks the compact layout already excludes from curvature.
5830 ///
5831 /// Returns `(manifold, t_compact)` where `t_compact` has length `q_active`.
5832 /// The logit-slot entries of `t_compact` are filled from the row logits (the
5833 /// Euclidean projector ignores the point, so any finite value is equivalent;
5834 /// using the true logits keeps the point well-defined and finite).
5835 pub(crate) fn compact_row_ext_manifold_and_point(
5836 &self,
5837 row: usize,
5838 layout: &SaeRowLayout,
5839 ) -> (LatentManifold, Array1<f64>) {
5840 let active = &layout.active_atoms[row];
5841 let q_active = layout.row_q_active(row);
5842 let mut parts: Vec<LatentManifold> = Vec::with_capacity(active.len() + active.len());
5843 let mut point = Array1::<f64>::zeros(q_active);
5844 // Logit slots: one Euclidean part per active atom, in `active` order.
5845 let logits_row = self.assignment.logits.row(row);
5846 for (j, &k) in active.iter().enumerate() {
5847 parts.push(LatentManifold::Euclidean);
5848 point[j] = logits_row[k];
5849 }
5850 // Coordinate blocks: each active atom's coordinate manifold + point, at
5851 // the compact coord start the layout assigned it.
5852 for (j, &k) in active.iter().enumerate() {
5853 let coord = &self.assignment.coords[k];
5854 let d = coord.latent_dim();
5855 let coord_start = layout.coord_starts[row][j];
5856 let manifold_k = coord.manifold();
5857 // A `d`-dim coordinate whose manifold is a product (e.g. a torus =
5858 // Circle×Circle) already carries its `d` parts; a scalar manifold is
5859 // one part. Either way the manifold's ambient width must equal `d`,
5860 // matching the `d` compact columns at `coord_start`.
5861 parts.push(manifold_k.clone());
5862 let coord_point = coord.row(row);
5863 for axis in 0..d {
5864 point[coord_start + axis] = coord_point[axis];
5865 }
5866 }
5867 (LatentManifold::Product(parts), point)
5868 }
5869
5870 /// Numerical rank of a symmetric matrix: the count of eigenvalues
5871 /// exceeding `tol · max_eig`, with `tol = 1e-9` (the conventional
5872 /// relative spectral cutoff used elsewhere in the codebase).
5873 ///
5874 /// Used to count the penalised dimension of each atom's `smooth_penalty`
5875 /// `S_k` so the REML criterion's `−½·p·rank(S)·log λ_smooth` Occam term
5876 /// uses the *effective* penalty rank rather than the ambient basis size
5877 /// (a thin-plate / B-spline penalty has a non-trivial null space).
5878 pub(crate) fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
5879 if s.nrows() != s.ncols() {
5880 return Err(format!(
5881 "SaeManifoldTerm::symmetric_rank: matrix must be square, got {}x{}",
5882 s.nrows(),
5883 s.ncols()
5884 ));
5885 }
5886 let m = s.ncols();
5887 if m == 0 {
5888 return Ok(0);
5889 }
5890 // Symmetrize defensively through the shared ndarray helper. The SAE
5891 // rank cutoff is intentionally local to the SAE evidence contract; only
5892 // the symmetric cleanup is shared with the other construction modules.
5893 let mut sym = s.clone();
5894 gam_linalg::matrix::symmetrize_in_place(&mut sym);
5895 let (evals, _evecs) = sym
5896 .eigh(Side::Lower)
5897 .map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
5898 let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
5899 if !(max_eig > 0.0) {
5900 return Ok(0);
5901 }
5902 let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
5903 Ok(evals.iter().filter(|&&v| v > tol).count())
5904 }
5905
5906 /// Penalised quasi-Laplace evidence score for the SAE term at a FIXED ρ.
5907 ///
5908 /// #1421: this is NOT a true normalized-prior REML/evidence objective. The
5909 /// assignment priors (softmax entropy, JumpReLU) have NO finite normalizer:
5910 /// for softmax the reference-logit chart sends `P(ℓ)→0` as a free logit →±∞
5911 /// so `∫ e^{−λP} dℓ = ∞`, and JumpReLU's bounded penalty `0<P<λ` keeps
5912 /// `e^{−λP}` bounded below over an unbounded domain, also divergent. There is
5913 /// therefore no ρ-independent assignment-prior normalizer that can be dropped
5914 /// as a constant. The smoothing-penalty `−½log|λS|_+` term IS a genuine
5915 /// (proper-Gaussian) REML normalizer and is kept exactly; the rest is a
5916 /// penalized quasi-Laplace score (Laplace curvature term `½log|H|` around the
5917 /// inner optimum), which the engine minimizes over ρ.
5918 ///
5919 /// Runs the inner `(t, β)` arrow-Schur Newton solve to convergence at the
5920 /// supplied ρ (with NO in-loop ARD update — ρ is owned by the engine),
5921 /// then forms the Laplace/REML cost
5922 ///
5923 /// ```text
5924 /// V(ρ) = ℓ_pen(t̂, β̂; ρ) + ½ log|H(t̂, β̂; ρ)|
5925 /// − ½ · p · (Σ_k rank S_k) · log λ_smooth
5926 /// ```
5927 ///
5928 /// where `ℓ_pen = loss.total()` is the penalised objective at the inner
5929 /// optimum and `½ log|H|` is the Laplace normaliser. `H` is the joint
5930 /// `(t, β)` Hessian assembled by the arrow-Schur system; its `H_tt` block
5931 /// carries `α = exp(log_ard)` on its diagonal, so as α grows `½ log|H|`
5932 /// rises while the `−½·n·log α` already inside `loss.ard` falls — their
5933 /// balance IS the effective-dof term that the deleted `α = n/‖t‖²` rule
5934 /// dropped, which is why the criterion needs no clamp to stay finite on a
5935 /// collapsing axis.
5936 ///
5937 /// The final `−½·p·rank(S)·log λ_smooth` term is the smoothing-penalty
5938 /// normaliser `−½ log|λ S|_+` restricted to its ρ-dependent part: `S_k` is
5939 /// shared across all `p` decoder output channels (the `⊗ I_p` Kronecker
5940 /// structure), so `log|λ S|_+ = p·rank(S)·log λ + p·log|S|_+`, and the
5941 /// `½ p·log|S|_+` piece is ρ-independent. The ρ-independent additive
5942 /// constants that ARE dropped here (they shift `V` by a constant and do not
5943 /// affect the ρ-argmin) are the `2π` Laplace constant and the base
5944 /// `½ p·log|S|_+` penalty logdet. #1421: NO assignment-prior normalizer is
5945 /// dropped, because none exists (softmax/JumpReLU priors are improper — see
5946 /// the doc on this function): the quasi-Laplace score simply omits a
5947 /// normalizer that is not a finite constant.
5948 ///
5949 /// Returns `(V, loss)` so the engine can both rank ρ and surface the inner
5950 /// loss breakdown.
5951 pub fn reml_criterion(
5952 &mut self,
5953 target: ArrayView2<'_, f64>,
5954 rho: &SaeManifoldRho,
5955 registry: Option<&AnalyticPenaltyRegistry>,
5956 inner_max_iter: usize,
5957 learning_rate: f64,
5958 ridge_ext_coord: f64,
5959 ridge_beta: f64,
5960 ) -> Result<(f64, SaeManifoldLoss), String> {
5961 self.reml_criterion_with_refine_policy(
5962 target,
5963 rho,
5964 registry,
5965 inner_max_iter,
5966 learning_rate,
5967 ridge_ext_coord,
5968 ridge_beta,
5969 true,
5970 )
5971 }
5972
5973 pub(crate) fn reml_criterion_with_refine_policy(
5974 &mut self,
5975 target: ArrayView2<'_, f64>,
5976 rho: &SaeManifoldRho,
5977 registry: Option<&AnalyticPenaltyRegistry>,
5978 inner_max_iter: usize,
5979 learning_rate: f64,
5980 ridge_ext_coord: f64,
5981 ridge_beta: f64,
5982 refine_progress_extension: bool,
5983 ) -> Result<(f64, SaeManifoldLoss), String> {
5984 let plan = self.streaming_plan().admitted_or_error(
5985 self.n_obs(),
5986 self.output_dim(),
5987 self.k_atoms(),
5988 )?;
5989 if plan.streaming {
5990 // #1225: streaming and dense MUST optimize the SAME mathematical
5991 // objective — the full REML criterion `loss.total() + extra_penalty +
5992 // ½ log|H| − Occam`. The streaming branch previously returned only
5993 // `loss.total() + extra_penalty_energy`, dropping the Laplace
5994 // normalizer `½ log|H|` and the Occam term, so large shapes (exactly
5995 // where streaming is needed) were ranked by penalized loss rather than
5996 // REML — and dense vs streaming disagreed on the objective. Route
5997 // through the streaming exact-logdet path, which assembles the same
5998 // chunk-by-chunk-bit-identical `½ log|H|_stream` and the same
5999 // `−Occam`/extra-penalty terms as the dense `reml_criterion_with_cache`
6000 // (different memory strategy, same objective).
6001 self.reml_criterion_streaming_exact(
6002 target,
6003 rho,
6004 registry,
6005 inner_max_iter,
6006 learning_rate,
6007 ridge_ext_coord,
6008 ridge_beta,
6009 )
6010 } else {
6011 let (v, loss, _cache) = self.reml_criterion_with_cache_refine_policy(
6012 target,
6013 rho,
6014 registry,
6015 inner_max_iter,
6016 learning_rate,
6017 ridge_ext_coord,
6018 ridge_beta,
6019 refine_progress_extension,
6020 )?;
6021 Ok((v, loss))
6022 }
6023 }
6024
6025 /// As [`Self::reml_criterion`], but also returns the converged undamped
6026 /// `ArrowFactorCache` so callers (the EFS fixed-point step) can read the
6027 /// selected-inverse traces `(H⁻¹)_tt` / `(H⁻¹)_ββ` without re-factoring.
6028 /// The cache is the single shared O(K³) Direct factor; both the
6029 /// log-determinant criterion and the Fellner-Schall ρ-step consume it.
6030 pub fn reml_criterion_with_cache(
6031 &mut self,
6032 target: ArrayView2<'_, f64>,
6033 rho: &SaeManifoldRho,
6034 registry: Option<&AnalyticPenaltyRegistry>,
6035 inner_max_iter: usize,
6036 learning_rate: f64,
6037 ridge_ext_coord: f64,
6038 ridge_beta: f64,
6039 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6040 self.reml_criterion_with_cache_refine_policy(
6041 target,
6042 rho,
6043 registry,
6044 inner_max_iter,
6045 learning_rate,
6046 ridge_ext_coord,
6047 ridge_beta,
6048 true,
6049 )
6050 }
6051
6052 pub(crate) fn reml_criterion_with_cache_refine_policy(
6053 &mut self,
6054 target: ArrayView2<'_, f64>,
6055 rho: &SaeManifoldRho,
6056 registry: Option<&AnalyticPenaltyRegistry>,
6057 inner_max_iter: usize,
6058 learning_rate: f64,
6059 ridge_ext_coord: f64,
6060 ridge_beta: f64,
6061 refine_progress_extension: bool,
6062 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6063 let admission_plan = self.streaming_plan().admitted_or_error(
6064 self.n_obs(),
6065 self.output_dim(),
6066 self.k_atoms(),
6067 )?;
6068 if !admission_plan.direct_logdet_admitted() {
6069 return Err(format!(
6070 "SaeManifoldTerm::reml_criterion_with_cache: predicted working set {} bytes exceeds budget {} bytes for dense evidence cache; shape n={},p={},K={}; cost-only streaming route is required",
6071 admission_plan.estimated_direct_peak_bytes,
6072 admission_plan.in_core_budget_bytes,
6073 self.n_obs(),
6074 self.output_dim(),
6075 self.k_atoms()
6076 ));
6077 }
6078 // 1. Run the inner (t, β) Newton solve to convergence at FIXED ρ.
6079 // `run_joint_fit_arrow_schur` no longer touches ρ.
6080 let mut rho_fixed = rho.clone();
6081 let mut loss = self.run_joint_fit_arrow_schur(
6082 target,
6083 &mut rho_fixed,
6084 registry,
6085 inner_max_iter,
6086 learning_rate,
6087 ridge_ext_coord,
6088 ridge_beta,
6089 )?;
6090
6091 // 2. Drive the inner (t, β) solve to the KKT/step-converged optimum and
6092 // take one final UNDAMPED factor there to obtain the joint Hessian
6093 // log-determinant. We force ridge = 0 and the dense `Direct` Schur
6094 // mode so `arrow_log_det_from_cache` returns the exact
6095 // `log|H| = Σ_i log|H_tt^(i)| + log|Schur_β|` (it rejects damped
6096 // factors and InexactPCG caches, which have no dense Schur factor).
6097 // This is the same evidence convention the main GAM REML path uses.
6098 // The shared `converge_inner_for_undamped_logdet` driver guarantees
6099 // the per-row `H_tt^(i)` blocks are PD at the converged optimum so
6100 // the undamped (`ridge = 0`) factorization succeeds — the streaming
6101 // log-det path reuses the identical driver so both rank the same
6102 // converged Laplace optimum and stay bit-identical.
6103 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
6104 let cache = self.converge_inner_for_undamped_logdet(
6105 target,
6106 rho,
6107 &mut rho_fixed,
6108 registry,
6109 inner_max_iter,
6110 learning_rate,
6111 ridge_ext_coord,
6112 ridge_beta,
6113 &mut loss,
6114 &options,
6115 refine_progress_extension,
6116 )?;
6117 self.record_evidence_gauge_deflation_count(cache.gauge_deflated_directions)?;
6118 loss.evidence_gauge_deflated_directions = cache.gauge_deflated_directions;
6119 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
6120 // Distinguish a GENUINE infeasibility — a probed ρ where the joint
6121 // Hessian is not PD so the Laplace evidence log-det is undefined —
6122 // from a real factorization defect. The cross-row IBP Woodbury
6123 // capacitance `C = I_R + D·Uᵀ H₀'⁻¹ U` can have det ≤ 0 at a ρ the
6124 // outer optimizer line-searches into (the indefinite basin adjacent
6125 // to the PD region); there the log-det legitimately does not exist.
6126 // That refusal must be RECOVERABLE (the outer BFGS should get +∞ and
6127 // steer back into the PD region), exactly like the "non-PD per-row
6128 // H_tt block" refusal — not a fatal `RemlOptimizationFailed` that
6129 // aborts the whole fit. See `is_recoverable_value_probe_refusal`.
6130 // (The old message claimed "no dense Schur factor", which is false
6131 // here — the Schur factor is present; the Woodbury correction is the
6132 // non-finite term.)
6133 if cache.cross_row_woodbury.is_some()
6134 && !cache.cross_row_woodbury_log_det().is_finite()
6135 {
6136 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
6137 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
6138 .to_string()
6139 } else {
6140 "SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None \
6141 (undamped joint Hessian log-det unavailable for the Laplace normaliser)"
6142 .to_string()
6143 }
6144 })?;
6145
6146 // 3. Smoothing-penalty Occam term `−½·Σ_k r_k·rank(S_k)·log λ_smooth`
6147 // plus the profiled-frame evidence-dimension correction
6148 // `+½·Σ_k r_k·(p−r_k)·log λ_smooth` (issue #972). On the full-`B` path
6149 // (`r_k == p`, no frames) this is exactly the historical
6150 // `½·p·(Σ rank S_k)·log λ_smooth`, so the small-model criterion is
6151 // unchanged. The single seam is `reml_occam_term`, shared with the
6152 // streaming path so both rank the identical Laplace dimension count.
6153 let occam = self.reml_occam_term(rho)?;
6154
6155 // Decoder-block analytic-penalty energy (#671/#672). The inner solve
6156 // descended this energy (it enters `gb`/`hbb`) but it had no native
6157 // `loss.*` representative, so the Laplace criterion `v` was scoring a
6158 // different objective than the one minimized. Add the converged
6159 // decoder-penalty value so the ρ-sweep ranks the same penalized
6160 // deviance. Excludes the Psi-tier ARD/assignment penalties already
6161 // accounted for in `loss.total()` (see
6162 // `analytic_decoder_penalty_value_total`).
6163 // Extra analytic-penalty energy (#671/#737). Decoder-block penalties and
6164 // coordinate-tier isometry enter the inner solve but have no `loss.*`
6165 // representative, so the Laplace criterion must add them explicitly to
6166 // rank the same penalized deviance the Newton solve descends.
6167 let extra_penalty_energy = match registry {
6168 Some(reg) => self
6169 .reml_extra_penalty_value_total(reg)
6170 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
6171 None => 0.0,
6172 };
6173
6174 let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
6175 Ok((v, loss, cache))
6176 }
6177
6178 /// The #1037 quotient-dimension invariant: a Laplace normalizer `½log|H|` is
6179 /// only comparable across ρ at a COMMON quotient (gauge-deflation) dimension.
6180 /// The first observation pins the expected count; a later match is a no-op.
6181 ///
6182 /// A later observation that DIFFERS is, under the K>1 fit, a LEGITIMATE
6183 /// quotient-dimension event — an atom born, reseeded (the #976 collapse
6184 /// guards), or rank-reduced moves the number of gauge-flat rows. Because a
6185 /// deflated direction is lifted to unit stiffness and contributes the
6186 /// ρ-independent `log 1 = 0` to the evidence, re-anchoring the comparison to
6187 /// the new dimension is exactly evidence-preserving and keeps every future
6188 /// cross-ρ comparison consistent — the principled response, not an abort.
6189 ///
6190 /// The genuine pathology the guard still catches is a count that NEVER
6191 /// STABILIZES: re-anchors are bounded by the per-atom structural-event budget
6192 /// (`k·(reseed_budget+1)+1`), and a runaway quotient dimension past that
6193 /// bound refuses loudly. This supersedes the prior strict-constant guard and
6194 /// its ±1 flicker band (#1117) at root — the band was masking exactly the
6195 /// legitimate K>1 dimension changes this re-anchoring now handles.
6196 pub(crate) fn record_evidence_gauge_deflation_count(
6197 &mut self,
6198 count: usize,
6199 ) -> Result<(), String> {
6200 match self.expected_evidence_gauge_deflated_directions {
6201 Some(expected) if expected == count => Ok(()),
6202 Some(expected) => {
6203 // A change in the gauge-deflation count between two evidence
6204 // factorizations is a legitimate quotient-dimension event under
6205 // the K>1 fit: an atom can be born, reseeded (the #976 collapse
6206 // guards), or rank-reduced across the ρ-walk, and each such event
6207 // moves the number of gauge-flat rows. The #1037 invariant is
6208 // NOT "the count never changes" — it is "two Laplace normalizers
6209 // are only comparable at a COMMON quotient dimension". The
6210 // principled response to a legitimate change is therefore to
6211 // RE-ANCHOR the comparison to the new dimension (so every future
6212 // cross-ρ comparison within the optimization is consistent), not
6213 // to abort the fit. This is exactly evidence-preserving: each
6214 // gauge-deflated direction is lifted to unit stiffness and
6215 // contributes the ρ-independent `log 1 = 0` to `½log|H|`, so the
6216 // converged criterion value is identical whether a given row is
6217 // counted as deflated or not — only the BOOKKEEPING dimension
6218 // must agree across a comparison, and re-anchoring restores that.
6219 //
6220 // The genuine pathology the guard must still catch is a count
6221 // that NEVER STABILIZES — an OSCILLATING quotient dimension that
6222 // re-anchors without converging, signalling a truly ill-posed
6223 // evidence surface. But the deflation count is NOT a discrete
6224 // dictionary-level event count: it is the per-ROW-summed number of
6225 // near-null evidence directions across all N rows (#1217). On real
6226 // K≥2 activations it is an O(N) quantity that drifts SMOOTHLY and
6227 // monotonically as the conditioning improves over the ρ-walk
6228 // (e.g. 171→156→…→113 as smoothing increases) — a benign,
6229 // evidence-neutral change (each deflated direction contributes the
6230 // ρ-independent `log 1 = 0` to `½log|H|`, so re-anchoring never
6231 // moves the criterion value). Charging such a monotone drift
6232 // against a `k`-sized "structural event" budget was wrong: it
6233 // counts threshold crossings of a continuous per-row quantity, not
6234 // atom births/reseeds, so the budget tripped on a perfectly healthy
6235 // converging K=2 fit (#1217 regression from the #1189/#1190
6236 // basin-escape fixes, which shifted which rows sit near the
6237 // deflation floor).
6238 //
6239 // The principled discriminator is DIRECTION REVERSALS: a count
6240 // that drifts one way and settles is benign; a count that bounces
6241 // up and down without settling is the oscillating-quotient
6242 // pathology. We therefore charge the re-anchor budget ONLY on a
6243 // reversal of the change direction, and size the budget by the
6244 // number of distinct dictionary structural events (births/reseeds)
6245 // that can each legitimately flip the drift direction. A monotone
6246 // drift of any length re-anchors freely (it is consistently
6247 // re-anchored and evidence-neutral); a genuinely oscillating count
6248 // exhausts the reversal budget and refuses loudly.
6249 let delta_sign: i8 = if count > expected { 1 } else { -1 };
6250 let is_reversal = self.evidence_gauge_deflation_last_delta_sign != 0
6251 && delta_sign != self.evidence_gauge_deflation_last_delta_sign;
6252 self.evidence_gauge_deflation_last_delta_sign = delta_sign;
6253 // A reversal alone is NOT the pathology — a BOUNDED flicker of a
6254 // few rows crossing the near-null deflation floor reverses
6255 // direction every step yet is the discretization jitter of a
6256 // continuous evidence spectrum, fully evidence-neutral (each
6257 // deflated direction contributes `log 1 = 0` either way). The
6258 // genuine "quotient dimension not stabilizing" pathology is a
6259 // WIDE-amplitude oscillation: a substantial FRACTION of the
6260 // dimension flipping back and forth. The count is an O(N) per-row
6261 // sum, so the discriminator must be the reversal AMPLITUDE
6262 // relative to the dimension level, not the bare reversal. Charge
6263 // the reversal budget only when a reversal's step exceeds a
6264 // relative jitter band; a converged-but-flickering fit (e.g.
6265 // 150<->147 on N=200, ~2% of the level) re-anchors freely while a
6266 // true runaway (e.g. 9<->2, ~80% of the level) still trips every
6267 // reversal and exhausts the budget. This was the second #795 root
6268 // cause: the single-planted-circle fit's per-row count flickers
6269 // 150<->147 near the deflation floor, so the bare-reversal guard
6270 // refused the simplest possible fit — with the isometry gauge ON
6271 // *or* OFF — long before the gauge magnitude mattered.
6272 let amplitude = expected.abs_diff(count);
6273 let level = expected.max(count);
6274 let jitter_band = (level / 4).max(2);
6275 if is_reversal && amplitude > jitter_band {
6276 self.evidence_gauge_deflation_reanchors += 1;
6277 }
6278 let reversal_budget = self
6279 .k_atoms()
6280 .saturating_mul(
6281 SAE_ATOM_COLLAPSE_RESEED_BUDGET
6282 + SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET
6283 + 1,
6284 )
6285 .saturating_add(1);
6286 if self.evidence_gauge_deflation_reanchors > reversal_budget {
6287 return Err(format!(
6288 "SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count \
6289 oscillated (reversed direction {} times, last {expected}->{count}) within \
6290 one optimization, exceeding the {reversal_budget}-reversal budget for {} \
6291 atoms; the quotient dimension is not stabilizing, refusing to compare \
6292 Laplace normalizers",
6293 self.evidence_gauge_deflation_reanchors,
6294 self.k_atoms()
6295 ));
6296 }
6297 log::debug!(
6298 "SaeManifoldTerm::reml_criterion: per-row evidence deflation count changed \
6299 {expected}->{count} (a benign per-row conditioning drift across the ρ-walk; \
6300 reversal {}/{reversal_budget}); re-anchoring the Laplace normalizer comparison \
6301 to the new dimension",
6302 self.evidence_gauge_deflation_reanchors
6303 );
6304 self.expected_evidence_gauge_deflated_directions = Some(count);
6305 Ok(())
6306 }
6307 None => {
6308 self.expected_evidence_gauge_deflated_directions = Some(count);
6309 Ok(())
6310 }
6311 }
6312 }
6313
6314 pub(crate) fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
6315 matches!(
6316 err,
6317 ArrowSchurError::PerRowFactorFailed { reason, .. }
6318 if reason.contains("H_tt is non-PD at base ridge")
6319 && reason.contains("evidence mode preserves the genuine Cholesky")
6320 )
6321 }
6322
6323 /// Drive the inner `(t, β)` Newton solve to the KKT/step-converged optimum
6324 /// and return the final UNDAMPED (`ridge = 0`) joint-Hessian factor cache.
6325 ///
6326 /// The Laplace normaliser `½log|H|` is only the correct REML criterion at
6327 /// the inner optimum `(t̂, β̂)`, so the criterion must refine the inner state
6328 /// until either the KKT gradient or the undamped Newton step meets tolerance
6329 /// before factoring. Crucially, **at the converged optimum the per-row
6330 /// `H_tt^(i)` blocks are PD**, so the undamped (`ridge = 0`) factorization
6331 /// succeeds; an off-optimum iterate (e.g. the initial seed, or a state
6332 /// stopped after only `inner_max_iter` steps) can have an indefinite /
6333 /// rank-deficient per-row block (`p_out = 1` → rank-1 `JᵀJ`, softmax
6334 /// assignment-sparsity negative logit curvature) that surfaces
6335 /// `PerRowFactorFailed` from the undamped `factor_one_row`. Both the dense
6336 /// (`reml_criterion_with_cache`) and the streaming
6337 /// (`reml_criterion_streaming_exact`) evidence paths route through this same
6338 /// driver, so they converge to the identical inner state and their
6339 /// `ridge = 0` log-determinants stay bit-identical (#847).
6340 pub(crate) fn converge_inner_for_undamped_logdet(
6341 &mut self,
6342 target: ArrayView2<'_, f64>,
6343 rho: &SaeManifoldRho,
6344 rho_fixed: &mut SaeManifoldRho,
6345 registry: Option<&AnalyticPenaltyRegistry>,
6346 inner_max_iter: usize,
6347 learning_rate: f64,
6348 ridge_ext_coord: f64,
6349 ridge_beta: f64,
6350 loss: &mut SaeManifoldLoss,
6351 options: &ArrowSolveOptions,
6352 refine_progress_extension: bool,
6353 ) -> Result<ArrowFactorCache, String> {
6354 // `inner_max_iter == 0` is a genuine FREEZE of the inner `(t, β)` state
6355 // — a verbatim warm-start reuse, not a convergence request (gam#577/#579,
6356 // #850). The convergence/refinement loop below MUST NOT run even one
6357 // Newton step in that case (the old `inner_max_iter.max(1)` floor moved
6358 // β off the seed), so we factor exactly once at the frozen iterate and
6359 // return that undamped cache without invoking the stationarity gate.
6360 // The caller has already run `run_joint_fit_arrow_schur(..., 0, ...)`,
6361 // which under the `max_iter == 0` freeze (gam#577/#579, #850) runs ONLY
6362 // the β-neutral basis refresh and returns the loss without touching β —
6363 // it skips the rank-reduction, frame activation, re-seed guards, and the
6364 // #1026 decoder-LSQ polish that would otherwise refit β off the seed — so
6365 // `self` is at the warm-start β here.
6366 if inner_max_iter == 0 {
6367 let sys = self
6368 .assemble_arrow_schur(target, rho, registry)
6369 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6370 let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
6371 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6372 // The frozen-state Newton step (factored.0, factored.1) is discarded
6373 // — only the undamped factor cache (factored.2) is consumed for the
6374 // log-det / selected-inverse traces; β stays at the warm-start seed.
6375 return Ok(factored.2);
6376 }
6377 let mut total_inner_iter = inner_max_iter;
6378 let accepted_base_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
6379 let value_probe_base_refine_iter = inner_max_iter.max(1).saturating_mul(4).max(16);
6380 let base_refine_iter = if refine_progress_extension {
6381 accepted_base_refine_iter
6382 } else {
6383 value_probe_base_refine_iter
6384 };
6385 let progress_refine_iter = if refine_progress_extension {
6386 inner_max_iter.max(1).saturating_mul(64).max(256)
6387 } else {
6388 base_refine_iter
6389 };
6390 let mut previous_refine_grad_norm: Option<f64> = None;
6391 let mut saw_refine_progress = false;
6392 // #1051 — objective-stagnation convergence. On an ill-conditioned
6393 // penalised bilinear fit (the euclidean / Duchon decoder × latent
6394 // coordinate system on a trivial shape), the inner Newton crawls: each
6395 // refine round lowers the penalised objective by a shrinking amount while
6396 // the KKT gradient and the undamped step stay above their relative
6397 // tolerances (the near-singular Schur amplifies the step in the
6398 // weakly-identified decoder direction). The grad-OR-step gate then never
6399 // fires and the solve is rejected as "did not converge" — the 1e12
6400 // sentinel. A Newton/LM iterate whose objective has stopped decreasing
6401 // to within `√εmach` of its scale IS the numerical inner optimum; ranking
6402 // the Laplace criterion there is correct. We accept that fixed point
6403 // instead of grinding the budget.
6404 let entry_loss_total = loss.total();
6405 let mut previous_loss_total = entry_loss_total;
6406 let mut refine_rounds: usize = 0;
6407 // Consecutive stall rounds: counts how many successive refine rounds
6408 // ended in a stall AND a failed undamped factor. Once this reaches
6409 // `SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS` the iterate is at
6410 // its numerical fixed point and cannot be improved further; returning
6411 // `Err` here is the same "did not converge" signal that
6412 // `is_recoverable_value_probe_refusal` already handles, so the outer
6413 // BFGS treats it as an INFINITY probe and tries a different ρ instead
6414 // of looping forever burning the extended progress budget. Without
6415 // this counter the stagnation handler fell through when the undamped
6416 // factor failed and the loop kept extending via `saw_refine_progress`
6417 // from earlier rounds, accumulating minutes of wasted work (#1094).
6418 let mut consecutive_stall_factor_fail: usize = 0;
6419 loop {
6420 let sys = self
6421 .assemble_arrow_schur(target, rho, registry)
6422 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6423 // Evidence-only factorization: the Newton step (Δt, Δβ) is discarded
6424 // and only the factor cache is consumed — the exact undamped log-det
6425 // and the selected-inverse traces. As ρ sweeps to extremes (e.g. a
6426 // wide ARD-α sweep), H_tt is genuinely PD but can be ill-conditioned;
6427 // the standard Direct guard rejects that to protect Newton-step
6428 // accuracy, but the log-det is exact from diag(L) regardless of the
6429 // condition number and the traces only need the (PD) factor. So
6430 // tolerate the ill-conditioning rejection here (a genuine non-PD pivot
6431 // still errors). The cache stays undamped at ridge=0, so
6432 // `arrow_log_det_from_cache` remains exact.
6433 // The exact KKT stationarity residual is the joint gradient
6434 // ‖g‖ = √(Σ_i ‖g_t^(i)‖² + ‖g_β‖²), read straight off the assembled
6435 // system. Unlike the Newton step Δ = H⁻¹g, the gradient is
6436 // factorisation-independent: it is NOT amplified by an inverse, so a
6437 // genuinely stationary but ill-conditioned fit (tiny g, possibly large
6438 // Δ in a flat direction) is correctly recognised as converged. The
6439 // `with_ill_conditioning_tolerated` Direct factor below documents that
6440 // its Δ may be inaccurate in exactly those flat directions, so using Δ
6441 // alone as the convergence gate would falsely reject healthy fits.
6442 let grad_norm_sq: f64 = sys
6443 .rows
6444 .iter()
6445 .map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
6446 .sum::<f64>()
6447 + sys.gb.iter().map(|&v| v * v).sum::<f64>();
6448 let grad_norm = grad_norm_sq.sqrt();
6449 // Quotient KKT-gradient (#1117): the raw joint gradient retains a
6450 // persistent small component in the chart-gauge orbit and the
6451 // rank-deficient decoder β-null even at a stationary fit, so the raw
6452 // grad gate never clears on a rank-deficient circle and the inner
6453 // refine loop crawls until the (large) progress budget dies — the
6454 // 2-min stall. Measure the gradient on the SAME identified quotient
6455 // the step gate already uses: a fit whose only remaining gradient
6456 // lives in those flat directions is stationary on the quotient, so
6457 // ranking the Laplace criterion there is correct. The dense per-row
6458 // g_t is laid into the `n·q` coordinate layout the gauge basis spans;
6459 // non-dense/heterogeneous systems fall back to the raw norm.
6460 let quotient_grad_norm = {
6461 let n = self.n_obs();
6462 let q = self.assignment.row_block_dim();
6463 let dense_len = n.saturating_mul(q);
6464 let mut grad_ext_coord = Array1::<f64>::zeros(dense_len);
6465 let mut dense_layout_ok = sys.rows.len() == n;
6466 if dense_layout_ok {
6467 for (row_idx, row) in sys.rows.iter().enumerate() {
6468 let base = sys.row_offsets[row_idx];
6469 let di = sys.row_dims[row_idx];
6470 if base + di > dense_len || row.gt.len() < di {
6471 dense_layout_ok = false;
6472 break;
6473 }
6474 for axis in 0..di {
6475 grad_ext_coord[base + axis] = row.gt[axis];
6476 }
6477 }
6478 }
6479 if dense_layout_ok {
6480 self.quotient_gradient_norm_sq(
6481 grad_ext_coord.view(),
6482 sys.gb.view(),
6483 grad_norm_sq,
6484 &rho_fixed.lambda_smooth_vec(),
6485 )
6486 .map(|v| v.sqrt())
6487 .unwrap_or(grad_norm)
6488 } else {
6489 grad_norm
6490 }
6491 };
6492 let iterate_scale = self.inner_iterate_scale();
6493 // Relative parameter-step tolerance for Δ (well-conditioned charts)
6494 // and a scaled KKT-gradient tolerance. Convergence is accepted on
6495 // EITHER a small KKT gradient OR a small undamped Newton step: SAE
6496 // manifold fits contain gauge-like coordinate/decoder directions (the
6497 // circle's rotation gauge, decoder column-space rotations) where the
6498 // shared-block Hessian is near-singular, so the undamped step can stay
6499 // large in that flat direction even at a genuine stationary point; the
6500 // gradient, which is not amplified by the inverse, recognises it. With
6501 // the isometry Gauss-Newton block now a coherent PSD pullback (no
6502 // indefinite Schur pivot), the inner solve reaches true stationarity,
6503 // so the gradient tolerance is a standard relative KKT residual rather
6504 // than the 0.1.154-regression band-aid (3e-3) that masked the
6505 // non-convergence the indefinite curvature caused.
6506 let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
6507 let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
6508 if !grad_norm_sq.is_finite() {
6509 return Err(format!(
6510 "SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
6511 at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
6512 factorisation is degenerate at this ρ"
6513 ));
6514 }
6515 let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
6516 match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
6517 Ok(factored) => factored,
6518 Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
6519 if grad_norm <= grad_tolerance || quotient_grad_norm <= grad_tolerance {
6520 // K>1: the softmax/IBP logit–coordinate Gauss-Newton
6521 // cross-terms (H_zt = J_z^T J_t, assembled row-locally from
6522 // the assignment JVP × basis JVP) can make a per-row H_tt
6523 // indefinite at the TRUE KKT stationary point — when two
6524 // atoms' decoders specialise in opposite directions the
6525 // Schur complement of the logit block goes negative even
6526 // though the priors and the full-joint GN term are PSD.
6527 //
6528 // The undamped evidence factor already conditions that
6529 // block the PRINCIPLED way: `factor_spectral_deflated_
6530 // evidence_row` discovers the negative/flat eigen-direction
6531 // and stiffens it to UNIT curvature (eigenvalue → +1), so it
6532 // contributes a ρ-INDEPENDENT log 1 = 0 to the evidence —
6533 // the same quotient pseudo-determinant convention the gauge
6534 // (#1037) and data-null (#1117) deflations use. Reaching
6535 // THIS arm at stationarity therefore means even the spectral
6536 // deflation declined (a non-finite block or a failed
6537 // eigendecomposition): the state is genuinely broken, so we
6538 // surface the hard refusal and let the outer BFGS treat this
6539 // ρ as an INFINITY probe (`is_recoverable_value_probe_
6540 // refusal`). We must NOT ridge-damp here: a `+ridge·I`
6541 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6542 // bias into the VALUE that the analytic ρ-gradient (built
6543 // for the undamped Laplace log-det) never sees, desyncing
6544 // the outer line-search — the multi-atom non-convergence
6545 // this fix (#1117) removes.
6546 return Err(format!(
6547 "SaeManifoldTerm::reml_criterion: stationary undamped \
6548 evidence factorization has a non-PD per-row H_tt block \
6549 that spectral unit-stiffness deflation could not \
6550 condition (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); \
6551 {err}"
6552 ));
6553 }
6554 let refine_limit = Self::refine_iteration_limit(
6555 total_inner_iter,
6556 base_refine_iter,
6557 progress_refine_iter,
6558 previous_refine_grad_norm,
6559 grad_norm,
6560 saw_refine_progress,
6561 );
6562 if total_inner_iter >= refine_limit {
6563 // #1117/#1118 — pre-stationarity genuinely-indefinite
6564 // non-gauge H_tt under K>1 IBP/softmax row-sharing. The
6565 // logit × coordinate Gauss-Newton cross term H_zt = J_zᵀJ_t
6566 // can drive a shared row's H_tt Schur complement NEGATIVE off
6567 // the gauge orbit; the LM-escalated refinement above cannot
6568 // always cross the indefinite basin into the PD region within
6569 // the descent-extended budget.
6570 //
6571 // The undamped (ridge=0) evidence factor already conditions
6572 // that block the PRINCIPLED way: `factor_spectral_deflated_
6573 // evidence_row` discovers the negative/flat eigen-direction
6574 // and stiffens it to UNIT curvature (eigenvalue → +1), a
6575 // ρ-INDEPENDENT log 1 = 0 evidence contribution — so the
6576 // `Ok(factored)` arm above accepts the indefinite block and
6577 // returns a finite, monotone-comparable value to the outer
6578 // BFGS WITHOUT a ρ-dependent bias. Reaching THIS arm means
6579 // even that spectral deflation declined (a non-finite block
6580 // or a failed eigendecomposition): the iterate is genuinely
6581 // broken, so we surface the hard refusal and let the outer
6582 // BFGS treat this ρ as an INFINITY probe.
6583 //
6584 // We must NOT ridge-damp here: a `+ridge·I` evidence
6585 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6586 // bias into the VALUE that the analytic ρ-gradient (built
6587 // for the undamped Laplace log-det) never sees, desyncing
6588 // the outer line-search — the multi-atom non-convergence this
6589 // fix removes. K=1 (and any already-PD or spectral-deflatable
6590 // K>1 row) never reaches this branch.
6591 return Err(format!(
6592 "SaeManifoldTerm::reml_criterion: undamped evidence \
6593 factorization hit a non-PD per-row H_tt block before KKT \
6594 stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
6595 and the refinement budget was exhausted after \
6596 {total_inner_iter} inner iterations; {err}"
6597 ));
6598 }
6599 let remaining = refine_limit - total_inner_iter;
6600 let refine_iter = inner_max_iter.max(1).min(remaining);
6601 saw_refine_progress |=
6602 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6603 previous_refine_grad_norm = Some(grad_norm);
6604 *loss = self.run_joint_fit_arrow_schur(
6605 target,
6606 rho_fixed,
6607 registry,
6608 refine_iter,
6609 learning_rate,
6610 ridge_ext_coord,
6611 ridge_beta,
6612 )?;
6613 total_inner_iter += refine_iter;
6614 continue;
6615 }
6616 Err(err) => {
6617 return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
6618 }
6619 };
6620 // The Laplace normaliser ½log|H| is only the correct REML criterion at
6621 // the inner optimum (t̂, β̂). Convergence is judged by EITHER a small
6622 // gradient (KKT stationarity) OR a small undamped Newton step; the
6623 // solve is only rejected as non-converged when BOTH are large, i.e.
6624 // the iterate is neither stationary nor about to move negligibly. That
6625 // disjunction is what keeps an ill-conditioned-but-stationary fit
6626 // (small g, large Δ) from being rejected while still refusing to rank
6627 // an off-optimum Laplace criterion that is genuinely mid-flight.
6628 let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
6629 + delta_beta.iter().map(|&v| v * v).sum::<f64>();
6630 if !step_norm_sq.is_finite() {
6631 return Err(format!(
6632 "SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
6633 the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
6634 Hessian factorisation is degenerate at this ρ"
6635 ));
6636 }
6637 let step_norm = step_norm_sq.sqrt();
6638 let quotient_step_norm_sq = self.quotient_newton_step_norm_sq(
6639 delta_t.view(),
6640 delta_beta.view(),
6641 step_norm_sq,
6642 &rho_fixed.lambda_smooth_vec(),
6643 )?;
6644 let quotient_step_norm = quotient_step_norm_sq.sqrt();
6645 // Converge on ANY of: the raw KKT gradient (well-conditioned fit),
6646 // the QUOTIENT KKT gradient (#1117 — rank-deficient fit whose only
6647 // residual gradient is gauge/null flat-direction crawl), or the
6648 // quotient Newton step. The quotient-gradient disjunct is what lets
6649 // a rank-deficient K=1 circle terminate in budget instead of crawling
6650 // the weakly-identified valley until the refine budget dies.
6651 if grad_norm <= grad_tolerance
6652 || quotient_grad_norm <= grad_tolerance
6653 || quotient_step_norm <= step_tolerance
6654 {
6655 return Ok(cache);
6656 }
6657 let refine_limit = Self::refine_iteration_limit(
6658 total_inner_iter,
6659 base_refine_iter,
6660 progress_refine_iter,
6661 previous_refine_grad_norm,
6662 grad_norm,
6663 saw_refine_progress,
6664 );
6665 if total_inner_iter >= refine_limit {
6666 // Inner solve did not converge in reml_criterion; the returned
6667 // Err below carries the full non-convergence diagnostic
6668 // (gradient / quotient-step norms and tolerances) to the caller.
6669 return Err(format!(
6670 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6671 neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
6672 the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
6673 (raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
6674 tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
6675 off-optimum Laplace criterion."
6676 ));
6677 }
6678 let remaining = refine_limit - total_inner_iter;
6679 let refine_iter = inner_max_iter.max(1).min(remaining);
6680 saw_refine_progress |=
6681 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6682 previous_refine_grad_norm = Some(grad_norm);
6683 *loss = self.run_joint_fit_arrow_schur(
6684 target,
6685 rho_fixed,
6686 registry,
6687 refine_iter,
6688 learning_rate,
6689 ridge_ext_coord,
6690 ridge_beta,
6691 )?;
6692 total_inner_iter += refine_iter;
6693 refine_rounds += 1;
6694 // #1051 — objective-stagnation fixed point. A whole refine round that
6695 // failed to lower the penalised objective by a meaningful FRACTION of
6696 // the total since-entry reduction means the Newton/LM iterate is at
6697 // its numerical optimum: the remaining KKT residual lives in the
6698 // weakly-identified decoder / gauge directions the near-singular Schur
6699 // cannot resolve. Ranking the Laplace criterion at this fixed point is
6700 // correct (the only further motion is cosmetic flat-valley crawl), so
6701 // accept the current cache instead of refining until the budget dies.
6702 // Requires a few completed refine rounds (so the fraction baseline is
6703 // meaningful) but is NOT gated behind the full refine budget — the
6704 // whole point is to terminate the crawl long before that.
6705 let new_loss_total = loss.total();
6706 // Two stagnation signals, both required: (1) the latest refine round
6707 // contributed a negligible FRACTION of the total objective reduction
6708 // achieved since entry — the fit has captured essentially all the
6709 // achievable improvement and is now crawling cosmetically along the
6710 // weakly-identified valley; (2) the absolute relative decrease is
6711 // itself tiny. The fraction test is scale- and rate-free (it fires
6712 // whether the crawl decays fast or slow), so it recognises the
6713 // over-smoothed / rank-deficient fixed point the bare relative floor
6714 // misses, while still never firing on a fit that is materially
6715 // improving round over round.
6716 let total_improvement = (entry_loss_total - new_loss_total).max(0.0);
6717 let round_improvement = (previous_loss_total - new_loss_total).max(0.0);
6718 let objective_scale = previous_loss_total.abs().max(new_loss_total.abs()) + 1.0;
6719 let relative_decrease = round_improvement / objective_scale;
6720 let captured_fraction = if total_improvement > 0.0 {
6721 round_improvement / total_improvement
6722 } else {
6723 0.0
6724 };
6725 let stalled = new_loss_total.is_finite()
6726 && relative_decrease.is_finite()
6727 && (relative_decrease < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL
6728 || captured_fraction < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION);
6729 previous_loss_total = new_loss_total;
6730 if stalled && refine_rounds >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6731 let stationary_sys = self
6732 .assemble_arrow_schur(target, rho_fixed, registry)
6733 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6734 if let Ok((_dt, _db, stationary_cache)) =
6735 solve_arrow_newton_step_with_options(&stationary_sys, 0.0, 0.0, options)
6736 {
6737 return Ok(stationary_cache);
6738 }
6739 // Stagnated AND the undamped factor still fails: this is the
6740 // numerical fixed point of the inner solve under rank-deficient
6741 // or ill-conditioned geometry (e.g. multi-atom euclidean with
6742 // near-zero initial latent coords, #1094). The iterate cannot
6743 // be improved further at this ρ. Treat it as "inner solve did
6744 // not converge" — the same signal `is_recoverable_value_probe_refusal`
6745 // already handles, causing the outer BFGS to return INFINITY for
6746 // this ρ probe and try a different one. Without this early
6747 // return the stagnation handler fell through and the loop kept
6748 // burning the extended `progress_refine_iter` budget indefinitely.
6749 consecutive_stall_factor_fail += 1;
6750 if consecutive_stall_factor_fail >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6751 return Err(format!(
6752 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6753 objective stalled for {consecutive_stall_factor_fail} consecutive refine \
6754 rounds (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) and the undamped \
6755 evidence factorization failed at each stall point — the iterate is at the \
6756 numerical fixed point under rank-deficient geometry (#{consecutive_stall_factor_fail} \
6757 stall-factor-fail rounds; refusing to rank an off-optimum Laplace criterion)"
6758 ));
6759 }
6760 } else {
6761 consecutive_stall_factor_fail = 0;
6762 }
6763 }
6764 }
6765
6766 pub(crate) fn refine_iteration_limit(
6767 total_inner_iter: usize,
6768 base_refine_iter: usize,
6769 progress_refine_iter: usize,
6770 previous_grad_norm: Option<f64>,
6771 grad_norm: f64,
6772 saw_refine_progress: bool,
6773 ) -> usize {
6774 // Flat affine-gauge valleys can keep crawling productively after the
6775 // historical base budget. Extend only when the measured KKT residual has
6776 // shown a real finite round-to-round drop; true stalls end at the base
6777 // work budget (#968/#1029). Value-order probes pass the base budget as
6778 // their progress budget, so this branch cannot make probes expensive.
6779 if total_inner_iter < base_refine_iter {
6780 return base_refine_iter;
6781 }
6782 let making_progress =
6783 saw_refine_progress || Self::refine_round_made_progress(previous_grad_norm, grad_norm);
6784 if making_progress && grad_norm.is_finite() {
6785 progress_refine_iter
6786 } else {
6787 base_refine_iter
6788 }
6789 }
6790
6791 pub(crate) fn refine_round_made_progress(
6792 previous_grad_norm: Option<f64>,
6793 grad_norm: f64,
6794 ) -> bool {
6795 previous_grad_norm
6796 .is_some_and(|prev| prev.is_finite() && grad_norm.is_finite() && grad_norm < prev)
6797 }
6798
6799 pub(crate) fn outer_gradient_arrow_solver<'a>(
6800 &'a self,
6801 cache: &'a ArrowFactorCache,
6802 penalized_gram_scale: &[f64],
6803 ) -> Result<DeflatedArrowSolver<'a>, OuterGradientError> {
6804 let Err(conditioning_err) = Self::outer_gradient_conditioning_error(cache) else {
6805 return Ok(DeflatedArrowSolver::plain(cache));
6806 };
6807 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
6808 return Err(conditioning_err);
6809 };
6810 if !(max_pivot.is_finite() && max_pivot > 0.0) {
6811 return Err(conditioning_err);
6812 }
6813
6814 // The conditioning gate has already flagged a near-singular joint Hessian
6815 // (`conditioning_err`). Below we attempt to attribute that flatness to the
6816 // closed-form gauge orbit (chart step gauges) plus the penalty-aware
6817 // decoder-null directions and deflate it. When NO such deflatable
6818 // direction can be recovered, the flat subspace is genuinely
6819 // non-identifiable -- a degenerate direction OUTSIDE the gauge orbit -- a
6820 // diagnosis distinct from the raw pivot-ratio conditioning trip. Both
6821 // classes are #1273 FD-eligible, but surfacing the gauge-degenerate case
6822 // as its own [`OuterGradientError::NonIdentifiable`] keeps the diagnostic
6823 // distinction the FD-eligibility contract is built around.
6824 let non_identifiable_err = OuterGradientError::NonIdentifiable {
6825 reason: format!(
6826 "near-singular joint Hessian with no deflatable gauge/decoder-null \
6827 direction (max pivot {max_pivot:.3e})"
6828 ),
6829 };
6830
6831 let full_len = cache.delta_t_len() + cache.k;
6832 let mut raw_gauges = Vec::new();
6833 for gauge in self
6834 .dense_step_gauge_vectors()
6835 .map_err(OuterGradientError::internal)?
6836 {
6837 if gauge.len() != full_len {
6838 continue;
6839 }
6840 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6841 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6842 continue;
6843 }
6844 raw_gauges.push(gauge);
6845 }
6846 // #1051/#1273: admit the penalty-aware decoder-β null directions as
6847 // additional deflation candidates. A rank-deficient decoder design
6848 // (e.g. a euclidean-1D line in a p=2 ambient: decoder column rank 1 of
6849 // 3) puts a genuine near-null direction of the joint Hessian in the β
6850 // block, OUTSIDE the closed-form chart gauge orbit. #1273: probing the
6851 // RAW unit-β basis `e_j` produced an INCOMPLETE candidate set — the
6852 // true flat direction is the penalised null of `G_k + λ_smooth·S_k`,
6853 // not an axis-aligned coordinate, so the outer gate rejected trial ρ
6854 // with a pivot ratio (5.3e-16 < 1e-12) that the inner gate (which
6855 // already uses `decoder_beta_null_directions(λ_smooth)`) accepts. Use
6856 // the SAME penalty-aware null directions here, evaluated at the smooth
6857 // scale the Schur factor used, so the outer and inner gates agree.
6858 // These full (n·q + beta_dim)-length vectors drop into the same
6859 // Gram-Schmidt + Rayleigh + Faddeev-Popov path below; the Rayleigh
6860 // floor still keeps only genuinely flat (sub-floor) directions, so a
6861 // well-conditioned decoder is unaffected.
6862 for dir in self
6863 .decoder_beta_null_directions(penalized_gram_scale)
6864 .map_err(OuterGradientError::internal)?
6865 {
6866 if dir.len() == full_len {
6867 raw_gauges.push(dir);
6868 }
6869 }
6870 // #1051/#1273: also admit the decoder COLUMN-SPAN null (an unrealised
6871 // ambient output channel of a rank-deficient decoder), which the
6872 // channel-free basis-null above structurally cannot represent. The
6873 // rank-1-decoder-line geometry (e.g. a 1-D euclidean line in p=2
6874 // ambient: decoder column rank 1 of 2) puts the joint Hessian's
6875 // sub-floor pivot entirely in one output channel; without this
6876 // candidate the outer gate had nothing to deflate it with and rejected
6877 // the trial ρ. The Rayleigh floor below still prunes any candidate that
6878 // is not genuinely flat against the cached Hessian.
6879 for dir in self
6880 .decoder_channel_null_directions()
6881 .map_err(OuterGradientError::internal)?
6882 {
6883 if dir.len() == full_len {
6884 raw_gauges.push(dir);
6885 }
6886 }
6887 if raw_gauges.is_empty() {
6888 return Err(non_identifiable_err);
6889 }
6890
6891 let mut gauge_span: Vec<Array1<f64>> = Vec::new();
6892 for mut gauge in raw_gauges {
6893 for basis in &gauge_span {
6894 let coeff = gauge.dot(basis);
6895 for i in 0..gauge.len() {
6896 gauge[i] -= coeff * basis[i];
6897 }
6898 }
6899 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6900 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6901 continue;
6902 }
6903 let inv_norm = norm_sq.sqrt().recip();
6904 for value in gauge.iter_mut() {
6905 *value *= inv_norm;
6906 }
6907 gauge_span.push(gauge);
6908 }
6909 if gauge_span.is_empty() {
6910 return Err(non_identifiable_err);
6911 }
6912
6913 let span_rank = gauge_span.len();
6914 let mut h_span = Array2::<f64>::zeros((span_rank, span_rank));
6915 for col in 0..span_rank {
6916 let h_gauge = match apply_cached_arrow_hessian(
6917 cache,
6918 gauge_span[col].slice(s![..cache.delta_t_len()]),
6919 gauge_span[col].slice(s![cache.delta_t_len()..]),
6920 ) {
6921 Ok(value) => value,
6922 // #1451: a shape/dimension mismatch or non-finite intermediate
6923 // from the Hessian apply is an internal-invariant defect and MUST
6924 // propagate; only a genuine numeric failure on a finite,
6925 // correctly-shaped input keeps the FD-eligible conditioning class.
6926 Err(err) => {
6927 return Err(OuterGradientError::classify_arrow_solver_error(
6928 &err,
6929 conditioning_err.clone(),
6930 ));
6931 }
6932 };
6933 let h_flat = flatten_arrow_parts(h_gauge.t.view(), h_gauge.beta.view());
6934 for row in 0..span_rank {
6935 h_span[[row, col]] = gauge_span[row].dot(&h_flat);
6936 }
6937 }
6938 for row in 0..span_rank {
6939 for col in 0..row {
6940 let sym = 0.5 * (h_span[[row, col]] + h_span[[col, row]]);
6941 h_span[[row, col]] = sym;
6942 h_span[[col, row]] = sym;
6943 }
6944 }
6945 // #1451: a non-finite entry in the projected gauge Hessian is an
6946 // internal-invariant defect (a NaN/Inf intermediate leaked into the
6947 // span), not a conditioning failure — it MUST propagate rather than be
6948 // masked behind an FD descent. Guard finiteness BEFORE the eigh so only a
6949 // genuine decomposition failure on a finite, correctly-shaped matrix keeps
6950 // the FD-eligible conditioning class.
6951 if !h_span.iter().all(|v| v.is_finite()) {
6952 return Err(OuterGradientError::internal(format!(
6953 "outer_gradient_arrow_solver: non-finite entry in projected gauge \
6954 Hessian (h_span is {span_rank}x{span_rank})"
6955 )));
6956 }
6957 let (evals, evecs) = h_span
6958 .eigh(Side::Lower)
6959 .map_err(|_| conditioning_err.clone())?;
6960 let strict_gauge_floor = SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR * max_pivot;
6961 let mut orthonormal: Vec<Array1<f64>> = Vec::new();
6962 for eig_idx in 0..evals.len() {
6963 let rayleigh = evals[eig_idx];
6964 if !(rayleigh.is_finite() && rayleigh <= strict_gauge_floor) {
6965 continue;
6966 }
6967 let mut direction = Array1::<f64>::zeros(full_len);
6968 for basis_idx in 0..span_rank {
6969 let coeff = evecs[[basis_idx, eig_idx]];
6970 for row in 0..full_len {
6971 direction[row] += coeff * gauge_span[basis_idx][row];
6972 }
6973 }
6974 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6975 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6976 continue;
6977 }
6978 let inv_norm = norm_sq.sqrt().recip();
6979 for value in direction.iter_mut() {
6980 *value *= inv_norm;
6981 }
6982 orthonormal.push(direction);
6983 }
6984 if orthonormal.is_empty() {
6985 // #1273/#1440: the conditioning gate has ALREADY certified a
6986 // near-singular joint Hessian (`conditioning_err`), so a genuine flat
6987 // direction exists inside the assembled gauge/decoder-null span even
6988 // when no projected-Hessian eigenvector cleared the strict or the
6989 // `fallback_gauge_floor` Rayleigh band. Rather than declining
6990 // (which historically routed the outer step to a finite-difference
6991 // descent direction — the FD instrument #1440 removes), deflate the
6992 // SMALLEST-Rayleigh eigenvector of the projected gauge Hessian
6993 // UNCONDITIONALLY. That eigenvector is the least-curvature member of
6994 // the validated gauge span (a Faddeev-Popov gauge candidate), so the
6995 // Tikhonov stiffness `max_pivot` in `from_orthonormal_gauges` bounds
6996 // its contribution at the Hessian scale and the components orthogonal
6997 // to it are byte-for-byte the plain analytic inverse solve. This keeps
6998 // the descent direction fully ANALYTIC (a projected/damped gradient),
6999 // never a differenced value path.
7000 let mut best_idx = None;
7001 let mut best_rayleigh = f64::INFINITY;
7002 for eig_idx in 0..evals.len() {
7003 let rayleigh = evals[eig_idx];
7004 if rayleigh.is_finite() && rayleigh < best_rayleigh {
7005 best_idx = Some(eig_idx);
7006 best_rayleigh = rayleigh;
7007 }
7008 }
7009 if let Some(eig_idx) = best_idx {
7010 let mut direction = Array1::<f64>::zeros(full_len);
7011 for basis_idx in 0..span_rank {
7012 let coeff = evecs[[basis_idx, eig_idx]];
7013 for row in 0..full_len {
7014 direction[row] += coeff * gauge_span[basis_idx][row];
7015 }
7016 }
7017 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
7018 if norm_sq.is_finite() && norm_sq > 1.0e-24 {
7019 let inv_norm = norm_sq.sqrt().recip();
7020 for value in direction.iter_mut() {
7021 *value *= inv_norm;
7022 }
7023 orthonormal.push(direction);
7024 }
7025 }
7026 }
7027 if orthonormal.is_empty() {
7028 return Err(non_identifiable_err);
7029 }
7030
7031 // Quotient-geometry gauge fixing: add stiffness only along the closed-form
7032 // gauge orbit (Faddeev-Popov style). Components orthogonal to that orbit
7033 // are identical to the original inverse solve, while gauge components are
7034 // bounded at the Hessian scale `max_pivot`.
7035 // #1451: a shape/length mismatch or non-finite stiffness/intermediate in
7036 // the deflated-solver assembly is an internal-invariant defect and MUST
7037 // propagate; only a genuine near-singular gauge Woodbury/back-solve keeps
7038 // the FD-eligible conditioning class.
7039 DeflatedArrowSolver::from_orthonormal_gauges(cache, orthonormal, max_pivot)
7040 .map_err(|err| OuterGradientError::classify_arrow_solver_error(&err, conditioning_err))
7041 }
7042
7043 pub(crate) fn outer_gradient_conditioning_error(
7044 cache: &ArrowFactorCache,
7045 ) -> Result<(), OuterGradientError> {
7046 let pivot = arrow_factor_min_pivot(cache);
7047 let Some(min_pivot) = pivot.min_pivot else {
7048 return Err(OuterGradientError::IllConditioned {
7049 reason: "joint Hessian numerically singular (no cached Cholesky pivots)"
7050 .to_string(),
7051 });
7052 };
7053 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
7054 return Err(OuterGradientError::IllConditioned {
7055 reason: "joint Hessian numerically singular (no cached Cholesky pivot scale)"
7056 .to_string(),
7057 });
7058 };
7059 let ratio = min_pivot / max_pivot;
7060 if min_pivot.is_finite()
7061 && max_pivot.is_finite()
7062 && max_pivot > 0.0
7063 && ratio.is_finite()
7064 && ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
7065 {
7066 return Ok(());
7067 }
7068 Err(OuterGradientError::IllConditioned {
7069 reason: format!(
7070 "joint Hessian numerically singular (min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, max pivot {max_pivot:.3e})",
7071 floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
7072 ),
7073 })
7074 }
7075
7076 /// Smoothing-penalty Occam normalizer `−½ Σ_k r_k·rank(S_k)·log λ_smooth`
7077 /// PLUS the profiled-frame evidence-dimension term `½ Σ_k r_k·(p−r_k)·log
7078 /// λ_smooth` (issue #972).
7079 ///
7080 /// On the full-`B` path every atom's frame rank `r_k == p`, so the first
7081 /// piece reduces to the historical `½ p·(Σ rank S_k)·log λ_smooth` and the
7082 /// Grassmann term is zero — bit-for-bit unchanged. When a frame is active the
7083 /// decoder coordinates `C_k` carry the `⊗ I_{r_k}` Kronecker structure (the
7084 /// smoothing penalty `S_k` now acts on `r_k` channels, not `p`), so the
7085 /// penalty-logdet normalizer uses `r_k·rank(S_k)`; and the `r_k·(p−r_k)`
7086 /// frame degrees of freedom profiled OUT of the border are counted explicitly
7087 /// in the Laplace dimension accounting (evidence honesty) so the criterion
7088 /// cannot buy a free evidence boost by hiding decoder freedom in the frame.
7089 pub(crate) fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
7090 // #1556: λ_smooth is per-atom, so the Occam penalty normalizer and the
7091 // profiled-frame evidence-dimension term are both per-atom sums, each
7092 // atom `k` weighted by its own `log λ_smooth[k]`. With a uniform
7093 // (broadcast) vector this is bit-for-bit the historical global form.
7094 let mut acc = 0.0_f64;
7095 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7096 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7097 // Penalized decoder dimension: `r_k` coordinate channels carry the
7098 // `S_k` roughness penalty (full-`B` path ⇒ `r_k == p`).
7099 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7100 // Profiled Grassmann dimensions enter the Laplace evidence dimension
7101 // count with the OPPOSITE sign of the penalty Occam term (they are
7102 // free, unpenalized-by-`S` profiled directions), so `−occam` adds
7103 // `+½ r(p−r) log λ_k` to the criterion `V` — the honesty correction.
7104 let frame_dim = atom.frame_manifold_dimension();
7105 let log_lambda = rho.log_lambda_smooth[atom_idx];
7106 acc += 0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)) * log_lambda;
7107 }
7108 // `V = … − occam`, so the net occam SUBTRACTS the penalty normalizer and
7109 // ADDS the frame-dimension count after the caller's `− occam`.
7110 Ok(acc)
7111 }
7112
7113 /// Per-atom derivative `∂(occam)/∂log λ_smooth[k]` (#1556): atom `k`'s entry
7114 /// is `½·(r_k·rank(S_k) − frame_dim_k)`, matching the per-atom Occam term in
7115 /// [`Self::reml_occam_term`]. Returns one entry per atom in atom order.
7116 pub(crate) fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<Vec<f64>, String> {
7117 let mut out = Vec::with_capacity(self.atoms.len());
7118 for atom in &self.atoms {
7119 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7120 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7121 let frame_dim = atom.frame_manifold_dimension();
7122 out.push(0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)));
7123 }
7124 Ok(out)
7125 }
7126
7127 pub fn reml_criterion_streaming_exact(
7128 &mut self,
7129 target: ArrayView2<'_, f64>,
7130 rho: &SaeManifoldRho,
7131 registry: Option<&AnalyticPenaltyRegistry>,
7132 inner_max_iter: usize,
7133 learning_rate: f64,
7134 ridge_ext_coord: f64,
7135 ridge_beta: f64,
7136 ) -> Result<(f64, SaeManifoldLoss), String> {
7137 let mut rho_fixed = rho.clone();
7138 let mut loss = self.run_joint_fit_arrow_schur(
7139 target,
7140 &mut rho_fixed,
7141 registry,
7142 inner_max_iter,
7143 learning_rate,
7144 ridge_ext_coord,
7145 ridge_beta,
7146 )?;
7147 // Drive the inner (t, β) state to the SAME KKT/step-converged optimum the
7148 // dense `reml_criterion_with_cache` reaches before factoring. At that
7149 // optimum the per-row `H_tt^(i)` blocks are PD, so the undamped
7150 // (`ridge_t = 0`) streaming factorization in `streaming_exact_arrow_log_det`
7151 // succeeds — without this, a state stopped after only `inner_max_iter`
7152 // steps can leave a rank-deficient / indefinite row block (`p_out = 1` →
7153 // rank-1 `JᵀJ`, softmax negative-logit curvature) that surfaces
7154 // `PerRowFactorFailed` at base ridge 0. Sharing the driver also keeps the
7155 // streaming and dense log-determinants bit-identical (#847).
7156 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7157 // The dense factor cache from convergence is surplus here — the streaming
7158 // path recomputes the (bit-identical) log-determinant chunk-by-chunk in
7159 // `streaming_exact_arrow_log_det` to bound peak memory — so it is dropped.
7160 let converged_cache = self.converge_inner_for_undamped_logdet(
7161 target,
7162 rho,
7163 &mut rho_fixed,
7164 registry,
7165 inner_max_iter,
7166 learning_rate,
7167 ridge_ext_coord,
7168 ridge_beta,
7169 &mut loss,
7170 &options,
7171 true,
7172 )?;
7173 drop(converged_cache);
7174 let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
7175 let occam = self.reml_occam_term(rho)?;
7176 // Extra analytic-penalty energy (#671/#737), matching the full-batch
7177 // `reml_criterion_with_cache` path so streaming and dense criteria rank
7178 // the identical penalized objective.
7179 let extra_penalty_energy = match registry {
7180 Some(reg) => self
7181 .reml_extra_penalty_value_total(reg)
7182 .map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
7183 None => 0.0,
7184 };
7185 Ok((
7186 loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
7187 loss,
7188 ))
7189 }
7190
7191 pub fn streaming_exact_arrow_log_det(
7192 &mut self,
7193 target: ArrayView2<'_, f64>,
7194 rho: &SaeManifoldRho,
7195 registry: Option<&AnalyticPenaltyRegistry>,
7196 ) -> Result<f64, String> {
7197 if target.dim() != (self.n_obs(), self.output_dim()) {
7198 return Err(format!(
7199 "SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
7200 self.n_obs(),
7201 self.output_dim(),
7202 target.dim()
7203 ));
7204 }
7205 let plan = self.streaming_plan().admitted_or_error(
7206 self.n_obs(),
7207 self.output_dim(),
7208 self.k_atoms(),
7209 )?;
7210 if plan.estimated_dense_schur_bytes > plan.in_core_budget_bytes {
7211 return Err(format!(
7212 "SaeManifoldTerm::streaming_exact_arrow_log_det: predicted dense reduced Schur {} bytes exceeds budget {} bytes; cost-only matrix-free route is required",
7213 plan.estimated_dense_schur_bytes, plan.in_core_budget_bytes
7214 ));
7215 }
7216 let n_total = self.n_obs();
7217 let chunk_size = plan.chunk_size.min(n_total.max(1));
7218 // #972 / #977 T1: the reduced β-Schur is over the FACTORED border when
7219 // frames are active (each chunk inherits the frames via
7220 // `materialize_chunk`, so every `chunk_schur` is `border_dim²`), matching
7221 // the dense path's factored log-det. Full-`B` ⇒ `border_dim == beta_dim`.
7222 let border_dim = if self.frames_active() {
7223 self.factored_border_dim()
7224 } else {
7225 self.beta_dim()
7226 };
7227 let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
7228 let mut log_det_tt = 0.0_f64;
7229 // #1038 cross-row IBP Woodbury accumulators. `M = Uᵀ H₀'⁻¹ U` is
7230 // chunk-additive in `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` and `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ`
7231 // (`A = H₀'` block-diagonal, `U` row-supported), closed against the
7232 // GLOBAL reduced Schur `S = schur_acc` after the loop. `None` for every
7233 // non-IBP (softmax / JumpReLU) term, where the streaming log-det is
7234 // exactly the bare `log_det_tt + log_det_schur` as before.
7235 let mut wood_m0: Option<Array2<f64>> = None;
7236 let mut wood_w: Option<Array2<f64>> = None;
7237 let mut wood_d: Option<Array1<f64>> = None;
7238 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7239 let mut start = 0usize;
7240 while start < n_total {
7241 let end = (start + chunk_size).min(n_total);
7242 let penalty_scale = (end - start) as f64 / n_total as f64;
7243 let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
7244 let chunk_coords: Vec<Array2<f64>> = self
7245 .assignment
7246 .coords
7247 .iter()
7248 .map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
7249 .collect();
7250 let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
7251 // #1117 — rank deficiency is removed at the basis layer at fit entry
7252 // (`reduce_atoms_to_data_supported_rank`), so each chunk inherits the
7253 // already-reduced full-rank atoms via `materialize_chunk`; there are
7254 // no global deflation projectors to propagate.
7255 // #991: chunk terms inherit the row's design honesty weight slice
7256 // (global mean-1 normalization preserved — NOT re-normalized per
7257 // chunk — so the per-chunk sums reconstruct the global weighted
7258 // objective exactly).
7259 if let Some(w) = self.row_loss_weights.as_deref() {
7260 chunk.row_loss_weights = Some(w[start..end].to_vec());
7261 }
7262 let z_chunk = target.slice(s![start..end, ..]);
7263 let sys = chunk
7264 .assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
7265 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7266 let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
7267 let (chunk_log_det_tt, chunk_schur, chunk_wood) = streaming
7268 .reduced_schur_log_det_tt_woodbury(0.0, 0.0, &options)
7269 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7270 log_det_tt += chunk_log_det_tt;
7271 for row in 0..border_dim {
7272 for col in 0..border_dim {
7273 schur_acc[[row, col]] += chunk_schur[[row, col]];
7274 }
7275 }
7276 if chunk_wood.is_some() && chunk_size < n_total {
7277 // The cross-row IBP empirical mass `M_k = Σ_i z_ik` couples ALL
7278 // rows, so the per-row `H₀'` diagonal (`score_derivative_k(M_k)`)
7279 // and the column coefficient `d_k = w·s'_k(M_k)` are only exact
7280 // when every row is assembled together — a SINGLE chunk. Under a
7281 // genuine multi-chunk pass each chunk would see a partial mass and
7282 // the Woodbury (and the bare per-row log-det) would be inexact, so
7283 // refuse loudly and route to the dense resident path rather than
7284 // return a silently-wrong evidence. The streaming log-det only
7285 // runs when the dense reduced Schur fits budget, so the single-
7286 // chunk regime is the common case; this guards the rest.
7287 return Err(
7288 "SaeManifoldTerm::streaming_exact_arrow_log_det: exact cross-row IBP \
7289 Woodbury evidence requires a single-chunk pass (the empirical mass \
7290 M_k = Σ_i z_ik couples all rows); this shape needs >1 chunk. Route \
7291 IBP-active large-n fits through the dense resident \
7292 ArrowFactorCache::arrow_log_det."
7293 .to_string(),
7294 );
7295 }
7296 if let Some(cw) = chunk_wood {
7297 wood_m0 = Some(match wood_m0.take() {
7298 Some(mut acc) => {
7299 acc += &cw.m0;
7300 acc
7301 }
7302 None => cw.m0,
7303 });
7304 wood_w = Some(match wood_w.take() {
7305 Some(mut acc) => {
7306 acc += &cw.w;
7307 acc
7308 }
7309 None => cw.w,
7310 });
7311 // `D = diag(d_k)` is per-atom; identical across chunks for a
7312 // single-chunk evidence pass (the regime the streaming log-det
7313 // runs in — the dense reduced Schur must fit budget here), where
7314 // it equals the global mass-derived `cross_row_d`.
7315 wood_d = Some(cw.d);
7316 }
7317 start = end;
7318 }
7319 let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
7320 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7321 let mut total = log_det_tt + log_det_schur;
7322 // #1038/#1225: close the exact cross-row IBP Woodbury correction
7323 // `log det(I_R + D Uᵀ H₀'⁻¹ U)` so the streaming evidence equals the
7324 // dense `arrow_log_det_from_cache` (which adds the SAME term). Without
7325 // it the streaming criterion would silently drop the entire cross-row
7326 // coupling and disagree with the dense path by exactly `log|C|`.
7327 if let (Some(m0), Some(w), Some(d)) = (wood_m0, wood_w, wood_d) {
7328 let correction = streaming_cross_row_woodbury_log_det(&schur_acc, &m0, &w, &d)
7329 .map_err(|err| {
7330 format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7331 })?
7332 .ok_or_else(|| {
7333 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
7334 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
7335 .to_string()
7336 })?;
7337 total += correction;
7338 }
7339 Ok(total)
7340 }
7341
7342 /// Per-atom, per-axis coordinate sum-of-squares `‖t_kj‖² = Σ_i t_{i,k,j}²`.
7343 ///
7344 /// This is the data-fit sufficient statistic for the ARD precision update
7345 /// (the numerator-side `‖t‖²` of the deleted `α = n/‖t‖²` rule). Returned
7346 /// per atom as an `Array1` of length `d_k`.
7347 ///
7348 /// On a *periodic* (Circle) axis the relevant statistic is the von-Mises
7349 /// energy-equivalent `Σ_i 2/α·V(t_i) = Σ_i (2/κ²)(1−cos κ t_i)` (independent
7350 /// of α), so that `½·α·sumsq == Σ_i V(t_i)` matches `ard_value`. This keeps
7351 /// the Mackay/Fellner–Schall fixed point `α ← n / (sumsq + tr H⁻¹)`
7352 /// consistent with the actual periodic prior energy rather than the
7353 /// origin-dependent raw `t²`.
7354 pub(crate) fn ard_coord_sumsq(&self) -> Vec<Array1<f64>> {
7355 let mut out = Vec::with_capacity(self.k_atoms());
7356 for coord in &self.assignment.coords {
7357 let d = coord.latent_dim();
7358 let periods = coord.effective_axis_periods();
7359 let mut sq = Array1::<f64>::zeros(d);
7360 for row in 0..coord.n_obs() {
7361 let t = coord.row(row);
7362 for axis in 0..d {
7363 // `sq_equiv` is independent of `alpha`; pass 1.0.
7364 sq[axis] += ArdAxisPrior::eval(1.0, t[axis], periods[axis]).sq_equiv;
7365 }
7366 }
7367 out.push(sq);
7368 }
7369 out
7370 }
7371
7372 /// Per-atom, per-axis posterior-variance trace `tr_kj(H⁻¹) =
7373 /// Σ_i [(H⁻¹)_tt]_{(i,k,j),(i,k,j)}` from the converged factor cache.
7374 ///
7375 /// `cache.latent_block_inverse_diagonal()` returns the diagonal of the
7376 /// latent block `(H⁻¹)_tt` in the cache's compact per-row `delta_t`
7377 /// layout (length `row_offsets[N]`); each per-row block is laid out as
7378 /// `[logit scalars…, then per-active-atom coord axes…]`. This routine
7379 /// sums those diagonal entries over the coord positions belonging to each
7380 /// `(atom k, axis j)` across all observation rows where atom `k` is active.
7381 ///
7382 /// `self.last_row_layout` must be the layout from the *same* assemble that
7383 /// produced `cache`:
7384 /// - `Some(layout)`: compact active-set mode (JumpReLU / large-K
7385 /// softmax-IBP truncation). For row `i`, atom `k`'s position in the
7386 /// active list gives its compact coord-block start `coord_starts[i][pos]`;
7387 /// inactive atoms contribute 0 (the prior dominates there anyway).
7388 /// - `None`: dense full-support layout, uniform row dim
7389 /// `q = assignment_dim + Σ d_k`; atom `k`'s coord block sits at the
7390 /// fixed full-row offset `coord_offsets[k]` after the assignment chart.
7391 ///
7392 /// This `tr_kj(H⁻¹)` is exactly the posterior-variance term the deleted
7393 /// `α = n/‖t‖²` rule dropped; the corrected Mackay/Fellner-Schall fixed
7394 /// point is `α_new = n / (‖t_kj‖² + tr_kj(H⁻¹))`.
7395 pub(crate) fn ard_inverse_traces(
7396 &self,
7397 cache: &ArrowFactorCache,
7398 ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7399 let inv_diag = cache.latent_block_inverse_diagonal()?;
7400 let n = self.n_obs();
7401 let coord_offsets = self.assignment.coord_offsets();
7402 let mut traces: Vec<Array1<f64>> = self
7403 .assignment
7404 .coords
7405 .iter()
7406 .map(|c| Array1::<f64>::zeros(c.latent_dim()))
7407 .collect();
7408 for row in 0..n {
7409 let row_base = cache.row_offsets[row];
7410 match self.last_row_layout {
7411 Some(ref layout) => {
7412 let active = &layout.active_atoms[row];
7413 let starts = &layout.coord_starts[row];
7414 for (pos, &k) in active.iter().enumerate() {
7415 let d = self.assignment.coords[k].latent_dim();
7416 let block_start = starts[pos];
7417 for axis in 0..d {
7418 traces[k][axis] += inv_diag[row_base + block_start + axis];
7419 }
7420 }
7421 }
7422 None => {
7423 for k in 0..self.k_atoms() {
7424 let d = self.assignment.coords[k].latent_dim();
7425 let block_start = coord_offsets[k];
7426 for axis in 0..d {
7427 traces[k][axis] += inv_diag[row_base + block_start + axis];
7428 }
7429 }
7430 }
7431 }
7432 }
7433 Ok(traces)
7434 }
7435
7436 pub(crate) fn ard_log_precision_explicit_derivatives(
7437 &self,
7438 rho: &SaeManifoldRho,
7439 ) -> Result<Vec<Array1<f64>>, String> {
7440 if rho.log_ard.len() != self.k_atoms() {
7441 return Err(format!(
7442 "ARD rho has {} atoms but term has {}",
7443 rho.log_ard.len(),
7444 self.k_atoms()
7445 ));
7446 }
7447 let n = self.n_obs() as f64;
7448 let mut out = Vec::with_capacity(self.k_atoms());
7449 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
7450 let d = coord.latent_dim();
7451 let mut atom_out = Array1::<f64>::zeros(rho.log_ard[atom_idx].len());
7452 if rho.log_ard[atom_idx].is_empty() {
7453 out.push(atom_out);
7454 continue;
7455 }
7456 if rho.log_ard[atom_idx].len() != d {
7457 return Err(format!(
7458 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
7459 rho.log_ard[atom_idx].len()
7460 ));
7461 }
7462 let periods = coord.effective_axis_periods();
7463 for axis in 0..d {
7464 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
7465 let period = periods[axis];
7466 let mut energy_deriv = 0.0_f64;
7467 for row in 0..coord.n_obs() {
7468 let t = coord.row(row)[axis];
7469 energy_deriv += ArdAxisPrior::eval(alpha, t, period).value;
7470 }
7471 let normalizer_deriv = match period {
7472 None => -0.5 * n,
7473 Some(p) => {
7474 let kappa = std::f64::consts::TAU / p;
7475 let eta = alpha / (kappa * kappa);
7476 // d/d(log α) of `n[-η + log I0(η)]` = `n η (I1/I0 - 1)`.
7477 // The ratio is computed without forming `e^{η}`, so it
7478 // stays finite for large `η` instead of the `inf/inf =
7479 // NaN` that `bessel_i1(η)/bessel_i0(η)` produces (#1113).
7480 let ratio = bessel_i0_log_and_ratio(eta).1;
7481 n * eta * (-1.0 + ratio)
7482 }
7483 };
7484 atom_out[axis] = energy_deriv + normalizer_deriv;
7485 }
7486 out.push(atom_out);
7487 }
7488 Ok(out)
7489 }
7490
7491 pub(crate) fn ard_log_precision_hessian_trace(
7492 &self,
7493 rho: &SaeManifoldRho,
7494 cache: &ArrowFactorCache,
7495 solver: &DeflatedArrowSolver<'_>,
7496 ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7497 // RAW selected-inverse diagonal: the per-axis diagonal contraction uses
7498 // the DEFLATED inverse; the full kept-subspace + rotation deflation
7499 // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per (row, axis)
7500 // afterwards via the Daleckii–Krein helper. Each ARD ρ-component
7501 // `(atom k, axis)` differentiates a SINGLE coordinate-slot diagonal entry,
7502 // so its `D` is the rank-one `hess·e_s e_sᵀ` at that local slot `s`.
7503 let inv_diag = solver
7504 .latent_inverse_diagonal()
7505 .map_err(|err| ArrowSchurError::SchurFactorFailed { reason: err })?;
7506 let n = self.n_obs();
7507 let total_t = cache.delta_t_len();
7508 let coord_offsets = self.assignment.coord_offsets();
7509 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
7510 .assignment
7511 .coords
7512 .iter()
7513 .map(LatentCoordValues::effective_axis_periods)
7514 .collect();
7515 let mut traces: Vec<Array1<f64>> = self
7516 .assignment
7517 .coords
7518 .iter()
7519 .enumerate()
7520 .map(|(k, c)| {
7521 if rho.log_ard[k].is_empty() {
7522 Array1::<f64>::zeros(0)
7523 } else {
7524 Array1::<f64>::zeros(c.latent_dim())
7525 }
7526 })
7527 .collect();
7528 for row in 0..n {
7529 let row_base = cache.row_offsets[row];
7530 let q = cache.row_dims[row];
7531 let dirs = cache
7532 .deflated_row_directions
7533 .get(row)
7534 .map(Vec::as_slice)
7535 .unwrap_or(&[]);
7536 let spectrum = cache
7537 .deflation_row_spectra
7538 .get(row)
7539 .and_then(Option::as_ref);
7540 // Per-row selected-inverse t-block, built once (only when deflated).
7541 let inv_vv = if dirs.is_empty() {
7542 None
7543 } else {
7544 let mut m = Array2::<f64>::zeros((q, q));
7545 for col in 0..q {
7546 let mut rhs_t = Array1::<f64>::zeros(total_t);
7547 let rhs_beta = Array1::<f64>::zeros(cache.k);
7548 rhs_t[row_base + col] = 1.0;
7549 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7550 ArrowSchurError::SchurFactorFailed { reason: err }
7551 })?;
7552 for r in 0..q {
7553 m[[r, col]] = solved.t[row_base + r];
7554 }
7555 }
7556 Some(m)
7557 };
7558 // Correction for one local coordinate slot `s` with curvature `hess`.
7559 let slot_correction = |s: usize, hess: f64| -> f64 {
7560 let Some(iv) = inv_vv.as_ref() else {
7561 return 0.0;
7562 };
7563 if s >= q || hess == 0.0 {
7564 return 0.0;
7565 }
7566 let mut d = Array2::<f64>::zeros((q, q));
7567 d[[s, s]] = hess;
7568 Self::deflation_block_correction(iv, &d, dirs, spectrum)
7569 };
7570 match self.last_row_layout {
7571 Some(ref layout) => {
7572 let active = &layout.active_atoms[row];
7573 let starts = &layout.coord_starts[row];
7574 for (pos, &k) in active.iter().enumerate() {
7575 if rho.log_ard[k].is_empty() {
7576 continue;
7577 }
7578 let coord = &self.assignment.coords[k];
7579 let d = coord.latent_dim();
7580 let block_start = starts[pos];
7581 for axis in 0..d {
7582 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7583 let t = coord.row(row)[axis];
7584 let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7585 let hess = prior.hess.max(0.0);
7586 let s = block_start + axis;
7587 traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7588 traces[k][axis] -= 0.5 * slot_correction(s, hess);
7589 }
7590 }
7591 }
7592 None => {
7593 for k in 0..self.k_atoms() {
7594 if rho.log_ard[k].is_empty() {
7595 continue;
7596 }
7597 let coord = &self.assignment.coords[k];
7598 let d = coord.latent_dim();
7599 let block_start = coord_offsets[k];
7600 for axis in 0..d {
7601 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7602 let t = coord.row(row)[axis];
7603 let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7604 let hess = prior.hess.max(0.0);
7605 let s = block_start + axis;
7606 traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7607 traces[k][axis] -= 0.5 * slot_correction(s, hess);
7608 }
7609 }
7610 }
7611 }
7612 }
7613 Ok(traces)
7614 }
7615
7616 /// Per-atom decoder-smoothness penalty quadratic form (#1556): entry `k` is
7617 /// the λ-free `<B_k, ½(S_k+S_kᵀ)·B_k> = Σ_oc B_k[:,oc]ᵀ S_k B_k[:,oc]`, the
7618 /// per-atom denominator of atom `k`'s λ_smooth Fellner-Schall update. The sum
7619 /// over atoms is `βᵀ(⊕_k S_k ⊗ I_p)β`, the un-scaled total penalty energy.
7620 /// `S_k` is symmetrised defensively (as the assembler does); the per-atom
7621 /// `½(S+Sᵀ)·B_k` GEMMs ride the multi-GPU batched smoothness GEMM with an
7622 /// exact per-atom CPU fallback.
7623 pub(crate) fn decoder_smoothness_quadratic_form_per_atom(&self) -> Vec<f64> {
7624 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
7625 .atoms
7626 .iter()
7627 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
7628 .collect();
7629 let sb_all = batched_smooth_sb(&sb_inputs, true);
7630 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7631 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
7632 per_atom[atom_idx] = (&atom.decoder_coefficients * sb).sum();
7633 }
7634 per_atom
7635 }
7636
7637 /// Per-atom effective penalized dof of the decoder smoothness penalty
7638 /// (#1556): entry `k` is `tr(S_β⁻¹ · M_k)` with `M_k = (λ_smooth[k]·S_k) ⊗ I`
7639 /// and `S_β⁻¹ = (H⁻¹)_ββ` the Schur-complement inverse, each atom scaled by
7640 /// its OWN `lambda_smooth[atom_idx]`. Built on
7641 /// [`ArrowFactorCache::schur_inverse_apply`]: column `(k,μ,oc)` of `M_k` is
7642 /// `λ_k·S_k[:,μ] ⊗ e_oc` (sparse), so we apply `S_β⁻¹` to that K-vector and
7643 /// read back `result[col]`. The total edf is the sum of the returned vector
7644 /// (a uniform/broadcast λ reproduces the historical global trace).
7645 pub(crate) fn decoder_smoothness_effective_dof_per_atom(
7646 &self,
7647 cache: &ArrowFactorCache,
7648 lambda_smooth: &[f64],
7649 ) -> Result<Vec<f64>, ArrowSchurError> {
7650 let p = self.output_dim();
7651 let frames_active = self.frames_active();
7652 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7653 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7654 (
7655 self.factored_beta_offsets(),
7656 Box::new(move |k: usize| ranks[k]),
7657 )
7658 } else {
7659 (self.beta_offsets(), Box::new(move |_k: usize| p))
7660 };
7661 let k = cache.k;
7662 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7663 let mut m_col = Array1::<f64>::zeros(k);
7664 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7665 let s = &atom.smooth_penalty;
7666 let m = atom.basis_size();
7667 let off = offsets[atom_idx];
7668 let r = out_dim(atom_idx);
7669 let lambda = lambda_smooth[atom_idx];
7670 let mut trace = 0.0_f64;
7671 for mu in 0..m {
7672 for oc in 0..r {
7673 let col = off + mu * r + oc;
7674 m_col.fill(0.0);
7675 for nu in 0..m {
7676 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7677 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7678 }
7679 let z = cache.schur_inverse_apply(m_col.view())?;
7680 trace += z[col];
7681 }
7682 }
7683 per_atom[atom_idx] = trace;
7684 }
7685 Ok(per_atom)
7686 }
7687
7688 /// Per-atom effective penalized dof via the deflated solver (#1556): entry
7689 /// `k` is `tr((H⁻¹)_ββ · M_k)` for `M_k = (λ_smooth[k]·S_k) ⊗ I`, each atom
7690 /// scaled by its OWN `lambda_smooth[atom_idx]`. The total is the sum.
7691 pub(crate) fn decoder_smoothness_effective_dof_with_solver_per_atom(
7692 &self,
7693 cache: &ArrowFactorCache,
7694 solver: &DeflatedArrowSolver<'_>,
7695 lambda_smooth: &[f64],
7696 ) -> Result<Vec<f64>, String> {
7697 let p = self.output_dim();
7698 // #972 / #977 T1: the cache's β block is the FACTORED border when frames
7699 // are active (`cache.k == factored_border_dim`), so the smoothness edf
7700 // trace `tr((H⁻¹)_ββ · M)` is taken over the same factored layout, with
7701 // `M = ⊕_k (λ_k S_k) ⊗ I_{r_k}` at the factored offsets (the `U_kᵀU_k = I`
7702 // collapse means the per-coordinate-channel penalty is `λ_k S_k`, exactly
7703 // as in the full-`B` `⊗ I_p` case but with `r_k` channels). On the
7704 // full-`B` path `frames_active` is false: `out_dim_k = p`, the offsets
7705 // are `beta_offsets`, and this is bit-for-bit the historical trace.
7706 let frames_active = self.frames_active();
7707 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7708 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7709 (
7710 self.factored_beta_offsets(),
7711 Box::new(move |k: usize| ranks[k]),
7712 )
7713 } else {
7714 (self.beta_offsets(), Box::new(move |_k: usize| p))
7715 };
7716 let k = cache.k;
7717 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7718 let mut m_col = Array1::<f64>::zeros(k);
7719 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7720 let s = &atom.smooth_penalty;
7721 let m = atom.basis_size();
7722 let off = offsets[atom_idx];
7723 let r = out_dim(atom_idx);
7724 let lambda = lambda_smooth[atom_idx];
7725 let mut trace = 0.0_f64;
7726 for mu in 0..m {
7727 for oc in 0..r {
7728 let col = off + mu * r + oc;
7729 // M[:,col] = λ_k · S_k[:,mu] ⊗ e_oc (nonzero at off+ν·r+oc).
7730 m_col.fill(0.0);
7731 for nu in 0..m {
7732 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7733 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7734 }
7735 let zero_t = Array1::<f64>::zeros(cache.delta_t_len());
7736 let z = solver.solve(zero_t.view(), m_col.view())?.beta;
7737 trace += z[col];
7738 }
7739 }
7740 per_atom[atom_idx] = trace;
7741 }
7742 Ok(per_atom)
7743 }
7744
7745 pub(crate) fn assignment_log_strength_hessian_trace(
7746 &self,
7747 rho: &SaeManifoldRho,
7748 cache: &ArrowFactorCache,
7749 solver: &DeflatedArrowSolver<'_>,
7750 ) -> Result<f64, String> {
7751 let k_atoms = self.k_atoms();
7752 // #1038 softmax: `H` carries the DENSE entropy block, and since the
7753 // entropy curvature scales linearly with `λ_sparse = exp(ρ)`,
7754 // `∂H/∂ρ = H_entropy` (the full dense per-row block, not just its
7755 // diagonal). The trace `½ tr(H⁻¹ ∂H/∂ρ)` must therefore contract the
7756 // dense `∂H/∂ρ` against the per-row selected-inverse BLOCK, mirroring the
7757 // dense `log|H|` and θ-adjoint — a diagonal-only contraction would
7758 // desync the ρ-gradient from the criterion. The assembled majorizer
7759 // `D = diag(Σ_j|H_kj|)` is itself DIAGONAL (#1419), so the contraction
7760 // reduces to `½ Σ_slot (H⁻¹)_{slot,slot}·D_atom`. On the dense `None`
7761 // layout the logit slot equals the atom position; on the compact
7762 // softmax top-`k` layout (#1408/#1409) the slots are the row's active
7763 // atoms — the SAME `D_atom` (full-`K` abs-row-sum) the assembly wrote.
7764 if let AssignmentMode::Softmax {
7765 temperature,
7766 sparsity,
7767 } = self.assignment.mode
7768 {
7769 if k_atoms <= 1 {
7770 return Ok(0.0);
7771 }
7772 let inv_tau = 1.0 / temperature;
7773 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
7774 let penalty = gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
7775 k_atoms,
7776 temperature,
7777 );
7778 // Softmax uses the reduced K−1 free-logit chart on the dense layout
7779 // (last reference logit fixed); the compact layout carries one slot
7780 // per active atom. The diagonal selected inverse gives each slot's
7781 // (H⁻¹)_{slot,slot}.
7782 let assignment_dim = self.assignment.assignment_coord_dim();
7783 // Kept-subspace inverse diagonal: the deflated inverse assigns
7784 // `1/λ̃ = 1` to each per-row UNIT-stiffness direction `vᵢ`, so a raw
7785 // diagonal `D` contraction would spuriously add `½ Σ_i vᵢᵀ D vᵢ` (a
7786 // ρ-independent direction must add 0). `latent_inverse_diagonal_kept`
7787 // removes that per-row deflated diagonal centrally.
7788 let inv_diag = solver
7789 .latent_inverse_diagonal_kept()
7790 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7791 let mut trace = 0.0_f64;
7792 for row in 0..self.n_obs() {
7793 let row_base = cache.row_offsets[row];
7794 // ∂(scale·D)/∂ρ = scale·D (linear in λ_sparse = eᵖ) — the SAME
7795 // operator the assembly and θ-adjoint differentiate.
7796 match self.last_row_layout {
7797 Some(ref layout) => {
7798 // #1410: the compact adjoint reads `D_kk` only for this
7799 // row's `≤ top_k` active atoms, so compute those entries
7800 // directly from the softmax row `a` via the active-only
7801 // Gershgorin helper — no full-`K` `row_logits` copy and no
7802 // full-`K` `d` vector. `a` itself is the irreducible `O(K)`
7803 // softmax normalisation, computed once per row and shared
7804 // across the row's active slots.
7805 let a = crate::assignment::softmax_row(
7806 self.assignment.logits.row(row),
7807 temperature,
7808 );
7809 let a = a.as_slice().expect("softmax row must be contiguous");
7810 let m = softmax_majorizer_log_mean(a);
7811 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7812 let d_atom =
7813 active_softmax_gershgorin_majorizer_entry(a, atom, m, scale);
7814 trace += inv_diag[row_base + pos] * d_atom;
7815 }
7816 }
7817 None => {
7818 // Dense layout genuinely contracts every free logit slot's
7819 // `D_kk`, so the full-`K` `d` is intrinsic here; keep the
7820 // single-source dense majorizer call.
7821 let row_logits: Vec<f64> = (0..k_atoms)
7822 .map(|k| self.assignment.logits[[row, k]])
7823 .collect();
7824 let d = penalty.psd_majorizer_abs_row_sums(&row_logits, scale);
7825 let q = cache.row_dims[row];
7826 let logit_dim = assignment_dim.min(q);
7827 for atom in 0..logit_dim {
7828 trace += inv_diag[row_base + atom] * d[atom];
7829 }
7830 }
7831 }
7832 }
7833 return Ok(0.5 * trace);
7834 }
7835 let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
7836 if hdiag.is_empty() {
7837 return Ok(0.0);
7838 }
7839 // RAW selected-inverse diagonal: the per-row diagonal contraction uses the
7840 // DEFLATED inverse; the full kept-subspace + β-Schur/rotation deflation
7841 // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per row afterwards
7842 // (`deflation_block_correction`), exactly as the data trace does. The
7843 // cross-row off-diagonal pass below contracts only DISTINCT rows `i ≠ j`,
7844 // off any single-row `vᵢ`'s support, so it needs no deflation correction.
7845 let inv_diag = solver
7846 .latent_inverse_diagonal()
7847 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7848 let assignment_dim = self.assignment.assignment_coord_dim();
7849 let total_t = cache.delta_t_len();
7850 // #932 FRONT C: row-local Takahashi selected inverse on the plain arrow
7851 // for the per-row deflation correction below (the diagonal trace already
7852 // uses the cheap `latent_inverse_diagonal`); gauge / cross-row Woodbury
7853 // fall back to the per-row full-system `solve` loop.
7854 let fast_selected = solver.plain_selected_inverse_available();
7855 let selected_beta_inv = if fast_selected && cache.k > 0 {
7856 solver
7857 .beta_inv()
7858 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?
7859 } else {
7860 Array2::<f64>::zeros((0, 0))
7861 };
7862 // #1416 cross-row IBP source: the per-row block that the deflation
7863 // factorizes is the NO-SELF base `H₀'` — the rank-one self curvature
7864 // `d_k·J_ik²` is DOWNDATED from each logit diagonal and re-applied through
7865 // the Woodbury carrier. The full-`H` diagonal contraction below still uses
7866 // the full `hdiag` (which carries that self term), but the per-row
7867 // DEFLATION correction must use `(∂H₀'/∂ρ)_tt`, i.e. `hdiag` MINUS the
7868 // downdated self term — otherwise the Daleckii–Krein correction
7869 // mis-attributes the (un-deflated) Woodbury self curvature's derivative to
7870 // the deflated subspace. For non-IBP modes there is no Woodbury source and
7871 // the self term is `0` (the deflated block IS the full block).
7872 // #1416 (compact-layout completion): the IBP cross-row Woodbury source is
7873 // installed for BOTH the dense and the compact (#1420 top-`k`) layouts (see
7874 // `set_ibp_cross_row_source`, which emits `(g_base + pos, atom, z'_ik)` for
7875 // the active set under a compact layout), so the deflated base `H₀'` is the
7876 // no-self block in BOTH layouts. The self-curvature downdate below must
7877 // therefore run regardless of layout — gating it to the dense path (the
7878 // pre-fix bug) left the compact deflation correction differentiating the
7879 // un-downdated full block. For non-IBP modes `ibp_assignment_third_channels`
7880 // returns `None`, there is no Woodbury source, and `self_curv` is
7881 // identically 0 (the deflated block IS the full block).
7882 let cross_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
7883 let learnable_alpha = matches!(
7884 self.assignment.mode,
7885 AssignmentMode::IBPMap {
7886 learnable_alpha: true,
7887 ..
7888 }
7889 );
7890 let self_curv = |row: usize, atom: usize| -> f64 {
7891 let Some(ch) = cross_channels.as_ref() else {
7892 return 0.0;
7893 };
7894 let d_k = if learnable_alpha {
7895 ch.cross_row_d_logalpha[atom]
7896 } else {
7897 ch.cross_row_d[atom]
7898 };
7899 let j = ch.z_jac[row * k_atoms + atom];
7900 d_k * j * j
7901 };
7902 let mut trace = 0.0_f64;
7903 for row in 0..self.n_obs() {
7904 let row_base = cache.row_offsets[row];
7905 let assignment_base = row * k_atoms;
7906 let q = cache.row_dims[row];
7907 // Per-row diagonal `(∂H₀'/∂ρ)_tt` for the deflation correction: the
7908 // assignment prior curves only the logit/assignment slots (coordinate
7909 // slots are 0 — ARD handles those), MINUS the downdated cross-row self
7910 // curvature. The full-`H` trace contraction keeps the full `hdiag`.
7911 let mut d_diag = Array1::<f64>::zeros(q);
7912 match self.last_row_layout {
7913 Some(ref layout) => {
7914 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7915 let d_slot = hdiag[assignment_base + atom];
7916 trace += inv_diag[row_base + pos] * d_slot;
7917 if pos < q {
7918 d_diag[pos] = d_slot - self_curv(row, atom);
7919 }
7920 }
7921 }
7922 None => {
7923 for free_idx in 0..assignment_dim {
7924 let d_slot = hdiag[assignment_base + free_idx];
7925 trace += inv_diag[row_base + free_idx] * d_slot;
7926 if free_idx < q {
7927 d_diag[free_idx] = d_slot - self_curv(row, free_idx);
7928 }
7929 }
7930 }
7931 }
7932 let dirs = cache
7933 .deflated_row_directions
7934 .get(row)
7935 .map(Vec::as_slice)
7936 .unwrap_or(&[]);
7937 if !dirs.is_empty() {
7938 let inv_vv = if fast_selected {
7939 let (inv_vv, _inv_vbeta) = solver
7940 .selected_inverse_row_blocks(row, &selected_beta_inv)
7941 .map_err(|err| {
7942 format!("assignment_log_strength_hessian_trace: selected inverse: {err}")
7943 })?;
7944 inv_vv
7945 } else {
7946 let mut inv_vv = Array2::<f64>::zeros((q, q));
7947 for col in 0..q {
7948 let mut rhs_t = Array1::<f64>::zeros(total_t);
7949 let rhs_beta = Array1::<f64>::zeros(cache.k);
7950 rhs_t[row_base + col] = 1.0;
7951 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7952 format!(
7953 "assignment_log_strength_hessian_trace: selected inverse: {err}"
7954 )
7955 })?;
7956 for r in 0..q {
7957 inv_vv[[r, col]] = solved.t[row_base + r];
7958 }
7959 }
7960 inv_vv
7961 };
7962 let mut d_mat = Array2::<f64>::zeros((q, q));
7963 for s in 0..q {
7964 d_mat[[s, s]] = d_diag[s];
7965 }
7966 let spectrum = cache
7967 .deflation_row_spectra
7968 .get(row)
7969 .and_then(Option::as_ref);
7970 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
7971 }
7972 }
7973 // #1416: the IBP prior Hessian is `H_p = d·J Jᵀ + diag(s, c)`, where the
7974 // rank-one `d·J Jᵀ` couples EVERY row pair `(i, j)` in a column `k`
7975 // through the shared empirical mass `M_k`. The assembled `H` carries the
7976 // full `H_full = H₀' + U D Uᵀ` (Woodbury, `set_ibp_cross_row_source`), and
7977 // for fixed alpha the entire IBP prior scales with `λ = eᵖ`, so
7978 // `∂H_p/∂ρ = H_p`. The diagonal loop above already captures the `i = j`
7979 // self terms (the `d·J_ik²` summand lives in `hdiag`); this pass adds the
7980 // omitted off-diagonal `½·d_k·Σ_{i≠j}(H⁻¹)_{ik,jk}·J_ik·J_jk`. Only IBP
7981 // has the cross-row rank-one source; for other diagonal modes
7982 // `ibp_assignment_third_channels` returns `None` and the trace stays the
7983 // pure diagonal contraction.
7984 //
7985 // #1416 (compact completion): this pass is LAYOUT-AGNOSTIC. Under the dense
7986 // layout atom `k`'s logit slot is local position `k`
7987 // (`row_offsets[i] + k`); under the compact (#1420 top-`k`) layout only the
7988 // row's active atoms carry coordinates and atom `k` lives at local position
7989 // `pos` of `active_atoms[row]` (`row_offsets[i] + pos`). The Woodbury source
7990 // and the θ-adjoint already use this active-slot mapping, so gating the
7991 // cross-row pass to the dense layout (the pre-fix bug) dropped the
7992 // off-diagonal term from `∂log|H|/∂ρ` whenever the budget/`top_k` engaged
7993 // the compact layout. We build per-column active sites `(row, t_index)` once
7994 // — exactly the θ-adjoint `col_sites` construction — then contract the
7995 // off-diagonal `i ≠ j` remainder with one solve per active site.
7996 if let Some(channels) = cross_channels.as_ref() {
7997 let n = self.n_obs();
7998 let total_t = cache.delta_t_len();
7999 // This trace is ½ ∂log|H|/∂ρ. For FIXED-α IBP the whole prior
8000 // scales with λ=eᵖ so ∂H_p/∂ρ = H_p and the rank-one coefficient
8001 // is the VALUE `cross_row_d[k] = w·s'_k`. For LEARNABLE-α this trace
8002 // is ½ ∂log|H|/∂logα, and the rank-one block's logα-derivative is
8003 // `∂d_k/∂logα = w·∂s'_k/∂logα` (`cross_row_d_logalpha[k]`) — the same
8004 // α-derivative the DIAGONAL channel (`hessian_diag_log_alpha_derivative`)
8005 // already uses. Using the value `s'_k` here (the pre-fix bug) made the
8006 // off-diagonal inconsistent with the diagonal and the α-gradient wrong.
8007 // (`learnable_alpha` is the same flag the self-curvature downdate uses.)
8008 // Per-column active sites `(row, global t-index)`. Layout-agnostic.
8009 let mut col_sites: Vec<Vec<(usize, usize)>> = vec![Vec::new(); k_atoms];
8010 match self.last_row_layout {
8011 Some(ref layout) => {
8012 for row in 0..n {
8013 let base = cache.row_offsets[row];
8014 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8015 col_sites[atom].push((row, base + pos));
8016 }
8017 }
8018 }
8019 None => {
8020 for row in 0..n {
8021 let base = cache.row_offsets[row];
8022 for k in 0..k_atoms {
8023 col_sites[k].push((row, base + k));
8024 }
8025 }
8026 }
8027 }
8028 let mut cross = 0.0_f64;
8029 for k in 0..k_atoms {
8030 let d_k = if learnable_alpha {
8031 channels.cross_row_d_logalpha[k]
8032 } else {
8033 channels.cross_row_d[k]
8034 };
8035 if d_k == 0.0 || col_sites[k].len() < 2 {
8036 continue;
8037 }
8038 for &(i, t_i) in &col_sites[k] {
8039 let j_ik = channels.z_jac[i * k_atoms + k];
8040 if j_ik == 0.0 {
8041 continue;
8042 }
8043 // (H⁻¹) column at row `i`'s active logit-`k` slot.
8044 let mut rhs_t = Array1::<f64>::zeros(total_t);
8045 let rhs_beta = Array1::<f64>::zeros(cache.k);
8046 rhs_t[t_i] = 1.0;
8047 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8048 format!("assignment_log_strength_hessian_trace: {err}")
8049 })?;
8050 for &(j, t_j) in &col_sites[k] {
8051 if j == i {
8052 continue;
8053 }
8054 let j_jk = channels.z_jac[j * k_atoms + k];
8055 if j_jk == 0.0 {
8056 continue;
8057 }
8058 cross += d_k * solved.t[t_j] * j_ik * j_jk;
8059 }
8060 }
8061 }
8062 trace += cross;
8063 }
8064 Ok(0.5 * trace)
8065 }
8066
8067 pub(crate) fn learnable_ibp_forward_alpha_data_derivative(
8068 &self,
8069 rho: &SaeManifoldRho,
8070 target: ArrayView2<'_, f64>,
8071 ) -> Result<f64, String> {
8072 let AssignmentMode::IBPMap {
8073 temperature: _,
8074 learnable_alpha: true,
8075 ..
8076 } = self.assignment.mode
8077 else {
8078 return Ok(0.0);
8079 };
8080 let alpha = self
8081 .assignment
8082 .mode
8083 .resolved_ibp_alpha(rho)
8084 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8085 let k_atoms = self.k_atoms();
8086 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8087 let mut dprior = Array1::<f64>::zeros(k_atoms);
8088 for k in 0..k_atoms {
8089 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8090 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8091 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8092 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8093 }
8094 let n = self.n_obs();
8095 let p = self.output_dim();
8096 let row_loss_w = self.row_loss_weights.as_deref();
8097 let whitens = self
8098 .row_metric
8099 .as_ref()
8100 .is_some_and(|metric| metric.whitens_likelihood());
8101 let mut decoded = vec![0.0_f64; p];
8102 let mut fitted = Array1::<f64>::zeros(p);
8103 let mut f_rho = Array1::<f64>::zeros(p);
8104 let mut residual = Array1::<f64>::zeros(p);
8105 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8106 let mut assignments = vec![0.0_f64; k_atoms];
8107 let mut total = 0.0_f64;
8108 for row in 0..n {
8109 self.assignment
8110 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8111 fitted.fill(0.0);
8112 f_rho.fill(0.0);
8113 for k in 0..k_atoms {
8114 self.atoms[k].fill_decoded_row(row, &mut decoded);
8115 // Ungated (#1026 background-tier) atoms have a force-fixed unit
8116 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8117 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8118 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8119 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8120 // but `a_k` still varies with α through `π_k`, so it must NOT be
8121 // skipped.)
8122 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8123 0.0
8124 } else {
8125 (assignments[k] / prior[k]) * dprior[k]
8126 };
8127 for out_col in 0..p {
8128 fitted[out_col] += assignments[k] * decoded[out_col];
8129 f_rho[out_col] += da_rho * decoded[out_col];
8130 }
8131 }
8132 for out_col in 0..p {
8133 residual[out_col] = fitted[out_col] - target[[row, out_col]];
8134 }
8135 let residual_metric = match self.row_metric.as_ref() {
8136 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8137 _ => residual.to_vec(),
8138 };
8139 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8140 let mut row_dot = 0.0_f64;
8141 for out_col in 0..p {
8142 row_dot += residual_metric[out_col] * f_rho[out_col];
8143 }
8144 total += row_weight * row_dot;
8145 }
8146 Ok(total)
8147 }
8148
8149 /// Per-row spectral-deflation correction `tr((H⁻¹)_tt · (D − DΦ[D]))` for one
8150 /// evidence ρ-component, to be SUBTRACTED from the raw-derivative trace
8151 /// `tr((H⁻¹)_tt · D)` the trace otherwise accumulates.
8152 ///
8153 /// The criterion VALUE re-deflates each per-row `H_tt` at every ρ, so the
8154 /// correct evidence gradient contracts `(H⁻¹)_tt` against the deflation-map
8155 /// derivative `DΦ[D]`, not the raw `D = (∂H_raw/∂ρ)_tt`. By Daleckii–Krein,
8156 /// in the row's RAW eigenbasis `U`,
8157 /// `DΦ[D] = U (F ∘ (Uᵀ D U)) Uᵀ`, `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)`
8158 /// (raw `λ` in the denominator, conditioned `λ̃` in the numerator; the
8159 /// diagonal / degenerate entry is `f'(λₘ) = 1` for an unclamped kept
8160 /// direction and `0` otherwise). Hence `D − DΦ[D] = U ((1−F) ∘ (Uᵀ D U)) Uᵀ`,
8161 /// whose kept×kept block is `0`, deflated×deflated block is the full `M`, and
8162 /// kept(m)×deflated(i) block carries the ROTATION coefficient
8163 /// `(1−λᵢ)/(λₘ−λᵢ)`. Contracting against the FULL deflated selected-inverse
8164 /// t-block `inv_vv` (which carries the β-Schur back-substitution) captures
8165 /// both the within-row kept-subspace term and the deferred β-Schur/rotation
8166 /// coupling in one pass, matching the re-deflating fixed-state FD oracle.
8167 ///
8168 /// `spectrum = Some` (spectral deflation): exact Daleckii–Krein. `None` with a
8169 /// non-empty `dirs` (gauge-only deflation, ρ-independent structural null):
8170 /// fall back to the within-row kept-subspace term `Σᵢ vᵢᵀ D vᵢ`.
8171 /// `inv_vv` is assumed symmetric (selected inverse of a symmetric PD system).
8172 fn deflation_block_correction(
8173 inv_vv: &Array2<f64>,
8174 d_mat: &Array2<f64>,
8175 dirs: &[Array1<f64>],
8176 spectrum: Option<&RowDeflationSpectrum>,
8177 ) -> f64 {
8178 let q = inv_vv.nrows();
8179 let Some(spec) = spectrum else {
8180 // Gauge-only deflation: ρ-independent structural null → within-row term.
8181 let mut acc = 0.0_f64;
8182 for v in dirs {
8183 for a in 0..q {
8184 let va = if a < v.len() { v[a] } else { 0.0 };
8185 if va == 0.0 {
8186 continue;
8187 }
8188 for b in 0..q {
8189 let vb = if b < v.len() { v[b] } else { 0.0 };
8190 acc += va * vb * d_mat[[a, b]];
8191 }
8192 }
8193 }
8194 return acc;
8195 };
8196 let u = &spec.evecs;
8197 if u.nrows() != q || u.ncols() != q {
8198 return 0.0;
8199 }
8200 let raw = &spec.raw_evals;
8201 let cond = &spec.cond_evals;
8202 // M = Uᵀ D U, W = Uᵀ inv_vv U (both q×q, symmetric).
8203 let m = u.t().dot(d_mat).dot(u);
8204 let w = u.t().dot(inv_vv).dot(u);
8205 // correction = Σ_{m,l} W[m,l]·M[m,l]·(1 − F[m,l]).
8206 let mut acc = 0.0_f64;
8207 let eps = 1.0e-12;
8208 for a in 0..q {
8209 for b in 0..q {
8210 let denom = raw[a] - raw[b];
8211 let f1 = if denom.abs() > eps {
8212 (cond[a] - cond[b]) / denom
8213 } else if cond[a] == raw[a] {
8214 1.0
8215 } else {
8216 0.0
8217 };
8218 acc += w[[a, b]] * m[[a, b]] * (1.0 - f1);
8219 }
8220 }
8221 acc
8222 }
8223
8224 /// #1417: exact `½ tr(H⁻¹ ∂H_data/∂logα)` for LEARNABLE IBP alpha.
8225 ///
8226 /// The forward assignment is `a_ik = σ(ℓ_ik/τ)·π_k(α)` with the #614
8227 /// consistent stick-breaking mean `π_k(α) = (α/(α+1))^(k+1)`, so
8228 /// `∂logπ_k/∂logα = (k+1)/(α+1)`. EVERY data-Jacobian column for atom `k` —
8229 /// the logit-JVP row (carries one `π_k`), the coordinate rows (carry one
8230 /// `a_k`), and the β-leg (`a_k·φ`) — carries exactly ONE `a_k`/`π_k` factor
8231 /// (`σ(ℓ/τ)` is α-independent). Hence each Jacobian column scales as
8232 /// `∂J_·k/∂logα = ((k+1)/(α+1))·J_·k`, and the data Hessian block for the
8233 /// atom pair `(k_a, k_b)` scales as
8234 /// ∂H_data[a,b]/∂logα = (((k_a+1) + (k_b+1))/(α+1))·H_data[a,b].
8235 /// Therefore the exact data-block contribution to the α-logdet trace is
8236 /// ½ tr(H⁻¹ ∂H_data/∂logα)
8237 /// = ½/(α+1) · Σ_{a,b} ((k_a+1) + (k_b+1))·(H⁻¹)_{ba}·H_data[a,b],
8238 /// over the full joint `(t, β)` index set. `H_data[a,b]` is the data-fit
8239 /// Gauss-Newton block built from the SAME `row_jets_for_logdet` first-jets the
8240 /// θ-adjoint uses (`H_tt = ⟨J_a,J_b⟩`, `H_tβ = ⟨J_a,J_β⟩`, `H_ββ = ⟨J_β,J_β'⟩`),
8241 /// and `(H⁻¹)` is contracted through the same per-row selected-inverse blocks.
8242 /// This closes the learnable-α gradient: combined with the prior-Hessian
8243 /// trace (`assignment_log_strength_hessian_trace`) the full
8244 /// `½ tr(H⁻¹ ∂H/∂logα)` is now assembled. For FIXED alpha (and non-IBP modes)
8245 /// this is identically zero.
8246 pub(crate) fn learnable_ibp_data_logdet_alpha_trace(
8247 &self,
8248 rho: &SaeManifoldRho,
8249 cache: &ArrowFactorCache,
8250 solver: &DeflatedArrowSolver<'_>,
8251 ) -> Result<f64, String> {
8252 let AssignmentMode::IBPMap {
8253 learnable_alpha: true,
8254 ..
8255 } = self.assignment.mode
8256 else {
8257 return Ok(0.0);
8258 };
8259 let alpha = self
8260 .assignment
8261 .mode
8262 .resolved_ibp_alpha(rho)
8263 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8264 let inv_alpha1 = 1.0 / (alpha + 1.0);
8265 let n = self.n_obs();
8266 let total_t = cache.delta_t_len();
8267 let second_jets = self.atom_second_jets()?;
8268 let border = self.border_channels_for_cache(cache)?;
8269
8270 // β-tier selected inverse `(H⁻¹)_ββ` (shared across rows). #932 FRONT C:
8271 // on the plain bordered arrow this is the cached dense `S⁻¹` formed once
8272 // (no `K` full-system solves); when a gauge / #1038 cross-row Woodbury is
8273 // active the row-local Takahashi blocks are NOT valid, so we fall back to
8274 // the per-β-coordinate `solve` loop (bit-identical, just O(n) per call).
8275 let fast_selected = solver.plain_selected_inverse_available();
8276 let beta_inv = if cache.k == 0 {
8277 Array2::<f64>::zeros((0, 0))
8278 } else if fast_selected {
8279 solver.beta_inv().map_err(|err| {
8280 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8281 })?
8282 } else {
8283 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8284 let rhs_t = Array1::<f64>::zeros(total_t);
8285 for col in 0..cache.k {
8286 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8287 rhs_beta[col] = 1.0;
8288 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8289 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8290 })?;
8291 for r in 0..cache.k {
8292 beta_inv[[r, col]] = solved.beta[r];
8293 }
8294 }
8295 beta_inv
8296 };
8297 // Atom index of each β border channel (the `k_b` weight for the β leg).
8298 let border_atom: Vec<usize> = border.iter().map(|c| c.atom).collect();
8299
8300 let mut trace = 0.0_f64;
8301 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8302 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8303 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8304 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8305 // rows fall back to the scalar per-row path (bit-identical either way).
8306 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8307 std::collections::VecDeque::new();
8308 let mut jet_window_next = 0usize;
8309 for row in 0..n {
8310 let q = cache.row_dims[row];
8311 let base = cache.row_offsets[row];
8312 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8313 self.assignment
8314 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8315 if jet_window.is_empty() {
8316 jet_window_next = self.refill_jet_window(
8317 rho,
8318 jet_window_next,
8319 cache,
8320 &second_jets,
8321 &border,
8322 &mut jet_window,
8323 )?;
8324 }
8325 let jets = jet_window.pop_front().expect("jet window must be non-empty");
8326 // Atom index (k-weight) of each local t-var.
8327 let var_atom: Vec<usize> = jets
8328 .vars
8329 .iter()
8330 .map(|v| match *v {
8331 SaeLocalRowVar::Logit { atom } => atom,
8332 SaeLocalRowVar::Coord { atom, .. } => atom,
8333 })
8334 .collect();
8335
8336 // Per-row selected inverse blocks `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ`.
8337 // #932 FRONT C: row-local Takahashi (O(q·(q+K))) on the plain arrow;
8338 // per-row full-system `solve` loop (O(n·q)) under gauge / cross-row
8339 // Woodbury where the row-local blocks are not valid.
8340 let (inv_vv, inv_vbeta) = if fast_selected {
8341 solver
8342 .selected_inverse_row_blocks(row, &beta_inv)
8343 .map_err(|err| {
8344 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8345 })?
8346 } else {
8347 let mut inv_vv = Array2::<f64>::zeros((q, q));
8348 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8349 for col in 0..q {
8350 let mut rhs_t = Array1::<f64>::zeros(total_t);
8351 let rhs_beta = Array1::<f64>::zeros(cache.k);
8352 rhs_t[base + col] = 1.0;
8353 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8354 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8355 })?;
8356 for r in 0..q {
8357 inv_vv[[r, col]] = solved.t[base + r];
8358 }
8359 for b in 0..cache.k {
8360 inv_vbeta[[col, b]] = solved.beta[b];
8361 }
8362 }
8363 (inv_vv, inv_vbeta)
8364 };
8365
8366 // #1026 — UNGATED (background-tier) atoms have a force-fixed unit gate,
8367 // so their mass `a_k ≡ 1` is α-INDEPENDENT: every data-Jacobian column
8368 // for an ungated atom carries `a_k = 1`, NOT `π_k(α)`, so its α-exponent
8369 // is `e_k = 0`, not `k+1`. Gated atoms keep `e_k = k+1`. (The prior trace
8370 // handles ungated separately by zeroing the fixed-logit `z_jac`.)
8371 let kfac = |atom: usize| -> f64 {
8372 if self.assignment.ungated.get(atom).copied().unwrap_or(false) {
8373 0.0
8374 } else {
8375 (atom + 1) as f64
8376 }
8377 };
8378 // t–t block: Σ_{a,b} (e_a + e_b)·(H⁻¹)_{ba}·⟨J_a, J_b⟩, where the
8379 // per-atom log-prior exponent is e_k = k+1 for the #614 consistent
8380 // stick-breaking mean π_k = (α/(α+1))^(k+1) (dlogπ_k/dlogα = (k+1)·inv_alpha1).
8381 for a in 0..q {
8382 for b in 0..q {
8383 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8384 if h_ab == 0.0 {
8385 continue;
8386 }
8387 let kw = kfac(var_atom[a]) + kfac(var_atom[b]);
8388 trace += kw * inv_vv[[b, a]] * h_ab;
8389 }
8390 }
8391 // Deflation correction (kept-subspace restriction + β-Schur/rotation).
8392 // `inv_vv` is the DEFLATED selected inverse, so the t–t contraction
8393 // above contracts the RAW derivative `D` where the re-deflating
8394 // criterion uses the deflation-map derivative `DΦ[D]`. Subtract the
8395 // exact over-count `tr(inv_vv·(D − DΦ[D]))` via the Daleckii–Krein
8396 // helper, with `D_{ab} = kw_ab·⟨J_a, J_b⟩` the SAME t–t operator the
8397 // trace contracts. The t–β/β–β blocks are not deflated, so only the
8398 // t–t contraction is corrected.
8399 let dirs = cache
8400 .deflated_row_directions
8401 .get(row)
8402 .map(Vec::as_slice)
8403 .unwrap_or(&[]);
8404 if !dirs.is_empty() {
8405 let mut d_mat = Array2::<f64>::zeros((q, q));
8406 for a in 0..q {
8407 for b in 0..q {
8408 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8409 if h_ab == 0.0 {
8410 continue;
8411 }
8412 d_mat[[a, b]] = (kfac(var_atom[a]) + kfac(var_atom[b])) * h_ab;
8413 }
8414 }
8415 let spectrum = cache
8416 .deflation_row_spectra
8417 .get(row)
8418 .and_then(Option::as_ref);
8419 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
8420 }
8421 // t–β and β–t blocks: appear symmetrically, contract once with the
8422 // factor 2 (H, H⁻¹ symmetric; `(H⁻¹)_βt = (H⁻¹)_tβᵀ`).
8423 for a in 0..q {
8424 for (beta_pos, channel) in border.iter().enumerate() {
8425 let h_ab = sae_dot(&jets.first[a], &jets.beta[beta_pos]);
8426 if h_ab == 0.0 {
8427 continue;
8428 }
8429 let kw = kfac(var_atom[a]) + kfac(border_atom[beta_pos]);
8430 trace += 2.0 * kw * inv_vbeta[[a, channel.index]] * h_ab;
8431 }
8432 }
8433 // β–β block: Σ_{β,β'} (k_β + k_β')·(H⁻¹)_{β'β}·⟨J_β, J_β'⟩.
8434 for (beta_i, channel_i) in border.iter().enumerate() {
8435 for (beta_j, channel_j) in border.iter().enumerate() {
8436 let h_ab = sae_dot(&jets.beta[beta_i], &jets.beta[beta_j]);
8437 if h_ab == 0.0 {
8438 continue;
8439 }
8440 let kw = kfac(border_atom[beta_i]) + kfac(border_atom[beta_j]);
8441 trace += kw * beta_inv[[channel_i.index, channel_j.index]] * h_ab;
8442 }
8443 }
8444 }
8445 Ok(0.5 * inv_alpha1 * trace)
8446 }
8447
8448 pub(crate) fn add_learnable_ibp_forward_alpha_data_rhs(
8449 &self,
8450 rho: &SaeManifoldRho,
8451 target: ArrayView2<'_, f64>,
8452 cache: &ArrowFactorCache,
8453 t: &mut Array1<f64>,
8454 beta: &mut Array1<f64>,
8455 ) -> Result<(), String> {
8456 let AssignmentMode::IBPMap {
8457 temperature,
8458 learnable_alpha: true,
8459 ..
8460 } = self.assignment.mode
8461 else {
8462 return Ok(());
8463 };
8464 let alpha = self
8465 .assignment
8466 .mode
8467 .resolved_ibp_alpha(rho)
8468 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8469 let k_atoms = self.k_atoms();
8470 let p = self.output_dim();
8471 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8472 let mut dprior = Array1::<f64>::zeros(k_atoms);
8473 for k in 0..k_atoms {
8474 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8475 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8476 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8477 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8478 }
8479 let inv_tau = 1.0 / temperature;
8480 let row_loss_w = self.row_loss_weights.as_deref();
8481 let whitens = self
8482 .row_metric
8483 .as_ref()
8484 .is_some_and(|metric| metric.whitens_likelihood());
8485 let border = self.border_channels_for_cache(cache)?;
8486 let mut decoded_rows = vec![vec![0.0_f64; p]; k_atoms];
8487 let mut decoded_deriv = vec![0.0_f64; p];
8488 let mut fitted = Array1::<f64>::zeros(p);
8489 let mut f_rho = Array1::<f64>::zeros(p);
8490 let mut residual = Array1::<f64>::zeros(p);
8491 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8492 let mut assignments = vec![0.0_f64; k_atoms];
8493 for row in 0..self.n_obs() {
8494 self.assignment
8495 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8496 fitted.fill(0.0);
8497 f_rho.fill(0.0);
8498 for k in 0..k_atoms {
8499 self.atoms[k].fill_decoded_row(row, &mut decoded_rows[k]);
8500 // Ungated (#1026 background-tier) atoms have a force-fixed unit
8501 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8502 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8503 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8504 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8505 // but `a_k` still varies with α through `π_k`, so it must NOT be
8506 // skipped.)
8507 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8508 0.0
8509 } else {
8510 (assignments[k] / prior[k]) * dprior[k]
8511 };
8512 for out_col in 0..p {
8513 fitted[out_col] += assignments[k] * decoded_rows[k][out_col];
8514 f_rho[out_col] += da_rho * decoded_rows[k][out_col];
8515 }
8516 }
8517 for out_col in 0..p {
8518 residual[out_col] = fitted[out_col] - target[[row, out_col]];
8519 }
8520 let residual_metric = match self.row_metric.as_ref() {
8521 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8522 _ => residual.to_vec(),
8523 };
8524 let f_metric = match self.row_metric.as_ref() {
8525 Some(metric) if whitens => metric.apply_metric_row(row, f_rho.view()),
8526 _ => f_rho.to_vec(),
8527 };
8528 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8529 let row_vars = self.row_vars_for_cache_row(row, cache)?;
8530 let row_base = cache.row_offsets[row];
8531 for (pos, var) in row_vars.iter().enumerate() {
8532 let mut contribution = 0.0_f64;
8533 match *var {
8534 SaeLocalRowVar::Logit { atom } => {
8535 let sigma = assignments[atom] / prior[atom];
8536 let sigma_jac = sigma * (1.0 - sigma) * inv_tau;
8537 let da_dl = sigma_jac * prior[atom];
8538 let d_da_rho_dl = sigma_jac * dprior[atom];
8539 for out_col in 0..p {
8540 contribution += da_dl * decoded_rows[atom][out_col] * f_metric[out_col];
8541 contribution += d_da_rho_dl
8542 * decoded_rows[atom][out_col]
8543 * residual_metric[out_col];
8544 }
8545 }
8546 SaeLocalRowVar::Coord { atom, axis } => {
8547 let sigma = assignments[atom] / prior[atom];
8548 let da_rho = sigma * dprior[atom];
8549 self.atoms[atom].fill_decoded_derivative_row(row, axis, &mut decoded_deriv);
8550 for out_col in 0..p {
8551 contribution +=
8552 assignments[atom] * decoded_deriv[out_col] * f_metric[out_col];
8553 contribution +=
8554 da_rho * decoded_deriv[out_col] * residual_metric[out_col];
8555 }
8556 }
8557 }
8558 t[row_base + pos] += row_weight * contribution;
8559 }
8560 for channel in &border {
8561 let phi = self.atoms[channel.atom].basis_values[[row, channel.basis_col]];
8562 let sigma = assignments[channel.atom] / prior[channel.atom];
8563 let da_rho = sigma * dprior[channel.atom];
8564 let mut contribution = 0.0_f64;
8565 for out_col in 0..p {
8566 let output = channel.output[out_col];
8567 contribution += assignments[channel.atom] * phi * output * f_metric[out_col];
8568 contribution += da_rho * phi * output * residual_metric[out_col];
8569 }
8570 beta[channel.index] += row_weight * contribution;
8571 }
8572 }
8573 Ok(())
8574 }
8575
8576 pub(crate) fn border_channels_for_cache(
8577 &self,
8578 cache: &ArrowFactorCache,
8579 ) -> Result<Vec<SaeBorderChannel>, String> {
8580 let p = self.output_dim();
8581 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8582 let offsets = if frames_active {
8583 self.factored_beta_offsets()
8584 } else {
8585 self.beta_offsets()
8586 };
8587 let mut channels = Vec::with_capacity(cache.k);
8588 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8589 let m = atom.basis_size();
8590 let frame = if frames_active {
8591 self.frame_output_matrix(atom_idx)
8592 } else {
8593 Array2::<f64>::eye(p)
8594 };
8595 let r = frame.ncols();
8596 for basis_col in 0..m {
8597 for channel in 0..r {
8598 let mut output = vec![0.0_f64; p];
8599 for out_col in 0..p {
8600 output[out_col] = frame[[out_col, channel]];
8601 }
8602 channels.push(SaeBorderChannel {
8603 atom: atom_idx,
8604 basis_col,
8605 index: offsets[atom_idx] + basis_col * r + channel,
8606 output,
8607 });
8608 }
8609 }
8610 }
8611 if channels.len() != cache.k {
8612 return Err(format!(
8613 "border channel layout has {} entries but cache border has {}",
8614 channels.len(),
8615 cache.k
8616 ));
8617 }
8618 Ok(channels)
8619 }
8620
8621 pub(crate) fn row_vars_for_cache_row(
8622 &self,
8623 row: usize,
8624 cache: &ArrowFactorCache,
8625 ) -> Result<Vec<SaeLocalRowVar>, String> {
8626 let q_row = cache.row_dims[row];
8627 let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
8628 match self.last_row_layout {
8629 Some(ref layout) => {
8630 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8631 vars[pos] = Some(SaeLocalRowVar::Logit { atom });
8632 let start = layout.coord_starts[row][pos];
8633 let d = self.assignment.coords[atom].latent_dim();
8634 for axis in 0..d {
8635 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8636 }
8637 }
8638 }
8639 None => {
8640 let assignment_dim = self.assignment.assignment_coord_dim();
8641 let coord_offsets = self.assignment.coord_offsets();
8642 for atom in 0..assignment_dim {
8643 vars[atom] = Some(SaeLocalRowVar::Logit { atom });
8644 }
8645 for atom in 0..self.k_atoms() {
8646 let start = coord_offsets[atom];
8647 let d = self.assignment.coords[atom].latent_dim();
8648 for axis in 0..d {
8649 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8650 }
8651 }
8652 }
8653 }
8654 vars.into_iter()
8655 .enumerate()
8656 .map(|(idx, v)| {
8657 v.ok_or_else(|| {
8658 format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
8659 })
8660 })
8661 .collect()
8662 }
8663
8664 pub(crate) fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
8665 let mut out = Vec::with_capacity(self.k_atoms());
8666 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8667 let coords = self.assignment.coords[atom_idx].as_matrix();
8668 let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
8669 second.second_jet(coords.view())?
8670 } else {
8671 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
8672 format!(
8673 "logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
8674 atom.name
8675 )
8676 })?;
8677 evaluator
8678 .second_jet_dyn(coords.view())
8679 .ok_or_else(|| {
8680 format!(
8681 "logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
8682 atom.name
8683 )
8684 })??
8685 };
8686 let expected = (
8687 atom.n_obs(),
8688 atom.basis_size(),
8689 atom.latent_dim,
8690 atom.latent_dim,
8691 );
8692 if jet.dim() != expected {
8693 return Err(format!(
8694 "logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
8695 atom.name,
8696 jet.dim(),
8697 expected
8698 ));
8699 }
8700 out.push(jet);
8701 }
8702 Ok(out)
8703 }
8704
8705 // [#780 line-count gate] The per-row jet / reconstruction-channel cluster
8706 // (`reconstruction_row_program_for_logdet`, the const-generic
8707 // reconstruction / β-border channel fills and their dynamic dispatchers,
8708 // `row_jets_for_logdet`, `row_jets_for_logdet_batch4`, `batch4_assemble`,
8709 // and `refill_jet_window`) lives in the sibling
8710 // `construction_row_jet_logdet_channels.rs` file, inlined via `include!`
8711 // below at module scope as a second `impl SaeManifoldTerm` block. Splitting
8712 // it out keeps this tracked file under the 10k limit; `include!` preserves
8713 // the identical module scope and private-field access.
8714
8715 pub(crate) fn assignment_prior_hdiag_derivative_entry(
8716 &self,
8717 rho: &SaeManifoldRho,
8718 row: usize,
8719 diag_atom: usize,
8720 wrt: SaeLocalRowVar,
8721 ibp_channels: Option<&IbpHessianDiagThirdChannels>,
8722 ) -> f64 {
8723 let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
8724 return 0.0;
8725 };
8726 match self.assignment.mode {
8727 AssignmentMode::Softmax { .. } => {
8728 // #1038: the softmax entropy Hessian is now stored DENSE in
8729 // `block.htt` and its full θ-derivative `∂H_{k,j}/∂z_w` (diagonal
8730 // AND off-diagonal) is added inline in `logdet_theta_adjoint` from
8731 // the shared `row_dense_hessian_logit_derivative`. Returning the
8732 // diagonal contribution here too would double-count, so this
8733 // primitive is silent for softmax — the dense path is the single
8734 // source for value, logdet, and adjoint.
8735 0.0
8736 }
8737 AssignmentMode::JumpReLU {
8738 temperature,
8739 threshold,
8740 } => {
8741 if diag_atom != wrt_atom {
8742 return 0.0;
8743 }
8744 let logit = self.assignment.logits[[row, diag_atom]];
8745 if !crate::assignment::jumprelu_in_optimization_band(
8746 logit,
8747 threshold,
8748 temperature,
8749 ) {
8750 return 0.0;
8751 }
8752 let inv_tau = 1.0 / temperature;
8753 let activation =
8754 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
8755 let slope = activation * (1.0 - activation);
8756 // #1415: P(ℓ)=λσ((ℓ−θ)/τ); P''(ℓ)=(λ/τ²)s(1−2a) so the third
8757 // derivative is P'''(ℓ)=(λ/τ³)·s·(1−6a+6a²), because
8758 // d/dℓ[s(1−2a)] = (1/τ)s[(1−2a)²−2s] = (1/τ)s(1−6a+6a²).
8759 rho.lambda_sparse()
8760 * slope
8761 * (1.0 - 6.0 * activation + 6.0 * activation * activation)
8762 * inv_tau
8763 * inv_tau
8764 * inv_tau
8765 }
8766 AssignmentMode::IBPMap { .. } => {
8767 // The assembled `htt` diagonal consumes
8768 // `IBPAssignmentPenalty::hessian_diag`, whose logit derivative
8769 // splits into a row-local direct-`z` channel and a global
8770 // empirical-`M_k` channel (π_k couples every row in column k).
8771 // This same-row primitive returns only the LOCAL direct-`z`
8772 // channel — and only on the matching logit (`diag_atom == w`),
8773 // since H_ik depends on no other row's z explicitly. The global
8774 // M_k channel is accumulated column-wise in
8775 // `logdet_theta_adjoint` (it needs the per-row selected-inverse
8776 // diagonals), so adding it here would double-count.
8777 if diag_atom != wrt_atom {
8778 return 0.0;
8779 }
8780 match ibp_channels {
8781 Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
8782 None => 0.0,
8783 }
8784 }
8785 }
8786 }
8787
8788 pub(crate) fn ard_majorized_hessian_derivative(
8789 &self,
8790 rho: &SaeManifoldRho,
8791 row: usize,
8792 atom: usize,
8793 axis: usize,
8794 ) -> f64 {
8795 if rho.log_ard[atom].is_empty() {
8796 return 0.0;
8797 }
8798 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8799 let periods = self.assignment.coords[atom].effective_axis_periods();
8800 let t = self.assignment.coords[atom].row(row)[axis];
8801 let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
8802 if prior.hess <= 0.0 {
8803 return 0.0;
8804 }
8805 match periods[axis] {
8806 None => 0.0,
8807 Some(period) => {
8808 let kappa = std::f64::consts::TAU / period;
8809 -alpha * kappa * (kappa * t).sin()
8810 }
8811 }
8812 }
8813
8814 pub fn outer_rho_gradient_ift_rhs(
8815 &self,
8816 rho: &SaeManifoldRho,
8817 target: ArrayView2<'_, f64>,
8818 j: usize,
8819 cache: &ArrowFactorCache,
8820 ) -> Result<SaeArrowVector, String> {
8821 let n_params = rho.to_flat().len();
8822 if j >= n_params {
8823 return Err(format!(
8824 "outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
8825 ));
8826 }
8827 let mut t = Array1::<f64>::zeros(cache.delta_t_len());
8828 let mut beta = Array1::<f64>::zeros(cache.k);
8829 if j == 0 {
8830 let assignment_grad =
8831 assignment_prior_log_strength_target_mixed(&self.assignment, rho)?;
8832 let k_atoms = self.k_atoms();
8833 let assignment_dim = self.assignment.assignment_coord_dim();
8834 for row in 0..self.n_obs() {
8835 let base = cache.row_offsets[row];
8836 let assignment_base = row * k_atoms;
8837 match self.last_row_layout {
8838 Some(ref layout) => {
8839 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8840 t[base + pos] = assignment_grad[assignment_base + atom];
8841 }
8842 }
8843 None => {
8844 for free_idx in 0..assignment_dim {
8845 t[base + free_idx] = assignment_grad[assignment_base + free_idx];
8846 }
8847 }
8848 }
8849 }
8850 self.add_learnable_ibp_forward_alpha_data_rhs(rho, target, cache, &mut t, &mut beta)?;
8851 } else if (1..=rho.log_lambda_smooth.len()).contains(&j) {
8852 // #1556: coordinate `j ∈ 1..=K` is the per-atom smoothness strength
8853 // `log λ_smooth[j-1]`. `∂(penalty)/∂log λ_k = λ_k·S_k C_k` touches ONLY
8854 // atom `k = j-1`'s decoder block; every other atom's RHS is zero.
8855 let target_atom = j - 1;
8856 let lambda = rho.lambda_smooth_for(target_atom);
8857 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8858 let offsets = if frames_active {
8859 self.factored_beta_offsets()
8860 } else {
8861 self.beta_offsets()
8862 };
8863 let atom = &self.atoms[target_atom];
8864 let m = atom.basis_size();
8865 let coeffs = if frames_active {
8866 match &atom.decoder_frame {
8867 Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
8868 None => atom.decoder_coefficients.clone(),
8869 }
8870 } else {
8871 atom.decoder_coefficients.clone()
8872 };
8873 let r = coeffs.ncols();
8874 let off = offsets[target_atom];
8875 for mu in 0..m {
8876 for channel in 0..r {
8877 let mut acc = 0.0_f64;
8878 for nu in 0..m {
8879 let s_sym =
8880 0.5 * (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
8881 acc += s_sym * coeffs[[nu, channel]];
8882 }
8883 beta[off + mu * r + channel] = lambda * acc;
8884 }
8885 }
8886 } else {
8887 let mut cursor = 1 + rho.log_lambda_smooth.len();
8888 for atom in 0..rho.log_ard.len() {
8889 for axis in 0..rho.log_ard[atom].len() {
8890 if cursor == j {
8891 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8892 let periods = self.assignment.coords[atom].effective_axis_periods();
8893 for row in 0..self.n_obs() {
8894 let row_t = self.assignment.coords[atom].row(row);
8895 let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
8896 let Some(pos) = sae_coord_penalty_offset(
8897 self.last_row_layout.as_ref(),
8898 self.assignment.coord_offsets()[atom] + axis,
8899 row,
8900 atom,
8901 ) else {
8902 continue;
8903 };
8904 t[cache.row_offsets[row] + pos] = prior.grad;
8905 }
8906 return Ok(SaeArrowVector { t, beta });
8907 }
8908 cursor += 1;
8909 }
8910 }
8911 }
8912 Ok(SaeArrowVector { t, beta })
8913 }
8914
8915 pub(crate) fn logdet_theta_adjoint(
8916 &self,
8917 rho: &SaeManifoldRho,
8918 cache: &ArrowFactorCache,
8919 solver: &DeflatedArrowSolver<'_>,
8920 ) -> Result<SaeArrowVector, String> {
8921 // Γ_a = tr(H⁻¹ ∂H/∂θ_a) over the inner variables θ (#1006). `H` here is
8922 // the SAME object the evidence factor builds — Gauss-Newton data
8923 // curvature plus the prior majorizers / `hessian_diag` diagonals the
8924 // Newton/Schur Cholesky factorizes — so each block's θ-derivative channel
8925 // is differentiated on the criterion's own branch (no value/gradient
8926 // desync). The IBP-MAP assignment prior is the one block whose
8927 // `hessian_diag` couples every row in a column through the plug-in
8928 // empirical mass `M_k = Σ_i z_ik`; its logit derivative therefore has a
8929 // row-local channel (handled inline via
8930 // `assignment_prior_hdiag_derivative_entry`) and a cross-row channel
8931 // (accumulated column-wise after the row loop, below).
8932 let n = self.n_obs();
8933 let total_t = cache.delta_t_len();
8934 let mut gamma_t = Array1::<f64>::zeros(total_t);
8935 let mut gamma_beta = Array1::<f64>::zeros(cache.k);
8936 let second_jets = self.atom_second_jets()?;
8937 let border = self.border_channels_for_cache(cache)?;
8938 // #932 FRONT C: plain-arrow `(H⁻¹)_ββ = S⁻¹` formed once from the cached
8939 // Schur factor; gauge / #1038 cross-row Woodbury fall back to the per-β
8940 // `solve` loop where the row-local Takahashi blocks are not valid.
8941 let fast_selected = solver.plain_selected_inverse_available();
8942 let beta_inv = if cache.k == 0 {
8943 Array2::<f64>::zeros((0, 0))
8944 } else if fast_selected {
8945 solver
8946 .beta_inv()
8947 .map_err(|err| format!("logdet_theta_adjoint: beta selected inverse: {err}"))?
8948 } else {
8949 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8950 let rhs_t = Array1::<f64>::zeros(total_t);
8951 for col in 0..cache.k {
8952 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8953 rhs_beta[col] = 1.0;
8954 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8955 format!("logdet_theta_adjoint: beta selected inverse solve: {err}")
8956 })?;
8957 for row in 0..cache.k {
8958 beta_inv[[row, col]] = solved.beta[row];
8959 }
8960 }
8961 beta_inv
8962 };
8963 // IBP `hessian_diag` logit third-derivative channels (#1006). The full
8964 // IBP Hessian also has per-column cross-row rank-one terms
8965 // `H_(i,k),(j,k) = d_k·J_ik·J_jk`; these ARE carried in `H` via the #1038
8966 // Woodbury source (`IbpCrossRowSource`, construction.rs:4710-4752), the
8967 // ρ-trace differentiates them (#1416,
8968 // `assignment_log_strength_hessian_trace`), AND this θ-adjoint now
8969 // differentiates them exactly too: the empirical-`M_k` channel below
8970 // contracts the shared-mass coupling of the DIAGONAL curvature, and the
8971 // cross-row Woodbury pass (further below, using `cross_row_dd` and
8972 // `logit_curvature`) contracts the `∂/∂ℓ_w (d_k·J_ik·J_jk)` rank-one
8973 // derivative — so value, logdet, ρ-trace, and θ-adjoint all differentiate
8974 // the one operator `H = H₀ + Σ_k d_k u_k u_kᵀ`.
8975 let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
8976 let k_atoms = self.k_atoms();
8977 // #1038 softmax entropy: the dense per-row entropy Hessian written into
8978 // `block.htt` has off-diagonal logit terms whose θ-derivative the adjoint
8979 // must contract too (not just the diagonal). Build the SAME penalty +
8980 // `scale = λ/τ²` the assembly uses so value/logdet/adjoint differentiate
8981 // one operator. `None` for non-softmax modes (their diagonal/cross-row
8982 // channels are handled by `assignment_prior_hdiag_derivative_entry` and
8983 // the IBP column pass).
8984 let softmax_dense_adjoint: Option<(
8985 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
8986 f64,
8987 )> = match self.assignment.mode {
8988 AssignmentMode::Softmax {
8989 temperature,
8990 sparsity,
8991 } if k_atoms > 1 => {
8992 let inv_tau = 1.0 / temperature;
8993 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
8994 Some((
8995 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
8996 k_atoms,
8997 temperature,
8998 ),
8999 scale,
9000 ))
9001 }
9002 _ => None,
9003 };
9004 // Per active logit position: (row i, column k, global t-index,
9005 // (H⁻¹)_ik,ik) — the inputs to the IBP cross-row empirical-`M_k` channel.
9006 let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
9007
9008 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
9009 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
9010 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
9011 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
9012 // rows fall back to the scalar per-row path (bit-identical either way).
9013 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
9014 std::collections::VecDeque::new();
9015 let mut jet_window_next = 0usize;
9016 for row in 0..n {
9017 let q = cache.row_dims[row];
9018 let base = cache.row_offsets[row];
9019 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
9020 self.assignment
9021 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
9022 if jet_window.is_empty() {
9023 jet_window_next = self.refill_jet_window(
9024 rho,
9025 jet_window_next,
9026 cache,
9027 &second_jets,
9028 &border,
9029 &mut jet_window,
9030 )?;
9031 }
9032 let jets = jet_window.pop_front().expect("jet window must be non-empty");
9033
9034 // #932 FRONT C: row-local Takahashi on the plain arrow; per-row
9035 // full-system `solve` loop under gauge / cross-row Woodbury.
9036 let (inv_vv, inv_vbeta) = if fast_selected {
9037 solver
9038 .selected_inverse_row_blocks(row, &beta_inv)
9039 .map_err(|err| {
9040 format!("logdet_theta_adjoint: selected inverse: {err}")
9041 })?
9042 } else {
9043 let mut inv_vv = Array2::<f64>::zeros((q, q));
9044 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
9045 for col in 0..q {
9046 let mut rhs_t = Array1::<f64>::zeros(total_t);
9047 let rhs_beta = Array1::<f64>::zeros(cache.k);
9048 rhs_t[base + col] = 1.0;
9049 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
9050 format!("logdet_theta_adjoint: selected inverse solve: {err}")
9051 })?;
9052 for r in 0..q {
9053 inv_vv[[r, col]] = solved.t[base + r];
9054 }
9055 for b in 0..cache.k {
9056 inv_vbeta[[col, b]] = solved.beta[b];
9057 }
9058 }
9059 (inv_vv, inv_vbeta)
9060 };
9061
9062 // Record each active logit's column, global t-index, and
9063 // selected-inverse diagonal (H⁻¹)_ik,ik for the IBP cross-row pass.
9064 if ibp_channels.is_some() {
9065 for (pos, var) in jets.vars.iter().enumerate() {
9066 if let SaeLocalRowVar::Logit { atom } = *var {
9067 ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
9068 }
9069 }
9070 }
9071
9072 // #1419: when `w` is a logit and the assignment is softmax, the per-row
9073 // Gershgorin majorizer `D = diag(Σ_j|H_kj|)` is what the assembly wrote
9074 // into `htt` (the genuine Loewner majorizer that replaces the indefinite
9075 // exact entropy Hessian). Its full θ-derivative `∂D_{k,k}/∂z_w` (diagonal;
9076 // `∂D_kk/∂z_w = Σ_j sign(H_kj)·∂H_kj/∂z_w`) is the SAME operator the
9077 // assembly and logdet now differentiate, so value and adjoint stay on ONE
9078 // exact branch. Compute it once per logit `w` and add it at every logit
9079 // pair `(a,b)` below. The diagonal softmax case is therefore handled here,
9080 // NOT in `assignment_prior_hdiag_derivative_entry` (which returns 0 for
9081 // softmax to avoid double-counting).
9082 // #1410: the softmax majorizer θ-derivative `∂D_kk/∂z_w` is DIAGONAL
9083 // (`D` is diagonal), and the compact adjoint reads it only for this
9084 // row's `≤ top_k` active atoms. Compute the needed diagonal entry
9085 // directly from the softmax row `a` (= `assignments`, in hand) via
9086 // `active_softmax_majorizer_logit_derivative_entry`, instead of the old
9087 // per-(row, logit) full `K×K` `row_psd_majorizer_logit_derivative`
9088 // allocation. `m = Σ_j a_j l_j` is shared across all `(w, k)` pairs of
9089 // the row, so compute it once. `inv_tau` carries the softmax `∂a/∂z`
9090 // convention.
9091 let softmax_adjoint_row: Option<(&[f64], f64, f64, f64)> =
9092 match (softmax_dense_adjoint.as_ref(), self.assignment.mode) {
9093 (Some((_penalty, scale)), AssignmentMode::Softmax { temperature, .. }) => {
9094 let a = assignments
9095 .as_slice()
9096 .expect("softmax assignments row must be contiguous");
9097 let m = softmax_majorizer_log_mean(a);
9098 Some((a, m, *scale, 1.0 / temperature))
9099 }
9100 _ => None,
9101 };
9102 // Per-row UNIT-stiffness deflated directions: the selected inverse
9103 // `inv_vv` is the DEFLATED inverse (it assigns `1/λ̃ = 1` to each
9104 // `vᵢ`), so every `inv_vv`-weighted t–t contraction of `∂H/∂θ_w`
9105 // below spuriously contracts the RAW derivative where the re-deflating
9106 // criterion uses the deflation-map derivative `DΦ`. The kept-subspace Γ
9107 // subtracts `tr(inv_vv·(D − DΦ[D]))` over the t–t block via the same
9108 // Daleckii–Krein helper the ρ-traces use (the t–β / β–β blocks are not
9109 // deflated). `θ` enters only the per-row block (no cross-row Woodbury
9110 // self-downdate on the θ path), so the raw t–t derivative `D` is used
9111 // directly.
9112 let defl_dirs = cache
9113 .deflated_row_directions
9114 .get(row)
9115 .map(Vec::as_slice)
9116 .unwrap_or(&[]);
9117 let defl_spectrum = cache
9118 .deflation_row_spectra
9119 .get(row)
9120 .and_then(Option::as_ref);
9121 for w in 0..q {
9122 let mut gamma = 0.0_f64;
9123 // The active logit `w` differentiates against; `None` unless this
9124 // slot is a softmax logit on the softmax path.
9125 let softmax_d_dw: Option<(&[f64], f64, f64, f64, usize)> =
9126 match (softmax_adjoint_row, jets.vars[w]) {
9127 (Some((a, m, scale, inv_tau)), SaeLocalRowVar::Logit { atom: atom_w }) => {
9128 Some((a, m, scale, inv_tau, atom_w))
9129 }
9130 _ => None,
9131 };
9132 let mut dh_mat = Array2::<f64>::zeros((q, q));
9133 for a in 0..q {
9134 for b in 0..q {
9135 let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
9136 + sae_dot(&jets.first[a], &jets.second[b][w]);
9137 // `∂D/∂z_w` is diagonal, so it contributes only when the two
9138 // logit slots are the SAME atom (`atom_a == atom_b`).
9139 if let (
9140 Some((a_soft, m, scale, inv_tau, _atom_w)),
9141 SaeLocalRowVar::Logit { atom: atom_a },
9142 SaeLocalRowVar::Logit { atom: atom_b },
9143 ) = (softmax_d_dw, jets.vars[a], jets.vars[b])
9144 {
9145 if atom_a == atom_b {
9146 dh += active_softmax_majorizer_logit_derivative_entry(
9147 a_soft, atom_a, _atom_w, m, scale, inv_tau,
9148 );
9149 }
9150 }
9151 if a == b {
9152 dh += match jets.vars[a] {
9153 SaeLocalRowVar::Logit { atom } => self
9154 .assignment_prior_hdiag_derivative_entry(
9155 rho,
9156 row,
9157 atom,
9158 jets.vars[w],
9159 ibp_channels.as_ref(),
9160 ),
9161 SaeLocalRowVar::Coord { atom, axis } if a == w => {
9162 self.ard_majorized_hessian_derivative(rho, row, atom, axis)
9163 }
9164 _ => 0.0,
9165 };
9166 }
9167 dh_mat[[a, b]] = dh;
9168 gamma += inv_vv[[b, a]] * dh;
9169 }
9170 }
9171 if !defl_dirs.is_empty() {
9172 gamma -= Self::deflation_block_correction(
9173 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9174 );
9175 }
9176 for a in 0..q {
9177 for (beta_pos, channel) in border.iter().enumerate() {
9178 let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
9179 + sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
9180 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9181 }
9182 }
9183 for (beta_i, channel_i) in border.iter().enumerate() {
9184 for (beta_j, channel_j) in border.iter().enumerate() {
9185 let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
9186 + sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
9187 gamma += beta_inv[[channel_i.index, channel_j.index]] * dh;
9188 }
9189 }
9190 gamma_t[base + w] = gamma;
9191 }
9192
9193 for (w_beta_pos, w_channel) in border.iter().enumerate() {
9194 let mut gamma = 0.0_f64;
9195 let mut dh_mat = Array2::<f64>::zeros((q, q));
9196 for a in 0..q {
9197 for b in 0..q {
9198 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
9199 + sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
9200 dh_mat[[a, b]] = dh;
9201 gamma += inv_vv[[b, a]] * dh;
9202 }
9203 }
9204 if !defl_dirs.is_empty() {
9205 gamma -= Self::deflation_block_correction(
9206 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9207 );
9208 }
9209 for a in 0..q {
9210 for (beta_pos, channel) in border.iter().enumerate() {
9211 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
9212 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9213 }
9214 }
9215 gamma_beta[w_channel.index] += gamma;
9216 }
9217 }
9218
9219 // IBP cross-row empirical-`M_k` channel of Γ (#1006). The assembled
9220 // diagonal H_ik consumes `hessian_diag`, whose dependence on the column
9221 // mass M_k = Σ_i z_ik couples every row in a column. Differentiating
9222 // tr(H⁻¹ ∂H/∂ℓ_wk) on that shared branch:
9223 // Γ_wk += [ Σ_i (H⁻¹)_ik,ik · ∂_M H_ik ] · J_wk = C_k · J_wk,
9224 // where ∂_M H_ik = `m_channel[i*K+k]` and J_wk = `z_jac[w*K+k]`. The
9225 // row-local direct-`z` channel was already added inline above, so this
9226 // pass adds only the cross-row remainder (it spans `w ≠ i` and the
9227 // self-row M_k self-coupling, which the row-local primitive deliberately
9228 // omits to avoid double-counting).
9229 if let Some(channels) = ibp_channels.as_ref() {
9230 let mut col_coeff = vec![0.0_f64; k_atoms];
9231 for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
9232 col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
9233 }
9234 for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
9235 gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
9236 }
9237
9238 // #1416 / #1641: the EXACT cross-row Woodbury derivative of Γ. The
9239 // assembled `H` carries the per-column rank-one block
9240 // `W_k = d_k·u_k u_kᵀ` with `u_k` the J-weighted column indicator
9241 // (`u_k[slot(i,k)] = J_ik`) and `d_k = w·s'_k` (`cross_row_d[k]`). Both
9242 // `d_k` (through `M_k`) and the `u_k` entries (through `ℓ_ik`) depend on
9243 // the logits, so
9244 // ∂W_k/∂ℓ_wk = dd_k·J_wk·u_k u_kᵀ
9245 // + d_k·c_wk·(e_w u_kᵀ + u_k e_wᵀ),
9246 // where `dd_k = ∂d_k/∂M_k = w·s''_k` (`cross_row_dd[k]`),
9247 // `c_wk = ∂J_wk/∂ℓ_wk` (`logit_curvature`), and `e_w` is the unit
9248 // vector at row `w`'s logit-`k` slot.
9249 //
9250 // The θ-adjoint contracts the FULL trace `Γ_wk = tr(H⁻¹ ∂H/∂ℓ_wk)`
9251 // (NOT the `½ tr` the ρ-trace uses — `fixed_state_logdet` differentiates
9252 // the full `log|H|`, and the per-row blocks above contract `inv_vv·dh`
9253 // with no ½). Critically, the `i=j` self curvature `w·s'_k·J_ik²` of the
9254 // rank-one block lives on the assembled `htt` DIAGONAL `H_ik`, so its
9255 // derivative is ALREADY differentiated by the row-local
9256 // `local_logit_third` channel (direct-z, `i=w`) and the `m_channel`
9257 // column pass (via `M_k`) above. This Woodbury pass must therefore add
9258 // ONLY the off-diagonal `i≠j` remainder — otherwise the self term is
9259 // double-counted (the #1641 defect: the pre-fix pass summed the full
9260 // `u_k u_kᵀ` including `i=j`, AND carried the ρ-trace ½, AND dropped the
9261 // factor 2 on the symmetric `e_w u_kᵀ + u_k e_wᵀ` term). Excluding `i=j`
9262 // is also why this pass needs no deflation correction: it contracts only
9263 // DISTINCT rows, off any single-row `vᵢ`'s support (matching the
9264 // #1416 ρ-trace cross-row pass).
9265 //
9266 // Contracting `tr(H⁻¹ ∂W_k/∂ℓ_wk)` over `i≠j` only:
9267 // Γ_wk += dd_k·J_wk·( u_kᵀ H⁻¹ u_k − Σ_i P_ii·J_ik² ) (term A)
9268 // + 2·d_k·c_wk·( (H⁻¹ u_k)_{slot(w,k)} − P_ww·J_wk ) (term B),
9269 // where `P_ii = (H⁻¹)_{slot(i,k),slot(i,k)}` is the selected-inverse
9270 // diagonal recorded in `ibp_logit_sites`. The subtracted self pieces are
9271 // exactly the `i=j` terms the diagonal channels own. Both `u_kᵀ H⁻¹ u_k`
9272 // and `(H⁻¹ u_k)` come from ONE solve per column, `x_k = H⁻¹ u_k` — so
9273 // the adjoint differentiates the SAME `H = H₀ + Σ_k W_k` the
9274 // value/logdet use, closing the one-operator contract on the rank-one
9275 // block too.
9276 //
9277 // Group the column sites once (the layout is mode-agnostic: dense or
9278 // compact, `ibp_logit_sites` already carries each active logit's
9279 // global t-index AND its selected-inverse diagonal `G_ii`), then per
9280 // column build `u_k`, solve, and distribute the OFF-DIAGONAL remainder.
9281 //
9282 // #1416 FIX: the diagonal (`i = w`) parts of term A and term B are
9283 // ALREADY supplied — `diag(term A) = dd_k·J_w·Σ_i G_ii·J_i²` by the
9284 // `m_channel` column pass above (whose `m_channel = w·(s''·J² + s'·c)`
9285 // carries the `s''·J²` self piece), and `diag(term B) = 2·d_k·c_w·G_ww·J_w`
9286 // by the inline `local_logit_third` self channel (whose
9287 // `s'·2J·∂_z J` piece is exactly that). So this pass must add ONLY the
9288 // cross-row off-diagonal remainder; double-counting the diagonal here
9289 // (the pre-fix `0.5·dd·J·uᵀGu + d·c·x_w` form, which is neither the
9290 // full nor the off-diagonal value) desynced the θ-adjoint from the FD
9291 // of `log|H|`. The exact `tr(H⁻¹ ∂W_k/∂ℓ_wk)` is
9292 // Γ_wk += dd_k·J_wk·(uᵀ G u − Σ_i G_ii·J_ik²) (term A, off-diagonal)
9293 // + 2·d_k·c_wk·((G u)_w − G_ww·J_wk) (term B, off-diagonal),
9294 // with `uᵀGu = Σ_i J_ik·(Gu)_i`, `(Gu) = x_k = H⁻¹ u_k` from one solve,
9295 // and `G_ii` the per-site selected-inverse diagonal.
9296 let total_t = cache.delta_t_len();
9297 let mut col_sites: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); k_atoms];
9298 for &(row, atom, t_index, inv_diag) in &ibp_logit_sites {
9299 col_sites[atom].push((row, t_index, inv_diag));
9300 }
9301 for atom in 0..k_atoms {
9302 let d_k = channels.cross_row_d[atom];
9303 let dd_k = channels.cross_row_dd[atom];
9304 if col_sites[atom].is_empty() || (d_k == 0.0 && dd_k == 0.0) {
9305 continue;
9306 }
9307 // u_k as a full t-RHS: J at each active logit-k slot.
9308 let mut rhs_t = Array1::<f64>::zeros(total_t);
9309 let rhs_beta = Array1::<f64>::zeros(cache.k);
9310 for &(row, t_index, _g) in &col_sites[atom] {
9311 rhs_t[t_index] = channels.z_jac[row * k_atoms + atom];
9312 }
9313 let x_k = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
9314 format!("logdet_theta_adjoint: IBP cross-row Woodbury solve: {err}")
9315 })?;
9316 // (JᵀH⁻¹J)_k = u_kᵀ x_k, and the diagonal `Σ_i G_ii·J_ik²` that the
9317 // `m_channel` pass already counted (subtract it from term A so this
9318 // pass holds only the off-diagonal `i ≠ j` remainder).
9319 let mut jt_hinv_j = 0.0_f64;
9320 let mut diag_jt_g_j = 0.0_f64;
9321 for &(row, t_index, g_ii) in &col_sites[atom] {
9322 let j = channels.z_jac[row * k_atoms + atom];
9323 jt_hinv_j += j * x_k.t[t_index];
9324 diag_jt_g_j += g_ii * j * j;
9325 }
9326 let off_diag_a = jt_hinv_j - diag_jt_g_j;
9327 for &(row, t_index, g_ii) in &col_sites[atom] {
9328 let j_wk = channels.z_jac[row * k_atoms + atom];
9329 let c_wk = channels.logit_curvature[row * k_atoms + atom];
9330 // term A (off-diagonal) + term B (off-diagonal); the inline /
9331 // `m_channel` passes already added the diagonal parts.
9332 let off_diag_b = x_k.t[t_index] - g_ii * j_wk;
9333 gamma_t[t_index] += dd_k * j_wk * off_diag_a + 2.0 * d_k * c_wk * off_diag_b;
9334 }
9335 }
9336 }
9337
9338 Ok(SaeArrowVector {
9339 t: gamma_t,
9340 beta: gamma_beta,
9341 })
9342 }
9343
9344 /// #1418: apply the EXACT stationarity-Jacobian correction `ΔC·v = (A − B)·v`
9345 /// to a joint `(t, β)` vector, matrix-free and per row.
9346 ///
9347 /// `A = ∇²_θθ L` is the true inner-fit Hessian; `B` is the assembled
9348 /// evidence/Newton operator the solver factors. They differ ONLY by the three
9349 /// curvature substitutions the assembly makes for stability:
9350 /// 1. data: `B` uses Gauss-Newton `J̃J̃ᵀ`, dropping the residual curvature
9351 /// `R[a,b] = Σ_out r_out·∂²f_out/∂θ_a∂θ_b` (t–t via `jets.second`, t–β via
9352 /// `jets.beta_deriv`; the decoder is linear in β so the β–β block is 0);
9353 /// 2. softmax: `B` uses the Gershgorin majorizer `D = diag(Σ_j|H_kj|)`,
9354 /// dropping `H_entropy − D` (#1419);
9355 /// 3. periodic ARD: `B` uses `max(V'',0)`, dropping the negative part
9356 /// `min(V'',0)` (the indefinite tail past a quarter period).
9357 /// `ΔC` is the sum of exactly these three deltas, each built from the SAME
9358 /// jets / penalty curvatures the assembly and the θ-adjoint use, so
9359 /// `A = B + ΔC` is the one true Hessian. Exact on BOTH the isotropic and the
9360 /// whitened-metric paths: the data fit is `½ r_nᵀ M_n r_n`, so the residual
9361 /// curvature is `Σ_out (M_n r_n)_out·∂²f_out/∂θ_a∂θ_b` — contract the
9362 /// metric-applied √w-scaled residual `error_metric = √w·M_n r_n` (the SAME
9363 /// quantity the assembly's β-tier gradient uses) against the RAW second jets
9364 /// `jets.second`/`jets.beta_deriv` (the same raw-jet convention the whole
9365 /// θ-adjoint and the Gauss-Newton `htt = J̃J̃ᵀ = J M Jᵀ` assembly use). On the
9366 /// isotropic path `M_n = I` so `error_metric = √w·r` and `J M Jᵀ = JJᵀ`,
9367 /// recovering the plain case. The softmax / ARD deltas are logit/coord-space
9368 /// prior curvatures and carry no output metric, so they are path-independent.
9369 fn apply_exact_hessian_minus_b(
9370 &self,
9371 rho: &SaeManifoldRho,
9372 target: ArrayView2<'_, f64>,
9373 cache: &ArrowFactorCache,
9374 v: &SaeArrowVector,
9375 ) -> Result<SaeArrowVector, String> {
9376 let p = self.output_dim();
9377 let n = self.n_obs();
9378 let k_atoms = self.k_atoms();
9379 let total_t = cache.delta_t_len();
9380 let second_jets = self.atom_second_jets()?;
9381 let border = self.border_channels_for_cache(cache)?;
9382 let row_loss_w = self.row_loss_weights.as_deref();
9383 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
9384 .assignment
9385 .coords
9386 .iter()
9387 .map(|coord| coord.effective_axis_periods())
9388 .collect();
9389
9390 // Optional softmax exact-entropy-minus-majorizer delta operator (#1419).
9391 let softmax_delta: Option<(
9392 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
9393 f64,
9394 )> = match self.assignment.mode {
9395 AssignmentMode::Softmax {
9396 temperature,
9397 sparsity,
9398 } if k_atoms > 1 => {
9399 let inv_tau = 1.0 / temperature;
9400 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
9401 Some((
9402 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
9403 k_atoms,
9404 temperature,
9405 ),
9406 scale,
9407 ))
9408 }
9409 _ => None,
9410 };
9411
9412 let mut out = SaeArrowVector {
9413 t: Array1::<f64>::zeros(total_t),
9414 beta: Array1::<f64>::zeros(cache.k),
9415 };
9416 let whitens = self
9417 .row_metric
9418 .as_ref()
9419 .is_some_and(|metric| metric.whitens_likelihood());
9420 let mut decoded = vec![0.0_f64; p];
9421 let mut fitted = Array1::<f64>::zeros(p);
9422 let mut error = Array1::<f64>::zeros(p);
9423 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
9424 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
9425 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
9426 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
9427 // rows fall back to the scalar per-row path (bit-identical either way).
9428 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
9429 std::collections::VecDeque::new();
9430 let mut jet_window_next = 0usize;
9431 for row in 0..n {
9432 let q = cache.row_dims[row];
9433 let base = cache.row_offsets[row];
9434 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
9435 self.assignment
9436 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
9437 if jet_window.is_empty() {
9438 jet_window_next = self.refill_jet_window(
9439 rho,
9440 jet_window_next,
9441 cache,
9442 &second_jets,
9443 &border,
9444 &mut jet_window,
9445 )?;
9446 }
9447 let jets = jet_window.pop_front().expect("jet window must be non-empty");
9448 let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
9449
9450 // √w-scaled metric-applied per-row residual `error_metric = √w·M_n r_n`
9451 // (the SAME object the assembly's β-tier gradient contracts). The
9452 // data-fit `½ r_nᵀ M_n r_n` has residual curvature `Σ (M_n r_n)·∂²f`,
9453 // so this is exactly the residual contracted against the raw `∂²f`
9454 // jets. `M_n = I` on the isotropic path ⇒ `error_metric = √w·r`.
9455 fitted.fill(0.0);
9456 for k in 0..k_atoms {
9457 self.atoms[k].fill_decoded_row(row, &mut decoded);
9458 let a_k = assignments[k];
9459 for out_col in 0..p {
9460 fitted[out_col] += a_k * decoded[out_col];
9461 }
9462 }
9463 for out_col in 0..p {
9464 error[out_col] = sqrt_row_w * (fitted[out_col] - target[[row, out_col]]);
9465 }
9466 let error_metric: Vec<f64> = match self.row_metric.as_ref() {
9467 Some(metric) if whitens => metric.apply_metric_row(row, error.view()),
9468 _ => error.to_vec(),
9469 };
9470
9471 // Local t-slice of `v` for this row.
9472 let v_t: Vec<f64> = (0..q).map(|c| v.t[base + c]).collect();
9473
9474 // (1a) residual curvature, t–t: ΔC_tt[a,b] = ⟨r, ∂²f_ab⟩.
9475 for a in 0..q {
9476 let mut acc = 0.0_f64;
9477 for b in 0..q {
9478 let r_ab = sae_dot(&error_metric, &jets.second[a][b]);
9479 acc += r_ab * v_t[b];
9480 }
9481 out.t[base + a] += acc;
9482 }
9483 // (1b) residual curvature, t–β and β–t: ΔC_tβ[a,β] = ⟨r, ∂²f_aβ⟩.
9484 // `jets.beta_deriv[a][β]` = ∂(∂f/∂β_β)/∂θ_a (the mixed second jet).
9485 for a in 0..q {
9486 for (beta_pos, channel) in border.iter().enumerate() {
9487 let r_ab = sae_dot(&error_metric, &jets.beta_deriv[a][beta_pos]);
9488 // t row picks up β leg of v; β row picks up t leg of v.
9489 out.t[base + a] += r_ab * v.beta[channel.index];
9490 out.beta[channel.index] += r_ab * v_t[a];
9491 }
9492 }
9493
9494 // (2) softmax: ΔC_logit = (H_entropy − D) over the free logits, where
9495 // `D = diag(Σ_j|H_kj|)` is the Gershgorin majorizer the assembled `B`
9496 // wrote into the logit block (#1419). Adding `H_entropy − D` recovers the
9497 // EXACT entropy curvature `A = B + ΔC`, so the solver's exact-Hessian
9498 // correction differentiates the SAME operator the assembly installed.
9499 if let Some((_penalty, scale)) = softmax_delta.as_ref() {
9500 let assignment_dim = self.assignment.assignment_coord_dim();
9501 // #1410: the correction only contracts the ACTIVE logit slots
9502 // (`jets.vars` carries the row's `≤ top_k` active atoms on the
9503 // compact layout), so build only the active sub-block of
9504 // `ΔC = H_entropy − D` ENTRY-WISE rather than materialising the
9505 // full `K×K` `row_dense_hessian` / `row_psd_majorizer` matrices per
9506 // row (an `O(K²)`-per-row allocation that defeated the compact
9507 // contract at the LLM shape). `D` is diagonal, so it subtracts only
9508 // on `ka == kb`; the off-diagonal `H_entropy` entries come from the
9509 // shared `(a, l, m)` algebra. The softmax row `a_soft` is the one
9510 // irreducible `O(K)` term, computed once per row.
9511 // #1557 — reuse this iteration's `assignments` (bit-identical).
9512 let a_soft = assignments
9513 .as_slice()
9514 .expect("softmax assignments row must be contiguous");
9515 let m = softmax_majorizer_log_mean(a_soft);
9516 for (a, va) in jets.vars.iter().enumerate() {
9517 let SaeLocalRowVar::Logit { atom: ka } = *va else {
9518 continue;
9519 };
9520 if ka >= assignment_dim {
9521 continue;
9522 }
9523 let mut acc = 0.0_f64;
9524 for (b, vb) in jets.vars.iter().enumerate() {
9525 let SaeLocalRowVar::Logit { atom: kb } = *vb else {
9526 continue;
9527 };
9528 if kb >= assignment_dim {
9529 continue;
9530 }
9531 let h_entropy =
9532 softmax_dense_entropy_hessian_entry(a_soft, ka, kb, m, *scale);
9533 // `D` is the diagonal Gershgorin majorizer (#1419), so it
9534 // contributes only on the diagonal `ka == kb`.
9535 let delta = if ka == kb {
9536 h_entropy
9537 - active_softmax_gershgorin_majorizer_entry(a_soft, ka, m, *scale)
9538 } else {
9539 h_entropy
9540 };
9541 acc += delta * v_t[b];
9542 }
9543 out.t[base + a] += acc;
9544 }
9545 }
9546
9547 // (3) periodic ARD: ΔC_coord = (V'' − max(V'',0)) = min(V'',0), diagonal.
9548 for (a, va) in jets.vars.iter().enumerate() {
9549 let SaeLocalRowVar::Coord { atom, axis } = *va else {
9550 continue;
9551 };
9552 if rho.log_ard[atom].is_empty() {
9553 continue;
9554 }
9555 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
9556 let t_val = self.assignment.coords[atom].row(row)[axis];
9557 let prior = ArdAxisPrior::eval(alpha, t_val, ard_axis_periods[atom][axis]);
9558 let neg = prior.hess.min(0.0);
9559 if neg != 0.0 {
9560 out.t[base + a] += neg * v_t[a];
9561 }
9562 }
9563 }
9564 Ok(out)
9565 }
9566
9567 /// #1418: matrix-free apply of the EXACT stationarity Jacobian `A = ∇²_θθ L`:
9568 /// `A v = B v + ΔC v`, the assembled arrow Hessian apply
9569 /// ([`apply_cached_arrow_hessian`]) plus the matrix-free dropped-curvature
9570 /// correction `ΔC = A − B` ([`Self::apply_exact_hessian_minus_b`]).
9571 fn apply_exact_hessian(
9572 &self,
9573 rho: &SaeManifoldRho,
9574 target: ArrayView2<'_, f64>,
9575 cache: &ArrowFactorCache,
9576 v: &SaeArrowVector,
9577 ) -> Result<SaeArrowVector, String> {
9578 let b_v = apply_cached_arrow_hessian(cache, v.t.view(), v.beta.view())?;
9579 let dc_v = self.apply_exact_hessian_minus_b(rho, target, cache, v)?;
9580 Ok(SaeArrowVector {
9581 t: &b_v.t + &dc_v.t,
9582 beta: &b_v.beta + &dc_v.beta,
9583 })
9584 }
9585
9586 /// #1418: solve `A x = rhs` for the EXACT stationarity Jacobian `A = ∇²_θθ L`
9587 /// via `B`-preconditioned CG ([`solve_b_preconditioned_cg`]) with the
9588 /// matrix-free `A v = B v + ΔC v` apply ([`Self::apply_exact_hessian`]). The
9589 /// IFT step `θ̂_ρ = −A⁻¹ g_ρ` must invert the EXACT `A`, not the surrogate `B`;
9590 /// CG converges for any `ρ(B⁻¹ΔC)`, where the earlier Neumann series diverged
9591 /// once the dropped curvature `ΔC = ⟨r, ∂²f⟩` grew (large unmodellable residual).
9592 fn solve_exact_stationarity(
9593 &self,
9594 rho: &SaeManifoldRho,
9595 target: ArrayView2<'_, f64>,
9596 cache: &ArrowFactorCache,
9597 solver: &DeflatedArrowSolver<'_>,
9598 rhs: &SaeArrowVector,
9599 ) -> Result<SaeArrowVector, String> {
9600 solve_b_preconditioned_cg(solver, rhs, |v| {
9601 self.apply_exact_hessian(rho, target, cache, v)
9602 })
9603 }
9604
9605 /// Analytic SAE REML outer-ρ gradient components at the already converged
9606 /// inner state represented by `loss` and `cache`.
9607 ///
9608 /// The returned gradient is the assembled analytic outer derivative:
9609 /// explicit penalty terms, direct logdet traces, Occam terms, and the #1006
9610 /// implicit-state third-order correction.
9611 pub(crate) fn analytic_outer_rho_gradient_components(
9612 &self,
9613 target: ArrayView2<'_, f64>,
9614 rho: &SaeManifoldRho,
9615 loss: &SaeManifoldLoss,
9616 cache: &ArrowFactorCache,
9617 solver: &DeflatedArrowSolver<'_>,
9618 ) -> Result<SaeOuterRhoGradientComponents, OuterGradientError> {
9619 let n_params = rho.to_flat().len();
9620 let mut explicit = Array1::<f64>::zeros(n_params);
9621 let mut logdet_trace = Array1::<f64>::zeros(n_params);
9622 let mut occam = Array1::<f64>::zeros(n_params);
9623 let mut third_order_correction = Array1::<f64>::zeros(n_params);
9624
9625 explicit[0] = assignment_prior_log_strength_derivative(&self.assignment, rho)
9626 + self
9627 .learnable_ibp_forward_alpha_data_derivative(rho, target)
9628 .map_err(OuterGradientError::internal)?;
9629 // #1417: the FULL `½ tr(H⁻¹ ∂H/∂logα)` for the assignment coordinate.
9630 // For LEARNABLE IBP alpha the forward assignments `a_ik = σ(ℓ/τ)·π_k(α)`
9631 // carry an explicit α-dependence (`∂logπ_k/∂logα = k/(α+1)`), so BOTH the
9632 // assignment-prior Hessian AND the data Gauss-Newton blocks
9633 // `H_ββ`, `H_tβ`, `H_tt` depend on logα. We assemble both traces:
9634 // • prior: `assignment_log_strength_hessian_trace`,
9635 // • data: `learnable_ibp_data_logdet_alpha_trace` (#1417), using the
9636 // exact `(k_a+k_b)/(α+1)` block-scaling identity.
9637 // For FIXED alpha (and non-IBP modes) the data term is identically zero,
9638 // so the fixed-alpha gradient is unchanged and exact.
9639 logdet_trace[0] = self
9640 .assignment_log_strength_hessian_trace(rho, cache, solver)
9641 .map_err(OuterGradientError::internal)?
9642 + self
9643 .learnable_ibp_data_logdet_alpha_trace(rho, cache, solver)
9644 .map_err(OuterGradientError::internal)?;
9645
9646 // #1556: λ_smooth is per-atom, so the smoothness gradient block occupies
9647 // flat indices `1..1+K` (one per atom), not a single index 1. Each atom
9648 // `k` carries its own explicit penalty-energy derivative, log|H| trace,
9649 // and Occam-normalizer derivative.
9650 let k_smooth = rho.log_lambda_smooth.len();
9651 let lambda_smooth_vec = rho.lambda_smooth_vec();
9652 // Explicit `∂loss.smoothness/∂log λ_k = 0.5·λ_k·<B_k, S_k B_k>` (the
9653 // per-atom split). Its sum is the λ-scaled penalty energy; renormalize to
9654 // `loss.smoothness` so the total matches the criterion's reported energy
9655 // bit-for-bit (folding in any minibatch `penalty_scale` baked into it).
9656 let mut smooth_explicit = self.decoder_smoothness_value_per_atom(&lambda_smooth_vec);
9657 let smooth_explicit_sum: f64 = smooth_explicit.iter().sum();
9658 if smooth_explicit_sum.abs() > 0.0 {
9659 let renorm = loss.smoothness / smooth_explicit_sum;
9660 for v in smooth_explicit.iter_mut() {
9661 *v *= renorm;
9662 }
9663 }
9664 let smooth_logdet = self
9665 .decoder_smoothness_effective_dof_with_solver_per_atom(
9666 cache,
9667 solver,
9668 &lambda_smooth_vec,
9669 )
9670 .map_err(|err| OuterGradientError::InternalInvariant {
9671 reason: format!("analytic_outer_rho_gradient_components: {err}"),
9672 })?;
9673 let smooth_occam = self
9674 .reml_occam_log_lambda_smooth_derivative()
9675 .map_err(OuterGradientError::internal)?;
9676 for atom_idx in 0..k_smooth {
9677 explicit[1 + atom_idx] = smooth_explicit[atom_idx];
9678 logdet_trace[1 + atom_idx] = 0.5 * smooth_logdet[atom_idx];
9679 occam[1 + atom_idx] = -smooth_occam[atom_idx];
9680 }
9681
9682 let ard_explicit = self
9683 .ard_log_precision_explicit_derivatives(rho)
9684 .map_err(OuterGradientError::internal)?;
9685 let ard_trace = self
9686 .ard_log_precision_hessian_trace(rho, cache, solver)
9687 .map_err(|err| OuterGradientError::InternalInvariant {
9688 reason: format!("analytic_outer_rho_gradient_components: {err}"),
9689 })?;
9690 let mut cursor = 1 + k_smooth;
9691 for k in 0..rho.log_ard.len() {
9692 for axis in 0..rho.log_ard[k].len() {
9693 explicit[cursor] = ard_explicit[k][axis];
9694 logdet_trace[cursor] = ard_trace[k][axis];
9695 cursor += 1;
9696 }
9697 }
9698
9699 let gamma = self
9700 .logdet_theta_adjoint(rho, cache, solver)
9701 .map_err(OuterGradientError::internal)?;
9702 // #1418: the implicit-function correction is `−½·Γᵀ·θ̂_ρ` with
9703 // `θ̂_ρ = −A⁻¹ g_ρ`, where `A = ∇²_θθ L` is the EXACT stationarity
9704 // Jacobian of the inner fit — data residual curvature, exact softmax
9705 // entropy Hessian, exact periodic ARD curvature. The matrix the `solver`
9706 // factors is `B` (Gauss-Newton data curvature, softmax Fisher metric,
9707 // `max(V'',0)` ARD majorizers): the `½log|B|` Laplace term is consistent
9708 // with `Γ = ½tr(B⁻¹ ∂B/∂θ)`, but the implicit step is governed by `A`.
9709 // `solve_exact_stationarity` applies the TRUE `A⁻¹` via a B⁻¹-
9710 // preconditioned Neumann fixed point (`A = B + ΔC`,
9711 // `ΔC = apply_exact_hessian_minus_b`), so the correction is no longer
9712 // biased by `(B⁻¹ − A⁻¹)`.
9713 for coord in 0..n_params {
9714 let rhs = self
9715 .outer_rho_gradient_ift_rhs(rho, target, coord, cache)
9716 .map_err(OuterGradientError::internal)?;
9717 let solved = self
9718 .solve_exact_stationarity(rho, target, cache, solver, &rhs)
9719 .map_err(OuterGradientError::internal)?;
9720 let mut dot = 0.0_f64;
9721 for idx in 0..gamma.t.len() {
9722 dot += gamma.t[idx] * solved.t[idx];
9723 }
9724 for idx in 0..gamma.beta.len() {
9725 dot += gamma.beta[idx] * solved.beta[idx];
9726 }
9727 third_order_correction[coord] = -0.5 * dot;
9728 }
9729
9730 Ok(SaeOuterRhoGradientComponents {
9731 explicit,
9732 logdet_trace,
9733 occam,
9734 third_order_correction,
9735 })
9736 }
9737
9738 /// Public analytic outer-ρ gradient at a converged inner state, constructing
9739 /// the deflated arrow solver from the supplied cache. Use this seam from
9740 /// integration tests and external consumers that have a converged
9741 /// `(loss, cache)` from [`Self::reml_criterion_with_cache`] but no access to
9742 /// the crate-private `DeflatedArrowSolver`.
9743 pub fn analytic_outer_rho_gradient_at_converged(
9744 &self,
9745 target: ArrayView2<'_, f64>,
9746 rho: &SaeManifoldRho,
9747 loss: &SaeManifoldLoss,
9748 cache: &ArrowFactorCache,
9749 ) -> Result<SaeOuterRhoGradientComponents, String> {
9750 let solver = self.outer_gradient_arrow_solver(cache, &rho.lambda_smooth_vec())?;
9751 self.analytic_outer_rho_gradient_components(target, rho, loss, cache, &solver)
9752 .map_err(|e| e.to_string())
9753 }
9754
9755 /// Compose the SAE LAML criterion as a sum of atoms (#931 SAE pilot).
9756 ///
9757 /// This is the single seam that establishes value↔gradient coherence for
9758 /// the SAE objective: it runs the inner solve once via
9759 /// [`Self::reml_criterion_with_cache`], reads the value decomposition
9760 /// (`loss.total() + extra_penalty_energy`, `log|H|`, `occam`) and the
9761 /// matching gradient channels (`SaeOuterRhoGradientComponents`) from the
9762 /// SAME converged cache, and hands them to [`SaeCriterion::assemble`]. The
9763 /// returned criterion's [`SaeCriterion::value`] and
9764 /// [`SaeCriterion::gradient`] are then projections of one factorization —
9765 /// the outer optimizer can no longer evaluate a value path and a gradient
9766 /// path that disagree (the #752/#748/#901 desync class). The
9767 /// implicit-stationarity envelope correction (#1006's Γ term) is its own
9768 /// named atom, so the channel the desync class keeps dropping is visible
9769 /// rather than a silent zero.
9770 pub fn criterion_as_atoms(
9771 &mut self,
9772 target: ArrayView2<'_, f64>,
9773 rho: &SaeManifoldRho,
9774 registry: Option<&AnalyticPenaltyRegistry>,
9775 inner_max_iter: usize,
9776 learning_rate: f64,
9777 ridge_ext_coord: f64,
9778 ridge_beta: f64,
9779 ) -> Result<SaeCriterion, String> {
9780 let (_v, loss, cache) = self.reml_criterion_with_cache(
9781 target,
9782 rho,
9783 registry,
9784 inner_max_iter,
9785 learning_rate,
9786 ridge_ext_coord,
9787 ridge_beta,
9788 )?;
9789 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
9790 "criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
9791 })?;
9792 let occam = self.reml_occam_term(rho)?;
9793 let extra_penalty_energy = match registry {
9794 Some(reg) => self
9795 .reml_extra_penalty_value_total(reg)
9796 .map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
9797 None => 0.0,
9798 };
9799 let data_fit_priors_value = loss.total() + extra_penalty_energy;
9800
9801 let solver = self.outer_gradient_arrow_solver(&cache, &rho.lambda_smooth_vec())?;
9802 let components =
9803 self.analytic_outer_rho_gradient_components(target, rho, &loss, &cache, &solver)?;
9804 Ok(SaeCriterion::assemble(
9805 data_fit_priors_value,
9806 log_det,
9807 occam,
9808 components.explicit,
9809 components.logdet_trace,
9810 components.occam,
9811 components.third_order_correction,
9812 ))
9813 }
9814
9815 // [#780 line-count gate] reconstruction_dispersion + assemble_shape_uncertainty
9816 // + complete_born_atom_shape_bands + shape_uncertainty_without_decoder_covariance
9817 // (the contiguous trailing methods of this impl block) were split into the
9818 // sibling construction_reconstruction.rs (declared in mod.rs); callers reach
9819 // them bare via use super::*.
9820}
9821
9822// [#780 line-count gate] Per-row jet / reconstruction-channel assembly for the
9823// streaming-exact arrow log-det lives in a sibling file as a second
9824// `impl SaeManifoldTerm` block, inlined here so it keeps the SAME module scope
9825// and private-field access. Keeps this tracked file under the 10k limit.
9826include!("construction_row_jet_logdet_channels.rs");
9827
9828// [#780 line-count gate] `term_from_padded_blocks_with_mode` (the padded-FFI
9829// term builder) was split into the sibling `construction_padded_blocks.rs`
9830// module (declared and re-exported from `mod.rs`), keeping this tracked file
9831// under the 10k limit. Callers still reach it bare through `use super::*`.
9832
9833// [#780 line-count gate] `refresh_isometry_caches_from_atom` and
9834// `refresh_isometry_caches_from_term` were split into the sibling
9835// `construction_cache_refresh.rs` module (declared and re-exported from
9836// `mod.rs`), keeping this tracked file under the 10k limit. Callers still reach
9837// both functions bare through `use super::*`.
9838
9839// [#780 line-count gate] The `#[cfg(test)]` modules below the production code
9840// are mechanically split into a sibling `*_tests` file and inlined via
9841// `include!` (the sanctioned cohesive-module decomposition — see build.rs
9842// file_stem_is_exempt_test_module). Keeps this tracked file under the 10k limit.
9843include!("construction_tests.rs");