gam_sae/manifold/construction.rs
1use super::*;
2use gam_math::jet_scalar::JetScalar;
3
4// [#780] Softmax-entropy Gershgorin majorizer leaf helpers live in a sibling
5// cohesive module, inlined here so they share this module scope.
6include!("softmax_entropy_majorizer.rs");
7
8/// Typed error from the SAE outer-gradient analytic assembly path (#1436).
9///
10/// The `eval()` analytic fallback (#1273/#1440: the plain undeflated analytic
11/// outer gradient, NOT a finite difference) must fire ONLY for the genuine
12/// conditioning/identifiability failure modes it was designed for — a
13/// near-singular-but-valid joint Hessian or a gauge-degenerate direction.
14/// Shape/indexing bugs, non-finite intermediates, and violated internal
15/// invariants are [`OuterGradientError::InternalInvariant`] and MUST propagate
16/// as hard errors so regressions surface instead of being silently masked by a
17/// degraded descent direction.
18#[derive(Clone, Debug)]
19pub(crate) enum OuterGradientError {
20 /// Expected: near-singular or ill-conditioned joint Hessian at a feasible ρ
21 /// (the genuine #1273 flat-valley case). Eligible for the FD fallback.
22 IllConditioned { reason: String },
23 /// Expected: a non-identifiable / gauge-degenerate direction at this ρ.
24 /// Eligible for the FD fallback.
25 NonIdentifiable { reason: String },
26 /// Unexpected: shape/dimension mismatch, non-finite intermediate, or a
27 /// violated internal invariant. MUST propagate — never fall back to FD.
28 InternalInvariant { reason: String },
29}
30
31impl OuterGradientError {
32 /// Whether this error class is recoverable by the #1273/#1440 analytic
33 /// plain-solver fallback (i.e. it represents a legitimate
34 /// conditioning/identifiability failure, not a programming/invariant defect).
35 pub(crate) fn is_conditioning_recoverable(&self) -> bool {
36 matches!(
37 self,
38 Self::IllConditioned { .. } | Self::NonIdentifiable { .. }
39 )
40 }
41
42 /// Construct an [`OuterGradientError::InternalInvariant`] from any error
43 /// displayable — the default classification for unexpected assembly failures
44 /// (shape mismatches, non-finite intermediates, violated invariants).
45 pub(crate) fn internal<E: std::fmt::Display>(err: E) -> Self {
46 Self::InternalInvariant {
47 reason: err.to_string(),
48 }
49 }
50
51 /// #1451 — classify a `String` error surfaced by the deflation linear-algebra
52 /// path (`apply_cached_arrow_hessian`, `DeflatedArrowSolver::from_orthonormal_gauges`)
53 /// into the correct [`OuterGradientError`] class.
54 ///
55 /// A genuine rank-deficiency / near-singularity failure (a back-solve or
56 /// Cholesky/Woodbury factor that tripped on a finite, correctly-shaped input)
57 /// is a legitimate #1273 conditioning failure and keeps `conditioning_err`
58 /// (`IllConditioned`), so it stays recoverable by the analytic fallback. A
59 /// shape/dimension mismatch or a non-finite intermediate is an
60 /// internal-invariant defect and MUST propagate ([`Self::internal`]) instead
61 /// of being masked as a plausible-but-wrong descent direction — exactly the
62 /// #1436 contract.
63 ///
64 /// The two solver helpers return `String` (not a typed error), so the
65 /// distinction is drawn from the stable markers those helpers emit for their
66 /// shape/non-finite guards (`vector shapes`, `gauge length`, `must be finite`,
67 /// `non-finite`). Everything else — including the `cholesky`/back-solve
68 /// near-singular failures — is treated as a genuine conditioning trip.
69 pub(crate) fn classify_arrow_solver_error(message: &str, conditioning_err: Self) -> Self {
70 let lower = message.to_ascii_lowercase();
71 let is_internal = lower.contains("vector shapes")
72 || lower.contains("gauge length")
73 || lower.contains("solution length")
74 || lower.contains("!= cache")
75 || lower.contains("must be finite")
76 || lower.contains("non-finite")
77 || lower.contains("not finite")
78 || lower.contains("nan")
79 || lower.contains("inf");
80 if is_internal {
81 Self::internal(message)
82 } else {
83 conditioning_err
84 }
85 }
86
87 /// The exact gate the gradient lane (`SaeManifoldOuterObjective::eval`) uses
88 /// to decide whether to descend with the #1273/#1440 analytic plain-solver
89 /// fallback instead of propagating the error as a hard failure.
90 ///
91 /// The fallback is admissible ONLY when BOTH hold:
92 /// * the REML cost at this rho is finite (a genuinely feasible point -- the
93 /// plain analytic solver supplies a descent direction for a value the
94 /// analytic path already produced), and
95 /// * the error is a legitimate conditioning/identifiability failure
96 /// ([`Self::is_conditioning_recoverable`]) -- the genuine #1273 flat-valley
97 /// case.
98 ///
99 /// A non-finite cost or an [`OuterGradientError::InternalInvariant`] must
100 /// propagate: masking a shape/indexing bug, a non-finite intermediate, or a
101 /// violated invariant behind a plausible-but-wrong step is exactly the
102 /// regression #1436 closes. Centralising the decision here (rather than
103 /// inlining the boolean at the call site) makes the `cost x error-class`
104 /// contract a single, directly unit-testable predicate.
105 pub(crate) fn admits_plain_solver_fallback(&self, cost: f64) -> bool {
106 cost.is_finite() && self.is_conditioning_recoverable()
107 }
108}
109
110impl std::fmt::Display for OuterGradientError {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 Self::IllConditioned { reason } => write!(f, "ill-conditioned: {reason}"),
114 Self::NonIdentifiable { reason } => write!(f, "non-identifiable: {reason}"),
115 Self::InternalInvariant { reason } => write!(f, "internal invariant: {reason}"),
116 }
117 }
118}
119
120impl From<OuterGradientError> for String {
121 fn from(e: OuterGradientError) -> String {
122 e.to_string()
123 }
124}
125
126/// Active-set layout override for [`SaeManifoldTerm::assemble_arrow_schur_inner`].
127///
128/// `None` is the production path: the layout is derived from the assignment mode
129/// and `sparse_active_plan`. `Some(layout_opt)` pins a specific layout — dense
130/// (`Some(None)`) or a chosen compact `SaeRowLayout` (`Some(Some(..))`) — so the
131/// compact-vs-dense Riemannian-geometry equality regression can drive both code
132/// paths on identical data without depending on the host/device memory budget
133/// that gates the compact path in production.
134pub(crate) type ForcedRowLayout = Option<Option<SaeRowLayout>>;
135
136/// #1154 — base co-training weight for the amortized-encoder reconstruction
137/// consistency penalty, as a fraction of the REML criterion magnitude. The
138/// effective weight is `COTRAIN_RECON_WEIGHT · max(|REML|, 1)`, so the penalty
139/// is a bounded, scale-free share of the objective and needs no caller knob.
140pub(crate) const COTRAIN_RECON_WEIGHT: f64 = 0.1;
141
142/// #1154 — base co-training weight for the encoder's certifiable-coverage
143/// penalty (the fraction of (row, atom) encodes the Kantorovich certificate
144/// rejected). Scaled like [`COTRAIN_RECON_WEIGHT`].
145pub(crate) const COTRAIN_CERT_WEIGHT: f64 = 0.05;
146
147/// #1154 — amortized-encoder consistency of a fitted dictionary against its own
148/// fit-time target. The co-training signal of the joint amortized-encoder +
149/// REML loop: how faithfully (and how certifiably) the cheap one-mat-vec
150/// encoder inverts the dictionary the inner solve converged to.
151#[derive(Debug, Clone, Copy)]
152pub struct AmortizedEncoderConsistency {
153 /// Mean per-element squared gap between the amortized reconstruction and the
154 /// exact fitted reconstruction (`‖x̂_amortized − x̂_exact‖² / (n·p)`). Zero ⇒
155 /// the IFT predictor reproduces the encode map exactly to first order.
156 pub recon_consistency: f64,
157 /// Fraction of (row, atom) amortized encodes whose Kantorovich certificate
158 /// failed (`h > ½`) and fell back to the certified Newton encode.
159 pub uncertified_fraction: f64,
160 /// Count of uncertified (row, atom) encodes (numerator of the fraction).
161 pub n_uncertified: usize,
162 /// Total (row, atom) encodes scored (`n · K`).
163 pub n_encodes: usize,
164}
165
166impl SaeManifoldTerm {
167 #[must_use = "build error must be handled"]
168 pub fn new(atoms: Vec<SaeManifoldAtom>, assignment: SaeAssignment) -> Result<Self, String> {
169 if atoms.is_empty() {
170 return Err("SaeManifoldTerm::new: at least one atom required".into());
171 }
172 let n = atoms[0].n_obs();
173 let p = atoms[0].output_dim();
174 if assignment.n_obs() != n || assignment.k_atoms() != atoms.len() {
175 return Err(format!(
176 "SaeManifoldTerm::new: assignment shape ({}, {}) does not match atoms ({n}, {})",
177 assignment.n_obs(),
178 assignment.k_atoms(),
179 atoms.len()
180 ));
181 }
182 for (k, atom) in atoms.iter().enumerate() {
183 if atom.n_obs() != n {
184 return Err(format!(
185 "SaeManifoldTerm::new: atom {k} has n_obs={} but atom 0 has {n}",
186 atom.n_obs()
187 ));
188 }
189 if atom.output_dim() != p {
190 return Err(format!(
191 "SaeManifoldTerm::new: atom {k} output_dim={} but atom 0 has {p}",
192 atom.output_dim()
193 ));
194 }
195 if atom.latent_dim != assignment.coords[k].latent_dim() {
196 return Err(format!(
197 "SaeManifoldTerm::new: atom {k} latent_dim={} but assignment coord has {}",
198 atom.latent_dim,
199 assignment.coords[k].latent_dim()
200 ));
201 }
202 }
203 Ok(Self {
204 atoms,
205 assignment,
206 temperature_schedule: None,
207 last_row_layout: None,
208 row_metric: None,
209 collapse_events: Vec::new(),
210 row_loss_weights: None,
211 last_frames_active: false,
212 assembly_chunk_override: None,
213 fixed_decoder_assembly: false,
214 softmax_active_cap: None,
215 border_hbb_workspace: Array2::<f64>::zeros((0, 0)),
216 certificate_dispersion: None,
217 curvature_walk_report: None,
218 expected_evidence_gauge_deflated_directions: None,
219 evidence_gauge_deflation_reanchors: 0,
220 evidence_gauge_deflation_last_delta_sign: 0,
221 dictionary_cocollapse_reseeds: 0,
222 best_cocollapse_incumbent: None,
223 decoder_repulsion_gate: None,
224 barrier_coactivation_gate: None,
225 hybrid_split_report: None,
226 atom_inner_fits: None,
227 oos_linear_images: None,
228 })
229 }
230
231 /// #1408/#1409 — install the optional hard per-row active-atom cap for
232 /// Softmax mode (threaded from the fit/encode `top_k`). A `Some(k)` with
233 /// `1 <= k < K` makes the Softmax assignment optimize on the COMPACT
234 /// top-`k` row layout (see [`Self::softmax_active_cap`]); `Some(k) >= K`
235 /// and `None` are both no-ops (full support). Non-softmax modes ignore it.
236 pub fn set_softmax_active_cap(&mut self, top_k: Option<usize>) {
237 self.softmax_active_cap = match top_k {
238 Some(k) if k >= 1 && k < self.k_atoms() => Some(k),
239 _ => None,
240 };
241 }
242
243 /// Install the fitted reconstruction dispersion used by
244 /// [`dictionary_incoherence_report`]. This is a pure diagnostic scalar and
245 /// does not feed any loss, criterion, penalty, or optimizer state.
246 pub fn set_certificate_dispersion(&mut self, dispersion: f64) -> Result<(), String> {
247 if !dispersion.is_finite() || dispersion <= 0.0 {
248 return Err(format!(
249 "SaeManifoldTerm::set_certificate_dispersion: dispersion must be finite and positive, got {dispersion}"
250 ));
251 }
252 self.certificate_dispersion = Some(dispersion);
253 Ok(())
254 }
255
256 /// Harvest the per-atom inner-decoder-smooth byproducts (#1097 / #1103) the
257 /// residual-gauge certificate's post-PIRLS atom inference reports consume.
258 ///
259 /// This is the post-fit harness seam: it needs the reconstruction target `Z`
260 /// (`target`) and the fitted dispersion `φ` (`dispersion`), both available
261 /// only after the joint fit converges and the engine has discarded `Z` from
262 /// the objective. For each atom `k` it captures the Gaussian-identity
263 /// penalized smooth of the atom's leading decoder output channel `j`
264 /// (largest column 2-norm of `B_k`) against its partial residual
265 /// `e_{i} = z_i − fitted_i + a_{ik} g_k(t_i)` on channel `j`, holding all
266 /// other atoms and the assignment fixed at the fitted optimum — exactly the
267 /// fixed snapshot ([`crate::identifiability::AtomInnerFit`]) the Riesz
268 /// debiasing and split-LRT smooth-structure e-value read.
269 ///
270 /// A pure read of the fitted state: it mutates only the diagnostic
271 /// `atom_inner_fits` field, never a loss / criterion / penalty / optimizer
272 /// state. Atoms with no active rows or a degenerate (rank-deficient,
273 /// non-SPD) inner Hessian get a `None` slot — the genuine prerequisite (an
274 /// SPD penalized inner Hessian on a non-empty active set) is absent there.
275 pub fn set_atom_inner_fits(
276 &mut self,
277 target: ArrayView2<'_, f64>,
278 rho: &SaeManifoldRho,
279 dispersion: f64,
280 ) -> Result<(), String> {
281 if !dispersion.is_finite() || dispersion <= 0.0 {
282 return Err(format!(
283 "SaeManifoldTerm::set_atom_inner_fits: dispersion must be finite and positive, got {dispersion}"
284 ));
285 }
286 let n = self.n_obs();
287 let p = self.output_dim();
288 let k_atoms = self.k_atoms();
289 if target.dim() != (n, p) {
290 return Err(format!(
291 "SaeManifoldTerm::set_atom_inner_fits: target {:?} != ({n}, {p})",
292 target.dim()
293 ));
294 }
295
296 // #1026 — `atom_inner_fits` is a pure diagnostic; skip its dense (N×K×P)
297 // tensor (~256 GiB at K=32768,P=32) past a cell ceiling — all-None slots,
298 // never OOM. The fit is unaffected; only this audit field is absent.
299 if n.saturating_mul(k_atoms).saturating_mul(p) > 64_000_000 {
300 self.atom_inner_fits = Some((0..k_atoms).map(|_| None).collect());
301 return Ok(());
302 }
303
304 // Settled per-row assignments and per-(row, atom) decoded outputs, so the
305 // per-atom partial residual is `e_k = (z − fitted) + a_k decoded_k`.
306 let mut assignments = Vec::with_capacity(n);
307 for row in 0..n {
308 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
309 }
310 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
311 let mut dbuf = vec![0.0_f64; p];
312 for row in 0..n {
313 for atom_idx in 0..k_atoms {
314 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
315 for c in 0..p {
316 decoded[[row, atom_idx, c]] = dbuf[c];
317 }
318 }
319 }
320 let mut fitted = Array2::<f64>::zeros((n, p));
321 for row in 0..n {
322 for atom_idx in 0..k_atoms {
323 let a = assignments[row][atom_idx];
324 if a == 0.0 {
325 continue;
326 }
327 for c in 0..p {
328 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
329 }
330 }
331 }
332
333 let mut inner_fits: Vec<Option<crate::identifiability::AtomInnerFit>> =
334 Vec::with_capacity(k_atoms);
335 for atom_idx in 0..k_atoms {
336 inner_fits.push(self.build_atom_inner_fit(
337 atom_idx,
338 target,
339 &assignments,
340 decoded.view(),
341 fitted.view(),
342 dispersion,
343 )?);
344 }
345 self.atom_inner_fits = Some(inner_fits);
346 Ok(())
347 }
348
349 /// Build one atom's fixed inner-smooth snapshot for the post-PIRLS atom
350 /// inference reports, or `None` when the atom has no active rows or the
351 /// penalized inner Hessian is not SPD. Returns `Err` only on a structural
352 /// inconsistency (shape mismatch), never on a benign degenerate atom.
353 pub(crate) fn build_atom_inner_fit(
354 &self,
355 atom_idx: usize,
356 target: ArrayView2<'_, f64>,
357 assignments: &[Array1<f64>],
358 decoded: ArrayView3<'_, f64>,
359 fitted: ArrayView2<'_, f64>,
360 dispersion: f64,
361 ) -> Result<Option<crate::identifiability::AtomInnerFit>, String> {
362 let atom = &self.atoms[atom_idx];
363 let n = atom.n_obs();
364 let m = atom.basis_size();
365 let p = atom.output_dim();
366 if m == 0 || p == 0 {
367 return Ok(None);
368 }
369
370 // Leading decoder output channel j = argmax_j ‖B_k[:, j]‖, the channel
371 // that carries the atom's signal.
372 let mut j_lead = 0usize;
373 let mut best_norm = -1.0_f64;
374 for col in 0..p {
375 let mut norm = 0.0_f64;
376 for r in 0..m {
377 let v = atom.decoder_coefficients[[r, col]];
378 norm += v * v;
379 }
380 if norm > best_norm {
381 best_norm = norm;
382 j_lead = col;
383 }
384 }
385 let beta = atom.decoder_coefficients.column(j_lead).to_owned();
386
387 // Active rows: a_{ik} > 0.
388 let active: Vec<usize> = (0..n)
389 .filter(|&row| assignments[row][atom_idx] > 0.0)
390 .collect();
391 let n_active = active.len();
392 // The penalized smooth needs at least as many active rows as it has
393 // basis columns to give a non-degenerate data Gram; below that the inner
394 // fit's SPD prerequisite is genuinely unmet.
395 if n_active == 0 {
396 return Ok(None);
397 }
398
399 let mut design = Array2::<f64>::zeros((n_active, m));
400 let mut derivative_design = Array2::<f64>::zeros((n_active, m));
401 let mut row_scores = Array2::<f64>::zeros((n_active, m));
402 let mut weights = Array1::<f64>::zeros(n_active);
403 for (slot, &row) in active.iter().enumerate() {
404 let a_ik = assignments[row][atom_idx];
405 let w_i = a_ik * a_ik;
406 weights[slot] = w_i;
407 for col in 0..m {
408 design[[slot, col]] = atom.basis_values[[row, col]];
409 // Leading latent axis (axis 0) is the atom's primary coordinate;
410 // it is the one the average-derivative functional integrates.
411 derivative_design[[slot, col]] = atom.basis_jacobian[[row, col, 0]];
412 }
413 // Partial residual on channel j, then the inner-smooth working
414 // response z_i = e_i / a_ik so that w_i (z_i − Φᵀβ) = a_ik r_i.
415 let e_i = target[[row, j_lead]] - fitted[[row, j_lead]]
416 + a_ik * decoded[[row, atom_idx, j_lead]];
417 let mu_hat = design.row(slot).dot(&beta);
418 let z_i = e_i / a_ik;
419 let res_i = z_i - mu_hat;
420 // Gaussian-identity score s_i = −w_i res_i Φ_i / φ.
421 let scale = -w_i * res_i / dispersion;
422 for col in 0..m {
423 row_scores[[slot, col]] = scale * design[[slot, col]];
424 }
425 }
426
427 // Penalized inner Hessian H = ΦᵀWΦ + S̃_k.
428 let mut xtwx = Array2::<f64>::zeros((m, m));
429 for slot in 0..n_active {
430 let w_i = weights[slot];
431 for a in 0..m {
432 let xa = design[[slot, a]];
433 if xa == 0.0 {
434 continue;
435 }
436 for b in 0..m {
437 xtwx[[a, b]] += w_i * xa * design[[slot, b]];
438 }
439 }
440 }
441 let penalty = atom.smooth_penalty.clone();
442 if penalty.dim() != (m, m) {
443 return Err(format!(
444 "build_atom_inner_fit: atom {atom_idx} smooth penalty {:?} != ({m}, {m})",
445 penalty.dim()
446 ));
447 }
448 let penalized_hessian = &xtwx + &penalty;
449
450 // SPD prerequisite: the inner penalized Hessian must factor, else the
451 // atom's inner-smooth fit is degenerate and no report is producible.
452 if penalized_hessian.cholesky(Side::Lower).is_err() {
453 return Ok(None);
454 }
455
456 // Peak (largest fitted |g_k| on channel j) and mode (largest assignment
457 // mass) design rows, over the active set.
458 let mut peak_slot = 0usize;
459 let mut peak_val = -1.0_f64;
460 let mut mode_slot = 0usize;
461 let mut mode_mass = -1.0_f64;
462 for (slot, &row) in active.iter().enumerate() {
463 let g_val = design.row(slot).dot(&beta).abs();
464 if g_val > peak_val {
465 peak_val = g_val;
466 peak_slot = slot;
467 }
468 let mass = assignments[row][atom_idx];
469 if mass > mode_mass {
470 mode_mass = mass;
471 mode_slot = slot;
472 }
473 }
474 let peak_design_row = design.row(peak_slot).to_owned();
475 let mode_design_row = design.row(mode_slot).to_owned();
476
477 Ok(Some(crate::identifiability::AtomInnerFit {
478 design,
479 derivative_design,
480 beta,
481 penalty,
482 penalized_hessian,
483 row_scores,
484 weights,
485 dispersion,
486 peak_design_row,
487 mode_design_row,
488 }))
489 }
490
491 /// Profile the Gaussian reconstruction dispersion at the current seed
492 /// state. This is the scale used to make SAE penalty seeds dimensionless
493 /// before the outer rho search starts.
494 pub fn seed_reconstruction_dispersion(
495 &self,
496 target: ArrayView2<'_, f64>,
497 ) -> Result<f64, String> {
498 let fitted = self.try_fitted()?;
499 if fitted.dim() != target.dim() {
500 return Err(format!(
501 "SaeManifoldTerm::seed_reconstruction_dispersion: fitted {:?} != target {:?}",
502 fitted.dim(),
503 target.dim()
504 ));
505 }
506 let n_scalar = (target.nrows() * target.ncols()).max(1) as f64;
507 let mut rss = 0.0_f64;
508 for row in 0..target.nrows() {
509 for col in 0..target.ncols() {
510 let r = target[[row, col]] - fitted[[row, col]];
511 rss += r * r;
512 }
513 }
514 if !rss.is_finite() || rss < 0.0 {
515 return Err(format!(
516 "SaeManifoldTerm::seed_reconstruction_dispersion: non-finite seed RSS {rss}"
517 ));
518 }
519 Ok((rss / n_scalar).max(SAE_SEED_DISPERSION_FLOOR))
520 }
521
522 /// Install per-row design honesty weights (#991) — the `1/π` inclusion
523 /// corrections of a designed corpus subsample (see the field docs on
524 /// `row_loss_weights` for exactly where they enter the objective).
525 ///
526 /// Weights must be finite and strictly positive, one per term row. They
527 /// are self-normalized to mean `1.0` here (only the *relative* design
528 /// correction matters at the fitted sample size; the absolute `n/budget`
529 /// scale would silently inflate the dispersion estimate against the
530 /// sample-sized dof). Weights that are identically equal after
531 /// normalization (an exact full pass, or any uniform design) are stored
532 /// as `None`, so the unweighted path stays bit-for-bit identical rather
533 /// than "multiplied by 1.0".
534 pub fn set_row_loss_weights(&mut self, weights: Vec<f64>) -> Result<(), String> {
535 if weights.len() != self.n_obs() {
536 return Err(format!(
537 "SaeManifoldTerm::set_row_loss_weights: {} weights for {} rows",
538 weights.len(),
539 self.n_obs()
540 ));
541 }
542 if weights.is_empty() {
543 self.row_loss_weights = None;
544 return Ok(());
545 }
546 if !weights.iter().all(|w| w.is_finite() && *w > 0.0) {
547 return Err(
548 "SaeManifoldTerm::set_row_loss_weights: weights must be finite and strictly \
549 positive"
550 .to_string(),
551 );
552 }
553 let first = weights[0];
554 if weights.iter().all(|w| *w == first) {
555 // Uniform design (full pass, or flat measure): the normalized
556 // weight is exactly 1 everywhere — take the unweighted path.
557 self.row_loss_weights = None;
558 return Ok(());
559 }
560 let mean = weights.iter().sum::<f64>() / weights.len() as f64;
561 self.row_loss_weights = Some(weights.into_iter().map(|w| w / mean).collect());
562 Ok(())
563 }
564
565 /// The installed (mean-1 normalized) design honesty weights, `None` on the
566 /// exact unweighted path.
567 pub fn row_loss_weights(&self) -> Option<&[f64]> {
568 self.row_loss_weights.as_deref()
569 }
570
571 /// Drop any installed per-row reconstruction weights, returning the term to
572 /// the exact unweighted (full-pass) path. Used by the #997 structure-search
573 /// wiring to clear the internal estimation/evaluation mask off the adopted
574 /// term before the payload reconstruction is read over all rows.
575 pub fn clear_row_loss_weights(&mut self) {
576 self.row_loss_weights = None;
577 }
578
579 /// Huber-style OUTLIER-ROBUST per-row weights from the target activation
580 /// norms — the missing default *policy* for the existing
581 /// [`set_row_loss_weights`](Self::set_row_loss_weights) mechanism.
582 ///
583 /// The SAE fits unweighted least squares, which weights each token by its
584 /// squared residual ∝ `‖z_i‖²`. On real LLM residual streams the per-token
585 /// norm distribution is heavy-tailed (e.g. an OLMo mixed-layer slice has
586 /// `p99/median ≈ 4.7`), so a small **coherent** cluster of high-norm tokens —
587 /// typically special / attention-sink tokens, not semantic content —
588 /// dominates the objective (measured: the top 5% of tokens carry ~31% of the
589 /// total `‖z‖²` budget) and pulls dictionary atoms toward their direction.
590 /// Mean-centering does NOT address this (it is per-feature, not per-token).
591 ///
592 /// This returns Huber weights `w_i = min(1, δ·m / ‖z_i‖)` where `m` is the
593 /// MEDIAN token norm: tokens at or below `δ·m` keep full weight, higher-norm
594 /// tokens are downweighted so their objective share grows only LINEARLY (not
595 /// quadratically) with norm. `δ` is the robustness knob (`δ=1` thresholds at
596 /// the median; larger `δ` only touches the extreme tail). The result is
597 /// mean-normalized (overall objective scale preserved). OPT-IN: the caller
598 /// installs it via `set_row_loss_weights` — the default fit is unchanged.
599 pub fn robust_norm_row_weights(
600 target: ArrayView2<'_, f64>,
601 delta: f64,
602 ) -> Result<Vec<f64>, String> {
603 if !(delta.is_finite() && delta > 0.0) {
604 return Err(format!(
605 "robust_norm_row_weights: delta must be finite and positive; got {delta}"
606 ));
607 }
608 let n = target.nrows();
609 if n == 0 {
610 return Ok(Vec::new());
611 }
612 let norms: Vec<f64> = (0..n)
613 .map(|i| {
614 let r = target.row(i);
615 r.dot(&r).sqrt()
616 })
617 .collect();
618 let mut sorted = norms.clone();
619 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
620 // Median token norm (lower-median for even n; floored off zero so an
621 // all-zero/degenerate slice yields uniform weights instead of NaN).
622 let median = sorted[n / 2].max(f64::MIN_POSITIVE);
623 let thresh = delta * median;
624 let raw: Vec<f64> = norms
625 .iter()
626 .map(|&nm| if nm <= thresh { 1.0 } else { thresh / nm })
627 .collect();
628 let mean = raw.iter().sum::<f64>() / n as f64;
629 if !(mean.is_finite() && mean > 0.0) {
630 return Err("robust_norm_row_weights: degenerate weight normalizer".to_string());
631 }
632 Ok(raw.into_iter().map(|w| w / mean).collect())
633 }
634
635 /// Install the single per-row [`RowMetric`](gam_problem::RowMetric)
636 /// that both the reconstruction likelihood and the isometry gauge read.
637 /// Installing per-row output-Fisher factors here flips the provenance to
638 /// `OutputFisher` *and* is the only way the gauge acquires a non-identity
639 /// weight, so the two inner products cannot diverge. Passing a Euclidean
640 /// metric (or never calling this) keeps the bit-identical isotropic path.
641 ///
642 /// The metric's row count and output dimension must match the term.
643 pub fn set_row_metric(
644 &mut self,
645 metric: gam_problem::RowMetric,
646 ) -> Result<(), String> {
647 if metric.n_rows() != self.n_obs() {
648 return Err(format!(
649 "SaeManifoldTerm::set_row_metric: metric has {} rows but term has {}",
650 metric.n_rows(),
651 self.n_obs()
652 ));
653 }
654 if metric.p_out() != self.output_dim() {
655 return Err(format!(
656 "SaeManifoldTerm::set_row_metric: metric output dim {} but term has {}",
657 metric.p_out(),
658 self.output_dim()
659 ));
660 }
661 self.row_metric = Some(metric);
662 Ok(())
663 }
664
665 /// The installed per-row metric, if any. `None` ⇒ Euclidean / isotropic.
666 /// Consumed by the gauge wiring (to build the matching `WeightField`) and by
667 /// Object 4 (to read the [`MetricProvenance`](gam_problem::MetricProvenance)).
668 pub fn row_metric(&self) -> Option<&gam_problem::RowMetric> {
669 self.row_metric.as_ref()
670 }
671
672 /// The per-row inner product the additive diagnostics read through: the
673 /// installed [`RowMetric`](gam_problem::RowMetric) when one
674 /// was set (output-Fisher harvest present), otherwise a freshly-built
675 /// Euclidean metric of the term's own `(n_obs, output_dim)` shape. Either way
676 /// a metric always exists, so the diagnostics are never gated by a flag — the
677 /// Euclidean fallback is the bit-identical isotropic path.
678 pub(crate) fn diagnostic_metric(
679 &self,
680 ) -> Result<gam_problem::RowMetric, String> {
681 match self.row_metric() {
682 Some(metric) => Ok(metric.clone()),
683 None => {
684 gam_problem::RowMetric::euclidean(self.n_obs(), self.output_dim())
685 }
686 }
687 }
688
689 /// Build the additive post-fit diagnostic report for this fitted term: the
690 /// two-score per-atom [`AtomTwoLensReport`](crate::inference::atom_lens::AtomTwoLensReport)
691 /// (presence / behavioral coupling / discrepancy) and the residual-gauge
692 /// [`ResidualGaugeReport`](crate::identifiability::ResidualGaugeReport)
693 /// certificate.
694 ///
695 /// Both reports are read through the same single metric
696 /// ([`Self::diagnostic_metric`]): under a Euclidean / no-harvest provenance
697 /// the lens coupling is `None` and the gauge is certified under Euclidean
698 /// provenance — never an error, never gated by a flag (magic-by-default,
699 /// mirroring the metric selection itself).
700 ///
701 /// `per_atom_ard_variances`, when supplied, is one ARD variance vector per
702 /// atom (length = `latent_dim_k`), threaded into the certificate's
703 /// equal-ARD-rotation detection. `None` (or a per-atom `None`) ⇒ no ARD prior
704 /// on that atom. `isometry_pin_active` records whether an isometry gauge
705 /// penalty was installed on the fit: `false` escalates the certificate to the
706 /// `diffeomorphism-unpinned` verdict (the honest "no metric pin" statement),
707 /// exactly as the certificate's own escalation flag specifies.
708 ///
709 /// Pure read: it never mutates the term, never touches a loss / criterion /
710 /// penalty / optimizer state.
711 pub fn fit_diagnostics_report(
712 &self,
713 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
714 isometry_pin_active: bool,
715 reconstruction_dispersion: Option<f64>,
716 assignments_override: Option<ArrayView2<'_, f64>>,
717 ) -> Result<SaeManifoldFitDiagnostics, String> {
718 if let Some(view) = assignments_override {
719 let n = self.n_obs();
720 let k = self.k_atoms();
721 if view.dim() != (n, k) {
722 return Err(format!(
723 "fit_diagnostics_report: assignments_override shape {:?} must be ({n}, {k})",
724 view.dim()
725 ));
726 }
727 }
728 let metric = self.diagnostic_metric()?;
729 let atom_two_lens =
730 crate::inference::atom_lens::atom_two_lens(self, &metric, assignments_override);
731
732 let (certificate_model, streamed_curvature) =
733 self.to_residual_gauge_model(metric, per_atom_ard_variances, isometry_pin_active)?;
734 // #998: within-atom gauge families are certified on their EXACT orbits
735 // in the model's own (decoder, coordinate) parameter space — compensated
736 // symmetries are data-nulls by construction there, no lowering-error
737 // calibration involved. This now holds whether or not an isometry pin is
738 // active:
739 // * pin INACTIVE ⇒ the orbit verdict is the data residual alone (no
740 // penalty operator);
741 // * pin ACTIVE ⇒ the orbit verdict adds the isometry pin's orbit-space
742 // curvature through an [`OrbitPenaltyOperator`] lowered from the
743 // atom's second jet `Φ''` (the pullback-metric change along the orbit
744 // differentiates `J = Φ'B` through `t`). A model-class symmetry that
745 // preserves the metric stays a certified freedom; a non-isometric
746 // orbit (a basis not closed under the action) is genuinely pinned.
747 // The relative-curvature fraction `cost/stiffness²` is invariant to the
748 // pin strength μ (both faces scale with μ), so the operator is built at a
749 // canonical unit weight. An atom whose basis exposes no analytic second
750 // jet supplies no operator and falls back to the data residual — never an
751 // error. Magic-by-default either way: the choice is derived from the fit,
752 // never a flag.
753 let views = self.atom_parameter_views();
754 let ops: Vec<Option<crate::identifiability::OrbitPenaltyOperator>> =
755 if isometry_pin_active {
756 views
757 .iter()
758 .map(|view| {
759 view.as_ref().and_then(|v| {
760 crate::identifiability::isometry_orbit_penalty_operator(
761 v, 1.0,
762 )
763 })
764 })
765 .collect()
766 } else {
767 (0..self.k_atoms()).map(|_| None).collect()
768 };
769 let residual_gauge = if isometry_pin_active {
770 // The pin-active path consumes the per-row Jacobian curvature
771 // directly (the certificate_model retains it under a pin), so route
772 // through the non-streamed exact entry point.
773 crate::identifiability::residual_gauge_exact(
774 &certificate_model,
775 &views,
776 &ops,
777 )?
778 } else {
779 let (curvature_gram, root_rows) = streamed_curvature.ok_or_else(|| {
780 "fit_diagnostics_report: missing streamed residual-gauge curvature for unpinned exact path"
781 .to_string()
782 })?;
783 crate::identifiability::residual_gauge_exact_from_curvature_gram(
784 &certificate_model,
785 &views,
786 &ops,
787 curvature_gram,
788 root_rows,
789 )?
790 };
791
792 // #1097 / #1103: per-atom Riesz-debiased functionals and the any-n-valid
793 // split-LRT smooth-structure e-value (non-constant vs constant inner
794 // decoder), read straight off the certificate model — which carries
795 // each atom's `inner_fit` snapshot when the caller harvested it via
796 // [`Self::set_atom_inner_fits`] before this report. Atoms without a
797 // harvested inner fit degrade their inference fields to `None` inside
798 // `atom_inference_reports`, so this is always populated (one entry per
799 // atom) and never gated by a flag.
800 let atom_inference =
801 crate::identifiability::atom_inference_reports(&certificate_model);
802
803 Ok(SaeManifoldFitDiagnostics {
804 atom_two_lens,
805 residual_gauge,
806 incoherence_report: match reconstruction_dispersion.or(self.certificate_dispersion) {
807 Some(dispersion) => Some(dictionary_incoherence_report_with_dispersion(
808 self, dispersion,
809 )?),
810 None => None,
811 },
812 atom_inference,
813 })
814 }
815
816 /// Build the trust-diagnostics producer for the Python `diagnostics` block.
817 ///
818 /// `assignments` is supplied by the payload assembly site so top-k projection,
819 /// when requested, is reflected in coverage/frequency and in the tangent
820 /// spectra. The active threshold is shared with the atom lens so all
821 /// assignment-support diagnostics agree on what "active" means.
822 pub fn trust_diagnostics_report(
823 &self,
824 assignments: ArrayView2<'_, f64>,
825 ) -> Result<SaeTrustDiagnostics, String> {
826 let n = self.n_obs();
827 let k_atoms = self.k_atoms();
828 if assignments.dim() != (n, k_atoms) {
829 return Err(format!(
830 "trust_diagnostics_report: assignments shape {:?} must be ({n}, {k_atoms})",
831 assignments.dim()
832 ));
833 }
834 if !assignments.iter().all(|v| v.is_finite()) {
835 return Err("trust_diagnostics_report: assignments must be finite".to_string());
836 }
837 let metric = self.diagnostic_metric()?;
838 let active_threshold = crate::inference::atom_lens::SAE_TRUST_ACTIVE_MASS_FLOOR;
839 let mut atoms = Vec::with_capacity(k_atoms);
840 let mut atom_trust = Vec::with_capacity(k_atoms);
841 for (atom_idx, atom) in self.atoms.iter().enumerate() {
842 let mut active_token_count = 0usize;
843 let mut activation_sum = 0.0_f64;
844 for row in 0..n {
845 let mass = assignments[[row, atom_idx]];
846 activation_sum += mass;
847 if mass > active_threshold {
848 active_token_count += 1;
849 }
850 }
851 let coverage = if n > 0 {
852 active_token_count as f64 / n as f64
853 } else {
854 0.0
855 };
856 let activation_frequency = if n > 0 {
857 activation_sum / n as f64
858 } else {
859 0.0
860 };
861 let (sigma_min_tangent, sigma_max_tangent) = self
862 .atom_tangent_spectrum_from_assignments(
863 atom_idx,
864 assignments,
865 &metric,
866 active_threshold,
867 )?;
868 let tangent_condition_score = if sigma_max_tangent > 0.0 {
869 (sigma_min_tangent / sigma_max_tangent).clamp(0.0, 1.0)
870 } else {
871 0.0
872 };
873 let trust_score = tangent_condition_score;
874 atom_trust.push(trust_score);
875 atoms.push(SaeAtomTrustDiagnostics {
876 trust_score,
877 sigma_min_tangent,
878 sigma_max_tangent,
879 tangent_condition_score,
880 coverage,
881 activation_frequency,
882 untyped: matches!(atom.basis_kind, SaeAtomBasisKind::Precomputed(_)),
883 active_token_count,
884 });
885 }
886 Ok(SaeTrustDiagnostics { atom_trust, atoms })
887 }
888
889 pub(crate) fn atom_tangent_spectrum_from_assignments(
890 &self,
891 atom_idx: usize,
892 assignments: ArrayView2<'_, f64>,
893 metric: &gam_problem::RowMetric,
894 active_threshold: f64,
895 ) -> Result<(f64, f64), String> {
896 let atom = &self.atoms[atom_idx];
897 let d = atom.latent_dim;
898 let p = self.output_dim();
899 if d == 0 || p == 0 {
900 return Ok((0.0, 0.0));
901 }
902 let mut gram = Array2::<f64>::zeros((d, d));
903 let mut active_mass_sum = 0.0_f64;
904 let mut jac_row = vec![0.0_f64; p * d];
905 for row in 0..self.n_obs() {
906 let mass = assignments[[row, atom_idx]];
907 if !(mass > active_threshold) {
908 continue;
909 }
910 active_mass_sum += mass;
911 for axis in 0..d {
912 let start = axis;
913 let mut tangent = vec![0.0_f64; p];
914 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
915 for out in 0..p {
916 jac_row[out * d + start] = tangent[out];
917 }
918 }
919 let row_pullback = metric.pullback(row, &jac_row, d);
920 for axis_a in 0..d {
921 for axis_b in 0..=axis_a {
922 gram[[axis_a, axis_b]] += mass * row_pullback[[axis_a, axis_b]];
923 }
924 }
925 jac_row.fill(0.0);
926 }
927 if !(active_mass_sum > 0.0) {
928 return Ok((0.0, 0.0));
929 }
930 let inv_mass = 1.0 / active_mass_sum;
931 for axis_a in 0..d {
932 for axis_b in 0..=axis_a {
933 let value = gram[[axis_a, axis_b]] * inv_mass;
934 gram[[axis_a, axis_b]] = value;
935 gram[[axis_b, axis_a]] = value;
936 }
937 }
938 let (evals, _) = gram.eigh(Side::Lower).map_err(|e| {
939 format!(
940 "trust_diagnostics_report: atom {atom_idx} tangent eigendecomposition failed: {e}"
941 )
942 })?;
943 let mut sigma_min = f64::INFINITY;
944 let mut sigma_max = 0.0_f64;
945 for value in evals.iter().copied() {
946 let clamped = value.max(0.0);
947 let sigma = clamped.sqrt();
948 sigma_min = sigma_min.min(sigma);
949 sigma_max = sigma_max.max(sigma);
950 }
951 if sigma_min.is_finite() {
952 Ok((sigma_min, sigma_max))
953 } else {
954 Ok((0.0, 0.0))
955 }
956 }
957
958 /// Per-atom exact parameter-space views for the #998 certificate path:
959 /// the basis values / first-derivative jet, decoder coefficients, latent
960 /// coordinates, and assignment mass each atom was actually fitted with.
961 /// Sphere atoms get `None` (their chart's group action is nonlinear, so
962 /// the exact-orbit realisation does not apply and they stay on the frame
963 /// path), as does any atom whose coordinate chart width disagrees with its
964 /// latent dimension (a structurally inconsistent atom must not masquerade
965 /// as exactly certified).
966 pub(crate) fn atom_parameter_views(
967 &self,
968 ) -> Vec<Option<crate::identifiability::AtomParameterView>> {
969 let assignments = self.assignment.assignments();
970 let n = self.n_obs();
971 self.atoms
972 .iter()
973 .enumerate()
974 .map(|(k, atom)| {
975 if matches!(atom.basis_kind, SaeAtomBasisKind::Sphere) {
976 return None;
977 }
978 let coords = self.assignment.coords[k].as_matrix().to_owned();
979 if coords.nrows() != n || coords.ncols() != atom.latent_dim {
980 return None;
981 }
982 let mut activations = Array1::<f64>::zeros(n);
983 for row in 0..n {
984 activations[row] = assignments[[row, k]];
985 }
986 // Second jet Φ'' (#998): supplied when the atom's evaluator
987 // exposes an analytic Hessian, so a pin-active fit can lower its
988 // orbit-space isometry penalty operator (the metric-change of the
989 // pullback gram differentiates Φ' through t). Absent ⇒ the orbit
990 // verdict stays on the data residual / no-pin path, never an
991 // error.
992 let basis_second_jet = atom
993 .basis_evaluator
994 .as_ref()
995 .and_then(|evaluator| evaluator.second_jet_dyn(coords.view()))
996 .and_then(|res| res.ok());
997 Some(crate::identifiability::AtomParameterView {
998 basis_values: atom.basis_values.clone(),
999 basis_jacobian: atom.basis_jacobian.clone(),
1000 decoder: atom.decoder_coefficients.clone(),
1001 coords,
1002 activations,
1003 basis_second_jet,
1004 })
1005 })
1006 .collect()
1007 }
1008
1009 /// Lower this fitted term into the self-contained
1010 /// [`FittedSaeManifold`](crate::identifiability::FittedSaeManifold) the
1011 /// residual-gauge certificate consumes.
1012 ///
1013 /// The certificate's parameter space is the per-atom decoder **frame** — the
1014 /// `(output_dim, latent_dim)` image of the atom's latent axes in output space.
1015 /// We realise it as the active-mass-weighted mean decoder tangent
1016 /// `frame_k[:, a] = (Σ_n a_{nk} · ∂g_k/∂t_a(n)) / Σ_n a_{nk}` over the atom's
1017 /// active rows (the centroid decoder Jacobian columns the certificate docs
1018 /// name). The per-row pinning Jacobian block `J_n ∈ ℝ^{p × param_dim}` is the
1019 /// assignment-weighted per-row decoder tangent placed at each atom's frame
1020 /// slot: column `(k, i, a)` of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i]` — exactly
1021 /// the directions the reconstruction data gives cost to, in the same metric
1022 /// the fit used (whitened by the certificate through `RowMetric`).
1023 ///
1024 /// The flattened frame layout matches the certificate's
1025 /// `vec(frame_0) ⊕ vec(frame_1) ⊕ …`, row-major within each frame
1026 /// (`frame_k[i, a]` at offset `atom_offset(k) + i·latent_dim_k + a`).
1027 pub(crate) fn to_residual_gauge_model(
1028 &self,
1029 metric: gam_problem::RowMetric,
1030 per_atom_ard_variances: Option<&[Option<Array1<f64>>]>,
1031 isometry_pin_active: bool,
1032 ) -> Result<
1033 (
1034 crate::identifiability::FittedSaeManifold,
1035 Option<(Array2<f64>, usize)>,
1036 ),
1037 String,
1038 > {
1039 use crate::identifiability::{AtomTopology, FittedAtom, FittedSaeManifold};
1040
1041 let n = self.n_obs();
1042 let p = self.output_dim();
1043 let k = self.k_atoms();
1044 let assignments = self.assignment.assignments();
1045
1046 // Per-atom frame `(p, d)` = active-mass-weighted mean decoder tangent,
1047 // and the flattened-frame column offset bookkeeping for the joint
1048 // parameter vector (`vec(frame_0) ⊕ …`, row-major within each frame).
1049 let mut fitted_atoms: Vec<FittedAtom> = Vec::with_capacity(k);
1050 let mut atom_offsets: Vec<usize> = Vec::with_capacity(k);
1051 let mut atom_axis_dim: Vec<usize> = Vec::with_capacity(k);
1052 let mut cursor = 0usize;
1053 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1054 let d = atom.latent_dim;
1055 let topology = match (&atom.basis_kind, d) {
1056 (SaeAtomBasisKind::Periodic, 1) | (SaeAtomBasisKind::Torus, 1) => {
1057 AtomTopology::Circle
1058 }
1059 (SaeAtomBasisKind::Periodic, _) | (SaeAtomBasisKind::Torus, _) => {
1060 AtomTopology::Torus { latent_dim: d }
1061 }
1062 (SaeAtomBasisKind::Sphere, _) => AtomTopology::Sphere,
1063 // `Cylinder` (`S¹ × ℝ`) has exactly one continuous gauge: the
1064 // rotation (shift) of the periodic axis. The unbounded line axis
1065 // carries no rotational gauge, and its translation is already
1066 // pinned by the design's constant column — so the identifiability
1067 // gauge is that of a single circle. Fixing it as `Torus` would
1068 // over-impose a second (nonexistent) circle shift; fixing it as
1069 // `EuclideanPatch { 2 }` would over-impose a frame rotation
1070 // mixing the periodic and linear axes. `Circle` fixes the one
1071 // real continuous gauge and leaves the linear axis ungauged.
1072 (SaeAtomBasisKind::Cylinder, _) => AtomTopology::Circle,
1073 (
1074 SaeAtomBasisKind::Linear
1075 | SaeAtomBasisKind::Duchon
1076 | SaeAtomBasisKind::EuclideanPatch
1077 | SaeAtomBasisKind::Poincare
1078 | SaeAtomBasisKind::Precomputed(_),
1079 _,
1080 ) => AtomTopology::EuclideanPatch { latent_dim: d },
1081 };
1082
1083 let mut frame = Array2::<f64>::zeros((p, d));
1084 let mut active_mass = 0.0_f64;
1085 let mut tangent = vec![0.0_f64; p];
1086 for row in 0..n {
1087 let a_nk = assignments[[row, atom_idx]];
1088 if !(a_nk > 0.0) {
1089 continue;
1090 }
1091 active_mass += a_nk;
1092 for axis in 0..d {
1093 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1094 for i in 0..p {
1095 frame[[i, axis]] += a_nk * tangent[i];
1096 }
1097 }
1098 }
1099 if active_mass > 0.0 {
1100 let inv = 1.0 / active_mass;
1101 frame.mapv_inplace(|v| v * inv);
1102 }
1103
1104 // #995 lowering-error scale: mass-weighted relative dispersion of
1105 // the per-row tangents around the mean frame just built,
1106 // Σ_n a_n Σ_ax ‖t_ax(n) − frame[:,ax]‖² / Σ_n a_n Σ_ax ‖t_ax(n)‖².
1107 // 0 ⇒ the frame represents every active row exactly (flat
1108 // decoder); → 1 ⇒ the tangent field disperses so strongly (e.g. a
1109 // full circle, whose tangents average out) that the mean-frame
1110 // compression cannot distinguish gauge motion from curvature. The
1111 // certificate calibrates its per-generator verdict tolerance to
1112 // this scale so it never claims a pin it cannot resolve.
1113 let mut disp_num = 0.0_f64;
1114 let mut disp_den = 0.0_f64;
1115 for row in 0..n {
1116 let a_nk = assignments[[row, atom_idx]];
1117 if !(a_nk > 0.0) {
1118 continue;
1119 }
1120 for axis in 0..d {
1121 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1122 for i in 0..p {
1123 let dev = tangent[i] - frame[[i, axis]];
1124 disp_num += a_nk * dev * dev;
1125 disp_den += a_nk * tangent[i] * tangent[i];
1126 }
1127 }
1128 }
1129 let lowering_error = if disp_den > 0.0 {
1130 (disp_num / disp_den).clamp(0.0, 1.0)
1131 } else {
1132 0.0
1133 };
1134
1135 let ard_variances = per_atom_ard_variances
1136 .and_then(|all| all.get(atom_idx))
1137 .and_then(|opt| opt.clone())
1138 .filter(|v| v.len() == d);
1139
1140 fitted_atoms.push(FittedAtom {
1141 name: atom.name.clone(),
1142 topology,
1143 frame,
1144 ard_variances,
1145 lowering_error,
1146 // #1019: post-fit chart canonicalization (arc length for
1147 // d = 1, isometry-flow for d = 2 torus, flat-reference
1148 // isometry-flow for d = 2 free/patch, round-sphere
1149 // conformal-boost flow for d = 2 sphere atoms) pins the chart;
1150 // the certificate downgrades this atom's chart freedom to the
1151 // finite isometry group with PinnedByCanonicalization
1152 // provenance.
1153 chart_canonicalized: atom.chart_canonicalized
1154 && (d == 1
1155 || (d == 2
1156 && matches!(
1157 atom.basis_kind,
1158 SaeAtomBasisKind::Torus
1159 | SaeAtomBasisKind::Linear
1160 | SaeAtomBasisKind::Duchon
1161 | SaeAtomBasisKind::EuclideanPatch
1162 | SaeAtomBasisKind::Sphere
1163 ))),
1164 // #1097 / #1103: the per-atom inner-decoder-smooth snapshot,
1165 // attached when the post-fit harness has run
1166 // [`Self::set_atom_inner_fits`] (it needs the reconstruction
1167 // target Z, dropped from the objective at fit end). `None` on a
1168 // bare certificate-only model, or for a degenerate atom whose
1169 // inner Hessian was not SPD.
1170 inner_fit: self
1171 .atom_inner_fits
1172 .as_ref()
1173 .and_then(|fits| fits.get(atom_idx))
1174 .and_then(|slot| slot.clone()),
1175 });
1176 atom_offsets.push(cursor);
1177 atom_axis_dim.push(d);
1178 cursor += p * d;
1179 }
1180 let param_dim = cursor;
1181
1182 // Per-row pinning Jacobian `J_n ∈ ℝ^{p × param_dim}` flattened row-major
1183 // (`J_n[i, c] = jacobian_rows[n][i · param_dim + c]`). Column `(k, i', a)`
1184 // of `J_n` is `a_{nk} · ∂g_k/∂t_a(n)[i']` placed at the atom-k frame slot
1185 // and read out on output coordinate `i = i'` (a frame perturbation of
1186 // output `i'` moves only the row's output coordinate `i'`).
1187 //
1188 // The pinned certificate still consumes the legacy row-block contract.
1189 // The unpinned exact path consumes only `RᵀR`, so stream each transient
1190 // row Jacobian through the metric whitening and discard it immediately.
1191 let (jacobian_rows, streamed_curvature) = if isometry_pin_active {
1192 let mut jacobian_rows: Vec<Vec<f64>> = Vec::with_capacity(n);
1193 let mut tangent = vec![0.0_f64; p];
1194 for row in 0..n {
1195 let mut j_flat = vec![0.0_f64; p * param_dim];
1196 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1197 let a_nk = assignments[[row, atom_idx]];
1198 if !(a_nk > 0.0) {
1199 continue;
1200 }
1201 let d = atom_axis_dim[atom_idx];
1202 let base = atom_offsets[atom_idx];
1203 for axis in 0..d {
1204 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1205 for i in 0..p {
1206 // Frame coordinate `(k, i, axis)` sits at column
1207 // `base + i·d + axis`; it sources output coordinate `i`.
1208 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1209 }
1210 }
1211 }
1212 jacobian_rows.push(j_flat);
1213 }
1214 (jacobian_rows, None)
1215 } else {
1216 let streamed = self.residual_gauge_streamed_data_curvature(
1217 &metric,
1218 &atom_offsets,
1219 &atom_axis_dim,
1220 param_dim,
1221 )?;
1222 (Vec::new(), Some(streamed))
1223 };
1224
1225 // Isometry-penalty curvature root over the frame parameter space. When
1226 // the isometry gauge pin is active it gives curvature along every fitted
1227 // frame direction (it resists deviation of the decoder image from its
1228 // arc-length parameterization), so its row space is the span of the
1229 // per-atom frame columns: one root row per `(k, axis)` carrying that
1230 // atom's frame column at the atom's frame slot. Empty (`0 × param_dim`)
1231 // when the pin is inactive — exactly the certificate's escalation
1232 // condition to `diffeomorphism-unpinned`.
1233 let isometry_penalty_root = if isometry_pin_active && param_dim > 0 {
1234 let mut root_rows: Vec<Array1<f64>> = Vec::new();
1235 for (atom_idx, fitted) in fitted_atoms.iter().enumerate() {
1236 let d = atom_axis_dim[atom_idx];
1237 let base = atom_offsets[atom_idx];
1238 for axis in 0..d {
1239 let mut r = Array1::<f64>::zeros(param_dim);
1240 let mut any = false;
1241 for i in 0..p {
1242 let v = fitted.frame[[i, axis]];
1243 if v != 0.0 {
1244 any = true;
1245 }
1246 r[base + i * d + axis] = v;
1247 }
1248 if any {
1249 root_rows.push(r);
1250 }
1251 }
1252 }
1253 let mut root = Array2::<f64>::zeros((root_rows.len(), param_dim));
1254 for (ri, r) in root_rows.iter().enumerate() {
1255 root.row_mut(ri).assign(r);
1256 }
1257 root
1258 } else {
1259 Array2::<f64>::zeros((0, param_dim))
1260 };
1261
1262 Ok((
1263 FittedSaeManifold {
1264 atoms: fitted_atoms,
1265 jacobian_rows,
1266 isometry_penalty_root,
1267 metric,
1268 },
1269 streamed_curvature,
1270 ))
1271 }
1272
1273 pub(crate) fn residual_gauge_streamed_data_curvature(
1274 &self,
1275 metric: &gam_problem::RowMetric,
1276 atom_offsets: &[usize],
1277 atom_axis_dim: &[usize],
1278 param_dim: usize,
1279 ) -> Result<(Array2<f64>, usize), String> {
1280 let n = self.n_obs();
1281 let p = self.output_dim();
1282 if metric.p_out() != p {
1283 return Err(format!(
1284 "residual_gauge_streamed_data_curvature: metric output dim {} but term has {p}",
1285 metric.p_out()
1286 ));
1287 }
1288 let rank = metric.metric_rank();
1289 let mut gram = Array2::<f64>::zeros((param_dim, param_dim));
1290 if param_dim == 0 || n == 0 || rank == 0 {
1291 return Ok((gram, n * rank));
1292 }
1293
1294 let assignments = self.assignment.assignments();
1295 let mut tangent = vec![0.0_f64; p];
1296 let mut j_flat = vec![0.0_f64; p * param_dim];
1297 let mut root_row = Array1::<f64>::zeros(param_dim);
1298 for row in 0..n {
1299 j_flat.fill(0.0);
1300 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1301 let a_nk = assignments[[row, atom_idx]];
1302 if !(a_nk > 0.0) {
1303 continue;
1304 }
1305 let d = atom_axis_dim[atom_idx];
1306 let base = atom_offsets[atom_idx];
1307 for axis in 0..d {
1308 atom.fill_decoded_derivative_row(row, axis, &mut tangent);
1309 for i in 0..p {
1310 j_flat[i * param_dim + base + i * d + axis] += a_nk * tangent[i];
1311 }
1312 }
1313 }
1314
1315 if metric.drives_gauge() {
1316 for r in 0..rank {
1317 root_row.fill(0.0);
1318 for c in 0..param_dim {
1319 let mut acc = 0.0_f64;
1320 for i in 0..p {
1321 acc += metric.factor_entry(row, i, r) * j_flat[i * param_dim + c];
1322 }
1323 root_row[c] = acc;
1324 }
1325 let row_slice = root_row.as_slice().ok_or_else(|| {
1326 "residual_gauge_streamed_data_curvature: non-contiguous root row"
1327 .to_string()
1328 })?;
1329 Self::accumulate_residual_gauge_gram_row(&mut gram, row_slice);
1330 }
1331 } else {
1332 for i in 0..p {
1333 let start = i * param_dim;
1334 let end = start + param_dim;
1335 Self::accumulate_residual_gauge_gram_row(&mut gram, &j_flat[start..end]);
1336 }
1337 }
1338 }
1339
1340 for a in 0..param_dim {
1341 for b in 0..a {
1342 gram[[b, a]] = gram[[a, b]];
1343 }
1344 }
1345 Ok((gram, n * rank))
1346 }
1347
1348 pub(crate) fn accumulate_residual_gauge_gram_row(gram: &mut Array2<f64>, row: &[f64]) {
1349 for a in 0..row.len() {
1350 let va = row[a];
1351 if va == 0.0 {
1352 continue;
1353 }
1354 for b in 0..=a {
1355 let vb = row[b];
1356 if vb != 0.0 {
1357 gram[[a, b]] += va * vb;
1358 }
1359 }
1360 }
1361 }
1362
1363 pub fn set_temperature_schedule(
1364 &mut self,
1365 sched: GumbelTemperatureSchedule,
1366 ) -> Result<(), String> {
1367 sched.validate()?;
1368 self.assignment
1369 .mode
1370 .set_temperature(sched.current_tau(sched.iter_count))?;
1371 self.temperature_schedule = Some(sched);
1372 Ok(())
1373 }
1374
1375 pub(crate) fn advance_temperature_schedule(&mut self) -> Result<Option<f64>, String> {
1376 let Some(schedule) = self.temperature_schedule.as_mut() else {
1377 return Ok(None);
1378 };
1379 schedule.validate()?;
1380 let tau = schedule.step();
1381 self.assignment.mode.set_temperature(tau)?;
1382 Ok(Some(tau))
1383 }
1384
1385 pub fn n_obs(&self) -> usize {
1386 self.assignment.n_obs()
1387 }
1388
1389 pub fn k_atoms(&self) -> usize {
1390 self.atoms.len()
1391 }
1392
1393 /// Auto-derived in-core vs streaming plan for SAE Arrow-Schur work.
1394 ///
1395 /// This is intentionally not user-configurable: the route follows the
1396 /// retained full-batch working-set estimate and the currently selected GPU
1397 /// memory budget when CUDA is usable, otherwise a conservative host budget.
1398 pub fn streaming_plan(&self) -> SaeStreamingPlan {
1399 let n_obs = self.n_obs();
1400 let total_basis: usize = self.atoms.iter().map(|atom| atom.basis_size()).sum();
1401 let d_max = self
1402 .atoms
1403 .iter()
1404 .map(|atom| atom.latent_dim)
1405 .max()
1406 .unwrap_or(0);
1407 let border_dim = if self.any_frame_active() {
1408 self.factored_border_dim()
1409 } else {
1410 self.beta_dim()
1411 };
1412 sae_streaming_plan_for_shape(n_obs, total_basis, self.k_atoms(), d_max, border_dim)
1413 }
1414
1415 /// Construction-time validation: every Psi-tier analytic penalty in the
1416 /// registry must be dispatchable into the SAE arrow-Schur row layout.
1417 ///
1418 /// Two invariants are enforced upfront so the dispatch loop in
1419 /// `add_sae_analytic_penalty_contributions` is total (no runtime
1420 /// "unsupported penalty" fallthrough, no per-call K-gating):
1421 ///
1422 /// 1. Every Psi-tier penalty is either in [`sae_penalty_is_row_block_supported`],
1423 /// or `NuclearNorm` (which is redirected to the per-atom decoder (β) block
1424 /// rather than the coord "t" row block). Assignment sparsity penalties
1425 /// (`IBPAssignment`, `SoftmaxAssignmentSparsity`) are refused because the SAE
1426 /// term already owns them through its built-in assignment path
1427 /// (`loss.assignment_sparsity`). Penalty kinds with cross-row structure
1428 /// (`TotalVariation`, `Monotonicity`, `BlockSparsity`,
1429 /// `IvaeRidgeMeanGauge`, `Orthogonality`, `NestedPrefix`,
1430 /// `SheafConsistency`) cannot be expressed in the SAE row-block layout
1431 /// and are refused here.
1432 ///
1433 /// 2. If any Psi-tier row-block penalty is present, every atom shares
1434 /// the same coord latent dim. The current registry model carries one
1435 /// `latent_dim` per descriptor (the "t" latent block declares one
1436 /// `d` value); per-atom dispatch with heterogeneous `d_k` would
1437 /// require per-atom registry entries or per-kind in-place
1438 /// reshaping. Mixed-d row-block fits are rejected with an actionable
1439 /// error pointing at the configuration mismatch.
1440 ///
1441 /// The K=1 case trivially satisfies (2). Beta-tier and rho-tier
1442 /// penalties are not constrained here.
1443 pub(crate) fn validate_analytic_penalty_registry(
1444 &self,
1445 registry: &AnalyticPenaltyRegistry,
1446 ) -> Result<(), String> {
1447 let mut row_block_penalty_present = false;
1448 for penalty in ®istry.penalties {
1449 if penalty.tier() != PenaltyTier::Psi {
1450 continue;
1451 }
1452 if matches!(
1453 penalty,
1454 AnalyticPenaltyKind::IBPAssignment(_)
1455 | AnalyticPenaltyKind::SoftmaxAssignmentSparsity(_)
1456 ) {
1457 return Err(format!(
1458 "SAE-manifold term refuses analytic penalty {:?}: assignment sparsity \
1459 is owned by the built-in SAE assignment path (loss.assignment_sparsity). \
1460 Registering it would double-count the objective and gradient",
1461 penalty.name()
1462 ));
1463 }
1464 // NuclearNorm is redirected to the per-atom decoder (β) block in
1465 // `add_sae_beta_penalty` (it penalizes each atom's decoder matrix
1466 // singular spectrum, i.e. its embedding rank), so it bypasses the
1467 // coord "t" row-block requirement below.
1468 if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) {
1469 continue;
1470 }
1471 if !sae_penalty_is_row_block_supported(penalty) {
1472 return Err(format!(
1473 "SAE-manifold term refuses analytic penalty {:?}: this kind \
1474 has cross-row structure and cannot be expressed in the \
1475 arrow-Schur row layout. Use only row-block-supported \
1476 coord penalties (ARD, BlockOrthogonality, \
1477 Sparsity/TopK/JumpReLU, RowPrecisionPrior, \
1478 ParametricRowPrecisionPrior, ScadMcp, Isometry) on the \
1479 coord latent block, or move the penalty to a non-SAE \
1480 term",
1481 penalty.name()
1482 ));
1483 }
1484 row_block_penalty_present = true;
1485 }
1486 if row_block_penalty_present {
1487 let mut dims = self.assignment.coords.iter().map(|c| c.latent_dim());
1488 if let Some(first) = dims.next() {
1489 if let Some(mismatch) = dims.find(|d| *d != first) {
1490 return Err(format!(
1491 "SAE-manifold term refuses row-block analytic penalty: \
1492 atoms have heterogeneous coord latent dims (saw {first} \
1493 and {mismatch}). Row-block penalties (ARD, \
1494 BlockOrthogonality, ...) target the unified \"t\" \
1495 latent block whose declared `d` matches one shape; \
1496 per-atom dispatch with mixed `d_k` would silently \
1497 truncate or expand axes. Configure all atoms with the \
1498 same `atom_dim`, or split the row-block penalty into \
1499 per-atom descriptors keyed to per-atom latent blocks"
1500 ));
1501 }
1502 }
1503 }
1504 Ok(())
1505 }
1506
1507 pub fn output_dim(&self) -> usize {
1508 self.atoms[0].output_dim()
1509 }
1510
1511 pub fn beta_dim(&self) -> usize {
1512 let p = self.output_dim();
1513 self.atoms.iter().map(|a| a.basis_size() * p).sum()
1514 }
1515
1516 pub(crate) fn take_border_hbb_workspace(&mut self, border_dim: usize) -> Array2<f64> {
1517 let mut workspace =
1518 std::mem::replace(&mut self.border_hbb_workspace, Array2::<f64>::zeros((0, 0)));
1519 if workspace.dim() != (border_dim, border_dim) {
1520 workspace = Array2::<f64>::zeros((border_dim, border_dim));
1521 } else {
1522 workspace.fill(0.0);
1523 }
1524 workspace
1525 }
1526
1527 pub(crate) fn reclaim_border_hbb_workspace(&mut self, sys: &mut ArrowSchurSystem) {
1528 let workspace = std::mem::replace(&mut sys.hbb, Array2::<f64>::zeros((0, 0)));
1529 self.border_hbb_workspace = workspace;
1530 }
1531
1532 /// Factored arrow-Schur border dimension `Σ_k M_k · r_k` (issue #972): the
1533 /// number of decoder coordinates the border actually carries once the
1534 /// low-rank Grassmann frames are profiled out. Atoms with no active frame
1535 /// contribute their full `M_k · p` (`r_k == p`), so on the all-full-`B` path
1536 /// this equals [`Self::beta_dim`]. The border Cholesky / evidence log-det
1537 /// scale with THIS count, not `beta_dim`.
1538 pub fn factored_border_dim(&self) -> usize {
1539 self.atoms.iter().map(|a| a.border_coeff_count()).sum()
1540 }
1541
1542 /// Total profiled-out Grassmann manifold dimension `Σ_k r_k·(p − r_k)` across
1543 /// all active frames (issue #972). This is the count of decoder-frame degrees
1544 /// of freedom estimated OUTSIDE the border by closed-form polar steps, and it
1545 /// must enter the Laplace evidence dimension accounting (evidence honesty):
1546 /// the profiled frame is a MAP point on `∏_k Gr(r_k, p)`, contributing this
1547 /// many free dimensions to the model. `0` when every atom is on the full-`B`
1548 /// path. Threaded into [`Self::reml_occam_term`].
1549 pub fn grassmann_evidence_dimension(&self) -> usize {
1550 self.atoms
1551 .iter()
1552 .map(|a| a.frame_manifold_dimension())
1553 .sum()
1554 }
1555
1556 /// True iff any atom has an active low-rank Grassmann frame (issue #972).
1557 pub fn frames_active(&self) -> bool {
1558 self.atoms.iter().any(|a| a.decoder_frame.is_some())
1559 }
1560
1561 /// Alias of [`Self::frames_active`] (issue #972 / #977 T1): the predicate the
1562 /// assembly / step-lift branch on to decide whether the β-tier is built in
1563 /// the factored coordinate layout. Named to read as the question
1564 /// "is the factored path engaged?" at its call sites.
1565 pub fn any_frame_active(&self) -> bool {
1566 self.frames_active()
1567 }
1568
1569 /// Per-atom column offsets of the *factored* border (issue #972 / #977 T1):
1570 /// the running prefix sum of `M_k · r_k`, one entry per atom (the same
1571 /// convention as [`Self::beta_offsets`]). This is the start of each atom's
1572 /// `C_k` block in the reduced border vector; on the all-full-`B` path it
1573 /// equals `beta_offsets`. Distinct from [`Self::factored_border_offsets`]
1574 /// only in name (both compute the identical prefix sum) — this method is the
1575 /// one the frame transform reads, mirroring `beta_offsets` at the call site.
1576 pub fn factored_beta_offsets(&self) -> Vec<usize> {
1577 self.factored_border_offsets()
1578 }
1579
1580 /// Frame output matrix `U_k ∈ St(p, r_k)` for atom `k` (issue #972 / #977 T1).
1581 /// Returns the active frame `U_k` (`p × r_k`) when atom `k` is framed, else
1582 /// the identity `I_p` (the `r_k == p`, `U_k == I_p` full-`B` special case) so
1583 /// the projection / lift code is uniform across a mixed dictionary.
1584 pub fn frame_output_matrix(&self, atom_idx: usize) -> Array2<f64> {
1585 let atom = &self.atoms[atom_idx];
1586 match &atom.decoder_frame {
1587 Some(frame) => frame.frame().to_owned(),
1588 None => Array2::<f64>::eye(atom.output_dim()),
1589 }
1590 }
1591
1592 /// Per-pair frame factor `W_{ij} = U_iᵀ U_j` (`r_i × r_j`) used as the output
1593 /// factor of the factored data β-Hessian block `G_{ij} ⊗ W_{ij}` (issue #972
1594 /// / #977 T1). When both atoms are framed this is the dense principal-angle
1595 /// cosine matrix between the two frames; for `i == j` with an orthonormal
1596 /// frame it is exactly `I_{r_i}`; for any un-framed atom the corresponding
1597 /// `U` is `I_p`, so a same-atom un-framed pair gives `I_p` (the clean full-`B`
1598 /// `G ⊗ I_p` collapse) and a framed/un-framed cross pair gives the rectangular
1599 /// `U_iᵀ` / `U_j` overlap.
1600 pub fn frame_cross_factor(&self, atom_i: usize, atom_j: usize) -> Array2<f64> {
1601 let ui = self.frame_output_matrix(atom_i);
1602 let uj = self.frame_output_matrix(atom_j);
1603 // `U_iᵀ U_j`: `(r_i × p) · (p × r_j)`. `fast_atb` forms `U_iᵀ U_j` directly.
1604 fast_atb(&ui, &uj)
1605 }
1606
1607 /// Per-atom column offsets of the *factored* border (issue #972): the
1608 /// running prefix sum of `M_k · r_k`. The analogue of [`Self::beta_offsets`]
1609 /// for the reduced coordinate layout — atom `k`'s `C_k` occupies
1610 /// `[factored_border_offsets()[k] .. + M_k·r_k)`. On the full-`B` path this
1611 /// equals `beta_offsets`.
1612 pub fn factored_border_offsets(&self) -> Vec<usize> {
1613 let mut out = Vec::with_capacity(self.k_atoms());
1614 let mut cursor = 0usize;
1615 for atom in &self.atoms {
1616 out.push(cursor);
1617 cursor += atom.border_coeff_count();
1618 }
1619 out
1620 }
1621
1622 /// Assemble the factored border coordinate vector `C = [vec(C_1); …; vec(C_K)]`
1623 /// in row-major `C_k[m, j] → C[off_k + m·r_k + j]` layout (issue #972).
1624 ///
1625 /// This is the reduced state the arrow-Schur border carries when frames are
1626 /// active: its length is [`Self::factored_border_dim`] (`Σ M_k·r_k`), the
1627 /// border-size invariant verified by [`grassmann_assert_border_dim_invariant`].
1628 /// Atoms
1629 /// without an active frame contribute their full `vec(B_k)` (their `r_k == p`
1630 /// coordinates are the decoder itself), so on the all-full-`B` path this
1631 /// reproduces [`Self::flatten_beta`].
1632 pub fn flatten_factored_border(&self) -> Result<Array1<f64>, String> {
1633 let offsets = self.factored_border_offsets();
1634 let mut out = Array1::<f64>::zeros(self.factored_border_dim());
1635 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1636 let off = offsets[atom_idx];
1637 let r = atom.border_frame_rank();
1638 let m = atom.basis_size();
1639 let coords = match atom.factored_coordinates()? {
1640 Some(c) => c,
1641 // Full-`B` path: the decoder itself is the coordinate matrix.
1642 None => atom.decoder_coefficients.clone(),
1643 };
1644 for basis_col in 0..m {
1645 for j in 0..r {
1646 out[off + basis_col * r + j] = coords[[basis_col, j]];
1647 }
1648 }
1649 }
1650 Ok(out)
1651 }
1652
1653 /// Scatter a factored border coordinate vector `C` (length
1654 /// [`Self::factored_border_dim`]) back into the per-atom decoders, refreshing
1655 /// each `decoder_coefficients = C_k · U_kᵀ` so the full-`B` consumers stay
1656 /// consistent after a factored border solve (issue #972). The inverse of
1657 /// [`Self::flatten_factored_border`].
1658 pub fn scatter_factored_border(&mut self, border: ArrayView1<'_, f64>) -> Result<(), String> {
1659 let expected = self.factored_border_dim();
1660 if border.len() != expected {
1661 return Err(format!(
1662 "SaeManifoldTerm::scatter_factored_border: border length {} must equal \
1663 factored border dim {expected}",
1664 border.len()
1665 ));
1666 }
1667 let offsets = self.factored_border_offsets();
1668 for atom_idx in 0..self.atoms.len() {
1669 let off = offsets[atom_idx];
1670 let (r, m, has_frame) = {
1671 let atom = &self.atoms[atom_idx];
1672 (
1673 atom.border_frame_rank(),
1674 atom.basis_size(),
1675 atom.decoder_frame.is_some(),
1676 )
1677 };
1678 let mut coords = Array2::<f64>::zeros((m, r));
1679 for basis_col in 0..m {
1680 for j in 0..r {
1681 coords[[basis_col, j]] = border[off + basis_col * r + j];
1682 }
1683 }
1684 if has_frame {
1685 self.atoms[atom_idx].set_factored_coordinates(coords.view())?;
1686 } else {
1687 // Full-`B` path: the coordinates ARE the decoder.
1688 self.atoms[atom_idx].decoder_coefficients = coords;
1689 }
1690 }
1691 Ok(())
1692 }
1693
1694 /// Auto-derive and install low-rank Grassmann decoder frames across all
1695 /// atoms (issue #972) — magic-by-default, no flag. Each atom independently
1696 /// activates its frame iff the factorization materially shrinks its border
1697 /// (see [`SaeManifoldAtom::maybe_activate_decoder_frame`]). Returns the
1698 /// number of atoms that activated a frame. Idempotent: re-running re-derives
1699 /// each frame from the current decoder.
1700 ///
1701 /// The decision keys on the *frontier* regime the issue targets: at large
1702 /// ambient `p` the full border `Σ M_k · p` reaches `10^7`–`10^8` and the
1703 /// border Cholesky dies, while the decoder's effective column rank `r` stays
1704 /// `≪ p`. Small-`p` atoms (where `r` cannot beat the activation margin)
1705 /// keep the bit-for-bit full-`B` path, so the small-model evidence is
1706 /// unchanged (verified by `factored_evidence_matches_full_b_at_small_p`).
1707 pub fn auto_activate_decoder_frames(&mut self) -> Result<usize, String> {
1708 let mut activated = 0usize;
1709 for atom in &mut self.atoms {
1710 let expected_rank = atom.decoder_frame_activation_rank()?;
1711 match (
1712 expected_rank,
1713 atom.decoder_frame.as_ref().map(GrassmannFrame::rank),
1714 ) {
1715 (Some(expected), Some(current)) if expected == current => {
1716 continue;
1717 }
1718 (None, Some(_)) => {
1719 atom.deactivate_decoder_frame();
1720 continue;
1721 }
1722 (None, None) => {
1723 continue;
1724 }
1725 (Some(_), _) => {}
1726 }
1727 if atom.maybe_activate_decoder_frame()?.is_some() {
1728 activated += 1;
1729 }
1730 }
1731 Ok(activated)
1732 }
1733
1734 /// Reconcile decoder-frame activation before a fit entry point. The
1735 /// user-facing `auto_activate_decoder_frames` contract returns only newly
1736 /// installed frames; this helper enforces the stronger invariant the large-p
1737 /// solver needs: every atom whose current decoder satisfies the activation
1738 /// predicate has an active frame after the pass.
1739 pub(crate) fn ensure_decoder_frames_active_for_current_decoder(
1740 &mut self,
1741 ) -> Result<(), String> {
1742 self.auto_activate_decoder_frames()?;
1743 for (atom_idx, atom) in self.atoms.iter().enumerate() {
1744 let expected_rank = atom.decoder_frame_activation_rank()?;
1745 if let Some(expected_rank) = expected_rank {
1746 match atom.decoder_frame.as_ref() {
1747 Some(frame) if frame.rank() == expected_rank => {}
1748 Some(frame) => {
1749 return Err(format!(
1750 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1751 atom {atom_idx} frame rank {} must equal audited rank {expected_rank}",
1752 frame.rank()
1753 ));
1754 }
1755 None => {
1756 return Err(format!(
1757 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1758 atom {atom_idx} has audited rank {expected_rank} but no active frame"
1759 ));
1760 }
1761 }
1762 } else if atom.decoder_frame.is_some() {
1763 return Err(format!(
1764 "SaeManifoldTerm::ensure_decoder_frames_active_for_current_decoder: \
1765 atom {atom_idx} kept a frame after the full-B predicate won"
1766 ));
1767 }
1768 }
1769 Ok(())
1770 }
1771
1772 /// Closed-form streaming POLAR refresh of every ACTIVE decoder frame from the
1773 /// current data evidence (issue #972 / #977 T1) — the U-block of the
1774 /// alternating block-coordinate ascent that complements the border's
1775 /// C-block Newton step.
1776 ///
1777 /// For each framed atom `k` we accumulate the `p × r_k` cross-moment
1778 /// `A_k = Σ_n a_{n,k} · e_{n,k} · ĉ_{n,k}ᵀ`,
1779 /// where `e_{n,k} = z_n − Σ_{k'≠k} a_{n,k'}·decoded_{k'}(n)` is the row's
1780 /// partial reconstruction residual (everything except atom `k`) and
1781 /// `ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^{r_k}` is atom `k`'s in-span decoded
1782 /// coordinate. The polar factor `U_new = polar(A_k)` is the closed-form MAP
1783 /// frame on `Gr(r_k, p)` given the C-coordinates held fixed — the same
1784 /// `O(p r²)` thin SVD the issue prescribes, run OUTSIDE the border. The frame
1785 /// is then re-installed and the decoder re-projected onto it so the
1786 /// authoritative `B_k = C_k U_newᵀ` and the `(C_k, U_new)` pair stay
1787 /// consistent (a no-op in span for a truly rank-`r` atom). Un-framed atoms
1788 /// are skipped. Returns the number of frames refreshed.
1789 pub(crate) fn refresh_active_frames_from_data(
1790 &mut self,
1791 target: ArrayView2<'_, f64>,
1792 rho: &SaeManifoldRho,
1793 ) -> Result<usize, String> {
1794 let n = self.n_obs();
1795 let p = self.output_dim();
1796 let k_atoms = self.k_atoms();
1797 if n == 0 {
1798 return Ok(0);
1799 }
1800 // Per-row assignments and per-(row, atom) decoded outputs, computed once.
1801 let mut assignments = Vec::with_capacity(n);
1802 for row in 0..n {
1803 assignments.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
1804 }
1805 let mut decoded = Array3::<f64>::zeros((n, k_atoms, p));
1806 let mut dbuf = vec![0.0_f64; p];
1807 for row in 0..n {
1808 for atom_idx in 0..k_atoms {
1809 self.atoms[atom_idx].fill_decoded_row(row, &mut dbuf);
1810 for c in 0..p {
1811 decoded[[row, atom_idx, c]] = dbuf[c];
1812 }
1813 }
1814 }
1815 // Full fitted reconstruction `Σ_k a_k decoded_k`, so the per-atom partial
1816 // residual is `e_k = (z − fitted) + a_k decoded_k` (add atom k back in).
1817 let mut fitted = Array2::<f64>::zeros((n, p));
1818 for row in 0..n {
1819 for atom_idx in 0..k_atoms {
1820 let a = assignments[row][atom_idx];
1821 if a == 0.0 {
1822 continue;
1823 }
1824 for c in 0..p {
1825 fitted[[row, c]] += a * decoded[[row, atom_idx, c]];
1826 }
1827 }
1828 }
1829 let mut refreshed = 0usize;
1830 for atom_idx in 0..k_atoms {
1831 // Only atoms with an active frame are refreshed.
1832 let Some(coords_c) = self.atoms[atom_idx].factored_coordinates()? else {
1833 continue;
1834 };
1835 let r = self.atoms[atom_idx].border_frame_rank();
1836 let m = self.atoms[atom_idx].basis_size();
1837 // Accumulate `A_k = Σ_n a_k · e_{n,k} · ĉ_{n,k}ᵀ` directly (p × r).
1838 let mut cross = GrassmannCrossMoment::new(p, r);
1839 // Build per-row p-target `a_k·e_k` and r-coord `a_k·ĉ` batched, then
1840 // accumulate as one outer-product sum. `accumulate` forms
1841 // `targetsᵀ·coords`, so scaling EITHER side by `a_k` once gives the
1842 // `a_k²` weight on the cross-moment that matches the C-block normal
1843 // equations (residual leg carries `a_k`, coordinate leg carries
1844 // `a_k`).
1845 let mut targets = Array2::<f64>::zeros((n, p));
1846 let mut rcoords = Array2::<f64>::zeros((n, r));
1847 for row in 0..n {
1848 let a = assignments[row][atom_idx];
1849 // Partial residual e_{n,k} = z_n − (fitted − a_k decoded_k).
1850 for c in 0..p {
1851 let e = target[[row, c]] - fitted[[row, c]] + a * decoded[[row, atom_idx, c]];
1852 targets[[row, c]] = a * e;
1853 }
1854 // In-span coordinate ĉ_{n,k} = Φ_k(t_n)·C_k ∈ ℝ^r.
1855 for j in 0..r {
1856 let mut acc = 0.0_f64;
1857 for basis_col in 0..m {
1858 acc += self.atoms[atom_idx].basis_values[[row, basis_col]]
1859 * coords_c[[basis_col, j]];
1860 }
1861 rcoords[[row, j]] = a * acc;
1862 }
1863 }
1864 cross.accumulate(targets.view(), rcoords.view())?;
1865 // `polar(A_k)` is well-defined only when the moment is non-trivial;
1866 // a zero moment (e.g. a fully collapsed atom) leaves the frame as-is.
1867 if cross.moment().iter().all(|&v| v == 0.0) {
1868 continue;
1869 }
1870 self.atoms[atom_idx].refresh_frame_from_cross_moment(cross.moment())?;
1871 refreshed += 1;
1872 }
1873 Ok(refreshed)
1874 }
1875
1876 pub fn beta_offsets(&self) -> Vec<usize> {
1877 let p = self.output_dim();
1878 let mut out = Vec::with_capacity(self.k_atoms());
1879 let mut cursor = 0usize;
1880 for atom in &self.atoms {
1881 out.push(cursor);
1882 cursor += atom.basis_size() * p;
1883 }
1884 out
1885 }
1886
1887 /// Per-atom β column ranges for the block-Jacobi Schur preconditioner.
1888 ///
1889 /// Returns one `Range<usize>` per atom, covering that atom's decoder
1890 /// coefficients in the flat β vector:
1891 /// `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
1892 ///
1893 /// Pass to [`ArrowSchurSystem::set_block_offsets`] so that
1894 /// [`gam_solve::arrow_schur::JacobiPreconditioner`] builds one dense
1895 /// Schur sub-block per atom instead of scalar-diagonal inversion.
1896 pub fn beta_block_offsets(&self) -> Arc<[std::ops::Range<usize>]> {
1897 let p = self.output_dim();
1898 let mut ranges: Vec<std::ops::Range<usize>> = Vec::with_capacity(self.k_atoms());
1899 let mut cursor = 0usize;
1900 for atom in &self.atoms {
1901 let width = atom.basis_size() * p;
1902 ranges.push(cursor..cursor + width);
1903 cursor += width;
1904 }
1905 Arc::from(ranges.into_boxed_slice())
1906 }
1907
1908 /// Decide whether the sparse per-row active-set layout is engaged for a
1909 /// dense-weight assignment mode, and if so derive the per-row active-atom
1910 /// cap and magnitude cutoff.
1911 ///
1912 /// #1408: this plan is mode-agnostic. `assemble_arrow_schur` consults it
1913 /// directly for IBP-MAP, and for `AssignmentMode::Softmax` via
1914 /// [`Self::softmax_active_plan`], which tightens it with an explicit `top_k`
1915 /// (`softmax_active_cap`). Softmax therefore engages the compact active-set
1916 /// layout whenever `top_k` or the budget bounds the active set (the
1917 /// active-sub-block Gershgorin majorizer + coherent logdet/θ-adjoint are
1918 /// landed — see `SaeRowLayout`'s doc); it keeps the full `K`-atom layout only
1919 /// when neither lever engages. The decision is auto-derived from
1920 /// the problem size and the device/host working-set budget — never a CLI flag
1921 /// or kwarg. JumpReLU is not handled here (it always uses its structural gate
1922 /// via [`SaeRowLayout::from_jumprelu`]). The dense Gauss-Newton data Gram `G`
1923 /// is `(m_total × m_total)` f64; if its dense form fits the budget we keep
1924 /// the exact full-support solve (every atom active per row), so small-`K`
1925 /// problems are bit-for-bit unchanged. Above that, we cap each row to the
1926 /// `k_active` atoms that make the *sparse* Gram fit the same budget, with a
1927 /// relative magnitude cutoff that drops assignment mass contributing
1928 /// negligible `O(a²)` curvature.
1929 ///
1930 /// Returns `Some((k_active_cap, cutoff))` to engage sparsity, or `None` to
1931 /// keep the dense full-support layout.
1932 pub(crate) fn sparse_active_plan(&self) -> Option<(usize, f64)> {
1933 // The per-row Riemannian tangent projection for non-Euclidean atom
1934 // latents is now applied directly on the compact active-set rows (see
1935 // the `Some(layout)` arm in `assemble_arrow_schur`, via
1936 // `compact_row_ext_manifold_and_point`), which rebuilds each row's
1937 // product manifold in its compact column order and applies the SAME
1938 // gt/htt/htbeta + Kronecker-Jacobian projections the dense path uses. So
1939 // the sparse plan may engage on curved ext-coord manifolds (circle /
1940 // torus / sphere atoms) — the affordability lever for manifold-SAE at
1941 // large `K`, where the dense `K²` co-assignment Gram is the cost. (The
1942 // former `is_euclidean()`-only restriction punted every curved atom to
1943 // the dense layout; it is lifted.) The host/device in-core budget is the
1944 // single gate now; it is parameterised in `sparse_active_plan_for_budget`
1945 // so the engagement regression can pin a small budget without allocating
1946 // a multi-GB dense Gram.
1947 let budget = match crate::gpu::device_runtime::GpuRuntime::global() {
1948 // Allow up to one quarter of the AGGREGATE device budget for the dense
1949 // Gram, matching the streaming dispatcher's in-core fraction. The
1950 // per-atom-pair Gram blocks fan out across the whole device pool, so
1951 // the in-core fraction sums every ordinal's budget, not just the
1952 // primary's.
1953 Some(rt) => {
1954 let aggregate: usize = rt
1955 .device_ordinals()
1956 .iter()
1957 .map(|&ord| rt.memory_budget_for(ord))
1958 .sum();
1959 aggregate / 4
1960 }
1961 None => sae_host_in_core_budget_bytes().0,
1962 };
1963 self.sparse_active_plan_for_budget(budget)
1964 }
1965
1966 /// Budget-parameterised core of [`Self::sparse_active_plan`]. The dense data
1967 /// Gram footprint `(m_total · m_total) f64` is compared against `budget`; a
1968 /// term whose dense Gram exceeds the budget engages the compact active-set
1969 /// plan (returns `Some((k_active_cap, cutoff))`), regardless of whether any
1970 /// atom latent is curved. Pulled out so the curved-atom engagement
1971 /// regression can pin a small budget deterministically.
1972 pub(crate) fn sparse_active_plan_for_budget(&self, budget: usize) -> Option<(usize, f64)> {
1973 // Relative magnitude cutoff: assignment mass below this fraction of the
1974 // row's peak `|a_k|` enters the Gram only as `O(a²)` curvature and is
1975 // dropped. Chosen so dropped terms are ~1e-6 of the peak self-coupling.
1976 const RELATIVE_CUTOFF: f64 = 1.0e-3;
1977
1978 let k_atoms = self.k_atoms();
1979 if k_atoms <= 1 {
1980 return None;
1981 }
1982 let p = self.output_dim();
1983 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
1984 // Dense data Gram footprint: (m_total · m_total) f64.
1985 let dense_gram_bytes = m_total
1986 .saturating_mul(m_total)
1987 .saturating_mul(SAE_BYTES_PER_F64);
1988 if dense_gram_bytes <= budget {
1989 return None;
1990 }
1991
1992 // Sparse Gram footprint scales with the per-row active basis count
1993 // `k_active · m_atom`. Solve for the largest `k_active` whose sparse
1994 // Gram `(k_active · m_atom)²` still fits the budget.
1995 let m_atom = (m_total as f64 / k_atoms as f64).max(1.0);
1996 let max_active_basis = ((budget as f64 / SAE_BYTES_PER_F64 as f64).sqrt() / m_atom).floor();
1997 let k_active_cap = (max_active_basis as usize).clamp(1, k_atoms);
1998 // p does not enter the Gram dimension (it is carried by the `⊗ I_p`
1999 // structure), but a degenerate `p == 0` term has no decoder columns.
2000 if p == 0 {
2001 return None;
2002 }
2003 Some((k_active_cap, RELATIVE_CUTOFF))
2004 }
2005
2006 /// #1408/#1409 — per-row active-set plan for the Softmax assignment.
2007 ///
2008 /// Engages the compact top-`k` row layout when EITHER the user supplied a
2009 /// hard `top_k` cap ([`Self::softmax_active_cap`], `1 <= k < K`) OR the
2010 /// dense data Gram exceeds the in-core budget (the same memory lever the
2011 /// IBP path uses via [`Self::sparse_active_plan`]). The returned
2012 /// `k_active_cap` is the tighter of the two, so an explicit `top_k`
2013 /// genuinely bounds the optimization even below the memory threshold and a
2014 /// large-K budget breach still bounds it when no `top_k` is set. Returns
2015 /// `None` (keep the exact full-`K` dense softmax layout) when neither lever
2016 /// engages.
2017 ///
2018 /// The cutoff is the same relative magnitude floor as the budget plan
2019 /// (`1e-3` of the row peak); under an explicit `top_k` cap alone (no budget
2020 /// breach) it is `0.0` so exactly the top-`k` atoms are retained.
2021 pub(crate) fn softmax_active_plan(&self) -> Option<(usize, f64)> {
2022 if self.k_atoms() <= 1 {
2023 return None;
2024 }
2025 let budget_plan = self.sparse_active_plan();
2026 match (self.softmax_active_cap, budget_plan) {
2027 (Some(cap), Some((budget_cap, cutoff))) => Some((cap.min(budget_cap), cutoff)),
2028 // Explicit cap only: retain exactly the top-`cap` atoms (no extra
2029 // magnitude cutoff beyond the cap).
2030 (Some(cap), None) => Some((cap, 0.0)),
2031 (None, plan) => plan,
2032 }
2033 }
2034
2035 pub fn flatten_beta(&self) -> Array1<f64> {
2036 let p = self.output_dim();
2037 let offsets = self.beta_offsets();
2038 let mut out = Array1::<f64>::zeros(self.beta_dim());
2039 for (atom_idx, atom) in self.atoms.iter().enumerate() {
2040 let m = atom.basis_size();
2041 let off = offsets[atom_idx];
2042 for basis_col in 0..m {
2043 for out_col in 0..p {
2044 out[off + basis_col * p + out_col] =
2045 atom.decoder_coefficients[[basis_col, out_col]];
2046 }
2047 }
2048 }
2049 out
2050 }
2051
2052 pub fn set_flat_beta(&mut self, beta: ArrayView1<'_, f64>) -> Result<(), String> {
2053 if beta.len() != self.beta_dim() {
2054 return Err(format!(
2055 "set_flat_beta: beta length {} != expected {}",
2056 beta.len(),
2057 self.beta_dim()
2058 ));
2059 }
2060 let p = self.output_dim();
2061 let offsets = self.beta_offsets();
2062 for (atom_idx, atom) in self.atoms.iter_mut().enumerate() {
2063 let m = atom.basis_size();
2064 let off = offsets[atom_idx];
2065 for basis_col in 0..m {
2066 for out_col in 0..p {
2067 atom.decoder_coefficients[[basis_col, out_col]] =
2068 beta[off + basis_col * p + out_col];
2069 }
2070 }
2071 }
2072 Ok(())
2073 }
2074
2075 pub fn refit_decoder_least_squares_at_current_state(
2076 &mut self,
2077 target: ArrayView2<'_, f64>,
2078 rho: Option<&SaeManifoldRho>,
2079 ) -> Result<(), String> {
2080 let n = self.n_obs();
2081 let p = self.output_dim();
2082 if target.dim() != (n, p) {
2083 return Err(format!(
2084 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: target shape {:?} != ({n}, {p})",
2085 target.dim()
2086 ));
2087 }
2088 let k_atoms = self.k_atoms();
2089 let offsets = self.beta_offsets();
2090 let m_total = self.beta_dim() / p;
2091 let mut design = Array2::<f64>::zeros((n, m_total));
2092 for row in 0..n {
2093 let assignments = match rho {
2094 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2095 None => self.assignment.try_assignments_row(row)?,
2096 };
2097 for atom_idx in 0..k_atoms {
2098 let atom = &self.atoms[atom_idx];
2099 let weight = assignments[atom_idx];
2100 let m = atom.basis_size();
2101 let off = offsets[atom_idx] / p;
2102 for basis_col in 0..m {
2103 design[[row, off + basis_col]] = weight * atom.basis_values[[row, basis_col]];
2104 }
2105 }
2106 }
2107 let beta = solve_design_least_squares(design.view(), target)?;
2108 if beta.dim() != (m_total, p) {
2109 return Err(format!(
2110 "SaeManifoldTerm::refit_decoder_least_squares_at_current_state: beta shape {:?} != ({m_total}, {p})",
2111 beta.dim()
2112 ));
2113 }
2114 for atom_idx in 0..k_atoms {
2115 let m = self.atoms[atom_idx].basis_size();
2116 let off = offsets[atom_idx] / p;
2117 for basis_col in 0..m {
2118 for out_col in 0..p {
2119 self.atoms[atom_idx].decoder_coefficients[[basis_col, out_col]] =
2120 beta[[off + basis_col, out_col]];
2121 }
2122 }
2123 self.atoms[atom_idx].refresh_intrinsic_smooth_penalty();
2124 }
2125 Ok(())
2126 }
2127
2128 pub fn fitted(&self) -> Array2<f64> {
2129 self.try_fitted().expect("assignment logits must be finite")
2130 }
2131
2132 /// The #1026 hybrid-collapse substitution map: `atom_idx → &AtomLinearImage`
2133 /// for every `d = 1` slot whose post-fit verdict selected its straight
2134 /// (`Θ → 0`) sub-model. Empty when no report has been computed
2135 /// (`hybrid_split_report == None`, e.g. mid-fit) or no slot collapsed. The
2136 /// SINGLE source of the collapse policy — every reconstruction path (the
2137 /// rho-keyed `try_fitted_with_rho`, the explicit-assignment
2138 /// [`Self::reconstruct_from_assignments`] used by the top-k projection)
2139 /// reads it so train, OOS, and top-k reconstructions decode collapsed slots
2140 /// identically (#1228, #1233).
2141 pub(crate) fn hybrid_linear_image_map(
2142 &self,
2143 ) -> std::collections::HashMap<usize, &crate::hybrid_split::AtomLinearImage> {
2144 // A fitted term carries its collapse policy on the post-fit
2145 // `hybrid_split_report`; an OOS term carries the same trained images on
2146 // `oos_linear_images` (#1228). At most one is `Some` in practice, but
2147 // prefer the report when both are present.
2148 if let Some(report) = self.hybrid_split_report.as_ref() {
2149 return report
2150 .verdicts
2151 .iter()
2152 .filter_map(|v| v.linear_image.as_ref().map(|img| (img.atom_idx, img)))
2153 .collect();
2154 }
2155 if let Some(images) = self.oos_linear_images.as_ref() {
2156 return images.iter().map(|img| (img.atom_idx, img)).collect();
2157 }
2158 std::collections::HashMap::new()
2159 }
2160
2161 /// #1228 — attach the trained dictionary's hybrid-collapsed linear images to
2162 /// this (typically OOS) term so its reconstruction (`fitted` / the top-k
2163 /// assembler) decodes verdict-linear `d = 1` slots by the SAME straight
2164 /// sub-model the training reconstruction used, instead of the original
2165 /// curved decoder. Each image's `atom_idx` must index a real slot; an image
2166 /// whose channel count `p` disagrees with this term's output dim, or whose
2167 /// `atom_idx` is out of range, is rejected so a stale/mismatched payload
2168 /// cannot silently corrupt the reconstruction. Pass an empty slice (or never
2169 /// call this) for an all-curved OOS reconstruction.
2170 ///
2171 /// `pub` (not `pub(crate)`): this is part of the FFI surface — the gam-pyffi
2172 /// crate calls it from `latent_basis_and_sae_ffi.rs` to attach a trained
2173 /// dictionary's hybrid-linear images to an OOS reconstruction term (#1228).
2174 /// Downgrading it to `pub(crate)` breaks the gam-pyffi cdylib build with
2175 /// E0624 (the gam lib still compiles, so the lib build does not catch it).
2176 pub fn set_hybrid_linear_images(
2177 &mut self,
2178 images: Vec<crate::hybrid_split::AtomLinearImage>,
2179 ) -> Result<(), String> {
2180 let p = self.output_dim();
2181 let k_atoms = self.k_atoms();
2182 for img in &images {
2183 if img.atom_idx >= k_atoms {
2184 return Err(format!(
2185 "set_hybrid_linear_images: atom_idx {} out of range (k_atoms={k_atoms})",
2186 img.atom_idx
2187 ));
2188 }
2189 if img.b0.len() != p || img.b1.len() != p {
2190 return Err(format!(
2191 "set_hybrid_linear_images: atom {} linear image has p=({}, {}) != output_dim {p}",
2192 img.atom_idx,
2193 img.b0.len(),
2194 img.b1.len()
2195 ));
2196 }
2197 if self.atoms[img.atom_idx].latent_dim != 1 {
2198 return Err(format!(
2199 "set_hybrid_linear_images: atom {} is not d=1; only d=1 slots collapse to a straight image",
2200 img.atom_idx
2201 ));
2202 }
2203 }
2204 self.oos_linear_images = if images.is_empty() {
2205 None
2206 } else {
2207 Some(images)
2208 };
2209 Ok(())
2210 }
2211
2212 /// Assemble the reconstruction `Σ_k a[i,k]·g_k(t_{ik})` from an EXPLICIT
2213 /// per-row assignment matrix (e.g. a hard top-k projection of the fitted
2214 /// soft assignments), honouring the #1026 hybrid collapse when `collapse` is
2215 /// set: a verdict-linear `d = 1` slot decodes its straight sub-model image
2216 /// instead of its curved curve, exactly as the production `try_fitted` does.
2217 /// This is the shared assembler the FFI top-k path uses so the projected
2218 /// reconstruction composes with hybrid collapse (#1233) instead of
2219 /// re-deriving the curved image by hand and silently bypassing the verdict.
2220 /// The atom coordinates (`t`) and decoded curves are the term's own fitted
2221 /// ones; only the assignment masses come from `assignments`.
2222 pub fn reconstruct_from_assignments(
2223 &self,
2224 assignments: ArrayView2<'_, f64>,
2225 collapse: bool,
2226 ) -> Result<Array2<f64>, String> {
2227 let n = self.n_obs();
2228 let p = self.output_dim();
2229 let k_atoms = self.k_atoms();
2230 if assignments.dim() != (n, k_atoms) {
2231 return Err(format!(
2232 "SaeManifoldTerm::reconstruct_from_assignments: assignments {:?} != ({n}, {k_atoms})",
2233 assignments.dim()
2234 ));
2235 }
2236 let linear_images = if collapse {
2237 self.hybrid_linear_image_map()
2238 } else {
2239 std::collections::HashMap::new()
2240 };
2241 let mut out = Array2::<f64>::zeros((n, p));
2242 let mut g_buf = vec![0.0_f64; p];
2243 for row in 0..n {
2244 for atom_idx in 0..k_atoms {
2245 let a_k = assignments[[row, atom_idx]];
2246 if a_k == 0.0 {
2247 continue;
2248 }
2249 if let Some(image) = linear_images.get(&atom_idx) {
2250 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2251 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2252 } else {
2253 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2254 }
2255 let mut out_row = out.row_mut(row);
2256 for out_col in 0..p {
2257 out_row[out_col] += a_k * g_buf[out_col];
2258 }
2259 }
2260 }
2261 Ok(out)
2262 }
2263
2264 pub fn try_fitted(&self) -> Result<Array2<f64>, String> {
2265 // Production/user-facing reconstruction: honours the #1026 hybrid-split
2266 // verdict (verdict-linear `d = 1` slots decode their straight sub-model).
2267 self.try_fitted_with_rho(None, true)
2268 }
2269
2270 pub(crate) fn try_fitted_for_rho(&self, rho: &SaeManifoldRho) -> Result<Array2<f64>, String> {
2271 // Internal/fitting reconstruction: the pure CURVED image (the joint fit
2272 // and the #1026 adjudication both require the uncollapsed curve).
2273 self.try_fitted_with_rho(Some(rho), false)
2274 }
2275
2276 pub(crate) fn try_fitted_with_rho(
2277 &self,
2278 rho: Option<&SaeManifoldRho>,
2279 collapse: bool,
2280 ) -> Result<Array2<f64>, String> {
2281 let n = self.n_obs();
2282 let p = self.output_dim();
2283 let k_atoms = self.k_atoms();
2284 let mut out = Array2::<f64>::zeros((n, p));
2285 // #1026 — the curved/linear hybrid-split verdict is LOAD-BEARING on the
2286 // production reconstruction, not just a side report. When
2287 // [`Self::compute_hybrid_split_report`] (run post-fit in
2288 // `canonicalize_charts_post_fit`) adjudicated a `d = 1` atom's evidence
2289 // in favour of its straight (Θ→0) sub-model, the model's output
2290 // reconstruction (`fitted()` / `try_fitted` → predict and the user-facing
2291 // output) decodes that slot with its fitted linear image instead of its
2292 // curved decoded curve. The linear images are coordinate-keyed and
2293 // rho-independent (exact weighted-LS lines realised inside the
2294 // adjudication — no re-fit, no #1051 outer continuation).
2295 //
2296 // The collapse engages only when the caller asks for it (`collapse`):
2297 // the production `try_fitted` path and the explicit
2298 // `hybrid_collapsed_reconstruction` entry point. The pure-curved
2299 // `try_fitted_for_rho` opts out — the joint fit's loss/assembly optimise
2300 // the curved decoder coefficients and must see the curved image, and the
2301 // #1026 adjudication itself compares the curved fit against its straight
2302 // sub-model — both require the uncollapsed curve. (During fitting the
2303 // report is `None` regardless; it is only computed post-fit.)
2304 let linear_images = if collapse {
2305 self.hybrid_linear_image_map()
2306 } else {
2307 std::collections::HashMap::new()
2308 };
2309 // Reuse a single scratch buffer across all (row, atom) pairs instead of
2310 // allocating a fresh `Array1<f64>` of length p per call.
2311 let mut g_buf = vec![0.0_f64; p];
2312 for row in 0..n {
2313 let a = match rho {
2314 Some(rho) => self.assignment.try_assignments_row_for_rho(row, rho)?,
2315 None => self.assignment.try_assignments_row(row)?,
2316 };
2317 for atom_idx in 0..k_atoms {
2318 let a_k = a[atom_idx];
2319 if let Some(image) = linear_images.get(&atom_idx) {
2320 // Verdict-linear slot: substitute the straight sub-model image
2321 // at this row's fitted on-atom coordinate — or, for a #1026
2322 // collapse-rescued slot, at its fresh per-row code.
2323 let own_t = self.assignment.coords[atom_idx].as_matrix()[[row, 0]];
2324 image.fill_row(image.coordinate_for_row(row, own_t), &mut g_buf);
2325 } else {
2326 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2327 }
2328 let mut out_row = out.row_mut(row);
2329 for out_col in 0..p {
2330 out_row[out_col] += a_k * g_buf[out_col];
2331 }
2332 }
2333 }
2334 Ok(out)
2335 }
2336
2337 /// Per-atom **leave-one-atom-out (LOAO) explained-variance contribution**
2338 /// (#1026): for each atom `k`, the drop in reconstruction explained variance
2339 /// `ΔEV_k = EV(full) − EV(full ⊖ atom_k)` when that atom's contribution
2340 /// `a[i,k]·g_k(coord[i,k])` is removed from the assembled reconstruction and
2341 /// nothing else is refit. Because every atom adds linearly into the same
2342 /// fitted reconstruction (`fitted[i] = Σ_k a[i,k]·g_k`), zeroing one atom is
2343 /// the exact "this atom withheld" counterfactual, and the EV it was earning
2344 /// is `EV(full) − EV(without k)`. This is the per-atom held-out EV
2345 /// attribution the #1026 roadmap pairs with each atom's fitted turning `Θ`:
2346 /// a `Θ ≈ 0` atom earning a large `ΔEV` is a linear-tail direction; a
2347 /// high-`Θ` atom earning a large `ΔEV` is a genuine curved family carrying
2348 /// reconstruction it would otherwise shatter into `N(ε) ≈ Θ/(2√(2ε))` linear
2349 /// directions. Pure read-only diagnostic — never mutates any atom.
2350 ///
2351 /// Returns one `Option<f64>` per atom in atom order; `None` for an atom
2352 /// whose ⊖-reconstruction EV is undefined (degenerate target variance), and
2353 /// `None` for the whole vector if the full-reconstruction EV is undefined.
2354 /// #1026: the load-bearing curved-vs-linear hybrid-split verdict for the
2355 /// fitted dictionary, or `None` until [`Self::canonicalize_charts_post_fit`]
2356 /// has run (or when no `d = 1` atom is eligible). Surfaced in the Python model
2357 /// output so the user sees which atoms genuinely earn their curvature.
2358 pub fn hybrid_split_report(
2359 &self,
2360 ) -> Option<&crate::hybrid_split::SaeHybridSplitReport> {
2361 self.hybrid_split_report.as_ref()
2362 }
2363
2364 /// Build the #1026 curved-vs-linear hybrid-split report by adjudicating each
2365 /// eligible `d = 1` atom's fitted curved image against its straight (linear
2366 /// special-case) sub-model on the common rank-aware Laplace evidence scale.
2367 ///
2368 /// Both candidates are scored against the SAME data — the atom's
2369 /// leave-this-atom-out response residual `y_resp = target − (full − a_k·γ_k)`
2370 /// (#1202) — over its assigned rows: the curved candidate predicts its actual
2371 /// mass-scaled contribution `a_k·γ_k`, the linear candidate the best
2372 /// mass-weighted straight line fit to `y_resp` (the collapsed linear lane —
2373 /// closed form, NOT the broken euclidean outer fit path of #1051). Linear is
2374 /// the curved family's nested `Θ = 0` sub-model on common data, so the
2375 /// per-slot evidence argmin is a genuine match-or-beat comparison. Eligible
2376 /// atoms are `d = 1` atoms with an installed evaluator at the full curvature
2377 /// dial (`homotopy_eta == 1.0`) whose live coordinate dim still matches the
2378 /// atom's latent dim. Returns `None` when no reconstruction `target` is
2379 /// supplied (there is no data to adjudicate against).
2380 pub fn compute_hybrid_split_report(
2381 &self,
2382 rho: &SaeManifoldRho,
2383 target: Option<ArrayView2<'_, f64>>,
2384 ) -> Result<Option<crate::hybrid_split::SaeHybridSplitReport>, String> {
2385 let n = self.n_obs();
2386 let p = self.output_dim();
2387 // Per-atom held-out `ΔEV_k` (leave-one-atom-out explained-variance drop),
2388 // paired with each atom's fitted turning Θ onto the verdict so the report
2389 // carries the #1026 `(Θ, ΔEV)` frontier point as structured data. Absent
2390 // when no reconstruction target is supplied.
2391 let loao_ev: Vec<Option<f64>> = match target {
2392 Some(t) => self.per_atom_loao_explained_variance(t, rho)?,
2393 None => vec![None; self.k_atoms()],
2394 };
2395 let delta_ev_for =
2396 |atom_idx: usize| -> Option<f64> { loao_ev.get(atom_idx).copied().flatten() };
2397 // The common-evidence comparison (#1202) scores both candidates against
2398 // the response data the atom is responsible for. That requires a target;
2399 // with none supplied there is nothing to adjudicate against, so no report.
2400 let Some(target) = target else {
2401 return Ok(None);
2402 };
2403 if target.dim() != (n, p) {
2404 return Err(format!(
2405 "SaeManifoldTerm::compute_hybrid_split_report: target {:?} != ({n}, {p})",
2406 target.dim()
2407 ));
2408 }
2409 // Per-row assignment masses (once), so each atom's weighted straight-line
2410 // fit uses the same row weighting the joint reconstruction loss does.
2411 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2412 for row in 0..n {
2413 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2414 }
2415 // The full assembled reconstruction `Σ_k a[i,k]·γ_k`, computed once. Each
2416 // atom's leave-this-atom-out response residual is `y_resp = target −
2417 // (full − a_k·γ_k)`, the data both that atom's candidates fit (#1202).
2418 let full = self.try_fitted_for_rho(rho)?;
2419 let eligible: Vec<usize> = (0..self.k_atoms())
2420 .filter(|&atom_idx| {
2421 let atom = &self.atoms[atom_idx];
2422 atom.latent_dim == 1
2423 && atom.basis_evaluator.is_some()
2424 && atom.homotopy_eta == 1.0
2425 && self.assignment.coords[atom_idx].latent_dim() == atom.latent_dim
2426 })
2427 .collect();
2428 // Per-atom fitted decoded image at every row (the curved candidate's
2429 // realized curve, which the linear candidate must approximate).
2430 let coords_for = |atom_idx: usize| -> Array1<f64> {
2431 self.assignment.coords[atom_idx]
2432 .as_matrix()
2433 .column(0)
2434 .to_owned()
2435 };
2436 let assign_for = |atom_idx: usize| -> Array1<f64> {
2437 Array1::from_iter((0..n).map(|row| weights[row][atom_idx]))
2438 };
2439 let decoded_for = |atom_idx: usize| -> Array2<f64> {
2440 let mut decoded = Array2::<f64>::zeros((n, p));
2441 let mut buf = vec![0.0_f64; p];
2442 for row in 0..n {
2443 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2444 for col in 0..p {
2445 decoded[[row, col]] = buf[col];
2446 }
2447 }
2448 decoded
2449 };
2450 // The atom's leave-this-atom-out response residual `y_resp = target −
2451 // (full − a_k·γ_k) = (target − full) + a_k·γ_k`. Both the curved and the
2452 // linear candidate are scored against this on common data (#1202).
2453 let target_resid_for = |atom_idx: usize| -> Array2<f64> {
2454 let mut resid = Array2::<f64>::zeros((n, p));
2455 let mut buf = vec![0.0_f64; p];
2456 for row in 0..n {
2457 let a_k = weights[row][atom_idx];
2458 self.atoms[atom_idx].fill_decoded_row(row, &mut buf);
2459 for col in 0..p {
2460 resid[[row, col]] = target[[row, col]] - full[[row, col]] + a_k * buf[col];
2461 }
2462 }
2463 resid
2464 };
2465 let manifold_for = |atom_idx: usize| -> gam_terms::latent::LatentManifold {
2466 self.assignment.coords[atom_idx].manifold().clone()
2467 };
2468 // #1026 EV-preservation gate denominator: the full target's total
2469 // column-centered variance `SST_full` (the SAME `sst` the reconstruction
2470 // EV is measured against), so the gate vetoes any collapse that would drop
2471 // full-reconstruction EV by more than its tolerance.
2472 let total_centered_variance = {
2473 let mut tss = 0.0_f64;
2474 for col in 0..p {
2475 let mut mean = 0.0_f64;
2476 for row in 0..n {
2477 mean += target[[row, col]];
2478 }
2479 mean /= n as f64;
2480 for row in 0..n {
2481 let c = target[[row, col]] - mean;
2482 tss += c * c;
2483 }
2484 }
2485 tss
2486 };
2487 crate::hybrid_split::build_hybrid_split_report(
2488 &self.atoms,
2489 eligible.into_iter(),
2490 coords_for,
2491 assign_for,
2492 decoded_for,
2493 target_resid_for,
2494 manifold_for,
2495 delta_ev_for,
2496 total_centered_variance,
2497 )
2498 }
2499
2500 pub fn per_atom_loao_explained_variance(
2501 &self,
2502 target: ArrayView2<'_, f64>,
2503 rho: &SaeManifoldRho,
2504 ) -> Result<Vec<Option<f64>>, String> {
2505 let n = self.n_obs();
2506 let p = self.output_dim();
2507 let k_atoms = self.k_atoms();
2508 if target.dim() != (n, p) {
2509 return Err(format!(
2510 "SaeManifoldTerm::per_atom_loao_explained_variance: target {:?} != ({n}, {p})",
2511 target.dim()
2512 ));
2513 }
2514 let full = self.try_fitted_for_rho(rho)?;
2515 let Some(ev_full) = reconstruction_explained_variance(target, full.view()) else {
2516 return Ok(vec![None; k_atoms]);
2517 };
2518 // Cache each row's assignment weights once, then subtract a single
2519 // atom's decoded contribution per LOAO pass instead of reassembling the
2520 // whole dictionary k times.
2521 let mut weights: Vec<Array1<f64>> = Vec::with_capacity(n);
2522 for row in 0..n {
2523 weights.push(self.assignment.try_assignments_row_for_rho(row, rho)?);
2524 }
2525 let mut g_buf = vec![0.0_f64; p];
2526 let mut out = Vec::with_capacity(k_atoms);
2527 for atom_idx in 0..k_atoms {
2528 let mut without = full.clone();
2529 for row in 0..n {
2530 let a_k = weights[row][atom_idx];
2531 if a_k == 0.0 {
2532 continue;
2533 }
2534 self.atoms[atom_idx].fill_decoded_row(row, &mut g_buf);
2535 let mut without_row = without.row_mut(row);
2536 for out_col in 0..p {
2537 without_row[out_col] -= a_k * g_buf[out_col];
2538 }
2539 }
2540 out.push(
2541 reconstruction_explained_variance(target, without.view())
2542 .map(|ev_without| ev_full - ev_without),
2543 );
2544 }
2545 Ok(out)
2546 }
2547
2548 /// #1026 — the LOAD-BEARING collapsed reconstruction: the assembled
2549 /// dictionary output `Σ_k a[i,k]·g_k(coord[i,k])` in which every slot whose
2550 /// hybrid-split verdict selected LINEAR has its curved decoded image replaced
2551 /// by its fitted straight sub-model `b₀ + (t − t̄)·b₁`. This is what makes the
2552 /// verdict *change the reconstruction* instead of merely logging a choice:
2553 /// the linear-collapsed atom no longer pays its `M·p` curved coefficients, it
2554 /// carries a `2·p` straight image whose decoded curve has zero turning.
2555 ///
2556 /// The straight images are the exact weighted-least-squares lines already
2557 /// realized inside [`Self::compute_hybrid_split_report`] (no re-fit, no outer
2558 /// continuation, sidestepping #1051). Returns the curved reconstruction
2559 /// unchanged when no verdict selected linear, or when the report has not been
2560 /// computed yet (`hybrid_split_report == None`).
2561 pub fn hybrid_collapsed_reconstruction(
2562 &self,
2563 rho: &SaeManifoldRho,
2564 ) -> Result<Array2<f64>, String> {
2565 // #1026 — the hybrid collapse is realised by the SINGLE reconstruction
2566 // path ([`Self::try_fitted_with_rho`]) with the collapse flag set: a
2567 // verdict-linear `d = 1` slot decodes its straight sub-model image
2568 // instead of its curved curve. This replaces the dedicated re-collapse
2569 // loop this method used to carry (a parallel layer). The production
2570 // `try_fitted` shares the identical routine at `rho = None`; this entry
2571 // point keeps the rho-keyed collapse for the #1026 EV-dominance reporting
2572 // (`hybrid_collapsed_explained_variance`) and the regression battery.
2573 self.try_fitted_with_rho(Some(rho), true)
2574 }
2575
2576 /// #1026 — the reconstruction explained variance of the hybrid-collapsed
2577 /// dictionary (every verdict-linear slot decoded by its straight sub-model)
2578 /// against `target`. The companion of [`Self::per_atom_loao_explained_variance`]
2579 /// for the dominance claim: because each linear-collapsed slot is the curved
2580 /// family's `Θ → 0` sub-model and is only kept when its evidence beats the
2581 /// curved candidate's parameter price, the collapsed dictionary match-or-beats
2582 /// the all-curved one on EV-per-parameter — the strict-generalization floor
2583 /// the #1026 hybrid argument rests on. `None` when EV is undefined (degenerate
2584 /// target variance).
2585 pub fn hybrid_collapsed_explained_variance(
2586 &self,
2587 target: ArrayView2<'_, f64>,
2588 rho: &SaeManifoldRho,
2589 ) -> Result<Option<f64>, String> {
2590 let n = self.n_obs();
2591 let p = self.output_dim();
2592 if target.dim() != (n, p) {
2593 return Err(format!(
2594 "SaeManifoldTerm::hybrid_collapsed_explained_variance: target {:?} != ({n}, {p})",
2595 target.dim()
2596 ));
2597 }
2598 let collapsed = self.hybrid_collapsed_reconstruction(rho)?;
2599 Ok(reconstruction_explained_variance(target, collapsed.view()))
2600 }
2601
2602 /// #1026 ladder item 2/3 — the AMORTIZED ENCODER, wired from the fitted
2603 /// dictionary. Builds the offline certified [`EncodeAtlas`] over this term's
2604 /// frozen atoms and encodes a target corpus `targets` (`n × p`) through the
2605 /// per-chart distilled Jacobian predictor, with the Kantorovich certificate
2606 /// gating each row and an exact-solve fallback for the rows the amortized
2607 /// predictor cannot certify. Returns one [`EncodeResult`] per atom (the
2608 /// per-atom encoded coordinates + per-row certificate mask), in dictionary
2609 /// order.
2610 ///
2611 /// This is the thread's "encoder + certificate-gated exact fallback"
2612 /// deployment made reachable from a fit: the distilled map approximates
2613 /// inference at one mat-vec/row, and any row whose amortized prediction fails
2614 /// `h ≤ ½` falls back to the certified IFT-warm-start Newton encode
2615 /// ([`EncodeAtlas::certified_encode_row`]); rows that still cannot be
2616 /// certified ride the [`EncodeResult::encode_uncertified_count`] flag for the
2617 /// upstream exact multi-start solve (honesty, never a silent wrong encode).
2618 ///
2619 /// Magic by default: the atlas's worst-case bounds are auto-derived from the
2620 /// fit — `amplitude_bound[k]` is the largest fitted assignment mass `a[i,k]`
2621 /// the encode can produce for atom `k` (the encode recovers `t` from
2622 /// `x ≈ z·γ_k(t)` at amplitude `z = a[i,k]`), and `target_norm_bound` is the
2623 /// largest target row norm — so no caller supplies a knob. Per-row amplitudes
2624 /// are the fitted assignment masses for the same target the dictionary was fit
2625 /// against; an external corpus reuses the per-row masses the assignment
2626 /// produces for it upstream (passed in `amplitudes`, one column per atom).
2627 pub fn amortized_encode_target(
2628 &self,
2629 targets: ArrayView2<'_, f64>,
2630 amplitudes: ArrayView2<'_, f64>,
2631 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2632 let p = self.output_dim();
2633 let k_atoms = self.k_atoms();
2634 let n = targets.nrows();
2635 if targets.ncols() != p {
2636 return Err(format!(
2637 "SaeManifoldTerm::amortized_encode_target: targets have {} cols but output_dim is {p}",
2638 targets.ncols()
2639 ));
2640 }
2641 if amplitudes.dim() != (n, k_atoms) {
2642 return Err(format!(
2643 "SaeManifoldTerm::amortized_encode_target: amplitudes {:?} must be (n={n}, K={k_atoms})",
2644 amplitudes.dim()
2645 ));
2646 }
2647
2648 // Magic-by-default offline bounds, auto-derived from the fit so no caller
2649 // supplies a knob. `target_norm_bound` is the largest target row L2 norm
2650 // (bounds `‖x‖` over the corpus); `amplitude_bound[k]` is the largest
2651 // fitted assignment mass for atom `k` (bounds `|z_k|`), with a strictly
2652 // positive floor so a near-inactive atom still certifies a finite radius.
2653 let mut target_norm_bound = 0.0_f64;
2654 for row in 0..n {
2655 let norm = targets.row(row).dot(&targets.row(row)).sqrt();
2656 if norm.is_finite() && norm > target_norm_bound {
2657 target_norm_bound = norm;
2658 }
2659 }
2660 let mut amplitude_bound = vec![0.0_f64; k_atoms];
2661 for atom_idx in 0..k_atoms {
2662 let mut bound = 0.0_f64;
2663 for row in 0..n {
2664 let z = amplitudes[[row, atom_idx]].abs();
2665 if z.is_finite() && z > bound {
2666 bound = z;
2667 }
2668 }
2669 // A strictly positive amplitude floor keeps the offline Lipschitz
2670 // scaling finite for atoms with no active row in this corpus (those
2671 // rows encode to the chart center via the certificate anyway).
2672 amplitude_bound[atom_idx] = bound.max(1.0);
2673 }
2674
2675 let atlas = crate::encode::EncodeAtlas::build(
2676 &self.atoms,
2677 &litude_bound,
2678 target_norm_bound,
2679 crate::encode::AtlasConfig::default(),
2680 )?;
2681
2682 // Per-atom amortized encode with a certificate-gated exact-solve fallback:
2683 // a row whose distilled prediction fails `h ≤ ½` is retried through the
2684 // certified IFT-warm-start Newton path; a row that still cannot be
2685 // certified stays flagged for the upstream multi-start solve.
2686 // (The atlas is rho-free; the per-row amplitudes already carry the
2687 // rho-resolved assignment masses the caller produced upstream.)
2688 let mut results = Vec::with_capacity(k_atoms);
2689 for atom_idx in 0..k_atoms {
2690 let atom = &self.atoms[atom_idx];
2691 let amp_col = amplitudes.column(atom_idx).to_owned();
2692 let amortized =
2693 atlas.amortized_encode_batch(atom, atom_idx, targets, amp_col.view())?;
2694 let mut coords = amortized.coords;
2695 let mut certified = amortized.certified;
2696 for row in 0..n {
2697 if certified[row] {
2698 continue;
2699 }
2700 let (t, cert) =
2701 atlas.certified_encode_row(atom, atom_idx, targets.row(row), amp_col[row])?;
2702 if cert.certified() {
2703 coords.row_mut(row).assign(&t);
2704 certified[row] = true;
2705 }
2706 }
2707 results.push(crate::encode::EncodeResult::from_rows(
2708 coords, certified,
2709 ));
2710 }
2711 Ok(results)
2712 }
2713
2714 /// #1026 — the fitted per-row assignment masses `a[i,k]` (the activation
2715 /// amplitudes `z_k` the amortized encode recovers `t` against), as an
2716 /// `n × K` matrix. These are exactly the masses
2717 /// [`Self::try_fitted_with_rho`] assembles the reconstruction from, so
2718 /// feeding them to [`Self::amortized_encode_target`] re-encodes the SAME
2719 /// inference the dictionary was fit against — the self-consistency the
2720 /// distilled encoder is supervised to approximate.
2721 pub fn fitted_assignment_amplitudes(
2722 &self,
2723 rho: &SaeManifoldRho,
2724 ) -> Result<Array2<f64>, String> {
2725 let n = self.n_obs();
2726 let k_atoms = self.k_atoms();
2727 let mut amplitudes = Array2::<f64>::zeros((n, k_atoms));
2728 for row in 0..n {
2729 let a = self.assignment.try_assignments_row_for_rho(row, rho)?;
2730 for atom_idx in 0..k_atoms {
2731 amplitudes[[row, atom_idx]] = a[atom_idx];
2732 }
2733 }
2734 Ok(amplitudes)
2735 }
2736
2737 /// #1026 — encode the dictionary's own fit-time target with the amortized
2738 /// encoder, deriving the per-row amplitudes from the fitted assignment so the
2739 /// caller supplies neither bounds nor amplitudes (magic by default). The
2740 /// end-to-end "fit → distilled encoder → certificate-gated encode" path.
2741 pub fn amortized_encode_fitted(
2742 &self,
2743 targets: ArrayView2<'_, f64>,
2744 rho: &SaeManifoldRho,
2745 ) -> Result<Vec<crate::encode::EncodeResult>, String> {
2746 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2747 self.amortized_encode_target(targets, amplitudes.view())
2748 }
2749
2750 /// #1154 — amortized-encoder consistency of the CURRENT dictionary against
2751 /// its own fit-time target. This is the co-training signal of the joint
2752 /// amortized-encoder + REML loop (Design A): the amortized (one-mat-vec)
2753 /// encode is built from the *current* fitted decoder, run on `targets`, and
2754 /// scored on two principled axes —
2755 ///
2756 /// * `recon_consistency` (the bilinear part of the co-training loss): the
2757 /// mean per-element squared gap between the **amortized** reconstruction
2758 /// `Σ_k z_k · Φ_k(t̂_k) B_k` (decode the amortized coords) and the
2759 /// **exact** fitted reconstruction `Σ_k z_k · Φ_k(t_k^*) B_k` the inner
2760 /// solve converged to. A dictionary whose encode map is well-approximated
2761 /// to first order by the per-chart IFT predictor scores near zero; a
2762 /// dictionary the amortized encoder *cannot* invert faithfully (sharp
2763 /// curvature, poorly-charted regions) scores high. Minimising this jointly
2764 /// with REML steers the fit toward dictionaries that admit a fast,
2765 /// faithful amortized encode — the architectural co-adaptation #1154 adds.
2766 /// * `uncertified_fraction`: the share of (row, atom) encodes whose
2767 /// Kantorovich certificate failed (`h > ½`), i.e. that fell back to the
2768 /// certified IFT-warm-start Newton. This is the encoder's *certifiable coverage*
2769 /// of the dictionary; co-training rewards dictionaries the cheap encode
2770 /// certifies, not just ones it happens to land.
2771 ///
2772 /// The certificate keeps every accepted amortized coord honest (uncertified
2773 /// rows already ride the exact fallback inside `amortized_encode_target`), so
2774 /// this metric never silently trusts a wrong encode — it MEASURES how much of
2775 /// the dictionary the cheap encoder can faithfully and certifiably invert.
2776 pub fn amortized_encoder_consistency(
2777 &self,
2778 targets: ArrayView2<'_, f64>,
2779 rho: &SaeManifoldRho,
2780 ) -> Result<AmortizedEncoderConsistency, String> {
2781 let n = self.n_obs();
2782 let p = self.output_dim();
2783 let k_atoms = self.k_atoms();
2784 if targets.dim() != (n, p) {
2785 return Err(format!(
2786 "SaeManifoldTerm::amortized_encoder_consistency: targets {:?} must be (n={n}, p={p})",
2787 targets.dim()
2788 ));
2789 }
2790 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2791 let encodes = self.amortized_encode_target(targets, amplitudes.view())?;
2792 // The EXACT fitted reconstruction the inner solve converged to (pure
2793 // curved image, rho-keyed) is the supervision target for the amortized
2794 // reconstruction. Both are n×p ambient, so the comparison is layout-free.
2795 let exact_recon = self.try_fitted_for_rho(rho)?;
2796
2797 // Build the amortized reconstruction Σ_k z_k · Φ_k(t̂_k) B_k by decoding
2798 // each atom's amortized coords through that atom's own basis evaluator.
2799 let mut amortized_recon = Array2::<f64>::zeros((n, p));
2800 let mut uncertified = 0usize;
2801 for atom_idx in 0..k_atoms {
2802 let atom = &self.atoms[atom_idx];
2803 let result = &encodes[atom_idx];
2804 // An atom with no basis evaluator cannot decode an amortized
2805 // reconstruction; every one of its rows is necessarily uncertified
2806 // (the encode flagged them all), so it contributes nothing to the
2807 // amortized recon and its full row-count to the uncertified tally.
2808 // Count it and skip the decode rather than erroring — the consistency
2809 // fold stays a bounded penalty, never a hard abort of the criterion.
2810 let Some(evaluator) = atom.basis_evaluator.as_ref() else {
2811 uncertified += n;
2812 continue;
2813 };
2814 uncertified += result.encode_uncertified_count;
2815 // Decode the amortized coords: Φ_k(t̂) is (n × M_k); B_k is (M_k × p).
2816 let (phi, _jac) = evaluator.evaluate(result.coords.view())?;
2817 let decoded = phi.dot(&atom.decoder_coefficients); // (n × p)
2818 for row in 0..n {
2819 let z = amplitudes[[row, atom_idx]];
2820 if z == 0.0 {
2821 continue;
2822 }
2823 for col in 0..p {
2824 amortized_recon[[row, col]] += z * decoded[[row, col]];
2825 }
2826 }
2827 }
2828
2829 let mut sse = 0.0_f64;
2830 for row in 0..n {
2831 for col in 0..p {
2832 let gap = amortized_recon[[row, col]] - exact_recon[[row, col]];
2833 sse += gap * gap;
2834 }
2835 }
2836 let denom = (n.max(1) * p.max(1)) as f64;
2837 let recon_consistency = sse / denom;
2838 let total_encodes = (n * k_atoms).max(1) as f64;
2839 let uncertified_fraction = uncertified as f64 / total_encodes;
2840
2841 Ok(AmortizedEncoderConsistency {
2842 recon_consistency,
2843 uncertified_fraction,
2844 n_uncertified: uncertified,
2845 n_encodes: n * k_atoms,
2846 })
2847 }
2848
2849 /// #1154 — the co-trained REML criterion: the exact REML criterion at `rho`
2850 /// PLUS the amortized-encoder consistency penalty, so the outer optimizer
2851 /// co-adapts the dictionary + smoothing parameters λ TOWARD a dictionary the
2852 /// fast amortized encoder can faithfully and certifiably invert.
2853 ///
2854 /// This is Design A of #1154. The inner solve still converges the `(t, β)`
2855 /// system to stationarity at the engine's current ρ (so the implicit-function
2856 /// REML λ-gradient `dβ̂/dλ = −(H+S_λ)⁻¹(dS_λ/dλ)β̂` stays EXACT — the encoder
2857 /// only warm-starts/co-adapts, it never replaces the stationary point). The
2858 /// added term
2859 ///
2860 /// ```text
2861 /// J_cotrain(ρ) = REML(ρ) + w · ‖x̂_amortized − x̂_exact‖²/(n·p)
2862 /// + w_cert · uncertified_fraction
2863 /// ```
2864 ///
2865 /// folds the post-fit amortized-encode quality into the ranked objective. The
2866 /// weights are auto-scaled to the REML criterion magnitude (magic by default:
2867 /// no caller knob) so the consistency term is a meaningful but non-dominant
2868 /// fraction of the objective regardless of problem scale.
2869 pub fn reml_criterion_cotrained(
2870 &mut self,
2871 target: ArrayView2<'_, f64>,
2872 rho: &SaeManifoldRho,
2873 registry: Option<&AnalyticPenaltyRegistry>,
2874 inner_max_iter: usize,
2875 learning_rate: f64,
2876 ridge_ext_coord: f64,
2877 ridge_beta: f64,
2878 ) -> Result<(f64, SaeManifoldLoss, AmortizedEncoderConsistency), String> {
2879 // #1154: always attempt the amortized warm-start first inside
2880 // `reml_criterion_cotrained` (the encode/warm path for the cotrained
2881 // objective). Good warm-starts from the running dictionary land the
2882 // inner solve closer to the stationary point used for the fold.
2883 // Advisory only (0 or err falls back to cold); telemetry recorded by
2884 // outer objective callers when present.
2885 self.warm_start_latents_from_amortized_encoder(target, rho)
2886 .unwrap_or(0);
2887 let (reml, loss) = self.reml_criterion_with_refine_policy(
2888 target,
2889 rho,
2890 registry,
2891 inner_max_iter,
2892 learning_rate,
2893 ridge_ext_coord,
2894 ridge_beta,
2895 true,
2896 )?;
2897 let consistency = self.amortized_encoder_consistency(target, rho)?;
2898 // Auto-scale the co-training weights to the REML magnitude so the
2899 // consistency penalty is a bounded, scale-free fraction of the objective
2900 // (magic by default: no caller knob). `reml_scale` floors at 1 so a
2901 // near-zero criterion still admits a meaningful consistency contribution.
2902 let cotrained = Self::fold_cotrain_consistency(reml, &consistency);
2903 Ok((cotrained, loss, consistency))
2904 }
2905
2906 /// #1154 — the single source of the co-training fold arithmetic: add the
2907 /// auto-scaled amortized-encoder consistency penalty to an already-computed
2908 /// REML criterion at the converged dictionary. Both the public
2909 /// [`Self::reml_criterion_cotrained`] entry point and the outer-loop value /
2910 /// gradient lanes (`SaeManifoldOuterObjective::fold_cotrain_consistency`)
2911 /// route through THIS function, so the folded objective cannot drift between
2912 /// the criterion and the cascade-ranked cost (the objective↔gradient desync
2913 /// bug class). The weights are auto-scaled to the REML magnitude (`max(|REML|,
2914 /// 1)`) so the penalty is a bounded, scale-free fraction of the objective
2915 /// regardless of problem scale; the fold carries no analytic gradient (under
2916 /// Design A the REML λ-gradient stays the exact implicit-function path).
2917 #[must_use]
2918 pub fn fold_cotrain_consistency(
2919 reml_cost: f64,
2920 consistency: &AmortizedEncoderConsistency,
2921 ) -> f64 {
2922 let reml_scale = reml_cost.abs().max(1.0);
2923 reml_cost
2924 + COTRAIN_RECON_WEIGHT * reml_scale * consistency.recon_consistency
2925 + COTRAIN_CERT_WEIGHT * reml_scale * consistency.uncertified_fraction
2926 }
2927
2928 /// #1154 item 2 — warm-start the inner latent coordinates from the amortized
2929 /// encoder (Design A). Builds the per-chart IFT-Jacobian atlas from the
2930 /// CURRENT dictionary, runs the one-mat-vec amortized encode of `target`
2931 /// against each atom at the rho-resolved assignment masses, and overwrites
2932 /// each atom's stored latent coords with the predicted `t̂` ON THE ROWS THE
2933 /// KANTOROVICH CERTIFICATE ACCEPTS. Uncertified rows are left at their
2934 /// current coords (the previous-iterate start), so the
2935 /// warm-start can only HELP — a row the cheap predictor cannot certify never
2936 /// corrupts the seed. The subsequent inner Newton refines from this seed to
2937 /// the SAME stationary point (the warm-start changes only the basin entry,
2938 /// not the root), so the REML λ-gradient stays exactly the implicit-function
2939 /// path and the criterion is unchanged at convergence — the amortized encoder
2940 /// only accelerates/co-adapts the inner solve, it never replaces the
2941 /// stationary point.
2942 ///
2943 /// Returns the number of (row, atom) coords actually warm-started (the
2944 /// certified-prediction count), for instrumentation / tests. A first-build
2945 /// dictionary with no usable charts simply warm-starts nothing and returns 0
2946 /// (the cold path is byte-for-byte unchanged).
2947 pub fn warm_start_latents_from_amortized_encoder(
2948 &mut self,
2949 target: ArrayView2<'_, f64>,
2950 rho: &SaeManifoldRho,
2951 ) -> Result<usize, String> {
2952 let n = self.n_obs();
2953 let k_atoms = self.k_atoms();
2954 if n == 0 || k_atoms == 0 {
2955 return Ok(0);
2956 }
2957 let amplitudes = self.fitted_assignment_amplitudes(rho)?;
2958 let encodes = self.amortized_encode_target(target, amplitudes.view())?;
2959 let mut warm_started = 0usize;
2960 for atom_idx in 0..k_atoms {
2961 let d = self.atoms[atom_idx].latent_dim;
2962 if d == 0 {
2963 continue;
2964 }
2965 let result = &encodes[atom_idx];
2966 // Start from the atom's CURRENT coords so uncertified rows are left
2967 // exactly as they were; overwrite only the certified predictions.
2968 let mut coords = self.assignment.coords[atom_idx].as_matrix();
2969 if coords.dim() != (n, d) {
2970 return Err(format!(
2971 "warm_start_latents_from_amortized_encoder: atom {atom_idx} coords {:?} != (n={n}, d={d})",
2972 coords.dim()
2973 ));
2974 }
2975 for row in 0..n {
2976 if !result.certified[row] {
2977 continue;
2978 }
2979 for axis in 0..d {
2980 coords[[row, axis]] = result.coords[[row, axis]];
2981 }
2982 warm_started += 1;
2983 }
2984 // `as_matrix` lays coords out row-major (`[[row, axis]]`), exactly the
2985 // `values[row*d + axis]` order `set_flat` expects, so a plain
2986 // row-major iterator reconstructs the flat vector.
2987 let flat = Array1::from_iter(coords.iter().copied());
2988 self.assignment.coords[atom_idx].set_flat(flat.view());
2989 }
2990 // The basis caches must follow the freshly-seeded coords so the next
2991 // inner solve evaluates Φ at the warm-started t̂, not the stale coords.
2992 self.refresh_basis_from_current_coords()?;
2993 Ok(warm_started)
2994 }
2995
2996 pub fn loss(
2997 &self,
2998 target: ArrayView2<'_, f64>,
2999 rho: &SaeManifoldRho,
3000 ) -> Result<SaeManifoldLoss, String> {
3001 self.loss_scaled(target, rho, 1.0)
3002 }
3003
3004 /// Penalized objective with a `penalty_scale` applied to the β-tier
3005 /// (decoder smoothness) penalty, mirroring
3006 /// [`Self::assemble_arrow_schur_scaled`]. The streaming line search sums
3007 /// per-chunk `loss_scaled(..., n_chunk / N)` so that the global smoothness
3008 /// penalty is counted exactly once across a pass while the per-row data,
3009 /// assignment-prior, and ARD terms sum naturally. `penalty_scale == 1.0`
3010 /// recovers the full-batch objective.
3011 pub fn loss_scaled(
3012 &self,
3013 target: ArrayView2<'_, f64>,
3014 rho: &SaeManifoldRho,
3015 penalty_scale: f64,
3016 ) -> Result<SaeManifoldLoss, String> {
3017 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3018 return Err(format!(
3019 "SaeManifoldTerm::loss_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3020 ));
3021 }
3022 if target.dim() != (self.n_obs(), self.output_dim()) {
3023 return Err(format!(
3024 "SaeManifoldTerm::loss: Z must be ({}, {}); got {:?}",
3025 self.n_obs(),
3026 self.output_dim(),
3027 target.dim()
3028 ));
3029 }
3030 // The likelihood whitens through the RowMetric **only** when the metric
3031 // is a genuinely estimated noise model (`metric.whitens_likelihood()`,
3032 // i.e. `WhitenedStructured` — the #974 residual-covariance seam). For
3033 // Euclidean (default `None`) and for the OutputFisher *gauge* metric the
3034 // reconstruction data-fit stays the isotropic `0.5 * Σ r²`: a gauge /
3035 // output-Fisher inner product must NOT silently replace the
3036 // reconstruction loss with a Fisher pullback (#980). It only drives the
3037 // gauge (see `analytic_penalties::corrected_isometry_penalty`). The
3038 // producer of `WhitenedStructured` is
3039 // `inference::residual_factor::StructuredResidualModel::row_metric`; the
3040 // SAME metric whitens the assembled gradient/Hessian in
3041 // `assemble_arrow_schur` (the single #974 seam), so this value and that
3042 // gradient cannot desync. Without a whitening metric this path is
3043 // bit-for-bit the historical isotropic data-fit.
3044 let whitens = self
3045 .row_metric
3046 .as_ref()
3047 .is_some_and(|metric| metric.whitens_likelihood());
3048 // #991 design honesty weights: the reconstruction channel of row `i`
3049 // is weighted by `w_i` (mean-1 HT inclusion correction). The assembly
3050 // applies the same `w_i` via a `√w_i` scaling of the row residual /
3051 // Jacobian / β load at its single seam, so this value and that
3052 // gradient/Hessian carry the identical per-row factor. `None` ⇒ the
3053 // historical unweighted sum, bit-for-bit.
3054 let row_loss_w = self.row_loss_weights.as_deref();
3055 let n = self.n_obs();
3056 let p = self.output_dim();
3057 let k_atoms = self.k_atoms();
3058 // #1017: the data-fit is the dominant per-line-search-trial cost (it
3059 // re-runs every Armijo halving × every inner Newton iteration × every
3060 // outer ρ evaluation). The old path materialised the whole `n × p`
3061 // fitted matrix (`try_fitted_for_rho`) and then walked it AGAIN to form
3062 // the residual sum — two sequential `n·p` passes plus an `n·p`
3063 // allocation per trial. Fuse the reconstruction and the residual reduce
3064 // into ONE row-parallel pass that never materialises the fitted matrix:
3065 // each row decodes its atoms into per-worker scratch, differences
3066 // against the target, and contributes its scalar `0.5·w·‖r‖²` to a
3067 // chunk-ordered fold (bit-identical run-to-run). Per-worker scratch
3068 // (`map_init`) keeps the only allocations one `g_buf`/`fitted_row` pair
3069 // per rayon thread rather than per row. Stay sequential inside a worker
3070 // (the topology race owns the outer pool) to avoid nested
3071 // oversubscription.
3072 let parallel = n >= SAE_LOSS_PARALLEL_ROW_MIN && rayon::current_thread_index().is_none();
3073 let row_data_fit =
3074 |row: usize,
3075 g_buf: &mut [f64],
3076 fitted_row: &mut [f64],
3077 assign_buf: &mut [f64]|
3078 -> Result<f64, String> {
3079 // #1557 — fill the per-atom assignment row into reused per-worker
3080 // scratch via the `_into` twin instead of heap-allocating a fresh
3081 // `Array1` per row per loss eval. Bit-identical to the allocating
3082 // `try_assignments_row_for_rho` (same arithmetic, same order); this
3083 // loss reruns every Armijo halving × inner Newton iter × outer ρ
3084 // eval, so the per-row K-sized allocation was a hot-path churn.
3085 self.assignment
3086 .try_assignments_row_for_rho_into(row, rho, assign_buf)?;
3087 let a = &*assign_buf;
3088 for slot in fitted_row.iter_mut() {
3089 *slot = 0.0;
3090 }
3091 for atom_idx in 0..k_atoms {
3092 self.atoms[atom_idx].fill_decoded_row(row, g_buf);
3093 let a_k = a[atom_idx];
3094 for out_col in 0..p {
3095 fitted_row[out_col] += a_k * g_buf[out_col];
3096 }
3097 }
3098 for out_col in 0..p {
3099 fitted_row[out_col] = target[[row, out_col]] - fitted_row[out_col];
3100 }
3101 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3102 let mut acc = 0.0_f64;
3103 match self.row_metric.as_ref() {
3104 Some(metric) if whitens => {
3105 let resid = ArrayView1::from(&fitted_row[..p]);
3106 for w in metric.whiten_residual_row(row, resid) {
3107 acc += 0.5 * w_row * w * w;
3108 }
3109 }
3110 _ => {
3111 for &r in fitted_row[..p].iter() {
3112 acc += 0.5 * w_row * r * r;
3113 }
3114 }
3115 }
3116 Ok(acc)
3117 };
3118 let data_fit = if parallel {
3119 use rayon::prelude::*;
3120 const CHUNK: usize = 32;
3121 let partials: Vec<Result<f64, String>> = (0..n)
3122 .into_par_iter()
3123 .chunks(CHUNK)
3124 .map_init(
3125 || (vec![0.0_f64; p], vec![0.0_f64; p], vec![0.0_f64; k_atoms]),
3126 |(g_buf, fitted_row, assign_buf), idxs| {
3127 // #1557 — pin any faer GEMM reached from this row-parallel
3128 // data-fit chunk to `Par::Seq` (no nested Rayon re-fan); the
3129 // per-row reductions are tiny, so the result is bit-identical.
3130 with_nested_parallel(|| {
3131 let mut acc = 0.0_f64;
3132 for row in idxs {
3133 acc += row_data_fit(row, g_buf, fitted_row, assign_buf)?;
3134 }
3135 Ok(acc)
3136 })
3137 },
3138 )
3139 .collect();
3140 let mut total = 0.0_f64;
3141 for partial in partials {
3142 total += partial?;
3143 }
3144 total
3145 } else {
3146 let mut g_buf = vec![0.0_f64; p];
3147 let mut fitted_row = vec![0.0_f64; p];
3148 let mut assign_buf = vec![0.0_f64; k_atoms];
3149 let mut total = 0.0_f64;
3150 for row in 0..n {
3151 total += row_data_fit(row, &mut g_buf, &mut fitted_row, &mut assign_buf)?;
3152 }
3153 total
3154 };
3155 let assignment_sparsity = assignment_prior_value(&self.assignment, rho);
3156 let smoothness = penalty_scale * self.decoder_smoothness_value(&rho.lambda_smooth_vec());
3157 let ard = self.ard_value(rho)?;
3158 Ok(SaeManifoldLoss {
3159 data_fit,
3160 assignment_sparsity,
3161 smoothness,
3162 ard,
3163 evidence_gauge_deflated_directions: 0,
3164 })
3165 }
3166
3167 /// Reconstruction data-fit `0.5·Σ_i w_i·‖whiten(Z_i − R_i)‖²` for an EXPLICIT
3168 /// reconstruction matrix `R` (e.g. the hard top-k–projected `fitted`), using
3169 /// the SAME per-row metric and design-honesty weights as [`Self::loss_scaled`]
3170 /// (the soft-assignment data-fit). The only difference is the residual source:
3171 /// `loss_scaled` decodes the soft assignments on the fly, this consumes a
3172 /// reconstruction the caller already assembled (so the projected loss and the
3173 /// returned projected `fitted` describe one and the same model). The penalty
3174 /// terms (`assignment_sparsity`/`smoothness`/`ard`) are decoder/ρ properties
3175 /// the top-k gate does not change, so the caller keeps them from the soft
3176 /// `loss_scaled` and only swaps this data-fit in — see #1232.
3177 pub fn data_fit_for_reconstruction(
3178 &self,
3179 target: ArrayView2<'_, f64>,
3180 reconstruction: ArrayView2<'_, f64>,
3181 ) -> Result<f64, String> {
3182 let n = self.n_obs();
3183 let p = self.output_dim();
3184 if target.dim() != (n, p) {
3185 return Err(format!(
3186 "SaeManifoldTerm::data_fit_for_reconstruction: Z must be ({n}, {p}); got {:?}",
3187 target.dim()
3188 ));
3189 }
3190 if reconstruction.dim() != (n, p) {
3191 return Err(format!(
3192 "SaeManifoldTerm::data_fit_for_reconstruction: reconstruction must be ({n}, {p}); got {:?}",
3193 reconstruction.dim()
3194 ));
3195 }
3196 let whitens = self
3197 .row_metric
3198 .as_ref()
3199 .is_some_and(|metric| metric.whitens_likelihood());
3200 let row_loss_w = self.row_loss_weights.as_deref();
3201 let mut resid = vec![0.0_f64; p];
3202 let mut total = 0.0_f64;
3203 for row in 0..n {
3204 for out_col in 0..p {
3205 resid[out_col] = target[[row, out_col]] - reconstruction[[row, out_col]];
3206 }
3207 let w_row = row_loss_w.map_or(1.0, |w| w[row]);
3208 match self.row_metric.as_ref() {
3209 Some(metric) if whitens => {
3210 let r = ArrayView1::from(&resid[..p]);
3211 for w in metric.whiten_residual_row(row, r) {
3212 total += 0.5 * w_row * w * w;
3213 }
3214 }
3215 _ => {
3216 for &r in resid[..p].iter() {
3217 total += 0.5 * w_row * r * r;
3218 }
3219 }
3220 }
3221 }
3222 Ok(total)
3223 }
3224
3225 pub fn analytic_penalty_value_total(
3226 &self,
3227 registry: &AnalyticPenaltyRegistry,
3228 penalty_scale: f64,
3229 ) -> Result<f64, ArrowSchurError> {
3230 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3231 return Err(ArrowSchurError::SchurFactorFailed {
3232 reason: format!(
3233 "SaeManifoldTerm::analytic_penalty_value_total: penalty_scale must be finite \
3234 and positive; got {penalty_scale}"
3235 ),
3236 });
3237 }
3238 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3239 let layout = registry.rho_layout();
3240 let beta = self.flatten_beta();
3241 let mut value = 0.0_f64;
3242 for (penalty, (rho_slice, tier, name)) in registry.penalties.iter().zip(layout.iter()) {
3243 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3244 // Skip the registry `ARDPenalty` here for the same reason it is
3245 // skipped in `add_sae_analytic_penalty_contributions`: the coordinate
3246 // ARD energy is already counted by `loss.ard` (the von-Mises
3247 // `ard_value`), and the registry penalty's legacy Gaussian `½λt²` is
3248 // period-discontinuous. Including it would double-count the energy and
3249 // make this line-search objective jump across the branch cut while the
3250 // assembled gradient (von-Mises only, after the assembly fix) stays
3251 // continuous — i.e. a near-zero step would change the objective by a
3252 // finite amount and Armijo would wrongly reject it.
3253 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
3254 continue;
3255 }
3256 match tier {
3257 PenaltyTier::Psi => {
3258 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
3259 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3260 value += penalty_scale
3261 * per_atom.value(beta.slice(s![start..end]), rho_local);
3262 }
3263 } else {
3264 if !sae_penalty_is_row_block_supported(penalty) {
3265 return Err(ArrowSchurError::SchurFactorFailed {
3266 reason: format!(
3267 "validate_analytic_penalty_registry should have refused \
3268 non-row-block Psi-tier penalty {:?} (registry layout name \
3269 {name:?})",
3270 penalty.name()
3271 ),
3272 });
3273 }
3274 for atom_idx in 0..self.k_atoms() {
3275 let coord = &self.assignment.coords[atom_idx];
3276 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3277 let corrected_kind =
3278 self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3279 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3280 } else if sae_coord_penalty_is_origin_anchored_magnitude(penalty) {
3281 // Origin-anchored magnitude shrinkage (SCAD/MCP) is
3282 // restricted to the Euclidean axes; periodic axes have
3283 // no chart origin and would make this energy
3284 // period-discontinuous (issue #795). This must mirror
3285 // the gradient/curvature assembly in
3286 // `add_sae_coord_penalty` exactly.
3287 match sae_coord_penalty_euclidean_restriction(coord) {
3288 Some((_axes, compacted)) => {
3289 value += penalty.value(compacted.view(), rho_local);
3290 }
3291 None => {
3292 value += penalty.value(coord.as_flat().view(), rho_local);
3293 }
3294 }
3295 } else {
3296 value += penalty.value(coord.as_flat().view(), rho_local);
3297 }
3298 }
3299 }
3300 }
3301 PenaltyTier::Beta => {
3302 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
3303 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3304 value += penalty_scale * per_fit.value(beta.view(), rho_local);
3305 }
3306 } else if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
3307 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3308 if start < end {
3309 value += penalty_scale * per_atom.value(beta.view(), rho_local);
3310 }
3311 }
3312 } else {
3313 value += penalty_scale * penalty.value(beta.view(), rho_local);
3314 }
3315 }
3316 PenaltyTier::Rho => {}
3317 }
3318 }
3319 Ok(value)
3320 }
3321
3322 /// Energy of the decoder-block analytic penalties that have no native
3323 /// `SaeManifoldLoss` counterpart, evaluated at the current decoder `β` and
3324 /// the converged SAE state. These act on the per-atom decoder coefficient
3325 /// matrices: cross-atom decoder incoherence (#671), mechanism
3326 /// (feature-group) sparsity, and nuclear-norm embedding rank (#672). Each
3327 /// is injected with its live per-atom shape / co-activation before its
3328 /// value is taken, mirroring the assemble path.
3329 ///
3330 /// This is deliberately narrower than [`Self::analytic_penalty_value_total`]:
3331 /// it excludes the Psi-tier coordinate / assignment penalties (ARD,
3332 /// Isometry, ScadMcp, BlockOrthogonality, IBP/softmax assignment sparsity).
3333 /// The SAE already carries its own ARD (`loss.ard`) and assignment sparsity
3334 /// (`loss.assignment_sparsity`) energy, so adding the registry ARD /
3335 /// assignment value on top would double-count, and the gauge-only
3336 /// coordinate penalties are not part of the penalized deviance the
3337 /// REML/Laplace criterion scores. The decoder-block penalties, by contrast,
3338 /// are real penalized-energy terms with no `loss.*` representative: the
3339 /// inner solve minimizes them (they enter `gb`/`hbb`) but they were absent
3340 /// from the criterion scalar `v`. This restores that consistency so the
3341 /// ρ-sweep ranks the same objective the inner solve descends — the #671
3342 /// incoherence lever in particular now shapes model selection, not just the
3343 /// Newton step.
3344 ///
3345 /// NOTE: the coordinate-block penalties with no native `loss.*` twin
3346 /// (`ScadMcp`, `BlockOrthogonality`) carry the same residual inconsistency
3347 /// (scored in the line search via `penalized_objective_total`, absent from
3348 /// the REML scalar). They are left out here because they share a registry
3349 /// dispatch with the always-on `Isometry` gauge, whose inclusion in the
3350 /// topology-comparison criterion is a separate design question (#673:
3351 /// topology evidence is gauge-conditional). Folding the coord-tier energy in
3352 /// is tracked apart from this #671 decoder fix.
3353 pub fn analytic_decoder_penalty_value_total(
3354 &self,
3355 registry: &AnalyticPenaltyRegistry,
3356 ) -> Result<f64, ArrowSchurError> {
3357 // Resolve each penalty's rho slice exactly as `analytic_penalty_value_total`
3358 // does (registry-local rho at zeros), so a learnable decoder-penalty weight
3359 // is honoured rather than indexing into an empty view.
3360 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3361 let layout = registry.rho_layout();
3362 let beta = self.flatten_beta();
3363 let mut value = 0.0_f64;
3364 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3365 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3366 match penalty {
3367 AnalyticPenaltyKind::DecoderIncoherence(base) => {
3368 if let Some(per_fit) = self.live_decoder_incoherence_penalty(base) {
3369 value += per_fit.value(beta.view(), rho_local);
3370 }
3371 }
3372 AnalyticPenaltyKind::MechanismSparsity(base) => {
3373 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
3374 if start < end {
3375 value += per_atom.value(beta.view(), rho_local);
3376 }
3377 }
3378 }
3379 AnalyticPenaltyKind::NuclearNorm(base) => {
3380 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
3381 value += per_atom.value(beta.slice(s![start..end]), rho_local);
3382 }
3383 }
3384 _ => {}
3385 }
3386 }
3387 Ok(value)
3388 }
3389
3390 /// Energy of the COORDINATE-tier isometry penalty(ies) at the converged
3391 /// SAE state. This is the per-atom `½μ Σ_n ‖J_n^T W_n J_n / gbar − g_ref‖²`
3392 /// summed over atoms, evaluated through `corrected_isometry_penalty` so the
3393 /// live decoder/coordinate caches drive the value exactly as the assemble
3394 /// path does. It has no `SaeManifoldLoss` twin (the loss carries only
3395 /// data-fit / assignment / smoothness / ARD), so the Laplace/REML criterion
3396 /// must add it explicitly to score the same penalized objective the inner
3397 /// solve descends.
3398 pub fn isometry_penalty_value_total(
3399 &self,
3400 registry: &AnalyticPenaltyRegistry,
3401 ) -> Result<f64, ArrowSchurError> {
3402 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
3403 let layout = registry.rho_layout();
3404 let mut value = 0.0_f64;
3405 for (penalty, (rho_slice, _tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
3406 if let AnalyticPenaltyKind::Isometry(iso) = penalty {
3407 let rho_local = rho_global.slice(s![rho_slice.clone()]);
3408 for atom_idx in 0..self.k_atoms() {
3409 let coord = &self.assignment.coords[atom_idx];
3410 let corrected_kind = self.corrected_isometry_penalty(iso, atom_idx, coord)?;
3411 value += corrected_kind.value(coord.as_flat().view(), rho_local);
3412 }
3413 }
3414 }
3415 Ok(value)
3416 }
3417
3418 /// Whether assembling `registry` will scatter an isometry Gauss-Newton
3419 /// cross-block (`H_tβ`) into the per-row dense `htbeta` slabs.
3420 ///
3421 /// `add_sae_isometry_metric_gn_blocks` writes the coupled cross-block (and
3422 /// flips on `activate_dense_htbeta_supplement`) only when (a) the registry
3423 /// carries an `Isometry` penalty and (b) the atom's chart
3424 /// `preserves_isometry_cross_block_coherence` (flat charts — `Euclidean`,
3425 /// `Circle`, and flat products — keep the full `μ AᵀA` coupling; curved /
3426 /// boundary charts drop it to stay PSD). On the non-frames matrix-free path
3427 /// the data-fit cross-block is carried by the Kronecker row operator and the
3428 /// per-row `htbeta` slab is allocated at zero width (#1406/#1407 anti-leak),
3429 /// so this dense isometry supplement has nowhere to land unless the slab is
3430 /// widened to the full `beta_dim`. This predicate decides exactly that. The
3431 /// effective isometry weight `μ` is NOT consulted here: a near-zero `μ`
3432 /// short-circuits the per-row write, but the slab must still exist so the
3433 /// solver's `htbeta_dense_supplement` read is well-shaped.
3434 pub(crate) fn registry_writes_dense_isometry_cross_block(
3435 &self,
3436 registry: &AnalyticPenaltyRegistry,
3437 ) -> bool {
3438 registry
3439 .penalties
3440 .iter()
3441 .any(|p| matches!(p, AnalyticPenaltyKind::Isometry(_)))
3442 && self
3443 .assignment
3444 .coords
3445 .iter()
3446 .any(|coord| coord.manifold().preserves_isometry_cross_block_coherence())
3447 }
3448
3449 /// Extra analytic-penalty energy that has no native `SaeManifoldLoss`
3450 /// component but is part of the penalized objective ranked by the SAE
3451 /// Laplace/REML criterion.
3452 pub fn reml_extra_penalty_value_total(
3453 &self,
3454 registry: &AnalyticPenaltyRegistry,
3455 ) -> Result<f64, ArrowSchurError> {
3456 Ok(self.analytic_decoder_penalty_value_total(registry)?
3457 + self.isometry_penalty_value_total(registry)?)
3458 }
3459
3460 pub fn penalized_objective_total(
3461 &self,
3462 target: ArrayView2<'_, f64>,
3463 rho: &SaeManifoldRho,
3464 registry: Option<&AnalyticPenaltyRegistry>,
3465 penalty_scale: f64,
3466 ) -> Result<f64, String> {
3467 let mut total = self.loss_scaled(target, rho, penalty_scale)?.total();
3468 if let Some(analytic_registry) = registry {
3469 total += self
3470 .analytic_penalty_value_total(analytic_registry, penalty_scale)
3471 .map_err(|err| format!("SaeManifoldTerm::penalized_objective_total: {err}"))?;
3472 }
3473 // #1026 — decoder-repulsion value, on the SAME frozen gate the assembly
3474 // used, so the line search sees the term the Newton step optimizes. 0
3475 // unless two atoms are near-collinear (the no-op case).
3476 total += self.decoder_repulsion_value(penalty_scale);
3477 // #1026/#1522 — interior-point collapse-prevention barriers, on the SAME
3478 // decoders the assembly's gradient/curvature used, so the line search sees
3479 // exactly the term the inner Newton step optimises (no value/grad desync).
3480 total += self.separation_barrier_value(penalty_scale);
3481 Ok(total)
3482 }
3483
3484 pub(crate) fn decoder_smoothness_value(&self, lambda_smooth: &[f64]) -> f64 {
3485 // Smoothness penalty value is `0.5·λ·Σ_oc B[:,oc]ᵀ S B[:,oc]`. Form the
3486 // `S·B` matrix product once per atom (O(M²·p)) and reduce against `B`
3487 // with a single O(M·p) Hadamard sum, instead of the previous
3488 // four-factor multiply-accumulate inside an `O(M²·p)` triple loop.
3489 // The quadratic form only sees the symmetric part of `S`, so reusing
3490 // the raw (un-symmetrised) `smooth_penalty` here is numerically
3491 // identical to the symmetrised assembly form.
3492 // Per-atom `S_k · B_k` products are independent across atoms, so they ride
3493 // the multi-GPU batched smoothness GEMM (uniform-shape groups tiled across
3494 // every device); `symmetrize = false` because the quadratic form only sees
3495 // the symmetric part of `S` regardless. Exact CPU fallback per atom.
3496 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3497 .atoms
3498 .iter()
3499 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3500 .collect();
3501 let sb_all = batched_smooth_sb(&sb_inputs, false);
3502 let mut acc = 0.0;
3503 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3504 acc += 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3505 }
3506 acc
3507 }
3508
3509 /// Per-atom decoder-smoothness values (#1556): entry `k` is
3510 /// `0.5·λ_smooth[k]·<B_k, S_k B_k>` (sum = [`Self::decoder_smoothness_value`]).
3511 /// This is the explicit `∂loss.smoothness/∂log λ_smooth[k]` gradient entry.
3512 pub(crate) fn decoder_smoothness_value_per_atom(&self, lambda_smooth: &[f64]) -> Vec<f64> {
3513 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3514 .atoms
3515 .iter()
3516 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3517 .collect();
3518 let sb_all = batched_smooth_sb(&sb_inputs, false);
3519 let mut per_atom = vec![0.0_f64; self.atoms.len()];
3520 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
3521 per_atom[atom_idx] =
3522 0.5 * lambda_smooth[atom_idx] * (&atom.decoder_coefficients * sb).sum();
3523 }
3524 per_atom
3525 }
3526
3527 pub(crate) fn ard_value(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
3528 if rho.log_ard.len() != self.k_atoms() {
3529 return Err(format!(
3530 "ARD rho has {} atoms but term has {}",
3531 rho.log_ard.len(),
3532 self.k_atoms()
3533 ));
3534 }
3535 let n = self.n_obs();
3536 let mut acc = 0.0;
3537 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3538 let d = coord.latent_dim();
3539 if rho.log_ard[atom_idx].is_empty() {
3540 continue;
3541 }
3542 if rho.log_ard[atom_idx].len() != d {
3543 return Err(format!(
3544 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
3545 rho.log_ard[atom_idx].len()
3546 ));
3547 }
3548 // Per-axis periodicity selects the smooth von-Mises energy on
3549 // wrapped (Circle) axes and the Gaussian on Euclidean axes.
3550 let periods = coord.effective_axis_periods();
3551 for axis in 0..d {
3552 let log_alpha = rho.log_ard[atom_idx][axis];
3553 // Clamp the log-precision before exponentiating: a raw
3554 // `exp(log_ard)` overflows to `inf` for `log_ard ≳ 709`, and the
3555 // `inf` precision then poisons the ARD energy / curvature with
3556 // `inf · 0.0 = NaN` (#742, Issue 4).
3557 let alpha = SaeManifoldRho::stable_exp_strength(log_alpha);
3558 let period = periods[axis];
3559 let mut energy = 0.0;
3560 for row in 0..n {
3561 let v = coord.row(row)[axis];
3562 energy += ArdAxisPrior::eval(alpha, v, period).value;
3563 }
3564 // Negative-log prior for precision alpha. The data-dependent
3565 // energy is the (Gaussian or von-Mises) coordinate prior; the
3566 // accompanying normaliser is the precision log-partition.
3567 //
3568 // Euclidean axes keep the Gaussian normaliser `-0.5 n log α`.
3569 // Periodic (von-Mises) axes use the EXACT von-Mises precision
3570 // log-partition `n[-η + log I0(η)]`, η = α/κ², κ = 2π/P, rather
3571 // than the Gaussian surrogate: the von-Mises partition function
3572 // is `2π I0(η)` (up to the κ Jacobian), so the per-observation
3573 // normaliser is `-η + log I0(η)` and is exact across the cut.
3574 match period {
3575 None => {
3576 acc += energy - 0.5 * (n as f64) * log_alpha;
3577 }
3578 Some(p) => {
3579 let kappa = std::f64::consts::TAU / p;
3580 let eta = alpha / (kappa * kappa);
3581 // Overflow-free `log I0(η)`; `bessel_i0(η).ln()` would be
3582 // `+inf` for `η ≳ 709` (#1113).
3583 let log_i0 = bessel_i0_log_and_ratio(eta).0;
3584 acc += energy + (n as f64) * (-eta + log_i0);
3585 }
3586 }
3587 }
3588 }
3589 Ok(acc)
3590 }
3591
3592 /// Assemble the enlarged `(logits, t)` row-local Arrow-Schur system.
3593 ///
3594 /// Full-batch entry point: a single chunk covering all rows, with the
3595 /// β-tier penalties (decoder smoothness, ARD, analytic β penalties) carrying
3596 /// their full strength. The streaming driver calls
3597 /// [`Self::assemble_arrow_schur_scaled`] directly with a `penalty_scale`
3598 /// equal to the minibatch fraction `n_chunk / N`, so that the sum of the
3599 /// per-chunk β-tier contributions over a full pass reconstructs exactly the
3600 /// single global β penalty (the smoothness/ARD/β terms are functions of `B`
3601 /// and the global coordinates, not of the chunk's rows).
3602 pub fn assemble_arrow_schur(
3603 &mut self,
3604 target: ArrayView2<'_, f64>,
3605 rho: &SaeManifoldRho,
3606 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3607 ) -> Result<ArrowSchurSystem, String> {
3608 self.assemble_arrow_schur_scaled(target, rho, analytic_penalties, 1.0)
3609 }
3610
3611 /// Assemble the row-local Arrow-Schur system with a `penalty_scale` applied
3612 /// to the β-tier (decoder smoothness, ARD prior, analytic β penalties).
3613 ///
3614 /// `penalty_scale == 1.0` recovers the full-batch assembly. The streaming
3615 /// driver passes the minibatch fraction `n_chunk / N` so that the β-tier
3616 /// reduced-Schur and gradient contributions of the chunks sum to exactly one
3617 /// global copy across a full pass (data-fit, assignment-prior, and per-row
3618 /// coord/logit analytic terms are *not* scaled — they are genuine per-row
3619 /// sums).
3620 pub fn assemble_arrow_schur_scaled(
3621 &mut self,
3622 target: ArrayView2<'_, f64>,
3623 rho: &SaeManifoldRho,
3624 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3625 penalty_scale: f64,
3626 ) -> Result<ArrowSchurSystem, String> {
3627 self.assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3628 target,
3629 rho,
3630 analytic_penalties,
3631 penalty_scale,
3632 SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM,
3633 )
3634 }
3635
3636 pub(crate) fn assemble_arrow_schur_scaled_with_beta_penalty_probe_threshold(
3637 &mut self,
3638 target: ArrayView2<'_, f64>,
3639 rho: &SaeManifoldRho,
3640 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3641 penalty_scale: f64,
3642 dense_beta_penalty_probe_max_dim: usize,
3643 ) -> Result<ArrowSchurSystem, String> {
3644 self.assemble_arrow_schur_inner(
3645 target,
3646 rho,
3647 analytic_penalties,
3648 penalty_scale,
3649 dense_beta_penalty_probe_max_dim,
3650 None,
3651 )
3652 }
3653
3654 /// Innermost assembly entry. `forced_layout` overrides the budget-derived
3655 /// active-set layout so a caller can pin the dense (`Forced(None)`) or a
3656 /// specific compact (`Forced(Some(layout))`) path — used by the
3657 /// compact-vs-dense Riemannian-geometry equality regression test to drive
3658 /// both layouts on identical data. `Computed` is the production path:
3659 /// the layout is derived from the assignment mode + `sparse_active_plan`.
3660 pub(crate) fn assemble_arrow_schur_inner(
3661 &mut self,
3662 target: ArrayView2<'_, f64>,
3663 rho: &SaeManifoldRho,
3664 analytic_penalties: Option<&AnalyticPenaltyRegistry>,
3665 penalty_scale: f64,
3666 dense_beta_penalty_probe_max_dim: usize,
3667 forced_layout: ForcedRowLayout,
3668 ) -> Result<ArrowSchurSystem, String> {
3669 if !(penalty_scale.is_finite() && penalty_scale > 0.0) {
3670 return Err(format!(
3671 "SaeManifoldTerm::assemble_arrow_schur_scaled: penalty_scale must be finite and positive; got {penalty_scale}"
3672 ));
3673 }
3674 if target.dim() != (self.n_obs(), self.output_dim()) {
3675 return Err(format!(
3676 "SaeManifoldTerm::assemble_arrow_schur: Z must be ({}, {}); got {:?}",
3677 self.n_obs(),
3678 self.output_dim(),
3679 target.dim()
3680 ));
3681 }
3682 if rho.log_ard.len() != self.k_atoms() {
3683 return Err(format!(
3684 "SaeManifoldTerm::assemble_arrow_schur: log_ard length {} != K {}",
3685 rho.log_ard.len(),
3686 self.k_atoms()
3687 ));
3688 }
3689 // `lambda_smooth` is indexed per-atom in the smoothness gradient/curvature
3690 // assembly (`lambda_smooth[atom_idx]`); a too-short vector (e.g. a growth
3691 // move that grew `k_atoms()` without extending ρ — #1556) would panic deep
3692 // in the assembly loop with an opaque index-out-of-bounds. Validate it here
3693 // alongside `log_ard` so the contract violation surfaces as a clear Err.
3694 if rho.log_lambda_smooth.len() != self.k_atoms() {
3695 return Err(format!(
3696 "SaeManifoldTerm::assemble_arrow_schur: log_lambda_smooth length {} != K {}",
3697 rho.log_lambda_smooth.len(),
3698 self.k_atoms()
3699 ));
3700 }
3701 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
3702 let ard_len = rho.log_ard[atom_idx].len();
3703 let d = coord.latent_dim();
3704 if ard_len != 0 && ard_len != d {
3705 return Err(format!(
3706 "SaeManifoldTerm::assemble_arrow_schur: log_ard atom {atom_idx} \
3707 has len {ard_len}; expected 0 (disabled) or atom dim {d}"
3708 ));
3709 }
3710 }
3711 // Reparameterize each atom's roughness Gram into arc length at the
3712 // current decoder/coordinates (issue #673). This is the single
3713 // chokepoint for both the inner Newton assembly and the undamped
3714 // evidence factorization, so freezing the pullback-metric weight here
3715 // (lagged-diffusivity) keeps the smoothness value, gradient, Kronecker
3716 // Hessian, and REML log-det mutually consistent within each assembly
3717 // and makes the converged penalty — hence the topology evidence —
3718 // gauge-invariant. Constant-speed (periodic) atoms are unaffected.
3719 for atom in &mut self.atoms {
3720 atom.refresh_intrinsic_smooth_penalty();
3721 }
3722 // #1026 — freeze the decoder-repulsion collinearity gate at the SAME
3723 // assembly chokepoint as the smoothness Gram, so the repulsion's
3724 // gradient/curvature (assembled below) and its value (read by the
3725 // line-search `penalized_objective_total`) share one frozen gate.
3726 self.refresh_decoder_repulsion_gate();
3727 // #1625 — freeze the SEPARATION barrier's normalized-coactivation `q_jk`
3728 // at the same chokepoint. The barrier weights its decoder-shape repulsion
3729 // by the routing coactivation, but its gradient treats that weight as a
3730 // constant; recomputing it from the trial logits in the line-search value
3731 // desyncs value vs gradient in the logit block and stalls the inner solve
3732 // (#1625). Freezing it here makes value/gradient/curvature consistent.
3733 self.refresh_barrier_coactivation_gate();
3734 let n = self.n_obs();
3735 let p = self.output_dim();
3736 let k_atoms = self.k_atoms();
3737 let assignment_dim = self.assignment.assignment_coord_dim();
3738 let q = self.assignment.row_block_dim();
3739 let beta_dim = self.beta_dim();
3740 let frame_projection = FrameProjection::new(self);
3741 let beta_offsets = frame_projection.beta_offsets.clone();
3742 let coord_offsets = self.assignment.coord_offsets();
3743 // β-tier decoder smoothness is a global (B-only) penalty; under a
3744 // minibatch pass it is scaled by the chunk fraction so the per-chunk
3745 // contributions sum to one global copy.
3746 // Per-atom decoder-smoothness strengths (#1556): atom k's penalty `S_k`
3747 // is scaled by `λ_smooth[k]·penalty_scale`. The minibatch `penalty_scale`
3748 // multiplies every atom uniformly.
3749 let lambda_smooth: Vec<f64> = rho
3750 .lambda_smooth_vec()
3751 .iter()
3752 .map(|&l| l * penalty_scale)
3753 .collect();
3754 let (assignment_grad, assignment_hdiag) =
3755 assignment_prior_grad_hdiag(&self.assignment, rho)?;
3756
3757 // #1038 softmax entropy: the exact per-row Hessian in logits is dense
3758 // (`H_kj = (λ/τ²) a_k[δ_kj(m−L_k−1)+a_j(L_k+L_j+1−2m)]`), not just the
3759 // `assignment_hdiag` diagonal. Build the shared penalty + `scale = λ/τ²`
3760 // once here so the dense row block written into `block.htt` below, the
3761 // criterion's `log|H|`, and the #1006 θ-adjoint all differentiate the
3762 // SAME operator. JumpReLU / IBP keep their (separately exact) diagonal /
3763 // cross-row channels and leave this `None`. The block is gauge-null in
3764 // isolation (`H·𝟙 = 0`); it is only ever summed onto the gauge-breaking
3765 // data-fit row block before the Cholesky factor, never factored alone.
3766 let softmax_dense: Option<(
3767 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
3768 f64,
3769 )> = match self.assignment.mode {
3770 AssignmentMode::Softmax {
3771 temperature,
3772 sparsity,
3773 } if k_atoms > 1 => {
3774 let inv_tau = 1.0 / temperature;
3775 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
3776 Some((
3777 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
3778 k_atoms,
3779 temperature,
3780 ),
3781 scale,
3782 ))
3783 }
3784 _ => None,
3785 };
3786
3787 // Decoder smoothness penalty: build one KroneckerPenaltyOp per atom
3788 // (structure = λ·S_k ⊗ I_p, offset = beta_offsets[k]) instead of
3789 // materialising the dense K×K block. The gradient is a dense K-vector
3790 // accumulated into `smooth_grad_gb` and written into sys.gb after sys
3791 // is constructed (#296).
3792 let mut smooth_ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len());
3793 // #972 / #977 T1: retain each atom's symmetrised `λ S_k` (`M_k × M_k`) so
3794 // the frame transform can rebuild the smooth penalty in the factored
3795 // coordinate space as `λ S_k ⊗ I_{r_k}` (the `tr(C_kᵀ S_k C_k)` form,
3796 // using `U_kᵀU_k = I`). Unused — and not even read — on the full-`B`
3797 // path, so this is a zero-cost capture there.
3798 let mut smooth_scaled_s: Vec<Array2<f64>> = Vec::with_capacity(self.atoms.len());
3799 let mut smooth_grad_gb = vec![0.0_f64; beta_dim];
3800 // #1117 — rank deficiency is handled at the basis layer: any
3801 // rank-deficient atom was reparametrized onto its data-supported subspace
3802 // at fit entry (`reduce_atoms_to_data_supported_rank`), so the β-tier here
3803 // always sees a full-rank design and needs no step-time data-null
3804 // deflation operator. The well-conditioned (full-rank) path is unchanged.
3805 // Per-atom smoothness-gradient GEMMs `½(S_k+S_kᵀ)·B_k` are independent
3806 // across atoms; batch them across ALL GPUs (uniform-shape tiles) and
3807 // scale by `lambda_smooth` below. `symmetrize = true` reproduces the
3808 // per-atom symmetrised `scaled_s/λ` used by the Kronecker op. Exact CPU
3809 // fallback per atom keeps the result bit-for-bit with the all-CPU path.
3810 let sym_sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
3811 .atoms
3812 .iter()
3813 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
3814 .collect();
3815 let sym_sb_all = batched_smooth_sb(&sym_sb_inputs, true);
3816 for (atom_idx, atom) in self.atoms.iter().enumerate() {
3817 let m = atom.basis_size();
3818 let off = beta_offsets[atom_idx];
3819 // Symmetrise and scale the smoothness penalty matrix.
3820 let mut scaled_s = Array2::<f64>::zeros((m, m));
3821 for i in 0..m {
3822 for j in 0..m {
3823 let s_ij = 0.5 * (atom.smooth_penalty[[i, j]] + atom.smooth_penalty[[j, i]]);
3824 scaled_s[[i, j]] = lambda_smooth[atom_idx] * s_ij;
3825 }
3826 }
3827 // Gradient: g[beta_i] += (λ_k S_k B_k)[i, out_col]. The (m×m)·(m×p)
3828 // GEMM `½(S+Sᵀ)·B_k` was computed in the multi-GPU batch above; here
3829 // we only apply atom k's `lambda_smooth[atom_idx]`.
3830 let sb = &sym_sb_all[atom_idx] * lambda_smooth[atom_idx];
3831 for out_col in 0..p {
3832 for i in 0..m {
3833 let beta_i = off + i * p + out_col;
3834 smooth_grad_gb[beta_i] += sb[[i, out_col]];
3835 }
3836 }
3837 // IdentityRightKroneckerPenaltyOp: factor_a = λ·S_k (m×m), factor_b = I_p.
3838 smooth_ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
3839 factor_a: scaled_s.clone(),
3840 p,
3841 global_offset: off,
3842 k: beta_dim,
3843 }));
3844 // Retain `λ S_k` for the factored rebuild (no-op cost on full-`B`).
3845 smooth_scaled_s.push(scaled_s);
3846 }
3847
3848 // Per-row active-set layout. Engaged for two regimes:
3849 // * JumpReLU — structural gate plus the smooth prior's
3850 // machine-precision support: atoms with
3851 // `(logit - threshold)/tau > -36` enter the compact solve
3852 // ([`jumprelu_in_optimization_band`]). Strictly gated-off atoms
3853 // (logit ≤ threshold) carry zero assignment mass so their data-fit
3854 // reconstruction contribution and data-fit logit JVP are zero, but
3855 // supported atoms keep value-consistent prior gradient in the row block.
3856 // * IBP-MAP at large `K` — the dense `(m_total · p)²` data
3857 // Gram is infeasible, so each row is truncated to its
3858 // top-`k_active` atoms above a relative magnitude cutoff
3859 // ([`Self::sparse_active_plan`]). Small-`K` problems return `None`
3860 // and keep the exact full-support layout.
3861 // The compact row block is sized `q_active = |active| + Σ_{k∈active}
3862 // d_k` instead of the full `q`.
3863 let coord_dims: Vec<usize> = self
3864 .assignment
3865 .coords
3866 .iter()
3867 .map(|c| c.latent_dim())
3868 .collect();
3869 let row_layout: Option<SaeRowLayout> = match forced_layout {
3870 Some(layout) => layout,
3871 None => match self.assignment.mode {
3872 AssignmentMode::JumpReLU {
3873 threshold,
3874 temperature,
3875 } => Some(SaeRowLayout::from_jumprelu(
3876 n,
3877 k_atoms,
3878 threshold,
3879 temperature,
3880 &self.assignment.logits,
3881 coord_dims.clone(),
3882 self.assignment.coord_offsets(),
3883 )),
3884 // #1408/#1409 — Softmax engages the COMPACT top-`k` row layout
3885 // inside the optimization (no longer a post-fit projection).
3886 // The active set is each row's top-`k_active_cap` softmax atoms
3887 // above the relative cutoff; the cap comes from the user's
3888 // `top_k` (`softmax_active_cap`) and/or the in-core memory budget
3889 // ([`Self::softmax_active_plan`]). The full-`K` softmax
3890 // normalization still forms `a` (the gate map); only the dropped
3891 // tail logits, carrying negligible `O(a)` reconstruction mass and
3892 // `O(a²)` curvature, leave the per-row block.
3893 //
3894 // Coherence (the load-bearing correctness invariant): the
3895 // assembly's softmax curvature branch writes the ACTIVE×ACTIVE
3896 // principal sub-block of the Gershgorin Loewner majorizer
3897 // `D = diag(Σ_j|H_kj|)` (#1419; PSD and `D ⪰ H_entropy`) on the
3898 // compact logit slots — NOT the indefinite `assignment_hdiag`
3899 // diagonal. The logdet ρ-trace
3900 // (`assignment_log_strength_hessian_trace`) iterates the row's
3901 // active logit slots and indexes that SAME majorizer by global
3902 // atom, and the θ-adjoint reads its derivative via `jets.vars`
3903 // (global-atom indexed), so value, log|H|, and Γ differentiate
3904 // ONE operator on the compact support. The FFI's after-the-fit
3905 // top-`k` projection is then a no-op at the optimum.
3906 AssignmentMode::Softmax { .. } => match self.softmax_active_plan() {
3907 Some((k_active_cap, relative_cutoff)) => {
3908 let mut assignments_all = Vec::with_capacity(n);
3909 for row in 0..n {
3910 assignments_all
3911 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3912 }
3913 Some(SaeRowLayout::from_dense_weights(
3914 &assignments_all,
3915 k_active_cap,
3916 relative_cutoff,
3917 coord_dims.clone(),
3918 self.assignment.coord_offsets(),
3919 ))
3920 }
3921 None => None,
3922 },
3923 AssignmentMode::IBPMap { .. } => {
3924 match self.sparse_active_plan() {
3925 Some((k_active_cap, relative_cutoff)) => {
3926 // Build per-row dense assignments once to derive the
3927 // active set; the row loop re-derives `assignments`
3928 // (cheap gate map at the same rho) and reuses these
3929 // active sets.
3930 let mut assignments_all = Vec::with_capacity(n);
3931 for row in 0..n {
3932 assignments_all
3933 .push(self.assignment.try_assignments_row_for_rho(row, rho)?);
3934 }
3935 // #1414: pass the RELATIVE cutoff through;
3936 // `from_dense_weights` applies it per row against that
3937 // row's own peak `max_k |a_{n,k}|`, matching the
3938 // documented `sparse_active_plan` contract. A single
3939 // global threshold (relative_cutoff · whole-dataset
3940 // peak) wrongly drops every atom of a uniformly-small
3941 // row when another row peaks high.
3942 Some(SaeRowLayout::from_dense_weights(
3943 &assignments_all,
3944 k_active_cap,
3945 relative_cutoff,
3946 coord_dims.clone(),
3947 self.assignment.coord_offsets(),
3948 ))
3949 }
3950 None => None,
3951 }
3952 }
3953 },
3954 };
3955 // #974 likelihood-whitening seam. The single per-row decision: when the
3956 // installed `RowMetric` is a genuinely estimated noise model
3957 // (`whitens_likelihood()` — only `WhitenedStructured`), the
3958 // reconstruction data-fit, its t-block Gauss-Newton row block, AND the
3959 // β-tier data-fit gradient are all assembled through the SAME per-row
3960 // metric `M_n = U_n U_nᵀ = Σ_n^{-1}`. There is exactly ONE construction
3961 // site (the `whiten_rows` closure below), so the value the line-search
3962 // sums and the gradient/Hessian the Newton step solves cannot drift apart
3963 // (the objective↔gradient-desync cure). For Euclidean / OutputFisher /
3964 // no-metric the closure is the identity and every downstream loop is
3965 // byte-identical to the historical isotropic path.
3966 let whitens_likelihood = self
3967 .row_metric
3968 .as_ref()
3969 .is_some_and(|metric| metric.whitens_likelihood());
3970 // #972 / #977 T1: engage the FACTORED Grassmann-coordinate β-tier when
3971 // any atom has an active decoder frame. The closed-form factorization
3972 // `Φᵀ(G ⊗ I_p)Φ = G ⊗ (U_iᵀU_j)` is EXACT only for the isotropic
3973 // likelihood; under an active whitening metric (`whitens_likelihood()`,
3974 // only `WhitenedStructured`) the per-row output factor would be
3975 // `U_iᵀ M_n U_j` and does NOT factor out of the basis Gram, so we fall
3976 // back to the full-`B` path there (frames + whitening is out of scope —
3977 // see #974). The common Euclidean / OutputFisher / no-metric case factors
3978 // cleanly. When `frames_engaged` is false, EVERY β-tier object below is
3979 // assembled bit-for-bit as the historical full-`B` path.
3980 let frames_engaged = self.any_frame_active() && !whitens_likelihood;
3981 // #1407: fixed-decoder mode skips the entire β decoder tier (G/gb/htbeta
3982 // operator/hbb/β-penalties); only per-row htt/gt are produced.
3983 let fixed_decoder = self.fixed_decoder_assembly;
3984 let admission_plan = self
3985 .streaming_plan()
3986 .admitted_or_error(self.n_obs(), self.output_dim(), self.k_atoms())
3987 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
3988 // #1407: fixed-decoder builds NO dense β-Hessian (hbb) — force the
3989 // empty-hbb system constructor so no `beta_dim × beta_dim` workspace is
3990 // taken (the early return skips `reclaim_border_hbb_workspace`).
3991 let dense_beta_curvature = !fixed_decoder
3992 && admission_plan.direct_admitted
3993 && !(frames_engaged && beta_dim > dense_beta_penalty_probe_max_dim);
3994 // #1406: the dense per-row cross-block slab `block.htbeta` is only WRITTEN
3995 // (line ~4243) and READ by the solver when `frames_engaged` (the factored
3996 // full-B path, which installs NO matrix-free row operator → the solver's
3997 // `sys_htbeta_apply_row` falls back to the dense slab). On the
3998 // `!frames_engaged` path the cross block is carried entirely by the
3999 // matrix-free Kronecker operator (`set_row_htbeta_operator`, ~line 4491);
4000 // `activate_dense_htbeta_supplement` is never called, so the solver never
4001 // touches `block.htbeta`. Allocating it at `beta_dim = K·M·p` there is the
4002 // ~6 TiB high-K leak (#1405/#1406): allocate ZERO columns instead. Frames
4003 // still use the (much smaller) factored border width.
4004 // #795/#1406/#1407: the non-frames matrix-free path normally holds a
4005 // ZERO-width per-row cross-block slab — the data-fit `H_tβ` is carried by
4006 // the Kronecker row operator (`set_row_htbeta_operator`), and allocating
4007 // the dense slab at `beta_dim = K·M·p` is the high-K memory leak. But an
4008 // ISOMETRY penalty on a coherence-preserving (flat) chart scatters an
4009 // ADDITIONAL Gauss-Newton cross-block into the dense per-row `htbeta`
4010 // slab and flips on `activate_dense_htbeta_supplement` — dropping it would
4011 // leave the Newton system block-diagonal and forfeit the strong `t↔B`
4012 // isometry coupling the circle fit needs to reach KKT stationarity (#795).
4013 // So on the non-frames path widen the slab to `beta_dim` exactly when that
4014 // dense supplement will be written, and keep zero width otherwise.
4015 let dense_isometry_cross_block = !fixed_decoder
4016 && analytic_penalties
4017 .map(|registry| self.registry_writes_dense_isometry_cross_block(registry))
4018 .unwrap_or(false);
4019 let row_htbeta_dim = if fixed_decoder {
4020 // Fixed-decoder mode skips the β tier entirely.
4021 0
4022 } else if frames_engaged {
4023 self.factored_border_dim()
4024 } else if dense_isometry_cross_block {
4025 // Matrix-free data-fit cross-block + dense isometry supplement: the
4026 // supplement is written/read in the full-`B` β coordinate system.
4027 beta_dim
4028 } else {
4029 // Matrix-free path with no dense cross-block supplement.
4030 0
4031 };
4032 // Build the Arrow-Schur system: heterogeneous row dims when a compact
4033 // layout is active, uniform `q` otherwise.
4034 let mut sys = if let Some(ref layout) = row_layout {
4035 let per_row_dims: Vec<usize> = (0..n).map(|row| layout.row_q_active(row)).collect();
4036 if dense_beta_curvature {
4037 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4038 ArrowSchurSystem::new_with_per_row_dims_and_hbb_and_htbeta_cols(
4039 per_row_dims,
4040 beta_dim,
4041 hbb_workspace,
4042 row_htbeta_dim,
4043 )
4044 } else {
4045 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4046 ArrowSchurSystem::new_with_per_row_dims_empty_hbb_and_htbeta_cols(
4047 per_row_dims,
4048 beta_dim,
4049 row_htbeta_dim,
4050 )
4051 }
4052 } else if dense_beta_curvature {
4053 let hbb_workspace = self.take_border_hbb_workspace(beta_dim);
4054 ArrowSchurSystem::new_with_hbb_and_htbeta_cols(
4055 n,
4056 q,
4057 beta_dim,
4058 hbb_workspace,
4059 row_htbeta_dim,
4060 )
4061 } else {
4062 self.border_hbb_workspace = Array2::<f64>::zeros((0, 0));
4063 ArrowSchurSystem::new_with_empty_hbb_and_htbeta_cols(n, q, beta_dim, row_htbeta_dim)
4064 };
4065 // Apply accumulated smoothness-penalty gradients into sys.gb.
4066 for (i, g) in smooth_grad_gb.iter().enumerate() {
4067 sys.gb[i] += g;
4068 }
4069 // `w_dim` is the whitened output dimension: `rank` of the metric factor
4070 // when whitening, else `p` (identity). `error_white` is the whitened
4071 // residual `U_nᵀ r_n ∈ ℝ^{w_dim}` whose squared norm is `r_nᵀ M_n r_n`,
4072 // shared by the value path, the t-block GN, and (lifted back to p-space)
4073 // the β-tier gradient.
4074 let w_dim = match self.row_metric.as_ref() {
4075 Some(metric) if whitens_likelihood => metric.metric_rank(),
4076 _ => p,
4077 };
4078 // Data-fit Gauss-Newton β-Hessian is block-diagonal across the `p`
4079 // output channels and identical in each: with the flat β layout
4080 // `β[μ·p + oc] = B[μ, oc]` (μ enumerating (atom, basis_col)) the GN
4081 // outer product `Jβᵀ Jβ` couples only equal `oc`, with the same
4082 // `(M_total × M_total)` block `G[μ, μ'] = Σ_rows (a_k φ_k[m])(a_{k'} φ_{k'}[m'])`
4083 // for every channel. So `H_data = G ⊗ I_p`. The `μ` index of an `a_phi`
4084 // entry whose global β base is `beta_base` is `beta_base / p` (every
4085 // `beta_offset` and the `basis_col·p` stride are multiples of `p`).
4086 //
4087 // `G` is only non-zero on `(atom_i, atom_j)` pairs that co-occur in
4088 // some row's active set, so we accumulate it as a sparse map of dense
4089 // per-atom-pair `(m_i × m_j)` blocks keyed by `(atom_i, atom_j)` rather
4090 // than as a dense `(m_total × m_total)` matrix. At `K = 100K` with
4091 // per-row active sets of size `k_active ≪ K`, only `O(N · k_active²)`
4092 // pairs are ever touched, so the data Gram (and every matvec /
4093 // diagonal pass over it via `SparseBlockKroneckerPenaltyOp`) tracks the
4094 // active atoms instead of `K²`. In the dense full-support layout the
4095 // map degenerates to every co-occurring pair, reproducing the dense
4096 // Gram exactly. A `BTreeMap` key order keeps the installed op's
4097 // fingerprint deterministic. The `μ`-space offset of atom `k` is
4098 // `beta_offsets[k] / p`.
4099 type SaeGBlocks = std::collections::BTreeMap<(usize, usize), Array2<f64>>;
4100 let m_total: usize = self.atoms.iter().map(|a| a.basis_size()).sum();
4101 let mu_offsets: Vec<usize> = beta_offsets.iter().map(|&off| off / p).collect();
4102 // Stick-breaking prior for IBP-MAP depends only on (k_atoms, alpha_eff)
4103 // which are constant across rows for the current rho; precompute once.
4104 let ibp_prior_vec = match self.assignment.mode {
4105 AssignmentMode::IBPMap { .. } => {
4106 let alpha = self
4107 .assignment
4108 .mode
4109 .resolved_ibp_alpha(rho)
4110 .ok_or_else(|| "IBP assignment alpha resolution failed".to_string())?;
4111 Some(ordered_geometric_shrinkage_prior(k_atoms, alpha).to_vec())
4112 }
4113 _ => None,
4114 };
4115 let ibp_prior_slice = ibp_prior_vec.as_deref();
4116 // #991 design honesty weights (mean-1 HT inclusion corrections); see
4117 // the seam comment at the per-row residual below.
4118 let row_loss_w = self.row_loss_weights.as_deref();
4119 // Dense full-support index `[0, k_atoms)`, used by the row loop when no
4120 // compact layout is engaged so the active-atom iteration is uniform.
4121 let all_atoms_index: Vec<usize> = (0..k_atoms).collect();
4122 // Per-atom per-axis periodicity, hoisted out of the row loop. Selects
4123 // the smooth von-Mises coordinate prior on wrapped (Circle) axes and
4124 // the Gaussian prior on Euclidean axes; see `ArdAxisPrior`.
4125 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
4126 .assignment
4127 .coords
4128 .iter()
4129 .map(|coord| coord.effective_axis_periods())
4130 .collect();
4131 struct SaeAssemblyRow {
4132 pub(crate) row: usize,
4133 pub(crate) block: ArrowRowBlock,
4134 pub(crate) gb_delta: Vec<(usize, f64)>,
4135 pub(crate) g_blocks: SaeGBlocks,
4136 pub(crate) kron_a_phi: Option<Vec<(usize, f64)>>,
4137 pub(crate) kron_jac: Option<Vec<f64>>,
4138 }
4139
4140 // Per-row scratch reused across all rows a rayon worker processes
4141 // (#1017). The assembly closure is re-run every inner Newton iteration ×
4142 // every outer ρ evaluation; allocating these eight loop-invariant-sized
4143 // buffers (`k_atoms·p`, several `p`, one `q·max(w_dim,p)`) once per
4144 // worker via `map_init` — rather than once per (row × assembly) inside
4145 // the closure — removes the dominant small-allocation traffic the
4146 // eu-stack profile attributed to allocator/barrier spin at the SAE LLM
4147 // shape (p≈5120). Every buffer is fully filled (or `.fill(0.0)`'d) before
4148 // it is read each row, so reuse is bit-identical to the fresh-alloc path;
4149 // `gb_delta`/`g_blocks` are NOT scratch (they move into the returned
4150 // `SaeAssemblyRow`) and stay allocated per row.
4151 struct RowScratch {
4152 pub(crate) decoded: Array2<f64>,
4153 pub(crate) dg_buf: Vec<f64>,
4154 pub(crate) fitted: Array1<f64>,
4155 pub(crate) error: Array1<f64>,
4156 pub(crate) error_white: Vec<f64>,
4157 pub(crate) error_metric: Array1<f64>,
4158 pub(crate) jac_white: Vec<f64>,
4159 pub(crate) decoded_scratch: Vec<f64>,
4160 // #1557 — per-worker scratch for the row assignment vector (filled via
4161 // `_into`, not allocated per row); full `k_atoms`, global-atom indexed.
4162 pub(crate) assignments: Array1<f64>,
4163 }
4164 // #1410: size the per-worker scratch by the COMPACT row dimensions, not
4165 // full `K`/`q`. With a compact layout the assembly only ever touches each
4166 // row's active atoms (≤ `max_active`) and its compact tangent block
4167 // (≤ `max_q_row`); allocating `decoded` at `k_atoms·p` and `jac_white` at
4168 // `q·max(w_dim,p)` was the per-worker `O(K)` blow-up (≈11 GiB/worker at
4169 // K=100k, p=5120 — and `map_init` gives every Rayon worker its own copy).
4170 // Without a layout the dense path needs full `k_atoms`/`q`. `decoded` rows
4171 // are addressed by COMPACT SLOT in the compact branch below (the dense
4172 // branch keeps global-atom rows), so the row count is the max active set.
4173 //
4174 // #1410/#1408/#1409: SOFTMAX now ALSO takes the `Some(layout)` branch
4175 // whenever a `top_k` cap (`set_softmax_active_cap`) or an in-core memory
4176 // breach engages `softmax_active_plan` → `from_dense_weights`, so its
4177 // per-worker `decoded`/`jac_white` scratch is the COMPACT
4178 // `max_active`/`max_q_row` size too — no longer the full `(k_atoms·p)` /
4179 // `(q·max(w_dim,p))` blow-up. JumpReLU / IBP-MAP likewise pay only
4180 // `max_active`. The remaining `None` (full-`K`) branch is the UNCAPPED
4181 // softmax / no-budget-breach case, which genuinely assembles the dense
4182 // entropy block over all `K`; capping it (the compact contract) removes
4183 // the per-worker `O(K)` footprint entirely. (#1410: the residual per-row
4184 // `O(K)` softmax-majorizer scratch — a `row_logits` copy and the full-`K`
4185 // `d`/`H_entropy` blocks — is removed separately; see the active-only
4186 // `active_softmax_gershgorin_majorizer_entry` /
4187 // `softmax_dense_entropy_hessian_entry` helpers below.)
4188 let (decoded_rows, scratch_q) = match row_layout.as_ref() {
4189 Some(layout) => {
4190 let max_active = (0..n)
4191 .map(|r| layout.active_atoms[r].len())
4192 .max()
4193 .unwrap_or(0)
4194 .max(1);
4195 let max_q_row = (0..n)
4196 .map(|r| layout.row_q_active(r))
4197 .max()
4198 .unwrap_or(q)
4199 .max(1);
4200 (max_active, max_q_row)
4201 }
4202 None => (k_atoms, q),
4203 };
4204 use rayon::iter::{IntoParallelIterator, ParallelIterator};
4205 // #1033 large-n: fold the per-row assembly results in row-ordered CHUNKS
4206 // rather than collecting all `n` `SaeAssemblyRow`s at once. The previous
4207 // path materialized the FULL `Vec<SaeAssemblyRow>` (every row's htt/gt
4208 // block + per-row `g_blocks` + `kron_a_phi`/`kron_jac`) AND the fold
4209 // destinations simultaneously — a ~2× transient peak over the resident
4210 // system during the fold, the assembly-side OOM cliff at large `n`. By
4211 // collecting one chunk, folding it into `sys.rows`/`g_blocks`/`kron_*`,
4212 // and dropping the chunk's `Vec` before the next chunk, the transient
4213 // intermediate is bounded to `O(chunk_size)` while the resident output is
4214 // unchanged. The fold stays STRICTLY row-ascending (chunk `[c0..c1)` then
4215 // `[c1..c2)`, rows in order within each chunk), so every `+=` into
4216 // `sys.gb`, the `g_blocks` BTreeMap, and the `kron_*` pushes lands in the
4217 // identical order as the single-pass fold — bit-for-bit the same system.
4218 // Chunk width is the admission plan's `chunk_size` (the same value
4219 // `streaming_plan` sizes for the matrix-free window), floored so a tiny
4220 // plan still makes forward progress.
4221 let assembly_chunk_rows = self
4222 .assembly_chunk_override
4223 .unwrap_or(admission_plan.chunk_size)
4224 .clamp(1, n.max(1));
4225 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4226 let mut kron_a_phi: Vec<Vec<(usize, f64)>> = Vec::with_capacity(n);
4227 let mut kron_jac: Vec<Vec<f64>> = Vec::with_capacity(n);
4228 let mut chunk_start = 0usize;
4229 while chunk_start < n {
4230 let chunk_end = (chunk_start + assembly_chunk_rows).min(n);
4231 let mut fold_offset_in_chunk = 0usize;
4232 let row_results: Vec<SaeAssemblyRow> = (chunk_start..chunk_end)
4233 .into_par_iter()
4234 .map_init(
4235 || RowScratch {
4236 decoded: Array2::<f64>::zeros((decoded_rows, p)),
4237 dg_buf: vec![0.0_f64; p],
4238 fitted: Array1::<f64>::zeros(p),
4239 error: Array1::<f64>::zeros(p),
4240 error_white: vec![0.0_f64; w_dim],
4241 error_metric: Array1::<f64>::zeros(p),
4242 jac_white: vec![0.0_f64; scratch_q * w_dim.max(p)],
4243 decoded_scratch: vec![0.0_f64; p],
4244 assignments: Array1::<f64>::zeros(k_atoms),
4245 },
4246 |scratch, row| -> Result<SaeAssemblyRow, String> {
4247 // #1557 — mark this rayon row worker as a nested data-parallel
4248 // region so any faer GEMM reached transitively from the per-row
4249 // assembly (frame `Uᵀ` products, the per-row cross-block /
4250 // Schur-accumulation matmuls, the Riemannian projections) pins to
4251 // `Par::Seq` via `effective_global_parallelism` instead of
4252 // re-fanning the global Rayon pool against this outer fan-out
4253 // (the `spindle` barrier-spin). Serial vs parallel over these tiny
4254 // per-row blocks is a single small product, so the result is
4255 // bit-identical. The guard is held for the whole closure body
4256 // including its `?`/`return` paths.
4257 with_nested_parallel(|| {
4258 let RowScratch {
4259 decoded,
4260 dg_buf,
4261 fitted,
4262 error,
4263 error_white,
4264 error_metric,
4265 jac_white,
4266 decoded_scratch,
4267 assignments,
4268 } = scratch;
4269 let mut gb_delta: Vec<(usize, f64)> = Vec::new();
4270 let mut g_blocks: SaeGBlocks = std::collections::BTreeMap::new();
4271 // #1557 — fill per-worker scratch (bit-identical to alloc path).
4272 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
4273 self.assignment
4274 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
4275 // Reconstruction uses the row's active support: for the dense
4276 // full-support layout this is all atoms (exact); for a compact
4277 // layout the dropped atoms carry negligible `O(a)` reconstruction
4278 // mass and zero curvature, so excluding them keeps `fitted`,
4279 // `error`, and the logit-JVP cross term `(decoded[k] − fitted)`
4280 // mutually consistent with the curvature actually assembled.
4281 fitted.fill(0.0);
4282 let row_active_owned: Option<&[usize]> =
4283 row_layout.as_ref().map(|l| l.active_atoms[row].as_slice());
4284 match row_active_owned {
4285 Some(active) => {
4286 // #1410: `decoded` is a compact (max_active × p) buffer
4287 // here; index it by the active-set SLOT `j` (the same
4288 // index the compact tangent block / `coord_starts` use),
4289 // NOT the global `atom_idx`.
4290 for (j, &atom_idx) in active.iter().enumerate() {
4291 let a_k = assignments[atom_idx];
4292 self.atoms[atom_idx]
4293 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4294 for out_col in 0..p {
4295 decoded[[j, out_col]] = decoded_scratch[out_col];
4296 fitted[out_col] += a_k * decoded_scratch[out_col];
4297 }
4298 }
4299 }
4300 None => {
4301 for atom_idx in 0..k_atoms {
4302 let a_k = assignments[atom_idx];
4303 self.atoms[atom_idx]
4304 .fill_decoded_row(row, decoded_scratch.as_mut_slice());
4305 for out_col in 0..p {
4306 decoded[[atom_idx, out_col]] = decoded_scratch[out_col];
4307 fitted[out_col] += a_k * decoded_scratch[out_col];
4308 }
4309 }
4310 }
4311 }
4312 for out_col in 0..p {
4313 error[out_col] = fitted[out_col] - target[[row, out_col]];
4314 }
4315 // #991 design-honesty seam: a per-row scalar weight `w_row` on the
4316 // reconstruction channel is exactly the metric `w_row · I_p`, so it
4317 // is realized as a `√w_row` scaling of the THREE row-local data
4318 // quantities at their construction sites — this residual, the
4319 // latent Jacobian (below), and the β basis load `a·φ` (below).
4320 // Every downstream data object then carries exactly one factor of
4321 // `w_row` (gt, htt, htbeta, the β Gram `G`, and the β gradient),
4322 // matching the `w_row`-weighted value `loss_scaled` sums; the
4323 // per-row latent priors (assignment / ARD, added to `gt`/`htt`
4324 // further down) are deliberately unweighted — see the
4325 // `row_loss_weights` field docs. `None` ⇒ `sqrt_row_w == 1.0` and
4326 // no multiply is applied (bit-identical unweighted path).
4327 let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
4328 if sqrt_row_w != 1.0 {
4329 for out_col in 0..p {
4330 error[out_col] *= sqrt_row_w;
4331 }
4332 }
4333 // #974 seam (step 1/2): whiten the per-row residual ONCE.
4334 // * not whitening ⇒ `error_white == error` (length p) and
4335 // `error_metric == error`; every downstream loop is the
4336 // historical isotropic path bit-for-bit.
4337 // * whitening ⇒ `error_white = U_nᵀ r_n ∈ ℝ^{w_dim}` (its squared
4338 // norm is `r_nᵀ M_n r_n`, the value the data-fit sums) and
4339 // `error_metric = U_n (U_nᵀ r_n) = M_n r_n ∈ ℝ^p` (the p-space
4340 // metric-applied residual the β-tier gradient contracts).
4341 match self.row_metric.as_ref() {
4342 Some(metric) if whitens_likelihood => {
4343 let wr = metric.whiten_residual_row(row, error.view());
4344 for (slot, &v) in error_white.iter_mut().zip(wr.iter()) {
4345 *slot = v;
4346 }
4347 let mr = metric.apply_metric_row(row, error.view());
4348 for (slot, &v) in error_metric.iter_mut().zip(mr.iter()) {
4349 *slot = v;
4350 }
4351 }
4352 _ => {
4353 for out_col in 0..p {
4354 error_white[out_col] = error[out_col];
4355 error_metric[out_col] = error[out_col];
4356 }
4357 }
4358 }
4359
4360 // Determine whether this row uses the compact active-set layout.
4361 // * JumpReLU: gated atoms plus the smooth prior's
4362 // machine-precision support enter.
4363 // * IBP-MAP at large K: only the top-`k_active` atoms.
4364 // * Otherwise (small K): the dense uniform-q layout.
4365 let (q_row, mut local_jac_row) = if let Some(layout) = row_layout.as_ref() {
4366 let active = &layout.active_atoms[row];
4367 let starts = &layout.coord_starts[row];
4368 let q_active = layout.row_q_active(row);
4369 let mut jac_compact = Array2::<f64>::zeros((q_active, p));
4370 // Logit JVP rows for active atoms only, using the per-mode
4371 // assignment sensitivity `da_k/dl_k` contracted into the
4372 // decoded / fitted-corrected output direction.
4373 let logits_row = self.assignment.logits.row(row);
4374 for (j, &k) in active.iter().enumerate() {
4375 fill_active_atom_logit_jvp(
4376 ActiveAtomLogitJvp {
4377 mode: self.assignment.mode,
4378 k,
4379 logit_k: logits_row[k],
4380 a_k: assignments[k],
4381 // #1410: compact slot `j`, not global atom `k`.
4382 decoded_k: decoded.row(j),
4383 fitted: fitted.view(),
4384 ibp_prior: ibp_prior_slice,
4385 compact_index: j,
4386 // #1026/#1033: a FIXED logit (ungated, or every
4387 // atom under frozen routing) has a constant gate
4388 // ⇒ zero logit-JVP.
4389 ungated: self.assignment.logit_is_fixed(k),
4390 },
4391 &mut jac_compact,
4392 );
4393 }
4394 // Coordinate JVP rows for active atoms only.
4395 for (j, &k) in active.iter().enumerate() {
4396 let d = self.atoms[k].latent_dim;
4397 let a_k = assignments[k];
4398 let coord_start = starts[j];
4399 for axis in 0..d {
4400 self.atoms[k].fill_decoded_derivative_row(
4401 row,
4402 axis,
4403 dg_buf.as_mut_slice(),
4404 );
4405 for out_col in 0..p {
4406 jac_compact[[coord_start + axis, out_col]] =
4407 a_k * dg_buf[out_col];
4408 }
4409 }
4410 }
4411 (q_active, jac_compact)
4412 } else {
4413 // Fresh per-row Jacobian, structurally identical to the
4414 // JumpReLU branch: every (q × p) element is unconditionally
4415 // overwritten below (assignment-chart JVP rows + coordinate rows), so the
4416 // `Array2::zeros` allocation needs no separate `fill(0.0)` and
4417 // the populated buffer is returned by move without a clone.
4418 let mut jac_row = Array2::<f64>::zeros((q, p));
4419 fill_assignment_logit_jvp_rows(
4420 self.assignment.mode,
4421 self.assignment.logits.row(row),
4422 assignments.view(),
4423 decoded.view(),
4424 fitted.view(),
4425 ibp_prior_slice,
4426 // #1026/#1033: zero logit-JVP rows for FIXED-logit atoms
4427 // (ungated, and all atoms under frozen routing).
4428 &self.assignment.fixed_logit_mask(),
4429 &mut jac_row,
4430 );
4431 // Coordinate columns for all atoms.
4432 for atom_idx in 0..k_atoms {
4433 let d = self.atoms[atom_idx].latent_dim;
4434 let off = coord_offsets[atom_idx];
4435 let a_k = assignments[atom_idx];
4436 for axis in 0..d {
4437 self.atoms[atom_idx].fill_decoded_derivative_row(
4438 row,
4439 axis,
4440 dg_buf.as_mut_slice(),
4441 );
4442 for out_col in 0..p {
4443 jac_row[[off + axis, out_col]] = a_k * dg_buf[out_col];
4444 }
4445 }
4446 }
4447 (q, jac_row)
4448 };
4449
4450 // #991 design-honesty seam, Jacobian leg: scale the row's latent
4451 // Jacobian by `√w_row` BEFORE the whitening / Kronecker capture so
4452 // htt (= J̃J̃ᵀ), the data part of gt (= J̃ẽ, the residual already
4453 // carries its own √w_row), and the htbeta cross block (J paired
4454 // with the √w_row-scaled β load below) each carry exactly one
4455 // factor of `w_row`. No-op on the unweighted path.
4456 if sqrt_row_w != 1.0 {
4457 for a in 0..q_row {
4458 for out_col in 0..p {
4459 local_jac_row[[a, out_col]] *= sqrt_row_w;
4460 }
4461 }
4462 }
4463
4464 // #974 seam (step 2/2): whiten the per-row Jacobian through the SAME
4465 // metric the residual was whitened by. `jac_white[a*w_dim + k]` holds
4466 // `J̃[a, k] = Σ_out U_n[out, k] · J_n[a, out]` so the t-block
4467 // Gauss-Newton row block is `htt = J̃ J̃ᵀ = J_n M_n J_nᵀ` and
4468 // `gt = J̃ ẽ = J_nᵀ M_n r_n`. When not whitening, `w_dim == p` and the
4469 // whitened jac equals the raw Jacobian, so htt/gt are byte-identical
4470 // to the historical isotropic assembly. Because the SAME `error_white`
4471 // feeds both the value-path data-fit (Σ½ ẽ²) and this gradient
4472 // (J̃ ẽ), the objective and its t-block gradient share one whitening
4473 // — they cannot desync.
4474 if whitens_likelihood {
4475 if let Some(metric) = self.row_metric.as_ref() {
4476 for a in 0..q_row {
4477 for k in 0..w_dim {
4478 let mut acc = 0.0;
4479 // U_n[out, k] read through the metric's factor layout.
4480 for out_col in 0..p {
4481 acc += metric.factor_entry(row, out_col, k)
4482 * local_jac_row[[a, out_col]];
4483 }
4484 jac_white[a * w_dim + k] = acc;
4485 }
4486 }
4487 }
4488 } else {
4489 for a in 0..q_row {
4490 for out_col in 0..p {
4491 jac_white[a * w_dim + out_col] = local_jac_row[[a, out_col]];
4492 }
4493 }
4494 }
4495
4496 // Build the per-row Arrow-Schur block at the row's active dim.
4497 let mut block = ArrowRowBlock::new(q_row, row_htbeta_dim);
4498 for a in 0..q_row {
4499 let jac_a = &jac_white[a * w_dim..(a + 1) * w_dim];
4500 let g = jac_a
4501 .iter()
4502 .zip(error_white.iter())
4503 .map(|(&j, &e)| j * e)
4504 .sum::<f64>();
4505 block.gt[a] += g;
4506 for b in 0..q_row {
4507 let jac_b = &jac_white[b * w_dim..(b + 1) * w_dim];
4508 let h = jac_a
4509 .iter()
4510 .zip(jac_b.iter())
4511 .map(|(&ja, &jb)| ja * jb)
4512 .sum::<f64>();
4513 block.htt[[a, b]] += h;
4514 }
4515 }
4516
4517 // Assignment prior in logit space.
4518 // For compact layout: position `j` = active_atoms index.
4519 // For dense layout: position `atom_idx` directly.
4520 //
4521 // H-consistency note (#1006 audit). This `assignment_hdiag` is the
4522 // assignment channel's raw diagonal curvature, added un-majorized. It
4523 // is exact for JumpReLU and exact within each IBP row/column diagonal,
4524 // but it is a deliberate diagonal approximation for two full-Hessian
4525 // structures that the current factorization does not yet carry (#1038):
4526 //
4527 // * softmax entropy has dense within-row Hessian
4528 // H_kj = (λ/τ²) a_k[δ_kj(m-L_k-1) + a_j(L_k+L_j+1-2m)];
4529 // this block stores only its diagonal.
4530 // * IBP empirical-π has cross-row rank-one terms per column
4531 // H_(i,k),(j,k) = w score_derivative_k z'_ik z'_jk for i != j;
4532 // this row-local block stores only the diagonal/self-row part.
4533 // The exact scalar `D`-coefficient `d_k = w·s'_k` is now
4534 // surfaced as `IbpHessianDiagThirdChannels::cross_row_d`
4535 // (FD-verified against ∂²value/∂ℓ_ik∂ℓ_jk in
4536 // `ibp_cross_row_woodbury_d_matches_full_off_diagonal_hessian`),
4537 // and `z_jac` carries `u_k`'s entries `z'_ik`. The exact
4538 // determinant-lemma consumer is
4539 // log det(I_K + D UᵀH₀'⁻¹U) on the NO-SELF base
4540 // H₀' = H₀ − Σ_k d_k diag(z'_ik²) — which requires re-factoring
4541 // the per-row logit-slot diagonal (a factorization-side change
4542 // in `solver::arrow_schur`, outside this assembly chokepoint).
4543 //
4544 // The criterion's log|H| and Γ adjoint differentiate this same
4545 // assembled diagonal/quasi-Laplace Hessian, so value and gradient stay
4546 // on one branch. A future dense-row softmax or IBP Woodbury correction
4547 // must update both assembly and the θ-adjoint together.
4548 let assignment_base = row * k_atoms;
4549 if let Some(layout) = row_layout.as_ref() {
4550 let active = &layout.active_atoms[row];
4551 // #1408/#1409 softmax compact curvature: the entropy
4552 // Hessian diagonal in `assignment_hdiag` is INDEFINITE,
4553 // so on a compact softmax layout write the Gershgorin
4554 // Loewner majorizer `D_kk = Σ_j|H_kj|` (#1419) — the same
4555 // PSD operator the dense softmax branch writes — at each
4556 // active logit slot. `D` is diagonal, so its active
4557 // principal sub-block is `diag(D_kk : k ∈ active)`; each
4558 // `D_kk` is the FULL-`K` abs-row-sum, so it still
4559 // dominates the active principal sub-block of `H_entropy`
4560 // (a genuine majorizer on the retained support). The
4561 // gradient stays the EXACT entropy gradient (it sets the
4562 // fixed point), so majorizing only conditions the Newton
4563 // step. JumpReLU/IBP keep their (exact) diagonal.
4564 //
4565 // #1410: compute only the active `D_kk` directly from this
4566 // row's softmax assignments `a` (= `assignments`, already
4567 // in hand), via `active_softmax_gershgorin_majorizer_entry`.
4568 // The previous `psd_majorizer_abs_row_sums(&row_logits, ..)`
4569 // call allocated TWO length-`K` per-row scratch vectors (a
4570 // fresh `row_logits` copy and the full-`K` returned `d`)
4571 // only to read `d[k]` for the `≤ top_k` active `k` — an
4572 // `O(K)` per-row allocation on the path the compact
4573 // contract keeps `K`-free. The shared `m = Σ_j a_j l_j` is
4574 // the one irreducible `O(K)` pass, computed once per row.
4575 let assignments_slice = assignments
4576 .as_slice()
4577 .expect("softmax assignments row must be contiguous");
4578 let majorizer_log_mean: Option<f64> = softmax_dense
4579 .as_ref()
4580 .map(|_| softmax_majorizer_log_mean(assignments_slice));
4581 for (j, &k) in active.iter().enumerate() {
4582 block.gt[j] += assignment_grad[assignment_base + k];
4583 match (softmax_dense.as_ref(), majorizer_log_mean) {
4584 (Some((_penalty, scale)), Some(m)) => {
4585 block.htt[[j, j]] +=
4586 active_softmax_gershgorin_majorizer_entry(
4587 assignments_slice,
4588 k,
4589 m,
4590 *scale,
4591 );
4592 }
4593 _ => block.htt[[j, j]] += assignment_hdiag[assignment_base + k],
4594 }
4595 }
4596 } else {
4597 for free_idx in 0..assignment_dim {
4598 block.gt[free_idx] += assignment_grad[assignment_base + free_idx];
4599 }
4600 if let Some((penalty, scale)) = softmax_dense.as_ref() {
4601 // #1419: write the genuine Gershgorin Loewner majorizer
4602 // `D = diag(Σ_j|H_kj|)` of the exact entropy Hessian onto the
4603 // row's logit block in place of the EXACT entropy Hessian. The
4604 // entropy Hessian is INDEFINITE (concave directions on
4605 // long-tailed rows), which drove the per-row evidence block
4606 // non-PD and forced the downstream Faddeev–Popov deflation to
4607 // flatten data-relevant logit directions (under-identifying the
4608 // atoms). `D` is a nonnegative diagonal, hence exactly PSD and
4609 // PD-preserving like the previous Fisher surrogate, so the block
4610 // stays PD and the deflation no longer fires on the entropy
4611 // block. Unlike the Fisher metric `G = scale·(diag(a) − a aᵀ)`,
4612 // which is PSD but NOT a majorizer (`G − H_entropy` can be
4613 // indefinite — K=2, a=(0.95,0.05): G₁₁=0.0475 < H₁₁=0.0784,
4614 // #1419), `D` actually satisfies `D ⪰ H_entropy` and `D ⪰ 0`,
4615 // so it is a true MM/Loewner curvature majorizer. Because the
4616 // entropy penalty is a FIXED prior whose stationary point is set
4617 // by its (unchanged) EXACT gradient, replacing its curvature
4618 // with the majorizer only conditions the Newton step and the
4619 // Laplace normalizer's curvature operator — it does NOT move the
4620 // optimum.
4621 //
4622 // Softmax uses the REDUCED K−1 free-logit chart (the last
4623 // reference logit is fixed at 0, `assignment_coord_dim() = K−1`).
4624 // Holding z_{K-1} fixed, the reduced curvature over the free
4625 // logits 0..K−1 is exactly the top-left (K−1)×(K−1) submatrix of
4626 // the full K×K majorizer (the fixed logit contributes no
4627 // row/column to the free curvature). The criterion's `log|H|`
4628 // and the #1006 θ-adjoint differentiate this SAME `D` (see the
4629 // `row_psd_majorizer_logit_derivative` site below), so value and
4630 // adjoint stay on one exact branch.
4631 let row_logits: Vec<f64> = (0..k_atoms)
4632 .map(|k| self.assignment.logits[[row, k]])
4633 .collect();
4634 let h_dense = penalty.row_psd_majorizer(&row_logits, *scale);
4635 for ki in 0..assignment_dim {
4636 for kj in 0..assignment_dim {
4637 block.htt[[ki, kj]] += h_dense[[ki, kj]];
4638 }
4639 }
4640 } else {
4641 for free_idx in 0..assignment_dim {
4642 block.htt[[free_idx, free_idx]] +=
4643 assignment_hdiag[assignment_base + free_idx];
4644 }
4645 }
4646 }
4647
4648 // ARD on each on-atom coordinate.
4649 // For compact layout: only active atoms; coord positions use compact starts.
4650 // For dense layout: all atoms; coord positions use coord_offsets.
4651 if let Some(layout) = row_layout.as_ref() {
4652 let active = &layout.active_atoms[row];
4653 let starts = &layout.coord_starts[row];
4654 for (j, &k) in active.iter().enumerate() {
4655 let coord = &self.assignment.coords[k];
4656 let d = coord.latent_dim();
4657 if rho.log_ard[k].is_empty() {
4658 continue;
4659 }
4660 if rho.log_ard[k].len() != d {
4661 return Err(format!(
4662 "ARD rho atom {k} has len {} but atom dim is {d}",
4663 rho.log_ard[k].len()
4664 ));
4665 }
4666 let row_t = coord.row(row);
4667 let periods = &ard_axis_periods[k];
4668 for axis in 0..d {
4669 // ARD on coords is a genuine per-row prior (each row
4670 // contributes the per-axis prior energy), so it is NOT
4671 // minibatch-scaled — the per-chunk row sums already
4672 // reconstruct the full coordinate prior across a pass.
4673 // The value (`ard_value`/`loss.ard`) and the gradient
4674 // both come from the SAME `ArdAxisPrior` energy, so they
4675 // stay FD-consistent on periodic axes. The exact
4676 // von-Mises curvature `V'' = α·cos(κt)` is INDEFINITE —
4677 // it goes negative for |t| past a quarter period — so
4678 // writing it raw into the Newton/Schur `htt` diagonal
4679 // makes that PSD curvature block indefinite and the Schur
4680 // Cholesky (used both for the Newton step and the exact
4681 // log-det) fails on a non-PD pivot. Accumulate the PSD
4682 // majorizer `max(V'', 0)` instead, exactly as
4683 // `add_sae_coord_penalty` does for the registry coord
4684 // penalties: the positive part keeps `htt` PSD so the
4685 // factorization succeeds, and majorizing the curvature of
4686 // a fixed prior only damps the Newton step — it does not
4687 // move the stationary point (the gradient, which sets the
4688 // fixed point, stays the exact `V'`).
4689 let alpha =
4690 SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
4691 let prior =
4692 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4693 block.gt[starts[j] + axis] += prior.grad;
4694 block.htt[[starts[j] + axis, starts[j] + axis]] +=
4695 prior.hess.max(0.0);
4696 }
4697 }
4698 } else {
4699 for atom_idx in 0..k_atoms {
4700 let coord = &self.assignment.coords[atom_idx];
4701 let d = coord.latent_dim();
4702 if rho.log_ard[atom_idx].is_empty() {
4703 continue;
4704 }
4705 if rho.log_ard[atom_idx].len() != d {
4706 return Err(format!(
4707 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
4708 rho.log_ard[atom_idx].len()
4709 ));
4710 }
4711 let off = coord_offsets[atom_idx];
4712 let row_t = coord.row(row);
4713 let periods = &ard_axis_periods[atom_idx];
4714 for axis in 0..d {
4715 // PSD-majorize the (possibly negative) von-Mises curvature
4716 // into the Newton/Schur `htt` block; see the compact-layout
4717 // branch above for why `max(V'', 0)` is required to keep
4718 // `htt` PD (the exact `V'' = α·cos κt` is indefinite past a
4719 // quarter period and breaks the Schur/log-det Cholesky).
4720 let alpha = SaeManifoldRho::stable_exp_strength(
4721 rho.log_ard[atom_idx][axis],
4722 );
4723 let prior =
4724 ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
4725 block.gt[off + axis] += prior.grad;
4726 block.htt[[off + axis, off + axis]] += prior.hess.max(0.0);
4727 }
4728 }
4729 }
4730
4731 // Beta gradient/Hessian — Kronecker form J_β = φᵀ ⊗ I_p.
4732 //
4733 // The per-row beta Jacobian is
4734 // J_β[out_col, beta_idx] = a_k · phi_k[basis_col] if out_col == out_col(beta_idx)
4735 // 0 otherwise
4736 // so the data-fit Gauss-Newton beta-Hessian factors as a rank-`p`
4737 // sum of outer products. We pre-compute the per-(atom, basis_col)
4738 // scalar `a_k · phi_k` once and reuse it across the `out_col`
4739 // and inner `(atom_j, basis_col2)` loops.
4740 //
4741 // Full-B rows keep the matrix-free Kronecker path below. Factored
4742 // rows write the `q_i × Σ M_k r_k` C-space cross slab directly by
4743 // folding each output-channel contribution through the atom frame,
4744 // so no `q_i × β_dim` slab is ever materialized.
4745 //
4746 // Only the row's active atoms contribute `a_phi` support and data
4747 // curvature: in a compact layout (JumpReLU gate or large-K
4748 // top-`k_active` truncation) the inactive atoms carry zero (gated)
4749 // or sub-cutoff assignment mass and are excluded — this is what
4750 // keeps both the htbeta support and the `G` accumulation
4751 // `O(k_active)` rather than `O(K)`. In the dense full-support
4752 // layout `row_active` spans all atoms.
4753 let row_active: &[usize] = match row_layout.as_ref() {
4754 Some(layout) => layout.active_atoms[row].as_slice(),
4755 None => &all_atoms_index,
4756 };
4757 // #1407: in fixed-decoder mode the β tier is not assembled at
4758 // all — leave gb_delta/g_blocks empty and kron None. htt/gt
4759 // (built above) are the only outputs the frozen-decoder step
4760 // consumes.
4761 let mut a_phi: Vec<(usize, f64)> = Vec::with_capacity(row_active.len() * 4);
4762 // Per-active-atom weighted basis row `a_k · φ_k[·]`, retained so the
4763 // data Gram blocks can be accumulated as clean per-atom-pair outer
4764 // products `(a_k φ_k) (a_{k'} φ_{k'})ᵀ`.
4765 let mut weighted_phi: Vec<(usize, Vec<f64>)> =
4766 Vec::with_capacity(row_active.len());
4767 if !fixed_decoder {
4768 for &atom_idx in row_active {
4769 let atom = &self.atoms[atom_idx];
4770 let atom_beta_off = beta_offsets[atom_idx];
4771 let m = atom.basis_size();
4772 let a_k = assignments[atom_idx];
4773 let mut wphi = Vec::with_capacity(m);
4774 for basis_col in 0..m {
4775 let phi = atom.basis_values[[row, basis_col]];
4776 // #991 design-honesty seam, β leg: the `√w_row` here pairs
4777 // with the `√w_row` on the residual (β gradient =
4778 // `a·φ · M r` ⇒ w_row) and with itself (β Gram `G` and the
4779 // htbeta Kronecker capture ⇒ w_row). `1.0` when unweighted.
4780 let w = a_k * phi * sqrt_row_w;
4781 a_phi.push((atom_beta_off + basis_col * p, w));
4782 wphi.push(w);
4783 }
4784 weighted_phi.push((atom_idx, wphi));
4785 }
4786 // β data-fit gradient `gᵦ += J_βᵀ M_n r_n`. The β-Jacobian is
4787 // `J_β = φ_nᵀ ⊗ I_p`, so `J_βᵀ M_n r_n = φ_n ⊗ (M_n r_n)` —
4788 // contract the basis weight `a·φ` against the p-space metric-applied
4789 // residual `error_metric` (= `M_n r_n`), the SAME whitening the value
4790 // path and t-block share. When not whitening, `error_metric == error`
4791 // and this is byte-identical to the historical `J_βᵀ r`.
4792 for &(beta_base_i, j_beta_i) in a_phi.iter() {
4793 if j_beta_i == 0.0 {
4794 continue;
4795 }
4796 for out_col in 0..p {
4797 gb_delta.push((
4798 beta_base_i + out_col,
4799 j_beta_i * error_metric[out_col],
4800 ));
4801 // No dense hbb write — the sparse `G ⊗ I_p` op installed
4802 // after the loop carries the data-fit GN β-Hessian.
4803 }
4804 }
4805 if frames_engaged {
4806 for &atom_idx in row_active {
4807 let atom = &self.atoms[atom_idx];
4808 let m = atom.basis_size();
4809 let a_k = assignments[atom_idx];
4810 for basis_col in 0..m {
4811 let phi = atom.basis_values[[row, basis_col]];
4812 let w = a_k * phi * sqrt_row_w;
4813 if w == 0.0 {
4814 continue;
4815 }
4816 let c_base = frame_projection.border_offsets[atom_idx]
4817 + basis_col * frame_projection.ranks[atom_idx];
4818 for c in 0..q_row {
4819 let mut hrow = block.htbeta.row_mut(c);
4820 let hrow_slice = hrow
4821 .as_slice_mut()
4822 .expect("htbeta row is contiguous");
4823 for out_col in 0..p {
4824 let value = local_jac_row[[c, out_col]] * w;
4825 frame_projection.accumulate_output_project(
4826 atom_idx, c_base, out_col, value, hrow_slice,
4827 );
4828 }
4829 }
4830 }
4831 }
4832 }
4833 // Data-fit GN β-Hessian: accumulate the channel-independent block
4834 // `G[μ_i, μ_j] += (a_k φ_k)[μ_i] (a_{k'} φ_{k'})[μ_j]` into the
4835 // sparse per-atom-pair map (the `out_col` dimension is carried by
4836 // `I_p`). Only co-occurring `(atom_i, atom_j)` pairs are touched.
4837 for ai in 0..weighted_phi.len() {
4838 let (atom_i, ref wphi_i) = weighted_phi[ai];
4839 let m_i = wphi_i.len();
4840 for aj in 0..weighted_phi.len() {
4841 let (atom_j, ref wphi_j) = weighted_phi[aj];
4842 let m_j = wphi_j.len();
4843 let blk = g_blocks
4844 .entry((atom_i, atom_j))
4845 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4846 for li in 0..m_i {
4847 let wi = wphi_i[li];
4848 if wi == 0.0 {
4849 continue;
4850 }
4851 for lj in 0..m_j {
4852 blk[[li, lj]] += wi * wphi_j[lj];
4853 }
4854 }
4855 }
4856 }
4857 } // #1407 end `if !fixed_decoder` β-tier accumulation
4858 let (kron_a_phi, kron_jac) = if !frames_engaged && !fixed_decoder {
4859 // Flatten local_jac_row row-major into a plain Vec<f64> (q_row * p entries).
4860 let mut jac_flat = vec![0.0_f64; q_row * p];
4861 for c in 0..q_row {
4862 for j in 0..p {
4863 jac_flat[c * p + j] = local_jac_row[[c, j]];
4864 }
4865 }
4866 (Some(a_phi), Some(jac_flat))
4867 } else {
4868 (None, None)
4869 };
4870 Ok(SaeAssemblyRow {
4871 row,
4872 block,
4873 gb_delta,
4874 g_blocks,
4875 kron_a_phi,
4876 kron_jac,
4877 })
4878 }) // #1557 with_nested_parallel
4879 },
4880 )
4881 .collect::<Result<Vec<_>, String>>()?;
4882
4883 // Fold THIS chunk's rows (ascending) into the global accumulators.
4884 // The parallel collect preserves index order within the chunk and
4885 // chunks are visited in ascending `chunk_start` order, so the overall
4886 // fold order is `0,1,2,…,n-1` — identical to the former single-pass
4887 // fold. The `row == chunk_start + fold_offset_in_chunk` assert pins
4888 // that strict sequential arrival (the invariant the `kron_*`
4889 // row-aligned pushes depend on).
4890 for row_result in row_results.into_iter() {
4891 let row = row_result.row;
4892 assert_eq!(
4893 row,
4894 chunk_start + fold_offset_in_chunk,
4895 "parallel SAE row assembly returned rows out of order"
4896 );
4897 fold_offset_in_chunk += 1;
4898 for (idx, value) in row_result.gb_delta {
4899 sys.gb[idx] += value;
4900 }
4901 for ((atom_i, atom_j), data) in row_result.g_blocks {
4902 let m_i = data.nrows();
4903 let m_j = data.ncols();
4904 let blk = g_blocks
4905 .entry((atom_i, atom_j))
4906 .or_insert_with(|| Array2::<f64>::zeros((m_i, m_j)));
4907 for li in 0..m_i {
4908 for lj in 0..m_j {
4909 blk[[li, lj]] += data[[li, lj]];
4910 }
4911 }
4912 }
4913 if !frames_engaged && !fixed_decoder {
4914 // Rows arrive in ascending order across chunks, so pushing
4915 // here yields `kron_*[row]` aligned to the row index exactly
4916 // as the single-pass `push` did.
4917 kron_a_phi.push(
4918 row_result
4919 .kron_a_phi
4920 .expect("full-B SAE row assembly must return a_phi rows"),
4921 );
4922 kron_jac.push(
4923 row_result
4924 .kron_jac
4925 .expect("full-B SAE row assembly must return local Jacobian rows"),
4926 );
4927 }
4928 sys.rows[row] = row_result.block;
4929 }
4930 chunk_start = chunk_end;
4931 }
4932 // #1407: fixed-decoder early return. The per-row htt/gt are now fully
4933 // assembled (data GN + assignment/ARD prior). Apply only the htt/gt
4934 // Riemannian projection (the decoder/β tier is intentionally absent), then
4935 // return the block-diagonal system. `fixed_decoder_step_from_rows` reads
4936 // only `rows[*].htt`/`gt` + `row_offsets`, so no β-tier object is needed.
4937 if fixed_decoder {
4938 match row_layout.as_ref() {
4939 None => {
4940 // Dense uniform-q: project htt/gt (and the 0-width htbeta, a
4941 // no-op) through the ext-coord manifold.
4942 self.apply_sae_riemannian_geometry(&mut sys);
4943 }
4944 Some(layout) => {
4945 // Compact heterogeneous-q: project each row's htt/gt at its
4946 // own ext-coord point, mirroring the full path's compact
4947 // Riemannian block (htbeta is 0-width here, so skipped).
4948 if !self.ext_coord_manifold().is_euclidean() {
4949 for row_idx in 0..n {
4950 let (manifold_i, point_i) =
4951 self.compact_row_ext_manifold_and_point(row_idx, layout);
4952 let t_i = point_i.view();
4953 let gt_e = sys.rows[row_idx].gt.clone();
4954 let htt_e = sys.rows[row_idx].htt.clone();
4955 sys.rows[row_idx].gt =
4956 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
4957 sys.rows[row_idx].htt = manifold_i.riemannian_hessian_matrix(
4958 t_i,
4959 gt_e.view(),
4960 htt_e.view(),
4961 );
4962 }
4963 }
4964 }
4965 }
4966 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
4967 sys.set_row_gauge_deflation(deflation);
4968 }
4969 self.last_row_layout = row_layout;
4970 self.last_frames_active = frames_engaged;
4971 return Ok(sys);
4972 }
4973 // Apply Riemannian geometry to the per-row row blocks (htt, gt) and
4974 // also to the per-row Kronecker local Jacobians stored in kron_jac.
4975 // When the SAE ext-coord manifold is non-Euclidean (any atom latent
4976 // on sphere / circle / interval), the local Jacobian rows that map
4977 // into the t-block tangent space must be projected via the per-row
4978 // tangent projector P_i. This mirrors what
4979 // `apply_riemannian_latent_geometry` does to `row.htbeta`, applied
4980 // here to the (q × p) kron_jac so the Kronecker htbeta_matvec uses
4981 // the Riemannian-projected form.
4982 // Apply Riemannian geometry only for the dense uniform-q layout. Any
4983 // compact active-set layout (JumpReLU gate or large-K softmax/IBP
4984 // truncation) has heterogeneous q_i; the Riemannian projector path
4985 // requires a uniform latent dimension. The sparse plan only engages on
4986 // Euclidean ext-coord manifolds (see `sparse_active_plan`), so skipping
4987 // the projector here is correct — there is nothing to project.
4988 match row_layout.as_ref() {
4989 None => {
4990 let raw_gt_rows: Vec<Array1<f64>> =
4991 sys.rows.iter().map(|row| row.gt.clone()).collect();
4992 self.apply_sae_riemannian_geometry(&mut sys);
4993 let manifold = self.ext_coord_manifold();
4994 if !frames_engaged && !manifold.is_euclidean() {
4995 let ext = self.ext_coord_matrix();
4996 // Project the local Jacobian columns onto the tangent space at
4997 // each row's ext-coord point. Each column `j` of the row's
4998 // (q_row × p) Jacobian is an ambient-space vector of length
4999 // `q_row`; the manifold projector acts on one such column at a
5000 // time. Working directly on the row-major `jac_flat` storage via
5001 // a single reusable `col_buf` avoids the two dense (q × p) copies
5002 // (flatten→Array2, project, unflatten→Vec) that previously fired
5003 // per row. `t_buf` still holds the row's ext-coord vector.
5004 let mut t_buf = vec![0.0_f64; q];
5005 let mut col_buf = Array1::<f64>::zeros(q);
5006 for row_idx in 0..n {
5007 let ext_row = ext.row(row_idx);
5008 for (slot, &v) in t_buf.iter_mut().zip(ext_row.iter()) {
5009 *slot = v;
5010 }
5011 let t_i = ArrayView1::from(t_buf.as_slice());
5012 let raw_gt = raw_gt_rows[row_idx].view();
5013 let jac_flat = &mut kron_jac[row_idx];
5014 let q_row = jac_flat.len() / p;
5015 for j in 0..p {
5016 for c in 0..q_row {
5017 col_buf[c] = jac_flat[c * p + j];
5018 }
5019 let projected_col = manifold.project_vector_to_gradient_tangent(
5020 t_i,
5021 raw_gt.slice(ndarray::s![..q_row]),
5022 col_buf.slice(ndarray::s![..q_row]),
5023 );
5024 for c in 0..q_row {
5025 jac_flat[c * p + j] = projected_col[c];
5026 }
5027 }
5028 }
5029 }
5030 }
5031 Some(layout) => {
5032 // Compact active-set layout (#1117 follow-up): the dense
5033 // `ext_coord_manifold()` is keyed to the uniform full-`q` block
5034 // ordering, so it cannot be applied to the heterogeneous compact
5035 // rows directly. Instead we rebuild, PER ROW, the product manifold
5036 // and ext-coord point in that row's compact column order (see
5037 // `compact_row_ext_manifold_and_point`) and apply the SAME three
5038 // per-row Riemannian operations the dense
5039 // `apply_riemannian_latent_geometry` applies — gradient tangent
5040 // projection of `gt`, the Riemannian Hessian correction of `htt`,
5041 // and the column tangent projection of `htbeta` — plus the
5042 // identical Kronecker `kron_jac` column projection. On the shared
5043 // active support this is byte-identical to slicing the dense
5044 // product manifold, so engaging the sparse plan on a non-Euclidean
5045 // ext manifold is now correct (the former
5046 // `is_euclidean()`-only guard in `sparse_active_plan` is lifted).
5047 //
5048 // Euclidean ext manifolds still skip all of this (every
5049 // per-row manifold is a product of Euclidean parts whose
5050 // projector is the identity); we early-out so those rows stay
5051 // byte-for-byte the historical compact path.
5052 if !self.ext_coord_manifold().is_euclidean() {
5053 for row_idx in 0..n {
5054 let (manifold_i, point_i) =
5055 self.compact_row_ext_manifold_and_point(row_idx, layout);
5056 let t_i = point_i.view();
5057 // gt / htt / htbeta on the compact ArrowRowBlock, exactly
5058 // as `apply_riemannian_latent_geometry` does for dense
5059 // uniform-q rows.
5060 let gt_e = sys.rows[row_idx].gt.clone();
5061 let htt_e = sys.rows[row_idx].htt.clone();
5062 sys.rows[row_idx].gt =
5063 manifold_i.project_gradient_to_tangent(t_i, gt_e.view());
5064 sys.rows[row_idx].htt =
5065 manifold_i.riemannian_hessian_matrix(t_i, gt_e.view(), htt_e.view());
5066 // #1406: only the frames path holds a real dense `htbeta`
5067 // slab; the matrix-free path leaves it 0-width (the
5068 // cross-block geometry is applied to `kron_jac` below), so
5069 // projecting a zero-column matrix is a no-op we skip.
5070 if frames_engaged {
5071 let htbeta_e = sys.rows[row_idx].htbeta.clone();
5072 sys.rows[row_idx].htbeta = manifold_i
5073 .project_matrix_columns_to_gradient_tangent(
5074 t_i,
5075 gt_e.view(),
5076 htbeta_e.view(),
5077 );
5078 }
5079 // Kronecker local-Jacobian column projection (full-B path
5080 // only), using the SAME pre-projection gradient `gt_e` so
5081 // the cross-block geometry matches the dense branch.
5082 if !frames_engaged {
5083 let jac_flat = &mut kron_jac[row_idx];
5084 let q_row = jac_flat.len() / p;
5085 let mut col_buf = Array1::<f64>::zeros(q_row);
5086 for j in 0..p {
5087 for c in 0..q_row {
5088 col_buf[c] = jac_flat[c * p + j];
5089 }
5090 let projected_col = manifold_i.project_vector_to_gradient_tangent(
5091 t_i,
5092 gt_e.view(),
5093 col_buf.view(),
5094 );
5095 for c in 0..q_row {
5096 jac_flat[c * p + j] = projected_col[c];
5097 }
5098 }
5099 }
5100 }
5101 }
5102 }
5103 }
5104 // Build and install the full-B Kronecker htbeta_matvec.
5105 //
5106 // `SaeKroneckerRows` holds per-row `(a_phi, local_jac)` and implements
5107 // the cross-block operator without ever materialising the dense
5108 // `(q × K·p)` slab. The cross-block factorises as `H_tβ = L · J_β`,
5109 // where `J_β = φᵀ ⊗ I_p` projects a length-`K` β vector onto the
5110 // `p`-dimensional decoded output space (`apply_jbeta`) and `L_i` is
5111 // the per-row `(q_i × p)` assignment+coordinate Jacobian that lifts
5112 // that p-vector into the row's `q_i`-dim tangent block (`apply_l`).
5113 // Both factors are required: the contract of `set_row_htbeta_operator`
5114 // is `out.len() == d` (= `q_i`), so writing `apply_jbeta`'s p-vector
5115 // output directly into a length-`q_i` buffer overflows whenever
5116 // `p > q_i` (the common case once `p` reflects real feature width).
5117 // Symmetric for the transpose: `H_βt = J_βᵀ · Lᵀ`, so apply `Lᵀ`
5118 // first to map the q_i-vector back to p-space, then scatter through
5119 // the support.
5120 // #1017/#1026: the legacy full-B device PCG assumes `G ⊗ I_p`, while
5121 // framed systems carry `G_ij ⊗ W_ij` with rank-r atom blocks. Feeding a
5122 // framed system to that kernel would silently return the wrong Newton
5123 // step. Framed device PCG therefore needs the dedicated factored kernel.
5124 // #1033 large-n: the per-row support `kron_a_phi` and local Jacobians
5125 // `kron_jac` are consumed by BOTH the host matrix-free row operator
5126 // (`SaeKroneckerRows`) and the solver's `DeviceSaePcgData`. Previously
5127 // each took its own full `O(n·q·p)` / `O(n·k_active)` clone, so the
5128 // always-resident footprint of the CPU non-frames path carried TWO copies
5129 // of the dominant Jacobian slab. Promote each to a single `Arc<[…]>` once
5130 // and hand both consumers a refcount bump (`O(1)`) — the backing
5131 // allocation is shared, halving the resident per-row Jacobian memory.
5132 // Reads are identical (`&arc[row]`, `.len()`), so the assembled system and
5133 // every matvec are bit-for-bit unchanged.
5134 let device_rows = if frames_engaged {
5135 None
5136 } else {
5137 let a_phi_shared: Arc<[Vec<(usize, f64)>]> =
5138 Arc::from(std::mem::take(&mut kron_a_phi).into_boxed_slice());
5139 let jac_shared: Arc<[Vec<f64>]> =
5140 Arc::from(std::mem::take(&mut kron_jac).into_boxed_slice());
5141 Some((a_phi_shared, jac_shared))
5142 };
5143 if !frames_engaged {
5144 let (a_phi_shared, jac_shared) = device_rows
5145 .clone()
5146 .expect("non-frames path always populates device_rows");
5147 let kron = Arc::new(SaeKroneckerRows::new(p, a_phi_shared, jac_shared));
5148 let kron_t = Arc::clone(&kron);
5149 let p_dim = p;
5150 sys.set_row_htbeta_operator(
5151 move |row_idx, x, out| {
5152 // out = L_i · (J_β · x). Allocate a length-p scratch buffer
5153 // for the intermediate decoded-output vector; both factors
5154 // overwrite their output buffers (`apply_jbeta` zeroes
5155 // before accumulating, `apply_l` writes per-row), so no
5156 // pre-zeroing of `u_p`/`out` is needed.
5157 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5158 let mut u_p = vec![0.0_f64; p_dim];
5159 if let Some(xs) = x.as_slice() {
5160 kron.apply_jbeta(row_idx, xs, &mut u_p);
5161 } else {
5162 let x_vec: Vec<f64> = x.iter().copied().collect();
5163 kron.apply_jbeta(row_idx, &x_vec, &mut u_p);
5164 }
5165 kron.apply_l(row_idx, &u_p, out_slice);
5166 },
5167 move |row_idx, v, out| {
5168 // out += J_βᵀ · (Lᵀ · v). `apply_l_t` accumulates into a
5169 // zero-initialised length-p buffer to produce the p-vector
5170 // `Lᵀ v`; `scatter_jbeta_t` then adds φ_i[s] · u_p[j] into
5171 // the length-K β accumulator at each active `(s, j)`.
5172 let out_slice = out.as_slice_mut().expect("out is always standard-layout");
5173 let mut u_p = vec![0.0_f64; p_dim];
5174 if let Some(vs) = v.as_slice() {
5175 kron_t.apply_l_t(row_idx, vs, &mut u_p);
5176 } else {
5177 let v_vec: Vec<f64> = v.iter().copied().collect();
5178 kron_t.apply_l_t(row_idx, &v_vec, &mut u_p);
5179 }
5180 kron_t.scatter_jbeta_t(row_idx, &u_p, out_slice);
5181 },
5182 );
5183 }
5184 let mut beta_penalty_assembly = SaeBetaPenaltyAssembly::default();
5185 let factored_row_projection = if frames_engaged && analytic_penalties.is_some() {
5186 Some(&frame_projection)
5187 } else {
5188 None
5189 };
5190 if let Some(registry) = analytic_penalties {
5191 // Upfront validation: refuse penalty kinds the SAE row layout
5192 // cannot host, and refuse mixed-d row-block configurations.
5193 // This makes the dispatch loop below total — no runtime
5194 // "unsupported penalty" fallthrough, no K-gating.
5195 self.validate_analytic_penalty_registry(registry)
5196 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5197 beta_penalty_assembly = self
5198 .add_sae_analytic_penalty_contributions(
5199 &mut sys,
5200 registry,
5201 penalty_scale,
5202 row_layout.as_ref(),
5203 dense_beta_curvature,
5204 factored_row_projection,
5205 )
5206 .map_err(|err| format!("SaeManifoldTerm::assemble_arrow_schur: {err}"))?;
5207 }
5208 // #1026 — decoder repulsion (collinearity-gated, registry-independent):
5209 // accumulate into the full-`B` β-tier here, BEFORE the frame transform,
5210 // so a framed system carries it identically to the analytic β penalties.
5211 // No-op unless two atoms are near-collinear (the frozen gate is `None`).
5212 if self.add_sae_decoder_repulsion(&mut sys, penalty_scale, dense_beta_curvature) {
5213 beta_penalty_assembly.record_curvature(dense_beta_curvature);
5214 }
5215 // #1026/#1522 — interior-point collapse-prevention barriers. The amplitude
5216 // barrier supplies the OUTWARD radial force at the zero-decoder collapse
5217 // point (the principal failure state the threshold repulsion skips), and
5218 // the separation barrier supplies the alignment-divergent separating
5219 // curvature on normalized shapes weighted by coactivation. Both accumulate
5220 // into the full-`B` β-tier here, BEFORE the frame transform, so a framed
5221 // system carries them identically to the analytic β penalties.
5222 // #1610 — on the dense path the barrier's Levenberg majorizer scatters
5223 // onto `sys.hbb`; on the matrix-free / framed production path `sys.hbb` is
5224 // unused, so the barrier hands back a per-atom scalar ridge which we fold
5225 // into `smooth_scaled_s` (the single source for the CPU composite penalty
5226 // op AND the device smooth blocks), restoring the collapse-prevention
5227 // curvature the operator was silently dropping there.
5228 let mut sep_atom_curv = vec![0.0_f64; self.atoms.len()];
5229 if self.add_sae_separation_barrier(
5230 &mut sys,
5231 penalty_scale,
5232 dense_beta_curvature,
5233 &mut sep_atom_curv,
5234 ) {
5235 if dense_beta_curvature {
5236 beta_penalty_assembly.record_curvature(true);
5237 } else {
5238 // Fold the per-atom majorizer `lev_k·I_{M_k}` into the smooth
5239 // penalty factor `λ S_k`. With `⊗ I_p` (full-`B`) or `⊗ I_{r_k}`
5240 // (factored, `U_kᵀU_k = I`) this is exactly the `lev_k·I` block
5241 // diagonal the dense path writes — and it now flows through the
5242 // structured penalty op and the device smooth blocks. No
5243 // `deferred_factored` mark: the curvature is in the smooth op, not
5244 // a deferred dense block, so the device path stays engaged.
5245 for atom_idx in 0..self.atoms.len() {
5246 let c = sep_atom_curv[atom_idx];
5247 if c > 0.0 {
5248 let m = smooth_scaled_s[atom_idx].nrows();
5249 for i in 0..m {
5250 smooth_scaled_s[atom_idx][[i, i]] += c;
5251 }
5252 smooth_ops[atom_idx] = Arc::new(IdentityRightKroneckerPenaltyOp {
5253 factor_a: smooth_scaled_s[atom_idx].clone(),
5254 p,
5255 global_offset: beta_offsets[atom_idx],
5256 k: beta_dim,
5257 });
5258 }
5259 }
5260 }
5261 }
5262 if frames_engaged {
5263 // ── #972 / #977 T1 — FACTORED β-tier transform ──────────────────
5264 //
5265 // The entire β-tier above was assembled in the full-`B` (p-wide)
5266 // layout: `sys.gb` is `g_B` (length `beta_dim`), `sys.hbb` carries
5267 // any analytic Beta-tier penalty, and `g_blocks` is the
5268 // FRAME-INDEPENDENT basis Gram. We now rebuild the β-tier in the
5269 // factored coordinate space `C` (width `factored_border_dim`), the
5270 // full-`B` system sandwiched by `Φ = blkdiag(I_{M_k} ⊗ U_k)`:
5271 // * gradient `g_C = Φᵀ g_B` (per atom `(g_B U_k)`),
5272 // * data H `Φᵀ(G⊗I_p)Φ = G_{ij}⊗(U_iᵀU_j)`,
5273 // * smooth `λ S_k ⊗ I_{r_k}` (since `U_kᵀU_k = I`),
5274 // * analytic `Φᵀ hbb Φ` (dense, only if written).
5275 // Un-framed atoms ride the `r_k = p, U_k = I_p` identity special case.
5276 let off_c = &frame_projection.border_offsets;
5277 let ranks = &frame_projection.ranks;
5278 let basis_sizes = &frame_projection.basis_sizes;
5279 let border_dim = frame_projection.border_dim();
5280 let gb_c = frame_projection.project_border_vec(sys.gb.view());
5281
5282 // Data β-Hessian: `G_{ij} ⊗ W_{ij}` with `W_{ij} = U_iᵀU_j`. The
5283 // basis Gram `g_blocks` is unchanged; only the output factor is the
5284 // per-pair frame overlap (`I_{r_k}` within a framed atom, `I_p` for
5285 // un-framed).
5286 let mut frame_blocks: Vec<FactoredFrameGBlock> = Vec::with_capacity(g_blocks.len());
5287 for ((atom_i, atom_j), data) in g_blocks.into_iter() {
5288 if data.iter().all(|&v| v == 0.0) {
5289 continue;
5290 }
5291 // `W_{ij} = U_iᵀ U_j` from the precomputed per-atom frames.
5292 let w = self.frame_cross_factor(atom_i, atom_j);
5293 frame_blocks.push(FactoredFrameGBlock {
5294 atom_i,
5295 atom_j,
5296 g: data,
5297 w,
5298 });
5299 }
5300 // #1017/#1026 — snapshot the factored data-fit blocks for the
5301 // frames-engaged device PCG BEFORE `FactoredFrameKroneckerOp::new`
5302 // consumes them. Cheap clone (co-occurring blocks only).
5303 let device_frame_blocks = frame_blocks.clone();
5304 let data_op =
5305 FactoredFrameKroneckerOp::new(ranks.clone(), basis_sizes.clone(), frame_blocks)?;
5306
5307 // Smooth penalty in factored space: `λ S_k ⊗ I_{r_k}` at `off_C[k]`.
5308 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = Vec::with_capacity(self.atoms.len() + 2);
5309 for k in 0..self.atoms.len() {
5310 let r = ranks[k];
5311 ops.push(Arc::new(IdentityRightKroneckerPenaltyOp {
5312 factor_a: smooth_scaled_s[k].clone(),
5313 p: r,
5314 global_offset: off_c[k],
5315 k: border_dim,
5316 }));
5317 }
5318 ops.push(Arc::new(data_op));
5319 // Analytic Beta-tier penalty: project the dense full-`B` `hbb` block
5320 // `Φᵀ hbb Φ` into the factored space. Only present when a Beta-tier
5321 // penalty actually wrote `hbb` (else `hbb` is all-zero and the dense
5322 // `(border_dim)²` op is skipped entirely, exactly as full-`B`).
5323 if beta_penalty_assembly.dense_written {
5324 let hbb_c =
5325 self.project_dense_penalty_to_factored(sys.hbb.view(), &frame_projection);
5326 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5327 } else if beta_penalty_assembly.deferred_factored {
5328 // Registry Beta-tier curvature deferred to factored-space probing.
5329 // The registry may be absent when `deferred_factored` was set ONLY
5330 // by the frozen-gate decoder repulsion (which is
5331 // registry-independent), so start from a zero factored block in
5332 // that case instead of unwrapping.
5333 let mut hbb_c = match analytic_penalties {
5334 Some(registry) => self.build_factored_beta_penalty_curvature(
5335 registry,
5336 penalty_scale,
5337 &frame_projection,
5338 ),
5339 None => Array2::<f64>::zeros((
5340 frame_projection.border_dim(),
5341 frame_projection.border_dim(),
5342 )),
5343 };
5344 // #1610 — the frozen-gate decoder repulsion's PSD majorizer was
5345 // dropped on this matrix-free/framed path (only its gradient was
5346 // applied). Project it into the factored block via the same
5347 // `psd_majorizer_hvp` + frame-projection probe pattern the registry
5348 // DecoderIncoherence uses, so the collapse-prevention curvature
5349 // reaches the operator here too. No-op when no repulsion is active.
5350 self.add_factored_repulsion_curvature(
5351 &mut hbb_c,
5352 penalty_scale,
5353 &frame_projection,
5354 );
5355 ops.push(Arc::new(DensePenaltyOp(hbb_c)));
5356 }
5357
5358 // Re-point the system's β-tier to the factored width. The t-tier
5359 // (per-row `htt`, `gt`) is frame-independent and untouched; row
5360 // cross-block slabs were allocated and assembled directly in
5361 // factored coordinates, so analytic row supplements and data-fit
5362 // cross terms already share shape `(q_i × factored_border_dim)`.
5363 sys.k = border_dim;
5364 sys.gb = gb_c;
5365 self.reclaim_border_hbb_workspace(&mut sys);
5366 // Factored per-atom block ranges for the block-Jacobi Schur
5367 // preconditioner: `[off_C[k] .. off_C[k] + M_k·r_k]`.
5368 let mut block_ranges: Vec<std::ops::Range<usize>> =
5369 Vec::with_capacity(self.atoms.len());
5370 for k in 0..self.atoms.len() {
5371 let start = off_c[k];
5372 block_ranges.push(start..start + basis_sizes[k] * ranks[k]);
5373 }
5374 sys.set_block_offsets(Arc::from(block_ranges.into_boxed_slice()));
5375 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: border_dim, ops }));
5376 // #1017/#1026 — install the frames-engaged device SAE PCG data. Skipped
5377 // (CPU fallback) when a dense analytic Beta-tier penalty fired (the
5378 // device kernel does not model that extra dense term). Builder:
5379 // `crate::frames::build_framed_device_sae_data`.
5380 let has_dense_beta_penalty =
5381 beta_penalty_assembly.dense_written || beta_penalty_assembly.deferred_factored;
5382 if !has_dense_beta_penalty {
5383 let device = crate::frames::build_framed_device_sae_data(
5384 crate::frames::FramedDeviceArgs {
5385 p,
5386 border_dim,
5387 border_offsets: off_c.as_slice(),
5388 ranks: ranks.as_slice(),
5389 basis_sizes: basis_sizes.as_slice(),
5390 smooth_scaled_s: &smooth_scaled_s,
5391 frame_blocks: device_frame_blocks,
5392 rows: &sys.rows,
5393 },
5394 );
5395 sys.set_device_sae_pcg_data(device);
5396 }
5397 } else {
5398 let (device_a_phi, device_local_jac) =
5399 device_rows.expect("full-beta SAE PCG rows are cloned before row operator install");
5400 // Wire per-atom β block ranges so the Jacobi preconditioner builds one
5401 // dense Schur sub-block per atom (block-Jacobi) instead of scalar-diagonal
5402 // inversion. Each atom's decoder coefficients form a natural block:
5403 // `[beta_offsets[k] .. beta_offsets[k] + basis_size[k] * p_out]`.
5404 sys.set_block_offsets(self.beta_block_offsets());
5405 // Install the composite BetaPenaltyOp (#296): smoothness contributions
5406 // via per-atom KroneckerPenaltyOp (avoid dense K×K materialisation), the
5407 // data-fit Gauss-Newton β-Hessian as the structured `G ⊗ I_p`
5408 // SparseBlockKroneckerPenaltyOp (block-sparse over co-occurring
5409 // `(atom, atom')` pairs, block-diagonal across the `p` output channels,
5410 // identical per channel), plus — only when a Beta-tier analytic penalty
5411 // was written — the dense `sys.hbb` residual contribution. When no beta
5412 // penalty fired, `sys.hbb` is all-zero and the dense `(K·p)²` operator
5413 // is skipped entirely. The sparse data op tracks only the active-atom
5414 // couplings, so its storage and matvec cost scale with `k_active`, not
5415 // `K`, at `K = 100K`.
5416 // Convert the per-atom-pair coupling map into `SparseGBlock`s keyed
5417 // by μ-space offsets. Empty blocks (no co-occurrence) are simply
5418 // absent from the map.
5419 let g_sparse_blocks: Vec<SparseGBlock> = g_blocks
5420 .into_iter()
5421 .filter_map(|((atom_i, atom_j), data)| {
5422 if data.iter().all(|&v| v == 0.0) {
5423 None
5424 } else {
5425 Some(SparseGBlock {
5426 row_off: mu_offsets[atom_i],
5427 col_off: mu_offsets[atom_j],
5428 data,
5429 })
5430 }
5431 })
5432 .collect();
5433 let device_smooth_blocks = smooth_scaled_s
5434 .iter()
5435 .enumerate()
5436 .map(|(atom_idx, factor_a)| {
5437 // #1117 — rank deficiency is removed at the basis layer, so the
5438 // device PCG smooth block is just `λ S_k ⊗ I_p` (full-rank
5439 // design); no data-null deflation is folded in here.
5440 DeviceSaeSmoothBlock {
5441 global_offset: beta_offsets[atom_idx],
5442 factor_a: factor_a.clone(),
5443 }
5444 })
5445 .collect();
5446 sys.set_device_sae_pcg_data(DeviceSaePcgData {
5447 p,
5448 beta_dim,
5449 a_phi: device_a_phi,
5450 local_jac: device_local_jac,
5451 smooth_blocks: device_smooth_blocks,
5452 sparse_g_blocks: g_sparse_blocks.clone(),
5453 frame: None,
5454 });
5455 let mut ops: Vec<Arc<dyn BetaPenaltyOp>> = smooth_ops;
5456 ops.push(Arc::new(SparseBlockKroneckerPenaltyOp {
5457 p,
5458 dim_a: m_total,
5459 k: beta_dim,
5460 blocks: g_sparse_blocks,
5461 }));
5462 if beta_penalty_assembly.dense_written {
5463 ops.push(Arc::new(DensePenaltyOp(sys.hbb.clone())));
5464 }
5465 sys.set_penalty_op(Arc::new(CompositePenaltyOp { k: beta_dim, ops }));
5466 self.reclaim_border_hbb_workspace(&mut sys);
5467 }
5468 if let Some(deflation) = self.row_gauge_deflation_for_layout(row_layout.as_ref()) {
5469 sys.set_row_gauge_deflation(deflation);
5470 }
5471 // #1038 IBP cross-row Woodbury source. The exact IBP Hessian has the
5472 // per-column rank-one cross-row block `H_(i,k),(j,k) = w·s'_k·z'_ik·z'_jk`
5473 // (for ALL `i,j`, including the `i=j` self term) that couples DISTINCT
5474 // latent rows through the shared empirical mass `M_k = Σ_i z_ik`. The
5475 // assembled row-block-diagonal `htt` already carries the `i=j` self term
5476 // `w·s'_k·z'_ik²` — it is the first summand of `assignment_hdiag`'s
5477 // `hessian_diag` value `w·(score_derivative·z_jac² + score·c_ik)` written
5478 // at the logit diagonal above. So the consumer (`solver::arrow_schur`,
5479 // #1038 `IbpCrossRowSource`/`CrossRowWoodbury`) DOWNDATES exactly
5480 // `Σ_k d_k·z'_ik²` (`self_term_downdate`) to recover the NO-SELF base
5481 // `H₀'`, then re-adds the FULL rank-one `U D Uᵀ` via the determinant
5482 // lemma — so value, the evidence log-determinant, and the θ/ρ-adjoint all
5483 // differentiate the SAME `H_full = H₀' + U D Uᵀ`.
5484 //
5485 // The source is built from the SAME `ibp_assignment_third_channels`
5486 // operator the #1006 θ-adjoint consumes:
5487 // * `d[k] = cross_row_d[k] = w·s'_k = w·score_derivative_k` (the column
5488 // `D`-coefficient — NOT sign-definite, hence the consumer's
5489 // indefinite-capacitance LU);
5490 // * `entries[(i,k)] = (global_t_index, k, z'_ik)` with `z'_ik =
5491 // z_jac[i·K + k]`. For the DENSE layout (`assignment_coord_dim() = K`,
5492 // `last_row_layout = None`) atom `k`'s logit slot is local position `k`
5493 // of row `i`'s block, so `global_t_index = sys.row_offsets[i] + k`. For
5494 // the COMPACT layout (#1420) only the row's active atoms are
5495 // coordinates and atom `k` lives at local position `pos` of
5496 // `active_atoms[row]`, so `global_t_index = sys.row_offsets[i] + pos`.
5497 // Both pin the `U`-column convention bit-for-bit to the consumer's
5498 // `ibp_logit_sites`/`row_vars_for_cache_row` slot mapping.
5499 if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
5500 let mut entries: Vec<(usize, usize, f64)> = Vec::with_capacity(n * k_atoms);
5501 for row in 0..n {
5502 let start = row * k_atoms;
5503 let g_base = sys.row_offsets[row];
5504 match row_layout.as_ref() {
5505 // #1420: compact layout — the local logit slot `pos` (not the
5506 // global atom index `k`) is the t-coordinate. Atom `k`'s logit
5507 // lives at local position `pos` of `active_atoms[row]`, so emit
5508 // `(g_base + pos, atom, z_jac[row·K + atom])` for the active set
5509 // only. Using `g_base + k` would attach atom `k`'s derivative to
5510 // the wrong slot (and run out of range for compact rows),
5511 // violating the `IbpCrossRowSource` contract.
5512 Some(layout) => {
5513 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
5514 let z_prime = channels.z_jac[start + atom];
5515 entries.push((g_base + pos, atom, z_prime));
5516 }
5517 }
5518 // Dense layout: atom `k`'s logit slot is local position `k`.
5519 None => {
5520 for k in 0..k_atoms {
5521 let z_prime = channels.z_jac[start + k];
5522 entries.push((g_base + k, k, z_prime));
5523 }
5524 }
5525 }
5526 }
5527 let source = IbpCrossRowSource {
5528 r: k_atoms,
5529 d: channels.cross_row_d.clone(),
5530 entries,
5531 };
5532 sys.set_ibp_cross_row_source(source);
5533 }
5534 // Store the active-set layout for `apply_newton_step`.
5535 self.last_row_layout = row_layout;
5536 // Record whether `delta_beta` from this system is a factored ΔC (needs a
5537 // frame lift) or a full-`B` ΔB. Read by `apply_newton_step_impl`.
5538 self.last_frames_active = frames_engaged;
5539 Ok(sys)
5540 }
5541
5542 /// Project a dense full-`B` Beta-tier penalty Hessian `hbb` (`beta_dim ×
5543 /// beta_dim`, the analytic `∂²P/∂B∂B` block) into the factored coordinate
5544 /// space `Φᵀ hbb Φ` (`border_dim × border_dim`) for the #972 / #977 T1
5545 /// frame transform. `Φ = blkdiag(I_{M_k} ⊗ U_k)` maps C-space → B-space, so
5546 /// the projected block contracts both index legs through the per-atom frames.
5547 ///
5548 /// The projection is done in two passes to stay `O(beta_dim · border_dim +
5549 /// border_dim²)` instead of forming the dense `Φ` explicitly: first
5550 /// `T = hbb · Φ` (right multiply, columns fold `U`), then `Φᵀ · T` (left
5551 /// multiply, rows fold `U`). Analytic Beta-tier penalties are rare and small,
5552 /// so this only fires when one is actually installed.
5553 pub(crate) fn project_dense_penalty_to_factored(
5554 &self,
5555 hbb: ArrayView2<'_, f64>,
5556 projection: &FrameProjection,
5557 ) -> Array2<f64> {
5558 projection.project_block(hbb)
5559 }
5560
5561 pub(crate) fn build_factored_beta_penalty_curvature(
5562 &self,
5563 registry: &AnalyticPenaltyRegistry,
5564 penalty_scale: f64,
5565 projection: &FrameProjection,
5566 ) -> Array2<f64> {
5567 let rho_global = Array1::<f64>::zeros(registry.total_rho_count());
5568 let layout = registry.rho_layout();
5569 let target_beta = self.flatten_beta();
5570 let mut hbb_c = Array2::<f64>::zeros((projection.border_dim(), projection.border_dim()));
5571 for (penalty, (rho_slice, tier, _name)) in registry.penalties.iter().zip(layout.iter()) {
5572 if matches!(penalty, AnalyticPenaltyKind::Ard(_)) {
5573 continue;
5574 }
5575 let rho_local = rho_global.slice(s![rho_slice.clone()]);
5576 match tier {
5577 PenaltyTier::Psi if matches!(penalty, AnalyticPenaltyKind::NuclearNorm(_)) => {
5578 self.add_factored_beta_penalty_curvature_for_penalty(
5579 &mut hbb_c,
5580 penalty,
5581 target_beta.view(),
5582 rho_local,
5583 penalty_scale,
5584 projection,
5585 );
5586 }
5587 PenaltyTier::Beta => {
5588 self.add_factored_beta_penalty_curvature_for_penalty(
5589 &mut hbb_c,
5590 penalty,
5591 target_beta.view(),
5592 rho_local,
5593 penalty_scale,
5594 projection,
5595 );
5596 }
5597 _ => {}
5598 }
5599 }
5600 hbb_c
5601 }
5602
5603 pub(crate) fn add_factored_beta_penalty_curvature_for_penalty(
5604 &self,
5605 hbb_c: &mut Array2<f64>,
5606 penalty: &AnalyticPenaltyKind,
5607 target_beta: ArrayView1<'_, f64>,
5608 rho_local: ArrayView1<'_, f64>,
5609 penalty_scale: f64,
5610 projection: &FrameProjection,
5611 ) {
5612 let p = self.output_dim();
5613 if let AnalyticPenaltyKind::DecoderIncoherence(base) = penalty {
5614 let Some(per_fit) = self.live_decoder_incoherence_penalty(base) else {
5615 return;
5616 };
5617 let beta_dim = self.beta_dim();
5618 let mut probe = Array1::<f64>::zeros(beta_dim);
5619 for k in 0..self.atoms.len() {
5620 for basis_col in 0..projection.basis_sizes[k] {
5621 for frame_col in 0..projection.ranks[k] {
5622 probe.fill(0.0);
5623 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5624 let col = projection.border_offsets[k]
5625 + basis_col * projection.ranks[k]
5626 + frame_col;
5627 let hv = per_fit.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5628 projection
5629 .project_border_vec(hv.view())
5630 .iter()
5631 .enumerate()
5632 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5633 }
5634 }
5635 }
5636 return;
5637 }
5638 if let AnalyticPenaltyKind::MechanismSparsity(base) = penalty {
5639 for (per_atom, start, end) in self.live_mechanism_sparsity_penalties(base) {
5640 let atom_idx = projection
5641 .beta_offsets
5642 .iter()
5643 .position(|&offset| offset == start)
5644 .expect("live mechanism-sparsity offset must match an SAE atom");
5645 let block_len = end - start;
5646 let mut local_penalty = per_atom.clone();
5647 local_penalty.target = PsiSlice {
5648 range: 0..block_len,
5649 latent_dim: Some(projection.basis_sizes[atom_idx]),
5650 };
5651 let block = target_beta.slice(s![start..end]);
5652 let mut probe = Array1::<f64>::zeros(block_len);
5653 for basis_col in 0..projection.basis_sizes[atom_idx] {
5654 for frame_col in 0..projection.ranks[atom_idx] {
5655 probe.fill(0.0);
5656 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5657 let col = projection.border_offsets[atom_idx]
5658 + basis_col * projection.ranks[atom_idx]
5659 + frame_col;
5660 let hv = local_penalty.psd_majorizer_hvp(block, rho_local, probe.view());
5661 projection.project_local_atom_vec_into(
5662 atom_idx,
5663 hv.view(),
5664 hbb_c.column_mut(col),
5665 penalty_scale,
5666 );
5667 }
5668 }
5669 }
5670 return;
5671 }
5672 if let AnalyticPenaltyKind::NuclearNorm(base) = penalty {
5673 for (per_atom, start, end) in self.live_nuclear_norm_penalties(base) {
5674 let atom_idx = projection
5675 .beta_offsets
5676 .iter()
5677 .position(|&offset| offset == start)
5678 .expect("live nuclear-norm offset must match an SAE atom");
5679 let block = target_beta.slice(s![start..end]);
5680 let block_len = end - start;
5681 let mut probe = Array1::<f64>::zeros(block_len);
5682 for basis_col in 0..projection.basis_sizes[atom_idx] {
5683 for frame_col in 0..projection.ranks[atom_idx] {
5684 probe.fill(0.0);
5685 projection.lift_local_axis_into(&mut probe, atom_idx, basis_col, frame_col);
5686 let col = projection.border_offsets[atom_idx]
5687 + basis_col * projection.ranks[atom_idx]
5688 + frame_col;
5689 let hv = per_atom.psd_majorizer_hvp(block, rho_local, probe.view());
5690 projection.project_local_atom_vec_into(
5691 atom_idx,
5692 hv.view(),
5693 hbb_c.column_mut(col),
5694 penalty_scale,
5695 );
5696 }
5697 }
5698 }
5699 return;
5700 }
5701 let beta_dim = self.beta_dim();
5702 let mut probe = Array1::<f64>::zeros(beta_dim);
5703 for k in 0..self.atoms.len() {
5704 for basis_col in 0..projection.basis_sizes[k] {
5705 for frame_col in 0..projection.ranks[k] {
5706 probe.fill(0.0);
5707 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5708 let col =
5709 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5710 let hv = penalty.psd_majorizer_hvp(target_beta, rho_local, probe.view());
5711 projection
5712 .project_border_vec(hv.view())
5713 .iter()
5714 .enumerate()
5715 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5716 }
5717 }
5718 }
5719 assert_eq!(p, self.output_dim());
5720 }
5721
5722 /// #1610 — project the frozen-gate decoder-repulsion PSD majorizer into the
5723 /// factored β block `hbb_c`. Mirrors the `DecoderIncoherence` arm of
5724 /// [`Self::add_factored_beta_penalty_curvature_for_penalty`] but sources the
5725 /// penalty from [`Self::live_decoder_repulsion_penalty`] (registry-independent,
5726 /// collinearity-gated), so the repulsion curvature reaches the operator on the
5727 /// matrix-free/framed path where the dense `sys.hbb` write is unused. No-op
5728 /// when no repulsion is active.
5729 pub(crate) fn add_factored_repulsion_curvature(
5730 &self,
5731 hbb_c: &mut Array2<f64>,
5732 penalty_scale: f64,
5733 projection: &FrameProjection,
5734 ) {
5735 let Some(per_fit) = self.live_decoder_repulsion_penalty() else {
5736 return;
5737 };
5738 let beta_dim = self.beta_dim();
5739 let target_beta = self.flatten_beta();
5740 // The repulsion penalty is non-learnable; its strength is already folded
5741 // into the frozen gate (see `live_decoder_repulsion_penalty`), so the rho
5742 // slice is empty/inert.
5743 let rho_local = Array1::<f64>::zeros(0);
5744 let mut probe = Array1::<f64>::zeros(beta_dim);
5745 for k in 0..self.atoms.len() {
5746 for basis_col in 0..projection.basis_sizes[k] {
5747 for frame_col in 0..projection.ranks[k] {
5748 probe.fill(0.0);
5749 projection.lift_axis_into(&mut probe, k, basis_col, frame_col);
5750 let col =
5751 projection.border_offsets[k] + basis_col * projection.ranks[k] + frame_col;
5752 let hv =
5753 per_fit.psd_majorizer_hvp(target_beta.view(), rho_local.view(), probe.view());
5754 projection
5755 .project_border_vec(hv.view())
5756 .iter()
5757 .enumerate()
5758 .for_each(|(row, &v)| hbb_c[[row, col]] += penalty_scale * v);
5759 }
5760 }
5761 }
5762 }
5763
5764 pub(crate) fn ext_coord_matrix(&self) -> Array2<f64> {
5765 let n = self.n_obs();
5766 let q = self.assignment.row_block_dim();
5767 let flat = self.assignment.flatten_ext_coords();
5768 let mut out = Array2::<f64>::zeros((n, q));
5769 for row in 0..n {
5770 for col in 0..q {
5771 out[[row, col]] = flat[row * q + col];
5772 }
5773 }
5774 out
5775 }
5776
5777 pub(crate) fn ext_coord_manifold(&self) -> LatentManifold {
5778 let mut parts = Vec::with_capacity(self.assignment.row_block_dim());
5779 for _ in 0..self.assignment.assignment_coord_dim() {
5780 parts.push(LatentManifold::Euclidean);
5781 }
5782 let mut any_constrained = false;
5783 for coord in &self.assignment.coords {
5784 if coord.manifold().is_euclidean() {
5785 for _ in 0..coord.latent_dim() {
5786 parts.push(LatentManifold::Euclidean);
5787 }
5788 } else {
5789 any_constrained = true;
5790 parts.push(coord.manifold().clone());
5791 }
5792 }
5793 if any_constrained {
5794 LatentManifold::Product(parts)
5795 } else {
5796 LatentManifold::Euclidean
5797 }
5798 }
5799
5800 pub(crate) fn apply_sae_riemannian_geometry(&self, sys: &mut ArrowSchurSystem) {
5801 let manifold = self.ext_coord_manifold();
5802 if manifold.is_euclidean() {
5803 return;
5804 }
5805 let ext = self.ext_coord_matrix();
5806 let latent =
5807 LatentCoordValues::from_matrix_with_manifold(ext.view(), LatentIdMode::None, manifold);
5808 sys.apply_riemannian_latent_geometry(&latent);
5809 }
5810
5811 /// Build the compact-layout ext-coord product manifold and point for one row.
5812 ///
5813 /// The dense `ext_coord_manifold()` is keyed to the full-`q` block ordering
5814 /// `[assignment parts (all Euclidean for IBP-MAP / JumpReLU), then per-atom
5815 /// coord blocks in atom order]`. A compact active-set row instead lays its
5816 /// `q_active` columns out as `[one Euclidean logit slot per active atom,
5817 /// then each active atom's coord block in `active` order]` (see
5818 /// [`SaeRowLayout::from_active_atoms`] / `coord_starts`). To reuse the exact
5819 /// per-row Riemannian projector on the compact block we rebuild a product
5820 /// manifold and the matching ext-coord point in that compact order: the
5821 /// `active.len()` logit slots are `Euclidean` (the assignment channel is
5822 /// always Euclidean for the modes that engage sparsity — `assignment_coord_dim
5823 /// == k_atoms`), and each active atom contributes its own coordinate
5824 /// manifold. On the shared active support this is byte-identical to slicing
5825 /// the dense full-`q` product manifold, so the compact projection matches the
5826 /// dense path exactly — it only drops the inactive atoms' (negligible-mass)
5827 /// coordinate blocks the compact layout already excludes from curvature.
5828 ///
5829 /// Returns `(manifold, t_compact)` where `t_compact` has length `q_active`.
5830 /// The logit-slot entries of `t_compact` are filled from the row logits (the
5831 /// Euclidean projector ignores the point, so any finite value is equivalent;
5832 /// using the true logits keeps the point well-defined and finite).
5833 pub(crate) fn compact_row_ext_manifold_and_point(
5834 &self,
5835 row: usize,
5836 layout: &SaeRowLayout,
5837 ) -> (LatentManifold, Array1<f64>) {
5838 let active = &layout.active_atoms[row];
5839 let q_active = layout.row_q_active(row);
5840 let mut parts: Vec<LatentManifold> = Vec::with_capacity(active.len() + active.len());
5841 let mut point = Array1::<f64>::zeros(q_active);
5842 // Logit slots: one Euclidean part per active atom, in `active` order.
5843 let logits_row = self.assignment.logits.row(row);
5844 for (j, &k) in active.iter().enumerate() {
5845 parts.push(LatentManifold::Euclidean);
5846 point[j] = logits_row[k];
5847 }
5848 // Coordinate blocks: each active atom's coordinate manifold + point, at
5849 // the compact coord start the layout assigned it.
5850 for (j, &k) in active.iter().enumerate() {
5851 let coord = &self.assignment.coords[k];
5852 let d = coord.latent_dim();
5853 let coord_start = layout.coord_starts[row][j];
5854 let manifold_k = coord.manifold();
5855 // A `d`-dim coordinate whose manifold is a product (e.g. a torus =
5856 // Circle×Circle) already carries its `d` parts; a scalar manifold is
5857 // one part. Either way the manifold's ambient width must equal `d`,
5858 // matching the `d` compact columns at `coord_start`.
5859 parts.push(manifold_k.clone());
5860 let coord_point = coord.row(row);
5861 for axis in 0..d {
5862 point[coord_start + axis] = coord_point[axis];
5863 }
5864 }
5865 (LatentManifold::Product(parts), point)
5866 }
5867
5868 /// Numerical rank of a symmetric matrix: the count of eigenvalues
5869 /// exceeding `tol · max_eig`, with `tol = 1e-9` (the conventional
5870 /// relative spectral cutoff used elsewhere in the codebase).
5871 ///
5872 /// Used to count the penalised dimension of each atom's `smooth_penalty`
5873 /// `S_k` so the REML criterion's `−½·p·rank(S)·log λ_smooth` Occam term
5874 /// uses the *effective* penalty rank rather than the ambient basis size
5875 /// (a thin-plate / B-spline penalty has a non-trivial null space).
5876 pub(crate) fn symmetric_rank(s: &Array2<f64>) -> Result<usize, String> {
5877 if s.nrows() != s.ncols() {
5878 return Err(format!(
5879 "SaeManifoldTerm::symmetric_rank: matrix must be square, got {}x{}",
5880 s.nrows(),
5881 s.ncols()
5882 ));
5883 }
5884 let m = s.ncols();
5885 if m == 0 {
5886 return Ok(0);
5887 }
5888 // Symmetrize defensively through the shared ndarray helper. The SAE
5889 // rank cutoff is intentionally local to the SAE evidence contract; only
5890 // the symmetric cleanup is shared with the other construction modules.
5891 let mut sym = s.clone();
5892 gam_linalg::matrix::symmetrize_in_place(&mut sym);
5893 let (evals, _evecs) = sym
5894 .eigh(Side::Lower)
5895 .map_err(|e| format!("SaeManifoldTerm::symmetric_rank: eigh failed: {e}"))?;
5896 let max_eig = evals.iter().fold(0.0_f64, |acc, &v| acc.max(v));
5897 if !(max_eig > 0.0) {
5898 return Ok(0);
5899 }
5900 let tol = SAE_MANIFOLD_SPECTRAL_RANK_CUTOFF * max_eig;
5901 Ok(evals.iter().filter(|&&v| v > tol).count())
5902 }
5903
5904 /// Penalised quasi-Laplace evidence score for the SAE term at a FIXED ρ.
5905 ///
5906 /// #1421: this is NOT a true normalized-prior REML/evidence objective. The
5907 /// assignment priors (softmax entropy, JumpReLU) have NO finite normalizer:
5908 /// for softmax the reference-logit chart sends `P(ℓ)→0` as a free logit →±∞
5909 /// so `∫ e^{−λP} dℓ = ∞`, and JumpReLU's bounded penalty `0<P<λ` keeps
5910 /// `e^{−λP}` bounded below over an unbounded domain, also divergent. There is
5911 /// therefore no ρ-independent assignment-prior normalizer that can be dropped
5912 /// as a constant. The smoothing-penalty `−½log|λS|_+` term IS a genuine
5913 /// (proper-Gaussian) REML normalizer and is kept exactly; the rest is a
5914 /// penalized quasi-Laplace score (Laplace curvature term `½log|H|` around the
5915 /// inner optimum), which the engine minimizes over ρ.
5916 ///
5917 /// Runs the inner `(t, β)` arrow-Schur Newton solve to convergence at the
5918 /// supplied ρ (with NO in-loop ARD update — ρ is owned by the engine),
5919 /// then forms the Laplace/REML cost
5920 ///
5921 /// ```text
5922 /// V(ρ) = ℓ_pen(t̂, β̂; ρ) + ½ log|H(t̂, β̂; ρ)|
5923 /// − ½ · p · (Σ_k rank S_k) · log λ_smooth
5924 /// ```
5925 ///
5926 /// where `ℓ_pen = loss.total()` is the penalised objective at the inner
5927 /// optimum and `½ log|H|` is the Laplace normaliser. `H` is the joint
5928 /// `(t, β)` Hessian assembled by the arrow-Schur system; its `H_tt` block
5929 /// carries `α = exp(log_ard)` on its diagonal, so as α grows `½ log|H|`
5930 /// rises while the `−½·n·log α` already inside `loss.ard` falls — their
5931 /// balance IS the effective-dof term that the deleted `α = n/‖t‖²` rule
5932 /// dropped, which is why the criterion needs no clamp to stay finite on a
5933 /// collapsing axis.
5934 ///
5935 /// The final `−½·p·rank(S)·log λ_smooth` term is the smoothing-penalty
5936 /// normaliser `−½ log|λ S|_+` restricted to its ρ-dependent part: `S_k` is
5937 /// shared across all `p` decoder output channels (the `⊗ I_p` Kronecker
5938 /// structure), so `log|λ S|_+ = p·rank(S)·log λ + p·log|S|_+`, and the
5939 /// `½ p·log|S|_+` piece is ρ-independent. The ρ-independent additive
5940 /// constants that ARE dropped here (they shift `V` by a constant and do not
5941 /// affect the ρ-argmin) are the `2π` Laplace constant and the base
5942 /// `½ p·log|S|_+` penalty logdet. #1421: NO assignment-prior normalizer is
5943 /// dropped, because none exists (softmax/JumpReLU priors are improper — see
5944 /// the doc on this function): the quasi-Laplace score simply omits a
5945 /// normalizer that is not a finite constant.
5946 ///
5947 /// Returns `(V, loss)` so the engine can both rank ρ and surface the inner
5948 /// loss breakdown.
5949 pub fn reml_criterion(
5950 &mut self,
5951 target: ArrayView2<'_, f64>,
5952 rho: &SaeManifoldRho,
5953 registry: Option<&AnalyticPenaltyRegistry>,
5954 inner_max_iter: usize,
5955 learning_rate: f64,
5956 ridge_ext_coord: f64,
5957 ridge_beta: f64,
5958 ) -> Result<(f64, SaeManifoldLoss), String> {
5959 self.reml_criterion_with_refine_policy(
5960 target,
5961 rho,
5962 registry,
5963 inner_max_iter,
5964 learning_rate,
5965 ridge_ext_coord,
5966 ridge_beta,
5967 true,
5968 )
5969 }
5970
5971 pub(crate) fn reml_criterion_with_refine_policy(
5972 &mut self,
5973 target: ArrayView2<'_, f64>,
5974 rho: &SaeManifoldRho,
5975 registry: Option<&AnalyticPenaltyRegistry>,
5976 inner_max_iter: usize,
5977 learning_rate: f64,
5978 ridge_ext_coord: f64,
5979 ridge_beta: f64,
5980 refine_progress_extension: bool,
5981 ) -> Result<(f64, SaeManifoldLoss), String> {
5982 let plan = self.streaming_plan().admitted_or_error(
5983 self.n_obs(),
5984 self.output_dim(),
5985 self.k_atoms(),
5986 )?;
5987 if plan.streaming {
5988 // #1225: streaming and dense MUST optimize the SAME mathematical
5989 // objective — the full REML criterion `loss.total() + extra_penalty +
5990 // ½ log|H| − Occam`. The streaming branch previously returned only
5991 // `loss.total() + extra_penalty_energy`, dropping the Laplace
5992 // normalizer `½ log|H|` and the Occam term, so large shapes (exactly
5993 // where streaming is needed) were ranked by penalized loss rather than
5994 // REML — and dense vs streaming disagreed on the objective. Route
5995 // through the streaming exact-logdet path, which assembles the same
5996 // chunk-by-chunk-bit-identical `½ log|H|_stream` and the same
5997 // `−Occam`/extra-penalty terms as the dense `reml_criterion_with_cache`
5998 // (different memory strategy, same objective).
5999 self.reml_criterion_streaming_exact(
6000 target,
6001 rho,
6002 registry,
6003 inner_max_iter,
6004 learning_rate,
6005 ridge_ext_coord,
6006 ridge_beta,
6007 )
6008 } else {
6009 let (v, loss, _cache) = self.reml_criterion_with_cache_refine_policy(
6010 target,
6011 rho,
6012 registry,
6013 inner_max_iter,
6014 learning_rate,
6015 ridge_ext_coord,
6016 ridge_beta,
6017 refine_progress_extension,
6018 )?;
6019 Ok((v, loss))
6020 }
6021 }
6022
6023 /// As [`Self::reml_criterion`], but also returns the converged undamped
6024 /// `ArrowFactorCache` so callers (the EFS fixed-point step) can read the
6025 /// selected-inverse traces `(H⁻¹)_tt` / `(H⁻¹)_ββ` without re-factoring.
6026 /// The cache is the single shared O(K³) Direct factor; both the
6027 /// log-determinant criterion and the Fellner-Schall ρ-step consume it.
6028 pub fn reml_criterion_with_cache(
6029 &mut self,
6030 target: ArrayView2<'_, f64>,
6031 rho: &SaeManifoldRho,
6032 registry: Option<&AnalyticPenaltyRegistry>,
6033 inner_max_iter: usize,
6034 learning_rate: f64,
6035 ridge_ext_coord: f64,
6036 ridge_beta: f64,
6037 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6038 self.reml_criterion_with_cache_refine_policy(
6039 target,
6040 rho,
6041 registry,
6042 inner_max_iter,
6043 learning_rate,
6044 ridge_ext_coord,
6045 ridge_beta,
6046 true,
6047 )
6048 }
6049
6050 pub(crate) fn reml_criterion_with_cache_refine_policy(
6051 &mut self,
6052 target: ArrayView2<'_, f64>,
6053 rho: &SaeManifoldRho,
6054 registry: Option<&AnalyticPenaltyRegistry>,
6055 inner_max_iter: usize,
6056 learning_rate: f64,
6057 ridge_ext_coord: f64,
6058 ridge_beta: f64,
6059 refine_progress_extension: bool,
6060 ) -> Result<(f64, SaeManifoldLoss, ArrowFactorCache), String> {
6061 let admission_plan = self.streaming_plan().admitted_or_error(
6062 self.n_obs(),
6063 self.output_dim(),
6064 self.k_atoms(),
6065 )?;
6066 if !admission_plan.direct_logdet_admitted() {
6067 return Err(format!(
6068 "SaeManifoldTerm::reml_criterion_with_cache: predicted working set {} bytes exceeds budget {} bytes for dense evidence cache; shape n={},p={},K={}; cost-only streaming route is required",
6069 admission_plan.estimated_direct_peak_bytes,
6070 admission_plan.in_core_budget_bytes,
6071 self.n_obs(),
6072 self.output_dim(),
6073 self.k_atoms()
6074 ));
6075 }
6076 // 1. Run the inner (t, β) Newton solve to convergence at FIXED ρ.
6077 // `run_joint_fit_arrow_schur` no longer touches ρ.
6078 let mut rho_fixed = rho.clone();
6079 let mut loss = self.run_joint_fit_arrow_schur(
6080 target,
6081 &mut rho_fixed,
6082 registry,
6083 inner_max_iter,
6084 learning_rate,
6085 ridge_ext_coord,
6086 ridge_beta,
6087 )?;
6088
6089 // 2. Drive the inner (t, β) solve to the KKT/step-converged optimum and
6090 // take one final UNDAMPED factor there to obtain the joint Hessian
6091 // log-determinant. We force ridge = 0 and the dense `Direct` Schur
6092 // mode so `arrow_log_det_from_cache` returns the exact
6093 // `log|H| = Σ_i log|H_tt^(i)| + log|Schur_β|` (it rejects damped
6094 // factors and InexactPCG caches, which have no dense Schur factor).
6095 // This is the same evidence convention the main GAM REML path uses.
6096 // The shared `converge_inner_for_undamped_logdet` driver guarantees
6097 // the per-row `H_tt^(i)` blocks are PD at the converged optimum so
6098 // the undamped (`ridge = 0`) factorization succeeds — the streaming
6099 // log-det path reuses the identical driver so both rank the same
6100 // converged Laplace optimum and stay bit-identical.
6101 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
6102 let cache = self.converge_inner_for_undamped_logdet(
6103 target,
6104 rho,
6105 &mut rho_fixed,
6106 registry,
6107 inner_max_iter,
6108 learning_rate,
6109 ridge_ext_coord,
6110 ridge_beta,
6111 &mut loss,
6112 &options,
6113 refine_progress_extension,
6114 )?;
6115 self.record_evidence_gauge_deflation_count(cache.gauge_deflated_directions)?;
6116 loss.evidence_gauge_deflated_directions = cache.gauge_deflated_directions;
6117 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
6118 // Distinguish a GENUINE infeasibility — a probed ρ where the joint
6119 // Hessian is not PD so the Laplace evidence log-det is undefined —
6120 // from a real factorization defect. The cross-row IBP Woodbury
6121 // capacitance `C = I_R + D·Uᵀ H₀'⁻¹ U` can have det ≤ 0 at a ρ the
6122 // outer optimizer line-searches into (the indefinite basin adjacent
6123 // to the PD region); there the log-det legitimately does not exist.
6124 // That refusal must be RECOVERABLE (the outer BFGS should get +∞ and
6125 // steer back into the PD region), exactly like the "non-PD per-row
6126 // H_tt block" refusal — not a fatal `RemlOptimizationFailed` that
6127 // aborts the whole fit. See `is_recoverable_value_probe_refusal`.
6128 // (The old message claimed "no dense Schur factor", which is false
6129 // here — the Schur factor is present; the Woodbury correction is the
6130 // non-finite term.)
6131 if cache.cross_row_woodbury.is_some()
6132 && !cache.cross_row_woodbury_log_det().is_finite()
6133 {
6134 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
6135 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
6136 .to_string()
6137 } else {
6138 "SaeManifoldTerm::reml_criterion: arrow_log_det_from_cache returned None \
6139 (undamped joint Hessian log-det unavailable for the Laplace normaliser)"
6140 .to_string()
6141 }
6142 })?;
6143
6144 // 3. Smoothing-penalty Occam term `−½·Σ_k r_k·rank(S_k)·log λ_smooth`
6145 // plus the profiled-frame evidence-dimension correction
6146 // `+½·Σ_k r_k·(p−r_k)·log λ_smooth` (issue #972). On the full-`B` path
6147 // (`r_k == p`, no frames) this is exactly the historical
6148 // `½·p·(Σ rank S_k)·log λ_smooth`, so the small-model criterion is
6149 // unchanged. The single seam is `reml_occam_term`, shared with the
6150 // streaming path so both rank the identical Laplace dimension count.
6151 let occam = self.reml_occam_term(rho)?;
6152
6153 // Decoder-block analytic-penalty energy (#671/#672). The inner solve
6154 // descended this energy (it enters `gb`/`hbb`) but it had no native
6155 // `loss.*` representative, so the Laplace criterion `v` was scoring a
6156 // different objective than the one minimized. Add the converged
6157 // decoder-penalty value so the ρ-sweep ranks the same penalized
6158 // deviance. Excludes the Psi-tier ARD/assignment penalties already
6159 // accounted for in `loss.total()` (see
6160 // `analytic_decoder_penalty_value_total`).
6161 // Extra analytic-penalty energy (#671/#737). Decoder-block penalties and
6162 // coordinate-tier isometry enter the inner solve but have no `loss.*`
6163 // representative, so the Laplace criterion must add them explicitly to
6164 // rank the same penalized deviance the Newton solve descends.
6165 let extra_penalty_energy = match registry {
6166 Some(reg) => self
6167 .reml_extra_penalty_value_total(reg)
6168 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?,
6169 None => 0.0,
6170 };
6171
6172 let v = loss.total() + extra_penalty_energy + 0.5 * log_det - occam;
6173 Ok((v, loss, cache))
6174 }
6175
6176 /// The #1037 quotient-dimension invariant: a Laplace normalizer `½log|H|` is
6177 /// only comparable across ρ at a COMMON quotient (gauge-deflation) dimension.
6178 /// The first observation pins the expected count; a later match is a no-op.
6179 ///
6180 /// A later observation that DIFFERS is, under the K>1 fit, a LEGITIMATE
6181 /// quotient-dimension event — an atom born, reseeded (the #976 collapse
6182 /// guards), or rank-reduced moves the number of gauge-flat rows. Because a
6183 /// deflated direction is lifted to unit stiffness and contributes the
6184 /// ρ-independent `log 1 = 0` to the evidence, re-anchoring the comparison to
6185 /// the new dimension is exactly evidence-preserving and keeps every future
6186 /// cross-ρ comparison consistent — the principled response, not an abort.
6187 ///
6188 /// The genuine pathology the guard still catches is a count that NEVER
6189 /// STABILIZES: re-anchors are bounded by the per-atom structural-event budget
6190 /// (`k·(reseed_budget+1)+1`), and a runaway quotient dimension past that
6191 /// bound refuses loudly. This supersedes the prior strict-constant guard and
6192 /// its ±1 flicker band (#1117) at root — the band was masking exactly the
6193 /// legitimate K>1 dimension changes this re-anchoring now handles.
6194 pub(crate) fn record_evidence_gauge_deflation_count(
6195 &mut self,
6196 count: usize,
6197 ) -> Result<(), String> {
6198 match self.expected_evidence_gauge_deflated_directions {
6199 Some(expected) if expected == count => Ok(()),
6200 Some(expected) => {
6201 // A change in the gauge-deflation count between two evidence
6202 // factorizations is a legitimate quotient-dimension event under
6203 // the K>1 fit: an atom can be born, reseeded (the #976 collapse
6204 // guards), or rank-reduced across the ρ-walk, and each such event
6205 // moves the number of gauge-flat rows. The #1037 invariant is
6206 // NOT "the count never changes" — it is "two Laplace normalizers
6207 // are only comparable at a COMMON quotient dimension". The
6208 // principled response to a legitimate change is therefore to
6209 // RE-ANCHOR the comparison to the new dimension (so every future
6210 // cross-ρ comparison within the optimization is consistent), not
6211 // to abort the fit. This is exactly evidence-preserving: each
6212 // gauge-deflated direction is lifted to unit stiffness and
6213 // contributes the ρ-independent `log 1 = 0` to `½log|H|`, so the
6214 // converged criterion value is identical whether a given row is
6215 // counted as deflated or not — only the BOOKKEEPING dimension
6216 // must agree across a comparison, and re-anchoring restores that.
6217 //
6218 // The genuine pathology the guard must still catch is a count
6219 // that NEVER STABILIZES — an OSCILLATING quotient dimension that
6220 // re-anchors without converging, signalling a truly ill-posed
6221 // evidence surface. But the deflation count is NOT a discrete
6222 // dictionary-level event count: it is the per-ROW-summed number of
6223 // near-null evidence directions across all N rows (#1217). On real
6224 // K≥2 activations it is an O(N) quantity that drifts SMOOTHLY and
6225 // monotonically as the conditioning improves over the ρ-walk
6226 // (e.g. 171→156→…→113 as smoothing increases) — a benign,
6227 // evidence-neutral change (each deflated direction contributes the
6228 // ρ-independent `log 1 = 0` to `½log|H|`, so re-anchoring never
6229 // moves the criterion value). Charging such a monotone drift
6230 // against a `k`-sized "structural event" budget was wrong: it
6231 // counts threshold crossings of a continuous per-row quantity, not
6232 // atom births/reseeds, so the budget tripped on a perfectly healthy
6233 // converging K=2 fit (#1217 regression from the #1189/#1190
6234 // basin-escape fixes, which shifted which rows sit near the
6235 // deflation floor).
6236 //
6237 // The principled discriminator is DIRECTION REVERSALS: a count
6238 // that drifts one way and settles is benign; a count that bounces
6239 // up and down without settling is the oscillating-quotient
6240 // pathology. We therefore charge the re-anchor budget ONLY on a
6241 // reversal of the change direction, and size the budget by the
6242 // number of distinct dictionary structural events (births/reseeds)
6243 // that can each legitimately flip the drift direction. A monotone
6244 // drift of any length re-anchors freely (it is consistently
6245 // re-anchored and evidence-neutral); a genuinely oscillating count
6246 // exhausts the reversal budget and refuses loudly.
6247 let delta_sign: i8 = if count > expected { 1 } else { -1 };
6248 let is_reversal = self.evidence_gauge_deflation_last_delta_sign != 0
6249 && delta_sign != self.evidence_gauge_deflation_last_delta_sign;
6250 self.evidence_gauge_deflation_last_delta_sign = delta_sign;
6251 if is_reversal {
6252 self.evidence_gauge_deflation_reanchors += 1;
6253 }
6254 let reversal_budget = self
6255 .k_atoms()
6256 .saturating_mul(
6257 SAE_ATOM_COLLAPSE_RESEED_BUDGET
6258 + SAE_DICTIONARY_COCOLLAPSE_RESEED_BUDGET
6259 + 1,
6260 )
6261 .saturating_add(1);
6262 if self.evidence_gauge_deflation_reanchors > reversal_budget {
6263 return Err(format!(
6264 "SaeManifoldTerm::reml_criterion: row-gauge evidence deflation count \
6265 oscillated (reversed direction {} times, last {expected}->{count}) within \
6266 one optimization, exceeding the {reversal_budget}-reversal budget for {} \
6267 atoms; the quotient dimension is not stabilizing, refusing to compare \
6268 Laplace normalizers",
6269 self.evidence_gauge_deflation_reanchors,
6270 self.k_atoms()
6271 ));
6272 }
6273 log::debug!(
6274 "SaeManifoldTerm::reml_criterion: per-row evidence deflation count changed \
6275 {expected}->{count} (a benign per-row conditioning drift across the ρ-walk; \
6276 reversal {}/{reversal_budget}); re-anchoring the Laplace normalizer comparison \
6277 to the new dimension",
6278 self.evidence_gauge_deflation_reanchors
6279 );
6280 self.expected_evidence_gauge_deflated_directions = Some(count);
6281 Ok(())
6282 }
6283 None => {
6284 self.expected_evidence_gauge_deflated_directions = Some(count);
6285 Ok(())
6286 }
6287 }
6288 }
6289
6290 pub(crate) fn is_undamped_evidence_row_non_pd(err: &ArrowSchurError) -> bool {
6291 matches!(
6292 err,
6293 ArrowSchurError::PerRowFactorFailed { reason, .. }
6294 if reason.contains("H_tt is non-PD at base ridge")
6295 && reason.contains("evidence mode preserves the genuine Cholesky")
6296 )
6297 }
6298
6299 /// Drive the inner `(t, β)` Newton solve to the KKT/step-converged optimum
6300 /// and return the final UNDAMPED (`ridge = 0`) joint-Hessian factor cache.
6301 ///
6302 /// The Laplace normaliser `½log|H|` is only the correct REML criterion at
6303 /// the inner optimum `(t̂, β̂)`, so the criterion must refine the inner state
6304 /// until either the KKT gradient or the undamped Newton step meets tolerance
6305 /// before factoring. Crucially, **at the converged optimum the per-row
6306 /// `H_tt^(i)` blocks are PD**, so the undamped (`ridge = 0`) factorization
6307 /// succeeds; an off-optimum iterate (e.g. the initial seed, or a state
6308 /// stopped after only `inner_max_iter` steps) can have an indefinite /
6309 /// rank-deficient per-row block (`p_out = 1` → rank-1 `JᵀJ`, softmax
6310 /// assignment-sparsity negative logit curvature) that surfaces
6311 /// `PerRowFactorFailed` from the undamped `factor_one_row`. Both the dense
6312 /// (`reml_criterion_with_cache`) and the streaming
6313 /// (`reml_criterion_streaming_exact`) evidence paths route through this same
6314 /// driver, so they converge to the identical inner state and their
6315 /// `ridge = 0` log-determinants stay bit-identical (#847).
6316 pub(crate) fn converge_inner_for_undamped_logdet(
6317 &mut self,
6318 target: ArrayView2<'_, f64>,
6319 rho: &SaeManifoldRho,
6320 rho_fixed: &mut SaeManifoldRho,
6321 registry: Option<&AnalyticPenaltyRegistry>,
6322 inner_max_iter: usize,
6323 learning_rate: f64,
6324 ridge_ext_coord: f64,
6325 ridge_beta: f64,
6326 loss: &mut SaeManifoldLoss,
6327 options: &ArrowSolveOptions,
6328 refine_progress_extension: bool,
6329 ) -> Result<ArrowFactorCache, String> {
6330 // `inner_max_iter == 0` is a genuine FREEZE of the inner `(t, β)` state
6331 // — a verbatim warm-start reuse, not a convergence request (gam#577/#579,
6332 // #850). The convergence/refinement loop below MUST NOT run even one
6333 // Newton step in that case (the old `inner_max_iter.max(1)` floor moved
6334 // β off the seed), so we factor exactly once at the frozen iterate and
6335 // return that undamped cache without invoking the stationarity gate.
6336 // The caller has already run `run_joint_fit_arrow_schur(..., 0, ...)`,
6337 // which under the `max_iter == 0` freeze (gam#577/#579, #850) runs ONLY
6338 // the β-neutral basis refresh and returns the loss without touching β —
6339 // it skips the rank-reduction, frame activation, re-seed guards, and the
6340 // #1026 decoder-LSQ polish that would otherwise refit β off the seed — so
6341 // `self` is at the warm-start β here.
6342 if inner_max_iter == 0 {
6343 let sys = self
6344 .assemble_arrow_schur(target, rho, registry)
6345 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6346 let factored = solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options)
6347 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6348 // The frozen-state Newton step (factored.0, factored.1) is discarded
6349 // — only the undamped factor cache (factored.2) is consumed for the
6350 // log-det / selected-inverse traces; β stays at the warm-start seed.
6351 return Ok(factored.2);
6352 }
6353 let mut total_inner_iter = inner_max_iter;
6354 let accepted_base_refine_iter = inner_max_iter.max(1).saturating_mul(16).max(64);
6355 let value_probe_base_refine_iter = inner_max_iter.max(1).saturating_mul(4).max(16);
6356 let base_refine_iter = if refine_progress_extension {
6357 accepted_base_refine_iter
6358 } else {
6359 value_probe_base_refine_iter
6360 };
6361 let progress_refine_iter = if refine_progress_extension {
6362 inner_max_iter.max(1).saturating_mul(64).max(256)
6363 } else {
6364 base_refine_iter
6365 };
6366 let mut previous_refine_grad_norm: Option<f64> = None;
6367 let mut saw_refine_progress = false;
6368 // #1051 — objective-stagnation convergence. On an ill-conditioned
6369 // penalised bilinear fit (the euclidean / Duchon decoder × latent
6370 // coordinate system on a trivial shape), the inner Newton crawls: each
6371 // refine round lowers the penalised objective by a shrinking amount while
6372 // the KKT gradient and the undamped step stay above their relative
6373 // tolerances (the near-singular Schur amplifies the step in the
6374 // weakly-identified decoder direction). The grad-OR-step gate then never
6375 // fires and the solve is rejected as "did not converge" — the 1e12
6376 // sentinel. A Newton/LM iterate whose objective has stopped decreasing
6377 // to within `√εmach` of its scale IS the numerical inner optimum; ranking
6378 // the Laplace criterion there is correct. We accept that fixed point
6379 // instead of grinding the budget.
6380 let entry_loss_total = loss.total();
6381 let mut previous_loss_total = entry_loss_total;
6382 let mut refine_rounds: usize = 0;
6383 // Consecutive stall rounds: counts how many successive refine rounds
6384 // ended in a stall AND a failed undamped factor. Once this reaches
6385 // `SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS` the iterate is at
6386 // its numerical fixed point and cannot be improved further; returning
6387 // `Err` here is the same "did not converge" signal that
6388 // `is_recoverable_value_probe_refusal` already handles, so the outer
6389 // BFGS treats it as an INFINITY probe and tries a different ρ instead
6390 // of looping forever burning the extended progress budget. Without
6391 // this counter the stagnation handler fell through when the undamped
6392 // factor failed and the loop kept extending via `saw_refine_progress`
6393 // from earlier rounds, accumulating minutes of wasted work (#1094).
6394 let mut consecutive_stall_factor_fail: usize = 0;
6395 loop {
6396 let sys = self
6397 .assemble_arrow_schur(target, rho, registry)
6398 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6399 // Evidence-only factorization: the Newton step (Δt, Δβ) is discarded
6400 // and only the factor cache is consumed — the exact undamped log-det
6401 // and the selected-inverse traces. As ρ sweeps to extremes (e.g. a
6402 // wide ARD-α sweep), H_tt is genuinely PD but can be ill-conditioned;
6403 // the standard Direct guard rejects that to protect Newton-step
6404 // accuracy, but the log-det is exact from diag(L) regardless of the
6405 // condition number and the traces only need the (PD) factor. So
6406 // tolerate the ill-conditioning rejection here (a genuine non-PD pivot
6407 // still errors). The cache stays undamped at ridge=0, so
6408 // `arrow_log_det_from_cache` remains exact.
6409 // The exact KKT stationarity residual is the joint gradient
6410 // ‖g‖ = √(Σ_i ‖g_t^(i)‖² + ‖g_β‖²), read straight off the assembled
6411 // system. Unlike the Newton step Δ = H⁻¹g, the gradient is
6412 // factorisation-independent: it is NOT amplified by an inverse, so a
6413 // genuinely stationary but ill-conditioned fit (tiny g, possibly large
6414 // Δ in a flat direction) is correctly recognised as converged. The
6415 // `with_ill_conditioning_tolerated` Direct factor below documents that
6416 // its Δ may be inaccurate in exactly those flat directions, so using Δ
6417 // alone as the convergence gate would falsely reject healthy fits.
6418 let grad_norm_sq: f64 = sys
6419 .rows
6420 .iter()
6421 .map(|row| row.gt.iter().map(|&v| v * v).sum::<f64>())
6422 .sum::<f64>()
6423 + sys.gb.iter().map(|&v| v * v).sum::<f64>();
6424 let grad_norm = grad_norm_sq.sqrt();
6425 // Quotient KKT-gradient (#1117): the raw joint gradient retains a
6426 // persistent small component in the chart-gauge orbit and the
6427 // rank-deficient decoder β-null even at a stationary fit, so the raw
6428 // grad gate never clears on a rank-deficient circle and the inner
6429 // refine loop crawls until the (large) progress budget dies — the
6430 // 2-min stall. Measure the gradient on the SAME identified quotient
6431 // the step gate already uses: a fit whose only remaining gradient
6432 // lives in those flat directions is stationary on the quotient, so
6433 // ranking the Laplace criterion there is correct. The dense per-row
6434 // g_t is laid into the `n·q` coordinate layout the gauge basis spans;
6435 // non-dense/heterogeneous systems fall back to the raw norm.
6436 let quotient_grad_norm = {
6437 let n = self.n_obs();
6438 let q = self.assignment.row_block_dim();
6439 let dense_len = n.saturating_mul(q);
6440 let mut grad_ext_coord = Array1::<f64>::zeros(dense_len);
6441 let mut dense_layout_ok = sys.rows.len() == n;
6442 if dense_layout_ok {
6443 for (row_idx, row) in sys.rows.iter().enumerate() {
6444 let base = sys.row_offsets[row_idx];
6445 let di = sys.row_dims[row_idx];
6446 if base + di > dense_len || row.gt.len() < di {
6447 dense_layout_ok = false;
6448 break;
6449 }
6450 for axis in 0..di {
6451 grad_ext_coord[base + axis] = row.gt[axis];
6452 }
6453 }
6454 }
6455 if dense_layout_ok {
6456 self.quotient_gradient_norm_sq(
6457 grad_ext_coord.view(),
6458 sys.gb.view(),
6459 grad_norm_sq,
6460 &rho_fixed.lambda_smooth_vec(),
6461 )
6462 .map(|v| v.sqrt())
6463 .unwrap_or(grad_norm)
6464 } else {
6465 grad_norm
6466 }
6467 };
6468 let iterate_scale = self.inner_iterate_scale();
6469 // Relative parameter-step tolerance for Δ (well-conditioned charts)
6470 // and a scaled KKT-gradient tolerance. Convergence is accepted on
6471 // EITHER a small KKT gradient OR a small undamped Newton step: SAE
6472 // manifold fits contain gauge-like coordinate/decoder directions (the
6473 // circle's rotation gauge, decoder column-space rotations) where the
6474 // shared-block Hessian is near-singular, so the undamped step can stay
6475 // large in that flat direction even at a genuine stationary point; the
6476 // gradient, which is not amplified by the inverse, recognises it. With
6477 // the isometry Gauss-Newton block now a coherent PSD pullback (no
6478 // indefinite Schur pivot), the inner solve reaches true stationarity,
6479 // so the gradient tolerance is a standard relative KKT residual rather
6480 // than the 0.1.154-regression band-aid (3e-3) that masked the
6481 // non-convergence the indefinite curvature caused.
6482 let step_tolerance = SAE_MANIFOLD_INNER_STEP_REL_TOL * iterate_scale;
6483 let grad_tolerance = SAE_MANIFOLD_INNER_GRAD_REL_TOL * iterate_scale;
6484 if !grad_norm_sq.is_finite() {
6485 return Err(format!(
6486 "SaeManifoldTerm::reml_criterion: undamped inner KKT residual is non-finite \
6487 at the inner optimum (‖g‖²={grad_norm_sq}); the joint Hessian \
6488 factorisation is degenerate at this ρ"
6489 ));
6490 }
6491 let (delta_t, delta_beta, cache): (Array1<f64>, Array1<f64>, ArrowFactorCache) =
6492 match solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, options) {
6493 Ok(factored) => factored,
6494 Err(err) if Self::is_undamped_evidence_row_non_pd(&err) => {
6495 if grad_norm <= grad_tolerance || quotient_grad_norm <= grad_tolerance {
6496 // K>1: the softmax/IBP logit–coordinate Gauss-Newton
6497 // cross-terms (H_zt = J_z^T J_t, assembled row-locally from
6498 // the assignment JVP × basis JVP) can make a per-row H_tt
6499 // indefinite at the TRUE KKT stationary point — when two
6500 // atoms' decoders specialise in opposite directions the
6501 // Schur complement of the logit block goes negative even
6502 // though the priors and the full-joint GN term are PSD.
6503 //
6504 // The undamped evidence factor already conditions that
6505 // block the PRINCIPLED way: `factor_spectral_deflated_
6506 // evidence_row` discovers the negative/flat eigen-direction
6507 // and stiffens it to UNIT curvature (eigenvalue → +1), so it
6508 // contributes a ρ-INDEPENDENT log 1 = 0 to the evidence —
6509 // the same quotient pseudo-determinant convention the gauge
6510 // (#1037) and data-null (#1117) deflations use. Reaching
6511 // THIS arm at stationarity therefore means even the spectral
6512 // deflation declined (a non-finite block or a failed
6513 // eigendecomposition): the state is genuinely broken, so we
6514 // surface the hard refusal and let the outer BFGS treat this
6515 // ρ as an INFINITY probe (`is_recoverable_value_probe_
6516 // refusal`). We must NOT ridge-damp here: a `+ridge·I`
6517 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6518 // bias into the VALUE that the analytic ρ-gradient (built
6519 // for the undamped Laplace log-det) never sees, desyncing
6520 // the outer line-search — the multi-atom non-convergence
6521 // this fix (#1117) removes.
6522 return Err(format!(
6523 "SaeManifoldTerm::reml_criterion: stationary undamped \
6524 evidence factorization has a non-PD per-row H_tt block \
6525 that spectral unit-stiffness deflation could not \
6526 condition (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}); \
6527 {err}"
6528 ));
6529 }
6530 let refine_limit = Self::refine_iteration_limit(
6531 total_inner_iter,
6532 base_refine_iter,
6533 progress_refine_iter,
6534 previous_refine_grad_norm,
6535 grad_norm,
6536 saw_refine_progress,
6537 );
6538 if total_inner_iter >= refine_limit {
6539 // #1117/#1118 — pre-stationarity genuinely-indefinite
6540 // non-gauge H_tt under K>1 IBP/softmax row-sharing. The
6541 // logit × coordinate Gauss-Newton cross term H_zt = J_zᵀJ_t
6542 // can drive a shared row's H_tt Schur complement NEGATIVE off
6543 // the gauge orbit; the LM-escalated refinement above cannot
6544 // always cross the indefinite basin into the PD region within
6545 // the descent-extended budget.
6546 //
6547 // The undamped (ridge=0) evidence factor already conditions
6548 // that block the PRINCIPLED way: `factor_spectral_deflated_
6549 // evidence_row` discovers the negative/flat eigen-direction
6550 // and stiffens it to UNIT curvature (eigenvalue → +1), a
6551 // ρ-INDEPENDENT log 1 = 0 evidence contribution — so the
6552 // `Ok(factored)` arm above accepts the indefinite block and
6553 // returns a finite, monotone-comparable value to the outer
6554 // BFGS WITHOUT a ρ-dependent bias. Reaching THIS arm means
6555 // even that spectral deflation declined (a non-finite block
6556 // or a failed eigendecomposition): the iterate is genuinely
6557 // broken, so we surface the hard refusal and let the outer
6558 // BFGS treat this ρ as an INFINITY probe.
6559 //
6560 // We must NOT ridge-damp here: a `+ridge·I` evidence
6561 // fallback injects a ρ-dependent ½·log|I + ridge·H_tt⁻¹|
6562 // bias into the VALUE that the analytic ρ-gradient (built
6563 // for the undamped Laplace log-det) never sees, desyncing
6564 // the outer line-search — the multi-atom non-convergence this
6565 // fix removes. K=1 (and any already-PD or spectral-deflatable
6566 // K>1 row) never reaches this branch.
6567 return Err(format!(
6568 "SaeManifoldTerm::reml_criterion: undamped evidence \
6569 factorization hit a non-PD per-row H_tt block before KKT \
6570 stationarity (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) \
6571 and the refinement budget was exhausted after \
6572 {total_inner_iter} inner iterations; {err}"
6573 ));
6574 }
6575 let remaining = refine_limit - total_inner_iter;
6576 let refine_iter = inner_max_iter.max(1).min(remaining);
6577 saw_refine_progress |=
6578 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6579 previous_refine_grad_norm = Some(grad_norm);
6580 *loss = self.run_joint_fit_arrow_schur(
6581 target,
6582 rho_fixed,
6583 registry,
6584 refine_iter,
6585 learning_rate,
6586 ridge_ext_coord,
6587 ridge_beta,
6588 )?;
6589 total_inner_iter += refine_iter;
6590 continue;
6591 }
6592 Err(err) => {
6593 return Err(format!("SaeManifoldTerm::reml_criterion: {err}"));
6594 }
6595 };
6596 // The Laplace normaliser ½log|H| is only the correct REML criterion at
6597 // the inner optimum (t̂, β̂). Convergence is judged by EITHER a small
6598 // gradient (KKT stationarity) OR a small undamped Newton step; the
6599 // solve is only rejected as non-converged when BOTH are large, i.e.
6600 // the iterate is neither stationary nor about to move negligibly. That
6601 // disjunction is what keeps an ill-conditioned-but-stationary fit
6602 // (small g, large Δ) from being rejected while still refusing to rank
6603 // an off-optimum Laplace criterion that is genuinely mid-flight.
6604 let step_norm_sq: f64 = delta_t.iter().map(|&v| v * v).sum::<f64>()
6605 + delta_beta.iter().map(|&v| v * v).sum::<f64>();
6606 if !step_norm_sq.is_finite() {
6607 return Err(format!(
6608 "SaeManifoldTerm::reml_criterion: undamped inner residual is non-finite at \
6609 the inner optimum (‖Δ‖²={step_norm_sq}, ‖g‖²={grad_norm_sq}); the joint \
6610 Hessian factorisation is degenerate at this ρ"
6611 ));
6612 }
6613 let step_norm = step_norm_sq.sqrt();
6614 let quotient_step_norm_sq = self.quotient_newton_step_norm_sq(
6615 delta_t.view(),
6616 delta_beta.view(),
6617 step_norm_sq,
6618 &rho_fixed.lambda_smooth_vec(),
6619 )?;
6620 let quotient_step_norm = quotient_step_norm_sq.sqrt();
6621 // Converge on ANY of: the raw KKT gradient (well-conditioned fit),
6622 // the QUOTIENT KKT gradient (#1117 — rank-deficient fit whose only
6623 // residual gradient is gauge/null flat-direction crawl), or the
6624 // quotient Newton step. The quotient-gradient disjunct is what lets
6625 // a rank-deficient K=1 circle terminate in budget instead of crawling
6626 // the weakly-identified valley until the refine budget dies.
6627 if grad_norm <= grad_tolerance
6628 || quotient_grad_norm <= grad_tolerance
6629 || quotient_step_norm <= step_tolerance
6630 {
6631 return Ok(cache);
6632 }
6633 let refine_limit = Self::refine_iteration_limit(
6634 total_inner_iter,
6635 base_refine_iter,
6636 progress_refine_iter,
6637 previous_refine_grad_norm,
6638 grad_norm,
6639 saw_refine_progress,
6640 );
6641 if total_inner_iter >= refine_limit {
6642 // Inner solve did not converge in reml_criterion; the returned
6643 // Err below carries the full non-convergence diagnostic
6644 // (gradient / quotient-step norms and tolerances) to the caller.
6645 return Err(format!(
6646 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6647 neither the KKT gradient ‖g‖={grad_norm:.6e} (tol {grad_tolerance:.6e}) nor \
6648 the quotient Newton step ‖Π⊥gauge Δ‖={quotient_step_norm:.6e} \
6649 (raw ‖Δ‖={step_norm:.6e}, tol {step_tolerance:.6e}) met \
6650 tolerance after {total_inner_iter} inner iterations. Refusing to rank an \
6651 off-optimum Laplace criterion."
6652 ));
6653 }
6654 let remaining = refine_limit - total_inner_iter;
6655 let refine_iter = inner_max_iter.max(1).min(remaining);
6656 saw_refine_progress |=
6657 Self::refine_round_made_progress(previous_refine_grad_norm, grad_norm);
6658 previous_refine_grad_norm = Some(grad_norm);
6659 *loss = self.run_joint_fit_arrow_schur(
6660 target,
6661 rho_fixed,
6662 registry,
6663 refine_iter,
6664 learning_rate,
6665 ridge_ext_coord,
6666 ridge_beta,
6667 )?;
6668 total_inner_iter += refine_iter;
6669 refine_rounds += 1;
6670 // #1051 — objective-stagnation fixed point. A whole refine round that
6671 // failed to lower the penalised objective by a meaningful FRACTION of
6672 // the total since-entry reduction means the Newton/LM iterate is at
6673 // its numerical optimum: the remaining KKT residual lives in the
6674 // weakly-identified decoder / gauge directions the near-singular Schur
6675 // cannot resolve. Ranking the Laplace criterion at this fixed point is
6676 // correct (the only further motion is cosmetic flat-valley crawl), so
6677 // accept the current cache instead of refining until the budget dies.
6678 // Requires a few completed refine rounds (so the fraction baseline is
6679 // meaningful) but is NOT gated behind the full refine budget — the
6680 // whole point is to terminate the crawl long before that.
6681 let new_loss_total = loss.total();
6682 // Two stagnation signals, both required: (1) the latest refine round
6683 // contributed a negligible FRACTION of the total objective reduction
6684 // achieved since entry — the fit has captured essentially all the
6685 // achievable improvement and is now crawling cosmetically along the
6686 // weakly-identified valley; (2) the absolute relative decrease is
6687 // itself tiny. The fraction test is scale- and rate-free (it fires
6688 // whether the crawl decays fast or slow), so it recognises the
6689 // over-smoothed / rank-deficient fixed point the bare relative floor
6690 // misses, while still never firing on a fit that is materially
6691 // improving round over round.
6692 let total_improvement = (entry_loss_total - new_loss_total).max(0.0);
6693 let round_improvement = (previous_loss_total - new_loss_total).max(0.0);
6694 let objective_scale = previous_loss_total.abs().max(new_loss_total.abs()) + 1.0;
6695 let relative_decrease = round_improvement / objective_scale;
6696 let captured_fraction = if total_improvement > 0.0 {
6697 round_improvement / total_improvement
6698 } else {
6699 0.0
6700 };
6701 let stalled = new_loss_total.is_finite()
6702 && relative_decrease.is_finite()
6703 && (relative_decrease < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_REL_TOL
6704 || captured_fraction < SAE_MANIFOLD_INNER_OBJECTIVE_STALL_FRACTION);
6705 previous_loss_total = new_loss_total;
6706 if stalled && refine_rounds >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6707 let stationary_sys = self
6708 .assemble_arrow_schur(target, rho_fixed, registry)
6709 .map_err(|err| format!("SaeManifoldTerm::reml_criterion: {err}"))?;
6710 if let Ok((_dt, _db, stationary_cache)) =
6711 solve_arrow_newton_step_with_options(&stationary_sys, 0.0, 0.0, options)
6712 {
6713 return Ok(stationary_cache);
6714 }
6715 // Stagnated AND the undamped factor still fails: this is the
6716 // numerical fixed point of the inner solve under rank-deficient
6717 // or ill-conditioned geometry (e.g. multi-atom euclidean with
6718 // near-zero initial latent coords, #1094). The iterate cannot
6719 // be improved further at this ρ. Treat it as "inner solve did
6720 // not converge" — the same signal `is_recoverable_value_probe_refusal`
6721 // already handles, causing the outer BFGS to return INFINITY for
6722 // this ρ probe and try a different one. Without this early
6723 // return the stagnation handler fell through and the loop kept
6724 // burning the extended `progress_refine_iter` budget indefinitely.
6725 consecutive_stall_factor_fail += 1;
6726 if consecutive_stall_factor_fail >= SAE_MANIFOLD_INNER_OBJECTIVE_STALL_MIN_ROUNDS {
6727 return Err(format!(
6728 "SaeManifoldTerm::reml_criterion: inner solve did not converge at fixed ρ; \
6729 objective stalled for {consecutive_stall_factor_fail} consecutive refine \
6730 rounds (‖g‖={grad_norm:.6e}, tol {grad_tolerance:.6e}) and the undamped \
6731 evidence factorization failed at each stall point — the iterate is at the \
6732 numerical fixed point under rank-deficient geometry (#{consecutive_stall_factor_fail} \
6733 stall-factor-fail rounds; refusing to rank an off-optimum Laplace criterion)"
6734 ));
6735 }
6736 } else {
6737 consecutive_stall_factor_fail = 0;
6738 }
6739 }
6740 }
6741
6742 pub(crate) fn refine_iteration_limit(
6743 total_inner_iter: usize,
6744 base_refine_iter: usize,
6745 progress_refine_iter: usize,
6746 previous_grad_norm: Option<f64>,
6747 grad_norm: f64,
6748 saw_refine_progress: bool,
6749 ) -> usize {
6750 // Flat affine-gauge valleys can keep crawling productively after the
6751 // historical base budget. Extend only when the measured KKT residual has
6752 // shown a real finite round-to-round drop; true stalls end at the base
6753 // work budget (#968/#1029). Value-order probes pass the base budget as
6754 // their progress budget, so this branch cannot make probes expensive.
6755 if total_inner_iter < base_refine_iter {
6756 return base_refine_iter;
6757 }
6758 let making_progress =
6759 saw_refine_progress || Self::refine_round_made_progress(previous_grad_norm, grad_norm);
6760 if making_progress && grad_norm.is_finite() {
6761 progress_refine_iter
6762 } else {
6763 base_refine_iter
6764 }
6765 }
6766
6767 pub(crate) fn refine_round_made_progress(
6768 previous_grad_norm: Option<f64>,
6769 grad_norm: f64,
6770 ) -> bool {
6771 previous_grad_norm
6772 .is_some_and(|prev| prev.is_finite() && grad_norm.is_finite() && grad_norm < prev)
6773 }
6774
6775 pub(crate) fn outer_gradient_arrow_solver<'a>(
6776 &'a self,
6777 cache: &'a ArrowFactorCache,
6778 penalized_gram_scale: &[f64],
6779 ) -> Result<DeflatedArrowSolver<'a>, OuterGradientError> {
6780 let Err(conditioning_err) = Self::outer_gradient_conditioning_error(cache) else {
6781 return Ok(DeflatedArrowSolver::plain(cache));
6782 };
6783 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
6784 return Err(conditioning_err);
6785 };
6786 if !(max_pivot.is_finite() && max_pivot > 0.0) {
6787 return Err(conditioning_err);
6788 }
6789
6790 // The conditioning gate has already flagged a near-singular joint Hessian
6791 // (`conditioning_err`). Below we attempt to attribute that flatness to the
6792 // closed-form gauge orbit (chart step gauges) plus the penalty-aware
6793 // decoder-null directions and deflate it. When NO such deflatable
6794 // direction can be recovered, the flat subspace is genuinely
6795 // non-identifiable -- a degenerate direction OUTSIDE the gauge orbit -- a
6796 // diagnosis distinct from the raw pivot-ratio conditioning trip. Both
6797 // classes are #1273 FD-eligible, but surfacing the gauge-degenerate case
6798 // as its own [`OuterGradientError::NonIdentifiable`] keeps the diagnostic
6799 // distinction the FD-eligibility contract is built around.
6800 let non_identifiable_err = OuterGradientError::NonIdentifiable {
6801 reason: format!(
6802 "near-singular joint Hessian with no deflatable gauge/decoder-null \
6803 direction (max pivot {max_pivot:.3e})"
6804 ),
6805 };
6806
6807 let full_len = cache.delta_t_len() + cache.k;
6808 let mut raw_gauges = Vec::new();
6809 for gauge in self
6810 .dense_step_gauge_vectors()
6811 .map_err(OuterGradientError::internal)?
6812 {
6813 if gauge.len() != full_len {
6814 continue;
6815 }
6816 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6817 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6818 continue;
6819 }
6820 raw_gauges.push(gauge);
6821 }
6822 // #1051/#1273: admit the penalty-aware decoder-β null directions as
6823 // additional deflation candidates. A rank-deficient decoder design
6824 // (e.g. a euclidean-1D line in a p=2 ambient: decoder column rank 1 of
6825 // 3) puts a genuine near-null direction of the joint Hessian in the β
6826 // block, OUTSIDE the closed-form chart gauge orbit. #1273: probing the
6827 // RAW unit-β basis `e_j` produced an INCOMPLETE candidate set — the
6828 // true flat direction is the penalised null of `G_k + λ_smooth·S_k`,
6829 // not an axis-aligned coordinate, so the outer gate rejected trial ρ
6830 // with a pivot ratio (5.3e-16 < 1e-12) that the inner gate (which
6831 // already uses `decoder_beta_null_directions(λ_smooth)`) accepts. Use
6832 // the SAME penalty-aware null directions here, evaluated at the smooth
6833 // scale the Schur factor used, so the outer and inner gates agree.
6834 // These full (n·q + beta_dim)-length vectors drop into the same
6835 // Gram-Schmidt + Rayleigh + Faddeev-Popov path below; the Rayleigh
6836 // floor still keeps only genuinely flat (sub-floor) directions, so a
6837 // well-conditioned decoder is unaffected.
6838 for dir in self
6839 .decoder_beta_null_directions(penalized_gram_scale)
6840 .map_err(OuterGradientError::internal)?
6841 {
6842 if dir.len() == full_len {
6843 raw_gauges.push(dir);
6844 }
6845 }
6846 // #1051/#1273: also admit the decoder COLUMN-SPAN null (an unrealised
6847 // ambient output channel of a rank-deficient decoder), which the
6848 // channel-free basis-null above structurally cannot represent. The
6849 // rank-1-decoder-line geometry (e.g. a 1-D euclidean line in p=2
6850 // ambient: decoder column rank 1 of 2) puts the joint Hessian's
6851 // sub-floor pivot entirely in one output channel; without this
6852 // candidate the outer gate had nothing to deflate it with and rejected
6853 // the trial ρ. The Rayleigh floor below still prunes any candidate that
6854 // is not genuinely flat against the cached Hessian.
6855 for dir in self
6856 .decoder_channel_null_directions()
6857 .map_err(OuterGradientError::internal)?
6858 {
6859 if dir.len() == full_len {
6860 raw_gauges.push(dir);
6861 }
6862 }
6863 if raw_gauges.is_empty() {
6864 return Err(non_identifiable_err);
6865 }
6866
6867 let mut gauge_span: Vec<Array1<f64>> = Vec::new();
6868 for mut gauge in raw_gauges {
6869 for basis in &gauge_span {
6870 let coeff = gauge.dot(basis);
6871 for i in 0..gauge.len() {
6872 gauge[i] -= coeff * basis[i];
6873 }
6874 }
6875 let norm_sq = gauge.iter().map(|v| v * v).sum::<f64>();
6876 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6877 continue;
6878 }
6879 let inv_norm = norm_sq.sqrt().recip();
6880 for value in gauge.iter_mut() {
6881 *value *= inv_norm;
6882 }
6883 gauge_span.push(gauge);
6884 }
6885 if gauge_span.is_empty() {
6886 return Err(non_identifiable_err);
6887 }
6888
6889 let span_rank = gauge_span.len();
6890 let mut h_span = Array2::<f64>::zeros((span_rank, span_rank));
6891 for col in 0..span_rank {
6892 let h_gauge = match apply_cached_arrow_hessian(
6893 cache,
6894 gauge_span[col].slice(s![..cache.delta_t_len()]),
6895 gauge_span[col].slice(s![cache.delta_t_len()..]),
6896 ) {
6897 Ok(value) => value,
6898 // #1451: a shape/dimension mismatch or non-finite intermediate
6899 // from the Hessian apply is an internal-invariant defect and MUST
6900 // propagate; only a genuine numeric failure on a finite,
6901 // correctly-shaped input keeps the FD-eligible conditioning class.
6902 Err(err) => {
6903 return Err(OuterGradientError::classify_arrow_solver_error(
6904 &err,
6905 conditioning_err.clone(),
6906 ));
6907 }
6908 };
6909 let h_flat = flatten_arrow_parts(h_gauge.t.view(), h_gauge.beta.view());
6910 for row in 0..span_rank {
6911 h_span[[row, col]] = gauge_span[row].dot(&h_flat);
6912 }
6913 }
6914 for row in 0..span_rank {
6915 for col in 0..row {
6916 let sym = 0.5 * (h_span[[row, col]] + h_span[[col, row]]);
6917 h_span[[row, col]] = sym;
6918 h_span[[col, row]] = sym;
6919 }
6920 }
6921 // #1451: a non-finite entry in the projected gauge Hessian is an
6922 // internal-invariant defect (a NaN/Inf intermediate leaked into the
6923 // span), not a conditioning failure — it MUST propagate rather than be
6924 // masked behind an FD descent. Guard finiteness BEFORE the eigh so only a
6925 // genuine decomposition failure on a finite, correctly-shaped matrix keeps
6926 // the FD-eligible conditioning class.
6927 if !h_span.iter().all(|v| v.is_finite()) {
6928 return Err(OuterGradientError::internal(format!(
6929 "outer_gradient_arrow_solver: non-finite entry in projected gauge \
6930 Hessian (h_span is {span_rank}x{span_rank})"
6931 )));
6932 }
6933 let (evals, evecs) = h_span
6934 .eigh(Side::Lower)
6935 .map_err(|_| conditioning_err.clone())?;
6936 let strict_gauge_floor = SAE_OUTER_GRADIENT_GAUGE_RAYLEIGH_FACTOR * max_pivot;
6937 let mut orthonormal: Vec<Array1<f64>> = Vec::new();
6938 for eig_idx in 0..evals.len() {
6939 let rayleigh = evals[eig_idx];
6940 if !(rayleigh.is_finite() && rayleigh <= strict_gauge_floor) {
6941 continue;
6942 }
6943 let mut direction = Array1::<f64>::zeros(full_len);
6944 for basis_idx in 0..span_rank {
6945 let coeff = evecs[[basis_idx, eig_idx]];
6946 for row in 0..full_len {
6947 direction[row] += coeff * gauge_span[basis_idx][row];
6948 }
6949 }
6950 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6951 if !(norm_sq.is_finite() && norm_sq > 1.0e-24) {
6952 continue;
6953 }
6954 let inv_norm = norm_sq.sqrt().recip();
6955 for value in direction.iter_mut() {
6956 *value *= inv_norm;
6957 }
6958 orthonormal.push(direction);
6959 }
6960 if orthonormal.is_empty() {
6961 // #1273/#1440: the conditioning gate has ALREADY certified a
6962 // near-singular joint Hessian (`conditioning_err`), so a genuine flat
6963 // direction exists inside the assembled gauge/decoder-null span even
6964 // when no projected-Hessian eigenvector cleared the strict or the
6965 // `fallback_gauge_floor` Rayleigh band. Rather than declining
6966 // (which historically routed the outer step to a finite-difference
6967 // descent direction — the FD instrument #1440 removes), deflate the
6968 // SMALLEST-Rayleigh eigenvector of the projected gauge Hessian
6969 // UNCONDITIONALLY. That eigenvector is the least-curvature member of
6970 // the validated gauge span (a Faddeev-Popov gauge candidate), so the
6971 // Tikhonov stiffness `max_pivot` in `from_orthonormal_gauges` bounds
6972 // its contribution at the Hessian scale and the components orthogonal
6973 // to it are byte-for-byte the plain analytic inverse solve. This keeps
6974 // the descent direction fully ANALYTIC (a projected/damped gradient),
6975 // never a differenced value path.
6976 let mut best_idx = None;
6977 let mut best_rayleigh = f64::INFINITY;
6978 for eig_idx in 0..evals.len() {
6979 let rayleigh = evals[eig_idx];
6980 if rayleigh.is_finite() && rayleigh < best_rayleigh {
6981 best_idx = Some(eig_idx);
6982 best_rayleigh = rayleigh;
6983 }
6984 }
6985 if let Some(eig_idx) = best_idx {
6986 let mut direction = Array1::<f64>::zeros(full_len);
6987 for basis_idx in 0..span_rank {
6988 let coeff = evecs[[basis_idx, eig_idx]];
6989 for row in 0..full_len {
6990 direction[row] += coeff * gauge_span[basis_idx][row];
6991 }
6992 }
6993 let norm_sq = direction.iter().map(|v| v * v).sum::<f64>();
6994 if norm_sq.is_finite() && norm_sq > 1.0e-24 {
6995 let inv_norm = norm_sq.sqrt().recip();
6996 for value in direction.iter_mut() {
6997 *value *= inv_norm;
6998 }
6999 orthonormal.push(direction);
7000 }
7001 }
7002 }
7003 if orthonormal.is_empty() {
7004 return Err(non_identifiable_err);
7005 }
7006
7007 // Quotient-geometry gauge fixing: add stiffness only along the closed-form
7008 // gauge orbit (Faddeev-Popov style). Components orthogonal to that orbit
7009 // are identical to the original inverse solve, while gauge components are
7010 // bounded at the Hessian scale `max_pivot`.
7011 // #1451: a shape/length mismatch or non-finite stiffness/intermediate in
7012 // the deflated-solver assembly is an internal-invariant defect and MUST
7013 // propagate; only a genuine near-singular gauge Woodbury/back-solve keeps
7014 // the FD-eligible conditioning class.
7015 DeflatedArrowSolver::from_orthonormal_gauges(cache, orthonormal, max_pivot)
7016 .map_err(|err| OuterGradientError::classify_arrow_solver_error(&err, conditioning_err))
7017 }
7018
7019 pub(crate) fn outer_gradient_conditioning_error(
7020 cache: &ArrowFactorCache,
7021 ) -> Result<(), OuterGradientError> {
7022 let pivot = arrow_factor_min_pivot(cache);
7023 let Some(min_pivot) = pivot.min_pivot else {
7024 return Err(OuterGradientError::IllConditioned {
7025 reason: "joint Hessian numerically singular (no cached Cholesky pivots)"
7026 .to_string(),
7027 });
7028 };
7029 let Some(max_pivot) = arrow_factor_max_pivot(cache) else {
7030 return Err(OuterGradientError::IllConditioned {
7031 reason: "joint Hessian numerically singular (no cached Cholesky pivot scale)"
7032 .to_string(),
7033 });
7034 };
7035 let ratio = min_pivot / max_pivot;
7036 if min_pivot.is_finite()
7037 && max_pivot.is_finite()
7038 && max_pivot > 0.0
7039 && ratio.is_finite()
7040 && ratio >= SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR
7041 {
7042 return Ok(());
7043 }
7044 Err(OuterGradientError::IllConditioned {
7045 reason: format!(
7046 "joint Hessian numerically singular (min/max pivot ratio {ratio:.3e} < floor {floor:.3e}; min pivot {min_pivot:.3e}, max pivot {max_pivot:.3e})",
7047 floor = SAE_OUTER_GRADIENT_PIVOT_RATIO_FLOOR,
7048 ),
7049 })
7050 }
7051
7052 /// Smoothing-penalty Occam normalizer `−½ Σ_k r_k·rank(S_k)·log λ_smooth`
7053 /// PLUS the profiled-frame evidence-dimension term `½ Σ_k r_k·(p−r_k)·log
7054 /// λ_smooth` (issue #972).
7055 ///
7056 /// On the full-`B` path every atom's frame rank `r_k == p`, so the first
7057 /// piece reduces to the historical `½ p·(Σ rank S_k)·log λ_smooth` and the
7058 /// Grassmann term is zero — bit-for-bit unchanged. When a frame is active the
7059 /// decoder coordinates `C_k` carry the `⊗ I_{r_k}` Kronecker structure (the
7060 /// smoothing penalty `S_k` now acts on `r_k` channels, not `p`), so the
7061 /// penalty-logdet normalizer uses `r_k·rank(S_k)`; and the `r_k·(p−r_k)`
7062 /// frame degrees of freedom profiled OUT of the border are counted explicitly
7063 /// in the Laplace dimension accounting (evidence honesty) so the criterion
7064 /// cannot buy a free evidence boost by hiding decoder freedom in the frame.
7065 pub(crate) fn reml_occam_term(&self, rho: &SaeManifoldRho) -> Result<f64, String> {
7066 // #1556: λ_smooth is per-atom, so the Occam penalty normalizer and the
7067 // profiled-frame evidence-dimension term are both per-atom sums, each
7068 // atom `k` weighted by its own `log λ_smooth[k]`. With a uniform
7069 // (broadcast) vector this is bit-for-bit the historical global form.
7070 let mut acc = 0.0_f64;
7071 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7072 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7073 // Penalized decoder dimension: `r_k` coordinate channels carry the
7074 // `S_k` roughness penalty (full-`B` path ⇒ `r_k == p`).
7075 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7076 // Profiled Grassmann dimensions enter the Laplace evidence dimension
7077 // count with the OPPOSITE sign of the penalty Occam term (they are
7078 // free, unpenalized-by-`S` profiled directions), so `−occam` adds
7079 // `+½ r(p−r) log λ_k` to the criterion `V` — the honesty correction.
7080 let frame_dim = atom.frame_manifold_dimension();
7081 let log_lambda = rho.log_lambda_smooth[atom_idx];
7082 acc += 0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)) * log_lambda;
7083 }
7084 // `V = … − occam`, so the net occam SUBTRACTS the penalty normalizer and
7085 // ADDS the frame-dimension count after the caller's `− occam`.
7086 Ok(acc)
7087 }
7088
7089 /// Per-atom derivative `∂(occam)/∂log λ_smooth[k]` (#1556): atom `k`'s entry
7090 /// is `½·(r_k·rank(S_k) − frame_dim_k)`, matching the per-atom Occam term in
7091 /// [`Self::reml_occam_term`]. Returns one entry per atom in atom order.
7092 pub(crate) fn reml_occam_log_lambda_smooth_derivative(&self) -> Result<Vec<f64>, String> {
7093 let mut out = Vec::with_capacity(self.atoms.len());
7094 for atom in &self.atoms {
7095 let rank_s = Self::symmetric_rank(&atom.smooth_penalty)?;
7096 let penalized_channel_dim = atom.border_frame_rank() * rank_s;
7097 let frame_dim = atom.frame_manifold_dimension();
7098 out.push(0.5 * ((penalized_channel_dim as f64) - (frame_dim as f64)));
7099 }
7100 Ok(out)
7101 }
7102
7103 pub fn reml_criterion_streaming_exact(
7104 &mut self,
7105 target: ArrayView2<'_, f64>,
7106 rho: &SaeManifoldRho,
7107 registry: Option<&AnalyticPenaltyRegistry>,
7108 inner_max_iter: usize,
7109 learning_rate: f64,
7110 ridge_ext_coord: f64,
7111 ridge_beta: f64,
7112 ) -> Result<(f64, SaeManifoldLoss), String> {
7113 let mut rho_fixed = rho.clone();
7114 let mut loss = self.run_joint_fit_arrow_schur(
7115 target,
7116 &mut rho_fixed,
7117 registry,
7118 inner_max_iter,
7119 learning_rate,
7120 ridge_ext_coord,
7121 ridge_beta,
7122 )?;
7123 // Drive the inner (t, β) state to the SAME KKT/step-converged optimum the
7124 // dense `reml_criterion_with_cache` reaches before factoring. At that
7125 // optimum the per-row `H_tt^(i)` blocks are PD, so the undamped
7126 // (`ridge_t = 0`) streaming factorization in `streaming_exact_arrow_log_det`
7127 // succeeds — without this, a state stopped after only `inner_max_iter`
7128 // steps can leave a rank-deficient / indefinite row block (`p_out = 1` →
7129 // rank-1 `JᵀJ`, softmax negative-logit curvature) that surfaces
7130 // `PerRowFactorFailed` at base ridge 0. Sharing the driver also keeps the
7131 // streaming and dense log-determinants bit-identical (#847).
7132 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7133 // The dense factor cache from convergence is surplus here — the streaming
7134 // path recomputes the (bit-identical) log-determinant chunk-by-chunk in
7135 // `streaming_exact_arrow_log_det` to bound peak memory — so it is dropped.
7136 let converged_cache = self.converge_inner_for_undamped_logdet(
7137 target,
7138 rho,
7139 &mut rho_fixed,
7140 registry,
7141 inner_max_iter,
7142 learning_rate,
7143 ridge_ext_coord,
7144 ridge_beta,
7145 &mut loss,
7146 &options,
7147 true,
7148 )?;
7149 drop(converged_cache);
7150 let log_det = self.streaming_exact_arrow_log_det(target, rho, registry)?;
7151 let occam = self.reml_occam_term(rho)?;
7152 // Extra analytic-penalty energy (#671/#737), matching the full-batch
7153 // `reml_criterion_with_cache` path so streaming and dense criteria rank
7154 // the identical penalized objective.
7155 let extra_penalty_energy = match registry {
7156 Some(reg) => self
7157 .reml_extra_penalty_value_total(reg)
7158 .map_err(|err| format!("SaeManifoldTerm::reml_criterion_streaming_exact: {err}"))?,
7159 None => 0.0,
7160 };
7161 Ok((
7162 loss.total() + extra_penalty_energy + 0.5 * log_det - occam,
7163 loss,
7164 ))
7165 }
7166
7167 pub fn streaming_exact_arrow_log_det(
7168 &mut self,
7169 target: ArrayView2<'_, f64>,
7170 rho: &SaeManifoldRho,
7171 registry: Option<&AnalyticPenaltyRegistry>,
7172 ) -> Result<f64, String> {
7173 if target.dim() != (self.n_obs(), self.output_dim()) {
7174 return Err(format!(
7175 "SaeManifoldTerm::streaming_exact_arrow_log_det: target must be ({}, {}); got {:?}",
7176 self.n_obs(),
7177 self.output_dim(),
7178 target.dim()
7179 ));
7180 }
7181 let plan = self.streaming_plan().admitted_or_error(
7182 self.n_obs(),
7183 self.output_dim(),
7184 self.k_atoms(),
7185 )?;
7186 if plan.estimated_dense_schur_bytes > plan.in_core_budget_bytes {
7187 return Err(format!(
7188 "SaeManifoldTerm::streaming_exact_arrow_log_det: predicted dense reduced Schur {} bytes exceeds budget {} bytes; cost-only matrix-free route is required",
7189 plan.estimated_dense_schur_bytes, plan.in_core_budget_bytes
7190 ));
7191 }
7192 let n_total = self.n_obs();
7193 let chunk_size = plan.chunk_size.min(n_total.max(1));
7194 // #972 / #977 T1: the reduced β-Schur is over the FACTORED border when
7195 // frames are active (each chunk inherits the frames via
7196 // `materialize_chunk`, so every `chunk_schur` is `border_dim²`), matching
7197 // the dense path's factored log-det. Full-`B` ⇒ `border_dim == beta_dim`.
7198 let border_dim = if self.frames_active() {
7199 self.factored_border_dim()
7200 } else {
7201 self.beta_dim()
7202 };
7203 let mut schur_acc = Array2::<f64>::zeros((border_dim, border_dim));
7204 let mut log_det_tt = 0.0_f64;
7205 // #1038 cross-row IBP Woodbury accumulators. `M = Uᵀ H₀'⁻¹ U` is
7206 // chunk-additive in `M0 = Σ Uᵢᵀ Aᵢ⁻¹ Uᵢ` and `W = Σ Bᵢᵀ Aᵢ⁻¹ Uᵢ`
7207 // (`A = H₀'` block-diagonal, `U` row-supported), closed against the
7208 // GLOBAL reduced Schur `S = schur_acc` after the loop. `None` for every
7209 // non-IBP (softmax / JumpReLU) term, where the streaming log-det is
7210 // exactly the bare `log_det_tt + log_det_schur` as before.
7211 let mut wood_m0: Option<Array2<f64>> = None;
7212 let mut wood_w: Option<Array2<f64>> = None;
7213 let mut wood_d: Option<Array1<f64>> = None;
7214 let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
7215 let mut start = 0usize;
7216 while start < n_total {
7217 let end = (start + chunk_size).min(n_total);
7218 let penalty_scale = (end - start) as f64 / n_total as f64;
7219 let chunk_logits = self.assignment.logits.slice(s![start..end, ..]).to_owned();
7220 let chunk_coords: Vec<Array2<f64>> = self
7221 .assignment
7222 .coords
7223 .iter()
7224 .map(|coord| coord.as_matrix().slice(s![start..end, ..]).to_owned())
7225 .collect();
7226 let mut chunk = self.materialize_chunk(chunk_logits, chunk_coords)?;
7227 // #1117 — rank deficiency is removed at the basis layer at fit entry
7228 // (`reduce_atoms_to_data_supported_rank`), so each chunk inherits the
7229 // already-reduced full-rank atoms via `materialize_chunk`; there are
7230 // no global deflation projectors to propagate.
7231 // #991: chunk terms inherit the row's design honesty weight slice
7232 // (global mean-1 normalization preserved — NOT re-normalized per
7233 // chunk — so the per-chunk sums reconstruct the global weighted
7234 // objective exactly).
7235 if let Some(w) = self.row_loss_weights.as_deref() {
7236 chunk.row_loss_weights = Some(w[start..end].to_vec());
7237 }
7238 let z_chunk = target.slice(s![start..end, ..]);
7239 let sys = chunk
7240 .assemble_arrow_schur_scaled(z_chunk, rho, registry, penalty_scale)
7241 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7242 let mut streaming = StreamingArrowSchur::from_system(&sys, sys.rows.len().max(1));
7243 let (chunk_log_det_tt, chunk_schur, chunk_wood) = streaming
7244 .reduced_schur_log_det_tt_woodbury(0.0, 0.0, &options)
7245 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7246 log_det_tt += chunk_log_det_tt;
7247 for row in 0..border_dim {
7248 for col in 0..border_dim {
7249 schur_acc[[row, col]] += chunk_schur[[row, col]];
7250 }
7251 }
7252 if chunk_wood.is_some() && chunk_size < n_total {
7253 // The cross-row IBP empirical mass `M_k = Σ_i z_ik` couples ALL
7254 // rows, so the per-row `H₀'` diagonal (`score_derivative_k(M_k)`)
7255 // and the column coefficient `d_k = w·s'_k(M_k)` are only exact
7256 // when every row is assembled together — a SINGLE chunk. Under a
7257 // genuine multi-chunk pass each chunk would see a partial mass and
7258 // the Woodbury (and the bare per-row log-det) would be inexact, so
7259 // refuse loudly and route to the dense resident path rather than
7260 // return a silently-wrong evidence. The streaming log-det only
7261 // runs when the dense reduced Schur fits budget, so the single-
7262 // chunk regime is the common case; this guards the rest.
7263 return Err(
7264 "SaeManifoldTerm::streaming_exact_arrow_log_det: exact cross-row IBP \
7265 Woodbury evidence requires a single-chunk pass (the empirical mass \
7266 M_k = Σ_i z_ik couples all rows); this shape needs >1 chunk. Route \
7267 IBP-active large-n fits through the dense resident \
7268 ArrowFactorCache::arrow_log_det."
7269 .to_string(),
7270 );
7271 }
7272 if let Some(cw) = chunk_wood {
7273 wood_m0 = Some(match wood_m0.take() {
7274 Some(mut acc) => {
7275 acc += &cw.m0;
7276 acc
7277 }
7278 None => cw.m0,
7279 });
7280 wood_w = Some(match wood_w.take() {
7281 Some(mut acc) => {
7282 acc += &cw.w;
7283 acc
7284 }
7285 None => cw.w,
7286 });
7287 // `D = diag(d_k)` is per-atom; identical across chunks for a
7288 // single-chunk evidence pass (the regime the streaming log-det
7289 // runs in — the dense reduced Schur must fit budget here), where
7290 // it equals the global mass-derived `cross_row_d`.
7291 wood_d = Some(cw.d);
7292 }
7293 start = end;
7294 }
7295 let log_det_schur = StreamingArrowSchur::reduced_schur_log_det(&schur_acc, &options)
7296 .map_err(|err| format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}"))?;
7297 let mut total = log_det_tt + log_det_schur;
7298 // #1038/#1225: close the exact cross-row IBP Woodbury correction
7299 // `log det(I_R + D Uᵀ H₀'⁻¹ U)` so the streaming evidence equals the
7300 // dense `arrow_log_det_from_cache` (which adds the SAME term). Without
7301 // it the streaming criterion would silently drop the entire cross-row
7302 // coupling and disagree with the dense path by exactly `log|C|`.
7303 if let (Some(m0), Some(w), Some(d)) = (wood_m0, wood_w, wood_d) {
7304 let correction = streaming_cross_row_woodbury_log_det(&schur_acc, &m0, &w, &d)
7305 .map_err(|err| {
7306 format!("SaeManifoldTerm::streaming_exact_arrow_log_det: {err}")
7307 })?
7308 .ok_or_else(|| {
7309 "SaeManifoldTerm::reml_criterion: cross-row IBP joint Hessian is non-PD at \
7310 this ρ; evidence Laplace log-det undefined (infeasible ρ probe)"
7311 .to_string()
7312 })?;
7313 total += correction;
7314 }
7315 Ok(total)
7316 }
7317
7318 /// Per-atom, per-axis coordinate sum-of-squares `‖t_kj‖² = Σ_i t_{i,k,j}²`.
7319 ///
7320 /// This is the data-fit sufficient statistic for the ARD precision update
7321 /// (the numerator-side `‖t‖²` of the deleted `α = n/‖t‖²` rule). Returned
7322 /// per atom as an `Array1` of length `d_k`.
7323 ///
7324 /// On a *periodic* (Circle) axis the relevant statistic is the von-Mises
7325 /// energy-equivalent `Σ_i 2/α·V(t_i) = Σ_i (2/κ²)(1−cos κ t_i)` (independent
7326 /// of α), so that `½·α·sumsq == Σ_i V(t_i)` matches `ard_value`. This keeps
7327 /// the Mackay/Fellner–Schall fixed point `α ← n / (sumsq + tr H⁻¹)`
7328 /// consistent with the actual periodic prior energy rather than the
7329 /// origin-dependent raw `t²`.
7330 pub(crate) fn ard_coord_sumsq(&self) -> Vec<Array1<f64>> {
7331 let mut out = Vec::with_capacity(self.k_atoms());
7332 for coord in &self.assignment.coords {
7333 let d = coord.latent_dim();
7334 let periods = coord.effective_axis_periods();
7335 let mut sq = Array1::<f64>::zeros(d);
7336 for row in 0..coord.n_obs() {
7337 let t = coord.row(row);
7338 for axis in 0..d {
7339 // `sq_equiv` is independent of `alpha`; pass 1.0.
7340 sq[axis] += ArdAxisPrior::eval(1.0, t[axis], periods[axis]).sq_equiv;
7341 }
7342 }
7343 out.push(sq);
7344 }
7345 out
7346 }
7347
7348 /// Per-atom, per-axis posterior-variance trace `tr_kj(H⁻¹) =
7349 /// Σ_i [(H⁻¹)_tt]_{(i,k,j),(i,k,j)}` from the converged factor cache.
7350 ///
7351 /// `cache.latent_block_inverse_diagonal()` returns the diagonal of the
7352 /// latent block `(H⁻¹)_tt` in the cache's compact per-row `delta_t`
7353 /// layout (length `row_offsets[N]`); each per-row block is laid out as
7354 /// `[logit scalars…, then per-active-atom coord axes…]`. This routine
7355 /// sums those diagonal entries over the coord positions belonging to each
7356 /// `(atom k, axis j)` across all observation rows where atom `k` is active.
7357 ///
7358 /// `self.last_row_layout` must be the layout from the *same* assemble that
7359 /// produced `cache`:
7360 /// - `Some(layout)`: compact active-set mode (JumpReLU / large-K
7361 /// softmax-IBP truncation). For row `i`, atom `k`'s position in the
7362 /// active list gives its compact coord-block start `coord_starts[i][pos]`;
7363 /// inactive atoms contribute 0 (the prior dominates there anyway).
7364 /// - `None`: dense full-support layout, uniform row dim
7365 /// `q = assignment_dim + Σ d_k`; atom `k`'s coord block sits at the
7366 /// fixed full-row offset `coord_offsets[k]` after the assignment chart.
7367 ///
7368 /// This `tr_kj(H⁻¹)` is exactly the posterior-variance term the deleted
7369 /// `α = n/‖t‖²` rule dropped; the corrected Mackay/Fellner-Schall fixed
7370 /// point is `α_new = n / (‖t_kj‖² + tr_kj(H⁻¹))`.
7371 pub(crate) fn ard_inverse_traces(
7372 &self,
7373 cache: &ArrowFactorCache,
7374 ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7375 let inv_diag = cache.latent_block_inverse_diagonal()?;
7376 let n = self.n_obs();
7377 let coord_offsets = self.assignment.coord_offsets();
7378 let mut traces: Vec<Array1<f64>> = self
7379 .assignment
7380 .coords
7381 .iter()
7382 .map(|c| Array1::<f64>::zeros(c.latent_dim()))
7383 .collect();
7384 for row in 0..n {
7385 let row_base = cache.row_offsets[row];
7386 match self.last_row_layout {
7387 Some(ref layout) => {
7388 let active = &layout.active_atoms[row];
7389 let starts = &layout.coord_starts[row];
7390 for (pos, &k) in active.iter().enumerate() {
7391 let d = self.assignment.coords[k].latent_dim();
7392 let block_start = starts[pos];
7393 for axis in 0..d {
7394 traces[k][axis] += inv_diag[row_base + block_start + axis];
7395 }
7396 }
7397 }
7398 None => {
7399 for k in 0..self.k_atoms() {
7400 let d = self.assignment.coords[k].latent_dim();
7401 let block_start = coord_offsets[k];
7402 for axis in 0..d {
7403 traces[k][axis] += inv_diag[row_base + block_start + axis];
7404 }
7405 }
7406 }
7407 }
7408 }
7409 Ok(traces)
7410 }
7411
7412 pub(crate) fn ard_log_precision_explicit_derivatives(
7413 &self,
7414 rho: &SaeManifoldRho,
7415 ) -> Result<Vec<Array1<f64>>, String> {
7416 if rho.log_ard.len() != self.k_atoms() {
7417 return Err(format!(
7418 "ARD rho has {} atoms but term has {}",
7419 rho.log_ard.len(),
7420 self.k_atoms()
7421 ));
7422 }
7423 let n = self.n_obs() as f64;
7424 let mut out = Vec::with_capacity(self.k_atoms());
7425 for (atom_idx, coord) in self.assignment.coords.iter().enumerate() {
7426 let d = coord.latent_dim();
7427 let mut atom_out = Array1::<f64>::zeros(rho.log_ard[atom_idx].len());
7428 if rho.log_ard[atom_idx].is_empty() {
7429 out.push(atom_out);
7430 continue;
7431 }
7432 if rho.log_ard[atom_idx].len() != d {
7433 return Err(format!(
7434 "ARD rho atom {atom_idx} has len {} but atom dim is {d}",
7435 rho.log_ard[atom_idx].len()
7436 ));
7437 }
7438 let periods = coord.effective_axis_periods();
7439 for axis in 0..d {
7440 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom_idx][axis]);
7441 let period = periods[axis];
7442 let mut energy_deriv = 0.0_f64;
7443 for row in 0..coord.n_obs() {
7444 let t = coord.row(row)[axis];
7445 energy_deriv += ArdAxisPrior::eval(alpha, t, period).value;
7446 }
7447 let normalizer_deriv = match period {
7448 None => -0.5 * n,
7449 Some(p) => {
7450 let kappa = std::f64::consts::TAU / p;
7451 let eta = alpha / (kappa * kappa);
7452 // d/d(log α) of `n[-η + log I0(η)]` = `n η (I1/I0 - 1)`.
7453 // The ratio is computed without forming `e^{η}`, so it
7454 // stays finite for large `η` instead of the `inf/inf =
7455 // NaN` that `bessel_i1(η)/bessel_i0(η)` produces (#1113).
7456 let ratio = bessel_i0_log_and_ratio(eta).1;
7457 n * eta * (-1.0 + ratio)
7458 }
7459 };
7460 atom_out[axis] = energy_deriv + normalizer_deriv;
7461 }
7462 out.push(atom_out);
7463 }
7464 Ok(out)
7465 }
7466
7467 pub(crate) fn ard_log_precision_hessian_trace(
7468 &self,
7469 rho: &SaeManifoldRho,
7470 cache: &ArrowFactorCache,
7471 solver: &DeflatedArrowSolver<'_>,
7472 ) -> Result<Vec<Array1<f64>>, ArrowSchurError> {
7473 // RAW selected-inverse diagonal: the per-axis diagonal contraction uses
7474 // the DEFLATED inverse; the full kept-subspace + rotation deflation
7475 // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per (row, axis)
7476 // afterwards via the Daleckii–Krein helper. Each ARD ρ-component
7477 // `(atom k, axis)` differentiates a SINGLE coordinate-slot diagonal entry,
7478 // so its `D` is the rank-one `hess·e_s e_sᵀ` at that local slot `s`.
7479 let inv_diag = solver
7480 .latent_inverse_diagonal()
7481 .map_err(|err| ArrowSchurError::SchurFactorFailed { reason: err })?;
7482 let n = self.n_obs();
7483 let total_t = cache.delta_t_len();
7484 let coord_offsets = self.assignment.coord_offsets();
7485 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
7486 .assignment
7487 .coords
7488 .iter()
7489 .map(LatentCoordValues::effective_axis_periods)
7490 .collect();
7491 let mut traces: Vec<Array1<f64>> = self
7492 .assignment
7493 .coords
7494 .iter()
7495 .enumerate()
7496 .map(|(k, c)| {
7497 if rho.log_ard[k].is_empty() {
7498 Array1::<f64>::zeros(0)
7499 } else {
7500 Array1::<f64>::zeros(c.latent_dim())
7501 }
7502 })
7503 .collect();
7504 for row in 0..n {
7505 let row_base = cache.row_offsets[row];
7506 let q = cache.row_dims[row];
7507 let dirs = cache
7508 .deflated_row_directions
7509 .get(row)
7510 .map(Vec::as_slice)
7511 .unwrap_or(&[]);
7512 let spectrum = cache
7513 .deflation_row_spectra
7514 .get(row)
7515 .and_then(Option::as_ref);
7516 // Per-row selected-inverse t-block, built once (only when deflated).
7517 let inv_vv = if dirs.is_empty() {
7518 None
7519 } else {
7520 let mut m = Array2::<f64>::zeros((q, q));
7521 for col in 0..q {
7522 let mut rhs_t = Array1::<f64>::zeros(total_t);
7523 let rhs_beta = Array1::<f64>::zeros(cache.k);
7524 rhs_t[row_base + col] = 1.0;
7525 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7526 ArrowSchurError::SchurFactorFailed { reason: err }
7527 })?;
7528 for r in 0..q {
7529 m[[r, col]] = solved.t[row_base + r];
7530 }
7531 }
7532 Some(m)
7533 };
7534 // Correction for one local coordinate slot `s` with curvature `hess`.
7535 let slot_correction = |s: usize, hess: f64| -> f64 {
7536 let Some(iv) = inv_vv.as_ref() else {
7537 return 0.0;
7538 };
7539 if s >= q || hess == 0.0 {
7540 return 0.0;
7541 }
7542 let mut d = Array2::<f64>::zeros((q, q));
7543 d[[s, s]] = hess;
7544 Self::deflation_block_correction(iv, &d, dirs, spectrum)
7545 };
7546 match self.last_row_layout {
7547 Some(ref layout) => {
7548 let active = &layout.active_atoms[row];
7549 let starts = &layout.coord_starts[row];
7550 for (pos, &k) in active.iter().enumerate() {
7551 if rho.log_ard[k].is_empty() {
7552 continue;
7553 }
7554 let coord = &self.assignment.coords[k];
7555 let d = coord.latent_dim();
7556 let block_start = starts[pos];
7557 for axis in 0..d {
7558 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7559 let t = coord.row(row)[axis];
7560 let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7561 let hess = prior.hess.max(0.0);
7562 let s = block_start + axis;
7563 traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7564 traces[k][axis] -= 0.5 * slot_correction(s, hess);
7565 }
7566 }
7567 }
7568 None => {
7569 for k in 0..self.k_atoms() {
7570 if rho.log_ard[k].is_empty() {
7571 continue;
7572 }
7573 let coord = &self.assignment.coords[k];
7574 let d = coord.latent_dim();
7575 let block_start = coord_offsets[k];
7576 for axis in 0..d {
7577 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[k][axis]);
7578 let t = coord.row(row)[axis];
7579 let prior = ArdAxisPrior::eval(alpha, t, ard_axis_periods[k][axis]);
7580 let hess = prior.hess.max(0.0);
7581 let s = block_start + axis;
7582 traces[k][axis] += 0.5 * inv_diag[row_base + s] * hess;
7583 traces[k][axis] -= 0.5 * slot_correction(s, hess);
7584 }
7585 }
7586 }
7587 }
7588 }
7589 Ok(traces)
7590 }
7591
7592 /// Per-atom decoder-smoothness penalty quadratic form (#1556): entry `k` is
7593 /// the λ-free `<B_k, ½(S_k+S_kᵀ)·B_k> = Σ_oc B_k[:,oc]ᵀ S_k B_k[:,oc]`, the
7594 /// per-atom denominator of atom `k`'s λ_smooth Fellner-Schall update. The sum
7595 /// over atoms is `βᵀ(⊕_k S_k ⊗ I_p)β`, the un-scaled total penalty energy.
7596 /// `S_k` is symmetrised defensively (as the assembler does); the per-atom
7597 /// `½(S+Sᵀ)·B_k` GEMMs ride the multi-GPU batched smoothness GEMM with an
7598 /// exact per-atom CPU fallback.
7599 pub(crate) fn decoder_smoothness_quadratic_form_per_atom(&self) -> Vec<f64> {
7600 let sb_inputs: Vec<(ArrayView2<'_, f64>, ArrayView2<'_, f64>)> = self
7601 .atoms
7602 .iter()
7603 .map(|atom| (atom.smooth_penalty.view(), atom.decoder_coefficients.view()))
7604 .collect();
7605 let sb_all = batched_smooth_sb(&sb_inputs, true);
7606 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7607 for (atom_idx, (atom, sb)) in self.atoms.iter().zip(sb_all.iter()).enumerate() {
7608 per_atom[atom_idx] = (&atom.decoder_coefficients * sb).sum();
7609 }
7610 per_atom
7611 }
7612
7613 /// Per-atom effective penalized dof of the decoder smoothness penalty
7614 /// (#1556): entry `k` is `tr(S_β⁻¹ · M_k)` with `M_k = (λ_smooth[k]·S_k) ⊗ I`
7615 /// and `S_β⁻¹ = (H⁻¹)_ββ` the Schur-complement inverse, each atom scaled by
7616 /// its OWN `lambda_smooth[atom_idx]`. Built on
7617 /// [`ArrowFactorCache::schur_inverse_apply`]: column `(k,μ,oc)` of `M_k` is
7618 /// `λ_k·S_k[:,μ] ⊗ e_oc` (sparse), so we apply `S_β⁻¹` to that K-vector and
7619 /// read back `result[col]`. The total edf is the sum of the returned vector
7620 /// (a uniform/broadcast λ reproduces the historical global trace).
7621 pub(crate) fn decoder_smoothness_effective_dof_per_atom(
7622 &self,
7623 cache: &ArrowFactorCache,
7624 lambda_smooth: &[f64],
7625 ) -> Result<Vec<f64>, ArrowSchurError> {
7626 let p = self.output_dim();
7627 let frames_active = self.frames_active();
7628 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7629 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7630 (
7631 self.factored_beta_offsets(),
7632 Box::new(move |k: usize| ranks[k]),
7633 )
7634 } else {
7635 (self.beta_offsets(), Box::new(move |_k: usize| p))
7636 };
7637 let k = cache.k;
7638 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7639 let mut m_col = Array1::<f64>::zeros(k);
7640 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7641 let s = &atom.smooth_penalty;
7642 let m = atom.basis_size();
7643 let off = offsets[atom_idx];
7644 let r = out_dim(atom_idx);
7645 let lambda = lambda_smooth[atom_idx];
7646 let mut trace = 0.0_f64;
7647 for mu in 0..m {
7648 for oc in 0..r {
7649 let col = off + mu * r + oc;
7650 m_col.fill(0.0);
7651 for nu in 0..m {
7652 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7653 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7654 }
7655 let z = cache.schur_inverse_apply(m_col.view())?;
7656 trace += z[col];
7657 }
7658 }
7659 per_atom[atom_idx] = trace;
7660 }
7661 Ok(per_atom)
7662 }
7663
7664 /// Per-atom effective penalized dof via the deflated solver (#1556): entry
7665 /// `k` is `tr((H⁻¹)_ββ · M_k)` for `M_k = (λ_smooth[k]·S_k) ⊗ I`, each atom
7666 /// scaled by its OWN `lambda_smooth[atom_idx]`. The total is the sum.
7667 pub(crate) fn decoder_smoothness_effective_dof_with_solver_per_atom(
7668 &self,
7669 cache: &ArrowFactorCache,
7670 solver: &DeflatedArrowSolver<'_>,
7671 lambda_smooth: &[f64],
7672 ) -> Result<Vec<f64>, String> {
7673 let p = self.output_dim();
7674 // #972 / #977 T1: the cache's β block is the FACTORED border when frames
7675 // are active (`cache.k == factored_border_dim`), so the smoothness edf
7676 // trace `tr((H⁻¹)_ββ · M)` is taken over the same factored layout, with
7677 // `M = ⊕_k (λ_k S_k) ⊗ I_{r_k}` at the factored offsets (the `U_kᵀU_k = I`
7678 // collapse means the per-coordinate-channel penalty is `λ_k S_k`, exactly
7679 // as in the full-`B` `⊗ I_p` case but with `r_k` channels). On the
7680 // full-`B` path `frames_active` is false: `out_dim_k = p`, the offsets
7681 // are `beta_offsets`, and this is bit-for-bit the historical trace.
7682 let frames_active = self.frames_active();
7683 let (offsets, out_dim): (Vec<usize>, Box<dyn Fn(usize) -> usize>) = if frames_active {
7684 let ranks: Vec<usize> = self.atoms.iter().map(|a| a.border_frame_rank()).collect();
7685 (
7686 self.factored_beta_offsets(),
7687 Box::new(move |k: usize| ranks[k]),
7688 )
7689 } else {
7690 (self.beta_offsets(), Box::new(move |_k: usize| p))
7691 };
7692 let k = cache.k;
7693 let mut per_atom = vec![0.0_f64; self.atoms.len()];
7694 let mut m_col = Array1::<f64>::zeros(k);
7695 for (atom_idx, atom) in self.atoms.iter().enumerate() {
7696 let s = &atom.smooth_penalty;
7697 let m = atom.basis_size();
7698 let off = offsets[atom_idx];
7699 let r = out_dim(atom_idx);
7700 let lambda = lambda_smooth[atom_idx];
7701 let mut trace = 0.0_f64;
7702 for mu in 0..m {
7703 for oc in 0..r {
7704 let col = off + mu * r + oc;
7705 // M[:,col] = λ_k · S_k[:,mu] ⊗ e_oc (nonzero at off+ν·r+oc).
7706 m_col.fill(0.0);
7707 for nu in 0..m {
7708 let s_nu_mu = 0.5 * (s[[nu, mu]] + s[[mu, nu]]);
7709 m_col[off + nu * r + oc] = lambda * s_nu_mu;
7710 }
7711 let zero_t = Array1::<f64>::zeros(cache.delta_t_len());
7712 let z = solver.solve(zero_t.view(), m_col.view())?.beta;
7713 trace += z[col];
7714 }
7715 }
7716 per_atom[atom_idx] = trace;
7717 }
7718 Ok(per_atom)
7719 }
7720
7721 pub(crate) fn assignment_log_strength_hessian_trace(
7722 &self,
7723 rho: &SaeManifoldRho,
7724 cache: &ArrowFactorCache,
7725 solver: &DeflatedArrowSolver<'_>,
7726 ) -> Result<f64, String> {
7727 let k_atoms = self.k_atoms();
7728 // #1038 softmax: `H` carries the DENSE entropy block, and since the
7729 // entropy curvature scales linearly with `λ_sparse = exp(ρ)`,
7730 // `∂H/∂ρ = H_entropy` (the full dense per-row block, not just its
7731 // diagonal). The trace `½ tr(H⁻¹ ∂H/∂ρ)` must therefore contract the
7732 // dense `∂H/∂ρ` against the per-row selected-inverse BLOCK, mirroring the
7733 // dense `log|H|` and θ-adjoint — a diagonal-only contraction would
7734 // desync the ρ-gradient from the criterion. The assembled majorizer
7735 // `D = diag(Σ_j|H_kj|)` is itself DIAGONAL (#1419), so the contraction
7736 // reduces to `½ Σ_slot (H⁻¹)_{slot,slot}·D_atom`. On the dense `None`
7737 // layout the logit slot equals the atom position; on the compact
7738 // softmax top-`k` layout (#1408/#1409) the slots are the row's active
7739 // atoms — the SAME `D_atom` (full-`K` abs-row-sum) the assembly wrote.
7740 if let AssignmentMode::Softmax {
7741 temperature,
7742 sparsity,
7743 } = self.assignment.mode
7744 {
7745 if k_atoms <= 1 {
7746 return Ok(0.0);
7747 }
7748 let inv_tau = 1.0 / temperature;
7749 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
7750 let penalty = gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
7751 k_atoms,
7752 temperature,
7753 );
7754 // Softmax uses the reduced K−1 free-logit chart on the dense layout
7755 // (last reference logit fixed); the compact layout carries one slot
7756 // per active atom. The diagonal selected inverse gives each slot's
7757 // (H⁻¹)_{slot,slot}.
7758 let assignment_dim = self.assignment.assignment_coord_dim();
7759 // Kept-subspace inverse diagonal: the deflated inverse assigns
7760 // `1/λ̃ = 1` to each per-row UNIT-stiffness direction `vᵢ`, so a raw
7761 // diagonal `D` contraction would spuriously add `½ Σ_i vᵢᵀ D vᵢ` (a
7762 // ρ-independent direction must add 0). `latent_inverse_diagonal_kept`
7763 // removes that per-row deflated diagonal centrally.
7764 let inv_diag = solver
7765 .latent_inverse_diagonal_kept()
7766 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7767 let mut trace = 0.0_f64;
7768 for row in 0..self.n_obs() {
7769 let row_base = cache.row_offsets[row];
7770 // ∂(scale·D)/∂ρ = scale·D (linear in λ_sparse = eᵖ) — the SAME
7771 // operator the assembly and θ-adjoint differentiate.
7772 match self.last_row_layout {
7773 Some(ref layout) => {
7774 // #1410: the compact adjoint reads `D_kk` only for this
7775 // row's `≤ top_k` active atoms, so compute those entries
7776 // directly from the softmax row `a` via the active-only
7777 // Gershgorin helper — no full-`K` `row_logits` copy and no
7778 // full-`K` `d` vector. `a` itself is the irreducible `O(K)`
7779 // softmax normalisation, computed once per row and shared
7780 // across the row's active slots.
7781 let a = crate::assignment::softmax_row(
7782 self.assignment.logits.row(row),
7783 temperature,
7784 );
7785 let a = a.as_slice().expect("softmax row must be contiguous");
7786 let m = softmax_majorizer_log_mean(a);
7787 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7788 let d_atom =
7789 active_softmax_gershgorin_majorizer_entry(a, atom, m, scale);
7790 trace += inv_diag[row_base + pos] * d_atom;
7791 }
7792 }
7793 None => {
7794 // Dense layout genuinely contracts every free logit slot's
7795 // `D_kk`, so the full-`K` `d` is intrinsic here; keep the
7796 // single-source dense majorizer call.
7797 let row_logits: Vec<f64> = (0..k_atoms)
7798 .map(|k| self.assignment.logits[[row, k]])
7799 .collect();
7800 let d = penalty.psd_majorizer_abs_row_sums(&row_logits, scale);
7801 let q = cache.row_dims[row];
7802 let logit_dim = assignment_dim.min(q);
7803 for atom in 0..logit_dim {
7804 trace += inv_diag[row_base + atom] * d[atom];
7805 }
7806 }
7807 }
7808 }
7809 return Ok(0.5 * trace);
7810 }
7811 let hdiag = assignment_prior_log_strength_hdiag(&self.assignment, rho)?;
7812 if hdiag.is_empty() {
7813 return Ok(0.0);
7814 }
7815 // RAW selected-inverse diagonal: the per-row diagonal contraction uses the
7816 // DEFLATED inverse; the full kept-subspace + β-Schur/rotation deflation
7817 // correction `tr(inv_vv·(D − DΦ[D]))` is subtracted per row afterwards
7818 // (`deflation_block_correction`), exactly as the data trace does. The
7819 // cross-row off-diagonal pass below contracts only DISTINCT rows `i ≠ j`,
7820 // off any single-row `vᵢ`'s support, so it needs no deflation correction.
7821 let inv_diag = solver
7822 .latent_inverse_diagonal()
7823 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?;
7824 let assignment_dim = self.assignment.assignment_coord_dim();
7825 let total_t = cache.delta_t_len();
7826 // #932 FRONT C: row-local Takahashi selected inverse on the plain arrow
7827 // for the per-row deflation correction below (the diagonal trace already
7828 // uses the cheap `latent_inverse_diagonal`); gauge / cross-row Woodbury
7829 // fall back to the per-row full-system `solve` loop.
7830 let fast_selected = solver.plain_selected_inverse_available();
7831 let selected_beta_inv = if fast_selected && cache.k > 0 {
7832 solver
7833 .beta_inv()
7834 .map_err(|err| format!("assignment_log_strength_hessian_trace: {err}"))?
7835 } else {
7836 Array2::<f64>::zeros((0, 0))
7837 };
7838 // #1416 cross-row IBP source: the per-row block that the deflation
7839 // factorizes is the NO-SELF base `H₀'` — the rank-one self curvature
7840 // `d_k·J_ik²` is DOWNDATED from each logit diagonal and re-applied through
7841 // the Woodbury carrier. The full-`H` diagonal contraction below still uses
7842 // the full `hdiag` (which carries that self term), but the per-row
7843 // DEFLATION correction must use `(∂H₀'/∂ρ)_tt`, i.e. `hdiag` MINUS the
7844 // downdated self term — otherwise the Daleckii–Krein correction
7845 // mis-attributes the (un-deflated) Woodbury self curvature's derivative to
7846 // the deflated subspace. For non-IBP modes there is no Woodbury source and
7847 // the self term is `0` (the deflated block IS the full block).
7848 let cross_channels = if self.last_row_layout.is_none() {
7849 ibp_assignment_third_channels(&self.assignment, rho)?
7850 } else {
7851 None
7852 };
7853 let learnable_alpha = matches!(
7854 self.assignment.mode,
7855 AssignmentMode::IBPMap {
7856 learnable_alpha: true,
7857 ..
7858 }
7859 );
7860 let self_curv = |row: usize, atom: usize| -> f64 {
7861 let Some(ch) = cross_channels.as_ref() else {
7862 return 0.0;
7863 };
7864 let d_k = if learnable_alpha {
7865 ch.cross_row_d_logalpha[atom]
7866 } else {
7867 ch.cross_row_d[atom]
7868 };
7869 let j = ch.z_jac[row * k_atoms + atom];
7870 d_k * j * j
7871 };
7872 let mut trace = 0.0_f64;
7873 for row in 0..self.n_obs() {
7874 let row_base = cache.row_offsets[row];
7875 let assignment_base = row * k_atoms;
7876 let q = cache.row_dims[row];
7877 // Per-row diagonal `(∂H₀'/∂ρ)_tt` for the deflation correction: the
7878 // assignment prior curves only the logit/assignment slots (coordinate
7879 // slots are 0 — ARD handles those), MINUS the downdated cross-row self
7880 // curvature. The full-`H` trace contraction keeps the full `hdiag`.
7881 let mut d_diag = Array1::<f64>::zeros(q);
7882 match self.last_row_layout {
7883 Some(ref layout) => {
7884 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
7885 let d_slot = hdiag[assignment_base + atom];
7886 trace += inv_diag[row_base + pos] * d_slot;
7887 if pos < q {
7888 d_diag[pos] = d_slot - self_curv(row, atom);
7889 }
7890 }
7891 }
7892 None => {
7893 for free_idx in 0..assignment_dim {
7894 let d_slot = hdiag[assignment_base + free_idx];
7895 trace += inv_diag[row_base + free_idx] * d_slot;
7896 if free_idx < q {
7897 d_diag[free_idx] = d_slot - self_curv(row, free_idx);
7898 }
7899 }
7900 }
7901 }
7902 let dirs = cache
7903 .deflated_row_directions
7904 .get(row)
7905 .map(Vec::as_slice)
7906 .unwrap_or(&[]);
7907 if !dirs.is_empty() {
7908 let inv_vv = if fast_selected {
7909 let (inv_vv, _inv_vbeta) = solver
7910 .selected_inverse_row_blocks(row, &selected_beta_inv)
7911 .map_err(|err| {
7912 format!("assignment_log_strength_hessian_trace: selected inverse: {err}")
7913 })?;
7914 inv_vv
7915 } else {
7916 let mut inv_vv = Array2::<f64>::zeros((q, q));
7917 for col in 0..q {
7918 let mut rhs_t = Array1::<f64>::zeros(total_t);
7919 let rhs_beta = Array1::<f64>::zeros(cache.k);
7920 rhs_t[row_base + col] = 1.0;
7921 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7922 format!(
7923 "assignment_log_strength_hessian_trace: selected inverse: {err}"
7924 )
7925 })?;
7926 for r in 0..q {
7927 inv_vv[[r, col]] = solved.t[row_base + r];
7928 }
7929 }
7930 inv_vv
7931 };
7932 let mut d_mat = Array2::<f64>::zeros((q, q));
7933 for s in 0..q {
7934 d_mat[[s, s]] = d_diag[s];
7935 }
7936 let spectrum = cache
7937 .deflation_row_spectra
7938 .get(row)
7939 .and_then(Option::as_ref);
7940 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
7941 }
7942 }
7943 // #1416: the IBP prior Hessian is `H_p = d·J Jᵀ + diag(s, c)`, where the
7944 // rank-one `d·J Jᵀ` couples EVERY row pair `(i, j)` in a column `k`
7945 // through the shared empirical mass `M_k`. The assembled `H` carries the
7946 // full `H_full = H₀' + U D Uᵀ` (Woodbury, construction.rs:4710-4752), and
7947 // for fixed alpha the entire IBP prior scales with `λ = eᵖ`, so
7948 // `∂H_p/∂ρ = H_p`. The diagonal loop above already captures the `i = j`
7949 // self terms (the `d·J_ik²` summand lives in `hdiag`); this pass adds the
7950 // omitted off-diagonal `½·d_k·Σ_{i≠j}(H⁻¹)_{ik,jk}·J_ik·J_jk`. Only IBP
7951 // has the cross-row rank-one source; for other diagonal modes
7952 // `ibp_assignment_third_channels` returns `None` and the trace stays the
7953 // pure diagonal contraction. (IBP fixed-alpha uses the dense `None`
7954 // layout, so atom `k`'s logit slot is local position `k`.)
7955 if self.last_row_layout.is_none() {
7956 if let Some(channels) = ibp_assignment_third_channels(&self.assignment, rho)? {
7957 let n = self.n_obs();
7958 let total_t = cache.delta_t_len();
7959 // This trace is ½ ∂log|H|/∂ρ. For FIXED-α IBP the whole prior
7960 // scales with λ=eᵖ so ∂H_p/∂ρ = H_p and the rank-one coefficient
7961 // is the VALUE `cross_row_d[k] = w·s'_k`. For LEARNABLE-α this trace
7962 // is ½ ∂log|H|/∂logα, and the rank-one block's logα-derivative is
7963 // `∂d_k/∂logα = w·∂s'_k/∂logα` (`cross_row_d_logalpha[k]`) — the same
7964 // α-derivative the DIAGONAL channel (`hessian_diag_log_alpha_derivative`)
7965 // already uses. Using the value `s'_k` here (the pre-fix bug) made the
7966 // off-diagonal inconsistent with the diagonal and the α-gradient wrong.
7967 let learnable_alpha = matches!(
7968 self.assignment.mode,
7969 AssignmentMode::IBPMap {
7970 learnable_alpha: true,
7971 ..
7972 }
7973 );
7974 let mut cross = 0.0_f64;
7975 for k in 0..k_atoms {
7976 let d_k = if learnable_alpha {
7977 channels.cross_row_d_logalpha[k]
7978 } else {
7979 channels.cross_row_d[k]
7980 };
7981 if d_k == 0.0 {
7982 continue;
7983 }
7984 for i in 0..n {
7985 let j_ik = channels.z_jac[i * k_atoms + k];
7986 if j_ik == 0.0 {
7987 continue;
7988 }
7989 // (H⁻¹) column at row `i`'s logit-`k` slot.
7990 let mut rhs_t = Array1::<f64>::zeros(total_t);
7991 let rhs_beta = Array1::<f64>::zeros(cache.k);
7992 rhs_t[cache.row_offsets[i] + k] = 1.0;
7993 let solved =
7994 solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
7995 format!("assignment_log_strength_hessian_trace: {err}")
7996 })?;
7997 for j in 0..n {
7998 if j == i {
7999 continue;
8000 }
8001 let j_jk = channels.z_jac[j * k_atoms + k];
8002 if j_jk == 0.0 {
8003 continue;
8004 }
8005 let inv_ij = solved.t[cache.row_offsets[j] + k];
8006 cross += d_k * inv_ij * j_ik * j_jk;
8007 }
8008 }
8009 }
8010 trace += cross;
8011 }
8012 }
8013 Ok(0.5 * trace)
8014 }
8015
8016 pub(crate) fn learnable_ibp_forward_alpha_data_derivative(
8017 &self,
8018 rho: &SaeManifoldRho,
8019 target: ArrayView2<'_, f64>,
8020 ) -> Result<f64, String> {
8021 let AssignmentMode::IBPMap {
8022 temperature: _,
8023 learnable_alpha: true,
8024 ..
8025 } = self.assignment.mode
8026 else {
8027 return Ok(0.0);
8028 };
8029 let alpha = self
8030 .assignment
8031 .mode
8032 .resolved_ibp_alpha(rho)
8033 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8034 let k_atoms = self.k_atoms();
8035 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8036 let mut dprior = Array1::<f64>::zeros(k_atoms);
8037 for k in 0..k_atoms {
8038 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8039 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8040 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8041 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8042 }
8043 let n = self.n_obs();
8044 let p = self.output_dim();
8045 let row_loss_w = self.row_loss_weights.as_deref();
8046 let whitens = self
8047 .row_metric
8048 .as_ref()
8049 .is_some_and(|metric| metric.whitens_likelihood());
8050 let mut decoded = vec![0.0_f64; p];
8051 let mut fitted = Array1::<f64>::zeros(p);
8052 let mut f_rho = Array1::<f64>::zeros(p);
8053 let mut residual = Array1::<f64>::zeros(p);
8054 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8055 let mut assignments = vec![0.0_f64; k_atoms];
8056 let mut total = 0.0_f64;
8057 for row in 0..n {
8058 self.assignment
8059 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8060 fitted.fill(0.0);
8061 f_rho.fill(0.0);
8062 for k in 0..k_atoms {
8063 self.atoms[k].fill_decoded_row(row, &mut decoded);
8064 // Ungated (#1026 background-tier) atoms have a force-fixed unit
8065 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8066 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8067 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8068 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8069 // but `a_k` still varies with α through `π_k`, so it must NOT be
8070 // skipped.)
8071 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8072 0.0
8073 } else {
8074 (assignments[k] / prior[k]) * dprior[k]
8075 };
8076 for out_col in 0..p {
8077 fitted[out_col] += assignments[k] * decoded[out_col];
8078 f_rho[out_col] += da_rho * decoded[out_col];
8079 }
8080 }
8081 for out_col in 0..p {
8082 residual[out_col] = fitted[out_col] - target[[row, out_col]];
8083 }
8084 let residual_metric = match self.row_metric.as_ref() {
8085 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8086 _ => residual.to_vec(),
8087 };
8088 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8089 let mut row_dot = 0.0_f64;
8090 for out_col in 0..p {
8091 row_dot += residual_metric[out_col] * f_rho[out_col];
8092 }
8093 total += row_weight * row_dot;
8094 }
8095 Ok(total)
8096 }
8097
8098 /// Per-row spectral-deflation correction `tr((H⁻¹)_tt · (D − DΦ[D]))` for one
8099 /// evidence ρ-component, to be SUBTRACTED from the raw-derivative trace
8100 /// `tr((H⁻¹)_tt · D)` the trace otherwise accumulates.
8101 ///
8102 /// The criterion VALUE re-deflates each per-row `H_tt` at every ρ, so the
8103 /// correct evidence gradient contracts `(H⁻¹)_tt` against the deflation-map
8104 /// derivative `DΦ[D]`, not the raw `D = (∂H_raw/∂ρ)_tt`. By Daleckii–Krein,
8105 /// in the row's RAW eigenbasis `U`,
8106 /// `DΦ[D] = U (F ∘ (Uᵀ D U)) Uᵀ`, `F_{ml} = (λ̃ₘ − λ̃ₗ)/(λₘ − λₗ)`
8107 /// (raw `λ` in the denominator, conditioned `λ̃` in the numerator; the
8108 /// diagonal / degenerate entry is `f'(λₘ) = 1` for an unclamped kept
8109 /// direction and `0` otherwise). Hence `D − DΦ[D] = U ((1−F) ∘ (Uᵀ D U)) Uᵀ`,
8110 /// whose kept×kept block is `0`, deflated×deflated block is the full `M`, and
8111 /// kept(m)×deflated(i) block carries the ROTATION coefficient
8112 /// `(1−λᵢ)/(λₘ−λᵢ)`. Contracting against the FULL deflated selected-inverse
8113 /// t-block `inv_vv` (which carries the β-Schur back-substitution) captures
8114 /// both the within-row kept-subspace term and the deferred β-Schur/rotation
8115 /// coupling in one pass, matching the re-deflating fixed-state FD oracle.
8116 ///
8117 /// `spectrum = Some` (spectral deflation): exact Daleckii–Krein. `None` with a
8118 /// non-empty `dirs` (gauge-only deflation, ρ-independent structural null):
8119 /// fall back to the within-row kept-subspace term `Σᵢ vᵢᵀ D vᵢ`.
8120 /// `inv_vv` is assumed symmetric (selected inverse of a symmetric PD system).
8121 fn deflation_block_correction(
8122 inv_vv: &Array2<f64>,
8123 d_mat: &Array2<f64>,
8124 dirs: &[Array1<f64>],
8125 spectrum: Option<&RowDeflationSpectrum>,
8126 ) -> f64 {
8127 let q = inv_vv.nrows();
8128 let Some(spec) = spectrum else {
8129 // Gauge-only deflation: ρ-independent structural null → within-row term.
8130 let mut acc = 0.0_f64;
8131 for v in dirs {
8132 for a in 0..q {
8133 let va = if a < v.len() { v[a] } else { 0.0 };
8134 if va == 0.0 {
8135 continue;
8136 }
8137 for b in 0..q {
8138 let vb = if b < v.len() { v[b] } else { 0.0 };
8139 acc += va * vb * d_mat[[a, b]];
8140 }
8141 }
8142 }
8143 return acc;
8144 };
8145 let u = &spec.evecs;
8146 if u.nrows() != q || u.ncols() != q {
8147 return 0.0;
8148 }
8149 let raw = &spec.raw_evals;
8150 let cond = &spec.cond_evals;
8151 // M = Uᵀ D U, W = Uᵀ inv_vv U (both q×q, symmetric).
8152 let m = u.t().dot(d_mat).dot(u);
8153 let w = u.t().dot(inv_vv).dot(u);
8154 // correction = Σ_{m,l} W[m,l]·M[m,l]·(1 − F[m,l]).
8155 let mut acc = 0.0_f64;
8156 let eps = 1.0e-12;
8157 for a in 0..q {
8158 for b in 0..q {
8159 let denom = raw[a] - raw[b];
8160 let f1 = if denom.abs() > eps {
8161 (cond[a] - cond[b]) / denom
8162 } else if cond[a] == raw[a] {
8163 1.0
8164 } else {
8165 0.0
8166 };
8167 acc += w[[a, b]] * m[[a, b]] * (1.0 - f1);
8168 }
8169 }
8170 acc
8171 }
8172
8173 /// #1417: exact `½ tr(H⁻¹ ∂H_data/∂logα)` for LEARNABLE IBP alpha.
8174 ///
8175 /// The forward assignment is `a_ik = σ(ℓ_ik/τ)·π_k(α)` with the #614
8176 /// consistent stick-breaking mean `π_k(α) = (α/(α+1))^(k+1)`, so
8177 /// `∂logπ_k/∂logα = (k+1)/(α+1)`. EVERY data-Jacobian column for atom `k` —
8178 /// the logit-JVP row (carries one `π_k`), the coordinate rows (carry one
8179 /// `a_k`), and the β-leg (`a_k·φ`) — carries exactly ONE `a_k`/`π_k` factor
8180 /// (`σ(ℓ/τ)` is α-independent). Hence each Jacobian column scales as
8181 /// `∂J_·k/∂logα = ((k+1)/(α+1))·J_·k`, and the data Hessian block for the
8182 /// atom pair `(k_a, k_b)` scales as
8183 /// ∂H_data[a,b]/∂logα = (((k_a+1) + (k_b+1))/(α+1))·H_data[a,b].
8184 /// Therefore the exact data-block contribution to the α-logdet trace is
8185 /// ½ tr(H⁻¹ ∂H_data/∂logα)
8186 /// = ½/(α+1) · Σ_{a,b} ((k_a+1) + (k_b+1))·(H⁻¹)_{ba}·H_data[a,b],
8187 /// over the full joint `(t, β)` index set. `H_data[a,b]` is the data-fit
8188 /// Gauss-Newton block built from the SAME `row_jets_for_logdet` first-jets the
8189 /// θ-adjoint uses (`H_tt = ⟨J_a,J_b⟩`, `H_tβ = ⟨J_a,J_β⟩`, `H_ββ = ⟨J_β,J_β'⟩`),
8190 /// and `(H⁻¹)` is contracted through the same per-row selected-inverse blocks.
8191 /// This closes the learnable-α gradient: combined with the prior-Hessian
8192 /// trace (`assignment_log_strength_hessian_trace`) the full
8193 /// `½ tr(H⁻¹ ∂H/∂logα)` is now assembled. For FIXED alpha (and non-IBP modes)
8194 /// this is identically zero.
8195 pub(crate) fn learnable_ibp_data_logdet_alpha_trace(
8196 &self,
8197 rho: &SaeManifoldRho,
8198 cache: &ArrowFactorCache,
8199 solver: &DeflatedArrowSolver<'_>,
8200 ) -> Result<f64, String> {
8201 let AssignmentMode::IBPMap {
8202 learnable_alpha: true,
8203 ..
8204 } = self.assignment.mode
8205 else {
8206 return Ok(0.0);
8207 };
8208 let alpha = self
8209 .assignment
8210 .mode
8211 .resolved_ibp_alpha(rho)
8212 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8213 let inv_alpha1 = 1.0 / (alpha + 1.0);
8214 let n = self.n_obs();
8215 let total_t = cache.delta_t_len();
8216 let second_jets = self.atom_second_jets()?;
8217 let border = self.border_channels_for_cache(cache)?;
8218
8219 // β-tier selected inverse `(H⁻¹)_ββ` (shared across rows). #932 FRONT C:
8220 // on the plain bordered arrow this is the cached dense `S⁻¹` formed once
8221 // (no `K` full-system solves); when a gauge / #1038 cross-row Woodbury is
8222 // active the row-local Takahashi blocks are NOT valid, so we fall back to
8223 // the per-β-coordinate `solve` loop (bit-identical, just O(n) per call).
8224 let fast_selected = solver.plain_selected_inverse_available();
8225 let beta_inv = if cache.k == 0 {
8226 Array2::<f64>::zeros((0, 0))
8227 } else if fast_selected {
8228 solver.beta_inv().map_err(|err| {
8229 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8230 })?
8231 } else {
8232 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8233 let rhs_t = Array1::<f64>::zeros(total_t);
8234 for col in 0..cache.k {
8235 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8236 rhs_beta[col] = 1.0;
8237 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8238 format!("learnable_ibp_data_logdet_alpha_trace: beta inverse: {err}")
8239 })?;
8240 for r in 0..cache.k {
8241 beta_inv[[r, col]] = solved.beta[r];
8242 }
8243 }
8244 beta_inv
8245 };
8246 // Atom index of each β border channel (the `k_b` weight for the β leg).
8247 let border_atom: Vec<usize> = border.iter().map(|c| c.atom).collect();
8248
8249 let mut trace = 0.0_f64;
8250 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8251 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8252 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8253 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8254 // rows fall back to the scalar per-row path (bit-identical either way).
8255 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8256 std::collections::VecDeque::new();
8257 let mut jet_window_next = 0usize;
8258 for row in 0..n {
8259 let q = cache.row_dims[row];
8260 let base = cache.row_offsets[row];
8261 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8262 self.assignment
8263 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8264 if jet_window.is_empty() {
8265 jet_window_next = self.refill_jet_window(
8266 rho,
8267 jet_window_next,
8268 cache,
8269 &second_jets,
8270 &border,
8271 &mut jet_window,
8272 )?;
8273 }
8274 let jets = jet_window.pop_front().expect("jet window must be non-empty");
8275 // Atom index (k-weight) of each local t-var.
8276 let var_atom: Vec<usize> = jets
8277 .vars
8278 .iter()
8279 .map(|v| match *v {
8280 SaeLocalRowVar::Logit { atom } => atom,
8281 SaeLocalRowVar::Coord { atom, .. } => atom,
8282 })
8283 .collect();
8284
8285 // Per-row selected inverse blocks `(H⁻¹)_tt` (q×q) and `(H⁻¹)_tβ`.
8286 // #932 FRONT C: row-local Takahashi (O(q·(q+K))) on the plain arrow;
8287 // per-row full-system `solve` loop (O(n·q)) under gauge / cross-row
8288 // Woodbury where the row-local blocks are not valid.
8289 let (inv_vv, inv_vbeta) = if fast_selected {
8290 solver
8291 .selected_inverse_row_blocks(row, &beta_inv)
8292 .map_err(|err| {
8293 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8294 })?
8295 } else {
8296 let mut inv_vv = Array2::<f64>::zeros((q, q));
8297 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8298 for col in 0..q {
8299 let mut rhs_t = Array1::<f64>::zeros(total_t);
8300 let rhs_beta = Array1::<f64>::zeros(cache.k);
8301 rhs_t[base + col] = 1.0;
8302 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8303 format!("learnable_ibp_data_logdet_alpha_trace: selected inverse: {err}")
8304 })?;
8305 for r in 0..q {
8306 inv_vv[[r, col]] = solved.t[base + r];
8307 }
8308 for b in 0..cache.k {
8309 inv_vbeta[[col, b]] = solved.beta[b];
8310 }
8311 }
8312 (inv_vv, inv_vbeta)
8313 };
8314
8315 // #1026 — UNGATED (background-tier) atoms have a force-fixed unit gate,
8316 // so their mass `a_k ≡ 1` is α-INDEPENDENT: every data-Jacobian column
8317 // for an ungated atom carries `a_k = 1`, NOT `π_k(α)`, so its α-exponent
8318 // is `e_k = 0`, not `k+1`. Gated atoms keep `e_k = k+1`. (The prior trace
8319 // handles ungated separately by zeroing the fixed-logit `z_jac`.)
8320 let kfac = |atom: usize| -> f64 {
8321 if self.assignment.ungated.get(atom).copied().unwrap_or(false) {
8322 0.0
8323 } else {
8324 (atom + 1) as f64
8325 }
8326 };
8327 // t–t block: Σ_{a,b} (e_a + e_b)·(H⁻¹)_{ba}·⟨J_a, J_b⟩, where the
8328 // per-atom log-prior exponent is e_k = k+1 for the #614 consistent
8329 // stick-breaking mean π_k = (α/(α+1))^(k+1) (dlogπ_k/dlogα = (k+1)·inv_alpha1).
8330 for a in 0..q {
8331 for b in 0..q {
8332 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8333 if h_ab == 0.0 {
8334 continue;
8335 }
8336 let kw = kfac(var_atom[a]) + kfac(var_atom[b]);
8337 trace += kw * inv_vv[[b, a]] * h_ab;
8338 }
8339 }
8340 // Deflation correction (kept-subspace restriction + β-Schur/rotation).
8341 // `inv_vv` is the DEFLATED selected inverse, so the t–t contraction
8342 // above contracts the RAW derivative `D` where the re-deflating
8343 // criterion uses the deflation-map derivative `DΦ[D]`. Subtract the
8344 // exact over-count `tr(inv_vv·(D − DΦ[D]))` via the Daleckii–Krein
8345 // helper, with `D_{ab} = kw_ab·⟨J_a, J_b⟩` the SAME t–t operator the
8346 // trace contracts. The t–β/β–β blocks are not deflated, so only the
8347 // t–t contraction is corrected.
8348 let dirs = cache
8349 .deflated_row_directions
8350 .get(row)
8351 .map(Vec::as_slice)
8352 .unwrap_or(&[]);
8353 if !dirs.is_empty() {
8354 let mut d_mat = Array2::<f64>::zeros((q, q));
8355 for a in 0..q {
8356 for b in 0..q {
8357 let h_ab = sae_dot(&jets.first[a], &jets.first[b]);
8358 if h_ab == 0.0 {
8359 continue;
8360 }
8361 d_mat[[a, b]] = (kfac(var_atom[a]) + kfac(var_atom[b])) * h_ab;
8362 }
8363 }
8364 let spectrum = cache
8365 .deflation_row_spectra
8366 .get(row)
8367 .and_then(Option::as_ref);
8368 trace -= Self::deflation_block_correction(&inv_vv, &d_mat, dirs, spectrum);
8369 }
8370 // t–β and β–t blocks: appear symmetrically, contract once with the
8371 // factor 2 (H, H⁻¹ symmetric; `(H⁻¹)_βt = (H⁻¹)_tβᵀ`).
8372 for a in 0..q {
8373 for (beta_pos, channel) in border.iter().enumerate() {
8374 let h_ab = sae_dot(&jets.first[a], &jets.beta[beta_pos]);
8375 if h_ab == 0.0 {
8376 continue;
8377 }
8378 let kw = kfac(var_atom[a]) + kfac(border_atom[beta_pos]);
8379 trace += 2.0 * kw * inv_vbeta[[a, channel.index]] * h_ab;
8380 }
8381 }
8382 // β–β block: Σ_{β,β'} (k_β + k_β')·(H⁻¹)_{β'β}·⟨J_β, J_β'⟩.
8383 for (beta_i, channel_i) in border.iter().enumerate() {
8384 for (beta_j, channel_j) in border.iter().enumerate() {
8385 let h_ab = sae_dot(&jets.beta[beta_i], &jets.beta[beta_j]);
8386 if h_ab == 0.0 {
8387 continue;
8388 }
8389 let kw = kfac(border_atom[beta_i]) + kfac(border_atom[beta_j]);
8390 trace += kw * beta_inv[[channel_i.index, channel_j.index]] * h_ab;
8391 }
8392 }
8393 }
8394 Ok(0.5 * inv_alpha1 * trace)
8395 }
8396
8397 pub(crate) fn add_learnable_ibp_forward_alpha_data_rhs(
8398 &self,
8399 rho: &SaeManifoldRho,
8400 target: ArrayView2<'_, f64>,
8401 cache: &ArrowFactorCache,
8402 t: &mut Array1<f64>,
8403 beta: &mut Array1<f64>,
8404 ) -> Result<(), String> {
8405 let AssignmentMode::IBPMap {
8406 temperature,
8407 learnable_alpha: true,
8408 ..
8409 } = self.assignment.mode
8410 else {
8411 return Ok(());
8412 };
8413 let alpha = self
8414 .assignment
8415 .mode
8416 .resolved_ibp_alpha(rho)
8417 .ok_or_else(|| "learnable IBP alpha resolution failed".to_string())?;
8418 let k_atoms = self.k_atoms();
8419 let p = self.output_dim();
8420 let prior = ordered_geometric_shrinkage_prior(k_atoms, alpha);
8421 let mut dprior = Array1::<f64>::zeros(k_atoms);
8422 for k in 0..k_atoms {
8423 // dπ_k/dρ for π_k = (α/(α+1))^(k+1) (#614 consistent stick-breaking
8424 // prior mean): dπ_k/dα = π_k·(k+1)/(α(α+1)), and with α = α₀·exp(ρ)
8425 // the log-α chain factor α cancels the 1/α ⇒ dπ_k/dρ = π_k·(k+1)/(α+1).
8426 dprior[k] = prior[k] * (k + 1) as f64 / (alpha + 1.0);
8427 }
8428 let inv_tau = 1.0 / temperature;
8429 let row_loss_w = self.row_loss_weights.as_deref();
8430 let whitens = self
8431 .row_metric
8432 .as_ref()
8433 .is_some_and(|metric| metric.whitens_likelihood());
8434 let border = self.border_channels_for_cache(cache)?;
8435 let mut decoded_rows = vec![vec![0.0_f64; p]; k_atoms];
8436 let mut decoded_deriv = vec![0.0_f64; p];
8437 let mut fitted = Array1::<f64>::zeros(p);
8438 let mut f_rho = Array1::<f64>::zeros(p);
8439 let mut residual = Array1::<f64>::zeros(p);
8440 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8441 let mut assignments = vec![0.0_f64; k_atoms];
8442 for row in 0..self.n_obs() {
8443 self.assignment
8444 .try_assignments_row_for_rho_into(row, rho, &mut assignments)?;
8445 fitted.fill(0.0);
8446 f_rho.fill(0.0);
8447 for k in 0..k_atoms {
8448 self.atoms[k].fill_decoded_row(row, &mut decoded_rows[k]);
8449 // Ungated (#1026 background-tier) atoms have a force-fixed unit
8450 // gate (`has_ungated` override), so their mass `a_k ≡ 1` is
8451 // α-INDEPENDENT (∂a_k/∂logα = 0). The π_k(α) chain below applies
8452 // ONLY to gated atoms, whose mass is `a_k = σ(ℓ/τ)·π_k(α)`. (NB:
8453 // frozen routing is NOT ungated — there the gate is a fixed σ(ℓ/τ)
8454 // but `a_k` still varies with α through `π_k`, so it must NOT be
8455 // skipped.)
8456 let da_rho = if self.assignment.ungated.get(k).copied().unwrap_or(false) {
8457 0.0
8458 } else {
8459 (assignments[k] / prior[k]) * dprior[k]
8460 };
8461 for out_col in 0..p {
8462 fitted[out_col] += assignments[k] * decoded_rows[k][out_col];
8463 f_rho[out_col] += da_rho * decoded_rows[k][out_col];
8464 }
8465 }
8466 for out_col in 0..p {
8467 residual[out_col] = fitted[out_col] - target[[row, out_col]];
8468 }
8469 let residual_metric = match self.row_metric.as_ref() {
8470 Some(metric) if whitens => metric.apply_metric_row(row, residual.view()),
8471 _ => residual.to_vec(),
8472 };
8473 let f_metric = match self.row_metric.as_ref() {
8474 Some(metric) if whitens => metric.apply_metric_row(row, f_rho.view()),
8475 _ => f_rho.to_vec(),
8476 };
8477 let row_weight = row_loss_w.map_or(1.0, |w| w[row]);
8478 let row_vars = self.row_vars_for_cache_row(row, cache)?;
8479 let row_base = cache.row_offsets[row];
8480 for (pos, var) in row_vars.iter().enumerate() {
8481 let mut contribution = 0.0_f64;
8482 match *var {
8483 SaeLocalRowVar::Logit { atom } => {
8484 let sigma = assignments[atom] / prior[atom];
8485 let sigma_jac = sigma * (1.0 - sigma) * inv_tau;
8486 let da_dl = sigma_jac * prior[atom];
8487 let d_da_rho_dl = sigma_jac * dprior[atom];
8488 for out_col in 0..p {
8489 contribution += da_dl * decoded_rows[atom][out_col] * f_metric[out_col];
8490 contribution += d_da_rho_dl
8491 * decoded_rows[atom][out_col]
8492 * residual_metric[out_col];
8493 }
8494 }
8495 SaeLocalRowVar::Coord { atom, axis } => {
8496 let sigma = assignments[atom] / prior[atom];
8497 let da_rho = sigma * dprior[atom];
8498 self.atoms[atom].fill_decoded_derivative_row(row, axis, &mut decoded_deriv);
8499 for out_col in 0..p {
8500 contribution +=
8501 assignments[atom] * decoded_deriv[out_col] * f_metric[out_col];
8502 contribution +=
8503 da_rho * decoded_deriv[out_col] * residual_metric[out_col];
8504 }
8505 }
8506 }
8507 t[row_base + pos] += row_weight * contribution;
8508 }
8509 for channel in &border {
8510 let phi = self.atoms[channel.atom].basis_values[[row, channel.basis_col]];
8511 let sigma = assignments[channel.atom] / prior[channel.atom];
8512 let da_rho = sigma * dprior[channel.atom];
8513 let mut contribution = 0.0_f64;
8514 for out_col in 0..p {
8515 let output = channel.output[out_col];
8516 contribution += assignments[channel.atom] * phi * output * f_metric[out_col];
8517 contribution += da_rho * phi * output * residual_metric[out_col];
8518 }
8519 beta[channel.index] += row_weight * contribution;
8520 }
8521 }
8522 Ok(())
8523 }
8524
8525 pub(crate) fn border_channels_for_cache(
8526 &self,
8527 cache: &ArrowFactorCache,
8528 ) -> Result<Vec<SaeBorderChannel>, String> {
8529 let p = self.output_dim();
8530 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8531 let offsets = if frames_active {
8532 self.factored_beta_offsets()
8533 } else {
8534 self.beta_offsets()
8535 };
8536 let mut channels = Vec::with_capacity(cache.k);
8537 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8538 let m = atom.basis_size();
8539 let frame = if frames_active {
8540 self.frame_output_matrix(atom_idx)
8541 } else {
8542 Array2::<f64>::eye(p)
8543 };
8544 let r = frame.ncols();
8545 for basis_col in 0..m {
8546 for channel in 0..r {
8547 let mut output = vec![0.0_f64; p];
8548 for out_col in 0..p {
8549 output[out_col] = frame[[out_col, channel]];
8550 }
8551 channels.push(SaeBorderChannel {
8552 atom: atom_idx,
8553 basis_col,
8554 index: offsets[atom_idx] + basis_col * r + channel,
8555 output,
8556 });
8557 }
8558 }
8559 }
8560 if channels.len() != cache.k {
8561 return Err(format!(
8562 "border channel layout has {} entries but cache border has {}",
8563 channels.len(),
8564 cache.k
8565 ));
8566 }
8567 Ok(channels)
8568 }
8569
8570 pub(crate) fn row_vars_for_cache_row(
8571 &self,
8572 row: usize,
8573 cache: &ArrowFactorCache,
8574 ) -> Result<Vec<SaeLocalRowVar>, String> {
8575 let q_row = cache.row_dims[row];
8576 let mut vars: Vec<Option<SaeLocalRowVar>> = vec![None; q_row];
8577 match self.last_row_layout {
8578 Some(ref layout) => {
8579 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8580 vars[pos] = Some(SaeLocalRowVar::Logit { atom });
8581 let start = layout.coord_starts[row][pos];
8582 let d = self.assignment.coords[atom].latent_dim();
8583 for axis in 0..d {
8584 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8585 }
8586 }
8587 }
8588 None => {
8589 let assignment_dim = self.assignment.assignment_coord_dim();
8590 let coord_offsets = self.assignment.coord_offsets();
8591 for atom in 0..assignment_dim {
8592 vars[atom] = Some(SaeLocalRowVar::Logit { atom });
8593 }
8594 for atom in 0..self.k_atoms() {
8595 let start = coord_offsets[atom];
8596 let d = self.assignment.coords[atom].latent_dim();
8597 for axis in 0..d {
8598 vars[start + axis] = Some(SaeLocalRowVar::Coord { atom, axis });
8599 }
8600 }
8601 }
8602 }
8603 vars.into_iter()
8604 .enumerate()
8605 .map(|(idx, v)| {
8606 v.ok_or_else(|| {
8607 format!("row_vars_for_cache_row: row {row} position {idx} was not mapped")
8608 })
8609 })
8610 .collect()
8611 }
8612
8613 pub(crate) fn atom_second_jets(&self) -> Result<Vec<Array4<f64>>, String> {
8614 let mut out = Vec::with_capacity(self.k_atoms());
8615 for (atom_idx, atom) in self.atoms.iter().enumerate() {
8616 let coords = self.assignment.coords[atom_idx].as_matrix();
8617 let jet = if let Some(second) = atom.basis_second_jet.as_ref() {
8618 second.second_jet(coords.view())?
8619 } else {
8620 let evaluator = atom.basis_evaluator.as_ref().ok_or_else(|| {
8621 format!(
8622 "logdet_theta_adjoint: atom '{}' has no basis evaluator for second jets",
8623 atom.name
8624 )
8625 })?;
8626 evaluator
8627 .second_jet_dyn(coords.view())
8628 .ok_or_else(|| {
8629 format!(
8630 "logdet_theta_adjoint: atom '{}' basis does not expose analytic second jets",
8631 atom.name
8632 )
8633 })??
8634 };
8635 let expected = (
8636 atom.n_obs(),
8637 atom.basis_size(),
8638 atom.latent_dim,
8639 atom.latent_dim,
8640 );
8641 if jet.dim() != expected {
8642 return Err(format!(
8643 "logdet_theta_adjoint: atom '{}' second jet shape {:?}, expected {:?}",
8644 atom.name,
8645 jet.dim(),
8646 expected
8647 ));
8648 }
8649 out.push(jet);
8650 }
8651 Ok(out)
8652 }
8653
8654 // [#780 line-count gate] The per-row jet / reconstruction-channel cluster
8655 // (`reconstruction_row_program_for_logdet`, the const-generic
8656 // reconstruction / β-border channel fills and their dynamic dispatchers,
8657 // `row_jets_for_logdet`, `row_jets_for_logdet_batch4`, `batch4_assemble`,
8658 // and `refill_jet_window`) lives in the sibling
8659 // `construction_row_jet_logdet_channels.rs` file, inlined via `include!`
8660 // below at module scope as a second `impl SaeManifoldTerm` block. Splitting
8661 // it out keeps this tracked file under the 10k limit; `include!` preserves
8662 // the identical module scope and private-field access.
8663
8664 pub(crate) fn assignment_prior_hdiag_derivative_entry(
8665 &self,
8666 rho: &SaeManifoldRho,
8667 row: usize,
8668 diag_atom: usize,
8669 wrt: SaeLocalRowVar,
8670 ibp_channels: Option<&IbpHessianDiagThirdChannels>,
8671 ) -> f64 {
8672 let SaeLocalRowVar::Logit { atom: wrt_atom } = wrt else {
8673 return 0.0;
8674 };
8675 match self.assignment.mode {
8676 AssignmentMode::Softmax { .. } => {
8677 // #1038: the softmax entropy Hessian is now stored DENSE in
8678 // `block.htt` and its full θ-derivative `∂H_{k,j}/∂z_w` (diagonal
8679 // AND off-diagonal) is added inline in `logdet_theta_adjoint` from
8680 // the shared `row_dense_hessian_logit_derivative`. Returning the
8681 // diagonal contribution here too would double-count, so this
8682 // primitive is silent for softmax — the dense path is the single
8683 // source for value, logdet, and adjoint.
8684 0.0
8685 }
8686 AssignmentMode::JumpReLU {
8687 temperature,
8688 threshold,
8689 } => {
8690 if diag_atom != wrt_atom {
8691 return 0.0;
8692 }
8693 let logit = self.assignment.logits[[row, diag_atom]];
8694 if !crate::assignment::jumprelu_in_optimization_band(
8695 logit,
8696 threshold,
8697 temperature,
8698 ) {
8699 return 0.0;
8700 }
8701 let inv_tau = 1.0 / temperature;
8702 let activation =
8703 gam_linalg::utils::stable_logistic((logit - threshold) * inv_tau);
8704 let slope = activation * (1.0 - activation);
8705 // #1415: P(ℓ)=λσ((ℓ−θ)/τ); P''(ℓ)=(λ/τ²)s(1−2a) so the third
8706 // derivative is P'''(ℓ)=(λ/τ³)·s·(1−6a+6a²), because
8707 // d/dℓ[s(1−2a)] = (1/τ)s[(1−2a)²−2s] = (1/τ)s(1−6a+6a²).
8708 rho.lambda_sparse()
8709 * slope
8710 * (1.0 - 6.0 * activation + 6.0 * activation * activation)
8711 * inv_tau
8712 * inv_tau
8713 * inv_tau
8714 }
8715 AssignmentMode::IBPMap { .. } => {
8716 // The assembled `htt` diagonal consumes
8717 // `IBPAssignmentPenalty::hessian_diag`, whose logit derivative
8718 // splits into a row-local direct-`z` channel and a global
8719 // empirical-`M_k` channel (π_k couples every row in column k).
8720 // This same-row primitive returns only the LOCAL direct-`z`
8721 // channel — and only on the matching logit (`diag_atom == w`),
8722 // since H_ik depends on no other row's z explicitly. The global
8723 // M_k channel is accumulated column-wise in
8724 // `logdet_theta_adjoint` (it needs the per-row selected-inverse
8725 // diagonals), so adding it here would double-count.
8726 if diag_atom != wrt_atom {
8727 return 0.0;
8728 }
8729 match ibp_channels {
8730 Some(ch) => ch.local_logit_third[row * ch.k_max + diag_atom],
8731 None => 0.0,
8732 }
8733 }
8734 }
8735 }
8736
8737 pub(crate) fn ard_majorized_hessian_derivative(
8738 &self,
8739 rho: &SaeManifoldRho,
8740 row: usize,
8741 atom: usize,
8742 axis: usize,
8743 ) -> f64 {
8744 if rho.log_ard[atom].is_empty() {
8745 return 0.0;
8746 }
8747 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8748 let periods = self.assignment.coords[atom].effective_axis_periods();
8749 let t = self.assignment.coords[atom].row(row)[axis];
8750 let prior = ArdAxisPrior::eval(alpha, t, periods[axis]);
8751 if prior.hess <= 0.0 {
8752 return 0.0;
8753 }
8754 match periods[axis] {
8755 None => 0.0,
8756 Some(period) => {
8757 let kappa = std::f64::consts::TAU / period;
8758 -alpha * kappa * (kappa * t).sin()
8759 }
8760 }
8761 }
8762
8763 pub fn outer_rho_gradient_ift_rhs(
8764 &self,
8765 rho: &SaeManifoldRho,
8766 target: ArrayView2<'_, f64>,
8767 j: usize,
8768 cache: &ArrowFactorCache,
8769 ) -> Result<SaeArrowVector, String> {
8770 let n_params = rho.to_flat().len();
8771 if j >= n_params {
8772 return Err(format!(
8773 "outer_rho_gradient_ift_rhs: coordinate {j} outside rho dim {n_params}"
8774 ));
8775 }
8776 let mut t = Array1::<f64>::zeros(cache.delta_t_len());
8777 let mut beta = Array1::<f64>::zeros(cache.k);
8778 if j == 0 {
8779 let assignment_grad =
8780 assignment_prior_log_strength_target_mixed(&self.assignment, rho)?;
8781 let k_atoms = self.k_atoms();
8782 let assignment_dim = self.assignment.assignment_coord_dim();
8783 for row in 0..self.n_obs() {
8784 let base = cache.row_offsets[row];
8785 let assignment_base = row * k_atoms;
8786 match self.last_row_layout {
8787 Some(ref layout) => {
8788 for (pos, &atom) in layout.active_atoms[row].iter().enumerate() {
8789 t[base + pos] = assignment_grad[assignment_base + atom];
8790 }
8791 }
8792 None => {
8793 for free_idx in 0..assignment_dim {
8794 t[base + free_idx] = assignment_grad[assignment_base + free_idx];
8795 }
8796 }
8797 }
8798 }
8799 self.add_learnable_ibp_forward_alpha_data_rhs(rho, target, cache, &mut t, &mut beta)?;
8800 } else if (1..=rho.log_lambda_smooth.len()).contains(&j) {
8801 // #1556: coordinate `j ∈ 1..=K` is the per-atom smoothness strength
8802 // `log λ_smooth[j-1]`. `∂(penalty)/∂log λ_k = λ_k·S_k C_k` touches ONLY
8803 // atom `k = j-1`'s decoder block; every other atom's RHS is zero.
8804 let target_atom = j - 1;
8805 let lambda = rho.lambda_smooth_for(target_atom);
8806 let frames_active = self.last_frames_active && cache.k == self.factored_border_dim();
8807 let offsets = if frames_active {
8808 self.factored_beta_offsets()
8809 } else {
8810 self.beta_offsets()
8811 };
8812 let atom = &self.atoms[target_atom];
8813 let m = atom.basis_size();
8814 let coeffs = if frames_active {
8815 match &atom.decoder_frame {
8816 Some(frame) => frame.project_decoder(atom.decoder_coefficients.view())?,
8817 None => atom.decoder_coefficients.clone(),
8818 }
8819 } else {
8820 atom.decoder_coefficients.clone()
8821 };
8822 let r = coeffs.ncols();
8823 let off = offsets[target_atom];
8824 for mu in 0..m {
8825 for channel in 0..r {
8826 let mut acc = 0.0_f64;
8827 for nu in 0..m {
8828 let s_sym =
8829 0.5 * (atom.smooth_penalty[[mu, nu]] + atom.smooth_penalty[[nu, mu]]);
8830 acc += s_sym * coeffs[[nu, channel]];
8831 }
8832 beta[off + mu * r + channel] = lambda * acc;
8833 }
8834 }
8835 } else {
8836 let mut cursor = 1 + rho.log_lambda_smooth.len();
8837 for atom in 0..rho.log_ard.len() {
8838 for axis in 0..rho.log_ard[atom].len() {
8839 if cursor == j {
8840 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
8841 let periods = self.assignment.coords[atom].effective_axis_periods();
8842 for row in 0..self.n_obs() {
8843 let row_t = self.assignment.coords[atom].row(row);
8844 let prior = ArdAxisPrior::eval(alpha, row_t[axis], periods[axis]);
8845 let Some(pos) = sae_coord_penalty_offset(
8846 self.last_row_layout.as_ref(),
8847 self.assignment.coord_offsets()[atom] + axis,
8848 row,
8849 atom,
8850 ) else {
8851 continue;
8852 };
8853 t[cache.row_offsets[row] + pos] = prior.grad;
8854 }
8855 return Ok(SaeArrowVector { t, beta });
8856 }
8857 cursor += 1;
8858 }
8859 }
8860 }
8861 Ok(SaeArrowVector { t, beta })
8862 }
8863
8864 pub(crate) fn logdet_theta_adjoint(
8865 &self,
8866 rho: &SaeManifoldRho,
8867 cache: &ArrowFactorCache,
8868 solver: &DeflatedArrowSolver<'_>,
8869 ) -> Result<SaeArrowVector, String> {
8870 // Γ_a = tr(H⁻¹ ∂H/∂θ_a) over the inner variables θ (#1006). `H` here is
8871 // the SAME object the evidence factor builds — Gauss-Newton data
8872 // curvature plus the prior majorizers / `hessian_diag` diagonals the
8873 // Newton/Schur Cholesky factorizes — so each block's θ-derivative channel
8874 // is differentiated on the criterion's own branch (no value/gradient
8875 // desync). The IBP-MAP assignment prior is the one block whose
8876 // `hessian_diag` couples every row in a column through the plug-in
8877 // empirical mass `M_k = Σ_i z_ik`; its logit derivative therefore has a
8878 // row-local channel (handled inline via
8879 // `assignment_prior_hdiag_derivative_entry`) and a cross-row channel
8880 // (accumulated column-wise after the row loop, below).
8881 let n = self.n_obs();
8882 let total_t = cache.delta_t_len();
8883 let mut gamma_t = Array1::<f64>::zeros(total_t);
8884 let mut gamma_beta = Array1::<f64>::zeros(cache.k);
8885 let second_jets = self.atom_second_jets()?;
8886 let border = self.border_channels_for_cache(cache)?;
8887 // #932 FRONT C: plain-arrow `(H⁻¹)_ββ = S⁻¹` formed once from the cached
8888 // Schur factor; gauge / #1038 cross-row Woodbury fall back to the per-β
8889 // `solve` loop where the row-local Takahashi blocks are not valid.
8890 let fast_selected = solver.plain_selected_inverse_available();
8891 let beta_inv = if cache.k == 0 {
8892 Array2::<f64>::zeros((0, 0))
8893 } else if fast_selected {
8894 solver
8895 .beta_inv()
8896 .map_err(|err| format!("logdet_theta_adjoint: beta selected inverse: {err}"))?
8897 } else {
8898 let mut beta_inv = Array2::<f64>::zeros((cache.k, cache.k));
8899 let rhs_t = Array1::<f64>::zeros(total_t);
8900 for col in 0..cache.k {
8901 let mut rhs_beta = Array1::<f64>::zeros(cache.k);
8902 rhs_beta[col] = 1.0;
8903 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8904 format!("logdet_theta_adjoint: beta selected inverse solve: {err}")
8905 })?;
8906 for row in 0..cache.k {
8907 beta_inv[[row, col]] = solved.beta[row];
8908 }
8909 }
8910 beta_inv
8911 };
8912 // IBP `hessian_diag` logit third-derivative channels (#1006). The full
8913 // IBP Hessian also has per-column cross-row rank-one terms
8914 // `H_(i,k),(j,k) = d_k·J_ik·J_jk`; these ARE carried in `H` via the #1038
8915 // Woodbury source (`IbpCrossRowSource`, construction.rs:4710-4752), the
8916 // ρ-trace differentiates them (#1416,
8917 // `assignment_log_strength_hessian_trace`), AND this θ-adjoint now
8918 // differentiates them exactly too: the empirical-`M_k` channel below
8919 // contracts the shared-mass coupling of the DIAGONAL curvature, and the
8920 // cross-row Woodbury pass (further below, using `cross_row_dd` and
8921 // `logit_curvature`) contracts the `∂/∂ℓ_w (d_k·J_ik·J_jk)` rank-one
8922 // derivative — so value, logdet, ρ-trace, and θ-adjoint all differentiate
8923 // the one operator `H = H₀ + Σ_k d_k u_k u_kᵀ`.
8924 let ibp_channels = ibp_assignment_third_channels(&self.assignment, rho)?;
8925 let k_atoms = self.k_atoms();
8926 // #1038 softmax entropy: the dense per-row entropy Hessian written into
8927 // `block.htt` has off-diagonal logit terms whose θ-derivative the adjoint
8928 // must contract too (not just the diagonal). Build the SAME penalty +
8929 // `scale = λ/τ²` the assembly uses so value/logdet/adjoint differentiate
8930 // one operator. `None` for non-softmax modes (their diagonal/cross-row
8931 // channels are handled by `assignment_prior_hdiag_derivative_entry` and
8932 // the IBP column pass).
8933 let softmax_dense_adjoint: Option<(
8934 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
8935 f64,
8936 )> = match self.assignment.mode {
8937 AssignmentMode::Softmax {
8938 temperature,
8939 sparsity,
8940 } if k_atoms > 1 => {
8941 let inv_tau = 1.0 / temperature;
8942 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
8943 Some((
8944 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
8945 k_atoms,
8946 temperature,
8947 ),
8948 scale,
8949 ))
8950 }
8951 _ => None,
8952 };
8953 // Per active logit position: (row i, column k, global t-index,
8954 // (H⁻¹)_ik,ik) — the inputs to the IBP cross-row empirical-`M_k` channel.
8955 let mut ibp_logit_sites: Vec<(usize, usize, usize, f64)> = Vec::new();
8956
8957 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
8958 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
8959 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
8960 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
8961 // rows fall back to the scalar per-row path (bit-identical either way).
8962 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
8963 std::collections::VecDeque::new();
8964 let mut jet_window_next = 0usize;
8965 for row in 0..n {
8966 let q = cache.row_dims[row];
8967 let base = cache.row_offsets[row];
8968 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
8969 self.assignment
8970 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
8971 if jet_window.is_empty() {
8972 jet_window_next = self.refill_jet_window(
8973 rho,
8974 jet_window_next,
8975 cache,
8976 &second_jets,
8977 &border,
8978 &mut jet_window,
8979 )?;
8980 }
8981 let jets = jet_window.pop_front().expect("jet window must be non-empty");
8982
8983 // #932 FRONT C: row-local Takahashi on the plain arrow; per-row
8984 // full-system `solve` loop under gauge / cross-row Woodbury.
8985 let (inv_vv, inv_vbeta) = if fast_selected {
8986 solver
8987 .selected_inverse_row_blocks(row, &beta_inv)
8988 .map_err(|err| {
8989 format!("logdet_theta_adjoint: selected inverse: {err}")
8990 })?
8991 } else {
8992 let mut inv_vv = Array2::<f64>::zeros((q, q));
8993 let mut inv_vbeta = Array2::<f64>::zeros((q, cache.k));
8994 for col in 0..q {
8995 let mut rhs_t = Array1::<f64>::zeros(total_t);
8996 let rhs_beta = Array1::<f64>::zeros(cache.k);
8997 rhs_t[base + col] = 1.0;
8998 let solved = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
8999 format!("logdet_theta_adjoint: selected inverse solve: {err}")
9000 })?;
9001 for r in 0..q {
9002 inv_vv[[r, col]] = solved.t[base + r];
9003 }
9004 for b in 0..cache.k {
9005 inv_vbeta[[col, b]] = solved.beta[b];
9006 }
9007 }
9008 (inv_vv, inv_vbeta)
9009 };
9010
9011 // Record each active logit's column, global t-index, and
9012 // selected-inverse diagonal (H⁻¹)_ik,ik for the IBP cross-row pass.
9013 if ibp_channels.is_some() {
9014 for (pos, var) in jets.vars.iter().enumerate() {
9015 if let SaeLocalRowVar::Logit { atom } = *var {
9016 ibp_logit_sites.push((row, atom, base + pos, inv_vv[[pos, pos]]));
9017 }
9018 }
9019 }
9020
9021 // #1419: when `w` is a logit and the assignment is softmax, the per-row
9022 // Gershgorin majorizer `D = diag(Σ_j|H_kj|)` is what the assembly wrote
9023 // into `htt` (the genuine Loewner majorizer that replaces the indefinite
9024 // exact entropy Hessian). Its full θ-derivative `∂D_{k,k}/∂z_w` (diagonal;
9025 // `∂D_kk/∂z_w = Σ_j sign(H_kj)·∂H_kj/∂z_w`) is the SAME operator the
9026 // assembly and logdet now differentiate, so value and adjoint stay on ONE
9027 // exact branch. Compute it once per logit `w` and add it at every logit
9028 // pair `(a,b)` below. The diagonal softmax case is therefore handled here,
9029 // NOT in `assignment_prior_hdiag_derivative_entry` (which returns 0 for
9030 // softmax to avoid double-counting).
9031 // #1410: the softmax majorizer θ-derivative `∂D_kk/∂z_w` is DIAGONAL
9032 // (`D` is diagonal), and the compact adjoint reads it only for this
9033 // row's `≤ top_k` active atoms. Compute the needed diagonal entry
9034 // directly from the softmax row `a` (= `assignments`, in hand) via
9035 // `active_softmax_majorizer_logit_derivative_entry`, instead of the old
9036 // per-(row, logit) full `K×K` `row_psd_majorizer_logit_derivative`
9037 // allocation. `m = Σ_j a_j l_j` is shared across all `(w, k)` pairs of
9038 // the row, so compute it once. `inv_tau` carries the softmax `∂a/∂z`
9039 // convention.
9040 let softmax_adjoint_row: Option<(&[f64], f64, f64, f64)> =
9041 match (softmax_dense_adjoint.as_ref(), self.assignment.mode) {
9042 (Some((_penalty, scale)), AssignmentMode::Softmax { temperature, .. }) => {
9043 let a = assignments
9044 .as_slice()
9045 .expect("softmax assignments row must be contiguous");
9046 let m = softmax_majorizer_log_mean(a);
9047 Some((a, m, *scale, 1.0 / temperature))
9048 }
9049 _ => None,
9050 };
9051 // Per-row UNIT-stiffness deflated directions: the selected inverse
9052 // `inv_vv` is the DEFLATED inverse (it assigns `1/λ̃ = 1` to each
9053 // `vᵢ`), so every `inv_vv`-weighted t–t contraction of `∂H/∂θ_w`
9054 // below spuriously contracts the RAW derivative where the re-deflating
9055 // criterion uses the deflation-map derivative `DΦ`. The kept-subspace Γ
9056 // subtracts `tr(inv_vv·(D − DΦ[D]))` over the t–t block via the same
9057 // Daleckii–Krein helper the ρ-traces use (the t–β / β–β blocks are not
9058 // deflated). `θ` enters only the per-row block (no cross-row Woodbury
9059 // self-downdate on the θ path), so the raw t–t derivative `D` is used
9060 // directly.
9061 let defl_dirs = cache
9062 .deflated_row_directions
9063 .get(row)
9064 .map(Vec::as_slice)
9065 .unwrap_or(&[]);
9066 let defl_spectrum = cache
9067 .deflation_row_spectra
9068 .get(row)
9069 .and_then(Option::as_ref);
9070 for w in 0..q {
9071 let mut gamma = 0.0_f64;
9072 // The active logit `w` differentiates against; `None` unless this
9073 // slot is a softmax logit on the softmax path.
9074 let softmax_d_dw: Option<(&[f64], f64, f64, f64, usize)> =
9075 match (softmax_adjoint_row, jets.vars[w]) {
9076 (Some((a, m, scale, inv_tau)), SaeLocalRowVar::Logit { atom: atom_w }) => {
9077 Some((a, m, scale, inv_tau, atom_w))
9078 }
9079 _ => None,
9080 };
9081 let mut dh_mat = Array2::<f64>::zeros((q, q));
9082 for a in 0..q {
9083 for b in 0..q {
9084 let mut dh = sae_dot(&jets.second[a][w], &jets.first[b])
9085 + sae_dot(&jets.first[a], &jets.second[b][w]);
9086 // `∂D/∂z_w` is diagonal, so it contributes only when the two
9087 // logit slots are the SAME atom (`atom_a == atom_b`).
9088 if let (
9089 Some((a_soft, m, scale, inv_tau, _atom_w)),
9090 SaeLocalRowVar::Logit { atom: atom_a },
9091 SaeLocalRowVar::Logit { atom: atom_b },
9092 ) = (softmax_d_dw, jets.vars[a], jets.vars[b])
9093 {
9094 if atom_a == atom_b {
9095 dh += active_softmax_majorizer_logit_derivative_entry(
9096 a_soft, atom_a, _atom_w, m, scale, inv_tau,
9097 );
9098 }
9099 }
9100 if a == b {
9101 dh += match jets.vars[a] {
9102 SaeLocalRowVar::Logit { atom } => self
9103 .assignment_prior_hdiag_derivative_entry(
9104 rho,
9105 row,
9106 atom,
9107 jets.vars[w],
9108 ibp_channels.as_ref(),
9109 ),
9110 SaeLocalRowVar::Coord { atom, axis } if a == w => {
9111 self.ard_majorized_hessian_derivative(rho, row, atom, axis)
9112 }
9113 _ => 0.0,
9114 };
9115 }
9116 dh_mat[[a, b]] = dh;
9117 gamma += inv_vv[[b, a]] * dh;
9118 }
9119 }
9120 if !defl_dirs.is_empty() {
9121 gamma -= Self::deflation_block_correction(
9122 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9123 );
9124 }
9125 for a in 0..q {
9126 for (beta_pos, channel) in border.iter().enumerate() {
9127 let dh = sae_dot(&jets.second[a][w], &jets.beta[beta_pos])
9128 + sae_dot(&jets.first[a], &jets.beta_deriv[w][beta_pos]);
9129 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9130 }
9131 }
9132 for (beta_i, channel_i) in border.iter().enumerate() {
9133 for (beta_j, channel_j) in border.iter().enumerate() {
9134 let dh = sae_dot(&jets.beta_deriv[w][beta_i], &jets.beta[beta_j])
9135 + sae_dot(&jets.beta[beta_i], &jets.beta_deriv[w][beta_j]);
9136 gamma += beta_inv[[channel_i.index, channel_j.index]] * dh;
9137 }
9138 }
9139 gamma_t[base + w] = gamma;
9140 }
9141
9142 for (w_beta_pos, w_channel) in border.iter().enumerate() {
9143 let mut gamma = 0.0_f64;
9144 let mut dh_mat = Array2::<f64>::zeros((q, q));
9145 for a in 0..q {
9146 for b in 0..q {
9147 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.first[b])
9148 + sae_dot(&jets.first[a], &jets.beta_l_deriv[b][w_beta_pos]);
9149 dh_mat[[a, b]] = dh;
9150 gamma += inv_vv[[b, a]] * dh;
9151 }
9152 }
9153 if !defl_dirs.is_empty() {
9154 gamma -= Self::deflation_block_correction(
9155 &inv_vv, &dh_mat, defl_dirs, defl_spectrum,
9156 );
9157 }
9158 for a in 0..q {
9159 for (beta_pos, channel) in border.iter().enumerate() {
9160 let dh = sae_dot(&jets.beta_l_deriv[a][w_beta_pos], &jets.beta[beta_pos]);
9161 gamma += 2.0 * inv_vbeta[[a, channel.index]] * dh;
9162 }
9163 }
9164 gamma_beta[w_channel.index] += gamma;
9165 }
9166 }
9167
9168 // IBP cross-row empirical-`M_k` channel of Γ (#1006). The assembled
9169 // diagonal H_ik consumes `hessian_diag`, whose dependence on the column
9170 // mass M_k = Σ_i z_ik couples every row in a column. Differentiating
9171 // tr(H⁻¹ ∂H/∂ℓ_wk) on that shared branch:
9172 // Γ_wk += [ Σ_i (H⁻¹)_ik,ik · ∂_M H_ik ] · J_wk = C_k · J_wk,
9173 // where ∂_M H_ik = `m_channel[i*K+k]` and J_wk = `z_jac[w*K+k]`. The
9174 // row-local direct-`z` channel was already added inline above, so this
9175 // pass adds only the cross-row remainder (it spans `w ≠ i` and the
9176 // self-row M_k self-coupling, which the row-local primitive deliberately
9177 // omits to avoid double-counting).
9178 if let Some(channels) = ibp_channels.as_ref() {
9179 let mut col_coeff = vec![0.0_f64; k_atoms];
9180 for &(row, atom, _t_index, inv_diag) in &ibp_logit_sites {
9181 col_coeff[atom] += inv_diag * channels.m_channel[row * k_atoms + atom];
9182 }
9183 for &(row, atom, t_index, _inv_diag) in &ibp_logit_sites {
9184 gamma_t[t_index] += col_coeff[atom] * channels.z_jac[row * k_atoms + atom];
9185 }
9186
9187 // #1416 / #1641: the EXACT cross-row Woodbury derivative of Γ. The
9188 // assembled `H` carries the per-column rank-one block
9189 // `W_k = d_k·u_k u_kᵀ` with `u_k` the J-weighted column indicator
9190 // (`u_k[slot(i,k)] = J_ik`) and `d_k = w·s'_k` (`cross_row_d[k]`). Both
9191 // `d_k` (through `M_k`) and the `u_k` entries (through `ℓ_ik`) depend on
9192 // the logits, so
9193 // ∂W_k/∂ℓ_wk = dd_k·J_wk·u_k u_kᵀ
9194 // + d_k·c_wk·(e_w u_kᵀ + u_k e_wᵀ),
9195 // where `dd_k = ∂d_k/∂M_k = w·s''_k` (`cross_row_dd[k]`),
9196 // `c_wk = ∂J_wk/∂ℓ_wk` (`logit_curvature`), and `e_w` is the unit
9197 // vector at row `w`'s logit-`k` slot.
9198 //
9199 // The θ-adjoint contracts the FULL trace `Γ_wk = tr(H⁻¹ ∂H/∂ℓ_wk)`
9200 // (NOT the `½ tr` the ρ-trace uses — `fixed_state_logdet` differentiates
9201 // the full `log|H|`, and the per-row blocks above contract `inv_vv·dh`
9202 // with no ½). Critically, the `i=j` self curvature `w·s'_k·J_ik²` of the
9203 // rank-one block lives on the assembled `htt` DIAGONAL `H_ik`, so its
9204 // derivative is ALREADY differentiated by the row-local
9205 // `local_logit_third` channel (direct-z, `i=w`) and the `m_channel`
9206 // column pass (via `M_k`) above. This Woodbury pass must therefore add
9207 // ONLY the off-diagonal `i≠j` remainder — otherwise the self term is
9208 // double-counted (the #1641 defect: the pre-fix pass summed the full
9209 // `u_k u_kᵀ` including `i=j`, AND carried the ρ-trace ½, AND dropped the
9210 // factor 2 on the symmetric `e_w u_kᵀ + u_k e_wᵀ` term). Excluding `i=j`
9211 // is also why this pass needs no deflation correction: it contracts only
9212 // DISTINCT rows, off any single-row `vᵢ`'s support (matching the
9213 // #1416 ρ-trace cross-row pass).
9214 //
9215 // Contracting `tr(H⁻¹ ∂W_k/∂ℓ_wk)` over `i≠j` only:
9216 // Γ_wk += dd_k·J_wk·( u_kᵀ H⁻¹ u_k − Σ_i P_ii·J_ik² ) (term A)
9217 // + 2·d_k·c_wk·( (H⁻¹ u_k)_{slot(w,k)} − P_ww·J_wk ) (term B),
9218 // where `P_ii = (H⁻¹)_{slot(i,k),slot(i,k)}` is the selected-inverse
9219 // diagonal recorded in `ibp_logit_sites`. The subtracted self pieces are
9220 // exactly the `i=j` terms the diagonal channels own. Both `u_kᵀ H⁻¹ u_k`
9221 // and `(H⁻¹ u_k)` come from ONE solve per column, `x_k = H⁻¹ u_k` — so
9222 // the adjoint differentiates the SAME `H = H₀ + Σ_k W_k` the
9223 // value/logdet use, closing the one-operator contract on the rank-one
9224 // block too.
9225 //
9226 // Group the column sites once (the layout is mode-agnostic: dense or
9227 // compact, `ibp_logit_sites` already carries each active logit's global
9228 // t-index and selected-inverse diagonal), then per column build `u_k`,
9229 // solve, and distribute.
9230 let total_t = cache.delta_t_len();
9231 let mut col_sites: Vec<Vec<(usize, usize, f64)>> = vec![Vec::new(); k_atoms];
9232 for &(row, atom, t_index, inv_diag) in &ibp_logit_sites {
9233 col_sites[atom].push((row, t_index, inv_diag));
9234 }
9235 for atom in 0..k_atoms {
9236 let d_k = channels.cross_row_d[atom];
9237 let dd_k = channels.cross_row_dd[atom];
9238 if col_sites[atom].is_empty() || (d_k == 0.0 && dd_k == 0.0) {
9239 continue;
9240 }
9241 // u_k as a full t-RHS: J at each active logit-k slot.
9242 let mut rhs_t = Array1::<f64>::zeros(total_t);
9243 let rhs_beta = Array1::<f64>::zeros(cache.k);
9244 for &(row, t_index, _inv_diag) in &col_sites[atom] {
9245 rhs_t[t_index] = channels.z_jac[row * k_atoms + atom];
9246 }
9247 let x_k = solver.solve(rhs_t.view(), rhs_beta.view()).map_err(|err| {
9248 format!("logdet_theta_adjoint: IBP cross-row Woodbury solve: {err}")
9249 })?;
9250 // (JᵀH⁻¹J)_k = u_kᵀ x_k, and the `i=j` self quadratic
9251 // Σ_i P_ii·J_ik² the diagonal channels already own.
9252 let mut jt_hinv_j = 0.0_f64;
9253 let mut self_quad = 0.0_f64;
9254 for &(row, t_index, inv_diag) in &col_sites[atom] {
9255 let j_ik = channels.z_jac[row * k_atoms + atom];
9256 jt_hinv_j += j_ik * x_k.t[t_index];
9257 self_quad += inv_diag * j_ik * j_ik;
9258 }
9259 for &(row, t_index, inv_diag) in &col_sites[atom] {
9260 let j_wk = channels.z_jac[row * k_atoms + atom];
9261 let c_wk = channels.logit_curvature[row * k_atoms + atom];
9262 // term A (off-diagonal dd) + term B (off-diagonal d·c), both with
9263 // their `i=j` self piece removed (owned by the diagonal channels).
9264 gamma_t[t_index] += dd_k * j_wk * (jt_hinv_j - self_quad)
9265 + 2.0 * d_k * c_wk * (x_k.t[t_index] - inv_diag * j_wk);
9266 }
9267 }
9268 }
9269
9270 Ok(SaeArrowVector {
9271 t: gamma_t,
9272 beta: gamma_beta,
9273 })
9274 }
9275
9276 /// #1418: apply the EXACT stationarity-Jacobian correction `ΔC·v = (A − B)·v`
9277 /// to a joint `(t, β)` vector, matrix-free and per row.
9278 ///
9279 /// `A = ∇²_θθ L` is the true inner-fit Hessian; `B` is the assembled
9280 /// evidence/Newton operator the solver factors. They differ ONLY by the three
9281 /// curvature substitutions the assembly makes for stability:
9282 /// 1. data: `B` uses Gauss-Newton `J̃J̃ᵀ`, dropping the residual curvature
9283 /// `R[a,b] = Σ_out r_out·∂²f_out/∂θ_a∂θ_b` (t–t via `jets.second`, t–β via
9284 /// `jets.beta_deriv`; the decoder is linear in β so the β–β block is 0);
9285 /// 2. softmax: `B` uses the Gershgorin majorizer `D = diag(Σ_j|H_kj|)`,
9286 /// dropping `H_entropy − D` (#1419);
9287 /// 3. periodic ARD: `B` uses `max(V'',0)`, dropping the negative part
9288 /// `min(V'',0)` (the indefinite tail past a quarter period).
9289 /// `ΔC` is the sum of exactly these three deltas, each built from the SAME
9290 /// jets / penalty curvatures the assembly and the θ-adjoint use, so
9291 /// `A = B + ΔC` is the one true Hessian. Exact on BOTH the isotropic and the
9292 /// whitened-metric paths: the data fit is `½ r_nᵀ M_n r_n`, so the residual
9293 /// curvature is `Σ_out (M_n r_n)_out·∂²f_out/∂θ_a∂θ_b` — contract the
9294 /// metric-applied √w-scaled residual `error_metric = √w·M_n r_n` (the SAME
9295 /// quantity the assembly's β-tier gradient uses) against the RAW second jets
9296 /// `jets.second`/`jets.beta_deriv` (the same raw-jet convention the whole
9297 /// θ-adjoint and the Gauss-Newton `htt = J̃J̃ᵀ = J M Jᵀ` assembly use). On the
9298 /// isotropic path `M_n = I` so `error_metric = √w·r` and `J M Jᵀ = JJᵀ`,
9299 /// recovering the plain case. The softmax / ARD deltas are logit/coord-space
9300 /// prior curvatures and carry no output metric, so they are path-independent.
9301 fn apply_exact_hessian_minus_b(
9302 &self,
9303 rho: &SaeManifoldRho,
9304 target: ArrayView2<'_, f64>,
9305 cache: &ArrowFactorCache,
9306 v: &SaeArrowVector,
9307 ) -> Result<SaeArrowVector, String> {
9308 let p = self.output_dim();
9309 let n = self.n_obs();
9310 let k_atoms = self.k_atoms();
9311 let total_t = cache.delta_t_len();
9312 let second_jets = self.atom_second_jets()?;
9313 let border = self.border_channels_for_cache(cache)?;
9314 let row_loss_w = self.row_loss_weights.as_deref();
9315 let ard_axis_periods: Vec<Vec<Option<f64>>> = self
9316 .assignment
9317 .coords
9318 .iter()
9319 .map(|coord| coord.effective_axis_periods())
9320 .collect();
9321
9322 // Optional softmax exact-entropy-minus-majorizer delta operator (#1419).
9323 let softmax_delta: Option<(
9324 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty,
9325 f64,
9326 )> = match self.assignment.mode {
9327 AssignmentMode::Softmax {
9328 temperature,
9329 sparsity,
9330 } if k_atoms > 1 => {
9331 let inv_tau = 1.0 / temperature;
9332 let scale = rho.lambda_sparse() * sparsity * inv_tau * inv_tau;
9333 Some((
9334 gam_terms::analytic_penalties::SoftmaxAssignmentSparsityPenalty::new(
9335 k_atoms,
9336 temperature,
9337 ),
9338 scale,
9339 ))
9340 }
9341 _ => None,
9342 };
9343
9344 let mut out = SaeArrowVector {
9345 t: Array1::<f64>::zeros(total_t),
9346 beta: Array1::<f64>::zeros(cache.k),
9347 };
9348 let whitens = self
9349 .row_metric
9350 .as_ref()
9351 .is_some_and(|metric| metric.whitens_likelihood());
9352 let mut decoded = vec![0.0_f64; p];
9353 let mut fitted = Array1::<f64>::zeros(p);
9354 let mut error = Array1::<f64>::zeros(p);
9355 // #1557 — reuse one K-sized scratch row across all N rows (alias-free).
9356 let mut assignments = Array1::<f64>::zeros(self.k_atoms());
9357 // #932 SIMD: jets are built in aligned 4-row SIMD batches through a
9358 // bounded (≤4-row) look-ahead window; unaligned / non-softmax / remainder
9359 // rows fall back to the scalar per-row path (bit-identical either way).
9360 let mut jet_window: std::collections::VecDeque<SaeRowJets> =
9361 std::collections::VecDeque::new();
9362 let mut jet_window_next = 0usize;
9363 for row in 0..n {
9364 let q = cache.row_dims[row];
9365 let base = cache.row_offsets[row];
9366 let a_scratch = assignments.as_slice_mut().expect("contiguous scratch");
9367 self.assignment
9368 .try_assignments_row_for_rho_into(row, rho, a_scratch)?;
9369 if jet_window.is_empty() {
9370 jet_window_next = self.refill_jet_window(
9371 rho,
9372 jet_window_next,
9373 cache,
9374 &second_jets,
9375 &border,
9376 &mut jet_window,
9377 )?;
9378 }
9379 let jets = jet_window.pop_front().expect("jet window must be non-empty");
9380 let sqrt_row_w = row_loss_w.map_or(1.0, |w| w[row].sqrt());
9381
9382 // √w-scaled metric-applied per-row residual `error_metric = √w·M_n r_n`
9383 // (the SAME object the assembly's β-tier gradient contracts). The
9384 // data-fit `½ r_nᵀ M_n r_n` has residual curvature `Σ (M_n r_n)·∂²f`,
9385 // so this is exactly the residual contracted against the raw `∂²f`
9386 // jets. `M_n = I` on the isotropic path ⇒ `error_metric = √w·r`.
9387 fitted.fill(0.0);
9388 for k in 0..k_atoms {
9389 self.atoms[k].fill_decoded_row(row, &mut decoded);
9390 let a_k = assignments[k];
9391 for out_col in 0..p {
9392 fitted[out_col] += a_k * decoded[out_col];
9393 }
9394 }
9395 for out_col in 0..p {
9396 error[out_col] = sqrt_row_w * (fitted[out_col] - target[[row, out_col]]);
9397 }
9398 let error_metric: Vec<f64> = match self.row_metric.as_ref() {
9399 Some(metric) if whitens => metric.apply_metric_row(row, error.view()),
9400 _ => error.to_vec(),
9401 };
9402
9403 // Local t-slice of `v` for this row.
9404 let v_t: Vec<f64> = (0..q).map(|c| v.t[base + c]).collect();
9405
9406 // (1a) residual curvature, t–t: ΔC_tt[a,b] = ⟨r, ∂²f_ab⟩.
9407 for a in 0..q {
9408 let mut acc = 0.0_f64;
9409 for b in 0..q {
9410 let r_ab = sae_dot(&error_metric, &jets.second[a][b]);
9411 acc += r_ab * v_t[b];
9412 }
9413 out.t[base + a] += acc;
9414 }
9415 // (1b) residual curvature, t–β and β–t: ΔC_tβ[a,β] = ⟨r, ∂²f_aβ⟩.
9416 // `jets.beta_deriv[a][β]` = ∂(∂f/∂β_β)/∂θ_a (the mixed second jet).
9417 for a in 0..q {
9418 for (beta_pos, channel) in border.iter().enumerate() {
9419 let r_ab = sae_dot(&error_metric, &jets.beta_deriv[a][beta_pos]);
9420 // t row picks up β leg of v; β row picks up t leg of v.
9421 out.t[base + a] += r_ab * v.beta[channel.index];
9422 out.beta[channel.index] += r_ab * v_t[a];
9423 }
9424 }
9425
9426 // (2) softmax: ΔC_logit = (H_entropy − D) over the free logits, where
9427 // `D = diag(Σ_j|H_kj|)` is the Gershgorin majorizer the assembled `B`
9428 // wrote into the logit block (#1419). Adding `H_entropy − D` recovers the
9429 // EXACT entropy curvature `A = B + ΔC`, so the solver's exact-Hessian
9430 // correction differentiates the SAME operator the assembly installed.
9431 if let Some((_penalty, scale)) = softmax_delta.as_ref() {
9432 let assignment_dim = self.assignment.assignment_coord_dim();
9433 // #1410: the correction only contracts the ACTIVE logit slots
9434 // (`jets.vars` carries the row's `≤ top_k` active atoms on the
9435 // compact layout), so build only the active sub-block of
9436 // `ΔC = H_entropy − D` ENTRY-WISE rather than materialising the
9437 // full `K×K` `row_dense_hessian` / `row_psd_majorizer` matrices per
9438 // row (an `O(K²)`-per-row allocation that defeated the compact
9439 // contract at the LLM shape). `D` is diagonal, so it subtracts only
9440 // on `ka == kb`; the off-diagonal `H_entropy` entries come from the
9441 // shared `(a, l, m)` algebra. The softmax row `a_soft` is the one
9442 // irreducible `O(K)` term, computed once per row.
9443 // #1557 — reuse this iteration's `assignments` (bit-identical).
9444 let a_soft = assignments
9445 .as_slice()
9446 .expect("softmax assignments row must be contiguous");
9447 let m = softmax_majorizer_log_mean(a_soft);
9448 for (a, va) in jets.vars.iter().enumerate() {
9449 let SaeLocalRowVar::Logit { atom: ka } = *va else {
9450 continue;
9451 };
9452 if ka >= assignment_dim {
9453 continue;
9454 }
9455 let mut acc = 0.0_f64;
9456 for (b, vb) in jets.vars.iter().enumerate() {
9457 let SaeLocalRowVar::Logit { atom: kb } = *vb else {
9458 continue;
9459 };
9460 if kb >= assignment_dim {
9461 continue;
9462 }
9463 let h_entropy =
9464 softmax_dense_entropy_hessian_entry(a_soft, ka, kb, m, *scale);
9465 // `D` is the diagonal Gershgorin majorizer (#1419), so it
9466 // contributes only on the diagonal `ka == kb`.
9467 let delta = if ka == kb {
9468 h_entropy
9469 - active_softmax_gershgorin_majorizer_entry(a_soft, ka, m, *scale)
9470 } else {
9471 h_entropy
9472 };
9473 acc += delta * v_t[b];
9474 }
9475 out.t[base + a] += acc;
9476 }
9477 }
9478
9479 // (3) periodic ARD: ΔC_coord = (V'' − max(V'',0)) = min(V'',0), diagonal.
9480 for (a, va) in jets.vars.iter().enumerate() {
9481 let SaeLocalRowVar::Coord { atom, axis } = *va else {
9482 continue;
9483 };
9484 if rho.log_ard[atom].is_empty() {
9485 continue;
9486 }
9487 let alpha = SaeManifoldRho::stable_exp_strength(rho.log_ard[atom][axis]);
9488 let t_val = self.assignment.coords[atom].row(row)[axis];
9489 let prior = ArdAxisPrior::eval(alpha, t_val, ard_axis_periods[atom][axis]);
9490 let neg = prior.hess.min(0.0);
9491 if neg != 0.0 {
9492 out.t[base + a] += neg * v_t[a];
9493 }
9494 }
9495 }
9496 Ok(out)
9497 }
9498
9499 /// #1418: matrix-free apply of the EXACT stationarity Jacobian `A = ∇²_θθ L`:
9500 /// `A v = B v + ΔC v`, the assembled arrow Hessian apply
9501 /// ([`apply_cached_arrow_hessian`]) plus the matrix-free dropped-curvature
9502 /// correction `ΔC = A − B` ([`Self::apply_exact_hessian_minus_b`]).
9503 fn apply_exact_hessian(
9504 &self,
9505 rho: &SaeManifoldRho,
9506 target: ArrayView2<'_, f64>,
9507 cache: &ArrowFactorCache,
9508 v: &SaeArrowVector,
9509 ) -> Result<SaeArrowVector, String> {
9510 let b_v = apply_cached_arrow_hessian(cache, v.t.view(), v.beta.view())?;
9511 let dc_v = self.apply_exact_hessian_minus_b(rho, target, cache, v)?;
9512 Ok(SaeArrowVector {
9513 t: &b_v.t + &dc_v.t,
9514 beta: &b_v.beta + &dc_v.beta,
9515 })
9516 }
9517
9518 /// #1418: solve `A x = rhs` for the EXACT stationarity Jacobian `A = ∇²_θθ L`
9519 /// via `B`-preconditioned CG ([`solve_b_preconditioned_cg`]) with the
9520 /// matrix-free `A v = B v + ΔC v` apply ([`Self::apply_exact_hessian`]). The
9521 /// IFT step `θ̂_ρ = −A⁻¹ g_ρ` must invert the EXACT `A`, not the surrogate `B`;
9522 /// CG converges for any `ρ(B⁻¹ΔC)`, where the earlier Neumann series diverged
9523 /// once the dropped curvature `ΔC = ⟨r, ∂²f⟩` grew (large unmodellable residual).
9524 fn solve_exact_stationarity(
9525 &self,
9526 rho: &SaeManifoldRho,
9527 target: ArrayView2<'_, f64>,
9528 cache: &ArrowFactorCache,
9529 solver: &DeflatedArrowSolver<'_>,
9530 rhs: &SaeArrowVector,
9531 ) -> Result<SaeArrowVector, String> {
9532 solve_b_preconditioned_cg(solver, rhs, |v| {
9533 self.apply_exact_hessian(rho, target, cache, v)
9534 })
9535 }
9536
9537 /// Analytic SAE REML outer-ρ gradient components at the already converged
9538 /// inner state represented by `loss` and `cache`.
9539 ///
9540 /// The returned gradient is the assembled analytic outer derivative:
9541 /// explicit penalty terms, direct logdet traces, Occam terms, and the #1006
9542 /// implicit-state third-order correction.
9543 pub(crate) fn analytic_outer_rho_gradient_components(
9544 &self,
9545 target: ArrayView2<'_, f64>,
9546 rho: &SaeManifoldRho,
9547 loss: &SaeManifoldLoss,
9548 cache: &ArrowFactorCache,
9549 solver: &DeflatedArrowSolver<'_>,
9550 ) -> Result<SaeOuterRhoGradientComponents, OuterGradientError> {
9551 let n_params = rho.to_flat().len();
9552 let mut explicit = Array1::<f64>::zeros(n_params);
9553 let mut logdet_trace = Array1::<f64>::zeros(n_params);
9554 let mut occam = Array1::<f64>::zeros(n_params);
9555 let mut third_order_correction = Array1::<f64>::zeros(n_params);
9556
9557 explicit[0] = assignment_prior_log_strength_derivative(&self.assignment, rho)
9558 + self
9559 .learnable_ibp_forward_alpha_data_derivative(rho, target)
9560 .map_err(OuterGradientError::internal)?;
9561 // #1417: the FULL `½ tr(H⁻¹ ∂H/∂logα)` for the assignment coordinate.
9562 // For LEARNABLE IBP alpha the forward assignments `a_ik = σ(ℓ/τ)·π_k(α)`
9563 // carry an explicit α-dependence (`∂logπ_k/∂logα = k/(α+1)`), so BOTH the
9564 // assignment-prior Hessian AND the data Gauss-Newton blocks
9565 // `H_ββ`, `H_tβ`, `H_tt` depend on logα. We assemble both traces:
9566 // • prior: `assignment_log_strength_hessian_trace`,
9567 // • data: `learnable_ibp_data_logdet_alpha_trace` (#1417), using the
9568 // exact `(k_a+k_b)/(α+1)` block-scaling identity.
9569 // For FIXED alpha (and non-IBP modes) the data term is identically zero,
9570 // so the fixed-alpha gradient is unchanged and exact.
9571 logdet_trace[0] = self
9572 .assignment_log_strength_hessian_trace(rho, cache, solver)
9573 .map_err(OuterGradientError::internal)?
9574 + self
9575 .learnable_ibp_data_logdet_alpha_trace(rho, cache, solver)
9576 .map_err(OuterGradientError::internal)?;
9577
9578 // #1556: λ_smooth is per-atom, so the smoothness gradient block occupies
9579 // flat indices `1..1+K` (one per atom), not a single index 1. Each atom
9580 // `k` carries its own explicit penalty-energy derivative, log|H| trace,
9581 // and Occam-normalizer derivative.
9582 let k_smooth = rho.log_lambda_smooth.len();
9583 let lambda_smooth_vec = rho.lambda_smooth_vec();
9584 // Explicit `∂loss.smoothness/∂log λ_k = 0.5·λ_k·<B_k, S_k B_k>` (the
9585 // per-atom split). Its sum is the λ-scaled penalty energy; renormalize to
9586 // `loss.smoothness` so the total matches the criterion's reported energy
9587 // bit-for-bit (folding in any minibatch `penalty_scale` baked into it).
9588 let mut smooth_explicit = self.decoder_smoothness_value_per_atom(&lambda_smooth_vec);
9589 let smooth_explicit_sum: f64 = smooth_explicit.iter().sum();
9590 if smooth_explicit_sum.abs() > 0.0 {
9591 let renorm = loss.smoothness / smooth_explicit_sum;
9592 for v in smooth_explicit.iter_mut() {
9593 *v *= renorm;
9594 }
9595 }
9596 let smooth_logdet = self
9597 .decoder_smoothness_effective_dof_with_solver_per_atom(
9598 cache,
9599 solver,
9600 &lambda_smooth_vec,
9601 )
9602 .map_err(|err| OuterGradientError::InternalInvariant {
9603 reason: format!("analytic_outer_rho_gradient_components: {err}"),
9604 })?;
9605 let smooth_occam = self
9606 .reml_occam_log_lambda_smooth_derivative()
9607 .map_err(OuterGradientError::internal)?;
9608 for atom_idx in 0..k_smooth {
9609 explicit[1 + atom_idx] = smooth_explicit[atom_idx];
9610 logdet_trace[1 + atom_idx] = 0.5 * smooth_logdet[atom_idx];
9611 occam[1 + atom_idx] = -smooth_occam[atom_idx];
9612 }
9613
9614 let ard_explicit = self
9615 .ard_log_precision_explicit_derivatives(rho)
9616 .map_err(OuterGradientError::internal)?;
9617 let ard_trace = self
9618 .ard_log_precision_hessian_trace(rho, cache, solver)
9619 .map_err(|err| OuterGradientError::InternalInvariant {
9620 reason: format!("analytic_outer_rho_gradient_components: {err}"),
9621 })?;
9622 let mut cursor = 1 + k_smooth;
9623 for k in 0..rho.log_ard.len() {
9624 for axis in 0..rho.log_ard[k].len() {
9625 explicit[cursor] = ard_explicit[k][axis];
9626 logdet_trace[cursor] = ard_trace[k][axis];
9627 cursor += 1;
9628 }
9629 }
9630
9631 let gamma = self
9632 .logdet_theta_adjoint(rho, cache, solver)
9633 .map_err(OuterGradientError::internal)?;
9634 // #1418: the implicit-function correction is `−½·Γᵀ·θ̂_ρ` with
9635 // `θ̂_ρ = −A⁻¹ g_ρ`, where `A = ∇²_θθ L` is the EXACT stationarity
9636 // Jacobian of the inner fit — data residual curvature, exact softmax
9637 // entropy Hessian, exact periodic ARD curvature. The matrix the `solver`
9638 // factors is `B` (Gauss-Newton data curvature, softmax Fisher metric,
9639 // `max(V'',0)` ARD majorizers): the `½log|B|` Laplace term is consistent
9640 // with `Γ = ½tr(B⁻¹ ∂B/∂θ)`, but the implicit step is governed by `A`.
9641 // `solve_exact_stationarity` applies the TRUE `A⁻¹` via a B⁻¹-
9642 // preconditioned Neumann fixed point (`A = B + ΔC`,
9643 // `ΔC = apply_exact_hessian_minus_b`), so the correction is no longer
9644 // biased by `(B⁻¹ − A⁻¹)`.
9645 for coord in 0..n_params {
9646 let rhs = self
9647 .outer_rho_gradient_ift_rhs(rho, target, coord, cache)
9648 .map_err(OuterGradientError::internal)?;
9649 let solved = self
9650 .solve_exact_stationarity(rho, target, cache, solver, &rhs)
9651 .map_err(OuterGradientError::internal)?;
9652 let mut dot = 0.0_f64;
9653 for idx in 0..gamma.t.len() {
9654 dot += gamma.t[idx] * solved.t[idx];
9655 }
9656 for idx in 0..gamma.beta.len() {
9657 dot += gamma.beta[idx] * solved.beta[idx];
9658 }
9659 third_order_correction[coord] = -0.5 * dot;
9660 }
9661
9662 Ok(SaeOuterRhoGradientComponents {
9663 explicit,
9664 logdet_trace,
9665 occam,
9666 third_order_correction,
9667 })
9668 }
9669
9670 /// Public analytic outer-ρ gradient at a converged inner state, constructing
9671 /// the deflated arrow solver from the supplied cache. Use this seam from
9672 /// integration tests and external consumers that have a converged
9673 /// `(loss, cache)` from [`Self::reml_criterion_with_cache`] but no access to
9674 /// the crate-private `DeflatedArrowSolver`.
9675 pub fn analytic_outer_rho_gradient_at_converged(
9676 &self,
9677 target: ArrayView2<'_, f64>,
9678 rho: &SaeManifoldRho,
9679 loss: &SaeManifoldLoss,
9680 cache: &ArrowFactorCache,
9681 ) -> Result<SaeOuterRhoGradientComponents, String> {
9682 let solver = self.outer_gradient_arrow_solver(cache, &rho.lambda_smooth_vec())?;
9683 self.analytic_outer_rho_gradient_components(target, rho, loss, cache, &solver)
9684 .map_err(|e| e.to_string())
9685 }
9686
9687 /// Compose the SAE LAML criterion as a sum of atoms (#931 SAE pilot).
9688 ///
9689 /// This is the single seam that establishes value↔gradient coherence for
9690 /// the SAE objective: it runs the inner solve once via
9691 /// [`Self::reml_criterion_with_cache`], reads the value decomposition
9692 /// (`loss.total() + extra_penalty_energy`, `log|H|`, `occam`) and the
9693 /// matching gradient channels (`SaeOuterRhoGradientComponents`) from the
9694 /// SAME converged cache, and hands them to [`SaeCriterion::assemble`]. The
9695 /// returned criterion's [`SaeCriterion::value`] and
9696 /// [`SaeCriterion::gradient`] are then projections of one factorization —
9697 /// the outer optimizer can no longer evaluate a value path and a gradient
9698 /// path that disagree (the #752/#748/#901 desync class). The
9699 /// implicit-stationarity envelope correction (#1006's Γ term) is its own
9700 /// named atom, so the channel the desync class keeps dropping is visible
9701 /// rather than a silent zero.
9702 pub fn criterion_as_atoms(
9703 &mut self,
9704 target: ArrayView2<'_, f64>,
9705 rho: &SaeManifoldRho,
9706 registry: Option<&AnalyticPenaltyRegistry>,
9707 inner_max_iter: usize,
9708 learning_rate: f64,
9709 ridge_ext_coord: f64,
9710 ridge_beta: f64,
9711 ) -> Result<SaeCriterion, String> {
9712 let (_v, loss, cache) = self.reml_criterion_with_cache(
9713 target,
9714 rho,
9715 registry,
9716 inner_max_iter,
9717 learning_rate,
9718 ridge_ext_coord,
9719 ridge_beta,
9720 )?;
9721 let log_det = arrow_log_det_from_cache(&cache).ok_or_else(|| {
9722 "criterion_as_atoms: arrow_log_det_from_cache returned None".to_string()
9723 })?;
9724 let occam = self.reml_occam_term(rho)?;
9725 let extra_penalty_energy = match registry {
9726 Some(reg) => self
9727 .reml_extra_penalty_value_total(reg)
9728 .map_err(|err| format!("SaeManifoldTerm::criterion_as_atoms: {err}"))?,
9729 None => 0.0,
9730 };
9731 let data_fit_priors_value = loss.total() + extra_penalty_energy;
9732
9733 let solver = self.outer_gradient_arrow_solver(&cache, &rho.lambda_smooth_vec())?;
9734 let components =
9735 self.analytic_outer_rho_gradient_components(target, rho, &loss, &cache, &solver)?;
9736 Ok(SaeCriterion::assemble(
9737 data_fit_priors_value,
9738 log_det,
9739 occam,
9740 components.explicit,
9741 components.logdet_trace,
9742 components.occam,
9743 components.third_order_correction,
9744 ))
9745 }
9746
9747 // [#780 line-count gate] reconstruction_dispersion + assemble_shape_uncertainty
9748 // + complete_born_atom_shape_bands + shape_uncertainty_without_decoder_covariance
9749 // (the contiguous trailing methods of this impl block) were split into the
9750 // sibling construction_reconstruction.rs (declared in mod.rs); callers reach
9751 // them bare via use super::*.
9752}
9753
9754// [#780 line-count gate] Per-row jet / reconstruction-channel assembly for the
9755// streaming-exact arrow log-det lives in a sibling file as a second
9756// `impl SaeManifoldTerm` block, inlined here so it keeps the SAME module scope
9757// and private-field access. Keeps this tracked file under the 10k limit.
9758include!("construction_row_jet_logdet_channels.rs");
9759
9760// [#780 line-count gate] `term_from_padded_blocks_with_mode` (the padded-FFI
9761// term builder) was split into the sibling `construction_padded_blocks.rs`
9762// module (declared and re-exported from `mod.rs`), keeping this tracked file
9763// under the 10k limit. Callers still reach it bare through `use super::*`.
9764
9765// [#780 line-count gate] `refresh_isometry_caches_from_atom` and
9766// `refresh_isometry_caches_from_term` were split into the sibling
9767// `construction_cache_refresh.rs` module (declared and re-exported from
9768// `mod.rs`), keeping this tracked file under the 10k limit. Callers still reach
9769// both functions bare through `use super::*`.
9770
9771// [#780 line-count gate] The `#[cfg(test)]` modules below the production code
9772// are mechanically split into a sibling `*_tests` file and inlined via
9773// `include!` (the sanctioned cohesive-module decomposition — see build.rs
9774// file_stem_is_exempt_test_module). Keeps this tracked file under the 10k limit.
9775include!("construction_tests.rs");