Skip to main content

gam_sae/
hybrid_split.rs

1//! #1026 — load-bearing curved-vs-linear hybrid split for the fitted SAE
2//! dictionary.
3//!
4//! The selection machinery ([`gam_solve::evidence::select_hybrid_atom`],
5//! [`gam_solve::evidence::select_hybrid_split`]) and the per-atom
6//! integration helper
7//! ([`crate::assignment::select_hybrid_atom_parameterization`]) are
8//! correct and tested, but until now were called nowhere in the fitter: the
9//! post-fit pass only *logged* each `d = 1` atom's fitted turning `Θ`. This
10//! module makes the split LOAD-BEARING by building, per fitted `d = 1` atom, the
11//! two already-realized candidates and adjudicating them by the common
12//! evidence criterion.
13//!
14//! ## The common-evidence comparison on the data (#1202)
15//!
16//! Both candidates are scored against the SAME data: the portion of the
17//! response matrix the atom is responsible for reconstructing, namely its
18//! **leave-this-atom-out residual**
19//!
20//!     y_resp[i] = target[i] − ( Σ_j a[i,j]·γ_j(t_{ij}) − a[i,k]·γ_k(t_{ik}) )
21//!               = target[i] − without_k[i],
22//!
23//! the response with every OTHER atom's contribution subtracted. Over the rows
24//! assigned to atom `k` (assignment mass `a[i,k] = a_k`), the two candidates
25//! predict that residual:
26//!
27//!   * the CURVED candidate predicts `a_k · γ_k(t)` — the atom's actual
28//!     already-fitted contribution; its data-fit deviance is the weighted RSS of
29//!     that contribution against `y_resp`, no longer zero by construction.
30//!   * the LINEAR candidate predicts `a_k · (b₀ + (t − t̄)·b₁)`, the best
31//!     weighted least-squares straight line fit to `y_resp` (design column
32//!     scaled by the same assignment mass `a_k`), so its data-fit deviance is the
33//!     weighted RSS of the best line against the SAME residual.
34//!
35//! Because the curved family's `Θ = 0` member reproduces exactly the linear
36//! prediction `a_k·(b₀ + (t − t̄)·b₁)` on this data, linear IS the nested `Θ = 0`
37//! sub-model of the curved family on common data — so the per-slot evidence
38//! argmin is a genuine "match-or-beat" comparison: the curved candidate is
39//! preferred only when its extra curvature lowers the data-fit deviance by more
40//! than its extra Laplace parameter price, and the linear special case wins
41//! whenever a straight line already explains the residual.
42//!
43//! This replaces the earlier post-hoc curve-simplification diagnostic, in which
44//! both candidates targeted the atom's already-fitted decoded image `γ_k(t)`
45//! (giving the curved arm a free zero residual against itself) rather than the
46//! response data. That version could not nest linear in curved on common data
47//! and so carried no real dominance guarantee (#1202); it is removed. The
48//! comparison here re-fits nothing in the (broken under #1051) euclidean /
49//! multi-atom outer continuation — it scores the already-realized curved
50//! contribution and the closed-form linear lane against the realized residual,
51//! both on the data, with no joint Hessian or continuation spine.
52
53use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
54
55use gam_linalg::faer_ndarray::FaerEigh;
56use gam_solve::evidence::{
57    HybridAtomCandidate, HybridAtomChoice, HybridSplitSelection, select_hybrid_split,
58};
59use gam_terms::latent::LatentManifold;
60use crate::chart_canonicalization::d1_atom_fitted_turning;
61use crate::manifold::SaeManifoldAtom;
62
63/// The rank-aware Laplace negative-log-evidence of a reduced per-atom Gaussian
64/// reconstruction sub-model: `residual_objective + ½ log|H|` with no smoothing
65/// penalty logdet and a full-rank design (no null space), which is the form
66/// [`gam_solve::evidence::laplace_evidence`] reduces to on this comparison.
67/// Kept inline (rather than routed through `EvidenceLogDetSource`) because both
68/// candidates' Hessian logdets are already the closed-form scalar moments of
69/// their shared design — no factor cache or HVP callback to assemble.
70fn reduced_laplace_nle(residual_objective: f64, log_det_h: f64) -> f64 {
71    residual_objective + 0.5 * log_det_h
72}
73
74/// Rank-aware `log|ΦᵀWΦ|_+` of the curved atom's weighted design Gram over its
75/// `M` decoder basis columns, with per-row weight `wᵢ = a_k²` (the same
76/// assignment-mass design weight the linear arm uses), summed over the
77/// eigenvalues above a relative spectral floor (#1223). This is the genuine
78/// weighted-design determinant the linear arm already reports — `log|XᵀWX|` —
79/// assembled for the curved basis so the two arms' Laplace complexity prices are
80/// computed on the SAME footing instead of pricing the curved arm with a
81/// parameter-count proxy `M·log(Σw)`.
82///
83/// Mirrors the linear arm exactly in what it does NOT include: no smoothing-
84/// penalty `λS` normalizer (the linear arm's Gram is the bare data Gram
85/// `diag(w_sum, s_tt)` too), so the comparison stays symmetric. The Gram is the
86/// design's outer Gram over its basis columns; it is identical across the `p`
87/// output channels (every channel shares the design `Φ`), so the per-channel
88/// `log|G|_+` is multiplied by `p` — matching the linear arm's `p·(…)` form.
89///
90/// `phi` is the curved design `Φ(t)` evaluated on the atom's assigned rows
91/// (`n × M`); `assign` is the per-row assignment mass `a_k` (NOT squared).
92/// Returns `None` when `Φ` is missing rows, the Gram is non-finite, or it has no
93/// positive eigenvalues (a fully rank-deficient design carries no determinant);
94/// the caller then falls back to the parameter-count proxy rather than fabricate
95/// a determinant.
96fn curved_design_gram_logdet(
97    phi: ArrayView2<'_, f64>,
98    assign: ArrayView1<'_, f64>,
99    p: usize,
100) -> Option<f64> {
101    let n = phi.nrows();
102    let m = phi.ncols();
103    if m == 0 || assign.len() != n || n == 0 {
104        return None;
105    }
106    // G = Φᵀ diag(a²) Φ  (M×M, symmetric PSD).
107    let mut gram = Array2::<f64>::zeros((m, m));
108    for i in 0..n {
109        let w = assign[i] * assign[i];
110        if !(w.is_finite() && w >= 0.0) {
111            return None;
112        }
113        if w == 0.0 {
114            continue;
115        }
116        let row = phi.row(i);
117        for a in 0..m {
118            let wa = w * row[a];
119            for b in a..m {
120                gram[[a, b]] += wa * row[b];
121            }
122        }
123    }
124    // Symmetrize the lower triangle (we only filled the upper).
125    for a in 0..m {
126        for b in 0..a {
127            gram[[a, b]] = gram[[b, a]];
128        }
129    }
130    if gram.iter().any(|v| !v.is_finite()) {
131        return None;
132    }
133    let (vals, _vecs) = gram.eigh(faer::Side::Lower).ok()?;
134    // Rank-aware log-determinant: sum log of eigenvalues above a relative floor
135    // tied to the largest eigenvalue, dropping the numerically-null directions
136    // (the curved design's null space, analogous to the linear arm's full-rank
137    // 2-D Gram). A design with no positive eigenvalue carries no determinant.
138    let lambda_max = vals.iter().cloned().fold(0.0_f64, f64::max);
139    if !(lambda_max > 0.0 && lambda_max.is_finite()) {
140        return None;
141    }
142    let floor = lambda_max * 1e-12;
143    let mut log_det = 0.0_f64;
144    let mut rank = 0usize;
145    for &lambda in vals.iter() {
146        if lambda > floor {
147            log_det += lambda.ln();
148            rank += 1;
149        }
150    }
151    if rank == 0 || !log_det.is_finite() {
152        return None;
153    }
154    // The design Gram is shared across the p output channels.
155    Some((p as f64) * log_det)
156}
157
158/// The fitted straight sub-model `γ̃(t) = b₀ + (t − t̄)·b₁` of one `d = 1` atom:
159/// the exact assignment-mass-weighted least-squares line fit to the atom's
160/// leave-this-atom-out RESPONSE residual `y_resp` over its assigned rows (the
161/// curved family's nested `Θ = 0` sub-model on common data, #1202). Carried on a
162/// verdict that selects LINEAR so the collapsed reconstruction can replace the
163/// curved decoded row with this straight image at any coordinate WITHOUT
164/// re-entering the (broken, #1051) outer fit — the coefficients are already
165/// realized inside the adjudication.
166#[derive(Clone, Debug)]
167pub struct AtomLinearImage {
168    /// The atom's slot index in the dictionary (so the collapsed assembly knows
169    /// which atom's decoded row to substitute).
170    pub atom_idx: usize,
171    /// The mass-weighted coordinate mean `t̄` the line is centered on.
172    pub t_bar: f64,
173    /// Per-output-channel centered intercept `b₀ = γ̄` at `t̄` (length `p`).
174    pub b0: Array1<f64>,
175    /// Per-output-channel slope `b₁` (length `p`).
176    pub b1: Array1<f64>,
177    /// #1026 collapse-rescue per-row coordinates. `None` for the ordinary path:
178    /// the line is evaluated at the atom's OWN realized coordinate `t`. `Some(u)`
179    /// only when the atom's circle codes had collapsed to a single point
180    /// (`s_tt ≈ 0`) so its own coordinate carries no spread — then the line is fit
181    /// against, and reconstruct evaluates it at, these FRESH per-row codes `uᵢ`
182    /// (the projection of the leave-this-atom-out residual onto its top
183    /// mass-weighted output direction). This is what lets a circle atom that the
184    /// joint fit drove into the degenerate "chord-through-the-arc" fixed point
185    /// still reconstruct its residual's best linear direction — recovering the
186    /// linear-tail reach the hybrid-split was designed to deliver — instead of a
187    /// constant (its collapsed curve), which is the real-OLMo rank-1 co-collapse
188    /// (held-out EV ≈ 0.13 vs the linear ceiling ≈ 0.74). Length `n` (one per
189    /// reconstructed row); unassigned rows are gated to zero by `a_k` anyway.
190    pub row_codes: Option<Array1<f64>>,
191}
192
193impl AtomLinearImage {
194    /// Evaluate the straight sub-model `b₀ + (t − t̄)·b₁` into `out` (length `p`).
195    pub fn fill_row(&self, t: f64, out: &mut [f64]) {
196        let dt = t - self.t_bar;
197        for (j, slot) in out.iter_mut().enumerate() {
198            *slot = self.b0[j] + dt * self.b1[j];
199        }
200    }
201
202    /// The coordinate at which row `row` should evaluate this image: the
203    /// collapse-rescue fresh code `uᵢ` when present (#1026), else the atom's own
204    /// realized coordinate `own_t` passed by the caller.
205    pub fn coordinate_for_row(&self, row: usize, own_t: f64) -> f64 {
206        match &self.row_codes {
207            Some(u) if row < u.len() => u[row],
208            _ => own_t,
209        }
210    }
211}
212
213/// One fitted `d = 1` atom's hybrid-split verdict, surfaced in the model output.
214#[derive(Clone, Debug)]
215pub struct AtomHybridVerdict {
216    /// The atom's name (slot identity in the dictionary).
217    pub atom_name: String,
218    /// The evidence-selected parameterization choice for this slot.
219    pub choice: HybridAtomChoice,
220    /// `true` iff the slot kept the CURVED parameterization (the fitted atom);
221    /// `false` iff it yielded to the LINEAR special case (the straight tail).
222    pub kept_curved: bool,
223    /// The atom's fitted turning `Θ = ∫|κ| ds` (radians), the novel geometric
224    /// quantity #1026 pairs against reconstruction EV: `Θ ≈ 0` is a linear-tail
225    /// direction wearing a curved basis, `Θ ≈ 2π` is a full curved loop. `None`
226    /// iff the evaluator has no analytic second jet or the curve is degenerate.
227    /// Captured here (not just logged) so the EV-vs-Θ frontier is queryable
228    /// structured data on the persisted report rather than a transient log line.
229    pub fitted_turning: Option<f64>,
230    /// The atom's training leave-one-atom-out explained-variance contribution
231    /// `ΔEV_k = EV(full) − EV(full∖{k})` — how much reconstruction EV this single
232    /// atom earns. Paired with [`Self::fitted_turning`] this is the `(Θ, ΔEV)`
233    /// point the #1026 frontier reports: a `Θ ≈ 0` atom with large `ΔEV` is a
234    /// genuine linear-tail direction; a high-`Θ` atom with large `ΔEV` is a
235    /// genuine curved family. `None` iff the caller did not supply LOAO EV.
236    pub train_loao_delta_ev: Option<f64>,
237    /// The fitted straight sub-model for this slot, present iff the verdict
238    /// selected LINEAR (`kept_curved == false`). The collapsed reconstruction
239    /// substitutes this for the atom's curved decoded image, making the verdict
240    /// load-bearing on the reconstruction rather than a passive diagnostic.
241    pub linear_image: Option<AtomLinearImage>,
242}
243
244/// The whole dictionary's hybrid-split report: one verdict per eligible `d = 1`
245/// atom, plus the dictionary-level aggregates the EV-vs-Θ frontier reports
246/// against.
247#[derive(Clone, Debug)]
248pub struct SaeHybridSplitReport {
249    /// One adjudicated verdict per eligible `d = 1` atom, in slot order. Atoms
250    /// that are not eligible (wrong dim, no evaluator, mid-homotopy) are absent
251    /// — they carry no curved/linear adjudication.
252    pub verdicts: Vec<AtomHybridVerdict>,
253    /// The dictionary-level rolled-up selection (summed NLE, total parameters,
254    /// curved/linear counts) over the eligible atoms.
255    pub selection: HybridSplitSelection,
256}
257
258/// Below this many assigned rows a `d = 1` atom cannot support a two-parameter
259/// straight-line fit with a residual estimate, so the linear candidate's
260/// deviance is undefined. Such atoms are skipped (absent from the report),
261/// never adjudicated on a fabricated deviance.
262const MIN_ROWS_FOR_LINEAR_FIT: usize = 3;
263
264/// #1610/#1026 — EV-PRESERVATION gate tolerance: a `d = 1` slot may collapse to
265/// its linear tail only if doing so costs at most this fraction of the target's
266/// total (centered) variance in full-reconstruction explained variance. The
267/// evidence argmin trades data-fit against the curved arm's parameter price in
268/// `NLE` units, but on a small / low-amplitude fixture that trade can prefer the
269/// cheaper line even when the curve carries real reconstruction signal — the
270/// collapse then DROPS EV (the observed 1.0 → 0.748 over-collapse). This gate is
271/// a direct guard on the quantity that actually regressed: the per-atom collapse
272/// EV impact equals `(linear_rss − curved_rss)/SST_full` exactly (collapsing
273/// atom `k` raises the full reconstruction SSR by `linear_rss − curved_rss` and
274/// the full target variance `SST_full` is fixed), so a collapse is vetoed when it
275/// would lose more than this fraction. EV is a dimensionless quantity in `[0,1]`,
276/// so an absolute EV-loss tolerance is itself scale-invariant — it does not
277/// reintroduce the scale-incommensurability that sank the evidence-reformulation
278/// attempt. A genuinely straight atom (the curved fit IS its own line) loses
279/// `≈ 0` EV and still collapses losslessly; only a curve doing real
280/// reconstruction work (loss `≫ 1e-3`) is kept. `1e-3` = 0.1% of total variance:
281/// comfortably above f64 round-off on an exact-line collapse yet far below any
282/// material EV loss, so it separates the lossless and load-bearing regimes
283/// without tuning.
284const SAE_HYBRID_COLLAPSE_MAX_EV_LOSS: f64 = 1.0e-3;
285
286/// The full-reconstruction SSR INCREASE from collapsing one `d = 1` atom to its
287/// fitted straight sub-model: `linear_rss − curved_rss`, where both arms are
288/// scored against the atom's leave-this-atom-out response residual `y_resp`
289/// (`target_resid`) exactly as [`build_atom_candidates`] scores them. Because the
290/// full reconstruction differs from the collapsed one ONLY in this atom's
291/// contribution (`a_k·γ_k` → `a_k·line`) on its assigned rows, this scalar is the
292/// exact amount the full reconstruction's SSR rises when the slot is collapsed;
293/// dividing by the fixed total target variance gives the full-EV loss the
294/// EV-preservation gate keys on. Positive ⇒ the curve out-fits its straight
295/// projection (collapsing hurts); `≤ 0` ⇒ the line is at least as good (collapse
296/// is lossless or improving, never gated). Mirrors the `curved_rss` / `linear_rss`
297/// accumulation in [`build_atom_candidates`] bit-for-bit so the gate and the
298/// evidence comparison see the same residuals.
299fn collapse_ssr_increase(
300    coords: ArrayView1<'_, f64>,
301    assign: ArrayView1<'_, f64>,
302    decoded: ArrayView2<'_, f64>,
303    target_resid: ArrayView2<'_, f64>,
304    t_bar: f64,
305    b0: &Array1<f64>,
306    b1: &Array1<f64>,
307) -> f64 {
308    let n = assign.len();
309    let p = target_resid.ncols();
310    let mut curved_rss = 0.0_f64;
311    let mut linear_rss = 0.0_f64;
312    for i in 0..n {
313        let a = assign[i];
314        let dt = coords[i] - t_bar;
315        for j in 0..p {
316            let y = target_resid[[i, j]];
317            let r_curved = y - a * decoded[[i, j]];
318            curved_rss += r_curved * r_curved;
319            let r_linear = y - a * (b0[j] + dt * b1[j]);
320            linear_rss += r_linear * r_linear;
321        }
322    }
323    linear_rss - curved_rss
324}
325
326/// Build the curved + linear candidates for ONE fitted `d = 1` atom and return
327/// them as `(linear, curved, (t̄, b₀, b₁))`, or `None` if the atom cannot present
328/// an honest pair (too few rows, degenerate coordinate span, or non-finite
329/// numbers). Both candidates are scored against the SAME data — the atom's
330/// leave-this-atom-out response residual `y_resp` — so the comparison is a
331/// genuine common-evidence one with linear nested as the curved family's `Θ = 0`
332/// sub-model (#1202).
333///
334/// Inputs over the atom's assigned rows:
335///   * `coords` — the fitted on-atom coordinate `t`.
336///   * `assign` — the per-row assignment mass `a_k` (NOT squared; this routine
337///     squares it where the design weight `a_k²` is needed).
338///   * `decoded` — the atom's fitted decoded image `γ_k(t) = Φ(t) B_k` (`p` cols),
339///     whose mass-scaled value `a_k·γ_k` is the curved candidate's PREDICTION.
340///   * `target_resid` — the atom's leave-this-atom-out response residual `y_resp`
341///     (`p` cols): the response with every OTHER atom's contribution removed.
342///     This is the data both candidates fit.
343///
344/// The curved candidate's data-fit deviance is `½ Σ ‖y_resp − a_k·γ_k‖²` (the
345/// plain reconstruction SSE, matching the joint loss — the mass already lives in
346/// the prediction `a_k·γ_k`); the linear candidate fits the best mass-weighted
347/// straight line to `y_resp` and pays `½ Σ ‖y_resp − a_k·(b₀ + (t − t̄)·b₁)‖²`.
348/// Because the curved family's `Θ = 0` member reproduces the linear prediction
349/// exactly, linear is the nested sub-model and the argmin is the honest
350/// match-or-beat criterion.
351fn build_atom_candidates(
352    coords: ArrayView1<'_, f64>,
353    assign: ArrayView1<'_, f64>,
354    decoded: ArrayView2<'_, f64>,
355    target_resid: ArrayView2<'_, f64>,
356    curved_num_params: usize,
357    curved_phi: Option<ArrayView2<'_, f64>>,
358    fitted_turning: Option<f64>,
359) -> Option<(
360    HybridAtomCandidate,
361    HybridAtomCandidate,
362    (f64, Array1<f64>, Array1<f64>),
363)> {
364    let n = coords.len();
365    let p = decoded.ncols();
366    if n < MIN_ROWS_FOR_LINEAR_FIT
367        || decoded.nrows() != n
368        || assign.len() != n
369        || target_resid.nrows() != n
370        || target_resid.ncols() != p
371        || p == 0
372    {
373        return None;
374    }
375
376    // The LINEAR candidate fits `a_k·(b₀ + (t − t̄)·b₁)` to the residual `y_resp`,
377    // so the natural design column is `a_k·[1, (t − t̄)]` and the per-row Gram
378    // weight is `wᵢ = a_k²`. We accumulate the mass-weighted coordinate mean `t̄`
379    // and spread `s_tt` under that weight; a row that barely belongs to the atom
380    // (`a_k ≈ 0`) contributes ≈ nothing, exactly as in the joint loss.
381    let mut w_sum = 0.0_f64;
382    let mut t_bar = 0.0_f64;
383    for i in 0..n {
384        let a = assign[i];
385        if !(a.is_finite() && a >= 0.0) {
386            return None;
387        }
388        let w = a * a;
389        w_sum += w;
390        t_bar += w * coords[i];
391    }
392    if !(w_sum > 0.0) {
393        return None;
394    }
395    t_bar /= w_sum;
396
397    // Weighted Σ wᵢ·(t − t̄)² with `wᵢ = a_k²` — the coordinate spread under the
398    // line's design weight. A degenerate (single-point mass) coordinate has no
399    // slope direction; refuse rather than divide by ~0.
400    let mut s_tt = 0.0_f64;
401    for i in 0..n {
402        let dt = coords[i] - t_bar;
403        s_tt += assign[i] * assign[i] * dt * dt;
404    }
405    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
406        return None;
407    }
408
409    // Per-output-channel mass-weighted least squares for the line fit to the
410    // RESIDUAL `y_resp`. Minimizing `Σᵢ ‖y_resp[i] − a_k·(b₀ + (t − t̄)·b₁)‖²` in
411    // the centered basis has the diagonal normal equations
412    //   b₀[j] = (Σ a_k·y_resp[i,j]) / w_sum,   (recall the design intercept is a_k)
413    //   b₁[j] = (Σ a_k·(t − t̄)·y_resp[i,j]) / s_tt.
414    let mut b0 = Array1::<f64>::zeros(p);
415    let mut b1 = Array1::<f64>::zeros(p);
416    for j in 0..p {
417        let mut s_1y = 0.0_f64;
418        let mut s_ty = 0.0_f64;
419        for i in 0..n {
420            let a = assign[i];
421            let dt = coords[i] - t_bar;
422            let y = target_resid[[i, j]];
423            s_1y += a * y;
424            s_ty += a * dt * y;
425        }
426        b0[j] = s_1y / w_sum;
427        b1[j] = s_ty / s_tt;
428    }
429
430    // Data-fit residual sums of squares of BOTH candidates against `y_resp`, the
431    // common data. The curved candidate predicts the atom's actual mass-scaled
432    // contribution `a_k·γ_k`; the linear candidate predicts the best line
433    // `a_k·(b₀ + (t − t̄)·b₁)`. These are no longer trivially zero for the curved
434    // arm — both are real misfits to the response residual, so the argmin is a
435    // genuine common-evidence comparison (#1202).
436    let mut curved_rss = 0.0_f64;
437    let mut linear_rss = 0.0_f64;
438    for i in 0..n {
439        let a = assign[i];
440        let dt = coords[i] - t_bar;
441        for j in 0..p {
442            let y = target_resid[[i, j]];
443            let r_curved = y - a * decoded[[i, j]];
444            curved_rss += r_curved * r_curved;
445            let r_linear = y - a * (b0[j] + dt * b1[j]);
446            linear_rss += r_linear * r_linear;
447        }
448    }
449
450    // Gaussian-reconstruction deviance: the residual objective `½ RSS` the
451    // Laplace normalizer is added to. The curved arm pays `½·curved_rss` (how
452    // well its fitted curve explains the residual) plus its larger `M·p`
453    // parameter price; the linear arm pays `½·linear_rss` plus a `2·p` price.
454    // Because the curved family's `Θ = 0` member equals the linear prediction,
455    // `curved_rss ≤ linear_rss` whenever the fitted curve is at least as good a
456    // residual fit as its own straight projection — the match-or-beat floor — and
457    // the argmin trades that data-fit gain against the curvature parameter price.
458    let curved_residual_objective = 0.5 * curved_rss;
459    let linear_residual_objective = 0.5 * linear_rss;
460
461    // Linear candidate parameter price: intercept + slope per output channel.
462    let linear_num_params = 2 * p;
463
464    // Laplace logdet of the (weighted) design Gram for the LINEAR candidate.
465    //
466    // For the centered weighted line fit `a_k·(b₀ + (t − t̄)·b₁)`, the per-output-
467    // channel design column is `a_k·[1, (t − t̄)]`, whose Gram is DIAGONAL in the
468    // centered basis: `diag(Σ a_k², Σ a_k²(t − t̄)²) = diag(w_sum, s_tt)`. Its log
469    // determinant is `log(w_sum) + log(s_tt)` PER output channel, i.e.
470    //
471    //     log|H_linear| = p · ( log(w_sum) + log(s_tt) ).
472    //
473    // The `log(s_tt)` term is the slope direction's information: a line through a
474    // wide, heavily-massed coordinate spread is better-determined than one through
475    // a tiny spread, and the Laplace evidence must reflect that (#1203).
476    //
477    // The curved arm's Laplace determinant is now the genuine weighted-design
478    // Gram log-determinant `p · log|ΦᵀWΦ|_+` (#1223): the SAME quantity the
479    // linear arm reports (`p·(log w_sum + log s_tt) = p·log|XᵀWX|`), assembled
480    // from the curved basis `Φ` on the atom's assigned rows under the same
481    // assignment-mass design weight `wᵢ = a_k²`. Both arms omit the smoothing
482    // `λS` normalizer, so the complexity price is computed on a symmetric
483    // footing — no parameter-count proxy. Only when `Φ` is unavailable (the
484    // caller could not evaluate the basis) or its Gram is fully rank-deficient do
485    // we fall back to the historical `curved_num_params · log(w_sum)` proxy, so
486    // the comparison degrades gracefully rather than fabricating a determinant.
487    if !(w_sum > 0.0 && w_sum.is_finite() && s_tt.is_finite()) {
488        return None;
489    }
490    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
491    let curved_log_det_h = curved_phi
492        .and_then(|phi| {
493            if phi.nrows() == n {
494                curved_design_gram_logdet(phi, assign, p)
495            } else {
496                None
497            }
498        })
499        .unwrap_or_else(|| (curved_num_params as f64) * w_sum.ln());
500
501    // Reduced Laplace NLE `residual_objective + ½ log|H|`. Both omit an explicit
502    // smoothing-penalty logdet (the intrinsic smoothness penalty is
503    // reparameterization-invariant and identical in expectation across the two
504    // parameterizations of the same image).
505    let linear_nle = reduced_laplace_nle(linear_residual_objective, linear_log_det_h);
506    let curved_nle = reduced_laplace_nle(curved_residual_objective, curved_log_det_h);
507    if !(linear_nle.is_finite() && curved_nle.is_finite()) {
508        return None;
509    }
510
511    let linear = HybridAtomCandidate::linear(linear_nle, linear_num_params);
512    let curved = HybridAtomCandidate::curved(1, curved_nle, curved_num_params, fitted_turning);
513    Some((linear, curved, (t_bar, b0, b1)))
514}
515
516/// #1026 collapse rescue. When a `d = 1` atom's own coordinate has collapsed to a
517/// single point (`build_atom_candidates` refuses because `s_tt ≈ 0`), the atom is
518/// stuck in the degenerate "chord-through-the-arc" fixed point and its curved
519/// decode is a constant — the rank-1 dictionary co-collapse (real-OLMo held-out EV
520/// ≈ 0.13 vs the rank-K linear ceiling ≈ 0.74). The hybrid-split was DESIGNED to
521/// let such a linear-tail atom decode as a straight line; the only reason it can't
522/// here is that its own codes carry no spread to fit a slope against.
523///
524/// Recover FRESH per-row codes from the data instead: `uᵢ = yᵢ·v`, the projection
525/// of the leave-this-atom-out residual onto its top mass-weighted output direction
526/// `v` (the rank-1 of `Σᵢ wᵢ yᵢyᵢᵀ`, `wᵢ = a_k²` — the SAME design weight the line
527/// fit uses). These codes span the residual's strongest linear axis by
528/// construction, so the straight image `b₀ + (uᵢ − ū)·b₁` fit against them
529/// reconstructs that axis at LINEAR quality — exactly the linear-tail reach the
530/// split owes. Returns the forced-LINEAR candidate plus the image carrying `uᵢ`,
531/// or `None` when the residual itself carries no usable direction (a genuine zero
532/// atom the mass/decoder guards own).
533fn build_collapse_rescue_linear_image(
534    atom_idx: usize,
535    assign: ArrayView1<'_, f64>,
536    target_resid: ArrayView2<'_, f64>,
537) -> Option<(HybridAtomCandidate, AtomLinearImage)> {
538    let n = assign.len();
539    let p = target_resid.ncols();
540    if n < MIN_ROWS_FOR_LINEAR_FIT || target_resid.nrows() != n || p == 0 {
541        return None;
542    }
543    let mut w_sum = 0.0_f64;
544    for i in 0..n {
545        let a = assign[i];
546        if !(a.is_finite() && a >= 0.0) {
547            return None;
548        }
549        w_sum += a * a;
550    }
551    if !(w_sum > 0.0) {
552        return None;
553    }
554    // Top mass-weighted output direction `v` of the residual via power iteration on
555    // `M = Σᵢ wᵢ yᵢyᵢᵀ` (p×p, never materialized): `v ← normalize(Σᵢ wᵢ yᵢ (yᵢ·v))`.
556    // Seed from the per-channel weighted energy so a rank-1 residual converges in
557    // one step and the seed is deterministic (no RNG).
558    let mut v = Array1::<f64>::zeros(p);
559    for j in 0..p {
560        let mut e = 0.0_f64;
561        for i in 0..n {
562            let a = assign[i];
563            let y = target_resid[[i, j]];
564            e += a * a * y * y;
565        }
566        v[j] = e;
567    }
568    let mut vnorm = v.dot(&v).sqrt();
569    if !(vnorm > 0.0) {
570        return None;
571    }
572    v.mapv_inplace(|x| x / vnorm);
573    for _ in 0..32 {
574        let mut mv = Array1::<f64>::zeros(p);
575        for i in 0..n {
576            let a = assign[i];
577            let w = a * a;
578            let mut proj = 0.0_f64;
579            for j in 0..p {
580                proj += target_resid[[i, j]] * v[j];
581            }
582            let wp = w * proj;
583            for j in 0..p {
584                mv[j] += wp * target_resid[[i, j]];
585            }
586        }
587        vnorm = mv.dot(&mv).sqrt();
588        if !(vnorm > 0.0) {
589            return None;
590        }
591        mv.mapv_inplace(|x| x / vnorm);
592        let cos = mv.dot(&v).abs();
593        v = mv;
594        if cos > 1.0 - 1e-12 {
595            break;
596        }
597    }
598    // Fresh per-row codes `uᵢ = yᵢ·v` and the weighted line fit against them.
599    let mut u = Array1::<f64>::zeros(n);
600    let mut t_bar = 0.0_f64;
601    for i in 0..n {
602        let mut proj = 0.0_f64;
603        for j in 0..p {
604            proj += target_resid[[i, j]] * v[j];
605        }
606        u[i] = proj;
607        t_bar += assign[i] * assign[i] * proj;
608    }
609    t_bar /= w_sum;
610    let mut s_tt = 0.0_f64;
611    for i in 0..n {
612        let dt = u[i] - t_bar;
613        s_tt += assign[i] * assign[i] * dt * dt;
614    }
615    if !(s_tt > 1e-12 * (1.0 + t_bar * t_bar)) {
616        return None;
617    }
618    let mut b0 = Array1::<f64>::zeros(p);
619    let mut b1 = Array1::<f64>::zeros(p);
620    let mut linear_rss = 0.0_f64;
621    for j in 0..p {
622        let mut s_1y = 0.0_f64;
623        let mut s_ty = 0.0_f64;
624        for i in 0..n {
625            let a = assign[i];
626            let dt = u[i] - t_bar;
627            let y = target_resid[[i, j]];
628            s_1y += a * y;
629            s_ty += a * dt * y;
630        }
631        b0[j] = s_1y / w_sum;
632        b1[j] = s_ty / s_tt;
633    }
634    for i in 0..n {
635        let a = assign[i];
636        let dt = u[i] - t_bar;
637        for j in 0..p {
638            let r = target_resid[[i, j]] - a * (b0[j] + dt * b1[j]);
639            linear_rss += r * r;
640        }
641    }
642    let linear_log_det_h = (p as f64) * (w_sum.ln() + s_tt.ln());
643    let linear_nle = reduced_laplace_nle(0.5 * linear_rss, linear_log_det_h);
644    if !linear_nle.is_finite() {
645        return None;
646    }
647    let linear = HybridAtomCandidate::linear(linear_nle, 2 * p);
648    let image = AtomLinearImage {
649        atom_idx,
650        t_bar,
651        b0,
652        b1,
653        row_codes: Some(u),
654    };
655    Some((linear, image))
656}
657
658/// Assemble the per-atom candidate slots for [`select_hybrid_split`] from the
659/// fitted `d = 1` atoms, run the adjudication, and return the report.
660///
661/// `atoms` are the fitted dictionary atoms; `coords_for` yields the on-atom
662/// coordinate column for a slot, `assign_for` the per-row assignment mass `a_k`,
663/// `decoded_for` the fitted decoded image rows `γ_k`, and `target_resid_for` the
664/// atom's leave-this-atom-out response residual `y_resp` (the data both
665/// candidates are scored against, #1202). `manifold_for` yields the atom's chart
666/// manifold (a flat / Euclidean chart can present only the linear candidate,
667/// enforced inside the selector).
668///
669/// Returns `None` (no report) when no atom is eligible — there is nothing to
670/// adjudicate.
671pub fn build_hybrid_split_report<'a, C, W, D, R, M, E>(
672    atoms: &'a [SaeManifoldAtom],
673    eligible_d1: impl Iterator<Item = usize>,
674    mut coords_for: C,
675    mut assign_for: W,
676    mut decoded_for: D,
677    mut target_resid_for: R,
678    mut manifold_for: M,
679    mut delta_ev_for: E,
680    // #1026 — the full target's total (column-centered) variance `SST_full`, the
681    // fixed denominator of the EV-preservation gate. `≤ 0` / non-finite disables
682    // the gate (a degenerate, varianceless target has no EV to preserve).
683    total_centered_variance: f64,
684) -> Result<Option<SaeHybridSplitReport>, String>
685where
686    C: FnMut(usize) -> Array1<f64>,
687    W: FnMut(usize) -> Array1<f64>,
688    D: FnMut(usize) -> Array2<f64>,
689    R: FnMut(usize) -> Array2<f64>,
690    M: FnMut(usize) -> LatentManifold,
691    // The atom's held-out LOAO `ΔEV_k`, keyed by atom index. `None` when LOAO EV
692    // is unavailable (e.g. the caller has no target to measure against).
693    E: FnMut(usize) -> Option<f64>,
694{
695    let mut slots: Vec<Vec<HybridAtomCandidate>> = Vec::new();
696    let mut names: Vec<String> = Vec::new();
697    let mut manifolds: Vec<LatentManifold> = Vec::new();
698    // Per-slot fitted straight sub-model `(atom_idx, t̄, b₀, b₁)`, surfaced onto
699    // the verdict iff the slot selects LINEAR so the collapsed reconstruction can
700    // substitute it for the curved decoded image.
701    let mut linear_images: Vec<AtomLinearImage> = Vec::new();
702    // Per-slot `(Θ, ΔEV)` — the #1026 frontier point — carried onto each verdict
703    // so the geometry/EV pairing is structured report data, not a log line.
704    let mut turnings: Vec<Option<f64>> = Vec::new();
705    let mut delta_evs: Vec<Option<f64>> = Vec::new();
706
707    for atom_idx in eligible_d1 {
708        let atom = &atoms[atom_idx];
709        let coords = coords_for(atom_idx);
710        let assign = assign_for(atom_idx);
711        let decoded = decoded_for(atom_idx);
712        let target_resid = target_resid_for(atom_idx);
713        // Curved parameter price = the decoder's `M · p` coefficients.
714        let curved_num_params = atom.decoder_coefficients.len();
715        let fitted_turning = atom.basis_evaluator.as_ref().and_then(|evaluator| {
716            d1_atom_fitted_turning(
717                evaluator.as_ref(),
718                atom.decoder_coefficients.view(),
719                coords.view(),
720            )
721            .ok()
722            .flatten()
723        });
724        // Evaluate the curved design `Φ(t)` on this atom's assigned rows so the
725        // curved arm's Laplace complexity is the real weighted-design Gram
726        // log-determinant rather than a parameter-count proxy (#1223). A `d = 1`
727        // atom's coordinate column is presented as an `n × 1` design input. If
728        // the evaluator is absent or refuses, `curved_phi` stays `None` and
729        // `build_atom_candidates` falls back to the proxy.
730        let coords_col = coords
731            .view()
732            .into_shape_with_order((coords.len(), 1))
733            .ok()
734            .map(|v| v.to_owned());
735        let curved_phi = match (atom.basis_evaluator.as_ref(), coords_col.as_ref()) {
736            (Some(evaluator), Some(col)) => {
737                evaluator.evaluate(col.view()).ok().map(|(phi, _jet)| phi)
738            }
739            _ => None,
740        };
741        // A flat (Euclidean) chart cannot honestly present a curved candidate;
742        // the selector drops it. Present both for curveable charts.
743        let manifold = manifold_for(atom_idx);
744        match build_atom_candidates(
745            coords.view(),
746            assign.view(),
747            decoded.view(),
748            target_resid.view(),
749            curved_num_params,
750            curved_phi.as_ref().map(|phi| phi.view()),
751            fitted_turning,
752        ) {
753            Some((linear, curved, (t_bar, b0, b1))) => {
754                // #1026 EV-PRESERVATION gate. Collapsing this slot raises the full
755                // reconstruction SSR by `linear_rss − curved_rss`; if that is more
756                // than `SAE_HYBRID_COLLAPSE_MAX_EV_LOSS` of the fixed total target
757                // variance the collapse would DROP reconstruction EV materially, so
758                // veto it by presenting only the curved candidate (the selector must
759                // keep curved). A lossless / improving collapse (`≤ 0`) and a
760                // negligible one stay free to collapse — EV-neutral cases (the
761                // top-k / birth-topology lines) are untouched. Only curveable charts
762                // are gated; a euclidean chart never had a curved option.
763                let collapse_loses_ev = total_centered_variance.is_finite()
764                    && total_centered_variance > 0.0
765                    && collapse_ssr_increase(
766                        coords.view(),
767                        assign.view(),
768                        decoded.view(),
769                        target_resid.view(),
770                        t_bar,
771                        &b0,
772                        &b1,
773                    ) > SAE_HYBRID_COLLAPSE_MAX_EV_LOSS * total_centered_variance;
774                let slot = if manifold.is_euclidean() {
775                    vec![linear]
776                } else if collapse_loses_ev {
777                    vec![curved]
778                } else {
779                    vec![linear, curved]
780                };
781                slots.push(slot);
782                names.push(atom.name.clone());
783                manifolds.push(manifold);
784                turnings.push(fitted_turning);
785                delta_evs.push(delta_ev_for(atom_idx));
786                linear_images.push(AtomLinearImage {
787                    atom_idx,
788                    t_bar,
789                    b0,
790                    b1,
791                    row_codes: None,
792                });
793            }
794            // #1026 collapse rescue: `build_atom_candidates` refused because the
795            // atom's own coordinate collapsed (`s_tt ≈ 0`) — the rank-1 co-collapse
796            // fixed point. Recover a FRESH linear image from the residual's top
797            // direction (fresh per-row codes) and force the LINEAR verdict (a
798            // single-option slot the selector must take) so the slot reconstructs
799            // its residual's best linear axis at linear quality instead of the
800            // collapsed-curve constant. `None` only when the residual itself is
801            // degenerate — then there is genuinely nothing to recover and we skip.
802            None => match build_collapse_rescue_linear_image(
803                atom_idx,
804                assign.view(),
805                target_resid.view(),
806            ) {
807                Some((linear, image)) => {
808                    slots.push(vec![linear]);
809                    names.push(atom.name.clone());
810                    manifolds.push(manifold);
811                    turnings.push(fitted_turning);
812                    delta_evs.push(delta_ev_for(atom_idx));
813                    linear_images.push(image);
814                }
815                None => continue,
816            },
817        }
818    }
819
820    if slots.is_empty() {
821        return Ok(None);
822    }
823
824    let selection = select_hybrid_split(&slots)?;
825    let verdicts: Vec<AtomHybridVerdict> = names
826        .into_iter()
827        .zip(selection.atoms.iter().copied())
828        .zip(linear_images.into_iter())
829        .zip(turnings.into_iter())
830        .zip(delta_evs.into_iter())
831        .map(
832            |((((atom_name, choice), linear_image), fitted_turning), train_loao_delta_ev)| {
833                let kept_curved = !choice.param.is_linear();
834                AtomHybridVerdict {
835                    atom_name,
836                    choice,
837                    kept_curved,
838                    fitted_turning,
839                    train_loao_delta_ev,
840                    // Carry the straight sub-model only when the verdict collapses
841                    // this slot to linear — the curved slots keep their fitted image.
842                    linear_image: if kept_curved {
843                        None
844                    } else {
845                        Some(linear_image)
846                    },
847                }
848            },
849        )
850        .collect();
851
852    Ok(Some(SaeHybridSplitReport {
853        verdicts,
854        selection,
855    }))
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861    use std::f64::consts::PI;
862
863    /// A straight RESPONSE residual (the atom's data is a line) is explained
864    /// equally well by both candidates, so the cheaper linear special case wins.
865    /// With `a_k = 1` the curved decoded image is straight too (Θ = 0), so both
866    /// the dominance floor and the evidence argmin select linear. This is the
867    /// common-data nested comparison (#1202): linear is the curved family's
868    /// `Θ = 0` member, so it cannot lose when a line already explains the data.
869    #[test]
870    fn straight_residual_selects_linear() {
871        let n = 40;
872        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
873        let assign = Array1::<f64>::ones(n);
874        // The data the atom must explain is a straight line in ℝ²; the curved
875        // decoded image equals that same line (a Θ = 0 curved fit).
876        let mut data = Array2::<f64>::zeros((n, 2));
877        let mut decoded = Array2::<f64>::zeros((n, 2));
878        for i in 0..n {
879            data[[i, 0]] = coords[i];
880            data[[i, 1]] = 0.6 * coords[i];
881            decoded[[i, 0]] = coords[i];
882            decoded[[i, 1]] = 0.6 * coords[i];
883        }
884        let (linear, curved, _) = build_atom_candidates(
885            coords.view(),
886            assign.view(),
887            decoded.view(),
888            data.view(),
889            // a generous curved parameter price (M·p)
890            10,
891            None,
892            Some(0.0),
893        )
894        .expect("straight residual yields a candidate pair");
895        let choice =
896            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
897        assert!(
898            choice.param.is_linear(),
899            "a straight response residual must keep the linear special case"
900        );
901    }
902
903    /// A turning RESPONSE residual (the atom's data traces a full circle) is fit
904    /// well by the curved decoded image (curved_rss ≈ 0) but poorly by any
905    /// straight line (large linear_rss), so the curved candidate wins the common
906    /// evidence comparison once its data-fit gain exceeds its extra parameter
907    /// price (#1202).
908    #[test]
909    fn turning_residual_selects_curved_on_evidence() {
910        let n = 60;
911        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
912        let assign = Array1::<f64>::ones(n);
913        // The data is a full circle; the curved decoded image is that same
914        // circle (the curved atom reconstructs its assigned residual), so the
915        // curved candidate has ≈ zero data-fit residual while a straight line
916        // cannot follow the loop.
917        let mut data = Array2::<f64>::zeros((n, 2));
918        let mut decoded = Array2::<f64>::zeros((n, 2));
919        for i in 0..n {
920            let theta = 2.0 * PI * coords[i];
921            data[[i, 0]] = theta.cos();
922            data[[i, 1]] = theta.sin();
923            decoded[[i, 0]] = theta.cos();
924            decoded[[i, 1]] = theta.sin();
925        }
926        // The curved atom has 5 parameters (just above the 4 = 2·p linear budget);
927        // the full-circle linear residual exceeds the extra-parameter overhead, so
928        // curved wins on evidence.
929        let (linear, curved, _) = build_atom_candidates(
930            coords.view(),
931            assign.view(),
932            decoded.view(),
933            data.view(),
934            5,
935            None,
936            Some(2.0 * PI),
937        )
938        .expect("turning residual yields a candidate pair");
939        assert!(
940            linear.negative_log_evidence > curved.negative_log_evidence,
941            "the line must misfit the circular residual worse than the curve does \
942             (linear NLE {} should exceed curved NLE {})",
943            linear.negative_log_evidence,
944            curved.negative_log_evidence
945        );
946        let choice =
947            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
948        assert_eq!(
949            choice.param,
950            gam_solve::evidence::HybridAtomParam::Curved { latent_dim: 1 },
951            "a full-circle response residual must keep the curved parameterization"
952        );
953        assert!(
954            choice.curved_evidence_margin > 0.0,
955            "curved must win a positive evidence margin over the linear secant"
956        );
957    }
958
959    /// The nested-dominance floor on common data (#1202): when the curved decoded
960    /// image is a WORSE fit to the response residual than its own best straight
961    /// projection, linear must win — the curved family cannot be charged extra
962    /// parameters to fit the residual no better than its `Θ = 0` member. Here the
963    /// data is a line but the curved image bends away from it, so curved_rss >
964    /// linear_rss and the cheaper, better-fitting line is selected.
965    #[test]
966    fn linear_beats_curved_when_curve_misfits_residual() {
967        let n = 50;
968        let coords = Array1::from_iter((0..n).map(|i| (i as f64) / ((n - 1) as f64)));
969        let assign = Array1::<f64>::ones(n);
970        // Data is a straight line; the curved decoded image is a parabola that
971        // departs from it, so a straight line fits the data strictly better.
972        let mut data = Array2::<f64>::zeros((n, 2));
973        let mut decoded = Array2::<f64>::zeros((n, 2));
974        for i in 0..n {
975            let t = coords[i];
976            data[[i, 0]] = t;
977            data[[i, 1]] = 0.5 * t;
978            decoded[[i, 0]] = t;
979            decoded[[i, 1]] = t * t; // bends away from the linear data
980        }
981        let (linear, curved, _) = build_atom_candidates(
982            coords.view(),
983            assign.view(),
984            decoded.view(),
985            data.view(),
986            // a real curved Θ above the floor so the dominance floor does not fire
987            6,
988            None,
989            Some(1.0),
990        )
991        .expect("candidate pair");
992        let choice =
993            gam_solve::evidence::select_hybrid_atom(&[linear, curved]).expect("non-empty slot");
994        assert!(
995            choice.param.is_linear(),
996            "a curved image that fits the data worse than its own line must yield \
997             to the linear special case on common-data evidence (#1202)"
998        );
999    }
1000
1001    /// The LINEAR candidate's Laplace logdet is the genuine weighted-design Gram
1002    /// determinant `p·(log w_sum + log s_tt)` with `w_sum = Σ a_k²`, `s_tt =
1003    /// Σ a_k²(t − t̄)²` — it INCLUDES the coordinate-spread term `log(s_tt)`
1004    /// (#1203). Verify both contributions are present by reading the logdet off a
1005    /// candidate whose linear residual is exactly zero (response residual = the
1006    /// fitted line), so `NLE_linear = ½·logdet`. Doubling the coordinate spread
1007    /// (at fixed assignment mass) scales `s_tt` by 4 → logdet += `p·log(4)`;
1008    /// doubling all assignment masses scales BOTH `w_sum` and `s_tt` by 4 (they
1009    /// are quadratic in `a_k`) → logdet += `2p·log(4)`.
1010    #[test]
1011    fn linear_logdet_includes_weighted_coordinate_spread() {
1012        let n = 40;
1013        let p = 2usize;
1014        // Read the logdet back off a candidate with zero linear residual: the
1015        // response residual is exactly `a_k·(line)`, so the WLS line recovers it
1016        // with RSS == 0 and `NLE_linear = ½·logdet`.
1017        let logdet = |coords: &Array1<f64>, assign: &Array1<f64>| -> f64 {
1018            // A straight image; the response residual is the same line scaled by
1019            // the per-row assignment mass `a_k`, so the prediction `a_k·(b₀+dt·b₁)`
1020            // matches it exactly and linear_rss == 0.
1021            let line = |t: f64| -> [f64; 2] { [t, 0.6 * t] };
1022            let mut decoded = Array2::<f64>::zeros((n, p));
1023            let mut data = Array2::<f64>::zeros((n, p));
1024            for i in 0..n {
1025                let l = line(coords[i]);
1026                decoded[[i, 0]] = l[0];
1027                decoded[[i, 1]] = l[1];
1028                data[[i, 0]] = assign[i] * l[0];
1029                data[[i, 1]] = assign[i] * l[1];
1030            }
1031            let (linear, _curved, _) = build_atom_candidates(
1032                coords.view(),
1033                assign.view(),
1034                decoded.view(),
1035                data.view(),
1036                10,
1037                None,
1038                Some(0.0),
1039            )
1040            .expect("straight residual yields a pair");
1041            2.0 * linear.negative_log_evidence // = logdet (linear_rss == 0)
1042        };
1043
1044        let base_coords =
1045            Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1046        let ones = Array1::<f64>::ones(n);
1047
1048        // Doubling the coordinate spread → s_tt ×4, w_sum fixed → logdet += p·log(4).
1049        let wide_coords = base_coords.mapv(|t| 2.0 * t);
1050        let d_spread = logdet(&wide_coords, &ones) - logdet(&base_coords, &ones);
1051        assert!(
1052            (d_spread - (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1053            "linear logdet must move by p·log(4) when coordinate spread doubles \
1054             (got {d_spread}); the spread term log(s_tt) must be present"
1055        );
1056
1057        // Doubling all assignment masses → w_sum ×4 AND s_tt ×4 (quadratic in a_k)
1058        // → logdet += 2p·log(4).
1059        let twos = Array1::<f64>::from_elem(n, 2.0);
1060        let d_weight = logdet(&base_coords, &twos) - logdet(&base_coords, &ones);
1061        assert!(
1062            (d_weight - 2.0 * (p as f64) * 4.0_f64.ln()).abs() < 1e-9,
1063            "linear logdet must move by 2p·log(4) when all assignment masses double \
1064             (got {d_weight})"
1065        );
1066    }
1067
1068    /// #1223 — the curved arm's Laplace complexity is the REAL weighted-design
1069    /// Gram log-determinant `p·log|ΦᵀWΦ|_+`, not a parameter-count proxy. Build a
1070    /// curved design whose columns are the constant and the centered coordinate
1071    /// (a 2-column basis), so `ΦᵀWΦ = diag(w_sum, s_tt)` exactly matches the
1072    /// linear arm's data Gram, and assert `curved_design_gram_logdet` returns
1073    /// `p·(log w_sum + log s_tt)` — the same determinant the linear arm reports
1074    /// on the same design weight. A proxy `M·log(w_sum)` would instead omit the
1075    /// `log(s_tt)` spread term, so this pins the genuine determinant.
1076    #[test]
1077    fn curved_gram_logdet_is_real_weighted_design_determinant() {
1078        let n = 40;
1079        let p = 3usize;
1080        let coords = Array1::from_iter((0..n).map(|i| -1.0 + 2.0 * (i as f64) / ((n - 1) as f64)));
1081        let assign = Array1::<f64>::from_iter((0..n).map(|i| 0.5 + 0.01 * (i as f64)));
1082
1083        // Mass-weighted coordinate mean and spread under wᵢ = a_k².
1084        let mut w_sum = 0.0;
1085        let mut t_bar = 0.0;
1086        for i in 0..n {
1087            let w = assign[i] * assign[i];
1088            w_sum += w;
1089            t_bar += w * coords[i];
1090        }
1091        t_bar /= w_sum;
1092        let mut s_tt = 0.0;
1093        for i in 0..n {
1094            let dt = coords[i] - t_bar;
1095            s_tt += assign[i] * assign[i] * dt * dt;
1096        }
1097
1098        // Curved design columns: [1, (t − t̄)]. Its weighted Gram is exactly
1099        // diag(w_sum, s_tt) (the cross term Σ w·(t−t̄) vanishes by construction),
1100        // so log|ΦᵀWΦ| = log(w_sum) + log(s_tt).
1101        let mut phi = Array2::<f64>::zeros((n, 2));
1102        for i in 0..n {
1103            phi[[i, 0]] = 1.0;
1104            phi[[i, 1]] = coords[i] - t_bar;
1105        }
1106        let got = curved_design_gram_logdet(phi.view(), assign.view(), p)
1107            .expect("non-degenerate curved design has a determinant");
1108        let want = (p as f64) * (w_sum.ln() + s_tt.ln());
1109        assert!(
1110            (got - want).abs() < 1e-9,
1111            "curved Gram logdet must be the real p·log|ΦᵀWΦ| = {want}, got {got}"
1112        );
1113
1114        // A rank-deficient design (a duplicated column) drops the null direction:
1115        // its determinant equals that of the single retained constant column,
1116        // p·log(w_sum), NOT a 2-column proxy.
1117        let mut phi_dup = Array2::<f64>::zeros((n, 2));
1118        for i in 0..n {
1119            phi_dup[[i, 0]] = 1.0;
1120            phi_dup[[i, 1]] = 1.0;
1121        }
1122        let got_dup = curved_design_gram_logdet(phi_dup.view(), assign.view(), p)
1123            .expect("rank-1 design still has a positive determinant");
1124        let want_dup = (p as f64) * (2.0 * w_sum).ln();
1125        assert!(
1126            (got_dup - want_dup).abs() < 1e-9,
1127            "rank-deficient curved Gram must report only its positive direction \
1128             (p·log(2·w_sum) = {want_dup}), got {got_dup}"
1129        );
1130    }
1131
1132    /// A degenerate (single-point-mass) coordinate has no slope direction and is
1133    /// refused rather than adjudicated on a fabricated deviance.
1134    #[test]
1135    fn degenerate_coordinate_is_refused() {
1136        let n = 5;
1137        let coords = Array1::<f64>::from_elem(n, 0.5); // no spread
1138        let assign = Array1::<f64>::ones(n);
1139        let decoded = Array2::<f64>::zeros((n, 2));
1140        let data = Array2::<f64>::zeros((n, 2));
1141        assert!(
1142            build_atom_candidates(
1143                coords.view(),
1144                assign.view(),
1145                decoded.view(),
1146                data.view(),
1147                6,
1148                None,
1149                Some(0.0)
1150            )
1151            .is_none(),
1152            "a degenerate coordinate span must be refused"
1153        );
1154    }
1155}