Skip to main content

gam_sae/
hybrid_split.rs

1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//!     y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//!               = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//!   * the CURVED candidate is scored at its CONSTRAINED MINIMUM over the atom's
28//!     decoder coefficients — re-fit on `y_resp` as
29//!     `min_B ‖y_resp − a_k·(Φ(t)·B)‖²` ([`curved_refit_rss`]) — so its data-fit
30//!     deviance is the smallest weighted RSS the curved family can attain on this
31//!     residual at the realized codes, not the possibly-collapsed realized curve.
32//!   * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
33//!     weighted least-squares straight line fit to `y_resp` (design column
34//!     scaled by the same assignment mass `a_k`), so its data-fit deviance is the
35//!     weighted RSS of the best line against the SAME residual.
36//!
37//! ## A genuine NESTED min-vs-min comparison (#1051)
38//!
39//! Both arms are now at their constrained minimum on the SAME leave-one-atom-out
40//! residual: the linear arm is the closed-form min-over-lines, and the curved arm
41//! is the min-over-decoders refit ([`curved_refit_rss`]). The linear special case
42//! `a_k·(b₀ + (t−t̄)·b₁)` is a MEMBER of the curved family whenever the straight
43//! lane `[1, (t−t̄)]` lies in the column span of the curved basis `Φ` — exactly for
44//! the interval / line-segment charts (whose basis carries the constant and linear
45//! terms) and to the basis's expressiveness for the periodic charts. So after the
46//! refit `curved_rss ≤ linear_rss` up to the least-squares solver tolerance: the
47//! curved family CANNOT do worse than its own `Θ = 0` sub-model. That is the
48//! nested-dominance property restored here — "curved match-or-beats linear" is now
49//! a floor established by re-optimizing the curved arm, not merely asserted.
50//!
51//! The broken (#1051) euclidean / multi-atom OUTER continuation is deliberately
52//! NOT re-entered: the direct per-atom `d = 1` decoder-only refit at the realized
53//! codes is sufficient to score the curved arm at its constrained minimum. When
54//! `Φ` is unavailable (no evaluator) or the refit solve is degenerate the arm
55//! falls back to the already-realized curve's RSS — an honest degradation, never a
56//! fabricated determinant. The argmin then trades the curved arm's (minimized)
57//! data fit against its larger parameter / complexity price, so a genuinely
58//! curved signal is preferred while a straight-line signal ties and collapses to
59//! the cheaper linear lane.
60//!
61//! This replaces an earlier diagnostic in which both candidates targeted the
62//! atom's already-fitted decoded image `γ_k(t)` (giving the curved arm a free
63//! zero residual against itself, #1202) and its successor which scored the
64//! already-REALIZED curved contribution (a post-fit heuristic that did not
65//! establish nested dominance); the curved arm is now re-fit to its minimum.
66
67use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
68
69use gam_linalg::faer_ndarray::FaerEigh;
70use gam_solve::evidence::{
71    HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
72};
73use gam_terms::latent::LatentManifold;
74use crate::chart_canonicalization::d1_atom_fitted_turning;
75use crate::manifold::{SaeManifoldAtom, solve_design_least_squares};
76
77/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
78/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
79/// penalty logdet and a full-rank design (no null space), which is the form
80/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
81/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
82/// candidates' Hessian logdets are already the closed-form scalar moments of
83/// their shared design — no factor cache or HVP callback to assemble.
84///
85/// SCALE CAVEAT: this is a FIXED-DISPERSION Laplace / penalized criterion, not a
86/// profiled REML marginal likelihood. It assumes unit dispersion (`σ² = 1`)
87/// after preprocessing — the residual objective is the bare `½ RSS` with no
88/// `RSS/(2σ²)` rescaling and no profiled-over-σ² log-determinant correction. It
89/// is therefore SENSITIVE to the response scale: rescaling `y → s·y` scales the
90/// `residual_objective` (`RSS`) by `s²` but leaves the `½ log|H|` complexity term
91/// unchanged, so the curved-vs-linear trade-off it expresses is not
92/// scale-invariant. Callers must keep the targets on a consistent (preprocessed,
93/// roughly unit-scale) footing for the comparison to mean what it says.
94fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
95    residual_objective + 0.5 * log_det_h
96}
97
98/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
99/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
100/// assignment-mass design weight the linear arm uses), summed over the
101/// eigenvalues above a relative spectral floor (#1223). This is the genuine
102/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
103/// assembled for the curved basis so the two arms' Laplace complexity prices are
104/// computed on the SAME footing instead of pricing the curved arm with a
105/// parameter-count proxy `M·log(Σw)`.
106///
107/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
108/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
109/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
110/// design's outer Gram over its basis columns; it is identical across the `p`
111/// output channels (every channel shares the design `Φ`), so the per-channel
112/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
113///
114/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
115/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
116/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
117/// positive eigenvalues (a fully rank-deficient design carries no determinant);
118/// the caller then falls back to the parameter-count proxy rather than fabricate
119/// a determinant.
120fn curved_design_gram_logdet(
121    phi: ArrayView2<'_, f64>,
122    assign: ArrayView1<'_, f64>,
123    p: usize,
124) -> Option<f64> {
125    let n = phi.nrows();
126    let m = phi.ncols();
127    if m == 0 || assign.len() != n || n == 0 {
128        return None;
129    }
130    // G = Φᵀ diag(a²) Φ  (M×M, symmetric PSD).
131    let mut gram = Array2::<f64>::zeros((m, m));
132    for i in 0..n {
133        let w = assign[i] * assign[i];
134        if !(w.is_finite() && w >= 0.0) {
135            return None;
136        }
137        if w == 0.0 {
138            continue;
139        }
140        let row = phi.row(i);
141        for a in 0..m {
142            let wa = w * row[a];
143            for b in a..m {
144                gram[[a, b]] += wa * row[b];
145            }
146        }
147    }
148    // Symmetrize the lower triangle (we only filled the upper).
149    for a in 0..m {
150        for b in 0..a {
151            gram[[a, b]] = gram[[b, a]];
152        }
153    }
154    if gram.iter().any(|v| !v.is_finite()) {
155        return None;
156    }
157    let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
158    // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
159    // tied to the largest eigenvalue, dropping the numerically-null directions
160    // (the curved design's null space, analogous to the linear arm's full-rank
161    // 2-D Gram). A design with no positive eigenvalue carries no determinant.
162    let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
163    if !(lambda_max > 0.0 && lambda_max.is_finite()) {
164        return None;
165    }
166    let floor = lambda_max * 1e-12;
167    let mut log_det = 0.0_f64;
168    let mut rank = 0usize;
169    for &lambda in vals.iter() {
170        if lambda > floor {
171            log_det += lambda.ln();
172            rank += 1;
173        }
174    }
175    if rank == 0 || !log_det.is_finite() {
176        return None;
177    }
178    // The design Gram is shared across the p output channels.
179    Some((p as f64) * log_det)
180}
181
182/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
183/// the exact assignment-mass-weighted least-squares line fit to the atom's
184/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
185/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
186/// verdict that selects LINEAR so the collapsed reconstruction can replace the
187/// curved decoded row with this straight image at any coordinate WITHOUT
188/// re-entering the (broken, #1051) outer fit — the coefficients are already
189/// realized inside the adjudication.
190#[derive(Clone, Debug)]
191pub struct AtomLinearImage {
192    /// The atom's slot index in the dictionary (so the collapsed assembly knows
193    /// which atom's decoded row to substitute).
194    pub atom_idx: usize,
195    /// The mass-weighted coordinate mean `t̄` the line is centered on.
196    pub t_bar: f64,
197    /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
198    pub b0: Array1<f64>,
199    /// Per-output-channel slope `b₁` (length `p`).
200    pub b1: Array1<f64>,
201    /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
202    /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
203    /// only when the atom's circle codes had collapsed to a single point
204    /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
205    /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
206    /// (the projection of the leave-this-atom-out residual onto its top
207    /// mass-weighted output direction). This is what lets a circle atom that the
208    /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
209    /// still reconstruct its residual's best linear direction — recovering the
210    /// linear-tail reach the hybrid-split was designed to deliver — instead of a
211    /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
212    /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
213    /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
214    ///
215    /// TRAIN-ONLY CAVEAT (#1777): these are the TRAIN rows' codes. They are only
216    /// meaningful for the exact rows the split was fit on; a held-out row has no
217    /// entry here and used to fall back to the atom's own (collapsed) coordinate
218    /// `own_t` — a DIFFERENT, degraded model out of sample. Prefer [`Self::v`]:
219    /// projecting a held-out row's leave-this-atom-out residual onto `v` recovers
220    /// that row's coordinate by the SAME math the train codes were built with, so
221    /// train and OOS use one model. `row_codes` is retained for back-compat and as
222    /// the exact cached train projection.
223    pub row_codes: Option<Array1<f64>>,
224    /// #1777 collapse-rescue projection DIRECTION `v` (length `p`, unit norm), the
225    /// top mass-weighted output direction of the atom's leave-this-atom-out
226    /// residual. `Some` exactly when this is a collapse-rescued image (paired with
227    /// `row_codes`); `None` for the ordinary straight-image path (which decodes at
228    /// the atom's own coordinate). This is the SERIALIZABLE quantity the FFI must
229    /// persist so an OOS term can recompute any row's coordinate as
230    /// `uᵢ = ⟨y_i − Σ_{j≠k} f_j(x_i), v⟩` — identical to the train code
231    /// `row_codes[i]` on a train row, and the correct held-out coordinate on an OOS
232    /// row (see [`Self::coordinate_from_residual`]). Length must equal `b0`/`b1`.
233    pub v: Option<Array1<f64>>,
234}
235
236impl AtomLinearImage {
237    /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
238    pub fn fill_row(&self, t: f64, out: &mut [f64]) {
239        let dt = t - self.t_bar;
240        for (j, slot) in out.iter_mut().enumerate() {
241            *slot = self.b0[j] + dt * self.b1[j];
242        }
243    }
244
245    /// The coordinate at which row `row` should evaluate this image: the
246    /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
247    /// realized coordinate `own_t` passed by the caller.
248    ///
249    /// TRAIN-ONLY: `row_codes` is indexed by TRAIN row, so this is correct only
250    /// for the rows the split was fit on. Out of sample use
251    /// [`Self::coordinate_from_residual`] (target-aware, model-identical to train).
252    pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
253        match &self.row_codes {
254            Some(u) if row < u.len() => u[row],
255            _ => own_t,
256        }
257    }
258
259    /// #1777 — the collapse-rescue coordinate of a row from ITS OWN
260    /// leave-this-atom-out residual `resid = y_i − Σ_{j≠k} f_j(x_i)` (length `p`),
261    /// namely `uᵢ = ⟨resid, v⟩`. `Some(uᵢ)` exactly when this is a collapse-rescued
262    /// image (`v` is set); `None` for the ordinary straight-image path (which has
263    /// no projection direction and decodes at the atom's own coordinate).
264    ///
265    /// This is the SAME math [`build_collapse_rescue_linear_image`] used to build
266    /// the train `row_codes` (`row_codes[i] = ⟨target_resid[i], v⟩`), so on a TRAIN
267    /// row it reproduces `row_codes[i]` exactly, and on a HELD-OUT row it yields
268    /// that row's correct coordinate — train and OOS share one model. Returns
269    /// `None` if `resid`'s length disagrees with `v`.
270    pub fn coordinate_from_residual(&self, resid: &[f64]) -> Option<f64> {
271        let v = self.v.as_ref()?;
272        if resid.len() != v.len() {
273            return None;
274        }
275        Some(v.iter().zip(resid).map(|(&vj, &rj)| vj * rj).sum())
276    }
277
278    /// Whether this image is a #1777 collapse-rescued image (carries a projection
279    /// direction `v` and per-row train codes) rather than an ordinary straight
280    /// image evaluated at the atom's own coordinate.
281    pub fn is_collapse_rescued(&self) -> bool {
282        self.v.is_some()
283    }
284}
285
286/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
287#[derive(Clone, Debug)]
288pub struct AtomHybridVerdict {
289    /// The atom's name (slot identity in the dictionary).
290    pub atom_name: String,
291    /// The evidence-selected parameterization choice for this slot.
292    pub choice: HybridAtomChoice,
293    /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
294    /// `false` iff it yielded to the LINEAR special case (the straight tail).
295    pub kept_curved: bool,
296    /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
297    /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
298    /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
299    /// iff the evaluator has no analytic second jet or the curve is degenerate.
300    /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
301    /// structured data on the persisted report rather than a transient log line.
302    pub fitted_turning: Option<f64>,
303    /// The atom's training leave-one-atom-out explained-variance contribution
304    /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
305    /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
306    /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
307    /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
308    /// genuine curved family. `None` iff the caller did not supply LOAO EV.
309    pub train_loao_delta_ev: Option<f64>,
310    /// The fitted straight sub-model for this slot, present iff the verdict
311    /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
312    /// substitutes this for the atom's curved decoded image, making the verdict
313    /// load-bearing on the reconstruction rather than a passive diagnostic.
314    pub linear_image: Option<AtomLinearImage>,
315}
316
317/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
318/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
319/// against.
320#[derive(Clone, Debug)]
321pub struct SaeHybridSplitReport {
322    /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
323    /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
324    /// — they carry no curved/linear adjudication.
325    pub verdicts: Vec<AtomHybridVerdict>,
326    /// The dictionary-level rolled-up selection (summed NLE, total parameters,
327    /// curved/linear counts) over the eligible atoms.
328    pub selection: HybridSplitSelection,
329}
330
331/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
332/// straight-line fit with a residual estimate, so the linear candidate's
333/// deviance is undefined. Such atoms are skipped (absent from the report),
334/// never adjudicated on a fabricated deviance.
335const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
336
337/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
338/// its linear tail only if doing so costs at most this fraction of the target's
339/// total (centered) variance in full-reconstruction explained variance. The
340/// evidence argmin trades data-fit against the curved arm's parameter price in
341/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
342/// cheaper line even when the curve carries real reconstruction signal — the
343/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
344/// a direct guard on the quantity that actually regressed: the per-atom collapse
345/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
346/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
347/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
348/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
349/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
350/// reintroduce the scale-incommensurability that sank the evidence-reformulation
351/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
352/// `≈ 0` EV and still collapses losslessly; only a curve doing real
353/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
354/// comfortably above f64 round-off on an exact-line collapse yet far below any
355/// material EV loss, so it separates the lossless and load-bearing regimes
356/// without tuning.
357///
358/// PER-ATOM SCOPE: applied per slot, this tolerance bounds only ONE atom's
359/// individual EV loss. It does NOT by itself bound the dictionary-level EV loss
360/// when several atoms collapse at once: the true global RSS change is
361/// `Σ_k Δ_k + 2 Σ_{j<k} <Δrecon_j, Δrecon_k>` — both the accumulation of many
362/// individually-tolerable `Δ_k` AND the cross terms between co-active atoms are
363/// invisible to the per-atom gate. [`build_hybrid_split_report`] adds an
364/// aggregate global guard (interpreting this same fraction as a bound on
365/// `Σ_k max(Δ_k, 0)`) on top of the per-atom gate; see there for what that guard
366/// does and does not prove.
367const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
368
369/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
370/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
371/// scored against the atom's leave-this-atom-out response residual `y_resp`
372/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
373/// full reconstruction differs from the collapsed one ONLY in this atom's
374/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
375/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
376/// dividing by the fixed total target variance gives the full-EV loss the
377/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
378/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
379/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
380/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
381/// evidence comparison see the same residuals.
382fn collapse_ssr_increase(
383    coords: ArrayView1<'_, f64>,
384    assign: ArrayView1<'_, f64>,
385    decoded: ArrayView2<'_, f64>,
386    target_resid: ArrayView2<'_, f64>,
387    t_bar: f64,
388    b0: &Array1<f64>,
389    b1: &Array1<f64>,
390) -> f64 {
391    let n = assign.len();
392    let p = target_resid.ncols();
393    let mut curved_rss = 0.0_f64;
394    let mut linear_rss = 0.0_f64;
395    for i in 0..n {
396        let a = assign[i];
397        let dt = coords[i] - t_bar;
398        for j in 0..p {
399            let y = target_resid[[i, j]];
400            let r_curved = y - a * decoded[[i, j]];
401            curved_rss += r_curved * r_curved;
402            let r_linear = y - a * (b0[j] + dt * b1[j]);
403            linear_rss += r_linear * r_linear;
404        }
405    }
406    linear_rss - curved_rss
407}
408
409/// #1051/#1026 NESTED MIN — re-fit the curved atom's decoder on the SAME
410/// leave-this-atom-out residual `y_resp` and return its **minimum** weighted
411/// reconstruction RSS at the atom's realized codes.
412///
413/// This is the genuine constrained minimum of the curved family over its free
414/// decoder coefficients `B`:
415///
416/// ```text
417/// min_B Σᵢ ‖ y_resp[i] − a_k·( Φ(t_i)·B ) ‖²   (design = diag(a)·Φ, rhs = y_resp).
418/// ```
419///
420/// It restores the nested-dominance floor `curved_rss ≤ linear_rss`. The linear
421/// special case `a_k·(b₀ + (t−t̄)·b₁)` is itself a member of this family whenever
422/// the straight lane `[1, (t−t̄)]` lies in the column span of the curved basis
423/// `Φ` — exactly (interval / line-segment charts, whose basis carries the
424/// constant and linear terms) or to the basis's expressiveness (the periodic
425/// charts). So `min_B` over `Φ` cannot do WORSE than the best straight line: the
426/// returned RSS is `≤` the linear arm's RSS up to the least-squares solver
427/// tolerance. This is the direct per-atom `d = 1` refit the module owes; the
428/// broken (#1051) euclidean / multi-atom OUTER continuation is deliberately NOT
429/// re-entered — the decoder-only refit at the realized codes is sufficient to
430/// score the curved arm at its constrained minimum.
431///
432/// `phi` is the curved design `Φ(t)` on the atom's rows (`n × M`); `assign` the
433/// per-row mass `a_k` (NOT squared — the design weight is the mass itself, so the
434/// residual is on the SAME footing the linear arm and the joint loss use).
435/// Returns `None` when the solve is degenerate or non-finite; the caller then
436/// falls back to the already-realized curve's RSS rather than fabricate a value.
437fn curved_refit_rss(
438    phi: ArrayView2<'_, f64>,
439    assign: ArrayView1<'_, f64>,
440    target_resid: ArrayView2<'_, f64>,
441) -> Option<f64> {
442    let n = phi.nrows();
443    let m = phi.ncols();
444    let p = target_resid.ncols();
445    if m == 0 || n == 0 || assign.len() != n || target_resid.nrows() != n || p == 0 {
446        return None;
447    }
448    // Weighted design `diag(a)·Φ` (n×M). The refit minimizes ‖diag(a)·Φ·B − y_resp‖²,
449    // so the fitted prediction is diag(a)·Φ·B — the curved contribution `a_k·γ_k`
450    // at its best decoder `B` on this residual.
451    let mut design = Array2::<f64>::zeros((n, m));
452    for i in 0..n {
453        let a = assign[i];
454        if !a.is_finite() {
455            return None;
456        }
457        for c in 0..m {
458            design[[i, c]] = a * phi[[i, c]];
459        }
460    }
461    let b = solve_design_least_squares(design.view(), target_resid).ok()?;
462    if b.iter().any(|v| !v.is_finite()) {
463        return None;
464    }
465    let pred = design.dot(&b);
466    let mut rss = 0.0_f64;
467    for i in 0..n {
468        for j in 0..p {
469            let r = target_resid[[i, j]] - pred[[i, j]];
470            rss += r * r;
471        }
472    }
473    rss.is_finite().then_some(rss)
474}
475
476/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
477/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
478/// an honest pair (too few rows, degenerate coordinate span, or non-finite
479/// numbers). Both candidates are scored against the SAME data — the atom's
480/// leave-this-atom-out response residual `y_resp` — at their CONSTRAINED MINIMUM:
481/// the linear arm is the freshly-fit min-over-lines and the curved arm is the
482/// min-over-decoders refit ([`curved_refit_rss`]), so the comparison is a genuine
483/// nested min-vs-min one and `curved_rss ≤ linear_rss` holds up to solver
484/// tolerance for a basis whose span contains the straight lane (#1051).
485///
486/// Inputs over the atom's assigned rows:
487///   * `coords` — the fitted on-atom coordinate `t`.
488///   * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
489///     squares it where the design weight `a_k²` is needed).
490///   * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
491///     whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
492///   * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
493///     (`p` cols): the response with every OTHER atom's contribution removed.
494///     This is the data both candidates fit.
495///
496/// The curved candidate's data-fit deviance is `½·min_B Σ ‖y_resp − a_k·(Φ·B)‖²`
497/// (its constrained minimum over the decoder; the mass lives in the prediction);
498/// the linear candidate fits the best mass-weighted straight line to `y_resp` and
499/// pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`. Because the linear prediction is
500/// itself a curved-family member (the straight lane lies in `span(Φ)` for the
501/// eligible charts), the curved arm's minimized RSS is `≤` the linear arm's up to
502/// solver tolerance, so the argmin is a genuine nested min-vs-min dominance
503/// comparison, not a post-fit compression heuristic.
504fn build_atom_candidates(
505    coords: ArrayView1<'_, f64>,
506    assign: ArrayView1<'_, f64>,
507    decoded: ArrayView2<'_, f64>,
508    target_resid: ArrayView2<'_, f64>,
509    curved_num_params: usize,
510    curved_phi: Option<ArrayView2<'_, f64>>,
511    fitted_turning: Option<f64>,
512) -> Option<(
513    HybridAtomCandidate,
514    HybridAtomCandidate,
515    (f64, Array1<f64>, Array1<f64>),
516)> {
517    let n = coords.len();
518    let p = decoded.ncols();
519    if n < MIN_ROWS_FOR_LINEAR_FIT
520        || decoded.nrows() != n
521        || assign.len() != n
522        || target_resid.nrows() != n
523        || target_resid.ncols() != p
524        || p == 0
525    {
526        return None;
527    }
528
529    // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
530    // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
531    // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
532    // and spread `s_tt` under that weight; a row that barely belongs to the atom
533    // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
534    let mut w_sum = 0.0_f64;
535    let mut t_bar = 0.0_f64;
536    for i in 0..n {
537        let a = assign[i];
538        if !(a.is_finite() && a >= 0.0) {
539            return None;
540        }
541        let w = a * a;
542        w_sum += w;
543        t_bar += w * coords[i];
544    }
545    if !(w_sum > 0.0) {
546        return None;
547    }
548    t_bar /= w_sum;
549
550    // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
551    // line's design weight. A degenerate (single-point mass) coordinate has no
552    // slope direction; refuse rather than divide by ~0.
553    let mut s_tt = 0.0_f64;
554    for i in 0..n {
555        let dt = coords[i] - t_bar;
556        s_tt += assign[i] * assign[i] * dt * dt;
557    }
558    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
559        return None;
560    }
561
562    // Per-output-channel mass-weighted least squares for the line fit to the
563    // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
564    // the centered basis has the diagonal normal equations
565    //   b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum,   (recall the design intercept is a_k)
566    //   b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
567    let mut b0 = Array1::<f64>::zeros(p);
568    let mut b1 = Array1::<f64>::zeros(p);
569    for j in 0..p {
570        let mut s_1y = 0.0_f64;
571        let mut s_ty = 0.0_f64;
572        for i in 0..n {
573            let a = assign[i];
574            let dt = coords[i] - t_bar;
575            let y = target_resid[[i, j]];
576            s_1y += a * y;
577            s_ty += a * dt * y;
578        }
579        b0[j] = s_1y / w_sum;
580        b1[j] = s_ty / s_tt;
581    }
582
583    // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
584    // common data. The linear candidate predicts the best line
585    // `a_k·(b₀ + (t − t̄)·b₁)`; the curved candidate is scored at its CONSTRAINED
586    // MINIMUM over the decoder coefficients (nested min-vs-min, #1051), re-fit on
587    // this same residual — not the possibly-collapsed already-realized curve. We
588    // also carry the realized curve's RSS as the honest fallback when the basis
589    // `Φ` is unavailable or its refit solve is degenerate.
590    let mut linear_rss = 0.0_f64;
591    let mut realized_curved_rss = 0.0_f64;
592    for i in 0..n {
593        let a = assign[i];
594        let dt = coords[i] - t_bar;
595        for j in 0..p {
596            let y = target_resid[[i, j]];
597            let r_linear = y - a * (b0[j] + dt * b1[j]);
598            linear_rss += r_linear * r_linear;
599            let r_curved = y - a * decoded[[i, j]];
600            realized_curved_rss += r_curved * r_curved;
601        }
602    }
603    // #1051 NESTED MIN — the curved arm's data fit is `min_B ‖y_resp − diag(a)Φ B‖²`,
604    // its constrained minimum over the decoder. Because the linear lane is a member
605    // of the curved family (the straight columns lie in `span(Φ)` for the eligible
606    // charts), this min-curved RSS is `≤ linear_rss` up to solver tolerance — the
607    // "curved match-or-beats linear" floor. Fall back to the realized curve's RSS
608    // only when Φ is absent or the refit is degenerate.
609    let curved_rss = match curved_phi {
610        Some(phi) if phi.nrows() == n => {
611            curved_refit_rss(phi, assign, target_resid).unwrap_or(realized_curved_rss)
612        }
613        _ => realized_curved_rss,
614    };
615
616    // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
617    // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
618    // well its REALIZED curve explains the residual) plus its larger `M·p`
619    // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
620    // `curved_rss` is the realized (not re-optimized) curve's misfit, so it is NOT
621    // guaranteed `≤ linear_rss`: when the realized curve underperforms its own best
622    // straight projection the cheaper line simply wins. The argmin trades whatever
623    // data-fit the realized curve buys against the curvature parameter price — a
624    // post-fit compression decision, not a nested match-or-beat floor.
625    let curved_residual_objective = 0.5 * curved_rss;
626    let linear_residual_objective = 0.5 * linear_rss;
627
628    // Linear candidate parameter price: intercept + slope per output channel.
629    let linear_num_params = 2 * p;
630
631    // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
632    //
633    // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
634    // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
635    // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
636    // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
637    //
638    //     log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
639    //
640    // The `log(s_tt)` term is the slope direction's information: a line through a
641    // wide, heavily-massed coordinate spread is better-determined than one through
642    // a tiny spread, and the Laplace evidence must reflect that (#1203).
643    //
644    // The curved arm's Laplace determinant is now the genuine weighted-design
645    // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
646    // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
647    // from the curved basis `Φ` on the atom's assigned rows under the same
648    // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
649    // `λS` normalizer, so the complexity price is computed on a symmetric
650    // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
651    // caller could not evaluate the basis) or its Gram is fully rank-deficient do
652    // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
653    // the comparison degrades gracefully rather than fabricating a determinant.
654    if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
655        return None;
656    }
657    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
658    let curved_log_det_h = curved_phi
659        .and_then(|phi| {
660            if phi.nrows() == n {
661                curved_design_gram_logdet(phi, assign, p)
662            } else {
663                None
664            }
665        })
666        .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
667
668    // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
669    // smoothing-penalty logdet (the intrinsic smoothness penalty is
670    // reparameterization-invariant and identical in expectation across the two
671    // parameterizations of the same image).
672    let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
673    let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
674    if !(linear_nle.is_finite() && curved_nle.is_finite()) {
675        return None;
676    }
677
678    let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
679    let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
680    Some((linear, curved, (t_bar, b0, b1)))
681}
682
683/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
684/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
685/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
686/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
687/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
688/// let such a linear-tail atom decode as a straight line; the only reason it can't
689/// here is that its own codes carry no spread to fit a slope against.
690///
691/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
692/// of the leave-this-atom-out residual onto its top mass-weighted output direction
693/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
694/// fit uses). These codes span the residual's strongest linear axis by
695/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
696/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
697/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
698/// or `None` when the residual itself carries no usable direction (a genuine zero
699/// atom the mass/decoder guards own).
700fn build_collapse_rescue_linear_image(
701    atom_idx: usize,
702    assign: ArrayView1<'_, f64>,
703    target_resid: ArrayView2<'_, f64>,
704) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
705    let n = assign.len();
706    let p = target_resid.ncols();
707    if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
708        return None;
709    }
710    let mut w_sum = 0.0_f64;
711    for i in 0..n {
712        let a = assign[i];
713        if !(a.is_finite() && a >= 0.0) {
714            return None;
715        }
716        w_sum += a * a;
717    }
718    if !(w_sum > 0.0) {
719        return None;
720    }
721    // Top mass-weighted output direction `v` of the residual via power iteration on
722    // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
723    // Seed from the per-channel weighted energy so a rank-1 residual converges in
724    // one step and the seed is deterministic (no RNG).
725    let mut v = Array1::<f64>::zeros(p);
726    for j in 0..p {
727        let mut e = 0.0_f64;
728        for i in 0..n {
729            let a = assign[i];
730            let y = target_resid[[i, j]];
731            e += a * a * y * y;
732        }
733        v[j] = e;
734    }
735    let mut vnorm = v.dot(&v).sqrt();
736    if !(vnorm > 0.0) {
737        return None;
738    }
739    v.mapv_inplace(|x| x / vnorm);
740    for _ in 0..32 {
741        let mut mv = Array1::<f64>::zeros(p);
742        for i in 0..n {
743            let a = assign[i];
744            let w = a * a;
745            let mut proj = 0.0_f64;
746            for j in 0..p {
747                proj += target_resid[[i, j]] * v[j];
748            }
749            let wp = w * proj;
750            for j in 0..p {
751                mv[j] += wp * target_resid[[i, j]];
752            }
753        }
754        vnorm = mv.dot(&mv).sqrt();
755        if !(vnorm > 0.0) {
756            return None;
757        }
758        mv.mapv_inplace(|x| x / vnorm);
759        let cos = mv.dot(&v).abs();
760        v = mv;
761        if cos > 1.0 - 1e-12 {
762            break;
763        }
764    }
765    // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
766    let mut u = Array1::<f64>::zeros(n);
767    let mut t_bar = 0.0_f64;
768    for i in 0..n {
769        let mut proj = 0.0_f64;
770        for j in 0..p {
771            proj += target_resid[[i, j]] * v[j];
772        }
773        u[i] = proj;
774        t_bar += assign[i] * assign[i] * proj;
775    }
776    t_bar /= w_sum;
777    let mut s_tt = 0.0_f64;
778    for i in 0..n {
779        let dt = u[i] - t_bar;
780        s_tt += assign[i] * assign[i] * dt * dt;
781    }
782    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
783        return None;
784    }
785    let mut b0 = Array1::<f64>::zeros(p);
786    let mut b1 = Array1::<f64>::zeros(p);
787    let mut linear_rss = 0.0_f64;
788    for j in 0..p {
789        let mut s_1y = 0.0_f64;
790        let mut s_ty = 0.0_f64;
791        for i in 0..n {
792            let a = assign[i];
793            let dt = u[i] - t_bar;
794            let y = target_resid[[i, j]];
795            s_1y += a * y;
796            s_ty += a * dt * y;
797        }
798        b0[j] = s_1y / w_sum;
799        b1[j] = s_ty / s_tt;
800    }
801    for i in 0..n {
802        let a = assign[i];
803        let dt = u[i] - t_bar;
804        for j in 0..p {
805            let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
806            linear_rss += r * r;
807        }
808    }
809    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
810    let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
811    if !linear_nle.is_finite() {
812        return None;
813    }
814    let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
815    let image = AtomLinearImage {
816        atom_idx,
817        t_bar,
818        b0,
819        b1,
820        row_codes: Some(u),
821        // #1777 — persist the projection direction so an OOS row's coordinate can
822        // be recomputed as ⟨residual, v⟩ (identical to the train `row_codes`),
823        // rather than falling back to the atom's collapsed own coordinate.
824        v: Some(v),
825    };
826    Some((linear, image))
827}
828
829/// #1026 item-2 — one collapsed slot's TRUE dictionary-level reconstruction
830/// change `δ_k[i,j] = a_k·(γ_k(t_i) − line_k(t_i))` over ALL globally-aligned rows
831/// (the caller presents `coords`/`assign`/`decoded`/`image` on the same `n` rows,
832/// so these δ vectors ARE cross-atom aligned and their inner products are the
833/// genuine cross terms). `line_k` is evaluated at the image's own coordinate
834/// (collapse-rescue slots evaluate at their fresh per-row codes via
835/// [`AtomLinearImage::coordinate_for_row`], exactly as the collapsed reconstruction
836/// does). Collapsing atom `k` shifts the full reconstruction residual by `+δ_k`.
837fn slot_delta(
838    coords: &Array1<f64>,
839    assign: &Array1<f64>,
840    decoded: &Array2<f64>,
841    image: &AtomLinearImage,
842) -> Array2<f64> {
843    let n = decoded.nrows();
844    let p = decoded.ncols();
845    let mut d = Array2::<f64>::zeros((n, p));
846    for i in 0..n {
847        let a = assign[i];
848        let coord = image.coordinate_for_row(i, coords[i]);
849        let dt = coord - image.t_bar;
850        for j in 0..p {
851            let line = image.b0[j] + dt * image.b1[j];
852            d[[i, j]] = a * (decoded[[i, j]] - line);
853        }
854    }
855    d
856}
857
858/// #1026 item-2 — the ALL-CURVED global reconstruction residual
859/// `R0 = target − Σ_all a·γ` recovered from any single atom's leave-this-atom-out
860/// residual: `R0 = y_resp_k − a_k·γ_k = target_resid − a_k·decoded`. Identical for
861/// every atom (each `target_resid` adds back exactly that atom's own contribution),
862/// so the caller computes it once from the first slot.
863fn slot_r0(assign: &Array1<f64>, decoded: &Array2<f64>, target_resid: &Array2<f64>) -> Array2<f64> {
864    let n = decoded.nrows();
865    let p = decoded.ncols();
866    let mut r0 = Array2::<f64>::zeros((n, p));
867    for i in 0..n {
868        let a = assign[i];
869        for j in 0..p {
870            r0[[i, j]] = target_resid[[i, j]] - a * decoded[[i, j]];
871        }
872    }
873    r0
874}
875
876/// #1026 item-2 — GLOBAL cross-term collapse guard. Given the collapsed slots'
877/// dictionary-level reconstruction-change vectors `δ_k` (globally row-aligned,
878/// `n × p`) and the all-curved global residual `R0 = target − Σ_all a·γ`, decide
879/// which collapses must revert to curved so the TRUE global reconstruction SSR
880/// increase
881///
882/// ```text
883/// ΔRSS(S) = ‖R0 + Σ_{k∈S} δ_k‖² − ‖R0‖²
884///         = 2⟨R0, Σ_{k∈S} δ_k⟩ + ‖Σ_{k∈S} δ_k‖²
885///         = Σ_{k∈S} Δ_k + 2 Σ_{j<k∈S} ⟨δ_j, δ_k⟩
886/// ```
887///
888/// stays within `global_tol`. Unlike the per-atom / summed-loss guard this
889/// INCLUDES the cross terms `2 Σ_{j<k} ⟨δ_j, δ_k⟩` between simultaneously-collapsed
890/// atoms (correlated collapse errors), which the aggregate `Σ max(Δ_k, 0)` bound is
891/// blind to. `forced` are the always-collapsed δ (euclidean / collapse-rescue slots
892/// with no curved alternative): they stay in the reconstruction but cannot be
893/// rolled back. `eligible` are `(slot, curved_evidence_margin, δ)` for
894/// rollback-eligible collapses. When `ΔRSS` over ALL collapses exceeds tolerance,
895/// the least-justified eligible collapses (largest margin — the most marginal
896/// linear win) are reverted one at a time, RECOMPUTING the true global increase
897/// (cross terms and all) after each revert, until within tolerance or none remain.
898/// Returns the slot indices to revert to curved.
899fn global_collapse_rollback(
900    r0: ArrayView2<'_, f64>,
901    eligible: &[(usize, f64, &Array2<f64>)],
902    forced: &[&Array2<f64>],
903    global_tol: f64,
904) -> Vec<usize> {
905    let (n, p) = r0.dim();
906    // True global SSR increase for the active δ set: 2⟨R0, Σδ⟩ + ‖Σδ‖².
907    let increase = |active: &[&Array2<f64>]| -> f64 {
908        let mut cross = 0.0_f64;
909        let mut self_sq = 0.0_f64;
910        for i in 0..n {
911            for j in 0..p {
912                let mut s = 0.0_f64;
913                for d in active {
914                    s += d[[i, j]];
915                }
916                cross += r0[[i, j]] * s;
917                self_sq += s * s;
918            }
919        }
920        2.0 * cross + self_sq
921    };
922    let mut kept = vec![true; eligible.len()];
923    let build_active = |kept: &[bool]| -> Vec<&Array2<f64>> {
924        let mut v: Vec<&Array2<f64>> = forced.to_vec();
925        for (idx, &(_, _, d)) in eligible.iter().enumerate() {
926            if kept[idx] {
927                v.push(d);
928            }
929        }
930        v
931    };
932    if increase(&build_active(&kept)) <= global_tol {
933        return Vec::new();
934    }
935    // Revert least-justified first: largest curved_evidence_margin (the most
936    // marginal linear win is the cheapest to give back to curved).
937    let mut order: Vec<usize> = (0..eligible.len()).collect();
938    order.sort_by(|&a, &b| {
939        eligible[b]
940            .1
941            .partial_cmp(&eligible[a].1)
942            .unwrap_or(std::cmp::Ordering::Equal)
943    });
944    let mut reverted = Vec::new();
945    for idx in order {
946        kept[idx] = false;
947        reverted.push(eligible[idx].0);
948        if increase(&build_active(&kept)) <= global_tol {
949            break;
950        }
951    }
952    reverted
953}
954
955/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
956/// fitted `d = 1` atoms, run the adjudication, and return the report.
957///
958/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
959/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
960/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
961/// atom's leave-this-atom-out response residual `y_resp` (the data both
962/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
963/// manifold (a flat / Euclidean chart can present only the linear candidate,
964/// enforced inside the selector).
965///
966/// Returns `None` (no report) when no atom is eligible — there is nothing to
967/// adjudicate.
968pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
969    atoms: &'a [SaeManifoldAtom],
970    eligible_d1: impl Iterator<Item = usize>,
971    mut coords_for: C,
972    mut assign_for: W,
973    mut decoded_for: D,
974    mut target_resid_for: R,
975    mut manifold_for: M,
976    mut delta_ev_for: E,
977    // #1026 — the full target's total (column-centered) variance `SST_full`, the
978    // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
979    // the gate (a degenerate, varianceless target has no EV to preserve).
980    total_centered_variance: f64,
981) -> Result<Option<SaeHybridSplitReport>, String>
982where
983    C: FnMut(usize) -> Array1<f64>,
984    W: FnMut(usize) -> Array1<f64>,
985    D: FnMut(usize) -> Array2<f64>,
986    R: FnMut(usize) -> Array2<f64>,
987    M: FnMut(usize) -> LatentManifold,
988    // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
989    // is unavailable (e.g. the caller has no target to measure against).
990    E: FnMut(usize) -> Option<f64>,
991{
992    let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
993    let mut names: Vec<String> = Vec::new();
994    let mut manifolds: Vec<LatentManifold> = Vec::new();
995    // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
996    // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
997    // substitute it for the curved decoded image.
998    let mut linear_images: Vec<AtomLinearImage> = Vec::new();
999    // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
1000    // so the geometry/EV pairing is structured report data, not a log line.
1001    let mut turnings: Vec<Option<f64>> = Vec::new();
1002    let mut delta_evs: Vec<Option<f64>> = Vec::new();
1003    // #1026 item-2 — per-slot collapse loss `Δ_k = linear_rss − curved_rss` for the
1004    // GLOBAL EV-preservation guard below. `Some(Δ_k)` for a curveable slot that
1005    // retained a curved alternative (so a chosen collapse there can be rolled back);
1006    // `None` for euclidean slots (no curved option) and collapse-rescue slots (the
1007    // curve was already degenerate — collapsing recovers EV rather than losing it).
1008    let mut collapse_loss: Vec<Option<f64>> = Vec::new();
1009    // #1026 item-2 — per-slot dictionary-level reconstruction-change vector δ_k
1010    // (globally row-aligned n×p), and the all-curved global residual R0 (computed
1011    // once from the first slot). These feed the GLOBAL cross-term collapse guard,
1012    // which reconstructs the full dictionary with the selected collapses applied
1013    // and measures the TRUE global EV degradation (cross terms and all).
1014    let mut deltas: Vec<Array2<f64>> = Vec::new();
1015    let mut r0: Option<Array2<f64>> = None;
1016
1017    for atom_idx in eligible_d1 {
1018        let atom = &atoms[atom_idx];
1019        let coords = coords_for(atom_idx);
1020        let assign = assign_for(atom_idx);
1021        let decoded = decoded_for(atom_idx);
1022        let target_resid = target_resid_for(atom_idx);
1023        // Curved parameter price = the decoder's `M · p` coefficients.
1024        let curved_num_params = atom.decoder_coefficients.len();
1025        let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
1026            d1_atom_fitted_turning(
1027                evaluator.as_ref(),
1028                atom.decoder_coefficients.view(),
1029                coords.view(),
1030            )
1031            .ok()
1032            .flatten()
1033        });
1034        // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
1035        // curved arm's Laplace complexity is the real weighted-design Gram
1036        // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
1037        // atom's coordinate column is presented as an `n × 1` design input. If
1038        // the evaluator is absent or refuses, `curved_phi` stays `None` and
1039        // `build_atom_candidates` falls back to the proxy.
1040        let coords_col = coords
1041            .view()
1042            .into_shape_with_order((coords.len(), 1))
1043            .ok()
1044            .map(|v| v.to_owned());
1045        let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
1046            (Some(evaluator), Some(col)) => {
1047                evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
1048            }
1049            _ => None,
1050        };
1051        // A flat (Euclidean) chart cannot honestly present a curved candidate;
1052        // the selector drops it. Present both for curveable charts.
1053        let manifold = manifold_for(atom_idx);
1054        match build_atom_candidates(
1055            coords.view(),
1056            assign.view(),
1057            decoded.view(),
1058            target_resid.view(),
1059            curved_num_params,
1060            curved_phi.as_ref().map(|phi| phi.view()),
1061            fitted_turning,
1062        ) {
1063            Some((linear, curved, (t_bar, b0, b1))) => {
1064                // #1026 PER-ATOM EV-PRESERVATION gate. Collapsing this slot raises
1065                // the full reconstruction SSR by `linear_rss − curved_rss`; if that
1066                // is more than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total
1067                // target variance the collapse would DROP this ONE atom's EV
1068                // materially, so veto it by presenting only the curved candidate (the
1069                // selector must keep curved). A lossless / improving collapse (`≤ 0`)
1070                // and a negligible one stay free to collapse — EV-neutral cases (the
1071                // top-k / birth-topology lines) are untouched. Only curveable charts
1072                // are gated; a euclidean chart never had a curved option. NOTE: this
1073                // gate is PER-ATOM only — the accumulation of many small collapses
1074                // and the dictionary-level cross terms are handled by the aggregate
1075                // guard after selection.
1076                let loss = collapse_ssr_increase(
1077                    coords.view(),
1078                    assign.view(),
1079                    decoded.view(),
1080                    target_resid.view(),
1081                    t_bar,
1082                    &b0,
1083                    &b1,
1084                );
1085                let collapse_loses_ev = total_centered_variance.is_finite()
1086                    && total_centered_variance > 0.0
1087                    && loss > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1088                let euclidean = manifold.is_euclidean();
1089                let slot = if euclidean {
1090                    vec![linear]
1091                } else if collapse_loses_ev {
1092                    vec![curved]
1093                } else {
1094                    vec![linear, curved]
1095                };
1096                // Build the straight image, then its globally-aligned δ_k for the
1097                // GLOBAL cross-term guard, before moving it into the report.
1098                let image = AtomLinearImage {
1099                    atom_idx,
1100                    t_bar,
1101                    b0,
1102                    b1,
1103                    row_codes: None,
1104                    // Ordinary straight image: decoded at the atom's own
1105                    // coordinate, so it carries no residual-projection direction.
1106                    v: None,
1107                };
1108                let delta = slot_delta(&coords, &assign, &decoded, &image);
1109                if r0.is_none() {
1110                    r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1111                }
1112                slots.push(slot);
1113                // A euclidean slot never had a curved alternative, so its collapse
1114                // carries no recoverable EV loss for the global guard; record `None`.
1115                collapse_loss.push(if euclidean { None } else { Some(loss) });
1116                names.push(atom.name.clone());
1117                manifolds.push(manifold);
1118                turnings.push(fitted_turning);
1119                delta_evs.push(delta_ev_for(atom_idx));
1120                deltas.push(delta);
1121                linear_images.push(image);
1122            }
1123            // #1026 collapse rescue: `build_atom_candidates` refused because the
1124            // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
1125            // fixed point. Recover a FRESH linear image from the residual's top
1126            // direction (fresh per-row codes) and force the LINEAR verdict (a
1127            // single-option slot the selector must take) so the slot reconstructs
1128            // its residual's best linear axis at linear quality instead of the
1129            // collapsed-curve constant. `None` only when the residual itself is
1130            // degenerate — then there is genuinely nothing to recover and we skip.
1131            None => match build_collapse_rescue_linear_image(
1132                atom_idx,
1133                assign.view(),
1134                target_resid.view(),
1135            ) {
1136                Some((linear, image)) => {
1137                    let delta = slot_delta(&coords, &assign, &decoded, &image);
1138                    if r0.is_none() {
1139                        r0 = Some(slot_r0(&assign, &decoded, &target_resid));
1140                    }
1141                    slots.push(vec![linear]);
1142                    // Forced-linear rescue: the curve was degenerate, so there is no
1143                    // curved alternative to roll back to and no recoverable EV loss.
1144                    collapse_loss.push(None);
1145                    names.push(atom.name.clone());
1146                    manifolds.push(manifold);
1147                    turnings.push(fitted_turning);
1148                    delta_evs.push(delta_ev_for(atom_idx));
1149                    deltas.push(delta);
1150                    linear_images.push(image);
1151                }
1152                None => continue,
1153            },
1154        }
1155    }
1156
1157    if slots.is_empty() {
1158        return Ok(None);
1159    }
1160
1161    let mut selection = select_hybrid_split(&slots)?;
1162
1163    // #1026 item-2 — GLOBAL CROSS-TERM EV-preservation guard over the SELECTED
1164    // collapses.
1165    //
1166    // The per-atom gate above bounds each atom's individual EV loss, but the TRUE
1167    // dictionary-level RSS increase from collapsing a SET of atoms is
1168    //   ΔRSS = Σ_k Δ_k + 2 Σ_{j<k} ⟨δ_j, δ_k⟩,
1169    // so two effects escape any per-atom or summed-loss bound: (1) the accumulation
1170    // of many individually-tolerable `Δ_k`, and (2) the CROSS TERMS between
1171    // simultaneously-collapsed atoms. The old aggregate `Σ max(Δ_k, 0)` guard bounded
1172    // (1) but was blind to (2): correlated collapse errors whose per-atom losses each
1173    // sit under tolerance can still push the true global loss over it.
1174    //
1175    // Every slot now carries its exact dictionary-level reconstruction-change vector
1176    // δ_k (globally row-aligned — the caller presents all per-atom arrays on the same
1177    // n rows), so we reconstruct the full dictionary WITH the selected collapse set
1178    // applied and measure the real global increase `ΔRSS` DIRECTLY (cross terms
1179    // captured), reverting the least-justified collapses until the degradation is
1180    // within the same EV tolerance and re-adjudicating so `selection` stays consistent.
1181    if total_centered_variance.is_finite() && total_centered_variance > 0.0 {
1182        if let Some(r0) = r0.as_ref() {
1183            let global_tol = SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
1184            // Partition the SELECTED-collapsed slots: rollback-eligible ones (a curved
1185            // alternative still present and a finite per-atom loss) vs forced ones
1186            // (euclidean / collapse-rescue — no curved fallback, stay collapsed but
1187            // still enter the global reconstruction so their cross terms are counted).
1188            let mut eligible: Vec<(usize, f64, &Array2<f64>)> = Vec::new();
1189            let mut forced: Vec<&Array2<f64>> = Vec::new();
1190            for (i, choice) in selection.atoms.iter().enumerate() {
1191                if !choice.param.is_linear() {
1192                    continue;
1193                }
1194                let has_curved_alt = slots[i].iter().any(|c| !c.param.is_linear());
1195                let loss_finite = collapse_loss[i].map(|l| l.is_finite()).unwrap_or(false);
1196                if has_curved_alt && loss_finite {
1197                    eligible.push((i, choice.curved_evidence_margin, &deltas[i]));
1198                } else {
1199                    forced.push(&deltas[i]);
1200                }
1201            }
1202            let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1203            if !reverted.is_empty() {
1204                for slot in reverted {
1205                    if let Some(curved) =
1206                        slots[slot].iter().find(|c| !c.param.is_linear()).copied()
1207                    {
1208                        slots[slot] = vec![curved];
1209                    }
1210                }
1211                selection = select_hybrid_split(&slots)?;
1212            }
1213        }
1214    }
1215
1216    let verdicts: Vec<AtomHybridVerdict> = names
1217        .into_iter()
1218        .zip(selection.atoms.iter().copied())
1219        .zip(linear_images.into_iter())
1220        .zip(turnings.into_iter())
1221        .zip(delta_evs.into_iter())
1222        .map(
1223            |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
1224                let kept_curved = !choice.param.is_linear();
1225                AtomHybridVerdict {
1226                    atom_name,
1227                    choice,
1228                    kept_curved,
1229                    fitted_turning,
1230                    train_loao_delta_ev,
1231                    // Carry the straight sub-model only when the verdict collapses
1232                    // this slot to linear — the curved slots keep their fitted image.
1233                    linear_image: if kept_curved {
1234                        None
1235                    } else {
1236                        Some(linear_image)
1237                    },
1238                }
1239            },
1240        )
1241        .collect();
1242
1243    Ok(Some(SaeHybridSplitReport {
1244        verdicts,
1245        selection,
1246    }))
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251    use super::*;
1252    use std::f64::consts::PI;
1253
1254    /// A straight RESPONSE residual (the atom's data is a line) is explained
1255    /// equally well by both candidates, so the cheaper linear special case wins.
1256    /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
1257    /// the dominance floor and the evidence argmin select linear. This is the
1258    /// common-data nested comparison (#1202): linear is the curved family's
1259    /// `Θ = 0` member, so it cannot lose when a line already explains the data.
1260    #[test]
1261    fn straight_residual_selects_linear() {
1262        let n = 40;
1263        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1264        let assign = Array1::<f64>::ones(n);
1265        // The data the atom must explain is a straight line in ℝ²; the curved
1266        // decoded image equals that same line (a Θ = 0 curved fit).
1267        let mut data = Array2::<f64>::zeros((n, 2));
1268        let mut decoded = Array2::<f64>::zeros((n, 2));
1269        for i in 0..n {
1270            data[[i, 0]] = coords[i];
1271            data[[i, 1]] = 0.6 * coords[i];
1272            decoded[[i, 0]] = coords[i];
1273            decoded[[i, 1]] = 0.6 * coords[i];
1274        }
1275        let (linear, curved, _) = build_atom_candidates(
1276            coords.view(),
1277            assign.view(),
1278            decoded.view(),
1279            data.view(),
1280            // a generous curved parameter price (M·p)
1281            10,
1282            None,
1283            Some(0.0),
1284        )
1285        .expect("straight residual yields a candidate pair");
1286        let choice =
1287            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1288        assert!(
1289            choice.param.is_linear(),
1290            "a straight response residual must keep the linear special case"
1291        );
1292    }
1293
1294    /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
1295    /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
1296    /// straight line (large linear_rss), so the curved candidate wins the common
1297    /// evidence comparison once its data-fit gain exceeds its extra parameter
1298    /// price (#1202).
1299    #[test]
1300    fn turning_residual_selects_curved_on_evidence() {
1301        let n = 60;
1302        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1303        let assign = Array1::<f64>::ones(n);
1304        // The data is a full circle; the curved decoded image is that same
1305        // circle (the curved atom reconstructs its assigned residual), so the
1306        // curved candidate has ≈ zero data-fit residual while a straight line
1307        // cannot follow the loop.
1308        let mut data = Array2::<f64>::zeros((n, 2));
1309        let mut decoded = Array2::<f64>::zeros((n, 2));
1310        for i in 0..n {
1311            let theta = 2.0 * PI * coords[i];
1312            data[[i, 0]] = theta.cos();
1313            data[[i, 1]] = theta.sin();
1314            decoded[[i, 0]] = theta.cos();
1315            decoded[[i, 1]] = theta.sin();
1316        }
1317        // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
1318        // the full-circle linear residual exceeds the extra-parameter overhead, so
1319        // curved wins on evidence.
1320        let (linear, curved, _) = build_atom_candidates(
1321            coords.view(),
1322            assign.view(),
1323            decoded.view(),
1324            data.view(),
1325            5,
1326            None,
1327            Some(2.0 * PI),
1328        )
1329        .expect("turning residual yields a candidate pair");
1330        assert!(
1331            linear.negative_log_evidence > curved.negative_log_evidence,
1332            "the line must misfit the circular residual worse than the curve does \
1333             (linear NLE {} should exceed curved NLE {})",
1334            linear.negative_log_evidence,
1335            curved.negative_log_evidence
1336        );
1337        let choice =
1338            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1339        assert_eq!(
1340            choice.param,
1341            gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
1342            "a full-circle response residual must keep the curved parameterization"
1343        );
1344        assert!(
1345            choice.curved_evidence_margin > 0.0,
1346            "curved must win a positive evidence margin over the linear secant"
1347        );
1348    }
1349
1350    /// The nested-dominance floor on common data (#1202): when the curved decoded
1351    /// image is a WORSE fit to the response residual than its own best straight
1352    /// projection, linear must win — the curved family cannot be charged extra
1353    /// parameters to fit the residual no better than its `Θ = 0` member. Here the
1354    /// data is a line but the curved image bends away from it, so curved_rss >
1355    /// linear_rss and the cheaper, better-fitting line is selected.
1356    #[test]
1357    fn linear_beats_curved_when_curve_misfits_residual() {
1358        let n = 50;
1359        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
1360        let assign = Array1::<f64>::ones(n);
1361        // Data is a straight line; the curved decoded image is a parabola that
1362        // departs from it, so a straight line fits the data strictly better.
1363        let mut data = Array2::<f64>::zeros((n, 2));
1364        let mut decoded = Array2::<f64>::zeros((n, 2));
1365        for i in 0..n {
1366            let t = coords[i];
1367            data[[i, 0]] = t;
1368            data[[i, 1]] = 0.5 * t;
1369            decoded[[i, 0]] = t;
1370            decoded[[i, 1]] = t * t; // bends away from the linear data
1371        }
1372        let (linear, curved, _) = build_atom_candidates(
1373            coords.view(),
1374            assign.view(),
1375            decoded.view(),
1376            data.view(),
1377            // a real curved Θ above the floor so the dominance floor does not fire
1378            6,
1379            None,
1380            Some(1.0),
1381        )
1382        .expect("candidate pair");
1383        let choice =
1384            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
1385        assert!(
1386            choice.param.is_linear(),
1387            "a curved image that fits the data worse than its own line must yield \
1388             to the linear special case on common-data evidence (#1202)"
1389        );
1390    }
1391
1392    /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1393    /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1394    /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1395    /// (#1203). Verify both contributions are present by reading the logdet off a
1396    /// candidate whose linear residual is exactly zero (response residual = the
1397    /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1398    /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1399    /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1400    /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1401    #[test]
1402    fn linear_logdet_includes_weighted_coordinate_spread() {
1403        let n = 40;
1404        let p = 2usize;
1405        // Read the logdet back off a candidate with zero linear residual: the
1406        // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1407        // with RSS == 0 and `NLE_linear = ½·logdet`.
1408        let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1409            // A straight image; the response residual is the same line scaled by
1410            // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1411            // matches it exactly and linear_rss == 0.
1412            let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1413            let mut decoded = Array2::<f64>::zeros((n, p));
1414            let mut data = Array2::<f64>::zeros((n, p));
1415            for i in 0..n {
1416                let l = line(coords[i]);
1417                decoded[[i, 0]] = l[0];
1418                decoded[[i, 1]] = l[1];
1419                data[[i, 0]] = assign[i] * l[0];
1420                data[[i, 1]] = assign[i] * l[1];
1421            }
1422            let (linear, _curved, _) = build_atom_candidates(
1423                coords.view(),
1424                assign.view(),
1425                decoded.view(),
1426                data.view(),
1427                10,
1428                None,
1429                Some(0.0),
1430            )
1431            .expect("straight residual yields a pair");
1432            2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1433        };
1434
1435        let base_coords =
1436            Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1437        let ones = Array1::<f64>::ones(n);
1438
1439        // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1440        let wide_coords = base_coords.mapv(|t| 2.0 * t);
1441        let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1442        assert!(
1443            (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1444            "linear logdet must move by p·log(4) when coordinate spread doubles \
1445             (got {d_spread}); the spread term log(s_tt) must be present"
1446        );
1447
1448        // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1449        // → logdet += 2p·log(4).
1450        let twos = Array1::<f64>::from_elem(n, 2.0);
1451        let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1452        assert!(
1453            (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1454            "linear logdet must move by 2p·log(4) when all assignment masses double \
1455             (got {d_weight})"
1456        );
1457    }
1458
1459    /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1460    /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1461    /// curved design whose columns are the constant and the centered coordinate
1462    /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1463    /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1464    /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1465    /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1466    /// `log(s_tt)` spread term, so this pins the genuine determinant.
1467    #[test]
1468    fn curved_gram_logdet_is_real_weighted_design_determinant() {
1469        let n = 40;
1470        let p = 3usize;
1471        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1472        let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1473
1474        // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1475        let mut w_sum = 0.0;
1476        let mut t_bar = 0.0;
1477        for i in 0..n {
1478            let w = assign[i] * assign[i];
1479            w_sum += w;
1480            t_bar += w * coords[i];
1481        }
1482        t_bar /= w_sum;
1483        let mut s_tt = 0.0;
1484        for i in 0..n {
1485            let dt = coords[i] - t_bar;
1486            s_tt += assign[i] * assign[i] * dt * dt;
1487        }
1488
1489        // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1490        // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1491        // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1492        let mut phi = Array2::<f64>::zeros((n, 2));
1493        for i in 0..n {
1494            phi[[i, 0]] = 1.0;
1495            phi[[i, 1]] = coords[i] - t_bar;
1496        }
1497        let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1498            .expect("non-degenerate curved design has a determinant");
1499        let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1500        assert!(
1501            (got - want).abs() < 1e-9,
1502            "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1503        );
1504
1505        // A rank-deficient design (a duplicated column) drops the null direction:
1506        // its determinant equals that of the single retained constant column,
1507        // p·log(w_sum), NOT a 2-column proxy.
1508        let mut phi_dup = Array2::<f64>::zeros((n, 2));
1509        for i in 0..n {
1510            phi_dup[[i, 0]] = 1.0;
1511            phi_dup[[i, 1]] = 1.0;
1512        }
1513        let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1514            .expect("rank-1 design still has a positive determinant");
1515        let want_dup = (p as f64) * (2.0 * w_sum).ln();
1516        assert!(
1517            (got_dup - want_dup).abs() < 1e-9,
1518            "rank-deficient curved Gram must report only its positive direction \
1519             (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1520        );
1521    }
1522
1523    /// #1051 NESTED MIN — the curved arm re-fit on the residual match-or-beats the
1524    /// best straight line: `curved_refit_rss ≤ best_line_rss` up to solver tolerance
1525    /// on a basis whose span contains the straight lane (here `Φ = [1, t, t²]`). A
1526    /// genuinely curved (quadratic) signal is STRICTLY preferred by the curved arm,
1527    /// while an exactly-straight signal ties near zero (so the cheaper linear lane
1528    /// wins downstream). This is the property the realized-curve heuristic could not
1529    /// establish: comparing a possibly-collapsed realized curve against min-over-lines
1530    /// did NOT guarantee `curved ≤ linear`; re-fitting the curved decoder does.
1531    #[test]
1532    fn refit_curved_rss_matches_or_beats_best_line_nested() {
1533        let n = 40usize;
1534        let p = 2usize;
1535        let coords =
1536            Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1537        let assign = Array1::<f64>::ones(n);
1538        // Φ = [1, t, t²]: its column span contains the straight lane [1, t], so the
1539        // decoder-only curved refit is a proper superset of the line fit.
1540        let mut phi = Array2::<f64>::zeros((n, 3));
1541        for i in 0..n {
1542            phi[[i, 0]] = 1.0;
1543            phi[[i, 1]] = coords[i];
1544            phi[[i, 2]] = coords[i] * coords[i];
1545        }
1546        // Best mass-weighted line RSS on `y` (assign = 1): fit design [1, (t − t̄)].
1547        let t_bar = coords.iter().sum::<f64>() / n as f64;
1548        let best_line_rss = |y: &Array2<f64>| -> f64 {
1549            let mut design = Array2::<f64>::zeros((n, 2));
1550            for i in 0..n {
1551                design[[i, 0]] = 1.0;
1552                design[[i, 1]] = coords[i] - t_bar;
1553            }
1554            let b = solve_design_least_squares(design.view(), y.view()).unwrap();
1555            let pred = design.dot(&b);
1556            let mut rss = 0.0_f64;
1557            for i in 0..n {
1558                for j in 0..p {
1559                    let r = y[[i, j]] - pred[[i, j]];
1560                    rss += r * r;
1561                }
1562            }
1563            rss
1564        };
1565        // (a) exactly straight, (b) quadratic curve, (c) noisy line.
1566        let mut y_line = Array2::<f64>::zeros((n, p));
1567        let mut y_curve = Array2::<f64>::zeros((n, p));
1568        let mut y_noisy = Array2::<f64>::zeros((n, p));
1569        for i in 0..n {
1570            let t = coords[i];
1571            y_line[[i, 0]] = 0.4 + 0.6 * t;
1572            y_line[[i, 1]] = -0.2 + 1.1 * t;
1573            y_curve[[i, 0]] = t * t;
1574            y_curve[[i, 1]] = 0.5 - t * t;
1575            y_noisy[[i, 0]] = 0.3 + 0.7 * t + 0.05 * (3.0 * t).sin();
1576            y_noisy[[i, 1]] = 0.9 * t;
1577        }
1578        for y in [&y_line, &y_curve, &y_noisy] {
1579            let curved = curved_refit_rss(phi.view(), assign.view(), y.view())
1580                .expect("non-degenerate refit");
1581            let line = best_line_rss(y);
1582            assert!(
1583                curved <= line + 1e-9 * (1.0 + line),
1584                "nested dominance: refit-curved RSS {curved} must be ≤ best-line RSS {line}"
1585            );
1586        }
1587        // A genuinely curved (quadratic) signal is STRICTLY preferred by the curve.
1588        let curved_c = curved_refit_rss(phi.view(), assign.view(), y_curve.view()).unwrap();
1589        let line_c = best_line_rss(&y_curve);
1590        assert!(
1591            curved_c < 0.5 * line_c,
1592            "a quadratic signal must be far better fit by the curve ({curved_c}) than \
1593             by the best line ({line_c})"
1594        );
1595        // A straight signal ties near zero — collapses to the cheaper linear lane.
1596        let curved_l = curved_refit_rss(phi.view(), assign.view(), y_line.view()).unwrap();
1597        assert!(
1598            curved_l < 1e-18 && best_line_rss(&y_line) < 1e-18,
1599            "an exactly-straight signal ties the two arms near zero (curved {curved_l})"
1600        );
1601    }
1602
1603    /// #1026 item-2 GLOBAL CROSS-TERM guard: two collapses with CORRELATED
1604    /// (parallel) reconstruction-change errors whose per-atom losses each sit under
1605    /// tolerance — and whose SUM `Σ Δ_k` is also under tolerance (so the OLD
1606    /// aggregate `Σ max(Δ_k,0)` guard would ACCEPT) — but whose cross term
1607    /// `2⟨δ_1,δ_2⟩` pushes the TRUE global loss over tolerance. The global guard must
1608    /// roll back the least-justified collapse. An ORTHOGONAL control (disjoint
1609    /// support) with the same per-atom losses is accepted, isolating the cross term.
1610    #[test]
1611    fn global_guard_rejects_correlated_collapses_the_aggregate_would_accept() {
1612        let n = 4usize;
1613        let p = 1usize;
1614        // R0 = 0 ⇒ Δ_k = ‖δ_k‖² and ΔRSS(S) = ‖Σ_S δ_k‖² exactly.
1615        let r0 = Array2::<f64>::zeros((n, p));
1616        // Parallel δ_1 = δ_2 = 1 on all rows: ‖δ_k‖² = 4 each, Σ Δ_k = 8,
1617        // ΔRSS_global = ‖δ_1+δ_2‖² = 16.
1618        let mut d1 = Array2::<f64>::zeros((n, p));
1619        let mut d2 = Array2::<f64>::zeros((n, p));
1620        for i in 0..n {
1621            d1[[i, 0]] = 1.0;
1622            d2[[i, 0]] = 1.0;
1623        }
1624        // tol = 10: each per-atom loss 4 ≤ 10, Σ Δ_k = 8 ≤ 10 (aggregate accepts),
1625        // but global 16 > 10 (cross term rejects).
1626        let global_tol = 10.0_f64;
1627        let eligible = vec![(0usize, 0.1_f64, &d1), (1usize, 0.2_f64, &d2)];
1628        let forced: Vec<&Array2<f64>> = Vec::new();
1629        let reverted = global_collapse_rollback(r0.view(), &eligible, &forced, global_tol);
1630        assert_eq!(
1631            reverted,
1632            vec![1usize],
1633            "the global cross-term guard must roll back the least-justified collapse \
1634             (largest margin = slot 1) that the summed-loss aggregate (8 ≤ 10) accepts"
1635        );
1636
1637        // ORTHOGONAL control: same per-atom losses (δ on disjoint rows) ⇒ cross term
1638        // 0 ⇒ ΔRSS_global = 8 ≤ 10 ⇒ no rollback. Isolates the cross term as the cause.
1639        let mut o1 = Array2::<f64>::zeros((n, p));
1640        let mut o2 = Array2::<f64>::zeros((n, p));
1641        o1[[0, 0]] = 2.0; // ‖o1‖² = 4
1642        o2[[2, 0]] = 2.0; // ‖o2‖² = 4, disjoint support
1643        let eligible_o = vec![(0usize, 0.1_f64, &o1), (1usize, 0.2_f64, &o2)];
1644        let reverted_o = global_collapse_rollback(r0.view(), &eligible_o, &forced, global_tol);
1645        assert!(
1646            reverted_o.is_empty(),
1647            "uncorrelated collapses (cross term 0, global loss 8 ≤ 10) must be accepted"
1648        );
1649    }
1650
1651    /// A degenerate (single-point-mass) coordinate has no slope direction and is
1652    /// refused rather than adjudicated on a fabricated deviance.
1653    #[test]
1654    fn degenerate_coordinate_is_refused() {
1655        let n = 5;
1656        let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1657        let assign = Array1::<f64>::ones(n);
1658        let decoded = Array2::<f64>::zeros((n, 2));
1659        let data = Array2::<f64>::zeros((n, 2));
1660        assert!(
1661            build_atom_candidates(
1662                coords.view(),
1663                assign.view(),
1664                decoded.view(),
1665                data.view(),
1666                6,
1667                None,
1668                Some(0.0)
1669            )
1670            .is_none(),
1671            "a degenerate coordinate span must be refused"
1672        );
1673    }
1674}