Skip to main content

gam_solve/inference/
residual_factor.rs

1//! #974 — the structured-residual covariance estimator and the single producer
2//! of [`MetricProvenance::WhitenedStructured`](gam_problem::MetricProvenance::WhitenedStructured).
3//!
4//! # What this estimates
5//!
6//! Given a residual matrix `R ∈ ℝ^{n×p}` (one `p`-dimensional reconstruction
7//! residual per row) and a smooth *activity coordinate* `z ∈ ℝ^n`, this fits the
8//! **structured residual-covariance model**
9//!
10//! ```text
11//!     Cov(r_n) = Σ_n = Λ · c(z_n) · Λᵀ + D ,
12//! ```
13//!
14//! where
15//!
16//! * `Λ ∈ ℝ^{p×r}` is a **low-rank interference factor** (the shared
17//!   off-isotropic subspace the residuals correlate along — e.g. a planted
18//!   interference subspace or a topology-race confound),
19//! * `D = diag(d) ≻ 0` is the **idiosyncratic diagonal** (per-channel
20//!   independent noise), and
21//! * `c(z) > 0` is the **smooth activity-scale law**: a strictly-positive scalar
22//!   that modulates the factor energy with the activity coordinate, recovered as
23//!   a binned-then-smoothed function of `z`.
24//!
25//! The fit is a deterministic, fixed-iteration **alternation** (no clock, no
26//! RNG; any tie is broken by index): it alternates
27//!
28//! 1. *(scale | Λ, D)* — re-estimate the per-row factor activity `c(z_n)` and
29//!    smooth it across `z`, holding the factor model fixed; and
30//! 2. *(Λ, D | scale)* — re-estimate the factor and diagonal from the
31//!    scale-deflated second-moment, holding the activity law fixed,
32//!
33//! a fixed small number of times. The **factor count `r`** is chosen by an
34//! evidence ladder: each candidate `r` is scored by its penalized Gaussian
35//! log-evidence and the best is kept.
36//!
37//! # What it produces
38//!
39//! [`StructuredResidualModel::row_metric`] materializes the **per-row precision
40//! factor** `U_n ∈ ℝ^{p×p}` with `U_n U_nᵀ = Σ_n^{-1}`, packaged as a
41//! [`RowMetric`](gam_problem::RowMetric) with
42//! [`MetricProvenance::WhitenedStructured`](gam_problem::MetricProvenance::WhitenedStructured).
43//! Whitening a residual `r_n` through it (`U_nᵀ r_n`) yields a vector whose
44//! squared Euclidean norm is `r_nᵀ Σ_n^{-1} r_n` — the Mahalanobis residual under
45//! the estimated noise model, which is exactly the likelihood-correct data-fit.
46//! The factor is built from `Σ_n^{-1}` computed in **Woodbury form** (an
47//! `r × r` solve, never a `p × p` inverse), so the estimator scales with the
48//! factor rank, not the dense output dimension.
49//!
50//! This is the first real producer of `WhitenedStructured`, and therefore the
51//! first metric whose `whitens_likelihood()` is `true`: see
52//! [`RowMetric::whitens_likelihood`](gam_problem::RowMetric::whitens_likelihood).
53
54use std::sync::Arc;
55
56use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
57
58use gam_problem::RowMetric;
59use gam_linalg::faer_ndarray::{FaerCholesky, FaerEigh};
60use faer::Side;
61
62/// Number of (scale | factor) ↔ (factor | scale) alternation sweeps. Fixed and
63/// deterministic: the alternation is a smooth descent on the structured-Gaussian
64/// objective and converges geometrically, so a small fixed budget is both
65/// sufficient and reproducible (no clock/RNG-driven stopping).
66const ALTERNATION_SWEEPS: usize = 8;
67
68/// Number of bins the activity coordinate `z` is partitioned into for the smooth
69/// activity-scale `c(z)`. The per-bin factor activity is estimated then linearly
70/// interpolated across bin centers, giving a continuous piecewise-linear scale
71/// law. Chosen as a fixed structural constant (magic-by-default): enough bins to
72/// resolve a smooth monotone or unimodal scale trend without over-fitting the
73/// per-row noise.
74const ACTIVITY_SCALE_BINS: usize = 8;
75
76/// Relative floor on the idiosyncratic diagonal `D`, as a fraction of the mean
77/// residual variance. Keeps `Σ_n ≻ 0` and the Woodbury `r × r` capacitance
78/// invertible even when a channel is (near-)perfectly explained by the factor.
79const DIAGONAL_REL_FLOOR: f64 = 1e-6;
80
81/// Relative floor on the activity scale `c(z)`, as a fraction of its mean. Keeps
82/// `c(z) > 0` (a covariance scale) across the whole `z` range.
83const SCALE_REL_FLOOR: f64 = 1e-4;
84
85/// The fitted structured residual-covariance model: low-rank factor `Λ`,
86/// idiosyncratic diagonal `D`, and the smooth activity-scale `c(z)` evaluated at
87/// every row. Produces per-row precision factors and the
88/// [`MetricProvenance::WhitenedStructured`](gam_problem::MetricProvenance::WhitenedStructured)
89/// [`RowMetric`](gam_problem::RowMetric).
90#[derive(Clone, Debug)]
91pub struct StructuredResidualModel {
92    /// Output dimensionality `p` (residual width).
93    p: usize,
94    /// Selected factor rank `r` (`0 ≤ r ≤ p`). `0` ⇒ pure-diagonal noise model.
95    factor_rank: usize,
96    /// Interference factor `Λ ∈ ℝ^{p×r}` (the shared off-diagonal subspace).
97    lambda: Array2<f64>,
98    /// Idiosyncratic diagonal `d ∈ ℝ^p` (`D = diag(d)`), floored `≻ 0`.
99    diagonal: Array1<f64>,
100    /// Per-row activity scale `c(z_n) > 0`, length `n`.
101    row_scale: Array1<f64>,
102    /// Penalized Gaussian log-evidence of the selected model (higher is better).
103    /// The value the evidence ladder maximized over the candidate ranks.
104    log_evidence: f64,
105}
106
107/// Estimator inputs: the residual matrix and the smooth activity coordinate.
108///
109/// `residuals` is `R ∈ ℝ^{n×p}`. `activity` is `z ∈ ℝ^n` — the coordinate the
110/// scale law `c(z)` is smooth in (e.g. an assignment-mass or activation-strength
111/// summary per row). When no genuine activity coordinate is available, passing a
112/// constant `z` recovers a homoscedastic factor model (`c(z) ≡ const`).
113pub struct ResidualFactorInput<'a> {
114    /// Residual matrix `R ∈ ℝ^{n×p}`.
115    pub residuals: ArrayView2<'a, f64>,
116    /// Activity coordinate `z ∈ ℝ^n` the scale law is smooth in.
117    pub activity: ArrayView1<'a, f64>,
118    /// Maximum factor rank the evidence ladder is allowed to consider. The
119    /// ladder scores `r = 0, 1, …, min(max_factor_rank, p−1)` and keeps the
120    /// penalized-evidence maximizer. `0` forces the pure-diagonal model.
121    pub max_factor_rank: usize,
122}
123
124impl StructuredResidualModel {
125    /// Fit the structured residual-covariance model by the deterministic
126    /// fixed-iteration alternation, selecting the factor rank by the evidence
127    /// ladder. Returns an error only on shape / non-finite-input violations; the
128    /// numerical path is total (every floor and solve is guarded).
129    pub fn fit(input: ResidualFactorInput<'_>) -> Result<Self, String> {
130        let r = input.residuals;
131        let z = input.activity;
132        let n = r.nrows();
133        let p = r.ncols();
134        if n == 0 || p == 0 {
135            return Err(format!(
136                "StructuredResidualModel::fit: residuals must be non-empty; got ({n}, {p})"
137            ));
138        }
139        if z.len() != n {
140            return Err(format!(
141                "StructuredResidualModel::fit: activity length {} != residual rows {n}",
142                z.len()
143            ));
144        }
145        if !r.iter().all(|v| v.is_finite()) {
146            return Err("StructuredResidualModel::fit: residuals must be finite".to_string());
147        }
148        if !z.iter().all(|v| v.is_finite()) {
149            return Err("StructuredResidualModel::fit: activity must be finite".to_string());
150        }
151
152        // Bin assignment for the activity-scale law: deterministic equal-width
153        // bins over the observed z-range. A degenerate (zero-width) range maps
154        // every row to bin 0, recovering a single homoscedastic scale.
155        let bins = ACTIVITY_SCALE_BINS.max(1);
156        let z_min = z.iter().copied().fold(f64::INFINITY, f64::min);
157        let z_max = z.iter().copied().fold(f64::NEG_INFINITY, f64::max);
158        let z_span = z_max - z_min;
159        let row_bin: Vec<usize> = (0..n)
160            .map(|i| {
161                if z_span <= 0.0 {
162                    0
163                } else {
164                    let frac = (z[i] - z_min) / z_span;
165                    let idx = (frac * bins as f64).floor() as isize;
166                    idx.clamp(0, bins as isize - 1) as usize
167                }
168            })
169            .collect();
170
171        let max_rank = input.max_factor_rank.min(p.saturating_sub(1));
172
173        // Evidence ladder over candidate factor ranks. Each candidate is fit by
174        // the full alternation and scored by its penalized Gaussian log-evidence;
175        // the maximizer is kept. Index order breaks any tie (lowest rank wins on
176        // an exact tie — Occam).
177        let mut best: Option<StructuredResidualModel> = None;
178        for rank in 0..=max_rank {
179            let model = Self::fit_fixed_rank(r, &row_bin, bins, rank)?;
180            let take = match &best {
181                None => true,
182                Some(b) => model.log_evidence > b.log_evidence,
183            };
184            if take {
185                best = Some(model);
186            }
187        }
188        best.ok_or_else(|| "StructuredResidualModel::fit: evidence ladder empty".to_string())
189    }
190
191    /// Fit the model at a fixed factor rank by the deterministic alternation.
192    fn fit_fixed_rank(
193        r: ArrayView2<'_, f64>,
194        row_bin: &[usize],
195        bins: usize,
196        rank: usize,
197    ) -> Result<Self, String> {
198        let n = r.nrows();
199        let p = r.ncols();
200
201        // Mean residual variance — the scale reference for the diagonal floor.
202        let mut total_var = 0.0_f64;
203        for i in 0..n {
204            for j in 0..p {
205                total_var += r[[i, j]] * r[[i, j]];
206            }
207        }
208        let mean_var = (total_var / (n as f64 * p as f64)).max(f64::MIN_POSITIVE);
209        let diag_floor = DIAGONAL_REL_FLOOR * mean_var;
210
211        // Initialize the per-row scale to 1 (homoscedastic start), the diagonal
212        // to the per-channel sample variance, and Λ to the leading eigenvectors
213        // of the (scale-1) second moment. The alternation refines all three.
214        let mut row_scale = Array1::<f64>::ones(n);
215        let mut bin_scale = Array1::<f64>::ones(bins);
216        // Raw (undeflated) per-channel second moment — the D estimator's data
217        // term. Constant across sweeps.
218        let raw_diag = column_variances(r);
219        let mut diagonal = raw_diag.mapv(|v| v.max(diag_floor));
220        let mut lambda = Array2::<f64>::zeros((p, rank));
221
222        for _sweep in 0..ALTERNATION_SWEEPS {
223            // (Λ, D | scale): scale-deflated second moment
224            //   S = (1/n) Σ_n (r_n r_nᵀ) / c(z_n).
225            // Under the model E[r_n r_nᵀ] = c_n ΛΛᵀ + D, so S ≈ ΛΛᵀ + D̄ with
226            // D̄ the scale-averaged diagonal; the leading eigenpairs of S − D
227            // give Λ, the residual diagonal gives D.
228            let s = scaled_second_moment(r, &row_scale);
229            let (evals, evecs) = symmetric_eig_ascending(&s)?;
230            // Leading `rank` eigenpairs (eigenvalues ascending ⇒ take the tail).
231            if rank > 0 {
232                for k in 0..rank {
233                    let col = p - 1 - k;
234                    // Factor energy above the idiosyncratic floor: the part of
235                    // the eigenvalue not explained by the mean diagonal.
236                    let mean_diag = diagonal.iter().copied().sum::<f64>() / p as f64;
237                    let energy = (evals[col] - mean_diag).max(0.0);
238                    let amp = energy.sqrt();
239                    for row in 0..p {
240                        lambda[[row, k]] = amp * evecs[[row, col]];
241                    }
242                }
243            }
244            // D update from the RAW (undeflated) moment, floored ≻ 0. The model
245            // is Σ_n = c_n·ΛΛᵀ + D with D NOT scale-multiplied, and c is mean-1
246            // normalized, so E[(1/n)Σ r_n r_nᵀ] = ΛΛᵀ + D exactly. The deflated
247            // moment `s` is the right object for the FACTOR block (its factor
248            // part is scale-free) but its diagonal carries D·mean(1/c) — a
249            // Jensen-inflated D (mean(1/c) > 1 for any non-constant law), which
250            // biased D upward by exactly mean(1/c̃) and let a spurious
251            // higher-rank candidate win the evidence ladder on a better D
252            // alone (the probe's rank-2 winner had a zero second column).
253            for j in 0..p {
254                let mut factor_var = 0.0_f64;
255                for k in 0..rank {
256                    factor_var += lambda[[j, k]] * lambda[[j, k]];
257                }
258                diagonal[j] = (raw_diag[j] - factor_var).max(diag_floor);
259            }
260
261            // (scale | Λ, D): per-row factor activity. With residual r_n, the
262            // factor-subspace energy is r_nᵀ P r_n where P projects onto
263            // range(Λ) in the D-whitened metric; the maximum-likelihood scalar
264            // multiplier on ΛΛᵀ that matches the row's factor-subspace energy is
265            //   c_n = (r̃_nᵀ B (BᵀB)^{-1} Bᵀ r̃_n) / tr(...)-normalizer.
266            // We use a stable closed-form proxy: the row's factor-coordinate
267            // energy ‖Λ⁺ r_n‖² normalized by the unit-scale expectation, then
268            // bin-smoothed across z. With rank 0 there is no factor ⇒ c ≡ 1.
269            if rank > 0 {
270                let mut bin_num = Array1::<f64>::zeros(bins);
271                let mut bin_den = Array1::<f64>::zeros(bins);
272                let coords = factor_coordinates(&lambda, &diagonal, r)?;
273                for i in 0..n {
274                    let mut energy = 0.0_f64;
275                    for k in 0..rank {
276                        energy += coords[[i, k]] * coords[[i, k]];
277                    }
278                    let b = row_bin[i];
279                    bin_num[b] += energy;
280                    bin_den[b] += rank as f64;
281                }
282                // Per-bin mean factor energy = activity scale. Empty bins inherit
283                // the global mean so the scale law stays defined everywhere.
284                let global = {
285                    let num: f64 = bin_num.iter().sum();
286                    let den: f64 = bin_den.iter().sum();
287                    if den > 0.0 { num / den } else { 1.0 }
288                };
289                for b in 0..bins {
290                    bin_scale[b] = if bin_den[b] > 0.0 {
291                        bin_num[b] / bin_den[b]
292                    } else {
293                        global
294                    };
295                }
296                // Smooth (3-point moving average over bins) for a continuous law,
297                // then floor ≻ 0.
298                let scale_floor = SCALE_REL_FLOOR * global.max(f64::MIN_POSITIVE);
299                let smoothed = moving_average_3(&bin_scale);
300                for b in 0..bins {
301                    bin_scale[b] = smoothed[b].max(scale_floor);
302                }
303                // Re-normalize so the mean scale is 1 (the factor amplitude lives
304                // in Λ; c(z) carries only the relative activity law). This keeps
305                // the (Λ, D) ↔ (scale) split identified.
306                let mean_scale = bin_scale.iter().copied().sum::<f64>() / bins as f64;
307                if mean_scale > 0.0 {
308                    bin_scale.mapv_inplace(|v| v / mean_scale);
309                }
310                for i in 0..n {
311                    row_scale[i] = bin_scale[row_bin[i]].max(scale_floor);
312                }
313            }
314        }
315
316        let log_evidence = penalized_log_evidence(r, &lambda, &diagonal, &row_scale, rank);
317        let mut model = Self {
318            p,
319            factor_rank: rank,
320            lambda,
321            diagonal,
322            row_scale,
323            log_evidence,
324        };
325        // Guard against any non-finite leak from a degenerate fit: fall back to a
326        // pure-diagonal model with the same evidence accounting.
327        if !model.is_finite() {
328            model.lambda = Array2::<f64>::zeros((p, rank));
329            model.row_scale = Array1::<f64>::ones(n);
330        }
331        Ok(model)
332    }
333
334    fn is_finite(&self) -> bool {
335        self.lambda.iter().all(|v| v.is_finite())
336            && self.diagonal.iter().all(|v| v.is_finite() && *v > 0.0)
337            && self.row_scale.iter().all(|v| v.is_finite() && *v > 0.0)
338            && self.log_evidence.is_finite()
339    }
340
341    /// Selected factor rank `r`.
342    pub fn factor_rank(&self) -> usize {
343        self.factor_rank
344    }
345
346    /// The fitted interference factor `Λ ∈ ℝ^{p×r}` (the shared off-isotropic
347    /// residual subspace). Consumed by the planted-subspace recovery test to
348    /// compare `range(Λ)` against the planted interference subspace.
349    pub fn factor(&self) -> ArrayView2<'_, f64> {
350        self.lambda.view()
351    }
352
353    /// The idiosyncratic diagonal `d ∈ ℝ^p` (`D = diag(d)`).
354    pub fn diagonal(&self) -> ArrayView1<'_, f64> {
355        self.diagonal.view()
356    }
357
358    /// The per-row activity scale `c(z_n) > 0`, length `n`. Recovers the smooth
359    /// activity-scale law evaluated at every observed `z_n`.
360    pub fn row_scale(&self) -> ArrayView1<'_, f64> {
361        self.row_scale.view()
362    }
363
364    /// The penalized Gaussian log-evidence the rank-selection ladder maximized.
365    pub fn log_evidence(&self) -> f64 {
366        self.log_evidence
367    }
368
369    /// Build the per-row precision factor stack `U_n ∈ ℝ^{p×p}` with
370    /// `U_n U_nᵀ = Σ_n^{-1}` and package it as a
371    /// [`MetricProvenance::WhitenedStructured`](gam_problem::MetricProvenance::WhitenedStructured)
372    /// [`RowMetric`](gam_problem::RowMetric). This is the single
373    /// production site of `WhitenedStructured`.
374    ///
375    /// The precision is formed in **Woodbury form**:
376    /// ```text
377    ///   Σ_n^{-1} = D^{-1} − D^{-1} Λ ( c^{-1} I_r + Λᵀ D^{-1} Λ )^{-1} Λᵀ D^{-1},
378    /// ```
379    /// an `r × r` capacitance solve (never a `p × p` inverse). The factor `U_n`
380    /// is the lower-Cholesky of the assembled `Σ_n^{-1}` (`rank = p`), so
381    /// `whiten_residual_row` returns coordinates whose squared norm is the exact
382    /// Mahalanobis residual `r_nᵀ Σ_n^{-1} r_n`.
383    pub fn row_metric(&self, n_rows: usize) -> Result<RowMetric, String> {
384        if n_rows != self.row_scale.len() {
385            return Err(format!(
386                "StructuredResidualModel::row_metric: requested {n_rows} rows but model has {}",
387                self.row_scale.len()
388            ));
389        }
390        let p = self.p;
391        // Row-major flat factor matrix: u[n, i*p + k] = U_n[i, k].
392        let mut u = Array2::<f64>::zeros((n_rows, p * p));
393        for row in 0..n_rows {
394            let precision = self.row_precision(row)?;
395            let factor = lower_cholesky_psd(&precision)?;
396            for i in 0..p {
397                for k in 0..p {
398                    u[[row, i * p + k]] = factor[[i, k]];
399                }
400            }
401        }
402        RowMetric::whitened_structured(Arc::new(u), p, p)
403    }
404
405    /// Per-row precision `Σ_n^{-1}` via the Woodbury identity (an `r × r` solve).
406    fn row_precision(&self, row: usize) -> Result<Array2<f64>, String> {
407        let p = self.p;
408        let r = self.factor_rank;
409        let c = self.row_scale[row].max(f64::MIN_POSITIVE);
410        // D^{-1}.
411        let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / self.diagonal[i]).collect();
412        // Start from D^{-1}.
413        let mut precision = Array2::<f64>::zeros((p, p));
414        for i in 0..p {
415            precision[[i, i]] = d_inv[i];
416        }
417        if r == 0 {
418            return Ok(precision);
419        }
420        // B = D^{-1} Λ  ∈ ℝ^{p×r}.
421        let mut b = Array2::<f64>::zeros((p, r));
422        for i in 0..p {
423            for k in 0..r {
424                b[[i, k]] = d_inv[i] * self.lambda[[i, k]];
425            }
426        }
427        // Capacitance M = c^{-1} I_r + Λᵀ D^{-1} Λ  ∈ ℝ^{r×r}.
428        let mut cap = Array2::<f64>::zeros((r, r));
429        for a in 0..r {
430            for bk in 0..r {
431                let mut acc = 0.0_f64;
432                for i in 0..p {
433                    acc += self.lambda[[i, a]] * b[[i, bk]];
434                }
435                cap[[a, bk]] = acc;
436            }
437            cap[[a, a]] += 1.0 / c;
438        }
439        // Σ_n^{-1} = D^{-1} − B M^{-1} Bᵀ. Solve M X = Bᵀ for X = M^{-1} Bᵀ
440        // (r × p) via Cholesky (M ≻ 0 since c^{-1} > 0 and ΛᵀD^{-1}Λ ⪰ 0).
441        let chol = cap
442            .cholesky(Side::Lower)
443            .map_err(|e| format!("StructuredResidualModel::row_precision capacitance: {e:?}"))?;
444        let mut bt = Array2::<f64>::zeros((r, p));
445        for k in 0..r {
446            for i in 0..p {
447                bt[[k, i]] = b[[i, k]];
448            }
449        }
450        let x = chol.solve_mat(&bt); // r × p
451        for i in 0..p {
452            for j in 0..p {
453                let mut acc = 0.0_f64;
454                for k in 0..r {
455                    acc += b[[i, k]] * x[[k, j]];
456                }
457                precision[[i, j]] -= acc;
458            }
459        }
460        // Symmetrize against round-off so the Cholesky downstream sees an exactly
461        // symmetric PSD matrix.
462        for i in 0..p {
463            for j in (i + 1)..p {
464                let avg = 0.5 * (precision[[i, j]] + precision[[j, i]]);
465                precision[[i, j]] = avg;
466                precision[[j, i]] = avg;
467            }
468        }
469        Ok(precision)
470    }
471}
472
473/// Per-channel (column) sample second moment of the residual matrix.
474fn column_variances(r: ArrayView2<'_, f64>) -> Array1<f64> {
475    let n = r.nrows();
476    let p = r.ncols();
477    let mut v = Array1::<f64>::zeros(p);
478    for j in 0..p {
479        let mut acc = 0.0_f64;
480        for i in 0..n {
481            acc += r[[i, j]] * r[[i, j]];
482        }
483        v[j] = acc / n as f64;
484    }
485    v
486}
487
488/// Scale-deflated second moment `S = (1/n) Σ_n (r_n r_nᵀ) / c_n`.
489fn scaled_second_moment(r: ArrayView2<'_, f64>, row_scale: &Array1<f64>) -> Array2<f64> {
490    let n = r.nrows();
491    let p = r.ncols();
492    let mut s = Array2::<f64>::zeros((p, p));
493    for i in 0..n {
494        let w = 1.0 / row_scale[i].max(f64::MIN_POSITIVE);
495        for a in 0..p {
496            let ra = r[[i, a]];
497            for b in 0..p {
498                s[[a, b]] += w * ra * r[[i, b]];
499            }
500        }
501    }
502    s.mapv_inplace(|v| v / n as f64);
503    // Symmetrize against accumulation round-off.
504    for a in 0..p {
505        for b in (a + 1)..p {
506            let avg = 0.5 * (s[[a, b]] + s[[b, a]]);
507            s[[a, b]] = avg;
508            s[[b, a]] = avg;
509        }
510    }
511    s
512}
513
514/// Factor coordinates `Λ⁺_D r_n` per row: the generalized-least-squares
515/// projection of each residual onto `range(Λ)` in the `D^{-1}` metric, returned
516/// as an `n × r` matrix. Solves the `r × r` normal equations
517/// `(Λᵀ D^{-1} Λ) γ = Λᵀ D^{-1} r_n` per row (shared factorization).
518fn factor_coordinates(
519    lambda: &Array2<f64>,
520    diagonal: &Array1<f64>,
521    r: ArrayView2<'_, f64>,
522) -> Result<Array2<f64>, String> {
523    let p = lambda.nrows();
524    let rank = lambda.ncols();
525    let n = r.nrows();
526    let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / diagonal[i]).collect();
527    // Normal matrix ΛᵀD^{-1}Λ (+ tiny ridge for invertibility).
528    let mut normal = Array2::<f64>::zeros((rank, rank));
529    for a in 0..rank {
530        for b in 0..rank {
531            let mut acc = 0.0_f64;
532            for i in 0..p {
533                acc += lambda[[i, a]] * d_inv[i] * lambda[[i, b]];
534            }
535            normal[[a, b]] = acc;
536        }
537    }
538    let trace = (0..rank).map(|k| normal[[k, k]]).sum::<f64>().max(1.0);
539    let ridge = 1e-10 * trace / rank.max(1) as f64;
540    for k in 0..rank {
541        normal[[k, k]] += ridge;
542    }
543    let chol = normal
544        .cholesky(Side::Lower)
545        .map_err(|e| format!("factor_coordinates normal solve: {e:?}"))?;
546    let mut coords = Array2::<f64>::zeros((n, rank));
547    let mut rhs = Array1::<f64>::zeros(rank);
548    for i in 0..n {
549        for a in 0..rank {
550            let mut acc = 0.0_f64;
551            for j in 0..p {
552                acc += lambda[[j, a]] * d_inv[j] * r[[i, j]];
553            }
554            rhs[a] = acc;
555        }
556        let gamma = chol.solvevec(&rhs);
557        for a in 0..rank {
558            coords[[i, a]] = gamma[a];
559        }
560    }
561    Ok(coords)
562}
563
564/// 3-point moving average over a bin vector (edge-clamped), giving the smooth
565/// activity-scale law a continuous, low-curvature shape.
566fn moving_average_3(v: &Array1<f64>) -> Array1<f64> {
567    let m = v.len();
568    let mut out = Array1::<f64>::zeros(m);
569    for i in 0..m {
570        let lo = i.saturating_sub(1);
571        let hi = (i + 1).min(m - 1);
572        let mut acc = 0.0_f64;
573        let mut cnt = 0.0_f64;
574        for j in lo..=hi {
575            acc += v[j];
576            cnt += 1.0;
577        }
578        out[i] = acc / cnt;
579    }
580    out
581}
582
583/// Ascending-eigenvalue symmetric eigendecomposition (faer convention).
584fn symmetric_eig_ascending(m: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>), String> {
585    m.eigh(Side::Lower)
586        .map_err(|e| format!("symmetric_eig: {e:?}"))
587}
588
589/// Lower-triangular Cholesky factor `L` of a (numerically) PSD matrix `A` with
590/// `L Lᵀ = A`, with a relative spectral floor so a marginally-indefinite
591/// precision (round-off) still factors. Used to turn `Σ_n^{-1}` into the
592/// `RowMetric` factor `U_n` (here `U_n = L`).
593fn lower_cholesky_psd(a: &Array2<f64>) -> Result<Array2<f64>, String> {
594    if let Ok(chol) = a.cholesky(Side::Lower) {
595        return Ok(chol.lower_triangular());
596    }
597    // Eigen-repair: clamp eigenvalues to a small positive floor and rebuild a
598    // symmetric square root, then Cholesky that (always succeeds, PD).
599    let (evals, evecs) = symmetric_eig_ascending(a)?;
600    let max_ev = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
601    let floor = 1e-10 * max_ev;
602    let p = a.nrows();
603    let mut sqrt = Array2::<f64>::zeros((p, p));
604    for i in 0..p {
605        for j in 0..p {
606            let mut acc = 0.0_f64;
607            for k in 0..p {
608                let ev = evals[k].max(floor);
609                acc += evecs[[i, k]] * ev.sqrt() * evecs[[j, k]];
610            }
611            sqrt[[i, j]] = acc;
612        }
613    }
614    sqrt.cholesky(Side::Lower)
615        .map(|c| c.lower_triangular())
616        .map_err(|e| format!("lower_cholesky_psd eigen-repair: {e:?}"))
617}
618
619/// Penalized Gaussian log-evidence of the structured model at the fitted
620/// parameters — the evidence ladder's rank-selection score.
621///
622/// The per-row log-density of `r_n ~ N(0, Σ_n)` is
623/// `−½ ( log|Σ_n| + r_nᵀ Σ_n^{-1} r_n + p log 2π )`. We sum it across rows and
624/// subtract a parameter-count penalty `½ k_params · log n` (a BIC-style Occam
625/// term over the `p·r` factor entries + `p` diagonal entries + the bin scales),
626/// so adding a spurious factor that does not improve the fit is rejected. Both
627/// `log|Σ_n|` and the quadratic use the Woodbury / matrix-determinant lemma so no
628/// dense `p × p` inverse or determinant is formed.
629fn penalized_log_evidence(
630    r: ArrayView2<'_, f64>,
631    lambda: &Array2<f64>,
632    diagonal: &Array1<f64>,
633    row_scale: &Array1<f64>,
634    rank: usize,
635) -> f64 {
636    let n = r.nrows();
637    let p = r.ncols();
638    let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / diagonal[i]).collect();
639    let log_det_d: f64 = diagonal.iter().map(|&d| d.ln()).sum();
640    let two_pi_ln = (2.0 * std::f64::consts::PI).ln();
641
642    let mut log_lik = 0.0_f64;
643    for i in 0..n {
644        let c = row_scale[i].max(f64::MIN_POSITIVE);
645        // Quadratic r_nᵀ Σ_n^{-1} r_n via Woodbury:
646        //   r_nᵀ D^{-1} r_n − (Bᵀ r_n)ᵀ M^{-1} (Bᵀ r_n),
647        // with B = D^{-1}Λ and M = c^{-1}I + ΛᵀD^{-1}Λ.
648        let mut quad = 0.0_f64;
649        for j in 0..p {
650            quad += r[[i, j]] * d_inv[j] * r[[i, j]];
651        }
652        let mut log_det = log_det_d;
653        if rank > 0 {
654            // M (r × r) and w = Bᵀ r_n = ΛᵀD^{-1} r_n.
655            let mut m = Array2::<f64>::zeros((rank, rank));
656            let mut w = Array1::<f64>::zeros(rank);
657            for a in 0..rank {
658                let mut wa = 0.0_f64;
659                for j in 0..p {
660                    wa += lambda[[j, a]] * d_inv[j] * r[[i, j]];
661                }
662                w[a] = wa;
663                for b in 0..rank {
664                    let mut acc = 0.0_f64;
665                    for j in 0..p {
666                        acc += lambda[[j, a]] * d_inv[j] * lambda[[j, b]];
667                    }
668                    m[[a, b]] = acc;
669                }
670                m[[a, a]] += 1.0 / c;
671            }
672            // Cholesky M = R Rᵀ → log|M|, and solve M y = w.
673            match m.cholesky(Side::Lower) {
674                Ok(chol) => {
675                    let y = chol.solvevec(&w);
676                    let mut wy = 0.0_f64;
677                    for a in 0..rank {
678                        wy += w[a] * y[a];
679                    }
680                    quad -= wy;
681                    // log|Σ_n| = log|D| + log|M| + r·log c   (matrix-determinant
682                    // lemma; the c^{-1}I shift carries the +r·log c).
683                    let diag = chol.diag();
684                    let log_det_m: f64 = diag.iter().map(|&l| (l * l).ln()).sum();
685                    log_det = log_det_d + log_det_m + rank as f64 * c.ln();
686                }
687                Err(_) => {
688                    // Degenerate capacitance — fall back to the diagonal model's
689                    // accounting for this row (no factor correction).
690                    log_det = log_det_d;
691                }
692            }
693        }
694        log_lik += -0.5 * (log_det + quad + p as f64 * two_pi_ln);
695    }
696
697    let k_params = (p * rank + p + ACTIVITY_SCALE_BINS) as f64;
698    log_lik - 0.5 * k_params * (n.max(2) as f64).ln()
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use ndarray::{Array1, Array2};
705
706    fn lcg_uniform(state: &mut u64) -> f64 {
707        *state = state
708            .wrapping_mul(6364136223846793005)
709            .wrapping_add(1442695040888963407);
710        ((*state >> 11) as f64) / ((1u64 << 53) as f64)
711    }
712
713    fn lcg_normal(state: &mut u64) -> f64 {
714        let u1 = lcg_uniform(state).max(1e-12);
715        let u2 = lcg_uniform(state);
716        (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
717    }
718
719    /// Per-rank evidence breakdown on the planted single-factor activity-law
720    /// DGP (the `fitted_scale_recovers_planted_activity_law` plant). Pins the
721    /// rank-selection decision itself: the ladder must prefer rank 1, and this
722    /// test names the margin so an over-selection regression is diagnosable
723    /// from the failure message alone.
724    #[test]
725    fn evidence_ladder_prefers_planted_rank_one() {
726        let n = 5000usize;
727        let p = 4usize;
728        let lambda0 = ndarray::array![[1.5], [1.2], [-0.4], [0.3]];
729        let sigma_eps = 0.2_f64;
730        let slope = 1.3_f64;
731        let mut seed = 0xD1B54A32D192ED03_u64;
732        let mut residuals = Array2::<f64>::zeros((n, p));
733        let mut activity = Array1::<f64>::zeros(n);
734        for row in 0..n {
735            let z = (row as f64) / (n as f64 - 1.0);
736            activity[row] = z;
737            let amp = (slope * z).exp().sqrt();
738            let f = lcg_normal(&mut seed);
739            for i in 0..p {
740                residuals[[row, i]] = amp * lambda0[[i, 0]] * f + sigma_eps * lcg_normal(&mut seed);
741            }
742        }
743        // Reproduce fit()'s bin assignment, then score each rank directly.
744        let bins = ACTIVITY_SCALE_BINS.max(1);
745        let row_bin: Vec<usize> = (0..n)
746            .map(|i| {
747                let frac = activity[i];
748                (frac * bins as f64).floor().clamp(0.0, bins as f64 - 1.0) as usize
749            })
750            .collect();
751        let mut report = String::new();
752        let mut ev = Vec::new();
753        for rank in 0..=2usize {
754            let m = StructuredResidualModel::fit_fixed_rank(residuals.view(), &row_bin, bins, rank)
755                .expect("fixed-rank fit");
756            let k_params = (p * rank + p + ACTIVITY_SCALE_BINS) as f64;
757            let log_lik = m.log_evidence() + 0.5 * k_params * (n as f64).ln();
758            let col_norms: Vec<f64> = (0..rank)
759                .map(|k| {
760                    m.factor()
761                        .column(k)
762                        .iter()
763                        .map(|v| v * v)
764                        .sum::<f64>()
765                        .sqrt()
766                })
767                .collect();
768            report.push_str(&format!(
769                "rank {rank}: evidence={:.3} loglik={:.3} penalty={:.3} col_norms={:?} diag={:?}\n",
770                m.log_evidence(),
771                log_lik,
772                0.5 * k_params * (n as f64).ln(),
773                col_norms,
774                m.diagonal()
775                    .iter()
776                    .map(|v| (v * 1e4).round() / 1e4)
777                    .collect::<Vec<_>>()
778            ));
779            ev.push(m.log_evidence());
780        }
781        assert!(
782            ev[1] > ev[0] && ev[1] > ev[2],
783            "evidence ladder must prefer the planted rank 1; breakdown:\n{report}"
784        );
785    }
786
787    /// Orthonormalize the columns of `m` (modified Gram–Schmidt), dropping
788    /// numerically-null columns. Test-side helper for subspace comparisons.
789    fn orthonormal_columns(m: ArrayView2<'_, f64>) -> Vec<Array1<f64>> {
790        let mut basis: Vec<Array1<f64>> = Vec::new();
791        for k in 0..m.ncols() {
792            let mut v = m.column(k).to_owned();
793            for q in &basis {
794                let c = v.dot(q);
795                v = &v - &(q * c);
796            }
797            let norm = v.dot(&v).sqrt();
798            if norm > 1e-10 {
799                basis.push(v / norm);
800            }
801        }
802        basis
803    }
804
805    /// Squared norm of the projection of unit vector `v` onto span(basis) —
806    /// `cos²` of the principal angle between `v` and the subspace.
807    fn projection_energy(v: &Array1<f64>, basis: &[Array1<f64>]) -> f64 {
808        basis.iter().map(|q| v.dot(q).powi(2)).sum()
809    }
810
811    /// #974 verification arm (a): the fitted factor must recover the PLANTED
812    /// interference subspace. Two orthogonal planted directions with distinct
813    /// strengths; the principal angles between each planted direction and
814    /// range(Λ̂) must be small, and the evidence ladder must select rank 2.
815    #[test]
816    fn factor_recovers_planted_interference_subspace() {
817        let n = 6000usize;
818        let p = 6usize;
819        // Two orthogonal planted unit directions.
820        let raw1: Array1<f64> = ndarray::array![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
821        let raw2: Array1<f64> = ndarray::array![1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
822        let v1 = &raw1 / raw1.dot(&raw1).sqrt();
823        let v2 = &raw2 / raw2.dot(&raw2).sqrt();
824        let (amp1, amp2) = (1.4_f64, 0.9_f64);
825        let sigma_eps = 0.15_f64;
826
827        let mut seed = 0x9E3779B97F4A7C15_u64;
828        let mut residuals = Array2::<f64>::zeros((n, p));
829        let activity = Array1::<f64>::zeros(n); // constant ⇒ homoscedastic law
830        for row in 0..n {
831            let f1 = amp1 * lcg_normal(&mut seed);
832            let f2 = amp2 * lcg_normal(&mut seed);
833            for i in 0..p {
834                residuals[[row, i]] = f1 * v1[i] + f2 * v2[i] + sigma_eps * lcg_normal(&mut seed);
835            }
836        }
837
838        let model = StructuredResidualModel::fit(ResidualFactorInput {
839            residuals: residuals.view(),
840            activity: activity.view(),
841            max_factor_rank: 4,
842        })
843        .expect("fit");
844
845        assert_eq!(
846            model.factor_rank(),
847            2,
848            "ladder must select the planted rank 2 (got {}, evidence {:.3})",
849            model.factor_rank(),
850            model.log_evidence()
851        );
852        let basis = orthonormal_columns(model.factor());
853        assert_eq!(basis.len(), 2, "fitted factor must span 2 directions");
854        let e1 = projection_energy(&v1, &basis);
855        let e2 = projection_energy(&v2, &basis);
856        // cos² of each principal angle ≥ 0.95 ⇒ angle ≤ ~13°.
857        assert!(
858            e1 > 0.95 && e2 > 0.95,
859            "planted directions must lie in range(Λ̂): cos² = ({e1:.4}, {e2:.4})"
860        );
861    }
862
863    /// #974 verification arm (d): recovery of the planted activity-variance
864    /// law. Single planted factor with per-row energy `exp(slope·z)`; the
865    /// fitted `c(z_n)` must reproduce the law's shape — strongly correlated
866    /// with the planted log-scale and with the right dynamic range.
867    #[test]
868    fn fitted_scale_recovers_planted_activity_law() {
869        let n = 6000usize;
870        let p = 4usize;
871        let lambda0 = ndarray::array![1.5, 1.2, -0.4, 0.3];
872        let sigma_eps = 0.2_f64;
873        let slope = 1.3_f64;
874        let mut seed = 0xD1B54A32D192ED03_u64;
875        let mut residuals = Array2::<f64>::zeros((n, p));
876        let mut activity = Array1::<f64>::zeros(n);
877        for row in 0..n {
878            let z = (row as f64) / (n as f64 - 1.0);
879            activity[row] = z;
880            let amp = (slope * z).exp().sqrt();
881            let f = lcg_normal(&mut seed);
882            for i in 0..p {
883                residuals[[row, i]] = amp * lambda0[i] * f + sigma_eps * lcg_normal(&mut seed);
884            }
885        }
886
887        let model = StructuredResidualModel::fit(ResidualFactorInput {
888            residuals: residuals.view(),
889            activity: activity.view(),
890            max_factor_rank: 2,
891        })
892        .expect("fit");
893        assert_eq!(model.factor_rank(), 1, "planted rank is 1");
894
895        // Pearson correlation between fitted log c(z_n) and the planted
896        // log-law slope·z (mean-1 normalization cancels in the correlation).
897        let fitted_log: Vec<f64> = model.row_scale().iter().map(|c| c.ln()).collect();
898        let planted_log: Vec<f64> = activity.iter().map(|z| slope * z).collect();
899        let mean_f = fitted_log.iter().sum::<f64>() / n as f64;
900        let mean_p = planted_log.iter().sum::<f64>() / n as f64;
901        let mut cov = 0.0_f64;
902        let mut var_f = 0.0_f64;
903        let mut var_p = 0.0_f64;
904        for i in 0..n {
905            let df = fitted_log[i] - mean_f;
906            let dp = planted_log[i] - mean_p;
907            cov += df * dp;
908            var_f += df * df;
909            var_p += dp * dp;
910        }
911        let corr = cov / (var_f.sqrt() * var_p.sqrt());
912        assert!(
913            corr > 0.9,
914            "fitted activity law must track the planted exp({slope}·z): corr = {corr:.4}"
915        );
916
917        // Dynamic range: planted c(top)/c(bottom) over the inner bin centers
918        // is exp(slope·7/8) ≈ 3.1; the binned/smoothed estimate must land in
919        // a generous bracket around it (smoothing shrinks the edges).
920        let lo = model.row_scale()[n / 16]; // first-bin interior
921        let hi = model.row_scale()[n - 1 - n / 16]; // last-bin interior
922        let ratio = hi / lo;
923        assert!(
924            ratio > 1.8 && ratio < 5.5,
925            "fitted dynamic range {ratio:.3} must bracket the planted ≈3.1"
926        );
927    }
928}