Skip to main content

gam_solve/
gaussian_reml.rs

1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{
3    FaerCholesky, FaerEigh, fast_ab, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
4};
5use faer::Side;
6use ndarray::{
7    Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, ArrayViewMut1, ArrayViewMut2, Axis,
8    s,
9};
10use rayon::prelude::*;
11use std::sync::Once;
12
13/// One-time warning latch for backward-pass graceful degradation on a
14/// near-singular penalized Hessian `K = XᵀWX + λS`. When `λ_k` saturates
15/// (e.g. 1e10+), `K` becomes effectively rank-deficient and the analytic VJP
16/// cannot be evaluated. Rather than raising, the backward returns zero
17/// gradients of the correct shape: this is the statistically correct
18/// "shrink-out" gradient — when `λ` has saturated, the atom is unused, so
19/// every input's contribution to the loss is zero in the limit.
20static ILL_CONDITIONED_BACKWARD_WARNED: Once = Once::new();
21
22fn warn_ill_conditioned_backward_once(p: usize, d: usize, condition_number: f64) {
23    ILL_CONDITIONED_BACKWARD_WARNED.call_once(|| {
24        log::warn!(
25            "gaussian_reml_fit_backward: K = XᵀWX + λS is near-singular \
26             (p={p}, d={d}, cond≈{condition_number:.2e}); returning zero gradients \
27             for this fit (λ has saturated, atom is effectively unused). \
28             Further occurrences are silent."
29        );
30    });
31}
32
33fn zero_backward_result(n: usize, p: usize, d: usize) -> GaussianRemlBackwardResult {
34    GaussianRemlBackwardResult {
35        grad_x: Array2::<f64>::zeros((n, p)),
36        grad_y: Array2::<f64>::zeros((n, d)),
37        grad_penalty: Array2::<f64>::zeros((p, p)),
38        grad_weights: Array1::<f64>::zeros(n),
39    }
40}
41
42const RHO_LOWER: f64 = -30.0;
43const RHO_UPPER: f64 = 30.0;
44const EIGEN_REL_TOL: f64 = 1.0e-10;
45const GRAD_TOL: f64 = 1.0e-12;
46const MIN_DEVIANCE: f64 = 1.0e-300;
47
48/// Canonicalize a penalty matrix to its symmetric average.
49///
50/// Closed-form Gaussian REML treats `S` as symmetric throughout — the
51/// eigendecomposition, the pseudo-determinant `log|S|₊`, the rank detector,
52/// and every per-helper VJP all assume `S = Sᵀ`. To make that contract
53/// explicit (rather than implicit in `eigh(Side::Lower)` reading the lower
54/// triangle and silently ignoring the upper), every entry point that takes a
55/// penalty matrix replaces it with `0.5 (S + Sᵀ)` before any downstream use.
56/// For symmetric input this is a numerical no-op; for asymmetric input it
57/// defines the function as operating on the symmetric average.
58fn canonicalize_penalty(penalty: ArrayView2<'_, f64>) -> Array2<f64> {
59    let p = penalty.nrows();
60    let mut out = penalty.to_owned();
61    for i in 0..p {
62        for j in (i + 1)..p {
63            let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
64            out[[i, j]] = avg;
65            out[[j, i]] = avg;
66        }
67    }
68    out
69}
70
71#[derive(Clone, Debug)]
72pub struct GaussianRemlEigenCache {
73    pub penalty_eigenvalues: Array1<f64>,
74    pub eigenvectors: Array2<f64>,
75    pub coefficient_basis: Array2<f64>,
76    pub xtwx_fingerprint: u64,
77    pub penalty_fingerprint: u64,
78    pub logdet_xtwx: f64,
79    pub logdet_penalty_positive: f64,
80    pub penalty_rank: usize,
81    pub nullity: usize,
82}
83
84#[derive(Clone, Debug, Default)]
85pub struct GaussianRemlWarmStart {
86    pub lambda: Option<f64>,
87    pub eigen_cache: Option<GaussianRemlEigenCache>,
88}
89
90impl GaussianRemlWarmStart {
91    pub fn from_multi_result(result: &GaussianRemlMultiResult) -> Self {
92        Self {
93            lambda: Some(result.lambda),
94            eigen_cache: Some(result.cache.clone()),
95        }
96    }
97}
98
99#[derive(Clone, Debug)]
100pub struct GaussianRemlResult {
101    pub lambda: f64,
102    pub rho: f64,
103    pub coefficients: Array1<f64>,
104    pub fitted: Array1<f64>,
105    pub reml_score: f64,
106    pub reml_grad_lambda: f64,
107    pub reml_hess_lambda: f64,
108    pub reml_grad_rho: f64,
109    pub reml_hess_rho: f64,
110    pub edf: f64,
111    pub sigma2: f64,
112    pub cache: GaussianRemlEigenCache,
113}
114
115#[derive(Clone, Debug)]
116pub struct GaussianRemlMultiResult {
117    pub lambda: f64,
118    pub rho: f64,
119    pub coefficients: Array2<f64>,
120    pub fitted: Array2<f64>,
121    pub reml_score: f64,
122    pub reml_grad_lambda: f64,
123    pub reml_hess_lambda: f64,
124    pub reml_grad_rho: f64,
125    pub reml_hess_rho: f64,
126    pub edf: f64,
127    pub sigma2: Array1<f64>,
128    pub cache: GaussianRemlEigenCache,
129}
130
131#[derive(Clone, Debug)]
132pub struct GaussianRemlFreeBScore {
133    pub reml_score: f64,
134    pub grad_coefficients: Array2<f64>,
135    pub grad_penalty: Array2<f64>,
136    pub grad_log_lambda: f64,
137    pub fitted: Array2<f64>,
138    pub sigma2: Array1<f64>,
139    pub edf: f64,
140}
141
142#[derive(Clone, Debug)]
143pub struct GaussianRemlBackwardResult {
144    pub grad_x: Array2<f64>,
145    pub grad_y: Array2<f64>,
146    pub grad_penalty: Array2<f64>,
147    pub grad_weights: Array1<f64>,
148}
149
150#[derive(Clone, Debug)]
151pub struct GaussianRemlMultiBackwardProblem<'a> {
152    pub x: ArrayView2<'a, f64>,
153    pub y: ArrayView2<'a, f64>,
154    pub weights: Option<ArrayView1<'a, f64>>,
155    pub fit: &'a GaussianRemlMultiResult,
156    pub grad_lambda: f64,
157    pub grad_coefficients: Option<ArrayView2<'a, f64>>,
158    pub grad_fitted: Option<ArrayView2<'a, f64>>,
159    pub grad_reml_score: f64,
160    pub grad_edf: f64,
161}
162
163#[derive(Clone, Debug)]
164pub struct GaussianRemlNoAllocWorkspace {
165    pub xtwy: Array2<f64>,
166    pub ywy: Array1<f64>,
167    pub projected_rhs: Array2<f64>,
168    pub projected_rhs_squared: Array2<f64>,
169    pub scaled_projected_rhs: Array2<f64>,
170}
171
172impl GaussianRemlNoAllocWorkspace {
173    pub fn new(n_coefficients: usize, n_outputs: usize) -> Self {
174        Self {
175            xtwy: Array2::zeros((n_coefficients, n_outputs)),
176            ywy: Array1::zeros(n_outputs),
177            projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
178            projected_rhs_squared: Array2::zeros((n_coefficients, n_outputs)),
179            scaled_projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
180        }
181    }
182
183    fn validate(&self, p: usize, d: usize) -> Result<(), EstimationError> {
184        if self.xtwy.dim() != (p, d)
185            || self.ywy.len() != d
186            || self.projected_rhs.dim() != (p, d)
187            || self.projected_rhs_squared.dim() != (p, d)
188            || self.scaled_projected_rhs.dim() != (p, d)
189        {
190            crate::bail_invalid_estim!(
191                "Gaussian REML no-alloc workspace shape mismatch: expected p={p}, d={d}"
192            );
193        }
194        Ok::<(), _>(())
195    }
196}
197
198#[derive(Clone, Copy, Debug)]
199pub struct GaussianRemlNoAllocFit {
200    pub lambda: f64,
201    pub rho: f64,
202    pub reml_score: f64,
203    pub reml_grad_lambda: f64,
204    pub reml_hess_lambda: f64,
205    pub reml_grad_rho: f64,
206    pub reml_hess_rho: f64,
207    pub edf: f64,
208}
209
210#[derive(Clone, Debug)]
211pub struct GaussianRemlMultiBatchProblem<'a> {
212    pub x: ArrayView2<'a, f64>,
213    pub y: ArrayView2<'a, f64>,
214    pub weights: Option<ArrayView1<'a, f64>>,
215    pub init_rho: Option<f64>,
216}
217
218#[derive(Clone, Debug)]
219pub struct GaussianRemlBlockOrthogonalResult {
220    pub coefficients: Vec<Array2<f64>>,
221    pub fitted: Array2<f64>,
222    pub lambdas: Array1<f64>,
223    pub log_lambdas: Array1<f64>,
224    pub reml_score: f64,
225    pub edf: Array1<f64>,
226}
227
228#[derive(Clone)]
229struct GaussianRemlPrepared {
230    cache: GaussianRemlEigenCache,
231    ywy: Array1<f64>,
232    projected_rhs_squared: Array2<f64>,
233    projected_rhs: Array2<f64>,
234    /// Number of rows with a strictly positive prior weight — the effective
235    /// sample size that enters the REML residual degrees of freedom `ν`. Rows
236    /// with weight `0` are excluded (see [`effective_observation_count`]).
237    n_effective: usize,
238    n_outputs: usize,
239}
240
241#[derive(Clone, Copy)]
242struct ObjectiveEval {
243    cost: f64,
244    grad: f64,
245    hess: f64,
246    edf: f64,
247}
248
249/// A single Gaussian closed-form REML objective term, carrying its analytic
250/// VALUE together with its analytic ρ-GRADIENT and ρ-HESSIAN.
251///
252/// Single source of truth: each term's value and its (already hand-derived,
253/// closed-form) ρ-derivatives are returned from ONE function body, so a future
254/// edit to the value formula cannot silently leave the derivatives stale.
255/// Mirrors the `PenaltyLogdetDerivs`-returning-tuple pattern used by the
256/// unified outer evaluator — the structural cure for the objective↔gradient
257/// desync class (#752/#748/#808). The three contributions are accumulated
258/// through [`ObjectiveEval`] at one site, so they cannot drift apart.
259#[derive(Clone, Copy)]
260struct TermDerivs {
261    value: f64,
262    grad: f64,
263    hess: f64,
264}
265
266impl std::ops::AddAssign<TermDerivs> for ObjectiveEval {
267    /// Fold a term's `(value, grad, hess)` triple into the running totals in
268    /// lock-step, so value and derivative can never be added at separate sites.
269    fn add_assign(&mut self, rhs: TermDerivs) {
270        self.cost += rhs.value;
271        self.grad += rhs.grad;
272        self.hess += rhs.hess;
273    }
274}
275
276/// `½d·(log|H| − log|S|_+)` value with its analytic ρ-gradient/Hessian.
277///
278/// The penalty-eigenvalue sum produces all three quantities from the SAME
279/// `t = λδ` intermediates in one pass, so the value (`log|1+t|`) and its
280/// derivatives (`t/(1+t)`, `t/(1+t)²`) are single-sourced.
281fn gaussian_reml_logdet_term(
282    cache: &GaussianRemlEigenCache,
283    rho: f64,
284    n_outputs: f64,
285) -> (TermDerivs, f64) {
286    let lambda = rho.exp();
287    let mut logdet_h = cache.logdet_xtwx;
288    let mut trace_h = 0.0;
289    let mut trace_h_deriv = 0.0;
290    let mut edf = 0.0;
291    for &delta in &cache.penalty_eigenvalues {
292        let t = lambda * delta;
293        logdet_h += (1.0 + t).ln();
294        if delta > 0.0 {
295            trace_h += t / (1.0 + t);
296            trace_h_deriv += t / ((1.0 + t) * (1.0 + t));
297        }
298        edf += 1.0 / (1.0 + t);
299    }
300    let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * rho;
301    let term = TermDerivs {
302        value: 0.5 * n_outputs * (logdet_h - logdet_s),
303        grad: 0.5 * n_outputs * (trace_h - cache.penalty_rank as f64),
304        hess: 0.5 * n_outputs * trace_h_deriv,
305    };
306    (term, edf)
307}
308
309/// Per-output dispersion-prior term `½ν·(1 + log(2π·dp/ν))` with its analytic
310/// ρ-gradient/Hessian.
311///
312/// `dp`, `dp_grad`, `dp_hess` are computed from the SAME eigenvalue sum, then
313/// the value `log(dp)` and its derivatives `dp_grad/dp`,
314/// `dp_hess/dp − (dp_grad/dp)²` are returned together so they cannot desync.
315fn gaussian_reml_dispersion_term(
316    cache: &GaussianRemlEigenCache,
317    ywy: ArrayView1<'_, f64>,
318    projected_rhs_squared: ArrayView2<'_, f64>,
319    output: usize,
320    nu: f64,
321    lambda: f64,
322) -> TermDerivs {
323    let mut fitted_quadratic = 0.0;
324    let mut dp_grad = 0.0;
325    let mut dp_hess = 0.0;
326    for eig in 0..cache.penalty_eigenvalues.len() {
327        let c2 = projected_rhs_squared[[eig, output]];
328        let t = lambda * cache.penalty_eigenvalues[eig];
329        let denom = 1.0 + t;
330        fitted_quadratic += c2 / denom;
331        dp_grad += c2 * t / (denom * denom);
332        dp_hess += c2 * t * (1.0 - t) / (denom * denom * denom);
333    }
334    let dp = (ywy[output] - fitted_quadratic).max(MIN_DEVIANCE);
335    TermDerivs {
336        value: 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln()),
337        grad: 0.5 * nu * dp_grad / dp,
338        hess: 0.5 * nu * (dp_hess / dp - (dp_grad * dp_grad) / (dp * dp)),
339    }
340}
341
342pub fn gaussian_reml_closed_form(
343    x: ArrayView2<'_, f64>,
344    y: ArrayView1<'_, f64>,
345    penalty: ArrayView2<'_, f64>,
346    weights: Option<ArrayView1<'_, f64>>,
347    init_rho: Option<f64>,
348) -> Result<GaussianRemlResult, EstimationError> {
349    gaussian_reml_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
350}
351
352pub fn gaussian_reml_closed_form_with_nullspace_dim(
353    x: ArrayView2<'_, f64>,
354    y: ArrayView1<'_, f64>,
355    penalty: ArrayView2<'_, f64>,
356    nullspace_dim: Option<usize>,
357    weights: Option<ArrayView1<'_, f64>>,
358    init_rho: Option<f64>,
359) -> Result<GaussianRemlResult, EstimationError> {
360    let y2 = y.insert_axis(Axis(1));
361    let result = gaussian_reml_multi_closed_form_with_nullspace_dim(
362        x,
363        y2,
364        penalty,
365        nullspace_dim,
366        weights,
367        init_rho,
368    )?;
369    scalar_result_from_multi(result)
370}
371
372fn scalar_result_from_multi(
373    result: GaussianRemlMultiResult,
374) -> Result<GaussianRemlResult, EstimationError> {
375    Ok(GaussianRemlResult {
376        lambda: result.lambda,
377        rho: result.rho,
378        coefficients: result.coefficients.column(0).to_owned(),
379        fitted: result.fitted.column(0).to_owned(),
380        reml_score: result.reml_score,
381        reml_grad_lambda: result.reml_grad_lambda,
382        reml_hess_lambda: result.reml_hess_lambda,
383        reml_grad_rho: result.reml_grad_rho,
384        reml_hess_rho: result.reml_hess_rho,
385        edf: result.edf,
386        sigma2: result.sigma2[0],
387        cache: result.cache,
388    })
389}
390
391pub fn gaussian_reml_multi_closed_form(
392    x: ArrayView2<'_, f64>,
393    y: ArrayView2<'_, f64>,
394    penalty: ArrayView2<'_, f64>,
395    weights: Option<ArrayView1<'_, f64>>,
396    init_rho: Option<f64>,
397) -> Result<GaussianRemlMultiResult, EstimationError> {
398    gaussian_reml_multi_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
399}
400
401pub fn gaussian_reml_multi_closed_form_with_nullspace_dim(
402    x: ArrayView2<'_, f64>,
403    y: ArrayView2<'_, f64>,
404    penalty: ArrayView2<'_, f64>,
405    nullspace_dim: Option<usize>,
406    weights: Option<ArrayView1<'_, f64>>,
407    init_rho: Option<f64>,
408) -> Result<GaussianRemlMultiResult, EstimationError> {
409    let init_lambda = init_rho.map(f64::exp);
410    gaussian_reml_multi_closed_form_from_parts(
411        x,
412        y,
413        penalty,
414        nullspace_dim,
415        weights,
416        init_lambda,
417        None,
418    )
419}
420
421pub fn gaussian_reml_multi_closed_form_warm_started(
422    x: ArrayView2<'_, f64>,
423    y: ArrayView2<'_, f64>,
424    penalty: ArrayView2<'_, f64>,
425    weights: Option<ArrayView1<'_, f64>>,
426    warm_start: Option<&GaussianRemlWarmStart>,
427) -> Result<GaussianRemlMultiResult, EstimationError> {
428    gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
429        x, y, penalty, None, weights, warm_start,
430    )
431}
432
433pub fn gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
434    x: ArrayView2<'_, f64>,
435    y: ArrayView2<'_, f64>,
436    penalty: ArrayView2<'_, f64>,
437    nullspace_dim: Option<usize>,
438    weights: Option<ArrayView1<'_, f64>>,
439    warm_start: Option<&GaussianRemlWarmStart>,
440) -> Result<GaussianRemlMultiResult, EstimationError> {
441    let init_lambda = warm_start.and_then(|start| start.lambda);
442    let eigen_cache = warm_start.and_then(|start| start.eigen_cache.as_ref());
443    gaussian_reml_multi_closed_form_from_parts(
444        x,
445        y,
446        penalty,
447        nullspace_dim,
448        weights,
449        init_lambda,
450        eigen_cache,
451    )
452}
453
454pub fn gaussian_reml_multi_closed_form_with_cache(
455    x: ArrayView2<'_, f64>,
456    y: ArrayView2<'_, f64>,
457    penalty: ArrayView2<'_, f64>,
458    weights: Option<ArrayView1<'_, f64>>,
459    init_lambda: Option<f64>,
460    eigen_cache: Option<&GaussianRemlEigenCache>,
461) -> Result<GaussianRemlMultiResult, EstimationError> {
462    gaussian_reml_multi_closed_form_from_parts(
463        x,
464        y,
465        penalty,
466        None,
467        weights,
468        init_lambda,
469        eigen_cache,
470    )
471}
472
473pub fn gaussian_reml_multi_closed_form_with_cache_no_alloc(
474    x: ArrayView2<'_, f64>,
475    y: ArrayView2<'_, f64>,
476    penalty: ArrayView2<'_, f64>,
477    weights: Option<ArrayView1<'_, f64>>,
478    init_lambda: Option<f64>,
479    eigen_cache: &GaussianRemlEigenCache,
480    workspace: &mut GaussianRemlNoAllocWorkspace,
481    mut coefficients: ArrayViewMut2<'_, f64>,
482    mut fitted: ArrayViewMut2<'_, f64>,
483    mut sigma2: ArrayViewMut1<'_, f64>,
484) -> Result<GaussianRemlNoAllocFit, EstimationError> {
485    // Match the symmetric-S contract used by the cache builder: the
486    // fingerprint check below compares against a fingerprint computed on the
487    // canonicalized penalty, so the input must be canonicalized first.
488    let penalty_owned = canonicalize_penalty(penalty);
489    let penalty = penalty_owned.view();
490    let n = x.nrows();
491    let p = x.ncols();
492    let d = y.ncols();
493    validate_gaussian_reml_design(x, penalty, weights)?;
494    validate_gaussian_reml_eigen_cache(eigen_cache, p)?;
495    if y.nrows() != n {
496        crate::bail_invalid_estim!(
497            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
498            y.nrows()
499        );
500    }
501    if y.iter().any(|value| !value.is_finite()) {
502        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
503    }
504    let n_effective = match weights {
505        Some(w) => effective_observation_count(w),
506        None => n,
507    };
508    if n_effective <= eigen_cache.nullity {
509        crate::bail_invalid_estim!(
510            "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
511            eigen_cache.nullity
512        );
513    }
514    let penalty_fingerprint = matrix_fingerprint(penalty);
515    if eigen_cache.penalty_fingerprint != penalty_fingerprint {
516        crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
517    }
518    workspace.validate(p, d)?;
519    if coefficients.dim() != (p, d) || fitted.dim() != (n, d) || sigma2.len() != d {
520        crate::bail_invalid_estim!(
521            "Gaussian REML no-alloc output shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
522        );
523    }
524    if let Some(lambda) = init_lambda {
525        validate_initial_lambda(lambda)?;
526    }
527
528    fill_weighted_rhs_no_alloc(x, y, weights, workspace)?;
529    project_rhs_no_alloc(eigen_cache, workspace);
530
531    let init_rho = init_lambda.map(f64::ln);
532    let rho = optimize_rho_no_alloc(
533        eigen_cache,
534        workspace.ywy.view(),
535        workspace.projected_rhs_squared.view(),
536        n_effective,
537        d,
538        init_rho,
539    )?;
540    let eval = evaluate_reml_parts(
541        eigen_cache,
542        workspace.ywy.view(),
543        workspace.projected_rhs_squared.view(),
544        n_effective,
545        d,
546        rho,
547    );
548    let lambda = rho.exp();
549    fill_coefficients_no_alloc(eigen_cache, workspace, lambda, coefficients.view_mut());
550    fill_fitted_no_alloc(x, coefficients.view(), fitted.view_mut());
551    fill_sigma2_no_alloc(
552        eigen_cache,
553        workspace.ywy.view(),
554        workspace.projected_rhs_squared.view(),
555        n_effective,
556        d,
557        lambda,
558        sigma2.view_mut(),
559    );
560    let (reml_grad_lambda, reml_hess_lambda) =
561        rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
562    Ok(GaussianRemlNoAllocFit {
563        lambda,
564        rho,
565        reml_score: eval.cost,
566        reml_grad_lambda,
567        reml_hess_lambda,
568        reml_grad_rho: eval.grad,
569        reml_hess_rho: eval.hess,
570        edf: eval.edf,
571    })
572}
573
574
575pub fn gaussian_reml_multi_closed_form_batch<'a>(
576    problems: &[GaussianRemlMultiBatchProblem<'a>],
577    penalty: ArrayView2<'a, f64>,
578    nullspace_dim: Option<usize>,
579) -> Result<Vec<GaussianRemlMultiResult>, EstimationError> {
580    if problems.is_empty() {
581        return Ok(Vec::new());
582    }
583    // Phase A: par_iter compute X'WX per problem (the only per-fit step that
584    // depends on `n_b`; remaining work is `O(p)` and can amortize through
585    // `_with_cache`).
586    let xtwx_per_problem: Vec<Array2<f64>> = problems
587        .par_iter()
588        .map(|problem| {
589            let weight = match problem.weights.as_ref() {
590                Some(w) => w.to_owned(),
591                None => Array1::ones(problem.x.nrows()),
592            };
593            dense_xt_diag_x(problem.x.view(), weight.view())
594        })
595        .collect();
596    // Phase B: one batched cuSOLVER Cholesky when policy approves uniform p
597    // and K aggregate FLOPs; otherwise the cache builder uses the normal
598    // per-fit non-GPU factorization path.
599    let caches =
600        build_gaussian_reml_eigen_cache_batched(xtwx_per_problem, penalty.view(), nullspace_dim);
601    // Phase C: par_iter finish each fit with its prebuilt cache. A cache-build
602    // error is a real per-problem error, not a signal to rebuild through a
603    // second path.
604    let fits: Vec<Result<GaussianRemlMultiResult, EstimationError>> = problems
605        .par_iter()
606        .zip(caches.into_par_iter())
607        .map(|(problem, cache_result)| {
608            let init_lambda = problem.init_rho.map(f64::exp);
609            let cache = cache_result?;
610            gaussian_reml_multi_closed_form_from_parts(
611                problem.x.view(),
612                problem.y.view(),
613                penalty.view(),
614                nullspace_dim,
615                problem.weights.as_ref().map(|weights| weights.view()),
616                init_lambda,
617                Some(&cache),
618            )
619        })
620        .collect();
621    fits.into_iter().collect()
622}
623
624struct BlockOrthogonalEval {
625    beta: Array2<f64>,
626    logdet: f64,
627    trace: f64,
628    trace_pair: f64,
629    fitted_energy: Array1<f64>,
630    penalty_energy: Array1<f64>,
631    curvature_energy: Array1<f64>,
632    edf: f64,
633}
634
635fn block_penalty_rank_logdet(
636    penalty: ArrayView2<'_, f64>,
637) -> Result<(usize, f64), EstimationError> {
638    let eigs = penalty
639        .to_owned()
640        .eigh(Side::Lower)
641        .map_err(|_| EstimationError::ModelIsIllConditioned {
642            condition_number: f64::INFINITY,
643        })?
644        .0;
645    let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
646    let tol = (EIGEN_REL_TOL * max_abs).max(1.0e-14);
647    let mut rank = 0_usize;
648    let mut logdet = 0.0;
649    for eig in eigs.iter().copied() {
650        if eig > tol {
651            rank += 1;
652            logdet += eig.ln();
653        }
654    }
655    Ok((rank, logdet))
656}
657
658fn block_orthogonal_eval(
659    gram: &Array2<f64>,
660    rhs: &Array2<f64>,
661    penalty: &Array2<f64>,
662    rho: f64,
663) -> Result<BlockOrthogonalEval, EstimationError> {
664    let lambda = rho.exp();
665    validate_initial_lambda(lambda)?;
666    let scaled_penalty = penalty * lambda;
667    let hessian = canonicalize_penalty((gram + &scaled_penalty).view());
668    let chol = gaussian_reml_cholesky_lower(hessian)?;
669    let beta = solve_spd_from_lower_factor(&chol, rhs)?;
670    let solved_penalty = solve_spd_from_lower_factor(&chol, &scaled_penalty)?;
671    let logdet = 2.0 * chol.diag().iter().map(|value| value.ln()).sum::<f64>();
672    let trace = (0..solved_penalty.nrows())
673        .map(|i| solved_penalty[[i, i]])
674        .sum::<f64>();
675    let trace_pair =
676        gam_linalg::utils::trace_of_product(solved_penalty.view(), solved_penalty.view());
677    let fitted_energy = (rhs * &beta).sum_axis(Axis(0));
678    let p_beta = scaled_penalty.dot(&beta);
679    let penalty_energy = (&beta * &p_beta).sum_axis(Axis(0));
680    let solved_p_beta = solve_spd_from_lower_factor(&chol, &p_beta)?;
681    let curvature_energy = (&p_beta * &solved_p_beta).sum_axis(Axis(0));
682    Ok(BlockOrthogonalEval {
683        beta,
684        logdet,
685        trace,
686        trace_pair,
687        fitted_energy,
688        penalty_energy,
689        curvature_energy,
690        edf: penalty.nrows() as f64 - trace,
691    })
692}
693
694/// Block-orthogonal shared-scale REML objective VALUE together with its
695/// analytic ρ-gradient and ρ-Hessian.
696///
697/// Single source of truth: the value `½d·logdet − ½·fit − ½d·rank·ρ` and its
698/// ρ-derivatives are returned from ONE function body, so a future edit to the
699/// objective cannot leave the Newton gradient/Hessian (previously written at a
700/// physically separate site inside `solve_block_orthogonal_rho`) stale. This
701/// closes a genuine `(value_here, gradient_there)` loose pair. Mirrors the
702/// `PenaltyLogdetDerivs` single-source pattern; behavior is identical (the same
703/// closed-form formulas, reorganized).
704struct BlockOrthogonalScaleDerivs {
705    value: f64,
706    grad: f64,
707    hess: f64,
708}
709
710fn block_orthogonal_scale_objective(
711    eval: &BlockOrthogonalEval,
712    rho: f64,
713    scale_precision: ArrayView1<'_, f64>,
714    rank: usize,
715) -> BlockOrthogonalScaleDerivs {
716    let d = scale_precision.len() as f64;
717    let fit_term = scale_precision
718        .iter()
719        .zip(eval.fitted_energy.iter())
720        .map(|(scale, energy)| scale * energy)
721        .sum::<f64>();
722    // VALUE: ½d·log|H| − ½ Σ_o w_o ⟨y_o, fit_o⟩ − ½d·rank·ρ.
723    let value = 0.5 * d * eval.logdet - 0.5 * fit_term - 0.5 * d * (rank as f64) * rho;
724    // ρ-GRADIENT: d/dρ of the same scalar. The logdet term contributes
725    // ½d·(tr(H⁻¹λS) − rank); the (data-independent-at-fixed-β envelope) fit term
726    // contributes +½ Σ_o w_o βᵀ(λS)β. Both share `eval`'s cached energies.
727    let grad = 0.5 * d * (eval.trace - rank as f64)
728        + 0.5
729            * scale_precision
730                .iter()
731                .zip(eval.penalty_energy.iter())
732                .map(|(scale, energy)| scale * energy)
733                .sum::<f64>();
734    // ρ-HESSIAN: d²/dρ². Logdet term: ½d·(tr(H⁻¹λS) − tr((H⁻¹λS)²)); penalty
735    // term: ½ Σ_o w_o (βᵀλSβ − 2 βᵀλS H⁻¹ λS β).
736    let hess = 0.5 * d * (eval.trace - eval.trace_pair)
737        + 0.5
738            * scale_precision
739                .iter()
740                .zip(eval.penalty_energy.iter().zip(eval.curvature_energy.iter()))
741                .map(|(scale, (energy, curvature))| scale * (energy - 2.0 * curvature))
742                .sum::<f64>();
743    BlockOrthogonalScaleDerivs { value, grad, hess }
744}
745
746fn solve_block_orthogonal_rho(
747    gram: &Array2<f64>,
748    rhs: &Array2<f64>,
749    penalty: &Array2<f64>,
750    rho0: f64,
751    scale_precision: ArrayView1<'_, f64>,
752    rank: usize,
753    max_iter: usize,
754) -> Result<(f64, BlockOrthogonalEval), EstimationError> {
755    let mut rho = rho0;
756    let mut current = block_orthogonal_eval(gram, rhs, penalty, rho)?;
757    for _ in 0..max_iter {
758        // Value, ρ-gradient, and ρ-Hessian all come from the SINGLE
759        // single-source objective evaluation — they cannot desync.
760        let derivs = block_orthogonal_scale_objective(&current, rho, scale_precision, rank);
761        let grad = derivs.grad;
762        let hess = derivs.hess;
763        if !(grad.is_finite() && hess.is_finite()) {
764            return Err(EstimationError::ModelIsIllConditioned {
765                condition_number: f64::INFINITY,
766            });
767        }
768        // Newton step where the curvature is reliably positive; a unit
769        // gradient-descent direction where it is not (non-convex / near-flat
770        // region) so we never step ALONG negative curvature (which ascends).
771        // There is deliberately NO magic step clamp: the line search below
772        // globalizes and SKIPS any candidate ρ that is infeasible (e.g. λ =
773        // exp(ρ) overflows `validate_initial_lambda`), so an over-long Newton
774        // step is simply rejected rather than bounded by an arbitrary constant
775        // or crashing the solve. This is the root fix the old `.clamp(-2,2)`
776        // was masking: the clamp existed only to keep an over-long step from
777        // reaching `block_orthogonal_eval`, which errors on a non-finite λ.
778        let descent = grad.signum();
779        let step = if hess > 1.0e-10 { grad / hess } else { descent };
780        let mut best_rho = rho;
781        let mut best_eval = current;
782        let mut best_phi =
783            block_orthogonal_scale_objective(&best_eval, best_rho, scale_precision, rank).value;
784        for candidate_rho in [
785            rho - step,
786            rho - 0.5 * step,
787            rho - 0.25 * step,
788            rho - descent,
789            rho - 0.25 * descent,
790        ] {
791            // Skip infeasible candidates (λ overflow / ill-conditioned Gram)
792            // instead of failing the whole solve — the bounded gradient
793            // candidates (`rho - descent`, `rho - 0.25·descent`) remain valid,
794            // so a too-long Newton step degrades to gradient descent.
795            let Ok(candidate_eval) = block_orthogonal_eval(gram, rhs, penalty, candidate_rho)
796            else {
797                continue;
798            };
799            let candidate_phi = block_orthogonal_scale_objective(
800                &candidate_eval,
801                candidate_rho,
802                scale_precision,
803                rank,
804            )
805            .value;
806            if candidate_phi < best_phi {
807                best_rho = candidate_rho;
808                best_eval = candidate_eval;
809                best_phi = candidate_phi;
810            }
811        }
812        let delta = (best_rho - rho).abs();
813        rho = best_rho;
814        current = best_eval;
815        if delta < 1.0e-12 || step.abs() < 1.0e-7 {
816            break;
817        }
818    }
819    Ok((rho, current))
820}
821
822pub fn gaussian_reml_blocks_orthogonal_shared_scale(
823    designs: &[Array2<f64>],
824    penalties: &[Array2<f64>],
825    y: ArrayView2<'_, f64>,
826    weights: Option<ArrayView1<'_, f64>>,
827    init_rhos: Option<&[f64]>,
828) -> Result<GaussianRemlBlockOrthogonalResult, EstimationError> {
829    if designs.is_empty() {
830        crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one block");
831    }
832    if designs.len() != penalties.len() {
833        crate::bail_invalid_estim!(
834            "block-orthogonal Gaussian REML block mismatch: {} designs, {} penalties",
835            designs.len(),
836            penalties.len()
837        );
838    }
839    let n = y.nrows();
840    let d = y.ncols();
841    if d == 0 {
842        crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one output");
843    }
844    if y.iter().any(|value| !value.is_finite()) {
845        crate::bail_invalid_estim!("block-orthogonal Gaussian REML response must be finite");
846    }
847    let weight = gaussian_reml_weights(n, weights)?;
848    if let Some(rhos) = init_rhos {
849        if rhos.len() != designs.len() {
850            crate::bail_invalid_estim!(
851                "block-orthogonal Gaussian REML init_rhos length mismatch: expected {}, got {}",
852                designs.len(),
853                rhos.len()
854            );
855        }
856        if rhos.iter().any(|value| !value.is_finite()) {
857            crate::bail_invalid_estim!("block-orthogonal Gaussian REML init_rhos must be finite");
858        }
859    }
860
861    let mut ywy = Array1::<f64>::zeros(d);
862    for row in 0..n {
863        for output in 0..d {
864            ywy[output] += weight[row] * y[[row, output]] * y[[row, output]];
865        }
866    }
867    let mut grams = Vec::with_capacity(designs.len());
868    let mut rhs_blocks = Vec::with_capacity(designs.len());
869    let mut penalties_owned = Vec::with_capacity(penalties.len());
870    let mut ranks = Vec::with_capacity(penalties.len());
871    let mut penalty_logdets = Vec::with_capacity(penalties.len());
872    let mut nullity_total = 0_usize;
873    for (block, (design, penalty)) in designs.iter().zip(penalties.iter()).enumerate() {
874        let penalty_owned = canonicalize_penalty(penalty.view());
875        validate_gaussian_reml_design(design.view(), penalty_owned.view(), Some(weight.view()))?;
876        if design.nrows() != n {
877            crate::bail_invalid_estim!(
878                "block-orthogonal Gaussian REML designs[{block}] has {} rows, expected {n}",
879                design.nrows()
880            );
881        }
882        let gram = dense_xt_diag_x(design.view(), weight.view());
883        let rhs = dense_xt_diag_y(design.view(), weight.view(), y);
884        let (rank, logdet) = block_penalty_rank_logdet(penalty_owned.view())?;
885        nullity_total += penalty_owned.nrows().saturating_sub(rank);
886        grams.push(canonicalize_penalty(gram.view()));
887        rhs_blocks.push(rhs);
888        penalties_owned.push(penalty_owned);
889        ranks.push(rank);
890        penalty_logdets.push(logdet);
891    }
892    let n_effective = effective_observation_count(weight.view());
893    if n_effective <= nullity_total {
894        crate::bail_invalid_estim!(
895            "block-orthogonal Gaussian REML requires more positive-weight rows than the total penalty nullity; got n_effective={n_effective}, nullity={nullity_total}"
896        );
897    }
898    let nu = (n_effective - nullity_total) as f64;
899    let mut rhos = match init_rhos {
900        Some(values) => Array1::from_vec(values.to_vec()),
901        None => Array1::zeros(designs.len()),
902    };
903    let mut scale_precision = ywy.mapv(|value| nu / value.max(MIN_DEVIANCE));
904    let mut evals = Vec::new();
905    for _ in 0..40 {
906        evals.clear();
907        for block in 0..designs.len() {
908            let (rho, eval) = solve_block_orthogonal_rho(
909                &grams[block],
910                &rhs_blocks[block],
911                &penalties_owned[block],
912                rhos[block],
913                scale_precision.view(),
914                ranks[block],
915                32,
916            )?;
917            rhos[block] = rho;
918            evals.push(eval);
919        }
920        let mut explained = Array1::<f64>::zeros(d);
921        for eval in evals.iter() {
922            explained += &eval.fitted_energy;
923        }
924        let q = &ywy - &explained;
925        if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
926            return Err(EstimationError::ModelIsIllConditioned {
927                condition_number: f64::INFINITY,
928            });
929        }
930        let next_scale = q.mapv(|value| nu / value);
931        let scale_step = next_scale
932            .iter()
933            .zip(scale_precision.iter())
934            .map(|(next, old)| (next.ln() - old.ln()).abs())
935            .fold(0.0_f64, f64::max);
936        scale_precision = next_scale;
937        if scale_step < 1.0e-7 {
938            break;
939        }
940    }
941    evals.clear();
942    for block in 0..designs.len() {
943        let (rho, eval) = solve_block_orthogonal_rho(
944            &grams[block],
945            &rhs_blocks[block],
946            &penalties_owned[block],
947            rhos[block],
948            scale_precision.view(),
949            ranks[block],
950            16,
951        )?;
952        rhos[block] = rho;
953        evals.push(eval);
954    }
955
956    let coefficients = evals
957        .iter()
958        .map(|eval| eval.beta.clone())
959        .collect::<Vec<_>>();
960    let mut fitted = Array2::<f64>::zeros((n, d));
961    for (design, coef) in designs.iter().zip(coefficients.iter()) {
962        fitted += &fast_ab(&design.view(), &coef.view());
963    }
964    let mut explained = Array1::<f64>::zeros(d);
965    for eval in evals.iter() {
966        explained += &eval.fitted_energy;
967    }
968    let q = &ywy - &explained;
969    if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
970        return Err(EstimationError::ModelIsIllConditioned {
971            condition_number: f64::INFINITY,
972        });
973    }
974    let lambdas = rhos.mapv(f64::exp);
975    let edf = Array1::from_iter(evals.iter().map(|eval| eval.edf));
976    let logdet_term = evals
977        .iter()
978        .enumerate()
979        .map(|(block, eval)| {
980            eval.logdet - penalty_logdets[block] - (ranks[block] as f64) * rhos[block]
981        })
982        .sum::<f64>();
983    let scale_term = q
984        .iter()
985        .map(|value| nu * (1.0 + (2.0 * std::f64::consts::PI * value / nu).ln()))
986        .sum::<f64>();
987    Ok(GaussianRemlBlockOrthogonalResult {
988        coefficients,
989        fitted,
990        lambdas,
991        log_lambdas: rhos,
992        reml_score: 0.5 * (d as f64) * logdet_term + 0.5 * scale_term,
993        edf,
994    })
995}
996
997fn gaussian_reml_multi_closed_form_from_parts(
998    x: ArrayView2<'_, f64>,
999    y: ArrayView2<'_, f64>,
1000    penalty: ArrayView2<'_, f64>,
1001    nullspace_dim: Option<usize>,
1002    weights: Option<ArrayView1<'_, f64>>,
1003    init_lambda: Option<f64>,
1004    eigen_cache: Option<&GaussianRemlEigenCache>,
1005) -> Result<GaussianRemlMultiResult, EstimationError> {
1006    let prepared = prepare_gaussian_reml(x, y, penalty, nullspace_dim, weights, eigen_cache)?;
1007    let init_rho = init_lambda
1008        .map(validate_initial_lambda)
1009        .transpose()?
1010        .map(f64::ln);
1011    let rho = optimize_rho(&prepared, init_rho)?;
1012    let eval = prepared.evaluate(rho);
1013    let lambda = rho.exp();
1014    let coefficients = prepared.coefficients(lambda);
1015    let fitted = dense_ab(x, coefficients.view());
1016    let sigma2 = prepared.sigma2(lambda);
1017    let (reml_grad_lambda, reml_hess_lambda) =
1018        rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
1019    Ok(GaussianRemlMultiResult {
1020        lambda,
1021        rho,
1022        coefficients,
1023        fitted,
1024        reml_score: eval.cost,
1025        reml_grad_lambda,
1026        reml_hess_lambda,
1027        reml_grad_rho: eval.grad,
1028        reml_hess_rho: eval.hess,
1029        edf: eval.edf,
1030        sigma2,
1031        cache: prepared.cache,
1032    })
1033}
1034
1035pub fn gaussian_reml_free_b_score(
1036    x: ArrayView2<'_, f64>,
1037    y: ArrayView2<'_, f64>,
1038    coefficients: ArrayView2<'_, f64>,
1039    log_lambda: f64,
1040    penalty: ArrayView2<'_, f64>,
1041    weights: Option<ArrayView1<'_, f64>>,
1042) -> Result<GaussianRemlFreeBScore, EstimationError> {
1043    if !log_lambda.is_finite() {
1044        crate::bail_invalid_estim!("Gaussian REML log_lambda must be finite; got {log_lambda}");
1045    }
1046    let lambda = log_lambda.exp();
1047    let penalty_owned = canonicalize_penalty(penalty);
1048    let penalty = penalty_owned.view();
1049    let n = x.nrows();
1050    let p = x.ncols();
1051    let d = y.ncols();
1052    validate_gaussian_reml_design(x, penalty, weights)?;
1053    if y.nrows() != n {
1054        crate::bail_invalid_estim!(
1055            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
1056            y.nrows()
1057        );
1058    }
1059    if coefficients.dim() != (p, d) {
1060        crate::bail_invalid_estim!(
1061            "Gaussian REML coefficient shape mismatch: expected {p}x{d}, got {}x{}",
1062            coefficients.nrows(),
1063            coefficients.ncols()
1064        );
1065    }
1066    if y.iter().chain(coefficients.iter()).any(|v| !v.is_finite()) {
1067        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
1068    }
1069
1070    let weight = gaussian_reml_weights(n, weights)?;
1071    let n_effective = effective_observation_count(weight.view());
1072    let cache =
1073        build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, Some(weight.view()))?;
1074    if n_effective <= cache.nullity {
1075        crate::bail_invalid_estim!(
1076            "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
1077            cache.nullity
1078        );
1079    }
1080    let nu = n_effective as f64 - cache.nullity as f64;
1081    let fitted = dense_ab(x, coefficients);
1082    let residual = y.to_owned() - &fitted;
1083    let xtw_residual = dense_xt_diag_y(x, weight.view(), residual.view());
1084    let s_beta = dense_ab(penalty, coefficients);
1085
1086    let mut logdet_h = cache.logdet_xtwx;
1087    let mut trace_h = 0.0;
1088    let mut edf = 0.0;
1089    for &delta in &cache.penalty_eigenvalues {
1090        let t = lambda * delta;
1091        logdet_h += (1.0 + t).ln();
1092        if delta > 0.0 {
1093            trace_h += t / (1.0 + t);
1094        }
1095        edf += 1.0 / (1.0 + t);
1096    }
1097    let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * log_lambda;
1098    let mut reml_score = 0.5 * (d as f64) * (logdet_h - logdet_s);
1099    let mut grad_log_lambda = 0.5 * (d as f64) * (trace_h - cache.penalty_rank as f64);
1100    let mut grad_coefficients = Array2::<f64>::zeros((p, d));
1101    let inverse_hessian = {
1102        let xtwx = dense_xt_diag_x(x, weight.view());
1103        let mut hessian = xtwx;
1104        hessian += &(penalty.to_owned() * lambda);
1105        hessian
1106            .cholesky(Side::Lower)
1107            .map_err(EstimationError::LinearSystemSolveFailed)?
1108            .solve_mat(&Array2::<f64>::eye(p))
1109    };
1110    let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(&cache);
1111    let mut grad_penalty = Array2::<f64>::zeros((p, p));
1112    for row in 0..p {
1113        for col in 0..p {
1114            grad_penalty[[row, col]] += 0.5
1115                * (d as f64)
1116                * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1117        }
1118    }
1119    let mut sigma2 = Array1::<f64>::zeros(d);
1120
1121    for output in 0..d {
1122        let mut weighted_rss = 0.0;
1123        for row in 0..n {
1124            let r = residual[[row, output]];
1125            weighted_rss += weight[row] * r * r;
1126        }
1127        let beta_col = coefficients.column(output);
1128        let s_beta_col = s_beta.column(output);
1129        let penalty_quadratic = beta_col.dot(&s_beta_col);
1130        let dp = (weighted_rss + lambda * penalty_quadratic).max(MIN_DEVIANCE);
1131        sigma2[output] = dp / nu;
1132        reml_score += 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln());
1133        grad_log_lambda += 0.5 * nu * lambda * penalty_quadratic / dp;
1134        let scale = nu / dp;
1135        for coeff in 0..p {
1136            grad_coefficients[[coeff, output]] =
1137                scale * (-xtw_residual[[coeff, output]] + lambda * s_beta[[coeff, output]]);
1138        }
1139        add_rank_one_penalty_vjp(0.5 * scale * lambda, beta_col, &mut grad_penalty);
1140    }
1141    for i in 0..p {
1142        for j in (i + 1)..p {
1143            let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1144            grad_penalty[[i, j]] = avg;
1145            grad_penalty[[j, i]] = avg;
1146        }
1147    }
1148
1149    Ok(GaussianRemlFreeBScore {
1150        reml_score,
1151        grad_coefficients,
1152        grad_penalty,
1153        grad_log_lambda,
1154        fitted,
1155        sigma2,
1156        edf,
1157    })
1158}
1159
1160pub fn gaussian_reml_multi_closed_form_backward(
1161    x: ArrayView2<'_, f64>,
1162    y: ArrayView2<'_, f64>,
1163    penalty: ArrayView2<'_, f64>,
1164    weights: Option<ArrayView1<'_, f64>>,
1165    init_lambda: Option<f64>,
1166    upstream_lambda: f64,
1167    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1168    upstream_fitted: Option<ArrayView2<'_, f64>>,
1169    upstream_reml_score: f64,
1170    upstream_edf: f64,
1171) -> Result<GaussianRemlBackwardResult, EstimationError> {
1172    let fit =
1173        gaussian_reml_multi_closed_form_with_cache(x, y, penalty, weights, init_lambda, None)?;
1174    gaussian_reml_multi_closed_form_backward_from_fit(
1175        x,
1176        y,
1177        penalty,
1178        weights,
1179        &fit,
1180        upstream_lambda,
1181        upstream_coefficients,
1182        upstream_fitted,
1183        upstream_reml_score,
1184        upstream_edf,
1185    )
1186}
1187
1188pub fn gaussian_reml_multi_closed_form_backward_from_fit(
1189    x: ArrayView2<'_, f64>,
1190    y: ArrayView2<'_, f64>,
1191    penalty: ArrayView2<'_, f64>,
1192    weights: Option<ArrayView1<'_, f64>>,
1193    fit: &GaussianRemlMultiResult,
1194    upstream_lambda: f64,
1195    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1196    upstream_fitted: Option<ArrayView2<'_, f64>>,
1197    upstream_reml_score: f64,
1198    upstream_edf: f64,
1199) -> Result<GaussianRemlBackwardResult, EstimationError> {
1200    validate_gaussian_reml_backward_upstreams(
1201        x,
1202        y,
1203        penalty,
1204        upstream_lambda,
1205        upstream_coefficients,
1206        upstream_fitted,
1207        upstream_reml_score,
1208        upstream_edf,
1209    )?;
1210    validate_gaussian_reml_forward_fit(x, y, penalty, weights, fit)?;
1211    let lambda = fit.lambda;
1212    let n = x.nrows();
1213    let p = x.ncols();
1214    let d = y.ncols();
1215    if !(fit.reml_hess_rho.is_finite() && fit.reml_hess_rho.abs() > 1.0e-14) {
1216        // Graceful degradation: when λ saturates, K = XᵀWX + λS is
1217        // effectively rank-deficient and the analytic VJP is undefined.
1218        // Return zero gradients (the correct shrink-out limit) instead of
1219        // raising — production training at large F can have individual
1220        // atoms saturate λ in early batches and must not blow up here.
1221        warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1222        return Ok(zero_backward_result(n, p, d));
1223    }
1224    let weight = gaussian_reml_weights(n, weights)?;
1225    let inverse_hessian = match gaussian_reml_inverse_hessian_from_cache(&fit.cache, lambda) {
1226        Ok(inv) => inv,
1227        Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1228            warn_ill_conditioned_backward_once(p, d, condition_number);
1229            return Ok(zero_backward_result(n, p, d));
1230        }
1231        Err(err) => return Err(err),
1232    };
1233    gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1234        x,
1235        y,
1236        penalty,
1237        weight,
1238        fit,
1239        inverse_hessian,
1240        upstream_lambda,
1241        upstream_coefficients,
1242        upstream_fitted,
1243        upstream_reml_score,
1244        upstream_edf,
1245        n,
1246        p,
1247        d,
1248    )
1249}
1250
1251fn gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1252    x: ArrayView2<'_, f64>,
1253    y: ArrayView2<'_, f64>,
1254    penalty: ArrayView2<'_, f64>,
1255    weight: Array1<f64>,
1256    fit: &GaussianRemlMultiResult,
1257    inverse_hessian: Array2<f64>,
1258    upstream_lambda: f64,
1259    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1260    upstream_fitted: Option<ArrayView2<'_, f64>>,
1261    upstream_reml_score: f64,
1262    upstream_edf: f64,
1263    n: usize,
1264    p: usize,
1265    d: usize,
1266) -> Result<GaussianRemlBackwardResult, EstimationError> {
1267    // Backward sees the same symmetric S the forward used. Canonicalize on
1268    // entry so an asymmetric input (e.g. a single-entry gradcheck perturbation
1269    // around a symmetric base) cannot leak into the per-helper VJPs.
1270    let penalty_owned = canonicalize_penalty(penalty);
1271    let penalty = penalty_owned.view();
1272    let lambda = fit.lambda;
1273    let beta = &fit.coefficients;
1274    let residual = y.to_owned() - &fit.fitted;
1275    // Match the forward's REML residual DoF: zero prior-weight rows are excluded
1276    // from the effective sample size (see `effective_observation_count`), so the
1277    // adjoint of `ν` uses the same count the forward used.
1278    let nu = effective_observation_count(weight.view()) as f64 - fit.cache.nullity as f64;
1279
1280    let mut grad_x = Array2::<f64>::zeros((n, p));
1281    let mut grad_y = Array2::<f64>::zeros((n, d));
1282    let mut grad_penalty = Array2::<f64>::zeros((p, p));
1283    let mut grad_weights = Array1::<f64>::zeros(n);
1284
1285    let mut upstream_beta = Array2::<f64>::zeros((p, d));
1286    if let Some(upstream_coefficients) = upstream_coefficients {
1287        upstream_beta += &upstream_coefficients;
1288    }
1289    if let Some(upstream_fitted) = upstream_fitted {
1290        upstream_beta += &dense_atb(x, upstream_fitted);
1291        grad_x += &dense_ab(upstream_fitted, beta.t());
1292    }
1293
1294    let mut lambda_adjoint = upstream_lambda;
1295    if upstream_beta.iter().any(|value| *value != 0.0) {
1296        // A downstream loss that explicitly uses beta_hat or fitted = X beta_hat
1297        // cannot use the REML envelope shortcut.  Route those seeds through
1298        // the fixed-rho KKT adjoint M u = upstream_beta, then differentiate
1299        // X, y, weights, and S through the ridge solve.
1300        add_ridge_profile_vjp_with_lambda_grad(
1301            1.0,
1302            x,
1303            y,
1304            penalty,
1305            &weight,
1306            lambda,
1307            &inverse_hessian,
1308            beta,
1309            upstream_beta.view(),
1310            &mut grad_x,
1311            &mut grad_y,
1312            &mut grad_penalty,
1313            &mut grad_weights,
1314            &mut lambda_adjoint,
1315        );
1316    }
1317
1318    if upstream_reml_score != 0.0 {
1319        add_reml_score_vjp(
1320            upstream_reml_score,
1321            x,
1322            &weight,
1323            &inverse_hessian,
1324            beta,
1325            &residual,
1326            &fit.sigma2,
1327            nu,
1328            lambda,
1329            &fit.cache,
1330            &mut grad_x,
1331            &mut grad_y,
1332            &mut grad_penalty,
1333            &mut grad_weights,
1334        );
1335        lambda_adjoint += upstream_reml_score * fit.reml_grad_lambda;
1336    }
1337
1338    if upstream_edf != 0.0 {
1339        lambda_adjoint += add_edf_vjp(
1340            upstream_edf,
1341            x,
1342            penalty,
1343            &weight,
1344            lambda,
1345            &inverse_hessian,
1346            &mut grad_x,
1347            &mut grad_penalty,
1348            &mut grad_weights,
1349        );
1350    }
1351
1352    if lambda_adjoint != 0.0 {
1353        let root_scale = -lambda_adjoint * lambda / fit.reml_hess_rho;
1354        add_reml_rho_gradient_vjp(
1355            root_scale,
1356            x,
1357            y,
1358            penalty,
1359            &weight,
1360            lambda,
1361            &inverse_hessian,
1362            beta,
1363            &residual,
1364            &fit.sigma2,
1365            nu,
1366            &mut grad_x,
1367            &mut grad_y,
1368            &mut grad_penalty,
1369            &mut grad_weights,
1370        );
1371    }
1372
1373    // The forward consumes `S` only through the canonicalization
1374    // `S_canon = 0.5 (S + Sᵀ)`. By the chain rule, the gradient w.r.t. an
1375    // input `S_input` is `0.5 (G + Gᵀ)` where `G = ∂L/∂S_canon` is what the
1376    // per-helper VJPs accumulate. Symmetrize the full matrix here so a
1377    // single-entry perturbation `δS = ε E_{i,j}` (asymmetric, as
1378    // `torch.autograd.gradcheck` produces) sees the gradient component
1379    // `0.5 (G[i,j] + G[j,i])` it expects from FD — no caller-side
1380    // bookkeeping required.
1381    let p = grad_penalty.nrows();
1382    for i in 0..p {
1383        for j in (i + 1)..p {
1384            let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1385            grad_penalty[[i, j]] = avg;
1386            grad_penalty[[j, i]] = avg;
1387        }
1388    }
1389    Ok(GaussianRemlBackwardResult {
1390        grad_x,
1391        grad_y,
1392        grad_penalty,
1393        grad_weights,
1394    })
1395}
1396
1397pub fn gaussian_reml_multi_closed_form_backward_batch<'a>(
1398    problems: &[GaussianRemlMultiBackwardProblem<'a>],
1399    penalty: ArrayView2<'a, f64>,
1400) -> Vec<Result<GaussianRemlBackwardResult, EstimationError>> {
1401    let inverse_hessians = batched_inverse_hessians_from_caches(problems);
1402    let results: Vec<Result<GaussianRemlBackwardResult, EstimationError>> = problems
1403        .par_iter()
1404        .zip(inverse_hessians.into_par_iter())
1405        .map(|(problem, inverse_hessian_result)| {
1406            validate_gaussian_reml_backward_upstreams(
1407                problem.x.view(),
1408                problem.y.view(),
1409                penalty,
1410                problem.grad_lambda,
1411                problem.grad_coefficients.as_ref().map(|g| g.view()),
1412                problem.grad_fitted.as_ref().map(|g| g.view()),
1413                problem.grad_reml_score,
1414                problem.grad_edf,
1415            )?;
1416            validate_gaussian_reml_forward_fit(
1417                problem.x.view(),
1418                problem.y.view(),
1419                penalty,
1420                problem.weights.as_ref().map(|w| w.view()),
1421                problem.fit,
1422            )?;
1423            let n = problem.x.nrows();
1424            let p = problem.x.ncols();
1425            let d = problem.y.ncols();
1426            if !(problem.fit.reml_hess_rho.is_finite() && problem.fit.reml_hess_rho.abs() > 1.0e-14)
1427            {
1428                // Graceful degradation — see `gaussian_reml_multi_closed_form_backward_from_fit`.
1429                warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1430                return Ok(zero_backward_result(n, p, d));
1431            }
1432            let weight = gaussian_reml_weights(n, problem.weights.as_ref().map(|w| w.view()))?;
1433            let inverse_hessian = match inverse_hessian_result {
1434                Ok(inv) => inv,
1435                Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1436                    warn_ill_conditioned_backward_once(p, d, condition_number);
1437                    return Ok(zero_backward_result(n, p, d));
1438                }
1439                Err(err) => return Err(err),
1440            };
1441            gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1442                problem.x.view(),
1443                problem.y.view(),
1444                penalty,
1445                weight,
1446                problem.fit,
1447                inverse_hessian,
1448                problem.grad_lambda,
1449                problem.grad_coefficients.as_ref().map(|g| g.view()),
1450                problem.grad_fitted.as_ref().map(|g| g.view()),
1451                problem.grad_reml_score,
1452                problem.grad_edf,
1453                n,
1454                p,
1455                d,
1456            )
1457        })
1458        .collect();
1459    results
1460}
1461
1462fn rho_derivatives_to_lambda(lambda: f64, grad_rho: f64, hess_rho: f64) -> (f64, f64) {
1463    (grad_rho / lambda, (hess_rho - grad_rho) / (lambda * lambda))
1464}
1465
1466fn validate_gaussian_reml_backward_upstreams(
1467    x: ArrayView2<'_, f64>,
1468    y: ArrayView2<'_, f64>,
1469    penalty: ArrayView2<'_, f64>,
1470    upstream_lambda: f64,
1471    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1472    upstream_fitted: Option<ArrayView2<'_, f64>>,
1473    upstream_reml_score: f64,
1474    upstream_edf: f64,
1475) -> Result<(), EstimationError> {
1476    if !(upstream_lambda.is_finite() && upstream_reml_score.is_finite() && upstream_edf.is_finite())
1477    {
1478        crate::bail_invalid_estim!("Gaussian REML backward upstream scalars must be finite");
1479    }
1480    if let Some(upstream_coefficients) = upstream_coefficients {
1481        if upstream_coefficients.dim() != (x.ncols(), y.ncols()) {
1482            crate::bail_invalid_estim!(
1483                "Gaussian REML backward coefficient upstream shape mismatch: expected {}x{}, got {}x{}",
1484                x.ncols(),
1485                y.ncols(),
1486                upstream_coefficients.nrows(),
1487                upstream_coefficients.ncols()
1488            );
1489        }
1490        if upstream_coefficients.iter().any(|value| !value.is_finite()) {
1491            crate::bail_invalid_estim!(
1492                "Gaussian REML backward coefficient upstream must be finite"
1493            );
1494        }
1495    }
1496    if let Some(upstream_fitted) = upstream_fitted {
1497        if upstream_fitted.dim() != y.dim() {
1498            crate::bail_invalid_estim!(
1499                "Gaussian REML backward fitted upstream shape mismatch: expected {}x{}, got {}x{}",
1500                y.nrows(),
1501                y.ncols(),
1502                upstream_fitted.nrows(),
1503                upstream_fitted.ncols()
1504            );
1505        }
1506        if upstream_fitted.iter().any(|value| !value.is_finite()) {
1507            crate::bail_invalid_estim!("Gaussian REML backward fitted upstream must be finite");
1508        }
1509    }
1510    validate_gaussian_reml_design(x, penalty, None)?;
1511    Ok(())
1512}
1513
1514fn validate_gaussian_reml_forward_fit(
1515    x: ArrayView2<'_, f64>,
1516    y: ArrayView2<'_, f64>,
1517    penalty: ArrayView2<'_, f64>,
1518    weights: Option<ArrayView1<'_, f64>>,
1519    fit: &GaussianRemlMultiResult,
1520) -> Result<(), EstimationError> {
1521    // Fingerprint the canonicalized penalty: caches are keyed on the
1522    // symmetric average, and the caller may hand us a raw input (e.g. a
1523    // single-entry-perturbed matrix produced by ``torch.autograd.gradcheck``).
1524    let penalty_owned = canonicalize_penalty(penalty);
1525    let penalty = penalty_owned.view();
1526    let n = x.nrows();
1527    let p = x.ncols();
1528    let d = y.ncols();
1529    validate_gaussian_reml_design(x, penalty, weights)?;
1530    validate_gaussian_reml_eigen_cache(&fit.cache, p)?;
1531    if y.nrows() != n
1532        || fit.coefficients.dim() != (p, d)
1533        || fit.fitted.dim() != (n, d)
1534        || fit.sigma2.len() != d
1535    {
1536        crate::bail_invalid_estim!(
1537            "Gaussian REML backward forward-state shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
1538        );
1539    }
1540    if !(fit.lambda.is_finite()
1541        && fit.lambda > 0.0
1542        && fit.rho.is_finite()
1543        && fit.reml_score.is_finite()
1544        && fit.reml_hess_rho.is_finite()
1545        && fit.edf.is_finite())
1546        || fit.coefficients.iter().any(|value| !value.is_finite())
1547        || fit.fitted.iter().any(|value| !value.is_finite())
1548        || fit.sigma2.iter().any(|value| !value.is_finite())
1549    {
1550        crate::bail_invalid_estim!("Gaussian REML backward forward state must be finite");
1551    }
1552    let penalty_fingerprint = matrix_fingerprint(penalty);
1553    if fit.cache.penalty_fingerprint != penalty_fingerprint {
1554        crate::bail_invalid_estim!("Gaussian REML backward forward-state penalty mismatch");
1555    }
1556    let weight = gaussian_reml_weights(n, weights)?;
1557    let xtwx = dense_xt_diag_x(x, weight.view());
1558    if fit.cache.xtwx_fingerprint != matrix_fingerprint(xtwx.view()) {
1559        crate::bail_invalid_estim!("Gaussian REML backward forward-state X'WX mismatch");
1560    }
1561    Ok(())
1562}
1563
1564fn gaussian_reml_inverse_hessian_from_cache(
1565    cache: &GaussianRemlEigenCache,
1566    lambda: f64,
1567) -> Result<Array2<f64>, EstimationError> {
1568    if !(lambda.is_finite() && lambda > 0.0) {
1569        crate::bail_invalid_estim!(
1570            "Gaussian REML lambda must be finite and positive; got {lambda}"
1571        );
1572    }
1573    let p = cache.penalty_eigenvalues.len();
1574    let mut scaled_basis = cache.coefficient_basis.clone();
1575    for eig in 0..p {
1576        let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1577        for row in 0..p {
1578            scaled_basis[[row, eig]] *= scale;
1579        }
1580    }
1581    let inverse = dense_ab(scaled_basis.view(), cache.coefficient_basis.t());
1582    if inverse.iter().any(|value| !value.is_finite()) {
1583        return Err(EstimationError::ModelIsIllConditioned {
1584            condition_number: f64::INFINITY,
1585        });
1586    }
1587    Ok(inverse)
1588}
1589
1590fn batched_inverse_hessians_from_caches(
1591    problems: &[GaussianRemlMultiBackwardProblem<'_>],
1592) -> Vec<Result<Array2<f64>, EstimationError>> {
1593    if problems.is_empty() {
1594        return Vec::new();
1595    }
1596    let p = problems[0].fit.cache.coefficient_basis.nrows();
1597    let uniform = p > 0
1598        && problems.iter().all(|problem| {
1599            let cache = &problem.fit.cache;
1600            cache.coefficient_basis.dim() == (p, p) && cache.penalty_eigenvalues.len() == p
1601        });
1602    if uniform && problems.len() > 1 {
1603        let mut scaled_basis = Array3::<f64>::zeros((problems.len(), p, p));
1604        let mut basis = Array3::<f64>::zeros((problems.len(), p, p));
1605        let mut valid = true;
1606        for (idx, problem) in problems.iter().enumerate() {
1607            let lambda = problem.fit.lambda;
1608            if !(lambda.is_finite() && lambda > 0.0) {
1609                valid = false;
1610                break;
1611            }
1612            let cache = &problem.fit.cache;
1613            basis
1614                .slice_mut(s![idx, .., ..])
1615                .assign(&cache.coefficient_basis);
1616            for eig in 0..p {
1617                let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1618                for row in 0..p {
1619                    scaled_basis[[idx, row, eig]] = cache.coefficient_basis[[row, eig]] * scale;
1620                }
1621            }
1622        }
1623        if valid
1624            && let Some(inverses) =
1625                gam_gpu::try_fast_abt_strided_batched(scaled_basis.view(), basis.view())
1626        {
1627            return inverses
1628                .axis_iter(Axis(0))
1629                .map(|inverse| Ok(inverse.to_owned()))
1630                .collect();
1631        }
1632    }
1633    problems
1634        .iter()
1635        .map(|problem| {
1636            gaussian_reml_inverse_hessian_from_cache(&problem.fit.cache, problem.fit.lambda)
1637        })
1638        .collect()
1639}
1640
1641/// Side-effects of the ridge-profile VJP that are independent of λ.
1642///
1643/// Computes the KKT adjoint `m = M^{-1} u` for `u = upstream_beta` and accumulates
1644/// the partials w.r.t. `X`, `y`, `S`, and `w` into the provided gradient buffers.
1645/// Returns `m` so callers that also need `∂L/∂λ` can fold in the λ-adjoint dot
1646/// product `−scale · ⟨m, S β⟩` without recomputing the adjoint solve.
1647fn ridge_profile_vjp_data_partials(
1648    scale: f64,
1649    x: ArrayView2<'_, f64>,
1650    y: ArrayView2<'_, f64>,
1651    penalty: ArrayView2<'_, f64>,
1652    weights: &Array1<f64>,
1653    lambda: f64,
1654    inverse_hessian: &Array2<f64>,
1655    beta: &Array2<f64>,
1656    upstream_beta: ArrayView2<'_, f64>,
1657    grad_x: &mut Array2<f64>,
1658    grad_y: &mut Array2<f64>,
1659    grad_penalty: &mut Array2<f64>,
1660    grad_weights: &mut Array1<f64>,
1661) -> Array2<f64> {
1662    let m = dense_ab(inverse_hessian.view(), upstream_beta);
1663    let c = dense_ab(m.view(), beta.t());
1664    let c_sym = &c + &c.t();
1665    let ymt = dense_ab(y, m.t());
1666    let xcs = dense_ab(x, c_sym.view());
1667    for i in 0..x.nrows() {
1668        let wi = weights[i] * scale;
1669        for k in 0..x.ncols() {
1670            grad_x[[i, k]] += wi * (ymt[[i, k]] - xcs[[i, k]]);
1671        }
1672    }
1673
1674    let xm = dense_ab(x, m.view());
1675    for i in 0..x.nrows() {
1676        let wi = weights[i] * scale;
1677        for j in 0..y.ncols() {
1678            grad_y[[i, j]] += wi * xm[[i, j]];
1679        }
1680    }
1681
1682    let xc = dense_ab(x, c.view());
1683    for i in 0..x.nrows() {
1684        let mut from_b = 0.0;
1685        for j in 0..y.ncols() {
1686            from_b += y[[i, j]] * xm[[i, j]];
1687        }
1688        let mut from_a = 0.0;
1689        for k in 0..x.ncols() {
1690            from_a += x[[i, k]] * xc[[i, k]];
1691        }
1692        grad_weights[i] += scale * (from_b - from_a);
1693    }
1694
1695    for row in 0..penalty.nrows() {
1696        for col in 0..penalty.ncols() {
1697            let mut value = 0.0;
1698            for output in 0..beta.ncols() {
1699                value += m[[row, output]] * beta[[col, output]];
1700            }
1701            grad_penalty[[row, col]] -= scale * lambda * value;
1702        }
1703    }
1704    m
1705}
1706
1707/// Ridge-profile VJP for callers that also need `∂L/∂λ`.
1708///
1709/// Accumulates the data/penalty/weight partials and adds the implicit-function
1710/// λ-adjoint contribution `−scale · ⟨M^{-1} u, S β⟩` into `lambda_adjoint_out`.
1711fn add_ridge_profile_vjp_with_lambda_grad(
1712    scale: f64,
1713    x: ArrayView2<'_, f64>,
1714    y: ArrayView2<'_, f64>,
1715    penalty: ArrayView2<'_, f64>,
1716    weights: &Array1<f64>,
1717    lambda: f64,
1718    inverse_hessian: &Array2<f64>,
1719    beta: &Array2<f64>,
1720    upstream_beta: ArrayView2<'_, f64>,
1721    grad_x: &mut Array2<f64>,
1722    grad_y: &mut Array2<f64>,
1723    grad_penalty: &mut Array2<f64>,
1724    grad_weights: &mut Array1<f64>,
1725    lambda_adjoint_out: &mut f64,
1726) {
1727    let m = ridge_profile_vjp_data_partials(
1728        scale,
1729        x,
1730        y,
1731        penalty,
1732        weights,
1733        lambda,
1734        inverse_hessian,
1735        beta,
1736        upstream_beta,
1737        grad_x,
1738        grad_y,
1739        grad_penalty,
1740        grad_weights,
1741    );
1742    let penalty_beta = dense_ab(penalty, beta.view());
1743    let dot = m
1744        .iter()
1745        .zip(penalty_beta.iter())
1746        .map(|(left, right)| left * right)
1747        .sum::<f64>();
1748    *lambda_adjoint_out += -scale * dot;
1749}
1750
1751/// Ridge-profile VJP for callers that hold λ fixed (e.g. the implicit-root
1752/// partial inside `add_reml_rho_gradient_vjp`). The λ-adjoint dot product is
1753/// skipped entirely — it would be unused work in this branch.
1754fn add_ridge_profile_vjp_fixed_lambda(
1755    scale: f64,
1756    x: ArrayView2<'_, f64>,
1757    y: ArrayView2<'_, f64>,
1758    penalty: ArrayView2<'_, f64>,
1759    weights: &Array1<f64>,
1760    lambda: f64,
1761    inverse_hessian: &Array2<f64>,
1762    beta: &Array2<f64>,
1763    upstream_beta: ArrayView2<'_, f64>,
1764    grad_x: &mut Array2<f64>,
1765    grad_y: &mut Array2<f64>,
1766    grad_penalty: &mut Array2<f64>,
1767    grad_weights: &mut Array1<f64>,
1768) {
1769    ridge_profile_vjp_data_partials(
1770        scale,
1771        x,
1772        y,
1773        penalty,
1774        weights,
1775        lambda,
1776        inverse_hessian,
1777        beta,
1778        upstream_beta,
1779        grad_x,
1780        grad_y,
1781        grad_penalty,
1782        grad_weights,
1783    );
1784}
1785
1786fn add_reml_score_vjp(
1787    scale: f64,
1788    x: ArrayView2<'_, f64>,
1789    weights: &Array1<f64>,
1790    inverse_hessian: &Array2<f64>,
1791    beta: &Array2<f64>,
1792    residual: &Array2<f64>,
1793    sigma2: &Array1<f64>,
1794    nu: f64,
1795    lambda: f64,
1796    cache: &GaussianRemlEigenCache,
1797    grad_x: &mut Array2<f64>,
1798    grad_y: &mut Array2<f64>,
1799    grad_penalty: &mut Array2<f64>,
1800    grad_weights: &mut Array1<f64>,
1801) {
1802    let d = beta.ncols() as f64;
1803    let xp = dense_ab(x, inverse_hessian.view());
1804    let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(cache);
1805    for row in 0..grad_penalty.nrows() {
1806        for col in 0..grad_penalty.ncols() {
1807            grad_penalty[[row, col]] +=
1808                scale * 0.5 * d * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1809        }
1810    }
1811    for i in 0..x.nrows() {
1812        let wi = weights[i] * scale * d;
1813        for k in 0..x.ncols() {
1814            grad_x[[i, k]] += wi * xp[[i, k]];
1815        }
1816        let mut leverage = 0.0;
1817        for k in 0..x.ncols() {
1818            leverage += x[[i, k]] * xp[[i, k]];
1819        }
1820        grad_weights[i] += scale * 0.5 * d * leverage;
1821    }
1822
1823    for j in 0..beta.ncols() {
1824        let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1825        let coef = scale * 0.5 * nu / dp;
1826        add_deviance_profile_vjp(
1827            coef,
1828            j,
1829            x,
1830            weights,
1831            beta,
1832            residual,
1833            grad_x,
1834            grad_y,
1835            grad_weights,
1836        );
1837        add_rank_one_penalty_vjp(coef * lambda, beta.column(j), grad_penalty);
1838    }
1839}
1840
1841/// VJP contribution from an upstream gradient on `edf`.
1842///
1843/// With `M = X^T W X + λ S`, `edf = trace(M^{-1} · X^T W X) = p - λ trace(M^{-1} S)`.
1844/// Holding `λ` fixed, the direct partials are
1845///   ∂edf/∂A = λ M^{-1} S M^{-1}      (A = X^T W X, symmetric)
1846///   ∂edf/∂S = −λ M^{-1} A M^{-1} = −λ M^{-1} + λ² M^{-1} S M^{-1}
1847///   ∂edf/∂λ = −trace(M^{-1} S) + λ trace((M^{-1} S)²)
1848/// The λ-component is returned as the lambda_adjoint contribution and routed
1849/// through the implicit-function chain by the caller (same path as
1850/// `upstream_lambda` and `upstream_reml_score`).
1851fn add_edf_vjp(
1852    scale: f64,
1853    x: ArrayView2<'_, f64>,
1854    penalty: ArrayView2<'_, f64>,
1855    weights: &Array1<f64>,
1856    lambda: f64,
1857    inverse_hessian: &Array2<f64>,
1858    grad_x: &mut Array2<f64>,
1859    grad_penalty: &mut Array2<f64>,
1860    grad_weights: &mut Array1<f64>,
1861) -> f64 {
1862    // m_inv_s = M^{-1} S, then g_a = λ M^{-1} S M^{-1} = ∂edf/∂A.
1863    let m_inv_s = dense_ab(inverse_hessian.view(), penalty);
1864    let mut g_a = dense_ab(m_inv_s.view(), inverse_hessian.view());
1865    g_a.mapv_inplace(|v| v * lambda);
1866
1867    // Chain ∂edf/∂A through A = X^T W X.
1868    //   grad_X += scale · 2 · (W X) · G_A
1869    //   grad_w_i += scale · (X G_A X^T)_{ii}
1870    let xg = dense_ab(x, g_a.view());
1871    // Row-scaled dense accumulate: grad_x[i,:] += (2·scale·weights[i]) · xg[i,:].
1872    // (Inlined here — the former `assembly::add_row_scaled_dense_into` helper was
1873    // removed as "unused" by 0cb722d, which missed this gam-pyffi-reachable caller.)
1874    let leading_scale = 2.0 * scale;
1875    for i in 0..xg.nrows() {
1876        let row_scale = leading_scale * weights[i];
1877        for k in 0..xg.ncols() {
1878            grad_x[[i, k]] += row_scale * xg[[i, k]];
1879        }
1880    }
1881    for i in 0..x.nrows() {
1882        let mut quad = 0.0;
1883        for k in 0..x.ncols() {
1884            quad += x[[i, k]] * xg[[i, k]];
1885        }
1886        grad_weights[i] += scale * quad;
1887    }
1888
1889    // ∂edf/∂S = -λ M^{-1} + λ² M^{-1} S M^{-1} = -λ M^{-1} + λ · g_a
1890    // (since g_a = λ M^{-1} S M^{-1}, so λ · g_a = λ² M^{-1} S M^{-1}).
1891    for row in 0..grad_penalty.nrows() {
1892        for col in 0..grad_penalty.ncols() {
1893            grad_penalty[[row, col]] +=
1894                scale * (-lambda * inverse_hessian[[row, col]] + lambda * g_a[[row, col]]);
1895        }
1896    }
1897
1898    // ∂edf/∂λ (with A, S fixed) = -tr(M^{-1} S) + λ tr((M^{-1} S)²).
1899    let p_dim = m_inv_s.nrows();
1900    let mut tr_m_inv_s = 0.0;
1901    for i in 0..p_dim {
1902        tr_m_inv_s += m_inv_s[[i, i]];
1903    }
1904    let mut tr_squared = 0.0;
1905    for i in 0..p_dim {
1906        for j in 0..p_dim {
1907            tr_squared += m_inv_s[[i, j]] * m_inv_s[[j, i]];
1908        }
1909    }
1910    scale * (-tr_m_inv_s + lambda * tr_squared)
1911}
1912
1913fn add_reml_rho_gradient_vjp(
1914    scale: f64,
1915    x: ArrayView2<'_, f64>,
1916    y: ArrayView2<'_, f64>,
1917    penalty: ArrayView2<'_, f64>,
1918    weights: &Array1<f64>,
1919    lambda: f64,
1920    inverse_hessian: &Array2<f64>,
1921    beta: &Array2<f64>,
1922    residual: &Array2<f64>,
1923    sigma2: &Array1<f64>,
1924    nu: f64,
1925    grad_x: &mut Array2<f64>,
1926    grad_y: &mut Array2<f64>,
1927    grad_penalty: &mut Array2<f64>,
1928    grad_weights: &mut Array1<f64>,
1929) {
1930    let d = beta.ncols() as f64;
1931    let inverse_s = dense_ab(inverse_hessian.view(), penalty);
1932    let trace_kernel = dense_ab(inverse_s.view(), inverse_hessian.view());
1933    for row in 0..grad_penalty.nrows() {
1934        for col in 0..grad_penalty.ncols() {
1935            grad_penalty[[row, col]] += scale
1936                * 0.5
1937                * d
1938                * lambda
1939                * (inverse_hessian[[col, row]] - lambda * trace_kernel[[col, row]]);
1940        }
1941    }
1942    let xt = dense_ab(x, trace_kernel.view());
1943    for i in 0..x.nrows() {
1944        let wi = -scale * d * lambda * weights[i];
1945        for k in 0..x.ncols() {
1946            grad_x[[i, k]] += wi * xt[[i, k]];
1947        }
1948        let mut quad = 0.0;
1949        for k in 0..x.ncols() {
1950            quad += x[[i, k]] * xt[[i, k]];
1951        }
1952        grad_weights[i] -= scale * 0.5 * d * lambda * quad;
1953    }
1954
1955    let s_beta = dense_ab(penalty, beta.view());
1956    let mut upstream_beta = Array2::<f64>::zeros(beta.dim());
1957    for j in 0..beta.ncols() {
1958        let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1959        let q = lambda * beta.column(j).dot(&s_beta.column(j));
1960        let q_coef = scale * nu / dp;
1961        for row in 0..beta.nrows() {
1962            upstream_beta[[row, j]] = q_coef * lambda * s_beta[[row, j]];
1963        }
1964        let dp_coef = -scale * 0.5 * nu * q / (dp * dp);
1965        add_rank_one_penalty_vjp(
1966            (0.5 * q_coef + dp_coef) * lambda,
1967            beta.column(j),
1968            grad_penalty,
1969        );
1970        add_deviance_profile_vjp(
1971            dp_coef,
1972            j,
1973            x,
1974            weights,
1975            beta,
1976            residual,
1977            grad_x,
1978            grad_y,
1979            grad_weights,
1980        );
1981    }
1982    // The implicit-root VJP holds lambda fixed inside this partial; only the
1983    // data, penalty, and weight side effects from the ridge solve are needed.
1984    add_ridge_profile_vjp_fixed_lambda(
1985        1.0,
1986        x,
1987        y,
1988        penalty,
1989        weights,
1990        lambda,
1991        inverse_hessian,
1992        beta,
1993        upstream_beta.view(),
1994        grad_x,
1995        grad_y,
1996        grad_penalty,
1997        grad_weights,
1998    );
1999}
2000
2001fn add_rank_one_penalty_vjp(
2002    scale: f64,
2003    beta_col: ArrayView1<'_, f64>,
2004    grad_penalty: &mut Array2<f64>,
2005) {
2006    for row in 0..beta_col.len() {
2007        for col in 0..beta_col.len() {
2008            grad_penalty[[row, col]] += scale * beta_col[row] * beta_col[col];
2009        }
2010    }
2011}
2012
2013fn gaussian_reml_penalty_pseudoinverse_from_cache(cache: &GaussianRemlEigenCache) -> Array2<f64> {
2014    let p = cache.penalty_eigenvalues.len();
2015    let mut scaled_basis = Array2::<f64>::zeros((p, p));
2016    for eig in 0..p {
2017        let delta = cache.penalty_eigenvalues[eig];
2018        if delta > 0.0 {
2019            for row in 0..p {
2020                scaled_basis[[row, eig]] = cache.coefficient_basis[[row, eig]] / delta;
2021            }
2022        }
2023    }
2024    dense_ab(scaled_basis.view(), cache.coefficient_basis.t())
2025}
2026
2027fn add_deviance_profile_vjp(
2028    scale: f64,
2029    output: usize,
2030    x: ArrayView2<'_, f64>,
2031    weights: &Array1<f64>,
2032    beta: &Array2<f64>,
2033    residual: &Array2<f64>,
2034    grad_x: &mut Array2<f64>,
2035    grad_y: &mut Array2<f64>,
2036    grad_weights: &mut Array1<f64>,
2037) {
2038    for i in 0..x.nrows() {
2039        let r = residual[[i, output]];
2040        let wr_scale = scale * weights[i] * r;
2041        grad_y[[i, output]] += 2.0 * wr_scale;
2042        for k in 0..x.ncols() {
2043            grad_x[[i, k]] -= 2.0 * wr_scale * beta[[k, output]];
2044        }
2045        grad_weights[i] += scale * r * r;
2046    }
2047}
2048
2049fn validate_initial_lambda(lambda: f64) -> Result<f64, EstimationError> {
2050    if lambda.is_finite() && lambda > 0.0 {
2051        Ok(lambda)
2052    } else {
2053        Err(EstimationError::InvalidInput(format!(
2054            "Gaussian REML initial lambda must be finite and positive; got {lambda}"
2055        )))
2056    }
2057}
2058
2059fn dense_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2060    fast_ab(&a, &b)
2061}
2062
2063fn dense_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2064    fast_atb(&a, &b)
2065}
2066
2067fn dense_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Array2<f64> {
2068    fast_xt_diag_x(&x, &w)
2069}
2070
2071fn dense_xt_diag_y(
2072    x: ArrayView2<'_, f64>,
2073    w: ArrayView1<'_, f64>,
2074    y: ArrayView2<'_, f64>,
2075) -> Array2<f64> {
2076    fast_xt_diag_y(&x, &w, &y)
2077}
2078
2079fn matrix_fingerprint(matrix: ArrayView2<'_, f64>) -> u64 {
2080    let mut hash = 0xcbf29ce484222325_u64;
2081    hash = fnv1a_mix(hash, matrix.nrows() as u64);
2082    hash = fnv1a_mix(hash, matrix.ncols() as u64);
2083    for &value in matrix {
2084        hash = fnv1a_mix(hash, value.to_bits());
2085    }
2086    hash
2087}
2088
2089fn fnv1a_mix(hash: u64, value: u64) -> u64 {
2090    (hash ^ value).wrapping_mul(0x100000001b3)
2091}
2092
2093/// Build eigen caches for K problems that share the same penalty matrix in a
2094/// single phased pipeline. X'WX construction is batched by the caller; each
2095/// cache then uses the same Cholesky/eigendecomposition implementation as the
2096/// single-fit path.
2097pub fn build_gaussian_reml_eigen_cache_batched(
2098    xtwx_matrices: Vec<Array2<f64>>,
2099    penalty: ArrayView2<'_, f64>,
2100    nullspace_dim: Option<usize>,
2101) -> Vec<Result<GaussianRemlEigenCache, EstimationError>> {
2102    let penalty_owned = canonicalize_penalty(penalty);
2103    let penalty = penalty_owned.view();
2104    let k = xtwx_matrices.len();
2105    if k == 0 {
2106        return Vec::new();
2107    }
2108    let fingerprints: Vec<u64> = xtwx_matrices
2109        .iter()
2110        .map(|m| matrix_fingerprint(m.view()))
2111        .collect();
2112
2113    let p = xtwx_matrices[0].nrows();
2114    let uniform_square = p > 0 && xtwx_matrices.iter().all(|matrix| matrix.dim() == (p, p));
2115    if uniform_square && k > 1 {
2116        let mut lower_matrices = xtwx_matrices.clone();
2117        if gam_gpu::try_cholesky_batched_lower_inplace(&mut lower_matrices).is_some() {
2118            // The batched penalty transform is an optional accelerator. On
2119            // failure we must NOT fabricate an empty Vec (indexing it per-block
2120            // would silently drop the transform for every block and could index
2121            // out of range) — instead route every block through the same
2122            // no-GPU-transform path used when the batched transform is
2123            // unavailable, which recomputes the whitened penalty on CPU from the
2124            // already-valid Cholesky factor `lower`.
2125            let transforms = batched_whitened_penalty_transforms(&lower_matrices, penalty);
2126            return lower_matrices
2127                .into_iter()
2128                .enumerate()
2129                .map(|(b, lower)| {
2130                    let precomputed_transform = transforms.as_ref().map(|t| t[b].clone());
2131                    gaussian_reml_eigen_cache_from_lower_with_transform(
2132                        lower,
2133                        penalty,
2134                        nullspace_dim,
2135                        fingerprints[b],
2136                        precomputed_transform,
2137                    )
2138                })
2139                .collect();
2140        }
2141    }
2142
2143    let mut results = Vec::with_capacity(k);
2144    for (b, xtwx) in xtwx_matrices.into_iter().enumerate() {
2145        let lower = match gaussian_reml_cholesky_lower(xtwx) {
2146            Ok(l) => l,
2147            Err(err) => {
2148                results.push(Err(err));
2149                continue;
2150            }
2151        };
2152        results.push(gaussian_reml_eigen_cache_from_lower_with_transform(
2153            lower,
2154            penalty,
2155            nullspace_dim,
2156            fingerprints[b],
2157            None,
2158        ));
2159    }
2160    results
2161}
2162
2163fn batched_whitened_penalty_transforms(
2164    lowers: &[Array2<f64>],
2165    penalty: ArrayView2<'_, f64>,
2166) -> Option<Vec<Array2<f64>>> {
2167    let first = lowers.first()?;
2168    let p = first.nrows();
2169    if p == 0 || first.ncols() != p || lowers.iter().any(|lower| lower.dim() != (p, p)) {
2170        return None;
2171    }
2172    let mut linv_stack = Array3::<f64>::zeros((lowers.len(), p, p));
2173    for (idx, lower) in lowers.iter().enumerate() {
2174        let l_inv = invert_lower_triangular(lower).ok()?;
2175        linv_stack.slice_mut(s![idx, .., ..]).assign(&l_inv);
2176    }
2177    let penalty_in_metric =
2178        gam_gpu::try_fast_ab_broadcast_b_batched(linv_stack.view(), penalty)?;
2179    let transformed =
2180        gam_gpu::try_fast_abt_strided_batched(penalty_in_metric.view(), linv_stack.view())?;
2181    Some(
2182        transformed
2183            .axis_iter(Axis(0))
2184            .map(|matrix| matrix.to_owned())
2185            .collect(),
2186    )
2187}
2188
2189pub fn build_gaussian_reml_eigen_cache(
2190    x: ArrayView2<'_, f64>,
2191    penalty: ArrayView2<'_, f64>,
2192    weights: Option<ArrayView1<'_, f64>>,
2193) -> Result<GaussianRemlEigenCache, EstimationError> {
2194    build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, weights)
2195}
2196
2197pub fn build_gaussian_reml_eigen_cache_with_nullspace_dim(
2198    x: ArrayView2<'_, f64>,
2199    penalty: ArrayView2<'_, f64>,
2200    nullspace_dim: Option<usize>,
2201    weights: Option<ArrayView1<'_, f64>>,
2202) -> Result<GaussianRemlEigenCache, EstimationError> {
2203    let penalty_owned = canonicalize_penalty(penalty);
2204    let penalty = penalty_owned.view();
2205    let n = x.nrows();
2206    validate_gaussian_reml_design(x, penalty, weights)?;
2207    let weight = gaussian_reml_weights(n, weights)?;
2208
2209    let xtwx = dense_xt_diag_x(x, weight.view());
2210    gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)
2211}
2212
2213fn validate_gaussian_reml_design(
2214    x: ArrayView2<'_, f64>,
2215    penalty: ArrayView2<'_, f64>,
2216    weights: Option<ArrayView1<'_, f64>>,
2217) -> Result<(), EstimationError> {
2218    let n = x.nrows();
2219    let p = x.ncols();
2220    if penalty.nrows() != p || penalty.ncols() != p {
2221        crate::bail_invalid_estim!(
2222            "Gaussian REML penalty shape mismatch: expected {p}x{p}, got {}x{}",
2223            penalty.nrows(),
2224            penalty.ncols()
2225        );
2226    }
2227    if x.iter().chain(penalty.iter()).any(|v| !v.is_finite()) {
2228        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2229    }
2230    if let Some(w) = weights {
2231        if w.len() != n {
2232            crate::bail_invalid_estim!(
2233                "Gaussian REML weights length mismatch: expected {n}, got {}",
2234                w.len()
2235            );
2236        }
2237        if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2238            crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2239        }
2240    }
2241    Ok(())
2242}
2243
2244/// Effective observation count for the REML residual degrees of freedom.
2245///
2246/// A prior weight of exactly `0` is the universal "excluded / infinite-variance"
2247/// convention (mgcv, statsmodels): such a row must be equivalent to omitting it
2248/// entirely. The weighted response energy already handles this (`weight[row] *
2249/// y² = 0` for a zero-weight row), and a zero-weight row likewise contributes
2250/// nothing to `XᵀWX` / `XᵀWy`, so it cannot move the coefficients at a fixed
2251/// smoothing parameter. The one place a zero-weight row used to leak in was the
2252/// residual degrees of freedom `ν = n − nullity`, which counted the raw row
2253/// count `n`. That deflated `σ²`, under-smoothed `λ`, and (through `λ`) biased
2254/// the coefficients — growing with the number of zero-weight rows. The residual
2255/// DoF must instead be built from the number of rows that actually enter the
2256/// likelihood, i.e. those with a strictly positive weight.
2257fn effective_observation_count(weight: ArrayView1<'_, f64>) -> usize {
2258    weight.iter().filter(|&&w| w > 0.0).count()
2259}
2260
2261fn gaussian_reml_weights(
2262    n: usize,
2263    weights: Option<ArrayView1<'_, f64>>,
2264) -> Result<Array1<f64>, EstimationError> {
2265    match weights {
2266        Some(w) => {
2267            if w.len() != n {
2268                crate::bail_invalid_estim!(
2269                    "Gaussian REML weights length mismatch: expected {n}, got {}",
2270                    w.len()
2271                );
2272            }
2273            if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2274                crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2275            }
2276            Ok(w.to_owned())
2277        }
2278        None => Ok(Array1::ones(n)),
2279    }
2280}
2281
2282fn gaussian_reml_eigen_cache_from_xtwx(
2283    xtwx: Array2<f64>,
2284    penalty: ArrayView2<'_, f64>,
2285    nullspace_dim: Option<usize>,
2286) -> Result<GaussianRemlEigenCache, EstimationError> {
2287    let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2288    let lower = gaussian_reml_cholesky_lower(xtwx)?;
2289    gaussian_reml_eigen_cache_from_lower(lower, penalty, nullspace_dim, xtwx_fingerprint)
2290}
2291
2292/// Cache-build entry point for callers that have already computed `L =
2293/// chol(X'WX, lower)`. Used by the batched K-way fit path so a single
2294/// `cusolverDnDpotrfBatched` call factors all K matrices, then each cache
2295/// finishes per-fit without re-doing the Cholesky.
2296fn gaussian_reml_eigen_cache_from_lower(
2297    lower: Array2<f64>,
2298    penalty: ArrayView2<'_, f64>,
2299    nullspace_dim: Option<usize>,
2300    xtwx_fingerprint: u64,
2301) -> Result<GaussianRemlEigenCache, EstimationError> {
2302    gaussian_reml_eigen_cache_from_lower_with_transform(
2303        lower,
2304        penalty,
2305        nullspace_dim,
2306        xtwx_fingerprint,
2307        None,
2308    )
2309}
2310
2311/// Cache-build variant that accepts a pre-computed whitened penalty
2312/// `L⁻¹·S·L⁻ᵀ`. Callers pass `None` to compute it from the Cholesky factor.
2313fn gaussian_reml_eigen_cache_from_lower_with_transform(
2314    lower: Array2<f64>,
2315    penalty: ArrayView2<'_, f64>,
2316    nullspace_dim: Option<usize>,
2317    xtwx_fingerprint: u64,
2318    precomputed_transform: Option<Array2<f64>>,
2319) -> Result<GaussianRemlEigenCache, EstimationError> {
2320    let p = lower.nrows();
2321    if lower.ncols() != p {
2322        crate::bail_invalid_estim!("Gaussian REML Cholesky factor must be square");
2323    }
2324    let penalty_fingerprint = matrix_fingerprint(penalty);
2325    let logdet_xtwx = 2.0 * lower.diag().iter().map(|v| v.ln()).sum::<f64>();
2326    let transformed_penalty = match precomputed_transform {
2327        Some(transformed) => transformed,
2328        None => {
2329            let l_inv = invert_lower_triangular(&lower)?;
2330            let penalty_in_metric = dense_ab(l_inv.view(), penalty);
2331            dense_ab(penalty_in_metric.view(), l_inv.t())
2332        }
2333    };
2334    let (mut penalty_eigenvalues, eigenvectors) =
2335        transformed_penalty.eigh(Side::Lower).map_err(|_| {
2336            EstimationError::ModelIsIllConditioned {
2337                condition_number: f64::INFINITY,
2338            }
2339        })?;
2340    // Rank tolerance must be RELATIVE to the largest eigenvalue — never
2341    // floored at an absolute value. The old `.max(1.0)` clamped the
2342    // tolerance up whenever max|eig| < 1, classifying genuine modes as
2343    // null for small-scale penalties (e.g. Wahba pseudo-spline `m=4`
2344    // with `K(p,p) ≈ 3e-4`). That broke REML's invariance under
2345    // `S → c·S` — the optimum λ rescales but the score landscape
2346    // diverges from the true marginal likelihood, and the smooth
2347    // contribution collapsed to ~0 on smooth truths.
2348    // Fully scale-invariant form: `safety · max|eig| · eps`.
2349    let max_abs_eig = penalty_eigenvalues
2350        .iter()
2351        .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2352    let eig_tol = max_abs_eig * EIGEN_REL_TOL;
2353    for value in &mut penalty_eigenvalues {
2354        if *value < 0.0 && value.abs() <= eig_tol {
2355            *value = 0.0;
2356        }
2357        if *value < 0.0 {
2358            crate::bail_invalid_estim!(
2359                "Gaussian REML penalty is not positive semidefinite; eigenvalue={value:.3e}"
2360            );
2361        }
2362    }
2363    let penalty_rank = penalty_eigenvalues
2364        .iter()
2365        .filter(|&&value| value > eig_tol)
2366        .count();
2367    let nullity = p - penalty_rank;
2368    if let Some(expected_nullity) = nullspace_dim
2369        && expected_nullity != nullity
2370    {
2371        crate::bail_invalid_estim!(
2372            "Gaussian REML penalty nullspace mismatch: expected {expected_nullity}, inferred {nullity}"
2373        );
2374    }
2375    let logdet_penalty_positive = gaussian_penalty_positive_logdet(penalty, penalty_rank)?;
2376    let coefficient_basis = solve_upper_triangular_matrix(&lower.t().to_owned(), &eigenvectors)?;
2377
2378    Ok(GaussianRemlEigenCache {
2379        penalty_eigenvalues,
2380        eigenvectors,
2381        coefficient_basis,
2382        xtwx_fingerprint,
2383        penalty_fingerprint,
2384        logdet_xtwx,
2385        logdet_penalty_positive,
2386        penalty_rank,
2387        nullity,
2388    })
2389}
2390
2391fn gaussian_reml_cholesky_lower(xtwx: Array2<f64>) -> Result<Array2<f64>, EstimationError> {
2392    // Attempt Cholesky directly; on failure, retry with a tiny diagonal jitter
2393    // proportional to the matrix trace. X'WX is symmetric positive semidefinite
2394    // by construction, but FP noise (e.g. in a basis whose kernel block is only
2395    // FP-orthogonal to its explicit polynomial nullspace columns, as the
2396    // periodic Duchon basis is) can push the smallest eigenvalue slightly
2397    // negative on adversarial inputs, intermittently failing Cholesky. A
2398    // jitter of 1e-12 * trace/p shifts every eigenvalue up by an amount well
2399    // below the natural scale of the well-conditioned eigenvalues but well
2400    // above f64 FP noise, eliminating the spurious-failure regime.
2401    let mut gpu_candidate = xtwx.clone();
2402    if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2403        return Ok(gpu_candidate);
2404    }
2405    if let Ok(chol) = xtwx.cholesky(Side::Lower) {
2406        return Ok(chol.lower_triangular());
2407    }
2408    let p = xtwx.nrows();
2409    let trace: f64 = (0..p).map(|i| xtwx[[i, i]]).sum();
2410    if !trace.is_finite() || trace <= 0.0 {
2411        return Err(EstimationError::ModelIsIllConditioned {
2412            condition_number: f64::INFINITY,
2413        });
2414    }
2415    let mut jitter = 1e-12 * trace / (p as f64);
2416    for _ in 0..6 {
2417        let mut jittered = xtwx.clone();
2418        for i in 0..p {
2419            jittered[[i, i]] += jitter;
2420        }
2421        let mut gpu_candidate = jittered.clone();
2422        if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2423            return Ok(gpu_candidate);
2424        }
2425        if let Ok(chol) = jittered.cholesky(Side::Lower) {
2426            return Ok(chol.lower_triangular());
2427        }
2428        jitter *= 10.0;
2429    }
2430    Err(EstimationError::ModelIsIllConditioned {
2431        condition_number: f64::INFINITY,
2432    })
2433}
2434
2435fn gaussian_penalty_positive_logdet(
2436    penalty: ArrayView2<'_, f64>,
2437    penalty_rank: usize,
2438) -> Result<f64, EstimationError> {
2439    if penalty_rank == 0 {
2440        return Ok(0.0);
2441    }
2442    let (pen_eigs, _) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
2443        EstimationError::ModelIsIllConditioned {
2444            condition_number: f64::INFINITY,
2445        }
2446    })?;
2447    // Scale-invariant relative tolerance — see the cousin site for the
2448    // rationale. Same `.max(1.0)` floor used to live here and corrupted
2449    // the positive-eigenvalue count for small-scale penalties.
2450    let pen_scale = pen_eigs
2451        .iter()
2452        .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2453    let pen_tol = pen_scale * EIGEN_REL_TOL;
2454    let mut positive_eigs: Vec<f64> = pen_eigs
2455        .iter()
2456        .copied()
2457        .filter(|&value| value > pen_tol)
2458        .collect();
2459    if positive_eigs.len() != penalty_rank {
2460        positive_eigs = pen_eigs
2461            .iter()
2462            .copied()
2463            .filter(|&value| value > 0.0)
2464            .collect();
2465        positive_eigs.sort_by(|a, b| b.total_cmp(a));
2466        if positive_eigs.len() < penalty_rank {
2467            return Err(EstimationError::ModelIsIllConditioned {
2468                condition_number: f64::INFINITY,
2469            });
2470        }
2471        positive_eigs.truncate(penalty_rank);
2472    }
2473    Ok(positive_eigs.iter().map(|value| value.ln()).sum())
2474}
2475
2476fn validate_gaussian_reml_eigen_cache(
2477    cache: &GaussianRemlEigenCache,
2478    p: usize,
2479) -> Result<(), EstimationError> {
2480    if cache.penalty_eigenvalues.len() != p
2481        || cache.eigenvectors.dim() != (p, p)
2482        || cache.coefficient_basis.dim() != (p, p)
2483    {
2484        crate::bail_invalid_estim!(
2485            "Gaussian REML eigen cache dimension mismatch: expected {p} coefficients"
2486        );
2487    }
2488    if cache.penalty_rank > p || cache.nullity > p || cache.penalty_rank + cache.nullity != p {
2489        crate::bail_invalid_estim!(
2490            "Gaussian REML eigen cache rank/nullity mismatch: rank={}, nullity={}, p={p}",
2491            cache.penalty_rank,
2492            cache.nullity
2493        );
2494    }
2495    if !(cache.logdet_xtwx.is_finite() && cache.logdet_penalty_positive.is_finite()) {
2496        crate::bail_invalid_estim!("Gaussian REML eigen cache log-determinants must be finite");
2497    }
2498    if cache
2499        .penalty_eigenvalues
2500        .iter()
2501        .any(|value| !value.is_finite() || *value < 0.0)
2502        || cache.eigenvectors.iter().any(|value| !value.is_finite())
2503        || cache
2504            .coefficient_basis
2505            .iter()
2506            .any(|value| !value.is_finite())
2507    {
2508        crate::bail_invalid_estim!(
2509            "Gaussian REML eigen cache entries must be finite with non-negative eigenvalues"
2510                .to_string(),
2511        );
2512    }
2513    Ok::<(), _>(())
2514}
2515
2516fn prepare_gaussian_reml(
2517    x: ArrayView2<'_, f64>,
2518    y: ArrayView2<'_, f64>,
2519    penalty: ArrayView2<'_, f64>,
2520    nullspace_dim: Option<usize>,
2521    weights: Option<ArrayView1<'_, f64>>,
2522    eigen_cache: Option<&GaussianRemlEigenCache>,
2523) -> Result<GaussianRemlPrepared, EstimationError> {
2524    // Enforce the symmetric-S contract once at the central forward chokepoint;
2525    // every closed-form forward path funnels through here.
2526    let penalty_owned = canonicalize_penalty(penalty);
2527    let penalty = penalty_owned.view();
2528    let n = x.nrows();
2529    let p = x.ncols();
2530    let d = y.ncols();
2531    validate_gaussian_reml_design(x, penalty, weights)?;
2532    if y.nrows() != n {
2533        crate::bail_invalid_estim!(
2534            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
2535            y.nrows()
2536        );
2537    }
2538    if y.iter().any(|v| !v.is_finite()) {
2539        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2540    }
2541    let weight = gaussian_reml_weights(n, weights)?;
2542    let n_effective = effective_observation_count(weight.view());
2543
2544    let xtwy = dense_xt_diag_y(x, weight.view(), y);
2545    let ywy = Array1::from_iter((0..d).map(|j| {
2546        let mut value = 0.0;
2547        for row in 0..n {
2548            value += weight[row] * y[[row, j]] * y[[row, j]];
2549        }
2550        value
2551    }));
2552    let xtwx = dense_xt_diag_x(x, weight.view());
2553
2554    if let Some(cache) = eigen_cache {
2555        validate_gaussian_reml_eigen_cache(cache, p)?;
2556        let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2557        if cache.xtwx_fingerprint != xtwx_fingerprint {
2558            crate::bail_invalid_estim!("Gaussian REML eigen cache X'WX mismatch");
2559        }
2560        let penalty_fingerprint = matrix_fingerprint(penalty);
2561        if cache.penalty_fingerprint != penalty_fingerprint {
2562            crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
2563        }
2564        if let Some(expected_nullity) = nullspace_dim
2565            && expected_nullity != cache.nullity
2566        {
2567            crate::bail_invalid_estim!(
2568                "Gaussian REML eigen cache nullspace mismatch: expected {expected_nullity}, got {}",
2569                cache.nullity
2570            );
2571        }
2572        if n_effective <= cache.nullity {
2573            crate::bail_invalid_estim!(
2574                "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
2575                cache.nullity
2576            );
2577        }
2578        let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2579        let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2580        return Ok(GaussianRemlPrepared {
2581            cache: cache.clone(),
2582            ywy,
2583            projected_rhs_squared,
2584            projected_rhs,
2585            n_effective,
2586            n_outputs: d,
2587        });
2588    }
2589
2590    let cache = gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)?;
2591    if n_effective <= cache.nullity {
2592        crate::bail_invalid_estim!(
2593            "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
2594            cache.nullity
2595        );
2596    }
2597    let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2598    let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2599
2600    Ok(GaussianRemlPrepared {
2601        cache,
2602        ywy,
2603        projected_rhs_squared,
2604        projected_rhs,
2605        n_effective,
2606        n_outputs: d,
2607    })
2608}
2609
2610impl GaussianRemlPrepared {
2611    fn nu(&self) -> f64 {
2612        self.n_effective as f64 - self.cache.nullity as f64
2613    }
2614
2615    fn evaluate(&self, rho: f64) -> ObjectiveEval {
2616        evaluate_reml_parts(
2617            &self.cache,
2618            self.ywy.view(),
2619            self.projected_rhs_squared.view(),
2620            self.n_effective,
2621            self.n_outputs,
2622            rho,
2623        )
2624    }
2625
2626    fn coefficients(&self, lambda: f64) -> Array2<f64> {
2627        let mut scaled = self.projected_rhs.clone();
2628        for i in 0..self.cache.penalty_eigenvalues.len() {
2629            let scale = 1.0 / (1.0 + lambda * self.cache.penalty_eigenvalues[i]);
2630            for value in scaled.row_mut(i) {
2631                *value *= scale;
2632            }
2633        }
2634        dense_ab(self.cache.coefficient_basis.view(), scaled.view())
2635    }
2636
2637    fn sigma2(&self, lambda: f64) -> Array1<f64> {
2638        let nu = self.nu();
2639        Array1::from_iter((0..self.n_outputs).map(|j| {
2640            let mut fitted_quadratic = 0.0;
2641            for i in 0..self.cache.penalty_eigenvalues.len() {
2642                let denom = 1.0 + lambda * self.cache.penalty_eigenvalues[i];
2643                fitted_quadratic += self.projected_rhs_squared[[i, j]] / denom;
2644            }
2645            ((self.ywy[j] - fitted_quadratic).max(MIN_DEVIANCE)) / nu
2646        }))
2647    }
2648}
2649
2650fn optimize_rho(
2651    prepared: &GaussianRemlPrepared,
2652    init_rho: Option<f64>,
2653) -> Result<f64, EstimationError> {
2654    if prepared.cache.penalty_rank == 0 {
2655        return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2656    }
2657
2658    const GRID_INTERVALS: usize = 96;
2659    let mut stationary = Vec::<f64>::new();
2660    let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2661    let mut prev_rho = RHO_LOWER;
2662    let mut prev_eval = prepared.evaluate(prev_rho);
2663    grid.push((prev_rho, prev_eval.cost));
2664    for i in 1..=GRID_INTERVALS {
2665        let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2666        let eval = prepared.evaluate(rho);
2667        grid.push((rho, eval.cost));
2668        if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2669            push_candidate(
2670                &mut stationary,
2671                refine_stationary_rho(prepared, prev_rho, rho, 0.5 * (prev_rho + rho)),
2672            );
2673        }
2674        prev_rho = rho;
2675        prev_eval = eval;
2676    }
2677
2678    let mut candidates = stationary;
2679    push_candidate(&mut candidates, RHO_LOWER);
2680    push_candidate(&mut candidates, RHO_UPPER);
2681    if let Some(rho0) = init_rho {
2682        push_candidate(&mut candidates, rho0);
2683    }
2684    if let Some(rho) = refine_best_grid_cell(prepared, &grid) {
2685        push_candidate(&mut candidates, rho);
2686    }
2687
2688    // Evaluate each candidate exactly once. `min_by` over a comparator that
2689    // re-evaluates would do O(n log n) extra `prepared.evaluate` calls during
2690    // the sort.
2691    candidates
2692        .into_iter()
2693        .map(|rho| (rho, prepared.evaluate(rho).cost))
2694        .min_by(|(_, a), (_, b)| a.total_cmp(b))
2695        .map(|(rho, _)| rho)
2696        .ok_or_else(|| {
2697            EstimationError::InvalidInput(
2698                "Gaussian REML optimizer produced no candidates".to_string(),
2699            )
2700        })
2701}
2702
2703fn refine_best_grid_cell(prepared: &GaussianRemlPrepared, grid: &[(f64, f64)]) -> Option<f64> {
2704    let best_idx = grid
2705        .iter()
2706        .enumerate()
2707        .filter(|(_, (_, cost))| cost.is_finite())
2708        .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2709        .map(|(idx, _)| idx)?;
2710    if best_idx == 0 || best_idx + 1 == grid.len() {
2711        return Some(grid[best_idx].0);
2712    }
2713    // The best interior grid cell brackets a genuine REML minimum (its cost is
2714    // below both neighbours), so the objective gradient changes sign across
2715    // `[grid[i-1], grid[i+1]]`. Refine to that stationary point (∂V/∂ρ = 0)
2716    // rather than minimising the cost with a golden section: the cost-based
2717    // search only locates ρ to ~√ε of the cell (~1e-8), whereas the
2718    // grad-sign-change branch already contributes stationary candidates
2719    // converged to GRAD_TOL (~1e-12). When both target the same minimum, the
2720    // ~1e-16 cost ordering between a 1e-8-accurate and a 1e-12-accurate ρ is
2721    // numerical noise, so `min_by(cost)` used to pick between two ρ values
2722    // ~1e-8 apart essentially at random — making the selected λ̂ a
2723    // non-smooth function of the design X (its ~1e-8 jumps wrecked the
2724    // closed-form REML reverse-mode VJP's agreement with finite differences).
2725    // Returning the stationary point makes every interior candidate a
2726    // GRAD_TOL-accurate root, so the residual selection jitter collapses to
2727    // ~1e-12 and λ̂(X) is smooth to the IFT gradient.
2728    Some(refine_stationary_rho(
2729        prepared,
2730        grid[best_idx - 1].0,
2731        grid[best_idx + 1].0,
2732        grid[best_idx].0,
2733    ))
2734}
2735
2736fn fill_weighted_rhs_no_alloc(
2737    x: ArrayView2<'_, f64>,
2738    y: ArrayView2<'_, f64>,
2739    weights: Option<ArrayView1<'_, f64>>,
2740    workspace: &mut GaussianRemlNoAllocWorkspace,
2741) -> Result<(), EstimationError> {
2742    let d = y.ncols();
2743
2744    // XᵀWY and YᵀWY via faer BLAS. Both `fast_xt_diag_y` and `fast_atb`
2745    // dispatch to faer's SIMD-optimized GEMM (with chunked weight scaling
2746    // when weights are present), replacing the previous scalar triple loop
2747    // over (n, p, d). For YᵀWY we only need the diagonal entries, but d is
2748    // small (typically 1–10) so computing the full d×d Gram is negligible.
2749    let (xtwy, ywy_full) = match weights {
2750        Some(w) => (fast_xt_diag_y(&x, &w, &y), fast_xt_diag_y(&y, &w, &y)),
2751        None => (fast_atb(&x, &y), fast_atb(&y, &y)),
2752    };
2753    workspace.xtwy.assign(&xtwy);
2754    for output in 0..d {
2755        workspace.ywy[output] = ywy_full[[output, output]];
2756    }
2757
2758    if workspace
2759        .xtwy
2760        .iter()
2761        .chain(workspace.ywy.iter())
2762        .any(|value| !value.is_finite())
2763    {
2764        crate::bail_invalid_estim!("Gaussian REML weighted cross-products must be finite");
2765    }
2766    Ok(())
2767}
2768
2769fn project_rhs_no_alloc(
2770    cache: &GaussianRemlEigenCache,
2771    workspace: &mut GaussianRemlNoAllocWorkspace,
2772) {
2773    // projected_rhs = coefficient_basisᵀ · xtwy, computed via faer BLAS
2774    // (was previously a scalar triple loop over (p, d, p)).
2775    let projected = fast_atb(&cache.coefficient_basis, &workspace.xtwy);
2776    workspace.projected_rhs.assign(&projected);
2777    let p = cache.penalty_eigenvalues.len();
2778    let d = workspace.ywy.len();
2779    for eig in 0..p {
2780        for output in 0..d {
2781            let value = workspace.projected_rhs[[eig, output]];
2782            workspace.projected_rhs_squared[[eig, output]] = value * value;
2783        }
2784    }
2785}
2786
2787fn evaluate_reml_parts(
2788    cache: &GaussianRemlEigenCache,
2789    ywy: ArrayView1<'_, f64>,
2790    projected_rhs_squared: ArrayView2<'_, f64>,
2791    n_effective: usize,
2792    n_outputs: usize,
2793    rho: f64,
2794) -> ObjectiveEval {
2795    let lambda = rho.exp();
2796    let nu = n_effective as f64 - cache.nullity as f64;
2797    let d = n_outputs as f64;
2798
2799    // Each term's value and its ρ-derivatives come back from ONE function so
2800    // they cannot be edited independently; `+=` folds the triple in lock-step.
2801    let (logdet_term, edf) = gaussian_reml_logdet_term(cache, rho, d);
2802    let mut eval = ObjectiveEval {
2803        cost: 0.0,
2804        grad: 0.0,
2805        hess: 0.0,
2806        edf,
2807    };
2808    eval += logdet_term;
2809    for output in 0..n_outputs {
2810        eval +=
2811            gaussian_reml_dispersion_term(cache, ywy, projected_rhs_squared, output, nu, lambda);
2812    }
2813    eval
2814}
2815
2816fn optimize_rho_no_alloc(
2817    cache: &GaussianRemlEigenCache,
2818    ywy: ArrayView1<'_, f64>,
2819    projected_rhs_squared: ArrayView2<'_, f64>,
2820    n_effective: usize,
2821    n_outputs: usize,
2822    init_rho: Option<f64>,
2823) -> Result<f64, EstimationError> {
2824    if cache.penalty_rank == 0 {
2825        return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2826    }
2827
2828    let lower_eval = evaluate_reml_parts(
2829        cache,
2830        ywy,
2831        projected_rhs_squared,
2832        n_effective,
2833        n_outputs,
2834        RHO_LOWER,
2835    );
2836
2837    let mut best_rho = RHO_LOWER;
2838    let mut best_cost = lower_eval.cost;
2839
2840    const GRID_INTERVALS: usize = 96;
2841    let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2842    let mut prev_rho = RHO_LOWER;
2843    let mut prev_eval = lower_eval;
2844    grid.push((prev_rho, prev_eval.cost));
2845    for i in 1..=GRID_INTERVALS {
2846        let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2847        let eval = evaluate_reml_parts(
2848            cache,
2849            ywy,
2850            projected_rhs_squared,
2851            n_effective,
2852            n_outputs,
2853            rho,
2854        );
2855        grid.push((rho, eval.cost));
2856        if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2857            let stationary_rho = refine_stationary_rho_no_alloc(
2858                cache,
2859                ywy,
2860                projected_rhs_squared,
2861                n_effective,
2862                n_outputs,
2863                prev_rho,
2864                rho,
2865                0.5 * (prev_rho + rho),
2866            );
2867            consider_rho_no_alloc(
2868                cache,
2869                ywy,
2870                projected_rhs_squared,
2871                n_effective,
2872                n_outputs,
2873                stationary_rho,
2874                &mut best_rho,
2875                &mut best_cost,
2876            );
2877        }
2878        prev_rho = rho;
2879        prev_eval = eval;
2880    }
2881    if let Some(best_idx) = grid
2882        .iter()
2883        .enumerate()
2884        .filter(|(_, (_, cost))| cost.is_finite())
2885        .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2886        .map(|(idx, _)| idx)
2887    {
2888        let refined = if best_idx == 0 || best_idx + 1 == grid.len() {
2889            grid[best_idx].0
2890        } else {
2891            // Refine the best interior grid cell to the REML stationary point
2892            // (∂V/∂ρ = 0) rather than the golden-section cost minimum, mirroring
2893            // the allocating `refine_best_grid_cell`. A cost-based search locates
2894            // ρ only to ~1e-8, which competed against the GRAD_TOL-accurate
2895            // (~1e-12) stationary candidates in the cost `min_by` below and made
2896            // the selected λ̂ jump ~1e-8 with the design — a non-smoothness the
2897            // closed-form REML VJP could not match under finite differences.
2898            // (Keeping both optimizers' refinement identical preserves their
2899            // allocating/no-alloc bit-for-bit parity.)
2900            refine_stationary_rho_no_alloc(
2901                cache,
2902                ywy,
2903                projected_rhs_squared,
2904                n_effective,
2905                n_outputs,
2906                grid[best_idx - 1].0,
2907                grid[best_idx + 1].0,
2908                grid[best_idx].0,
2909            )
2910        };
2911        consider_rho_no_alloc(
2912            cache,
2913            ywy,
2914            projected_rhs_squared,
2915            n_effective,
2916            n_outputs,
2917            refined,
2918            &mut best_rho,
2919            &mut best_cost,
2920        );
2921    }
2922
2923    consider_rho_no_alloc(
2924        cache,
2925        ywy,
2926        projected_rhs_squared,
2927        n_effective,
2928        n_outputs,
2929        RHO_UPPER,
2930        &mut best_rho,
2931        &mut best_cost,
2932    );
2933    if let Some(rho0) = init_rho {
2934        consider_rho_no_alloc(
2935            cache,
2936            ywy,
2937            projected_rhs_squared,
2938            n_effective,
2939            n_outputs,
2940            rho0,
2941            &mut best_rho,
2942            &mut best_cost,
2943        );
2944    }
2945
2946    if best_cost.is_finite() {
2947        Ok(best_rho)
2948    } else {
2949        Err(EstimationError::InvalidInput(
2950            "Gaussian REML optimizer produced no finite candidates".to_string(),
2951        ))
2952    }
2953}
2954
2955fn consider_rho_no_alloc(
2956    cache: &GaussianRemlEigenCache,
2957    ywy: ArrayView1<'_, f64>,
2958    projected_rhs_squared: ArrayView2<'_, f64>,
2959    n_effective: usize,
2960    n_outputs: usize,
2961    rho: f64,
2962    best_rho: &mut f64,
2963    best_cost: &mut f64,
2964) {
2965    if !rho.is_finite() {
2966        return;
2967    }
2968    let candidate = rho.clamp(RHO_LOWER, RHO_UPPER);
2969    let eval = evaluate_reml_parts(
2970        cache,
2971        ywy,
2972        projected_rhs_squared,
2973        n_effective,
2974        n_outputs,
2975        candidate,
2976    );
2977    if eval.cost < *best_cost {
2978        *best_rho = candidate;
2979        *best_cost = eval.cost;
2980    }
2981}
2982
2983fn refine_stationary_rho_no_alloc(
2984    cache: &GaussianRemlEigenCache,
2985    ywy: ArrayView1<'_, f64>,
2986    projected_rhs_squared: ArrayView2<'_, f64>,
2987    n_effective: usize,
2988    n_outputs: usize,
2989    mut lo: f64,
2990    mut hi: f64,
2991    mut rho: f64,
2992) -> f64 {
2993    for _ in 0..80 {
2994        let eval = evaluate_reml_parts(
2995            cache,
2996            ywy,
2997            projected_rhs_squared,
2998            n_effective,
2999            n_outputs,
3000            rho,
3001        );
3002        if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3003            return rho;
3004        }
3005        if eval.grad >= 0.0 {
3006            hi = rho;
3007        } else {
3008            lo = rho;
3009        }
3010        let newton = if eval.hess > 0.0 {
3011            let candidate = rho - eval.grad / eval.hess;
3012            (candidate > lo && candidate < hi).then_some(candidate)
3013        } else {
3014            None
3015        };
3016        if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3017            break;
3018        }
3019        rho = newton.unwrap_or(0.5 * (lo + hi));
3020    }
3021    0.5 * (lo + hi)
3022}
3023
3024fn fill_coefficients_no_alloc(
3025    cache: &GaussianRemlEigenCache,
3026    workspace: &mut GaussianRemlNoAllocWorkspace,
3027    lambda: f64,
3028    mut coefficients: ArrayViewMut2<'_, f64>,
3029) {
3030    let p = cache.penalty_eigenvalues.len();
3031    let d = workspace.ywy.len();
3032    for eig in 0..p {
3033        let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
3034        for output in 0..d {
3035            workspace.scaled_projected_rhs[[eig, output]] =
3036                workspace.projected_rhs[[eig, output]] * scale;
3037        }
3038    }
3039
3040    for col in 0..p {
3041        for output in 0..d {
3042            let mut value = 0.0;
3043            for eig in 0..p {
3044                value += cache.coefficient_basis[[col, eig]]
3045                    * workspace.scaled_projected_rhs[[eig, output]];
3046            }
3047            coefficients[[col, output]] = value;
3048        }
3049    }
3050}
3051
3052fn fill_fitted_no_alloc(
3053    x: ArrayView2<'_, f64>,
3054    coefficients: ArrayView2<'_, f64>,
3055    mut fitted: ArrayViewMut2<'_, f64>,
3056) {
3057    let n = x.nrows();
3058    let p = x.ncols();
3059    let d = coefficients.ncols();
3060    for row in 0..n {
3061        for output in 0..d {
3062            let mut value = 0.0;
3063            for col in 0..p {
3064                value += x[[row, col]] * coefficients[[col, output]];
3065            }
3066            fitted[[row, output]] = value;
3067        }
3068    }
3069}
3070
3071fn fill_sigma2_no_alloc(
3072    cache: &GaussianRemlEigenCache,
3073    ywy: ArrayView1<'_, f64>,
3074    projected_rhs_squared: ArrayView2<'_, f64>,
3075    n_effective: usize,
3076    n_outputs: usize,
3077    lambda: f64,
3078    mut sigma2: ArrayViewMut1<'_, f64>,
3079) {
3080    let nu = n_effective as f64 - cache.nullity as f64;
3081    for output in 0..n_outputs {
3082        let mut fitted_quadratic = 0.0;
3083        for eig in 0..cache.penalty_eigenvalues.len() {
3084            let denom = 1.0 + lambda * cache.penalty_eigenvalues[eig];
3085            fitted_quadratic += projected_rhs_squared[[eig, output]] / denom;
3086        }
3087        sigma2[output] = ((ywy[output] - fitted_quadratic).max(MIN_DEVIANCE)) / nu;
3088    }
3089}
3090
3091fn push_candidate(candidates: &mut Vec<f64>, rho: f64) {
3092    if rho.is_finite() {
3093        candidates.push(rho.clamp(RHO_LOWER, RHO_UPPER));
3094    }
3095}
3096
3097fn refine_stationary_rho(
3098    prepared: &GaussianRemlPrepared,
3099    mut lo: f64,
3100    mut hi: f64,
3101    mut rho: f64,
3102) -> f64 {
3103    for _ in 0..80 {
3104        let eval = prepared.evaluate(rho);
3105        if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3106            return rho;
3107        }
3108        if eval.grad >= 0.0 {
3109            hi = rho;
3110        } else {
3111            lo = rho;
3112        }
3113        let newton = if eval.hess > 0.0 {
3114            let candidate = rho - eval.grad / eval.hess;
3115            (candidate > lo && candidate < hi).then_some(candidate)
3116        } else {
3117            None
3118        };
3119        if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3120            break;
3121        }
3122        rho = newton.unwrap_or(0.5 * (lo + hi));
3123    }
3124    0.5 * (lo + hi)
3125}
3126
3127fn invert_lower_triangular(lower: &Array2<f64>) -> Result<Array2<f64>, EstimationError> {
3128    let n = lower.nrows();
3129    if lower.ncols() != n {
3130        crate::bail_invalid_estim!("lower-triangular solve requires a square matrix");
3131    }
3132    let eye = Array2::eye(n);
3133    solve_lower_triangular_matrix(lower, &eye)
3134}
3135
3136fn solve_lower_triangular_matrix(
3137    lower: &Array2<f64>,
3138    rhs: &Array2<f64>,
3139) -> Result<Array2<f64>, EstimationError> {
3140    let n = lower.nrows();
3141    if lower.ncols() != n || rhs.nrows() != n {
3142        crate::bail_invalid_estim!("lower-triangular solve dimension mismatch");
3143    }
3144    if let Some(out) = gam_gpu::try_solve_lower_triangular_matrix(lower.view(), rhs.view()) {
3145        return Ok(out);
3146    }
3147    let mut out = Array2::<f64>::zeros(rhs.dim());
3148    for col in 0..rhs.ncols() {
3149        for i in 0..n {
3150            let mut value = rhs[[i, col]];
3151            for k in 0..i {
3152                value -= lower[[i, k]] * out[[k, col]];
3153            }
3154            let diag = lower[[i, i]];
3155            if !(diag.is_finite() && diag.abs() > 0.0) {
3156                return Err(EstimationError::ModelIsIllConditioned {
3157                    condition_number: f64::INFINITY,
3158                });
3159            }
3160            out[[i, col]] = value / diag;
3161        }
3162    }
3163    Ok(out)
3164}
3165
3166/// Solve the SPD system `L Lᵀ X = rhs` for `X` given the lower Cholesky factor
3167/// `L` (as returned by [`gaussian_reml_cholesky_lower`]): a forward solve
3168/// against `L` followed by a back solve against `Lᵀ`.
3169fn solve_spd_from_lower_factor(
3170    lower: &Array2<f64>,
3171    rhs: &Array2<f64>,
3172) -> Result<Array2<f64>, EstimationError> {
3173    let forward = solve_lower_triangular_matrix(lower, rhs)?;
3174    solve_upper_triangular_matrix(&lower.t().to_owned(), &forward)
3175}
3176
3177fn solve_upper_triangular_matrix(
3178    upper: &Array2<f64>,
3179    rhs: &Array2<f64>,
3180) -> Result<Array2<f64>, EstimationError> {
3181    let n = upper.nrows();
3182    if upper.ncols() != n || rhs.nrows() != n {
3183        crate::bail_invalid_estim!("upper-triangular solve dimension mismatch");
3184    }
3185    if let Some(out) = gam_gpu::try_solve_upper_triangular_matrix(upper.view(), rhs.view()) {
3186        return Ok(out);
3187    }
3188    let mut out = Array2::<f64>::zeros(rhs.dim());
3189    for col in 0..rhs.ncols() {
3190        for i_rev in 0..n {
3191            let i = n - 1 - i_rev;
3192            let mut value = rhs[[i, col]];
3193            for k in (i + 1)..n {
3194                value -= upper[[i, k]] * out[[k, col]];
3195            }
3196            let diag = upper[[i, i]];
3197            if !(diag.is_finite() && diag.abs() > 0.0) {
3198                return Err(EstimationError::ModelIsIllConditioned {
3199                    condition_number: f64::INFINITY,
3200                });
3201            }
3202            out[[i, col]] = value / diag;
3203        }
3204    }
3205    Ok(out)
3206}
3207
3208#[cfg(test)]
3209mod tests {
3210    use super::*;
3211    use ndarray::array;
3212
3213    #[test]
3214    fn edf_does_not_double_count_penalty_nullspace() {
3215        let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0],];
3216        let y = array![[0.0], [1.0], [1.8], [3.2], [4.1]];
3217        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3218        let result =
3219            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3220                .expect("small full-rank Gaussian REML fit");
3221
3222        assert!(result.edf >= result.cache.nullity as f64);
3223        assert!(result.edf <= x.ncols() as f64 + 1.0e-10);
3224    }
3225
3226    #[test]
3227    fn multi_output_duplicate_columns_match_scalar_fit() {
3228        let x = array![
3229            [1.0, -1.0],
3230            [1.0, -0.5],
3231            [1.0, 0.0],
3232            [1.0, 0.5],
3233            [1.0, 1.0],
3234            [1.0, 1.5],
3235        ];
3236        let y1 = array![0.5, 0.2, 0.0, 0.3, 1.1, 2.0];
3237        let y = Array2::from_shape_fn(
3238            (y1.len(), 2),
3239            |(i, j)| if j == 0 { y1[i] } else { 2.0 * y1[i] },
3240        );
3241        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3242
3243        let scalar =
3244            gaussian_reml_closed_form(x.view(), y1.view(), penalty.view(), None, Some(0.0))
3245                .expect("scalar Gaussian REML fit");
3246        let multi =
3247            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3248                .expect("multi-output Gaussian REML fit");
3249
3250        assert!((multi.rho - scalar.rho).abs() <= 1.0e-8);
3251        for i in 0..x.ncols() {
3252            assert!((multi.coefficients[[i, 0]] - scalar.coefficients[i]).abs() <= 1.0e-8);
3253            assert!((multi.coefficients[[i, 1]] - 2.0 * scalar.coefficients[i]).abs() <= 1.0e-8);
3254        }
3255    }
3256
3257    #[test]
3258    fn warm_start_reuses_cache_and_lambda_seed() {
3259        let x = array![
3260            [1.0, -1.0],
3261            [1.0, -0.25],
3262            [1.0, 0.5],
3263            [1.0, 1.25],
3264            [1.0, 2.0],
3265        ];
3266        let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3267        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3268
3269        let cold =
3270            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3271                .expect("cold fit");
3272        let warm_start = GaussianRemlWarmStart::from_multi_result(&cold);
3273        let warm = gaussian_reml_multi_closed_form_warm_started(
3274            x.view(),
3275            y.view(),
3276            penalty.view(),
3277            None,
3278            Some(&warm_start),
3279        )
3280        .expect("warm-started fit");
3281
3282        assert!((cold.lambda - warm.lambda).abs() <= 1.0e-10);
3283        assert_eq!(cold.cache.xtwx_fingerprint, warm.cache.xtwx_fingerprint);
3284        for i in 0..x.ncols() {
3285            assert!((cold.coefficients[[i, 0]] - warm.coefficients[[i, 0]]).abs() <= 1.0e-10);
3286        }
3287    }
3288
3289    #[test]
3290    fn warm_start_cache_rejects_different_penalty_geometry() {
3291        let x = array![
3292            [1.0, -1.0],
3293            [1.0, -0.25],
3294            [1.0, 0.5],
3295            [1.0, 1.25],
3296            [1.0, 2.0],
3297        ];
3298        let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3299        let penalty_a = array![[0.0, 0.0], [0.0, 1.0]];
3300        let penalty_b = array![[1.0, -1.0], [-1.0, 1.0]];
3301
3302        let first =
3303            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty_a.view(), None, Some(0.0))
3304                .expect("first fit");
3305        let warm_start = GaussianRemlWarmStart::from_multi_result(&first);
3306        let err = gaussian_reml_multi_closed_form_warm_started(
3307            x.view(),
3308            y.view(),
3309            penalty_b.view(),
3310            None,
3311            Some(&warm_start),
3312        )
3313        .expect_err("penalty-mismatched cache must be rejected");
3314
3315        assert!(err.to_string().contains("penalty mismatch"));
3316    }
3317
3318    #[test]
3319    fn no_alloc_cache_path_matches_allocating_fit() {
3320        let x = array![
3321            [1.0, -1.0, 0.25],
3322            [1.0, -0.5, 0.10],
3323            [1.0, 0.0, -0.20],
3324            [1.0, 0.5, -0.05],
3325            [1.0, 1.0, 0.30],
3326            [1.0, 1.5, 0.60],
3327        ];
3328        let y = array![
3329            [0.0, 0.2],
3330            [0.3, 0.1],
3331            [0.4, -0.1],
3332            [0.9, 0.3],
3333            [1.6, 0.8],
3334            [2.2, 1.2],
3335        ];
3336        let weights = array![1.0, 0.8, 1.2, 1.1, 0.9, 1.3];
3337        let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]];
3338
3339        let allocating = gaussian_reml_multi_closed_form_with_cache(
3340            x.view(),
3341            y.view(),
3342            penalty.view(),
3343            Some(weights.view()),
3344            Some(1.0),
3345            None,
3346        )
3347        .expect("allocating fit");
3348        let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3349        let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3350        let mut fitted = Array2::zeros(y.dim());
3351        let mut sigma2 = Array1::zeros(y.ncols());
3352
3353        let no_alloc = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3354            x.view(),
3355            y.view(),
3356            penalty.view(),
3357            Some(weights.view()),
3358            Some(allocating.lambda),
3359            &allocating.cache,
3360            &mut workspace,
3361            coefficients.view_mut(),
3362            fitted.view_mut(),
3363            sigma2.view_mut(),
3364        )
3365        .expect("no-alloc cached fit");
3366
3367        assert!((no_alloc.lambda - allocating.lambda).abs() <= 1.0e-10);
3368        assert!((no_alloc.reml_score - allocating.reml_score).abs() <= 1.0e-8);
3369        assert!((no_alloc.reml_grad_rho - allocating.reml_grad_rho).abs() <= 1.0e-8);
3370        assert!((no_alloc.reml_hess_rho - allocating.reml_hess_rho).abs() <= 1.0e-8);
3371        assert!((no_alloc.edf - allocating.edf).abs() <= 1.0e-10);
3372        for i in 0..x.ncols() {
3373            for j in 0..y.ncols() {
3374                assert!((coefficients[[i, j]] - allocating.coefficients[[i, j]]).abs() <= 1.0e-8);
3375            }
3376        }
3377        for i in 0..x.nrows() {
3378            for j in 0..y.ncols() {
3379                assert!((fitted[[i, j]] - allocating.fitted[[i, j]]).abs() <= 1.0e-8);
3380            }
3381        }
3382        for j in 0..y.ncols() {
3383            assert!((sigma2[j] - allocating.sigma2[j]).abs() <= 1.0e-10);
3384        }
3385    }
3386
3387    #[test]
3388    fn no_alloc_cache_path_rejects_bad_shapes_and_penalty_mismatch() {
3389        let x = array![[1.0, -1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 2.0]];
3390        let y = array![[0.0], [0.2], [0.9], [1.8]];
3391        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3392        let cache = build_gaussian_reml_eigen_cache(x.view(), penalty.view(), None)
3393            .expect("Gaussian REML cache");
3394
3395        let mut bad_workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols() + 1);
3396        let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3397        let mut fitted = Array2::zeros(y.dim());
3398        let mut sigma2 = Array1::zeros(y.ncols());
3399        let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3400            x.view(),
3401            y.view(),
3402            penalty.view(),
3403            None,
3404            Some(1.0),
3405            &cache,
3406            &mut bad_workspace,
3407            coefficients.view_mut(),
3408            fitted.view_mut(),
3409            sigma2.view_mut(),
3410        )
3411        .expect_err("workspace shape mismatch must be rejected");
3412        assert!(err.to_string().contains("workspace shape mismatch"));
3413
3414        let penalty_mismatch = array![[1.0, -1.0], [-1.0, 1.0]];
3415        let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3416        let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3417            x.view(),
3418            y.view(),
3419            penalty_mismatch.view(),
3420            None,
3421            Some(1.0),
3422            &cache,
3423            &mut workspace,
3424            coefficients.view_mut(),
3425            fitted.view_mut(),
3426            sigma2.view_mut(),
3427        )
3428        .expect_err("penalty mismatch must be rejected");
3429        assert!(err.to_string().contains("penalty mismatch"));
3430    }
3431
3432    #[derive(Clone, Copy, Debug)]
3433    enum ForwardScalar {
3434        Lambda,
3435        RemlScore,
3436        Coefficient(usize, usize),
3437        Fitted(usize, usize),
3438        Edf,
3439    }
3440
3441    fn finite_difference_design() -> Array2<f64> {
3442        Array2::from_shape_fn((20, 5), |(row, col)| {
3443            let t = (row as f64 - 9.5) / 10.0;
3444            match col {
3445                0 => 1.0,
3446                1 => t,
3447                2 => 0.5 * (3.0 * t * t - 1.0),
3448                3 => 0.5 * (5.0 * t * t * t - 3.0 * t),
3449                4 => (35.0 * t.powi(4) - 30.0 * t * t + 3.0) / 8.0,
3450                _ => unreachable!(),
3451            }
3452        })
3453    }
3454
3455    fn finite_difference_response(outputs: usize) -> Array2<f64> {
3456        // The truth must NOT lie (essentially) in span(X). The 5-column design
3457        // is Legendre P_0..P_4, so a low-order polynomial + low-frequency sin
3458        // would be fit to near machine precision — driving σ² → 0, dp → 0,
3459        // and ∂score/∂y ≈ ν w r / dp → ∞. Central finite differences with
3460        // Richardson extrapolation cannot resolve such steep, highly-nonlinear
3461        // surfaces at 1e-6 relative because the truncation term scales with
3462        // f^(5)(y), which explodes in that regime. The high-frequency sin
3463        // below is well outside span(P_0..P_4) on t ∈ [-0.95, 0.95], leaving
3464        // a genuine residual (σ² ≈ 1e-3) and an interior REML optimum
3465        // (ρ ≈ -3) at which the analytic-vs-FD comparison is meaningful.
3466        Array2::from_shape_fn((20, outputs), |(row, output)| {
3467            let t = (row as f64 - 9.5) / 10.0;
3468            let phase = output as f64 + 1.0;
3469            0.2 + 0.25 * phase * t - 0.12 * t * t
3470                + (0.08 + 0.03 * phase) * (1.1 * t + 0.3 * phase).sin()
3471                + 0.05 * (7.0 * t + 0.5 * phase).sin()
3472        })
3473    }
3474
3475    fn finite_difference_penalty() -> Array2<f64> {
3476        Array2::from_diag(&array![0.0, 0.8, 1.2, 1.7, 2.3])
3477    }
3478
3479    fn finite_difference_weights() -> Array1<f64> {
3480        Array1::from_shape_fn(20, |row| {
3481            let t = (row as f64 - 9.5) / 10.0;
3482            1.0 + 0.025 * (1.1 * t).sin() + 0.01 * t
3483        })
3484    }
3485
3486    /// Fallible forward-scalar probe. Returns `None` when the closed-form fit
3487    /// rejects the inputs — the relevant case being a penalty perturbation that
3488    /// pushes `S` out of the PSD cone (a single-entry central bump on a
3489    /// null-direction entry drives one eigenvalue slightly negative). Such a
3490    /// point has no well-defined REML objective, so the caller skips it rather
3491    /// than panicking.
3492    fn one_hot_objective_try(
3493        x: ArrayView2<'_, f64>,
3494        y: ArrayView2<'_, f64>,
3495        penalty: ArrayView2<'_, f64>,
3496        weights: ArrayView1<'_, f64>,
3497        target: ForwardScalar,
3498    ) -> Option<f64> {
3499        let fit = gaussian_reml_multi_closed_form_with_cache(
3500            x,
3501            y,
3502            penalty,
3503            Some(weights),
3504            Some(0.85),
3505            None,
3506        )
3507        .ok()?;
3508        Some(match target {
3509            ForwardScalar::Lambda => fit.lambda,
3510            ForwardScalar::RemlScore => fit.reml_score,
3511            ForwardScalar::Coefficient(row, col) => fit.coefficients[[row, col]],
3512            ForwardScalar::Fitted(row, col) => fit.fitted[[row, col]],
3513            ForwardScalar::Edf => fit.edf,
3514        })
3515    }
3516
3517    fn one_hot_objective(
3518        x: ArrayView2<'_, f64>,
3519        y: ArrayView2<'_, f64>,
3520        penalty: ArrayView2<'_, f64>,
3521        weights: ArrayView1<'_, f64>,
3522        target: ForwardScalar,
3523    ) -> f64 {
3524        one_hot_objective_try(x, y, penalty, weights, target)
3525            .expect("finite-difference forward fit")
3526    }
3527
3528    fn one_hot_backward(
3529        x: ArrayView2<'_, f64>,
3530        y: ArrayView2<'_, f64>,
3531        penalty: ArrayView2<'_, f64>,
3532        weights: ArrayView1<'_, f64>,
3533        target: ForwardScalar,
3534    ) -> GaussianRemlBackwardResult {
3535        let mut grad_coefficients = Array2::<f64>::zeros((x.ncols(), y.ncols()));
3536        let mut grad_fitted = Array2::<f64>::zeros(y.dim());
3537        let (grad_lambda, grad_score, grad_edf, coefficient_upstream, fitted_upstream) =
3538            match target {
3539                ForwardScalar::Lambda => (1.0, 0.0, 0.0, None, None),
3540                ForwardScalar::RemlScore => (0.0, 1.0, 0.0, None, None),
3541                ForwardScalar::Coefficient(row, col) => {
3542                    grad_coefficients[[row, col]] = 1.0;
3543                    (0.0, 0.0, 0.0, Some(grad_coefficients.view()), None)
3544                }
3545                ForwardScalar::Fitted(row, col) => {
3546                    grad_fitted[[row, col]] = 1.0;
3547                    (0.0, 0.0, 0.0, None, Some(grad_fitted.view()))
3548                }
3549                ForwardScalar::Edf => (0.0, 0.0, 1.0, None, None),
3550            };
3551        gaussian_reml_multi_closed_form_backward(
3552            x,
3553            y,
3554            penalty,
3555            Some(weights),
3556            Some(0.85),
3557            grad_lambda,
3558            coefficient_upstream,
3559            fitted_upstream,
3560            grad_score,
3561            grad_edf,
3562        )
3563        .expect("analytic backward VJP")
3564    }
3565
3566    fn assert_fd_close(label: &str, analytic: f64, finite_difference: f64) {
3567        let rel_tol = 1.0e-6_f64;
3568        let abs_tol = 1.0e-6_f64;
3569        let tol = abs_tol.max(rel_tol * analytic.abs().max(finite_difference.abs()));
3570        let diff = (analytic - finite_difference).abs();
3571        assert!(
3572            diff <= tol,
3573            "{label}: analytic={analytic:.12e}, finite_difference={finite_difference:.12e}, diff={diff:.3e}, tol={tol:.3e}"
3574        );
3575    }
3576
3577    fn adaptive_central_difference(mut eval: impl FnMut(f64) -> f64) -> f64 {
3578        let steps: [f64; 5] = [1.0e-3, 5.0e-4, 2.5e-4, 1.25e-4, 6.25e-5];
3579        let mut best = f64::NAN;
3580        let mut best_delta = f64::INFINITY;
3581        let mut previous: Option<f64> = None;
3582        for h in steps {
3583            let d1 = (eval(h) - eval(-h)) / (2.0 * h);
3584            let half_h = 0.5 * h;
3585            let d2 = (eval(half_h) - eval(-half_h)) / (2.0 * half_h);
3586            let estimate: f64 = d2 + (d2 - d1) / 3.0;
3587            if let Some(prev) = previous {
3588                let delta = (estimate - prev).abs();
3589                if delta < best_delta {
3590                    best_delta = delta;
3591                    best = estimate;
3592                }
3593            } else {
3594                best = estimate;
3595            }
3596            previous = Some(estimate);
3597        }
3598        best
3599    }
3600
3601    fn assert_backward_matches_forward_finite_difference(outputs: usize) {
3602        let x = finite_difference_design();
3603        let y = finite_difference_response(outputs);
3604        let penalty = finite_difference_penalty();
3605        let weights = finite_difference_weights();
3606        let targets = [
3607            ForwardScalar::Lambda,
3608            ForwardScalar::RemlScore,
3609            ForwardScalar::Coefficient(3, outputs - 1),
3610            ForwardScalar::Fitted(12, outputs - 1),
3611            ForwardScalar::Edf,
3612        ];
3613        for target in targets {
3614            let backward =
3615                one_hot_backward(x.view(), y.view(), penalty.view(), weights.view(), target);
3616
3617            for row in 0..x.nrows() {
3618                for col in 0..x.ncols() {
3619                    let eval = |delta: f64| {
3620                        let mut candidate = x.clone();
3621                        candidate[[row, col]] += delta;
3622                        one_hot_objective(
3623                            candidate.view(),
3624                            y.view(),
3625                            penalty.view(),
3626                            weights.view(),
3627                            target,
3628                        )
3629                    };
3630                    let fd = adaptive_central_difference(eval);
3631                    assert_fd_close(
3632                        &format!("target={target:?} x[{row},{col}]"),
3633                        backward.grad_x[[row, col]],
3634                        fd,
3635                    );
3636                }
3637            }
3638
3639            for row in 0..y.nrows() {
3640                for col in 0..y.ncols() {
3641                    let eval = |delta: f64| {
3642                        let mut candidate = y.clone();
3643                        candidate[[row, col]] += delta;
3644                        one_hot_objective(
3645                            x.view(),
3646                            candidate.view(),
3647                            penalty.view(),
3648                            weights.view(),
3649                            target,
3650                        )
3651                    };
3652                    let fd = adaptive_central_difference(eval);
3653                    assert_fd_close(
3654                        &format!("target={target:?} y[{row},{col}]"),
3655                        backward.grad_y[[row, col]],
3656                        fd,
3657                    );
3658                }
3659            }
3660
3661            for row in 0..weights.len() {
3662                let eval = |delta: f64| {
3663                    let mut candidate = weights.clone();
3664                    candidate[row] += delta;
3665                    one_hot_objective(x.view(), y.view(), penalty.view(), candidate.view(), target)
3666                };
3667                let fd = adaptive_central_difference(eval);
3668                assert_fd_close(
3669                    &format!("target={target:?} weights[{row}]"),
3670                    backward.grad_weights[row],
3671                    fd,
3672                );
3673            }
3674
3675            // ∂L/∂S over the RANGE-SPACE penalty entries. The REML objective
3676            // carries −½d·log|S|₊ (the pseudo-determinant over the NONZERO
3677            // eigenvalues), so ∂L/∂S is only a finite, FD-verifiable derivative
3678            // where a central ±h bump keeps S inside the PSD cone WITHOUT
3679            // changing its rank. A single-entry bump touching the null
3680            // direction violates both: the −h side drives an eigenvalue
3681            // slightly negative (leaves the cone → fit Err) and the +h side
3682            // turns the zero eigenvalue into a tiny positive one that joins
3683            // log|S|₊ as a −log(ε) term (a rank-change discontinuity in L).
3684            // The null-direction component of the analytic S-gradient is a
3685            // gauge convention for the null space (the L-metric pseudoinverse
3686            // `penalty_pinv` = L⁻ᵀ T⁺ L⁻¹), validated by algebra/consumer, not
3687            // FD. So restrict to the strictly-positive diagonal block (both
3688            // indices in 1..p for the diag([0, 0.8, 1.2, 1.7, 2.3]) fixture,
3689            // where S_rr > 0 and ±h stays PSD at full rank). The forward
3690            // consumes only `S_canon = 0.5(S + Sᵀ)` and the backward returns
3691            // the symmetrized gradient, so a single-entry bump of S[r, c]
3692            // (asymmetric) compares directly against `grad_penalty[r, c]` =
3693            // 0.5(G[r, c] + G[c, r]). Defensively, any entry whose largest ±h
3694            // probe leaves the cone is skipped (cone membership is monotone in
3695            // |h| here, so probing the largest step suffices).
3696            let null_index = 0usize; // diag([0.0, ...]) ⇒ coordinate 0 is the null direction.
3697            let probe_h = 1.0e-3_f64; // matches the largest adaptive_central_difference step.
3698            for r in 0..penalty.nrows() {
3699                for c in 0..penalty.ncols() {
3700                    if r == null_index || c == null_index {
3701                        continue;
3702                    }
3703                    let eval = |delta: f64| {
3704                        let mut candidate = penalty.clone();
3705                        candidate[[r, c]] += delta;
3706                        one_hot_objective(
3707                            x.view(),
3708                            y.view(),
3709                            candidate.view(),
3710                            weights.view(),
3711                            target,
3712                        )
3713                    };
3714                    let cone_safe = {
3715                        let mut s_plus = penalty.clone();
3716                        let mut s_minus = penalty.clone();
3717                        s_plus[[r, c]] += probe_h;
3718                        s_minus[[r, c]] -= probe_h;
3719                        one_hot_objective_try(
3720                            x.view(),
3721                            y.view(),
3722                            s_plus.view(),
3723                            weights.view(),
3724                            target,
3725                        )
3726                        .is_some()
3727                            && one_hot_objective_try(
3728                                x.view(),
3729                                y.view(),
3730                                s_minus.view(),
3731                                weights.view(),
3732                                target,
3733                            )
3734                            .is_some()
3735                    };
3736                    if !cone_safe {
3737                        continue;
3738                    }
3739                    let fd = adaptive_central_difference(eval);
3740                    assert_fd_close(
3741                        &format!("target={target:?} penalty[{r},{c}]"),
3742                        backward.grad_penalty[[r, c]],
3743                        fd,
3744                    );
3745                }
3746            }
3747        }
3748    }
3749
3750    #[test]
3751    fn scalar_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3752        assert_backward_matches_forward_finite_difference(1);
3753    }
3754
3755    #[test]
3756    fn multi_output_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3757        assert_backward_matches_forward_finite_difference(3);
3758    }
3759
3760    #[test]
3761    fn backward_vjp_matches_finite_difference() {
3762        let x = array![
3763            [1.0, -1.0, 0.2],
3764            [1.0, -0.3, -0.1],
3765            [1.0, 0.2, 0.4],
3766            [1.0, 0.8, 0.1],
3767            [1.0, 1.4, 0.5],
3768            [1.0, 2.0, 0.9],
3769        ];
3770        let y = array![
3771            [0.1, -0.2],
3772            [0.2, 0.1],
3773            [0.7, 0.0],
3774            [1.1, 0.3],
3775            [1.8, 0.9],
3776            [2.4, 1.4],
3777        ];
3778        let weights = array![1.0, 0.9, 1.1, 1.2, 0.8, 1.3];
3779        let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.2], [0.0, 0.2, 1.7]];
3780        let upstream_coefficients = array![[0.2, -0.1], [0.05, 0.03], [-0.04, 0.07]];
3781        let upstream_fitted = array![
3782            [0.01, -0.02],
3783            [0.03, 0.01],
3784            [-0.01, 0.02],
3785            [0.04, -0.03],
3786            [0.02, 0.05],
3787            [-0.02, 0.01],
3788        ];
3789        let upstream_lambda = 0.17;
3790        let upstream_score = -0.11;
3791
3792        let backward = gaussian_reml_multi_closed_form_backward(
3793            x.view(),
3794            y.view(),
3795            penalty.view(),
3796            Some(weights.view()),
3797            Some(0.8),
3798            upstream_lambda,
3799            Some(upstream_coefficients.view()),
3800            Some(upstream_fitted.view()),
3801            upstream_score,
3802            0.0,
3803        )
3804        .expect("backward VJP");
3805
3806        let objective = |x_eval: &Array2<f64>, y_eval: &Array2<f64>, w_eval: &Array1<f64>| {
3807            let fit = gaussian_reml_multi_closed_form_with_cache(
3808                x_eval.view(),
3809                y_eval.view(),
3810                penalty.view(),
3811                Some(w_eval.view()),
3812                Some(0.8),
3813                None,
3814            )
3815            .expect("fit for objective");
3816            upstream_lambda * fit.lambda
3817                + upstream_score * fit.reml_score
3818                + (&fit.coefficients * &upstream_coefficients).sum()
3819                + (&fit.fitted * &upstream_fitted).sum()
3820        };
3821        let eps = 1.0e-6;
3822        assert!(objective(&x, &y, &weights).is_finite());
3823
3824        let mut x_plus = x.clone();
3825        let mut x_minus = x.clone();
3826        x_plus[[3, 2]] += eps;
3827        x_minus[[3, 2]] -= eps;
3828        let fd_x =
3829            (objective(&x_plus, &y, &weights) - objective(&x_minus, &y, &weights)) / (2.0 * eps);
3830        assert!(
3831            (fd_x - backward.grad_x[[3, 2]]).abs() <= 2.0e-4,
3832            "grad_x mismatch: analytic={} fd={}",
3833            backward.grad_x[[3, 2]],
3834            fd_x
3835        );
3836
3837        let mut y_plus = y.clone();
3838        let mut y_minus = y.clone();
3839        y_plus[[4, 1]] += eps;
3840        y_minus[[4, 1]] -= eps;
3841        let fd_y =
3842            (objective(&x, &y_plus, &weights) - objective(&x, &y_minus, &weights)) / (2.0 * eps);
3843        assert!(
3844            (fd_y - backward.grad_y[[4, 1]]).abs() <= 2.0e-4,
3845            "grad_y mismatch: analytic={} fd={}",
3846            backward.grad_y[[4, 1]],
3847            fd_y
3848        );
3849
3850        let mut w_plus = weights.clone();
3851        let mut w_minus = weights.clone();
3852        w_plus[2] += eps;
3853        w_minus[2] -= eps;
3854        let fd_w = (objective(&x, &y, &w_plus) - objective(&x, &y, &w_minus)) / (2.0 * eps);
3855        assert!(
3856            (fd_w - backward.grad_weights[2]).abs() <= 2.0e-4,
3857            "grad_weight mismatch: analytic={} fd={}",
3858            backward.grad_weights[2],
3859            fd_w
3860        );
3861
3862        // Combined-seed ∂L/∂S spot-check: perturb individual penalty entries with
3863        // x/y/w held at base, under mixed (λ, score, β, fitted) seeds. The penalty
3864        // [[0,0,0],[0,1,0.2],[0,0.2,1.7]] is nullity 1 (coordinate 0 is the null
3865        // direction); ∂L/∂S is FD-verifiable only on the strictly-positive
3866        // RANGE block (indices 1,2), where a central ±h bump keeps S PSD at full
3867        // rank. Null-touching entries (any index 0) are non-FD-verifiable — the
3868        // −½d·log|S|₊ pseudo-determinant term makes L either cone-leaving or
3869        // rank-change-discontinuous there (see the exhaustive S loop above). A
3870        // single-entry asymmetric bump of S[r, c] compares directly to
3871        // grad_penalty[[r, c]] = 0.5(G[r,c] + G[c,r]), exercising the backward
3872        // symmetrization.
3873        let objective_s = |s_eval: &Array2<f64>| {
3874            let fit = gaussian_reml_multi_closed_form_with_cache(
3875                x.view(),
3876                y.view(),
3877                s_eval.view(),
3878                Some(weights.view()),
3879                Some(0.8),
3880                None,
3881            )
3882            .expect("fit for penalty objective");
3883            upstream_lambda * fit.lambda
3884                + upstream_score * fit.reml_score
3885                + (&fit.coefficients * &upstream_coefficients).sum()
3886                + (&fit.fitted * &upstream_fitted).sum()
3887        };
3888        // (1,1) full-rank diagonal; (1,2) pure off-diagonal between two penalized
3889        // directions; (2,2) full-rank diagonal. All in the strictly-positive
3890        // range block, so ±h stays PSD at full rank.
3891        for (r, c) in [(1usize, 1usize), (1, 2), (2, 2)] {
3892            let mut s_plus = penalty.clone();
3893            let mut s_minus = penalty.clone();
3894            s_plus[[r, c]] += eps;
3895            s_minus[[r, c]] -= eps;
3896            let fd_s = (objective_s(&s_plus) - objective_s(&s_minus)) / (2.0 * eps);
3897            assert!(
3898                (fd_s - backward.grad_penalty[[r, c]]).abs() <= 2.0e-4,
3899                "grad_penalty[{r},{c}] mismatch: analytic={} fd={}",
3900                backward.grad_penalty[[r, c]],
3901                fd_s
3902            );
3903        }
3904    }
3905
3906    #[test]
3907    fn batched_eigen_cache_matches_per_fit_build() {
3908        // Three K=3 problems sharing the same penalty matrix. The batched
3909        // pipeline must produce caches that are bit-exact identical to what
3910        // the per-fit `gaussian_reml_eigen_cache_from_xtwx` builder produces,
3911        // regardless of whether the GPU batched Cholesky kicks in or the
3912        // helper falls through to per-fit Cholesky.
3913        let xtwx_a = array![[4.0, 1.0], [1.0, 3.0]];
3914        let xtwx_b = array![[2.5, -0.5], [-0.5, 1.7]];
3915        let xtwx_c = array![[7.2, 0.3], [0.3, 5.1]];
3916        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3917
3918        let batched = build_gaussian_reml_eigen_cache_batched(
3919            vec![xtwx_a.clone(), xtwx_b.clone(), xtwx_c.clone()],
3920            penalty.view(),
3921            None,
3922        );
3923        assert_eq!(batched.len(), 3);
3924
3925        for (xtwx, batched_cache) in [&xtwx_a, &xtwx_b, &xtwx_c].into_iter().zip(batched.iter()) {
3926            let single = gaussian_reml_eigen_cache_from_xtwx(xtwx.clone(), penalty.view(), None)
3927                .expect("per-fit cache");
3928            let batched_cache = batched_cache.as_ref().expect("batched cache");
3929            assert_eq!(batched_cache.penalty_rank, single.penalty_rank);
3930            assert_eq!(batched_cache.nullity, single.nullity);
3931            assert_eq!(batched_cache.xtwx_fingerprint, single.xtwx_fingerprint);
3932            assert_eq!(
3933                batched_cache.penalty_fingerprint,
3934                single.penalty_fingerprint
3935            );
3936            assert!((batched_cache.logdet_xtwx - single.logdet_xtwx).abs() <= 1.0e-12);
3937            assert!(
3938                (batched_cache.logdet_penalty_positive - single.logdet_penalty_positive).abs()
3939                    <= 1.0e-12
3940            );
3941            for (a, b) in batched_cache
3942                .penalty_eigenvalues
3943                .iter()
3944                .zip(single.penalty_eigenvalues.iter())
3945            {
3946                assert!((a - b).abs() <= 1.0e-12);
3947            }
3948            for ((a, b), _) in batched_cache
3949                .coefficient_basis
3950                .iter()
3951                .zip(single.coefficient_basis.iter())
3952                .zip(0..)
3953            {
3954                assert!((a - b).abs() <= 1.0e-12);
3955            }
3956        }
3957    }
3958
3959    #[test]
3960    fn scalar_rho_optimizer_chooses_lowest_cost_stationary_point() {
3961        let cache = GaussianRemlEigenCache {
3962            penalty_eigenvalues: array![5.2430192311066924e-05, 81734184.18548436],
3963            eigenvectors: Array2::eye(2),
3964            coefficient_basis: Array2::eye(2),
3965            xtwx_fingerprint: 0,
3966            penalty_fingerprint: 0,
3967            logdet_xtwx: 0.0,
3968            logdet_penalty_positive: 0.0,
3969            penalty_rank: 2,
3970            nullity: 0,
3971        };
3972        let prepared = GaussianRemlPrepared {
3973            cache: cache.clone(),
3974            ywy: array![0.5021347226586624],
3975            projected_rhs_squared: array![[0.361060218768292], [0.01014486085547482]],
3976            projected_rhs: array![
3977                [0.361060218768292_f64.sqrt()],
3978                [0.01014486085547482_f64.sqrt()]
3979            ],
3980            n_effective: 100,
3981            n_outputs: 1,
3982        };
3983
3984        let rho = optimize_rho(&prepared, None).expect("allocating rho optimizer");
3985        let no_alloc_rho = optimize_rho_no_alloc(
3986            &cache,
3987            prepared.ywy.view(),
3988            prepared.projected_rhs_squared.view(),
3989            prepared.n_effective,
3990            prepared.n_outputs,
3991            None,
3992        )
3993        .expect("no-alloc rho optimizer");
3994
3995        assert!(
3996            (rho - 4.3251059890).abs() < 1.0e-6,
3997            "rho optimizer selected {rho}, expected the lower-cost later stationary point"
3998        );
3999        assert!(
4000            (no_alloc_rho - rho).abs() < 1.0e-8,
4001            "no-alloc optimizer selected {no_alloc_rho}, allocating selected {rho}"
4002        );
4003        assert!(prepared.evaluate(rho).cost < prepared.evaluate(-18.9277503549).cost);
4004    }
4005
4006    #[test]
4007    fn backward_from_fit_matches_backward_with_refit() {
4008        // The Task 3 state round-trip in pyffi calls `_from_fit`; that path
4009        // must be numerically identical to the refitting `_backward` entry
4010        // when fed the same forward result. This guards the optimization
4011        // against drift when either path is touched.
4012        let x = array![[1.0, -0.9], [1.0, -0.4], [1.0, 0.1], [1.0, 0.6], [1.0, 1.1],];
4013        let y = array![[0.2, -0.1], [0.4, 0.1], [0.7, 0.3], [1.0, 0.5], [1.5, 0.8]];
4014        let penalty = array![[0.0, 0.0], [0.0, 1.5]];
4015        let weights = array![1.05, 0.95, 1.01, 0.99, 1.03];
4016
4017        let refit = gaussian_reml_multi_closed_form_backward(
4018            x.view(),
4019            y.view(),
4020            penalty.view(),
4021            Some(weights.view()),
4022            Some(0.85),
4023            0.2,
4024            None,
4025            None,
4026            -0.1,
4027            0.0,
4028        )
4029        .expect("refit backward");
4030
4031        let fit = gaussian_reml_multi_closed_form_with_cache(
4032            x.view(),
4033            y.view(),
4034            penalty.view(),
4035            Some(weights.view()),
4036            Some(0.85),
4037            None,
4038        )
4039        .expect("forward fit");
4040        let from_fit = gaussian_reml_multi_closed_form_backward_from_fit(
4041            x.view(),
4042            y.view(),
4043            penalty.view(),
4044            Some(weights.view()),
4045            &fit,
4046            0.2,
4047            None,
4048            None,
4049            -0.1,
4050            0.0,
4051        )
4052        .expect("from_fit backward");
4053
4054        for (a, b) in refit.grad_x.iter().zip(from_fit.grad_x.iter()) {
4055            assert!((a - b).abs() <= 1.0e-12);
4056        }
4057        for (a, b) in refit.grad_y.iter().zip(from_fit.grad_y.iter()) {
4058            assert!((a - b).abs() <= 1.0e-12);
4059        }
4060        for (a, b) in refit.grad_weights.iter().zip(from_fit.grad_weights.iter()) {
4061            assert!((a - b).abs() <= 1.0e-12);
4062        }
4063    }
4064
4065    /// Regression: when `K = XᵀWX + λS` is effectively rank-deficient (e.g.
4066    /// `λ` has saturated very large), the backward must NOT error — it must
4067    /// degrade gracefully and return zero gradients of the correct shape.
4068    /// This is the production-training scenario where individual atoms can
4069    /// saturate `λ_k` in early batches; raising here would crash an entire
4070    /// step. We construct the degenerate state by running a real forward
4071    /// fit and then corrupting `reml_hess_rho` to 0 (the gate variable the
4072    /// backward checks). We assert: (a) no error, (b) all gradients finite,
4073    /// (c) shapes match the inputs.
4074    #[test]
4075    fn backward_degrades_gracefully_when_k_is_near_singular() {
4076        // Small, full-rank S with a moderately-conditioned X. The exact
4077        // numbers don't matter; what matters is that we then force the
4078        // ill-conditioned gate to fire.
4079        let x = array![
4080            [1.0, -1.0, 0.5],
4081            [1.0, -0.5, 0.2],
4082            [1.0, 0.0, -0.1],
4083            [1.0, 0.5, 0.3],
4084            [1.0, 1.0, 0.8],
4085            [1.0, 1.5, 1.1],
4086            [1.0, 2.0, 1.5],
4087            [1.0, 2.5, 2.0],
4088            [1.0, 3.0, 2.6],
4089            [1.0, 3.5, 3.1],
4090        ];
4091        let y = array![
4092            [0.1],
4093            [0.3],
4094            [0.4],
4095            [0.7],
4096            [1.0],
4097            [1.5],
4098            [2.0],
4099            [2.7],
4100            [3.3],
4101            [4.0]
4102        ];
4103        // Full-rank S to keep the forward well-posed.
4104        let penalty = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
4105
4106        let mut fit =
4107            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
4108                .expect("forward fit must succeed for well-posed input");
4109        // Force the ill-conditioned gate to fire by zeroing the REML
4110        // Hessian w.r.t. rho — this is exactly what happens in production
4111        // when `λ` saturates to 1e10+ and `d²ℓ/dρ² → 0`.
4112        fit.reml_hess_rho = 0.0;
4113
4114        let result = gaussian_reml_multi_closed_form_backward_from_fit(
4115            x.view(),
4116            y.view(),
4117            penalty.view(),
4118            None,
4119            &fit,
4120            // Nonzero upstreams to force the backward to actually try to
4121            // populate gradients (rather than short-circuit on zero seeds).
4122            1.0,
4123            None,
4124            None,
4125            1.0,
4126            1.0,
4127        )
4128        .expect("backward must NOT error on near-singular K");
4129
4130        assert_eq!(result.grad_x.dim(), (x.nrows(), x.ncols()));
4131        assert_eq!(result.grad_y.dim(), (y.nrows(), y.ncols()));
4132        assert_eq!(result.grad_penalty.dim(), (x.ncols(), x.ncols()));
4133        assert_eq!(result.grad_weights.dim(), x.nrows());
4134        for v in result.grad_x.iter() {
4135            assert!(v.is_finite(), "grad_x must be finite, got {v}");
4136        }
4137        for v in result.grad_y.iter() {
4138            assert!(v.is_finite(), "grad_y must be finite, got {v}");
4139        }
4140        for v in result.grad_penalty.iter() {
4141            assert!(v.is_finite(), "grad_penalty must be finite, got {v}");
4142        }
4143        for v in result.grad_weights.iter() {
4144            assert!(v.is_finite(), "grad_weights must be finite, got {v}");
4145        }
4146    }
4147}
4148
4149/// Vector–Jacobian products of the multi-block per-smooth-λ Gaussian REML
4150/// forward fit ([`gaussian_reml_blocks_orthogonal_shared_scale`]), back to the
4151/// design blocks, penalty blocks, response, and weights.
4152pub struct GaussianRemlBlocksBackwardAnalytic {
4153    pub grad_designs: Vec<Array2<f64>>,
4154    pub grad_penalties: Vec<Array2<f64>>,
4155    pub grad_y: Array2<f64>,
4156    pub grad_weights: Array1<f64>,
4157}
4158
4159/// Analytic backward for the multi-block per-smooth-λ Gaussian REML forward.
4160///
4161/// Computes VJPs of (coefficients, fitted, lambdas, log_lambdas, reml_score,
4162/// edf) back to (design_blocks, penalty_blocks, y, weights). The VJP is
4163/// assembled at the converged log-λ vector: fixed-ρ β/fitted/profiled-REML/EDF
4164/// terms are accumulated first, then the smoothing-parameter sensitivity is
4165/// routed through the F×F profiled REML score Hessian from the implicit optimum.
4166/// Pairs with the forward [`gaussian_reml_blocks_orthogonal_shared_scale`].
4167pub fn gaussian_reml_fit_blocks_backward_analytic(
4168    designs: &[Array2<f64>],
4169    penalties_raw: &[Array2<f64>],
4170    y: ArrayView1<'_, f64>,
4171    weights: ArrayView1<'_, f64>,
4172    rhos: &[f64],
4173    grad_coefficients: Option<ArrayView2<'_, f64>>,
4174    grad_fitted: Option<ArrayView2<'_, f64>>,
4175    grad_lambdas: Option<ArrayView1<'_, f64>>,
4176    grad_log_lambdas: Option<ArrayView1<'_, f64>>,
4177    grad_reml_score: f64,
4178    grad_edf: Option<ArrayView1<'_, f64>>,
4179) -> Result<GaussianRemlBlocksBackwardAnalytic, EstimationError> {
4180    let n = y.len();
4181    let f_blocks = designs.len();
4182    let mut offsets = Vec::with_capacity(f_blocks + 1);
4183    offsets.push(0_usize);
4184    for design in designs {
4185        offsets.push(offsets.last().copied().unwrap() + design.ncols());
4186    }
4187    let p_total = *offsets.last().unwrap();
4188    if n == 0 || p_total == 0 {
4189        return Err(EstimationError::InvalidInput(
4190            "gaussian_reml_fit_blocks_backward requires non-empty rows and at least one coefficient column"
4191                .to_string(),
4192        ));
4193    }
4194
4195    if rhos.len() != f_blocks {
4196        return Err(EstimationError::InvalidInput(format!(
4197            "log_lambdas length mismatch: expected {f_blocks}, got {}",
4198            rhos.len()
4199        )));
4200    }
4201    if let Some(gc) = grad_coefficients {
4202        if gc.dim() != (p_total, 1) {
4203            return Err(EstimationError::InvalidInput(format!(
4204                "grad_coefficients shape mismatch: expected {}x1, got {}x{}",
4205                p_total,
4206                gc.nrows(),
4207                gc.ncols()
4208            )));
4209        }
4210    }
4211    if let Some(gf) = grad_fitted {
4212        if gf.dim() != (n, 1) {
4213            return Err(EstimationError::InvalidInput(format!(
4214                "grad_fitted shape mismatch: expected {}x1, got {}x{}",
4215                n,
4216                gf.nrows(),
4217                gf.ncols()
4218            )));
4219        }
4220    }
4221    if !grad_reml_score.is_finite() {
4222        return Err(EstimationError::InvalidInput(format!(
4223            "grad_reml_score must be finite; got {grad_reml_score}"
4224        )));
4225    }
4226    if let Some(vec) = grad_lambdas {
4227        if vec.len() != f_blocks {
4228            return Err(EstimationError::InvalidInput(format!(
4229                "grad_lambdas length mismatch: expected {f_blocks}, got {}",
4230                vec.len()
4231            )));
4232        }
4233    }
4234    if let Some(vec) = grad_log_lambdas {
4235        if vec.len() != f_blocks {
4236            return Err(EstimationError::InvalidInput(format!(
4237                "grad_log_lambdas length mismatch: expected {f_blocks}, got {}",
4238                vec.len()
4239            )));
4240        }
4241    }
4242    if let Some(vec) = grad_edf {
4243        if vec.len() != f_blocks {
4244            return Err(EstimationError::InvalidInput(format!(
4245                "grad_edf length mismatch: expected {f_blocks}, got {}",
4246                vec.len()
4247            )));
4248        }
4249    }
4250    if let Some(gc) = grad_coefficients {
4251        if let Some(((row, col), value)) = gc.indexed_iter().find(|(_, value)| !value.is_finite()) {
4252            return Err(EstimationError::InvalidInput(format!(
4253                "grad_coefficients[{row},{col}] must be finite; got {value}"
4254            )));
4255        }
4256    }
4257    if let Some(gf) = grad_fitted {
4258        if let Some(((row, col), value)) = gf.indexed_iter().find(|(_, value)| !value.is_finite()) {
4259            return Err(EstimationError::InvalidInput(format!(
4260                "grad_fitted[{row},{col}] must be finite; got {value}"
4261            )));
4262        }
4263    }
4264    if let Some(vec) = grad_lambdas {
4265        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4266            return Err(EstimationError::InvalidInput(format!(
4267                "grad_lambdas[{block}] must be finite; got {value}"
4268            )));
4269        }
4270    }
4271    if let Some(vec) = grad_log_lambdas {
4272        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4273            return Err(EstimationError::InvalidInput(format!(
4274                "grad_log_lambdas[{block}] must be finite; got {value}"
4275            )));
4276        }
4277    }
4278    if let Some(vec) = grad_edf {
4279        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4280            return Err(EstimationError::InvalidInput(format!(
4281                "grad_edf[{block}] must be finite; got {value}"
4282            )));
4283        }
4284    }
4285    for (block, design) in designs.iter().enumerate() {
4286        if let Some(((row, col), value)) =
4287            design.indexed_iter().find(|(_, value)| !value.is_finite())
4288        {
4289            return Err(EstimationError::InvalidInput(format!(
4290                "designs[{block}][{row},{col}] must be finite; got {value}"
4291            )));
4292        }
4293    }
4294    for (block, penalty) in penalties_raw.iter().enumerate() {
4295        if let Some(((row, col), value)) =
4296            penalty.indexed_iter().find(|(_, value)| !value.is_finite())
4297        {
4298            return Err(EstimationError::InvalidInput(format!(
4299                "penalties[{block}][{row},{col}] must be finite; got {value}"
4300            )));
4301        }
4302    }
4303    if let Some((row, value)) = y.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4304        return Err(EstimationError::InvalidInput(format!(
4305            "y[{row}] must be finite; got {value}"
4306        )));
4307    }
4308    if let Some((row, value)) = weights
4309        .iter()
4310        .enumerate()
4311        .find(|(_, value)| !value.is_finite() || **value < 0.0)
4312    {
4313        return Err(EstimationError::InvalidInput(format!(
4314            "weights[{row}] must be finite and non-negative; got {value}"
4315        )));
4316    }
4317
4318    let mut z = Array2::<f64>::zeros((n, p_total));
4319    for k in 0..f_blocks {
4320        z.slice_mut(s![.., offsets[k]..offsets[k + 1]])
4321            .assign(&designs[k]);
4322    }
4323
4324    let penalties: Vec<Array2<f64>> = penalties_raw
4325        .iter()
4326        .map(|p| {
4327            let mut out = p.clone();
4328            gam_linalg::matrix::symmetrize_in_place(&mut out);
4329            out
4330        })
4331        .collect();
4332    let mut ranks = Vec::with_capacity(f_blocks);
4333    let mut pinvs = Vec::with_capacity(f_blocks);
4334    for penalty in &penalties {
4335        let (rank, pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(penalty)?;
4336        ranks.push(rank);
4337        pinvs.push(pinv);
4338    }
4339
4340    let lambdas = Array1::from_iter(rhos.iter().map(|rho| rho.exp()));
4341    if let Some((block, lambda)) = lambdas
4342        .iter()
4343        .enumerate()
4344        .find(|(_, lambda)| !lambda.is_finite() || **lambda <= 0.0)
4345    {
4346        return Err(EstimationError::InvalidInput(format!(
4347            "exp(log_lambdas[{block}]) must be finite and positive; got {lambda}"
4348        )));
4349    }
4350    let mut k_matrix = fast_xt_diag_x(&z.view(), &weights);
4351    for block in 0..f_blocks {
4352        let lambda = lambdas[block];
4353        for local_i in 0..penalties[block].nrows() {
4354            let global_i = offsets[block] + local_i;
4355            for local_j in 0..penalties[block].ncols() {
4356                let global_j = offsets[block] + local_j;
4357                k_matrix[[global_i, global_j]] += lambda * penalties[block][[local_i, local_j]];
4358            }
4359        }
4360    }
4361    let r = gam_linalg::utils::invert_spd_with_ridge(&k_matrix, 0.0)?;
4362
4363    let mut xtwy = Array1::<f64>::zeros(p_total);
4364    for row in 0..n {
4365        let wy = weights[row] * y[row];
4366        for col in 0..p_total {
4367            xtwy[col] += z[[row, col]] * wy;
4368        }
4369    }
4370    let beta = r.dot(&xtwy);
4371    let fitted = z.dot(&beta);
4372    if let Some((col, value)) = beta
4373        .iter()
4374        .enumerate()
4375        .find(|(_, value)| !value.is_finite())
4376    {
4377        return Err(EstimationError::InvalidInput(format!(
4378            "solved coefficient {col} is non-finite: {value}"
4379        )));
4380    }
4381    let residual = &y.to_owned() - &fitted;
4382    let weighted_residual = &residual * &weights.to_owned();
4383    let ywy = y
4384        .iter()
4385        .zip(weights.iter())
4386        .map(|(&yi, &wi)| wi * yi * yi)
4387        .sum::<f64>();
4388    let q_raw = ywy - xtwy.dot(&beta);
4389    if !q_raw.is_finite() {
4390        return Err(EstimationError::InvalidInput(format!(
4391            "Gaussian REML residual quadratic form must be finite; got {q_raw}"
4392        )));
4393    }
4394    let q = q_raw.max(1.0e-300);
4395    let nullity = penalties
4396        .iter()
4397        .zip(ranks.iter())
4398        .map(|(penalty, rank)| penalty.nrows().saturating_sub(*rank))
4399        .sum::<usize>();
4400    // Match the block-orthogonal forward's effective sample size: zero
4401    // prior-weight rows are excluded from the residual degrees of freedom.
4402    let nu = effective_observation_count(weights) as f64 - nullity as f64;
4403    if !(nu.is_finite() && nu > 0.0) {
4404        return Err(EstimationError::InvalidInput(format!(
4405            "Gaussian REML residual degrees of freedom must be positive; got {nu}"
4406        )));
4407    }
4408    let tau = nu / q;
4409    let tau_q = -nu / (q * q);
4410    if !(tau.is_finite() && tau_q.is_finite()) {
4411        return Err(EstimationError::InvalidInput(format!(
4412            "Gaussian REML scale derivatives are non-finite: tau={tau}, tau_q={tau_q}"
4413        )));
4414    }
4415
4416    let mut grad_z = Array2::<f64>::zeros((n, p_total));
4417    let mut g_kernel = Array2::<f64>::zeros((p_total, p_total));
4418    let mut h_kernel = Array1::<f64>::zeros(p_total);
4419    let mut q_kernel = 0.0_f64;
4420    let mut j_blocks: Vec<Array2<f64>> = penalties
4421        .iter()
4422        .map(|p| Array2::<f64>::zeros(p.dim()))
4423        .collect();
4424
4425    let mut beta_tilde = Array1::<f64>::zeros(p_total);
4426    if let Some(gc) = grad_coefficients {
4427        beta_tilde += &gc.column(0).to_owned();
4428    }
4429    if let Some(gf) = grad_fitted {
4430        let gf_col = gf.column(0).to_owned();
4431        beta_tilde += &z.t().dot(&gf_col);
4432        for row in 0..n {
4433            for col in 0..p_total {
4434                grad_z[[row, col]] += gf_col[row] * beta[col];
4435            }
4436        }
4437    }
4438
4439    // Generic downstream losses that explicitly seed beta_hat or fitted
4440    // values cannot use the REML envelope shortcut. Route those seeds through
4441    // the fixed-rho KKT adjoint K u = beta_tilde before differentiating
4442    // designs, penalties, y, weights, and rho.
4443    let u = r.dot(&beta_tilde);
4444    h_kernel += &u;
4445    for i in 0..p_total {
4446        for j in 0..p_total {
4447            g_kernel[[i, j]] -= 0.5 * (beta[i] * u[j] + u[i] * beta[j]);
4448        }
4449    }
4450
4451    let mut alpha = Array1::<f64>::zeros(f_blocks);
4452    if let Some(gl) = grad_lambdas {
4453        for block in 0..f_blocks {
4454            alpha[block] += gl[block] * lambdas[block];
4455        }
4456    }
4457    if let Some(grho) = grad_log_lambdas {
4458        alpha += &grho.to_owned();
4459    }
4460
4461    let mut p_betas = Vec::with_capacity(f_blocks);
4462    let mut m_vectors = Vec::with_capacity(f_blocks);
4463    let mut rp_matrices = Vec::with_capacity(f_blocks);
4464    let mut rpr_matrices = Vec::with_capacity(f_blocks);
4465    let mut b_values = Array1::<f64>::zeros(f_blocks);
4466    let mut t_values = Array1::<f64>::zeros(f_blocks);
4467
4468    for block in 0..f_blocks {
4469        let start = offsets[block];
4470        let end = offsets[block + 1];
4471        let beta_k = beta.slice(s![start..end]).to_owned();
4472        let s_beta = penalties[block].dot(&beta_k);
4473        let lambda = lambdas[block];
4474        let lambda_s_beta = s_beta.mapv(|value| lambda * value);
4475        let mut p_beta = Array1::<f64>::zeros(p_total);
4476        for local_i in 0..(end - start) {
4477            p_beta[start + local_i] = lambda_s_beta[local_i];
4478        }
4479        let weighted_penalty = penalties[block].mapv(|value| lambda * value);
4480        let rp_block = r.slice(s![.., start..end]).dot(&weighted_penalty);
4481        let mut rp = Array2::<f64>::zeros((p_total, p_total));
4482        rp.slice_mut(s![.., start..end]).assign(&rp_block);
4483        let rpr = rp_block.dot(&r.slice(s![start..end, ..]));
4484        let m = r.slice(s![.., start..end]).dot(&lambda_s_beta);
4485        b_values[block] = beta.dot(&p_beta);
4486        t_values[block] = (0..(end - start))
4487            .map(|local_i| rp_block[[start + local_i, local_i]])
4488            .sum::<f64>();
4489        alpha[block] -= u.dot(&p_beta);
4490        p_betas.push(p_beta);
4491        m_vectors.push(m);
4492        rp_matrices.push(rp);
4493        rpr_matrices.push(rpr);
4494    }
4495
4496    if grad_reml_score != 0.0 {
4497        q_kernel += 0.5 * grad_reml_score * tau;
4498        g_kernel += &(r.clone() * (0.5 * grad_reml_score));
4499        for block in 0..f_blocks {
4500            j_blocks[block] -= &(pinvs[block].clone() * (0.5 * grad_reml_score / lambdas[block]));
4501        }
4502    }
4503
4504    let mut trace_pairs = Array2::<f64>::zeros((f_blocks, f_blocks));
4505    for i in 0..f_blocks {
4506        for j in 0..f_blocks {
4507            trace_pairs[[i, j]] = gam_linalg::utils::trace_of_product(
4508                rp_matrices[i].view(),
4509                rp_matrices[j].view(),
4510            );
4511        }
4512    }
4513
4514    if let Some(ge) = grad_edf {
4515        for edf_block in 0..f_blocks {
4516            let scale = ge[edf_block];
4517            if scale == 0.0 {
4518                continue;
4519            }
4520            let start = offsets[edf_block];
4521            let end = offsets[edf_block + 1];
4522            g_kernel += &(rpr_matrices[edf_block].clone() * scale);
4523            j_blocks[edf_block] -= &(r.slice(s![start..end, start..end]).to_owned() * scale);
4524            for rho_block in 0..f_blocks {
4525                alpha[rho_block] += scale * trace_pairs[[edf_block, rho_block]];
4526                if rho_block == edf_block {
4527                    alpha[rho_block] -= scale * t_values[edf_block];
4528                }
4529            }
4530        }
4531    }
4532
4533    if let Some((block, value)) = alpha
4534        .iter()
4535        .enumerate()
4536        .find(|(_, value)| !value.is_finite())
4537    {
4538        return Err(EstimationError::InvalidInput(format!(
4539            "rho adjoint seed for block {block} is non-finite: {value}"
4540        )));
4541    }
4542
4543    if alpha.iter().any(|value| *value != 0.0) {
4544        let mut outer_h = Array2::<f64>::zeros((f_blocks, f_blocks));
4545        for k in 0..f_blocks {
4546            for j in 0..f_blocks {
4547                let beta_pk_r_pj_beta = p_betas[k].dot(&m_vectors[j]);
4548                outer_h[[k, j]] = 0.5 * trace_pairs[[k, j]] + tau * beta_pk_r_pj_beta
4549                    - if k == j {
4550                        0.5 * (t_values[k] + tau * b_values[k])
4551                    } else {
4552                        0.0
4553                    }
4554                    - 0.5 * tau_q * b_values[k] * b_values[j];
4555            }
4556        }
4557        // `outer_h` is the Jacobian of the negative profiled REML estimating
4558        // equation. Preserve signed curvature directions while flooring
4559        // near-zero modes; flipping negative eigenvalues would change the VJP.
4560        gam_linalg::matrix::symmetrize_in_place(&mut outer_h);
4561        if let Some(((row, col), value)) =
4562            outer_h.indexed_iter().find(|(_, value)| !value.is_finite())
4563        {
4564            return Err(EstimationError::InvalidInput(format!(
4565                "outer rho curvature entry ({row},{col}) is non-finite: {value}"
4566            )));
4567        }
4568        let rho_adj =
4569            gam_linalg::utils::solve_symmetric_vector_with_floor(&outer_h, &alpha, 1.0e-10)?;
4570        if let Some((block, value)) = rho_adj
4571            .iter()
4572            .enumerate()
4573            .find(|(_, value)| !value.is_finite())
4574        {
4575            return Err(EstimationError::InvalidInput(format!(
4576                "outer rho adjoint for block {block} is non-finite: {value}"
4577            )));
4578        }
4579        let weighted_b_sum = rho_adj
4580            .iter()
4581            .zip(b_values.iter())
4582            .map(|(&zk, &bk)| zk * bk)
4583            .sum::<f64>();
4584        q_kernel += 0.5 * tau_q * weighted_b_sum;
4585        for block in 0..f_blocks {
4586            let zk = rho_adj[block];
4587            if zk == 0.0 {
4588                continue;
4589            }
4590            g_kernel -= &(rpr_matrices[block].clone() * (0.5 * zk));
4591            let m = &m_vectors[block];
4592            for i in 0..p_total {
4593                h_kernel[i] += tau * zk * m[i];
4594                for j in 0..p_total {
4595                    g_kernel[[i, j]] -= 0.5 * tau * zk * (beta[i] * m[j] + m[i] * beta[j]);
4596                }
4597            }
4598            let start = offsets[block];
4599            let end = offsets[block + 1];
4600            j_blocks[block] += &(r.slice(s![start..end, start..end]).to_owned() * (0.5 * zk));
4601            for i in 0..(end - start) {
4602                for j in 0..(end - start) {
4603                    j_blocks[block][[i, j]] += 0.5 * tau * zk * beta[start + i] * beta[start + j];
4604                }
4605            }
4606        }
4607    }
4608
4609    for row in 0..n {
4610        for col in 0..p_total {
4611            grad_z[[row, col]] += -2.0 * q_kernel * weighted_residual[row] * beta[col];
4612        }
4613    }
4614    let zg = z.dot(&g_kernel);
4615    for row in 0..n {
4616        for col in 0..p_total {
4617            grad_z[[row, col]] += 2.0 * weights[row] * zg[[row, col]];
4618        }
4619    }
4620    let wy = y.to_owned() * &weights.to_owned();
4621    for row in 0..n {
4622        for col in 0..p_total {
4623            grad_z[[row, col]] += wy[row] * h_kernel[col];
4624        }
4625    }
4626
4627    let mut grad_y = Array2::<f64>::zeros((n, 1));
4628    let zh = z.dot(&h_kernel);
4629    for row in 0..n {
4630        grad_y[[row, 0]] = 2.0 * q_kernel * weighted_residual[row] + weights[row] * zh[row];
4631    }
4632
4633    let mut grad_weights = Array1::<f64>::zeros(n);
4634    for row in 0..n {
4635        let diag_zgz = (0..p_total)
4636            .map(|col| z[[row, col]] * zg[[row, col]])
4637            .sum::<f64>();
4638        grad_weights[row] = q_kernel * residual[row] * residual[row] + diag_zgz + y[row] * zh[row];
4639    }
4640
4641    // Weight-scale invariance of the REML score (issue #877). The Gaussian REML
4642    // criterion the score adjoint targets — the profiled cost assembled in
4643    // `reml_outer_engine::objective` — carries the data-density normalization
4644    // `−½ Σ_{wᵢ>0} log(wᵢ)` (the `|W|^{½}` factor of the weighted normal
4645    // likelihood) together with the geometric-mean weight anchor on ρ. Their
4646    // net effect is that the *score* depends on the observation weights only up
4647    // to a global scale: replacing `w → c·w` leaves it unchanged. By Euler's
4648    // homogeneity identity that invariance is exactly
4649    //   Σ_i wᵢ · ∂(score)/∂wᵢ = 0,
4650    // i.e. the score's weight-gradient is orthogonal to the scaling direction
4651    // `1/wᵢ`. The kernel propagation above produces the *raw* (un-normalized)
4652    // weight partials `aᵢ`, which do not satisfy this constraint; the missing
4653    // piece is the projection that removes the scaling component. Subtract the
4654    // multiple of `1/wᵢ` that restores `Σ_i wᵢ·gradᵢ = 0`:
4655    //   gradᵢ ← aᵢ − μ/wᵢ,  μ = (Σ_{j:wⱼ>0} wⱼ aⱼ) / n₊.
4656    // Only the score seed is scale-invariant — β̂, fitted = Zβ̂ and the EDF all
4657    // scale with the weights (the λS term in K = ZᵀWZ + λS does not), so their
4658    // adjoints must NOT be projected. We therefore form the score-only partials
4659    // `aᵢˢ = ½·grs·(τ·rᵢ² + zᵢᵀ R zᵢ)` from the score's own kernel
4660    // contributions (q_kernel ← ½·grs·τ, g_kernel ← ½·grs·R) and project just
4661    // those, leaving the coefficient/fitted/EDF/λ weight-gradients intact.
4662    if grad_reml_score != 0.0 {
4663        let q_kernel_score = 0.5 * grad_reml_score * tau;
4664        let zr = z.dot(&r);
4665        let n_pos = (0..n).filter(|&i| weights[i] > 0.0).count();
4666        if n_pos > 0 {
4667            let mut weighted_score_partial_sum = 0.0_f64;
4668            for row in 0..n {
4669                if weights[row] <= 0.0 {
4670                    continue;
4671                }
4672                let z_r_z = (0..p_total).map(|col| z[[row, col]] * zr[[row, col]]).sum::<f64>();
4673                let a_score = q_kernel_score * residual[row] * residual[row]
4674                    + 0.5 * grad_reml_score * z_r_z;
4675                weighted_score_partial_sum += weights[row] * a_score;
4676            }
4677            let projection = weighted_score_partial_sum / n_pos as f64;
4678            for row in 0..n {
4679                if weights[row] > 0.0 {
4680                    grad_weights[row] -= projection / weights[row];
4681                }
4682            }
4683        }
4684    }
4685
4686    let mut grad_penalties = Vec::with_capacity(f_blocks);
4687    for block in 0..f_blocks {
4688        let start = offsets[block];
4689        let end = offsets[block + 1];
4690        let mut local = g_kernel.slice(s![start..end, start..end]).to_owned();
4691        for i in 0..(end - start) {
4692            for j in 0..(end - start) {
4693                local[[i, j]] += q_kernel * beta[start + i] * beta[start + j];
4694            }
4695        }
4696        local += &j_blocks[block];
4697        local *= lambdas[block];
4698        gam_linalg::matrix::symmetrize_in_place(&mut local);
4699        grad_penalties.push(local);
4700    }
4701
4702    let mut grad_designs = Vec::with_capacity(f_blocks);
4703    for block in 0..f_blocks {
4704        grad_designs.push(
4705            grad_z
4706                .slice(s![.., offsets[block]..offsets[block + 1]])
4707                .to_owned(),
4708        );
4709    }
4710
4711    Ok(GaussianRemlBlocksBackwardAnalytic {
4712        grad_designs,
4713        grad_penalties,
4714        grad_y,
4715        grad_weights,
4716    })
4717}
4718
4719/// Fixed-λ multi-output Gaussian fit under a per-row dense Fisher–Rao precision
4720/// metric: coefficients, fitted values, per-output residual scale, and the
4721/// penalized Fisher-weighted objective.
4722pub struct DenseFisherGaussianFit {
4723    pub coefficients: Array2<f64>,
4724    pub fitted: Array2<f64>,
4725    pub sigma2: Array1<f64>,
4726    pub objective: f64,
4727}
4728
4729/// Add a block-diagonal `λ·S` penalty (one `S` block per output) into a stacked
4730/// `(k·n_outputs)` Hessian in place, symmetrizing `S`.
4731pub fn add_block_diagonal_penalty(
4732    hessian: &mut Array2<f64>,
4733    penalty: ArrayView2<'_, f64>,
4734    lambda: f64,
4735    n_outputs: usize,
4736) -> Result<(), EstimationError> {
4737    let k = penalty.ncols();
4738    if penalty.nrows() != k {
4739        return Err(EstimationError::InvalidInput(format!(
4740            "penalty must be square for dense Fisher fit; got {}x{}",
4741            penalty.nrows(),
4742            penalty.ncols()
4743        )));
4744    }
4745    if hessian.dim() != (k * n_outputs, k * n_outputs) {
4746        return Err(EstimationError::InvalidInput(
4747            "dense Fisher Hessian shape mismatch while adding penalty".to_string(),
4748        ));
4749    }
4750    for output in 0..n_outputs {
4751        let offset = output * k;
4752        for row in 0..k {
4753            for col in 0..k {
4754                let s_sym = 0.5 * (penalty[[row, col]] + penalty[[col, row]]);
4755                hessian[[offset + row, offset + col]] += lambda * s_sym;
4756            }
4757        }
4758    }
4759    Ok(())
4760}
4761
4762/// Closed-form fixed-λ multi-output Gaussian fit with a per-row dense Fisher–Rao
4763/// precision metric. Assembles the block `XᵀWX` (+ block-diagonal `λS`) and
4764/// `XᵀWY` via the dense Fisher block kernels, solves, then forms fitted values,
4765/// per-output residual scale `sigma2`, and the penalized Fisher-weighted
4766/// objective seeded by `latent_prior_score`. `row_weights` are the (already
4767/// resolved) per-observation likelihood weights.
4768pub fn dense_fisher_gaussian_fit(
4769    design: ArrayView2<'_, f64>,
4770    y: ArrayView2<'_, f64>,
4771    penalty: ArrayView2<'_, f64>,
4772    row_weights: ArrayView1<'_, f64>,
4773    fisher_w: ArrayView3<'_, f64>,
4774    lambda: f64,
4775    latent_prior_score: f64,
4776) -> Result<DenseFisherGaussianFit, EstimationError> {
4777    let n_obs = design.nrows();
4778    let k = design.ncols();
4779    let n_outputs = y.ncols();
4780    let mut hessian = crate::pirls::dense_block_xtwx(design, fisher_w, Some(row_weights))?;
4781    add_block_diagonal_penalty(&mut hessian, penalty, lambda, n_outputs)?;
4782    let rhs = crate::pirls::dense_block_xtwy(design, fisher_w, y, Some(row_weights))?;
4783    let beta_vec =
4784        gam_linalg::utils::solve_dense_block_system(&hessian, &rhs, "dense Fisher Gaussian")
4785            .map_err(EstimationError::InvalidInput)?;
4786    let mut coefficients = Array2::<f64>::zeros((k, n_outputs));
4787    for output in 0..n_outputs {
4788        for col in 0..k {
4789            coefficients[[col, output]] = beta_vec[output * k + col];
4790        }
4791    }
4792    let fitted = design.dot(&coefficients);
4793    let mut sigma2 = Array1::<f64>::zeros(n_outputs);
4794    let mut objective = latent_prior_score;
4795    for row in 0..n_obs {
4796        for a in 0..n_outputs {
4797            let ra = y[[row, a]] - fitted[[row, a]];
4798            sigma2[a] += row_weights[row] * ra * ra;
4799            for b in 0..n_outputs {
4800                objective += 0.5
4801                    * row_weights[row]
4802                    * ra
4803                    * fisher_w[[row, a, b]]
4804                    * (y[[row, b]] - fitted[[row, b]]);
4805            }
4806        }
4807    }
4808    for output in 0..n_outputs {
4809        sigma2[output] /= (n_obs.saturating_sub(k).max(1)) as f64;
4810        let beta_col = coefficients.column(output);
4811        let s_beta = penalty.dot(&beta_col);
4812        objective += 0.5 * lambda * beta_col.dot(&s_beta);
4813    }
4814    Ok(DenseFisherGaussianFit {
4815        coefficients,
4816        fitted,
4817        sigma2,
4818        objective,
4819    })
4820}