Skip to main content

gam_sae/
hybrid_split.rs

1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//!     y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//!               = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//!   * the CURVED candidate is scored at its CONSTRAINED MINIMUM over the atom's
28//!     decoder coefficients — re-fit on `y_resp` as
29//!     `min_B ‖y_resp − a_k·(Φ(t)·B)‖²` ([`curved_refit_rss`]) — so its data-fit
30//!     deviance is the smallest weighted RSS the curved family can attain on this
31//!     residual at the realized codes, not the possibly-collapsed realized curve.
32//!   * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
33//!     weighted least-squares straight line fit to `y_resp` (design column
34//!     scaled by the same assignment mass `a_k`), so its data-fit deviance is the
35//!     weighted RSS of the best line against the SAME residual.
36//!
37//! ## A genuine NESTED min-vs-min comparison (#1051)
38//!
39//! Both arms are now at their constrained minimum on the SAME leave-one-atom-out
40//! residual: the linear arm is the closed-form min-over-lines, and the curved arm
41//! is the min-over-decoders refit ([`curved_refit_rss`]). The linear special case
42//! `a_k·(b₀ + (t−t̄)·b₁)` is a MEMBER of the curved family whenever the straight
43//! lane `[1, (t−t̄)]` lies in the column span of the curved basis `Φ` — exactly for
44//! the interval / line-segment charts (whose basis carries the constant and linear
45//! terms) and to the basis's expressiveness for the periodic charts. So after the
46//! refit `curved_rss ≤ linear_rss` up to the least-squares solver tolerance: the
47//! curved family CANNOT do worse than its own `Θ = 0` sub-model. That is the
48//! nested-dominance property restored here — "curved match-or-beats linear" is now
49//! a floor established by re-optimizing the curved arm, not merely asserted.
50//!
51//! The broken (#1051) euclidean / multi-atom OUTER continuation is deliberately
52//! NOT re-entered: the direct per-atom `d = 1` decoder-only refit at the realized
53//! codes is sufficient to score the curved arm at its constrained minimum. When
54//! `Φ` is unavailable (no evaluator) or the refit solve is degenerate the arm
55//! falls back to the already-realized curve's RSS — an honest degradation, never a
56//! fabricated determinant. The argmin then trades the curved arm's (minimized)
57//! data fit against its larger parameter / complexity price, so a genuinely
58//! curved signal is preferred while a straight-line signal ties and collapses to
59//! the cheaper linear lane.
60//!
61//! This replaces an earlier diagnostic in which both candidates targeted the
62//! atom's already-fitted decoded image `γ_k(t)` (giving the curved arm a free
63//! zero residual against itself, #1202) and its successor which scored the
64//! already-REALIZED curved contribution (a post-fit heuristic that did not
65//! establish nested dominance); the curved arm is now re-fit to its minimum.
66
67use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
68
69use gam_linalg::faer_ndarray::FaerEigh;
70use gam_solve::evidence::{
71    HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
72};
73use gam_terms::latent::LatentManifold;
74use crate::chart_canonicalization::d1_atom_fitted_turning;
75use crate::manifold::{SaeManifoldAtom, solve_design_least_squares};
76
77/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
78/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
79/// penalty logdet and a full-rank design (no null space), which is the form
80/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
81/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
82/// candidates' Hessian logdets are already the closed-form scalar moments of
83/// their shared design — no factor cache or HVP callback to assemble.
84///
85/// SCALE CAVEAT: this is a FIXED-DISPERSION Laplace / penalized criterion, not a
86/// profiled REML marginal likelihood. It assumes unit dispersion (`σ² = 1`)
87/// after preprocessing — the residual objective is the bare `½ RSS` with no
88/// `RSS/(2σ²)` rescaling and no profiled-over-σ² log-determinant correction. It
89/// is therefore SENSITIVE to the response scale: rescaling `y → s·y` scales the
90/// `residual_objective` (`RSS`) by `s²` but leaves the `½ log|H|` complexity term
91/// unchanged, so the curved-vs-linear trade-off it expresses is not
92/// scale-invariant. Callers must keep the targets on a consistent (preprocessed,
93/// roughly unit-scale) footing for the comparison to mean what it says.
94fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
95    residual_objective + 0.5 * log_det_h
96}
97
98/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
99/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
100/// assignment-mass design weight the linear arm uses), summed over the
101/// eigenvalues above a relative spectral floor (#1223). This is the genuine
102/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
103/// assembled for the curved basis so the two arms' Laplace complexity prices are
104/// computed on the SAME footing instead of pricing the curved arm with a
105/// parameter-count proxy `M·log(Σw)`.
106///
107/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
108/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
109/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
110/// design's outer Gram over its basis columns; it is identical across the `p`
111/// output channels (every channel shares the design `Φ`), so the per-channel
112/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
113///
114/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
115/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
116/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
117/// positive eigenvalues (a fully rank-deficient design carries no determinant);
118/// the caller then falls back to the parameter-count proxy rather than fabricate
119/// a determinant.
120fn curved_design_gram_logdet(
121    phi: ArrayView2<'_, f64>,
122    assign: ArrayView1<'_, f64>,
123    p: usize,
124) -> Option<f64> {
125    let n = phi.nrows();
126    let m = phi.ncols();
127    if m == 0 || assign.len() != n || n == 0 {
128        return None;
129    }
130    // G = Φᵀ diag(a²) Φ  (M×M, symmetric PSD).
131    let mut gram = Array2::<f64>::zeros((m, m));
132    for i in 0..n {
133        let w = assign[i] * assign[i];
134        if !(w.is_finite() && w >= 0.0) {
135            return None;
136        }
137        if w == 0.0 {
138            continue;
139        }
140        let row = phi.row(i);
141        for a in 0..m {
142            let wa = w * row[a];
143            for b in a..m {
144                gram[[a, b]] += wa * row[b];
145            }
146        }
147    }
148    // Symmetrize the lower triangle (we only filled the upper).
149    for a in 0..m {
150        for b in 0..a {
151            gram[[a, b]] = gram[[b, a]];
152        }
153    }
154    if gram.iter().any(|v| !v.is_finite()) {
155        return None;
156    }
157    let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
158    // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
159    // tied to the largest eigenvalue, dropping the numerically-null directions
160    // (the curved design's null space, analogous to the linear arm's full-rank
161    // 2-D Gram). A design with no positive eigenvalue carries no determinant.
162    let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
163    if !(lambda_max > 0.0 && lambda_max.is_finite()) {
164        return None;
165    }
166    let floor = lambda_max * 1e-12;
167    let mut log_det = 0.0_f64;
168    let mut rank = 0usize;
169    for &lambda in vals.iter() {
170        if lambda > floor {
171            log_det += lambda.ln();
172            rank += 1;
173        }
174    }
175    if rank == 0 || !log_det.is_finite() {
176        return None;
177    }
178    // The design Gram is shared across the p output channels.
179    Some((p as f64) * log_det)
180}
181
182/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
183/// the exact assignment-mass-weighted least-squares line fit to the atom's
184/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
185/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
186/// verdict that selects LINEAR so the collapsed reconstruction can replace the
187/// curved decoded row with this straight image at any coordinate WITHOUT
188/// re-entering the (broken, #1051) outer fit — the coefficients are already
189/// realized inside the adjudication.
190#[derive(Clone, Debug)]
191pub struct AtomLinearImage {
192    /// The atom's slot index in the dictionary (so the collapsed assembly knows
193    /// which atom's decoded row to substitute).
194    pub atom_idx: usize,
195    /// The mass-weighted coordinate mean `t̄` the line is centered on.
196    pub t_bar: f64,
197    /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
198    pub b0: Array1<f64>,
199    /// Per-output-channel slope `b₁` (length `p`).
200    pub b1: Array1<f64>,
201    /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
202    /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
203    /// only when the atom's circle codes had collapsed to a single point
204    /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
205    /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
206    /// (the projection of the leave-this-atom-out residual onto its top
207    /// mass-weighted output direction). This is what lets a circle atom that the
208    /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
209    /// still reconstruct its residual's best linear direction — recovering the
210    /// linear-tail reach the hybrid-split was designed to deliver — instead of a
211    /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
212    /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
213    /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
214    ///
215    /// TRAIN-ONLY CAVEAT (#1777): these are the TRAIN rows' codes. They are only
216    /// meaningful for the exact rows the split was fit on; a held-out row has no
217    /// entry here and used to fall back to the atom's own (collapsed) coordinate
218    /// `own_t` — a DIFFERENT, degraded model out of sample. Prefer [`Self::v`]:
219    /// projecting a held-out row's leave-this-atom-out residual onto `v` recovers
220    /// that row's coordinate by the SAME math the train codes were built with, so
221    /// train and OOS use one model. `row_codes` is retained for back-compat and as
222    /// the exact cached train projection.
223    pub row_codes: Option<Array1<f64>>,
224    /// #1777 collapse-rescue projection DIRECTION `v` (length `p`, unit norm), the
225    /// top mass-weighted output direction of the atom's leave-this-atom-out
226    /// residual. `Some` exactly when this is a collapse-rescued image (paired with
227    /// `row_codes`); `None` for the ordinary straight-image path (which decodes at
228    /// the atom's own coordinate). This is the SERIALIZABLE quantity the FFI must
229    /// persist so an OOS term can recompute any row's coordinate as
230    /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — identical to the train code
231    /// `row_codes[i]` on a train row, and the correct held-out coordinate on an OOS
232    /// row (see [`Self::coordinate_from_residual`]). Length must equal `b0`/`b1`.
233    pub v: Option<Array1<f64>>,
234}
235
236impl AtomLinearImage {
237    /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
238    pub fn fill_row(&self, t: f64, out: &mut [f64]) {
239        let dt = t - self.t_bar;
240        for (j, slot) in out.iter_mut().enumerate() {
241            *slot = self.b0[j] + dt * self.b1[j];
242        }
243    }
244
245    /// The coordinate at which row `row` should evaluate this image: the
246    /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
247    /// realized coordinate `own_t` passed by the caller.
248    ///
249    /// TRAIN-ONLY: `row_codes` is indexed by TRAIN row, so this is correct only
250    /// for the rows the split was fit on. Out of sample use
251    /// [`Self::coordinate_from_residual`] (target-aware, model-identical to train).
252    pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
253        match &self.row_codes {
254            Some(u) if row < u.len() => u[row],
255            _ => own_t,
256        }
257    }
258
259    /// #1777 — the collapse-rescue coordinate of a row from ITS OWN
260    /// leave-this-atom-out residual `resid = y_i − Σ_{j≠k} f_j(x_i)` (length `p`),
261    /// namely `uᵢ = ⟨resid, v⟩`. `Some(uᵢ)` exactly when this is a collapse-rescued
262    /// image (`v` is set); `None` for the ordinary straight-image path (which has
263    /// no projection direction and decodes at the atom's own coordinate).
264    ///
265    /// This is the SAME math [`build_collapse_rescue_linear_image`] used to build
266    /// the train `row_codes` (`row_codes[i] = ⟨target_resid[i], v⟩`), so on a TRAIN
267    /// row it reproduces `row_codes[i]` exactly, and on a HELD-OUT row it yields
268    /// that row's correct coordinate — train and OOS share one model. Returns
269    /// `None` if `resid`'s length disagrees with `v`.
270    pub fn coordinate_from_residual(&self, resid: &[f64]) -> Option<f64> {
271        let v = self.v.as_ref()?;
272        if resid.len() != v.len() {
273            return None;
274        }
275        Some(v.iter().zip(resid).map(|(&vj, &rj)| vj * rj).sum())
276    }
277
278    /// Whether this image is a #1777 collapse-rescued image (carries a projection
279    /// direction `v` and per-row train codes) rather than an ordinary straight
280    /// image evaluated at the atom's own coordinate.
281    pub fn is_collapse_rescued(&self) -> bool {
282        self.v.is_some()
283    }
284}
285
286/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
287#[derive(Clone, Debug)]
288pub struct AtomHybridVerdict {
289    /// The atom's name (slot identity in the dictionary).
290    pub atom_name: String,
291    /// The evidence-selected parameterization choice for this slot.
292    pub choice: HybridAtomChoice,
293    /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
294    /// `false` iff it yielded to the LINEAR special case (the straight tail).
295    pub kept_curved: bool,
296    /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
297    /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
298    /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
299    /// iff the evaluator has no analytic second jet or the curve is degenerate.
300    /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
301    /// structured data on the persisted report rather than a transient log line.
302    pub fitted_turning: Option<f64>,
303    /// The atom's training leave-one-atom-out explained-variance contribution
304    /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
305    /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
306    /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
307    /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
308    /// genuine curved family. `None` iff the caller did not supply LOAO EV.
309    pub train_loao_delta_ev: Option<f64>,
310    /// The fitted straight sub-model for this slot, present iff the verdict
311    /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
312    /// substitutes this for the atom's curved decoded image, making the verdict
313    /// load-bearing on the reconstruction rather than a passive diagnostic.
314    pub linear_image: Option<AtomLinearImage>,
315}
316
317/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
318/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
319/// against.
320#[derive(Clone, Debug)]
321pub struct SaeHybridSplitReport {
322    /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
323    /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
324    /// — they carry no curved/linear adjudication.
325    pub verdicts: Vec<AtomHybridVerdict>,
326    /// The dictionary-level rolled-up selection (summed NLE, total parameters,
327    /// curved/linear counts) over the eligible atoms.
328    pub selection: HybridSplitSelection,
329}
330
331/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
332/// straight-line fit with a residual estimate, so the linear candidate's
333/// deviance is undefined. Such atoms are skipped (absent from the report),
334/// never adjudicated on a fabricated deviance.
335const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
336
337/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
338/// its linear tail only if doing so costs at most this fraction of the target's
339/// total (centered) variance in full-reconstruction explained variance. The
340/// evidence argmin trades data-fit against the curved arm's parameter price in
341/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
342/// cheaper line even when the curve carries real reconstruction signal — the
343/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
344/// a direct guard on the quantity that actually regressed: the per-atom collapse
345/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
346/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
347/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
348/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
349/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
350/// reintroduce the scale-incommensurability that sank the evidence-reformulation
351/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
352/// `≈ 0` EV and still collapses losslessly; only a curve doing real
353/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
354/// comfortably above f64 round-off on an exact-line collapse yet far below any
355/// material EV loss, so it separates the lossless and load-bearing regimes
356/// without tuning.
357///
358/// WHY A FIXED DIMENSIONLESS TOLERANCE AND NOT A NOISE/DISPERSION-DERIVED ONE
359/// (#1610). This gate is a SAFETY BACKSTOP layered over the evidence/REML
360/// selection ([`select_hybrid_split`]), which is itself the nested-model
361/// statistical test: it already trades the curved arm's data-fit against its
362/// Laplace parameter price in NLE units and picks the line when the curve is not
363/// evidence-justified. The backstop exists only to catch the residual case where
364/// that argmin prefers the cheaper line yet doing so DROPS reconstruction EV (the
365/// observed 1.0 → 0.748). The correct instrument for a backstop over a
366/// statistical test is a conservative negligibility tolerance on the quantity it
367/// protects (full-reconstruction EV), NOT a second re-derived statistical
368/// threshold. A noise/dispersion-derived per-atom tolerance — e.g.
369/// `df_extra · σ̂² / SST_full`, the curve's expected spurious extra RSS under the
370/// null that the image is straight, with `σ̂²` the curved fit's residual
371/// dispersion — has, moreover, no SAFE calibration here: on an exact-line
372/// collapse the fit's dispersion `σ̂² → 0`, so a pure noise threshold falls BELOW
373/// the least-squares solver round-off of `linear_rss − curved_rss`
374/// (`≈ κ·εmach·SST_full`) and would spuriously VETO the lossless collapses the
375/// deterministic tests require; raising it to clear round-off, conversely, only
376/// relaxes the backstop TOWARD the over-collapse boundary it was added to hold
377/// (larger tolerance ⇒ fewer vetoes ⇒ more collapse). There is thus no safe
378/// direction to make it noise-adaptive. The safe window is the wide, well-
379/// separated band between solver round-off (`~1e-12` relative) and any material
380/// EV loss (`~1e-2`); `1e-3` is the standard "0.1% of variance is negligible"
381/// point inside it — dimensionless on an EV `∈ [0,1]` (hence scale-invariant, per
382/// this issue's scale-invariance contract) with a single explicit meaning, not a
383/// corpus-tuned magnitude. (A noise-adaptive backstop would also change
384/// reconstruction on real activations in a way only the real-OLMo behavioral
385/// battery can validate.)
386///
387/// PER-ATOM SCOPE: applied per slot, this tolerance bounds only ONE atom's
388/// individual EV loss. It does NOT by itself bound the dictionary-level EV loss
389/// when several atoms collapse at once: the true global RSS change is
390/// `Σ_k Δ_k + 2 Σ_{j<k} <Δrecon_j, Δrecon_k>` — both the accumulation of many
391/// individually-tolerable `Δ_k` AND the cross terms between co-active atoms are
392/// invisible to the per-atom gate. [`build_hybrid_split_report`] adds an
393/// aggregate global guard (interpreting this same fraction as a bound on
394/// `Σ_k max(Δ_k, 0)`) on top of the per-atom gate; see there for what that guard
395/// does and does not prove.
396const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
397
398/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
399/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
400/// scored against the atom's leave-this-atom-out response residual `y_resp`
401/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
402/// full reconstruction differs from the collapsed one ONLY in this atom's
403/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
404/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
405/// dividing by the fixed total target variance gives the full-EV loss the
406/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
407/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
408/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
409/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
410/// evidence comparison see the same residuals.
411fn collapse_ssr_increase(
412    coords: ArrayView1<'_, f64>,
413    assign: ArrayView1<'_, f64>,
414    decoded: ArrayView2<'_, f64>,
415    target_resid: ArrayView2<'_, f64>,
416    t_bar: f64,
417    b0: &Array1<f64>,
418    b1: &Array1<f64>,
419) -> f64 {
420    let n = assign.len();
421    let p = target_resid.ncols();
422    let mut curved_rss = 0.0_f64;
423    let mut linear_rss = 0.0_f64;
424    for i in 0..n {
425        let a = assign[i];
426        let dt = coords[i] - t_bar;
427        for j in 0..p {
428            let y = target_resid[[i, j]];
429            let r_curved = y - a * decoded[[i, j]];
430            curved_rss += r_curved * r_curved;
431            let r_linear = y - a * (b0[j] + dt * b1[j]);
432            linear_rss += r_linear * r_linear;
433        }
434    }
435    linear_rss - curved_rss
436}
437
438/// #1051/#1026 NESTED MIN — re-fit the curved atom's decoder on the SAME
439/// leave-this-atom-out residual `y_resp` and return its **minimum** weighted
440/// reconstruction RSS at the atom's realized codes.
441///
442/// This is the genuine constrained minimum of the curved family over its free
443/// decoder coefficients `B`:
444///
445/// ```text
446/// min_B Σᵢ ‖ y_resp[i] − a_k·( Φ(t_i)·B ) ‖²   (design = diag(a)·Φ, rhs = y_resp).
447/// ```
448///
449/// It restores the nested-dominance floor `curved_rss ≤ linear_rss`. The linear
450/// special case `a_k·(b₀ + (t−t̄)·b₁)` is itself a member of this family whenever
451/// the straight lane `[1, (t−t̄)]` lies in the column span of the curved basis
452/// `Φ` — exactly (interval / line-segment charts, whose basis carries the
453/// constant and linear terms) or to the basis's expressiveness (the periodic
454/// charts). So `min_B` over `Φ` cannot do WORSE than the best straight line: the
455/// returned RSS is `≤` the linear arm's RSS up to the least-squares solver
456/// tolerance. This is the direct per-atom `d = 1` refit the module owes; the
457/// broken (#1051) euclidean / multi-atom OUTER continuation is deliberately NOT
458/// re-entered — the decoder-only refit at the realized codes is sufficient to
459/// score the curved arm at its constrained minimum.
460///
461/// `phi` is the curved design `Φ(t)` on the atom's rows (`n × M`); `assign` the
462/// per-row mass `a_k` (NOT squared — the design weight is the mass itself, so the
463/// residual is on the SAME footing the linear arm and the joint loss use).
464/// Returns `None` when the solve is degenerate or non-finite; the caller then
465/// falls back to the already-realized curve's RSS rather than fabricate a value.
466fn curved_refit_rss(
467    phi: ArrayView2<'_, f64>,
468    assign: ArrayView1<'_, f64>,
469    target_resid: ArrayView2<'_, f64>,
470) -> Option<f64> {
471    let n = phi.nrows();
472    let m = phi.ncols();
473    let p = target_resid.ncols();
474    if m == 0 || n == 0 || assign.len() != n || target_resid.nrows() != n || p == 0 {
475        return None;
476    }
477    // Weighted design `diag(a)·Φ` (n×M). The refit minimizes ‖diag(a)·Φ·B − y_resp‖²,
478    // so the fitted prediction is diag(a)·Φ·B — the curved contribution `a_k·γ_k`
479    // at its best decoder `B` on this residual.
480    let mut design = Array2::<f64>::zeros((n, m));
481    for i in 0..n {
482        let a = assign[i];
483        if !a.is_finite() {
484            return None;
485        }
486        for c in 0..m {
487            design[[i, c]] = a * phi[[i, c]];
488        }
489    }
490    let b = solve_design_least_squares(design.view(), target_resid).ok()?;
491    if b.iter().any(|v| !v.is_finite()) {
492        return None;
493    }
494    let pred = design.dot(&b);
495    let mut rss = 0.0_f64;
496    for i in 0..n {
497        for j in 0..p {
498            let r = target_resid[[i, j]] - pred[[i, j]];
499            rss += r * r;
500        }
501    }
502    rss.is_finite().then_some(rss)
503}
504
505/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
506/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
507/// an honest pair (too few rows, degenerate coordinate span, or non-finite
508/// numbers). Both candidates are scored against the SAME data — the atom's
509/// leave-this-atom-out response residual `y_resp` — at their CONSTRAINED MINIMUM:
510/// the linear arm is the freshly-fit min-over-lines and the curved arm is the
511/// min-over-decoders refit ([`curved_refit_rss`]), so the comparison is a genuine
512/// nested min-vs-min one and `curved_rss ≤ linear_rss` holds up to solver
513/// tolerance for a basis whose span contains the straight lane (#1051).
514///
515/// Inputs over the atom's assigned rows:
516///   * `coords` — the fitted on-atom coordinate `t`.
517///   * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
518///     squares it where the design weight `a_k²` is needed).
519///   * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
520///     whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
521///   * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
522///     (`p` cols): the response with every OTHER atom's contribution removed.
523///     This is the data both candidates fit.
524///
525/// The curved candidate's data-fit deviance is `½·min_B Σ ‖y_resp − a_k·(Φ·B)‖²`
526/// (its constrained minimum over the decoder; the mass lives in the prediction);
527/// the linear candidate fits the best mass-weighted straight line to `y_resp` and
528/// pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`. Because the linear prediction is
529/// itself a curved-family member (the straight lane lies in `span(Φ)` for the
530/// eligible charts), the curved arm's minimized RSS is `≤` the linear arm's up to
531/// solver tolerance, so the argmin is a genuine nested min-vs-min dominance
532/// comparison, not a post-fit compression heuristic.
533fn build_atom_candidates(
534    coords: ArrayView1<'_, f64>,
535    assign: ArrayView1<'_, f64>,
536    decoded: ArrayView2<'_, f64>,
537    target_resid: ArrayView2<'_, f64>,
538    curved_num_params: usize,
539    curved_phi: Option<ArrayView2<'_, f64>>,
540    fitted_turning: Option<f64>,
541) -> Option<(
542    HybridAtomCandidate,
543    HybridAtomCandidate,
544    (f64, Array1<f64>, Array1<f64>),
545)> {
546    let n = coords.len();
547    let p = decoded.ncols();
548    if n < MIN_ROWS_FOR_LINEAR_FIT
549        || decoded.nrows() != n
550        || assign.len() != n
551        || target_resid.nrows() != n
552        || target_resid.ncols() != p
553        || p == 0
554    {
555        return None;
556    }
557
558    // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
559    // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
560    // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
561    // and spread `s_tt` under that weight; a row that barely belongs to the atom
562    // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
563    let mut w_sum = 0.0_f64;
564    let mut t_bar = 0.0_f64;
565    for i in 0..n {
566        let a = assign[i];
567        if !(a.is_finite() && a >= 0.0) {
568            return None;
569        }
570        let w = a * a;
571        w_sum += w;
572        t_bar += w * coords[i];
573    }
574    if !(w_sum > 0.0) {
575        return None;
576    }
577    t_bar /= w_sum;
578
579    // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
580    // line's design weight. A degenerate (single-point mass) coordinate has no
581    // slope direction; refuse rather than divide by ~0.
582    let mut s_tt = 0.0_f64;
583    for i in 0..n {
584        let dt = coords[i] - t_bar;
585        s_tt += assign[i] * assign[i] * dt * dt;
586    }
587    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
588        return None;
589    }
590
591    // Per-output-channel mass-weighted least squares for the line fit to the
592    // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
593    // the centered basis has the diagonal normal equations
594    //   b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum,   (recall the design intercept is a_k)
595    //   b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
596    let mut b0 = Array1::<f64>::zeros(p);
597    let mut b1 = Array1::<f64>::zeros(p);
598    for j in 0..p {
599        let mut s_1y = 0.0_f64;
600        let mut s_ty = 0.0_f64;
601        for i in 0..n {
602            let a = assign[i];
603            let dt = coords[i] - t_bar;
604            let y = target_resid[[i, j]];
605            s_1y += a * y;
606            s_ty += a * dt * y;
607        }
608        b0[j] = s_1y / w_sum;
609        b1[j] = s_ty / s_tt;
610    }
611
612    // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
613    // common data. The linear candidate predicts the best line
614    // `a_k·(b₀ + (t − t̄)·b₁)`; the curved candidate is scored at its CONSTRAINED
615    // MINIMUM over the decoder coefficients (nested min-vs-min, #1051), re-fit on
616    // this same residual — not the possibly-collapsed already-realized curve. We
617    // also carry the realized curve's RSS as the honest fallback when the basis
618    // `Φ` is unavailable or its refit solve is degenerate.
619    let mut linear_rss = 0.0_f64;
620    let mut realized_curved_rss = 0.0_f64;
621    for i in 0..n {
622        let a = assign[i];
623        let dt = coords[i] - t_bar;
624        for j in 0..p {
625            let y = target_resid[[i, j]];
626            let r_linear = y - a * (b0[j] + dt * b1[j]);
627            linear_rss += r_linear * r_linear;
628            let r_curved = y - a * decoded[[i, j]];
629            realized_curved_rss += r_curved * r_curved;
630        }
631    }
632    // #1051 NESTED MIN — the curved arm's data fit is `min_B ‖y_resp − diag(a)Φ B‖²`,
633    // its constrained minimum over the decoder. Because the linear lane is a member
634    // of the curved family (the straight columns lie in `span(Φ)` for the eligible
635    // charts), this min-curved RSS is `≤ linear_rss` up to solver tolerance — the
636    // "curved match-or-beats linear" floor. Fall back to the realized curve's RSS
637    // only when Φ is absent or the refit is degenerate.
638    let curved_rss = match curved_phi {
639        Some(phi) if phi.nrows() == n => {
640            curved_refit_rss(phi, assign, target_resid).unwrap_or(realized_curved_rss)
641        }
642        _ => realized_curved_rss,
643    };
644
645    // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
646    // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
647    // well its REALIZED curve explains the residual) plus its larger `M·p`
648    // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
649    // `curved_rss` is the realized (not re-optimized) curve's misfit, so it is NOT
650    // guaranteed `≤ linear_rss`: when the realized curve underperforms its own best
651    // straight projection the cheaper line simply wins. The argmin trades whatever
652    // data-fit the realized curve buys against the curvature parameter price — a
653    // post-fit compression decision, not a nested match-or-beat floor.
654    let curved_residual_objective = 0.5 * curved_rss;
655    let linear_residual_objective = 0.5 * linear_rss;
656
657    // Linear candidate parameter price: intercept + slope per output channel.
658    let linear_num_params = 2 * p;
659
660    // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
661    //
662    // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
663    // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
664    // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
665    // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
666    //
667    //     log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
668    //
669    // The `log(s_tt)` term is the slope direction's information: a line through a
670    // wide, heavily-massed coordinate spread is better-determined than one through
671    // a tiny spread, and the Laplace evidence must reflect that (#1203).
672    //
673    // The curved arm's Laplace determinant is now the genuine weighted-design
674    // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
675    // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
676    // from the curved basis `Φ` on the atom's assigned rows under the same
677    // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
678    // `λS` normalizer, so the complexity price is computed on a symmetric
679    // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
680    // caller could not evaluate the basis) or its Gram is fully rank-deficient do
681    // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
682    // the comparison degrades gracefully rather than fabricating a determinant.
683    if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
684        return None;
685    }
686    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
687    let curved_log_det_h = curved_phi
688        .and_then(|phi| {
689            if phi.nrows() == n {
690                curved_design_gram_logdet(phi, assign, p)
691            } else {
692                None
693            }
694        })
695        .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
696
697    // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
698    // smoothing-penalty logdet (the intrinsic smoothness penalty is
699    // reparameterization-invariant and identical in expectation across the two
700    // parameterizations of the same image).
701    let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
702    let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
703    if !(linear_nle.is_finite() && curved_nle.is_finite()) {
704        return None;
705    }
706
707    let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
708    let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
709    Some((linear, curved, (t_bar, b0, b1)))
710}
711
712/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
713/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
714/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
715/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
716/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
717/// let such a linear-tail atom decode as a straight line; the only reason it can't
718/// here is that its own codes carry no spread to fit a slope against.
719///
720/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
721/// of the leave-this-atom-out residual onto its top mass-weighted output direction
722/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
723/// fit uses). These codes span the residual's strongest linear axis by
724/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
725/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
726/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
727/// or `None` when the residual itself carries no usable direction (a genuine zero
728/// atom the mass/decoder guards own).
729fn build_collapse_rescue_linear_image(
730    atom_idx: usize,
731    assign: ArrayView1<'_, f64>,
732    target_resid: ArrayView2<'_, f64>,
733) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
734    let n = assign.len();
735    let p = target_resid.ncols();
736    if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
737        return None;
738    }
739    let mut w_sum = 0.0_f64;
740    for i in 0..n {
741        let a = assign[i];
742        if !(a.is_finite() && a >= 0.0) {
743            return None;
744        }
745        w_sum += a * a;
746    }
747    if !(w_sum > 0.0) {
748        return None;
749    }
750    // Top mass-weighted output direction `v` of the residual via power iteration on
751    // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
752    // Seed from the per-channel weighted energy so a rank-1 residual converges in
753    // one step and the seed is deterministic (no RNG).
754    let mut v = Array1::<f64>::zeros(p);
755    for j in 0..p {
756        let mut e = 0.0_f64;
757        for i in 0..n {
758            let a = assign[i];
759            let y = target_resid[[i, j]];
760            e += a * a * y * y;
761        }
762        v[j] = e;
763    }
764    let mut vnorm = v.dot(&v).sqrt();
765    if !(vnorm > 0.0) {
766        return None;
767    }
768    v.mapv_inplace(|x| x / vnorm);
769    for _ in 0..32 {
770        let mut mv = Array1::<f64>::zeros(p);
771        for i in 0..n {
772            let a = assign[i];
773            let w = a * a;
774            let mut proj = 0.0_f64;
775            for j in 0..p {
776                proj += target_resid[[i, j]] * v[j];
777            }
778            let wp = w * proj;
779            for j in 0..p {
780                mv[j] += wp * target_resid[[i, j]];
781            }
782        }
783        vnorm = mv.dot(&mv).sqrt();
784        if !(vnorm > 0.0) {
785            return None;
786        }
787        mv.mapv_inplace(|x| x / vnorm);
788        let cos = mv.dot(&v).abs();
789        v = mv;
790        if cos > 1.0 - 1e-12 {
791            break;
792        }
793    }
794    // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
795    let mut u = Array1::<f64>::zeros(n);
796    let mut t_bar = 0.0_f64;
797    for i in 0..n {
798        let mut proj = 0.0_f64;
799        for j in 0..p {
800            proj += target_resid[[i, j]] * v[j];
801        }
802        u[i] = proj;
803        t_bar += assign[i] * assign[i] * proj;
804    }
805    t_bar /= w_sum;
806    let mut s_tt = 0.0_f64;
807    for i in 0..n {
808        let dt = u[i] - t_bar;
809        s_tt += assign[i] * assign[i] * dt * dt;
810    }
811    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
812        return None;
813    }
814    let mut b0 = Array1::<f64>::zeros(p);
815    let mut b1 = Array1::<f64>::zeros(p);
816    let mut linear_rss = 0.0_f64;
817    for j in 0..p {
818        let mut s_1y = 0.0_f64;
819        let mut s_ty = 0.0_f64;
820        for i in 0..n {
821            let a = assign[i];
822            let dt = u[i] - t_bar;
823            let y = target_resid[[i, j]];
824            s_1y += a * y;
825            s_ty += a * dt * y;
826        }
827        b0[j] = s_1y / w_sum;
828        b1[j] = s_ty / s_tt;
829    }
830    for i in 0..n {
831        let a = assign[i];
832        let dt = u[i] - t_bar;
833        for j in 0..p {
834            let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
835            linear_rss += r * r;
836        }
837    }
838    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
839    let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
840    if !linear_nle.is_finite() {
841        return None;
842    }
843    let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
844    let image = AtomLinearImage {
845        atom_idx,
846        t_bar,
847        b0,
848        b1,
849        row_codes: Some(u),
850        // #1777 — persist the projection direction so an OOS row's coordinate can
851        // be recomputed as ⟨residual, v⟩ (identical to the train `row_codes`),
852        // rather than falling back to the atom's collapsed own coordinate.
853        v: Some(v),
854    };
855    Some((linear, image))
856}
857
858/// #1026 item-2 — one collapsed slot's TRUE dictionary-level reconstruction
859/// change `δ_k[i,j] = a_k·(γ_k(t_i) − line_k(t_i))` over ALL globally-aligned rows
860/// (the caller presents `coords`/`assign`/`decoded`/`image` on the same `n` rows,
861/// so these δ vectors ARE cross-atom aligned and their inner products are the
862/// genuine cross terms). `line_k` is evaluated at the image's own coordinate
863/// (collapse-rescue slots evaluate at their fresh per-row codes via
864/// [`AtomLinearImage::coordinate_for_row`], exactly as the collapsed reconstruction
865/// does). Collapsing atom `k` shifts the full reconstruction residual by `+δ_k`.
866fn slot_delta(
867    coords: &Array1<f64>,
868    assign: &Array1<f64>,
869    decoded: &Array2<f64>,
870    image: &AtomLinearImage,
871) -> Array2<f64> {
872    let n = decoded.nrows();
873    let p = decoded.ncols();
874    let mut d = Array2::<f64>::zeros((n, p));
875    for i in 0..n {
876        let a = assign[i];
877        let coord = image.coordinate_for_row(i, coords[i]);
878        let dt = coord - image.t_bar;
879        for j in 0..p {
880            let line = image.b0[j] + dt * image.b1[j];
881            d[[i, j]] = a * (decoded[[i, j]] - line);
882        }
883    }
884    d
885}
886
887/// #1026 item-2 — the ALL-CURVED global reconstruction residual
888/// `R0 = target − Σ_all a·γ` recovered from any single atom's leave-this-atom-out
889/// residual: `R0 = y_resp_k − a_k·γ_k = target_resid − a_k·decoded`. Identical for
890/// every atom (each `target_resid` adds back exactly that atom's own contribution),
891/// so the caller computes it once from the first slot.
892fn slot_r0(assign: &Array1<f64>, decoded: &Array2<f64>, target_resid: &Array2<f64>) -> Array2<f64> {
893    let n = decoded.nrows();
894    let p = decoded.ncols();
895    let mut r0 = Array2::<f64>::zeros((n, p));
896    for i in 0..n {
897        let a = assign[i];
898        for j in 0..p {
899            r0[[i, j]] = target_resid[[i, j]] - a * decoded[[i, j]];
900        }
901    }
902    r0
903}
904
905/// #1026 item-2 — GLOBAL cross-term collapse guard. Given the collapsed slots'
906/// dictionary-level reconstruction-change vectors `δ_k` (globally row-aligned,
907/// `n × p`) and the all-curved global residual `R0 = target − Σ_all a·γ`, decide
908/// which collapses must revert to curved so the TRUE global reconstruction SSR
909/// increase
910///
911/// ```text
912/// ΔRSS(S) = ‖R0 + Σ_{k∈S} δ_k‖² − ‖R0‖²
913///         = 2⟨R0, Σ_{k∈S} δ_k⟩ + ‖Σ_{k∈S} δ_k‖²
914///         = Σ_{k∈S} Δ_k + 2 Σ_{j<k∈S} ⟨δ_j, δ_k⟩
915/// ```
916///
917/// stays within `global_tol`. Unlike the per-atom / summed-loss guard this
918/// INCLUDES the cross terms `2 Σ_{j<k} ⟨δ_j, δ_k⟩` between simultaneously-collapsed
919/// atoms (correlated collapse errors), which the aggregate `Σ max(Δ_k, 0)` bound is
920/// blind to. `forced` are the always-collapsed δ (euclidean / collapse-rescue slots
921/// with no curved alternative): they stay in the reconstruction but cannot be
922/// rolled back. `eligible` are `(slot, curved_evidence_margin, δ)` for
923/// rollback-eligible collapses. When `ΔRSS` over ALL collapses exceeds tolerance,
924/// the least-justified eligible collapses (largest margin — the most marginal
925/// linear win) are reverted one at a time, RECOMPUTING the true global increase
926/// (cross terms and all) after each revert, until within tolerance or none remain.
927/// Returns the slot indices to revert to curved.
928fn global_collapse_rollback(
929    r0: ArrayView2<'_, f64>,
930    eligible: &[(usize, f64, &Array2<f64>)],
931    forced: &[&Array2<f64>],
932    global_tol: f64,
933) -> Vec<usize> {
934    let (n, p) = r0.dim();
935    // True global SSR increase for the active δ set: 2⟨R0, Σδ⟩ + ‖Σδ‖².
936    let increase = |active: &[&Array2<f64>]| -> f64 {
937        let mut cross = 0.0_f64;
938        let mut self_sq = 0.0_f64;
939        for i in 0..n {
940            for j in 0..p {
941                let mut s = 0.0_f64;
942                for d in active {
943                    s += d[[i, j]];
944                }
945                cross += r0[[i, j]] * s;
946                self_sq += s * s;
947            }
948        }
949        2.0 * cross + self_sq
950    };
951    let mut kept = vec![true; eligible.len()];
952    let build_active = |kept: &[bool]| -> Vec<&Array2<f64>> {
953        let mut v: Vec<&Array2<f64>> = forced.to_vec();
954        for (idx, &(_, _, d)) in eligible.iter().enumerate() {
955            if kept[idx] {
956                v.push(d);
957            }
958        }
959        v
960    };
961    if increase(&build_active(&kept)) <= global_tol {
962        return Vec::new();
963    }
964    // Revert least-justified first: largest curved_evidence_margin (the most
965    // marginal linear win is the cheapest to give back to curved).
966    let mut order: Vec<usize> = (0..eligible.len()).collect();
967    order.sort_by(|&a, &b| {
968        eligible[b]
969            .1
970            .partial_cmp(&eligible[a].1)
971            .unwrap_or(std::cmp::Ordering::Equal)
972    });
973    let mut reverted = Vec::new();
974    for idx in order {
975        kept[idx] = false;
976        reverted.push(eligible[idx].0);
977        if increase(&build_active(&kept)) <= global_tol {
978            break;
979        }
980    }
981    reverted
982}
983
984/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
985/// fitted `d = 1` atoms, run the adjudication, and return the report.
986///
987/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
988/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
989/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
990/// atom's leave-this-atom-out response residual `y_resp` (the data both
991/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
992/// manifold (a flat / Euclidean chart can present only the linear candidate,
993/// enforced inside the selector).
994///
995/// Returns `None` (no report) when no atom is eligible — there is nothing to
996/// adjudicate.
997pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
998    atoms: &'a [SaeManifoldAtom],
999    eligible_d1: impl Iterator<Item = usize>,
1000    mut coords_for: C,
1001    mut assign_for: W,
1002    mut decoded_for: D,
1003    mut target_resid_for: R,
1004    mut manifold_for: M,
1005    mut delta_ev_for: E,
1006    // #1026 — the full target's total (column-centered) variance `SST_full`, the
1007    // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
1008    // the gate (a degenerate, varianceless target has no EV to preserve).
1009    total_centered_variance: f64,
1010) -> Result<Option<SaeHybridSplitReport>, String>
1011where
1012    C: FnMut(usize) -> Array1<f64>,
1013    W: FnMut(usize) -> Array1<f64>,
1014    D: FnMut(usize) -> Array2<f64>,
1015    R: FnMut(usize) -> Array2<f64>,
1016    M: FnMut(usize) -> LatentManifold,
1017    // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
1018    // is unavailable (e.g. the caller has no target to measure against).
1019    E: FnMut(usize) -> Option<f64>,
1020{
1021    let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
1022    let mut names: Vec<String> = Vec::new();
1023    let mut manifolds: Vec<LatentManifold> = Vec::new();
1024    // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
1025    // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
1026    // substitute it for the curved decoded image.
1027    let mut linear_images: Vec<AtomLinearImage> = Vec::new();
1028    // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
1029    // so the geometry/EV pairing is structured report data, not a log line.
1030    let mut turnings: Vec<Option<f64>> = Vec::new();
1031    let mut delta_evs: Vec<Option<f64>> = Vec::new();
1032    // #1026 item-2 — per-slot collapse loss `Δ_k = linear_rss − curved_rss` for the
1033    // GLOBAL EV-preservation guard below. `Some(Δ_k)` for a curveable slot that
1034    // retained a curved alternative (so a chosen collapse there can be rolled back);
1035    // `None` for euclidean slots (no curved option) and collapse-rescue slots (the
1036    // curve was already degenerate — collapsing recovers EV rather than losing it).
1037    let mut collapse_loss: Vec<Option<f64>> = Vec::new();
1038    // #1026 item-2 — per-slot dictionary-level reconstruction-change vector δ_k
1039    // (globally row-aligned n×p), and the all-curved global residual R0 (computed
1040    // once from the first slot). These feed the GLOBAL cross-term collapse guard,
1041    // which reconstructs the full dictionary with the selected collapses applied
1042    // and measures the TRUE global EV degradation (cross terms and all).
1043    let mut deltas: Vec<Array2<f64>> = Vec::new();
1044    let mut r0: Option<Array2<f64>> = None;
1045
1046    for atom_idx in eligible_d1 {
1047        let atom = &atoms[atom_idx];
1048        let coords = coords_for(atom_idx);
1049        let assign = assign_for(atom_idx);
1050        let decoded = decoded_for(atom_idx);
1051        let target_resid = target_resid_for(atom_idx);
1052        // Curved parameter price = the decoder's `M · p` coefficients.
1053        let curved_num_params = atom.decoder_coefficients.len();
1054        let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
1055            d1_atom_fitted_turning(
1056                evaluator.as_ref(),
1057                atom.decoder_coefficients.view(),
1058                coords.view(),
1059            )
1060            .ok()
1061            .flatten()
1062        });
1063        // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
1064        // curved arm's Laplace complexity is the real weighted-design Gram
1065        // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
1066        // atom's coordinate column is presented as an `n × 1` design input. If
1067        // the evaluator is absent or refuses, `curved_phi` stays `None` and
1068        // `build_atom_candidates` falls back to the proxy.
1069        let coords_col = coords
1070            .view()
1071            .into_shape_with_order((coords.len(), 1))
1072            .ok()
1073            .map(|v| v.to_owned());
1074        let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
1075            (Some(evaluator), Some(col)) => {
1076                evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
1077            }
1078            _ => None,
1079        };
1080        // A flat (Euclidean) chart cannot honestly present a curved candidate;
1081        // the selector drops it. Present both for curveable charts.
1082        let manifold = manifold_for(atom_idx);
1083        match build_atom_candidates(
1084            coords.view(),
1085            assign.view(),
1086            decoded.view(),
1087            target_resid.view(),
1088            curved_num_params,
1089            curved_phi.as_ref().map(|phi| phi.view()),
1090            fitted_turning,
1091        ) {
1092            Some((linear, curved, (t_bar, b0, b1))) => {
1093                // #1026 PER-ATOM EV-PRESERVATION gate. Collapsing this slot raises
1094                // the full reconstruction SSR by `linear_rss − curved_rss`; if that
1095                // is more than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total
1096                // target variance the collapse would DROP this ONE atom's EV
1097                // materially, so veto it by presenting only the curved candidate (the
1098                // selector must keep curved). A lossless / improving collapse (`≤ 0`)
1099                // and a negligible one stay free to collapse — EV-neutral cases (the
1100                // top-k / birth-topology lines) are untouched. Only curveable charts
1101                // are gated; a euclidean chart never had a curved option. NOTE: this
1102                // gate is PER-ATOM only — the accumulation of many small collapses
1103                // and the dictionary-level cross terms are handled by the aggregate
1104                // guard after selection.
1105                let loss = collapse_ssr_increase(
1106                    coords.view(),
1107                    assign.view(),
1108                    decoded.view(),
1109                    target_resid.view(),
1110                    t_bar,
1111                    &b0,
1112                    &b1,
1113                );
1114                let collapse_loses_ev = total_centered_variance.is_finite()
1115                    && total_centered_variance > 0.0
1116                    && loss > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1117                let euclidean = manifold.is_euclidean();
1118                let slot = if euclidean {
1119                    vec![linear]
1120                } else if collapse_loses_ev {
1121                    vec![curved]
1122                } else {
1123                    vec![linear, curved]
1124                };
1125                // Build the straight image, then its globally-aligned δ_k for the
1126                // GLOBAL cross-term guard, before moving it into the report.
1127                let image = AtomLinearImage {
1128                    atom_idx,
1129                    t_bar,
1130                    b0,
1131                    b1,
1132                    row_codes: None,
1133                    // Ordinary straight image: decoded at the atom's own
1134                    // coordinate, so it carries no residual-projection direction.
1135                    v: None,
1136                };
1137                let delta = slot_delta(&coords, &assign, &decoded, &image);
1138                if r0.is_none() {
1139                    r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1140                }
1141                slots.push(slot);
1142                // A euclidean slot never had a curved alternative, so its collapse
1143                // carries no recoverable EV loss for the global guard; record `None`.
1144                collapse_loss.push(if euclidean { None } else { Some(loss) });
1145                names.push(atom.name.clone());
1146                manifolds.push(manifold);
1147                turnings.push(fitted_turning);
1148                delta_evs.push(delta_ev_for(atom_idx));
1149                deltas.push(delta);
1150                linear_images.push(image);
1151            }
1152            // #1026 collapse rescue: `build_atom_candidates` refused because the
1153            // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
1154            // fixed point. Recover a FRESH linear image from the residual's top
1155            // direction (fresh per-row codes) and force the LINEAR verdict (a
1156            // single-option slot the selector must take) so the slot reconstructs
1157            // its residual's best linear axis at linear quality instead of the
1158            // collapsed-curve constant. `None` only when the residual itself is
1159            // degenerate — then there is genuinely nothing to recover and we skip.
1160            None => match build_collapse_rescue_linear_image(
1161                atom_idx,
1162                assign.view(),
1163                target_resid.view(),
1164            ) {
1165                Some((linear, image)) => {
1166                    let delta = slot_delta(&coords, &assign, &decoded, &image);
1167                    if r0.is_none() {
1168                        r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1169                    }
1170                    slots.push(vec![linear]);
1171                    // Forced-linear rescue: the curve was degenerate, so there is no
1172                    // curved alternative to roll back to and no recoverable EV loss.
1173                    collapse_loss.push(None);
1174                    names.push(atom.name.clone());
1175                    manifolds.push(manifold);
1176                    turnings.push(fitted_turning);
1177                    delta_evs.push(delta_ev_for(atom_idx));
1178                    deltas.push(delta);
1179                    linear_images.push(image);
1180                }
1181                None => continue,
1182            },
1183        }
1184    }
1185
1186    if slots.is_empty() {
1187        return Ok(None);
1188    }
1189
1190    let mut selection = select_hybrid_split(&slots)?;
1191
1192    // #1026 item-2 — GLOBAL CROSS-TERM EV-preservation guard over the SELECTED
1193    // collapses.
1194    //
1195    // The per-atom gate above bounds each atom's individual EV loss, but the TRUE
1196    // dictionary-level RSS increase from collapsing a SET of atoms is
1197    //   ΔRSS = Σ_k Δ_k + 2 Σ_{j<k} ⟨δ_j, δ_k⟩,
1198    // so two effects escape any per-atom or summed-loss bound: (1) the accumulation
1199    // of many individually-tolerable `Δ_k`, and (2) the CROSS TERMS between
1200    // simultaneously-collapsed atoms. The old aggregate `Σ max(Δ_k, 0)` guard bounded
1201    // (1) but was blind to (2): correlated collapse errors whose per-atom losses each
1202    // sit under tolerance can still push the true global loss over it.
1203    //
1204    // Every slot now carries its exact dictionary-level reconstruction-change vector
1205    // δ_k (globally row-aligned — the caller presents all per-atom arrays on the same
1206    // n rows), so we reconstruct the full dictionary WITH the selected collapse set
1207    // applied and measure the real global increase `ΔRSS` DIRECTLY (cross terms
1208    // captured), reverting the least-justified collapses until the degradation is
1209    // within the same EV tolerance and re-adjudicating so `selection` stays consistent.
1210    if total_centered_variance.is_finite() && total_centered_variance > 0.0 {
1211        if let Some(r0) = r0.as_ref() {
1212            let global_tol = SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1213            // Partition the SELECTED-collapsed slots: rollback-eligible ones (a curved
1214            // alternative still present and a finite per-atom loss) vs forced ones
1215            // (euclidean / collapse-rescue — no curved fallback, stay collapsed but
1216            // still enter the global reconstruction so their cross terms are counted).
1217            let mut eligible: Vec<(usize, f64, &Array2<f64>)> = Vec::new();
1218            let mut forced: Vec<&Array2<f64>> = Vec::new();
1219            for (i, choice) in selection.atoms.iter().enumerate() {
1220                if !choice.param.is_linear() {
1221                    continue;
1222                }
1223                let has_curved_alt = slots[i].iter().any(|c| !c.param.is_linear());
1224                let loss_finite = collapse_loss[i].map(|l| l.is_finite()).unwrap_or(false);
1225                if has_curved_alt && loss_finite {
1226                    eligible.push((i, choice.curved_evidence_margin, &deltas[i]));
1227                } else {
1228                    forced.push(&deltas[i]);
1229                }
1230            }
1231            let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1232            if !reverted.is_empty() {
1233                for slot in reverted {
1234                    if let Some(curved) =
1235                        slots[slot].iter().find(|c| !c.param.is_linear()).copied()
1236                    {
1237                        slots[slot] = vec![curved];
1238                    }
1239                }
1240                selection = select_hybrid_split(&slots)?;
1241            }
1242        }
1243    }
1244
1245    let verdicts: Vec<AtomHybridVerdict> = names
1246        .into_iter()
1247        .zip(selection.atoms.iter().copied())
1248        .zip(linear_images.into_iter())
1249        .zip(turnings.into_iter())
1250        .zip(delta_evs.into_iter())
1251        .map(
1252            |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
1253                let kept_curved = !choice.param.is_linear();
1254                AtomHybridVerdict {
1255                    atom_name,
1256                    choice,
1257                    kept_curved,
1258                    fitted_turning,
1259                    train_loao_delta_ev,
1260                    // Carry the straight sub-model only when the verdict collapses
1261                    // this slot to linear — the curved slots keep their fitted image.
1262                    linear_image: if kept_curved {
1263                        None
1264                    } else {
1265                        Some(linear_image)
1266                    },
1267                }
1268            },
1269        )
1270        .collect();
1271
1272    Ok(Some(SaeHybridSplitReport {
1273        verdicts,
1274        selection,
1275    }))
1276}
1277
1278#[cfg(test)]
1279mod tests {
1280    use super::*;
1281    use std::f64::consts::PI;
1282
1283    /// A straight RESPONSE residual (the atom's data is a line) is explained
1284    /// equally well by both candidates, so the cheaper linear special case wins.
1285    /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
1286    /// the dominance floor and the evidence argmin select linear. This is the
1287    /// common-data nested comparison (#1202): linear is the curved family's
1288    /// `Θ = 0` member, so it cannot lose when a line already explains the data.
1289    #[test]
1290    fn straight_residual_selects_linear() {
1291        let n = 40;
1292        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1293        let assign = Array1::<f64>::ones(n);
1294        // The data the atom must explain is a straight line in ℝ²; the curved
1295        // decoded image equals that same line (a Θ = 0 curved fit).
1296        let mut data = Array2::<f64>::zeros((n, 2));
1297        let mut decoded = Array2::<f64>::zeros((n, 2));
1298        for i in 0..n {
1299            data[[i, 0]] = coords[i];
1300            data[[i, 1]] = 0.6 * coords[i];
1301            decoded[[i, 0]] = coords[i];
1302            decoded[[i, 1]] = 0.6 * coords[i];
1303        }
1304        let (linear, curved, _) = build_atom_candidates(
1305            coords.view(),
1306            assign.view(),
1307            decoded.view(),
1308            data.view(),
1309            // a generous curved parameter price (M·p)
1310            10,
1311            None,
1312            Some(0.0),
1313        )
1314        .expect("straight residual yields a candidate pair");
1315        let choice =
1316            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1317        assert!(
1318            choice.param.is_linear(),
1319            "a straight response residual must keep the linear special case"
1320        );
1321    }
1322
1323    /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
1324    /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
1325    /// straight line (large linear_rss), so the curved candidate wins the common
1326    /// evidence comparison once its data-fit gain exceeds its extra parameter
1327    /// price (#1202).
1328    #[test]
1329    fn turning_residual_selects_curved_on_evidence() {
1330        let n = 60;
1331        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1332        let assign = Array1::<f64>::ones(n);
1333        // The data is a full circle; the curved decoded image is that same
1334        // circle (the curved atom reconstructs its assigned residual), so the
1335        // curved candidate has ≈ zero data-fit residual while a straight line
1336        // cannot follow the loop.
1337        let mut data = Array2::<f64>::zeros((n, 2));
1338        let mut decoded = Array2::<f64>::zeros((n, 2));
1339        for i in 0..n {
1340            let theta = 2.0 * PI * coords[i];
1341            data[[i, 0]] = theta.cos();
1342            data[[i, 1]] = theta.sin();
1343            decoded[[i, 0]] = theta.cos();
1344            decoded[[i, 1]] = theta.sin();
1345        }
1346        // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
1347        // the full-circle linear residual exceeds the extra-parameter overhead, so
1348        // curved wins on evidence.
1349        let (linear, curved, _) = build_atom_candidates(
1350            coords.view(),
1351            assign.view(),
1352            decoded.view(),
1353            data.view(),
1354            5,
1355            None,
1356            Some(2.0 * PI),
1357        )
1358        .expect("turning residual yields a candidate pair");
1359        assert!(
1360            linear.negative_log_evidence > curved.negative_log_evidence,
1361            "the line must misfit the circular residual worse than the curve does \
1362             (linear NLE {} should exceed curved NLE {})",
1363            linear.negative_log_evidence,
1364            curved.negative_log_evidence
1365        );
1366        let choice =
1367            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1368        assert_eq!(
1369            choice.param,
1370            gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
1371            "a full-circle response residual must keep the curved parameterization"
1372        );
1373        assert!(
1374            choice.curved_evidence_margin > 0.0,
1375            "curved must win a positive evidence margin over the linear secant"
1376        );
1377    }
1378
1379    /// The nested-dominance floor on common data (#1202): when the curved decoded
1380    /// image is a WORSE fit to the response residual than its own best straight
1381    /// projection, linear must win — the curved family cannot be charged extra
1382    /// parameters to fit the residual no better than its `Θ = 0` member. Here the
1383    /// data is a line but the curved image bends away from it, so curved_rss >
1384    /// linear_rss and the cheaper, better-fitting line is selected.
1385    #[test]
1386    fn linear_beats_curved_when_curve_misfits_residual() {
1387        let n = 50;
1388        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1389        let assign = Array1::<f64>::ones(n);
1390        // Data is a straight line; the curved decoded image is a parabola that
1391        // departs from it, so a straight line fits the data strictly better.
1392        let mut data = Array2::<f64>::zeros((n, 2));
1393        let mut decoded = Array2::<f64>::zeros((n, 2));
1394        for i in 0..n {
1395            let t = coords[i];
1396            data[[i, 0]] = t;
1397            data[[i, 1]] = 0.5 * t;
1398            decoded[[i, 0]] = t;
1399            decoded[[i, 1]] = t * t; // bends away from the linear data
1400        }
1401        let (linear, curved, _) = build_atom_candidates(
1402            coords.view(),
1403            assign.view(),
1404            decoded.view(),
1405            data.view(),
1406            // a real curved Θ above the floor so the dominance floor does not fire
1407            6,
1408            None,
1409            Some(1.0),
1410        )
1411        .expect("candidate pair");
1412        let choice =
1413            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1414        assert!(
1415            choice.param.is_linear(),
1416            "a curved image that fits the data worse than its own line must yield \
1417             to the linear special case on common-data evidence (#1202)"
1418        );
1419    }
1420
1421    /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1422    /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1423    /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1424    /// (#1203). Verify both contributions are present by reading the logdet off a
1425    /// candidate whose linear residual is exactly zero (response residual = the
1426    /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1427    /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1428    /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1429    /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1430    #[test]
1431    fn linear_logdet_includes_weighted_coordinate_spread() {
1432        let n = 40;
1433        let p = 2usize;
1434        // Read the logdet back off a candidate with zero linear residual: the
1435        // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1436        // with RSS == 0 and `NLE_linear = ½·logdet`.
1437        let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1438            // A straight image; the response residual is the same line scaled by
1439            // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1440            // matches it exactly and linear_rss == 0.
1441            let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1442            let mut decoded = Array2::<f64>::zeros((n, p));
1443            let mut data = Array2::<f64>::zeros((n, p));
1444            for i in 0..n {
1445                let l = line(coords[i]);
1446                decoded[[i, 0]] = l[0];
1447                decoded[[i, 1]] = l[1];
1448                data[[i, 0]] = assign[i] * l[0];
1449                data[[i, 1]] = assign[i] * l[1];
1450            }
1451            let (linear, _curved, _) = build_atom_candidates(
1452                coords.view(),
1453                assign.view(),
1454                decoded.view(),
1455                data.view(),
1456                10,
1457                None,
1458                Some(0.0),
1459            )
1460            .expect("straight residual yields a pair");
1461            2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1462        };
1463
1464        let base_coords =
1465            Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1466        let ones = Array1::<f64>::ones(n);
1467
1468        // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1469        let wide_coords = base_coords.mapv(|t| 2.0 * t);
1470        let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1471        assert!(
1472            (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1473            "linear logdet must move by p·log(4) when coordinate spread doubles \
1474             (got {d_spread}); the spread term log(s_tt) must be present"
1475        );
1476
1477        // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1478        // → logdet += 2p·log(4).
1479        let twos = Array1::<f64>::from_elem(n, 2.0);
1480        let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1481        assert!(
1482            (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1483            "linear logdet must move by 2p·log(4) when all assignment masses double \
1484             (got {d_weight})"
1485        );
1486    }
1487
1488    /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1489    /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1490    /// curved design whose columns are the constant and the centered coordinate
1491    /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1492    /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1493    /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1494    /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1495    /// `log(s_tt)` spread term, so this pins the genuine determinant.
1496    #[test]
1497    fn curved_gram_logdet_is_real_weighted_design_determinant() {
1498        let n = 40;
1499        let p = 3usize;
1500        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1501        let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1502
1503        // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1504        let mut w_sum = 0.0;
1505        let mut t_bar = 0.0;
1506        for i in 0..n {
1507            let w = assign[i] * assign[i];
1508            w_sum += w;
1509            t_bar += w * coords[i];
1510        }
1511        t_bar /= w_sum;
1512        let mut s_tt = 0.0;
1513        for i in 0..n {
1514            let dt = coords[i] - t_bar;
1515            s_tt += assign[i] * assign[i] * dt * dt;
1516        }
1517
1518        // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1519        // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1520        // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1521        let mut phi = Array2::<f64>::zeros((n, 2));
1522        for i in 0..n {
1523            phi[[i, 0]] = 1.0;
1524            phi[[i, 1]] = coords[i] - t_bar;
1525        }
1526        let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1527            .expect("non-degenerate curved design has a determinant");
1528        let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1529        assert!(
1530            (got - want).abs() < 1e-9,
1531            "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1532        );
1533
1534        // A rank-deficient design (a duplicated column) drops the null direction:
1535        // its determinant equals that of the single retained constant column,
1536        // p·log(w_sum), NOT a 2-column proxy.
1537        let mut phi_dup = Array2::<f64>::zeros((n, 2));
1538        for i in 0..n {
1539            phi_dup[[i, 0]] = 1.0;
1540            phi_dup[[i, 1]] = 1.0;
1541        }
1542        let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1543            .expect("rank-1 design still has a positive determinant");
1544        let want_dup = (p as f64) * (2.0 * w_sum).ln();
1545        assert!(
1546            (got_dup - want_dup).abs() < 1e-9,
1547            "rank-deficient curved Gram must report only its positive direction \
1548             (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1549        );
1550    }
1551
1552    /// #1051 NESTED MIN — the curved arm re-fit on the residual match-or-beats the
1553    /// best straight line: `curved_refit_rss ≤ best_line_rss` up to solver tolerance
1554    /// on a basis whose span contains the straight lane (here `Φ = [1, t, t²]`). A
1555    /// genuinely curved (quadratic) signal is STRICTLY preferred by the curved arm,
1556    /// while an exactly-straight signal ties near zero (so the cheaper linear lane
1557    /// wins downstream). This is the property the realized-curve heuristic could not
1558    /// establish: comparing a possibly-collapsed realized curve against min-over-lines
1559    /// did NOT guarantee `curved ≤ linear`; re-fitting the curved decoder does.
1560    #[test]
1561    fn refit_curved_rss_matches_or_beats_best_line_nested() {
1562        let n = 40usize;
1563        let p = 2usize;
1564        let coords =
1565            Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1566        let assign = Array1::<f64>::ones(n);
1567        // Φ = [1, t, t²]: its column span contains the straight lane [1, t], so the
1568        // decoder-only curved refit is a proper superset of the line fit.
1569        let mut phi = Array2::<f64>::zeros((n, 3));
1570        for i in 0..n {
1571            phi[[i, 0]] = 1.0;
1572            phi[[i, 1]] = coords[i];
1573            phi[[i, 2]] = coords[i] * coords[i];
1574        }
1575        // Best mass-weighted line RSS on `y` (assign = 1): fit design [1, (t − t̄)].
1576        let t_bar = coords.iter().sum::<f64>() / n as f64;
1577        let best_line_rss = |y: &Array2<f64>| -> f64 {
1578            let mut design = Array2::<f64>::zeros((n, 2));
1579            for i in 0..n {
1580                design[[i, 0]] = 1.0;
1581                design[[i, 1]] = coords[i] - t_bar;
1582            }
1583            let b = solve_design_least_squares(design.view(), y.view()).unwrap();
1584            let pred = design.dot(&b);
1585            let mut rss = 0.0_f64;
1586            for i in 0..n {
1587                for j in 0..p {
1588                    let r = y[[i, j]] - pred[[i, j]];
1589                    rss += r * r;
1590                }
1591            }
1592            rss
1593        };
1594        // (a) exactly straight, (b) quadratic curve, (c) noisy line.
1595        let mut y_line = Array2::<f64>::zeros((n, p));
1596        let mut y_curve = Array2::<f64>::zeros((n, p));
1597        let mut y_noisy = Array2::<f64>::zeros((n, p));
1598        for i in 0..n {
1599            let t = coords[i];
1600            y_line[[i, 0]] = 0.4 + 0.6 * t;
1601            y_line[[i, 1]] = -0.2 + 1.1 * t;
1602            y_curve[[i, 0]] = t * t;
1603            y_curve[[i, 1]] = 0.5 - t * t;
1604            y_noisy[[i, 0]] = 0.3 + 0.7 * t + 0.05 * (3.0 * t).sin();
1605            y_noisy[[i, 1]] = 0.9 * t;
1606        }
1607        for y in [&y_line, &y_curve, &y_noisy] {
1608            let curved = curved_refit_rss(phi.view(), assign.view(), y.view())
1609                .expect("non-degenerate refit");
1610            let line = best_line_rss(y);
1611            assert!(
1612                curved <= line + 1e-9 * (1.0 + line),
1613                "nested dominance: refit-curved RSS {curved} must be ≤ best-line RSS {line}"
1614            );
1615        }
1616        // A genuinely curved (quadratic) signal is STRICTLY preferred by the curve.
1617        let curved_c = curved_refit_rss(phi.view(), assign.view(), y_curve.view()).unwrap();
1618        let line_c = best_line_rss(&y_curve);
1619        assert!(
1620            curved_c < 0.5 * line_c,
1621            "a quadratic signal must be far better fit by the curve ({curved_c}) than \
1622             by the best line ({line_c})"
1623        );
1624        // A straight signal ties near zero — collapses to the cheaper linear lane.
1625        let curved_l = curved_refit_rss(phi.view(), assign.view(), y_line.view()).unwrap();
1626        assert!(
1627            curved_l < 1e-18 && best_line_rss(&y_line) < 1e-18,
1628            "an exactly-straight signal ties the two arms near zero (curved {curved_l})"
1629        );
1630    }
1631
1632    /// #1026 item-2 GLOBAL CROSS-TERM guard: two collapses with CORRELATED
1633    /// (parallel) reconstruction-change errors whose per-atom losses each sit under
1634    /// tolerance — and whose SUM `Σ Δ_k` is also under tolerance (so the OLD
1635    /// aggregate `Σ max(Δ_k,0)` guard would ACCEPT) — but whose cross term
1636    /// `2⟨δ_1,δ_2⟩` pushes the TRUE global loss over tolerance. The global guard must
1637    /// roll back the least-justified collapse. An ORTHOGONAL control (disjoint
1638    /// support) with the same per-atom losses is accepted, isolating the cross term.
1639    #[test]
1640    fn global_guard_rejects_correlated_collapses_the_aggregate_would_accept() {
1641        let n = 4usize;
1642        let p = 1usize;
1643        // R0 = 0 ⇒ Δ_k = ‖δ_k‖² and ΔRSS(S) = ‖Σ_S δ_k‖² exactly.
1644        let r0 = Array2::<f64>::zeros((n, p));
1645        // Parallel δ_1 = δ_2 = 1 on all rows: ‖δ_k‖² = 4 each, Σ Δ_k = 8,
1646        // ΔRSS_global = ‖δ_1+δ_2‖² = 16.
1647        let mut d1 = Array2::<f64>::zeros((n, p));
1648        let mut d2 = Array2::<f64>::zeros((n, p));
1649        for i in 0..n {
1650            d1[[i, 0]] = 1.0;
1651            d2[[i, 0]] = 1.0;
1652        }
1653        // tol = 10: each per-atom loss 4 ≤ 10, Σ Δ_k = 8 ≤ 10 (aggregate accepts),
1654        // but global 16 > 10 (cross term rejects).
1655        let global_tol = 10.0_f64;
1656        let eligible = vec![(0usize, 0.1_f64, &d1), (1usize, 0.2_f64, &d2)];
1657        let forced: Vec<&Array2<f64>> = Vec::new();
1658        let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1659        assert_eq!(
1660            reverted,
1661            vec![1usize],
1662            "the global cross-term guard must roll back the least-justified collapse \
1663             (largest margin = slot 1) that the summed-loss aggregate (8 ≤ 10) accepts"
1664        );
1665
1666        // ORTHOGONAL control: same per-atom losses (δ on disjoint rows) ⇒ cross term
1667        // 0 ⇒ ΔRSS_global = 8 ≤ 10 ⇒ no rollback. Isolates the cross term as the cause.
1668        let mut o1 = Array2::<f64>::zeros((n, p));
1669        let mut o2 = Array2::<f64>::zeros((n, p));
1670        o1[[0, 0]] = 2.0; // ‖o1‖² = 4
1671        o2[[2, 0]] = 2.0; // ‖o2‖² = 4, disjoint support
1672        let eligible_o = vec![(0usize, 0.1_f64, &o1), (1usize, 0.2_f64, &o2)];
1673        let reverted_o = global_collapse_rollback(r0.view(), &eligible_o, &forced, global_tol);
1674        assert!(
1675            reverted_o.is_empty(),
1676            "uncorrelated collapses (cross term 0, global loss 8 ≤ 10) must be accepted"
1677        );
1678    }
1679
1680    /// A degenerate (single-point-mass) coordinate has no slope direction and is
1681    /// refused rather than adjudicated on a fabricated deviance.
1682    #[test]
1683    fn degenerate_coordinate_is_refused() {
1684        let n = 5;
1685        let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1686        let assign = Array1::<f64>::ones(n);
1687        let decoded = Array2::<f64>::zeros((n, 2));
1688        let data = Array2::<f64>::zeros((n, 2));
1689        assert!(
1690            build_atom_candidates(
1691                coords.view(),
1692                assign.view(),
1693                decoded.view(),
1694                data.view(),
1695                6,
1696                None,
1697                Some(0.0)
1698            )
1699            .is_none(),
1700            "a degenerate coordinate span must be refused"
1701        );
1702    }
1703}