Skip to main content

gam_solve/
gaussian_reml.rs

1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{
3    FaerCholesky, FaerEigh, fast_ab, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
4};
5use faer::Side;
6use ndarray::{
7    Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, ArrayViewMut1, ArrayViewMut2, Axis,
8    s,
9};
10use rayon::prelude::*;
11use std::sync::Once;
12
13/// One-time warning latch for backward-pass graceful degradation on a
14/// near-singular penalized Hessian `K = XᵀWX + λS`. When `λ_k` saturates
15/// (e.g. 1e10+), `K` becomes effectively rank-deficient and the analytic VJP
16/// cannot be evaluated. Rather than raising, the backward returns zero
17/// gradients of the correct shape: this is the statistically correct
18/// "shrink-out" gradient — when `λ` has saturated, the atom is unused, so
19/// every input's contribution to the loss is zero in the limit.
20static ILL_CONDITIONED_BACKWARD_WARNED: Once = Once::new();
21
22fn warn_ill_conditioned_backward_once(p: usize, d: usize, condition_number: f64) {
23    ILL_CONDITIONED_BACKWARD_WARNED.call_once(|| {
24        log::warn!(
25            "gaussian_reml_fit_backward: K = XᵀWX + λS is near-singular \
26             (p={p}, d={d}, cond≈{condition_number:.2e}); returning zero gradients \
27             for this fit (λ has saturated, atom is effectively unused). \
28             Further occurrences are silent."
29        );
30    });
31}
32
33fn zero_backward_result(n: usize, p: usize, d: usize) -> GaussianRemlBackwardResult {
34    GaussianRemlBackwardResult {
35        grad_x: Array2::<f64>::zeros((n, p)),
36        grad_y: Array2::<f64>::zeros((n, d)),
37        grad_penalty: Array2::<f64>::zeros((p, p)),
38        grad_weights: Array1::<f64>::zeros(n),
39    }
40}
41
42const RHO_LOWER: f64 = -30.0;
43const RHO_UPPER: f64 = 30.0;
44const EIGEN_REL_TOL: f64 = 1.0e-10;
45const GRAD_TOL: f64 = 1.0e-12;
46const MIN_DEVIANCE: f64 = 1.0e-300;
47
48/// Canonicalize a penalty matrix to its symmetric average.
49///
50/// Closed-form Gaussian REML treats `S` as symmetric throughout — the
51/// eigendecomposition, the pseudo-determinant `log|S|₊`, the rank detector,
52/// and every per-helper VJP all assume `S = Sᵀ`. To make that contract
53/// explicit (rather than implicit in `eigh(Side::Lower)` reading the lower
54/// triangle and silently ignoring the upper), every entry point that takes a
55/// penalty matrix replaces it with `0.5 (S + Sᵀ)` before any downstream use.
56/// For symmetric input this is a numerical no-op; for asymmetric input it
57/// defines the function as operating on the symmetric average.
58fn canonicalize_penalty(penalty: ArrayView2<'_, f64>) -> Array2<f64> {
59    let p = penalty.nrows();
60    let mut out = penalty.to_owned();
61    for i in 0..p {
62        for j in (i + 1)..p {
63            let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
64            out[[i, j]] = avg;
65            out[[j, i]] = avg;
66        }
67    }
68    out
69}
70
71#[derive(Clone, Debug)]
72pub struct GaussianRemlEigenCache {
73    pub penalty_eigenvalues: Array1<f64>,
74    pub eigenvectors: Array2<f64>,
75    pub coefficient_basis: Array2<f64>,
76    pub xtwx_fingerprint: u64,
77    pub penalty_fingerprint: u64,
78    pub logdet_xtwx: f64,
79    pub logdet_penalty_positive: f64,
80    pub penalty_rank: usize,
81    pub nullity: usize,
82}
83
84#[derive(Clone, Debug, Default)]
85pub struct GaussianRemlWarmStart {
86    pub lambda: Option<f64>,
87    pub eigen_cache: Option<GaussianRemlEigenCache>,
88}
89
90impl GaussianRemlWarmStart {
91    pub fn from_multi_result(result: &GaussianRemlMultiResult) -> Self {
92        Self {
93            lambda: Some(result.lambda),
94            eigen_cache: Some(result.cache.clone()),
95        }
96    }
97}
98
99#[derive(Clone, Debug)]
100pub struct GaussianRemlResult {
101    pub lambda: f64,
102    pub rho: f64,
103    pub coefficients: Array1<f64>,
104    pub fitted: Array1<f64>,
105    pub reml_score: f64,
106    pub reml_grad_lambda: f64,
107    pub reml_hess_lambda: f64,
108    pub reml_grad_rho: f64,
109    pub reml_hess_rho: f64,
110    pub edf: f64,
111    pub sigma2: f64,
112    pub cache: GaussianRemlEigenCache,
113}
114
115#[derive(Clone, Debug)]
116pub struct GaussianRemlMultiResult {
117    pub lambda: f64,
118    pub rho: f64,
119    pub coefficients: Array2<f64>,
120    pub fitted: Array2<f64>,
121    pub reml_score: f64,
122    pub reml_grad_lambda: f64,
123    pub reml_hess_lambda: f64,
124    pub reml_grad_rho: f64,
125    pub reml_hess_rho: f64,
126    pub edf: f64,
127    pub sigma2: Array1<f64>,
128    pub cache: GaussianRemlEigenCache,
129}
130
131#[derive(Clone, Debug)]
132pub struct GaussianRemlFreeBScore {
133    pub reml_score: f64,
134    pub grad_coefficients: Array2<f64>,
135    pub grad_penalty: Array2<f64>,
136    pub grad_log_lambda: f64,
137    pub fitted: Array2<f64>,
138    pub sigma2: Array1<f64>,
139    pub edf: f64,
140}
141
142#[derive(Clone, Debug)]
143pub struct GaussianRemlBackwardResult {
144    pub grad_x: Array2<f64>,
145    pub grad_y: Array2<f64>,
146    pub grad_penalty: Array2<f64>,
147    pub grad_weights: Array1<f64>,
148}
149
150#[derive(Clone, Debug)]
151pub struct GaussianRemlMultiBackwardProblem<'a> {
152    pub x: ArrayView2<'a, f64>,
153    pub y: ArrayView2<'a, f64>,
154    pub weights: Option<ArrayView1<'a, f64>>,
155    pub fit: &'a GaussianRemlMultiResult,
156    pub grad_lambda: f64,
157    pub grad_coefficients: Option<ArrayView2<'a, f64>>,
158    pub grad_fitted: Option<ArrayView2<'a, f64>>,
159    pub grad_reml_score: f64,
160    pub grad_edf: f64,
161}
162
163#[derive(Clone, Debug)]
164pub struct GaussianRemlNoAllocWorkspace {
165    pub xtwy: Array2<f64>,
166    pub ywy: Array1<f64>,
167    pub projected_rhs: Array2<f64>,
168    pub projected_rhs_squared: Array2<f64>,
169    pub scaled_projected_rhs: Array2<f64>,
170}
171
172impl GaussianRemlNoAllocWorkspace {
173    pub fn new(n_coefficients: usize, n_outputs: usize) -> Self {
174        Self {
175            xtwy: Array2::zeros((n_coefficients, n_outputs)),
176            ywy: Array1::zeros(n_outputs),
177            projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
178            projected_rhs_squared: Array2::zeros((n_coefficients, n_outputs)),
179            scaled_projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
180        }
181    }
182
183    fn validate(&self, p: usize, d: usize) -> Result<(), EstimationError> {
184        if self.xtwy.dim() != (p, d)
185            || self.ywy.len() != d
186            || self.projected_rhs.dim() != (p, d)
187            || self.projected_rhs_squared.dim() != (p, d)
188            || self.scaled_projected_rhs.dim() != (p, d)
189        {
190            crate::bail_invalid_estim!(
191                "Gaussian REML no-alloc workspace shape mismatch: expected p={p}, d={d}"
192            );
193        }
194        Ok::<(), _>(())
195    }
196}
197
198#[derive(Clone, Copy, Debug)]
199pub struct GaussianRemlNoAllocFit {
200    pub lambda: f64,
201    pub rho: f64,
202    pub reml_score: f64,
203    pub reml_grad_lambda: f64,
204    pub reml_hess_lambda: f64,
205    pub reml_grad_rho: f64,
206    pub reml_hess_rho: f64,
207    pub edf: f64,
208}
209
210#[derive(Clone, Debug)]
211pub struct GaussianRemlMultiBatchProblem<'a> {
212    pub x: ArrayView2<'a, f64>,
213    pub y: ArrayView2<'a, f64>,
214    pub weights: Option<ArrayView1<'a, f64>>,
215    pub init_rho: Option<f64>,
216}
217
218#[derive(Clone, Debug)]
219pub struct GaussianRemlBlockOrthogonalResult {
220    pub coefficients: Vec<Array2<f64>>,
221    pub fitted: Array2<f64>,
222    pub lambdas: Array1<f64>,
223    pub log_lambdas: Array1<f64>,
224    pub reml_score: f64,
225    pub edf: Array1<f64>,
226}
227
228#[derive(Clone)]
229struct GaussianRemlPrepared {
230    cache: GaussianRemlEigenCache,
231    ywy: Array1<f64>,
232    projected_rhs_squared: Array2<f64>,
233    projected_rhs: Array2<f64>,
234    n_observations: usize,
235    n_outputs: usize,
236}
237
238#[derive(Clone, Copy)]
239struct ObjectiveEval {
240    cost: f64,
241    grad: f64,
242    hess: f64,
243    edf: f64,
244}
245
246/// A single Gaussian closed-form REML objective term, carrying its analytic
247/// VALUE together with its analytic ρ-GRADIENT and ρ-HESSIAN.
248///
249/// Single source of truth: each term's value and its (already hand-derived,
250/// closed-form) ρ-derivatives are returned from ONE function body, so a future
251/// edit to the value formula cannot silently leave the derivatives stale.
252/// Mirrors the `PenaltyLogdetDerivs`-returning-tuple pattern used by the
253/// unified outer evaluator — the structural cure for the objective↔gradient
254/// desync class (#752/#748/#808). The three contributions are accumulated
255/// through [`ObjectiveEval`] at one site, so they cannot drift apart.
256#[derive(Clone, Copy)]
257struct TermDerivs {
258    value: f64,
259    grad: f64,
260    hess: f64,
261}
262
263impl std::ops::AddAssign<TermDerivs> for ObjectiveEval {
264    /// Fold a term's `(value, grad, hess)` triple into the running totals in
265    /// lock-step, so value and derivative can never be added at separate sites.
266    fn add_assign(&mut self, rhs: TermDerivs) {
267        self.cost += rhs.value;
268        self.grad += rhs.grad;
269        self.hess += rhs.hess;
270    }
271}
272
273/// `½d·(log|H| − log|S|_+)` value with its analytic ρ-gradient/Hessian.
274///
275/// The penalty-eigenvalue sum produces all three quantities from the SAME
276/// `t = λδ` intermediates in one pass, so the value (`log|1+t|`) and its
277/// derivatives (`t/(1+t)`, `t/(1+t)²`) are single-sourced.
278fn gaussian_reml_logdet_term(
279    cache: &GaussianRemlEigenCache,
280    rho: f64,
281    n_outputs: f64,
282) -> (TermDerivs, f64) {
283    let lambda = rho.exp();
284    let mut logdet_h = cache.logdet_xtwx;
285    let mut trace_h = 0.0;
286    let mut trace_h_deriv = 0.0;
287    let mut edf = 0.0;
288    for &delta in &cache.penalty_eigenvalues {
289        let t = lambda * delta;
290        logdet_h += (1.0 + t).ln();
291        if delta > 0.0 {
292            trace_h += t / (1.0 + t);
293            trace_h_deriv += t / ((1.0 + t) * (1.0 + t));
294        }
295        edf += 1.0 / (1.0 + t);
296    }
297    let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * rho;
298    let term = TermDerivs {
299        value: 0.5 * n_outputs * (logdet_h - logdet_s),
300        grad: 0.5 * n_outputs * (trace_h - cache.penalty_rank as f64),
301        hess: 0.5 * n_outputs * trace_h_deriv,
302    };
303    (term, edf)
304}
305
306/// Per-output dispersion-prior term `½ν·(1 + log(2π·dp/ν))` with its analytic
307/// ρ-gradient/Hessian.
308///
309/// `dp`, `dp_grad`, `dp_hess` are computed from the SAME eigenvalue sum, then
310/// the value `log(dp)` and its derivatives `dp_grad/dp`,
311/// `dp_hess/dp − (dp_grad/dp)²` are returned together so they cannot desync.
312fn gaussian_reml_dispersion_term(
313    cache: &GaussianRemlEigenCache,
314    ywy: ArrayView1<'_, f64>,
315    projected_rhs_squared: ArrayView2<'_, f64>,
316    output: usize,
317    nu: f64,
318    lambda: f64,
319) -> TermDerivs {
320    let mut fitted_quadratic = 0.0;
321    let mut dp_grad = 0.0;
322    let mut dp_hess = 0.0;
323    for eig in 0..cache.penalty_eigenvalues.len() {
324        let c2 = projected_rhs_squared[[eig, output]];
325        let t = lambda * cache.penalty_eigenvalues[eig];
326        let denom = 1.0 + t;
327        fitted_quadratic += c2 / denom;
328        dp_grad += c2 * t / (denom * denom);
329        dp_hess += c2 * t * (1.0 - t) / (denom * denom * denom);
330    }
331    let dp = (ywy[output] - fitted_quadratic).max(MIN_DEVIANCE);
332    TermDerivs {
333        value: 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln()),
334        grad: 0.5 * nu * dp_grad / dp,
335        hess: 0.5 * nu * (dp_hess / dp - (dp_grad * dp_grad) / (dp * dp)),
336    }
337}
338
339pub fn gaussian_reml_closed_form(
340    x: ArrayView2<'_, f64>,
341    y: ArrayView1<'_, f64>,
342    penalty: ArrayView2<'_, f64>,
343    weights: Option<ArrayView1<'_, f64>>,
344    init_rho: Option<f64>,
345) -> Result<GaussianRemlResult, EstimationError> {
346    gaussian_reml_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
347}
348
349pub fn gaussian_reml_closed_form_with_nullspace_dim(
350    x: ArrayView2<'_, f64>,
351    y: ArrayView1<'_, f64>,
352    penalty: ArrayView2<'_, f64>,
353    nullspace_dim: Option<usize>,
354    weights: Option<ArrayView1<'_, f64>>,
355    init_rho: Option<f64>,
356) -> Result<GaussianRemlResult, EstimationError> {
357    let y2 = y.insert_axis(Axis(1));
358    let result = gaussian_reml_multi_closed_form_with_nullspace_dim(
359        x,
360        y2,
361        penalty,
362        nullspace_dim,
363        weights,
364        init_rho,
365    )?;
366    scalar_result_from_multi(result)
367}
368
369fn scalar_result_from_multi(
370    result: GaussianRemlMultiResult,
371) -> Result<GaussianRemlResult, EstimationError> {
372    Ok(GaussianRemlResult {
373        lambda: result.lambda,
374        rho: result.rho,
375        coefficients: result.coefficients.column(0).to_owned(),
376        fitted: result.fitted.column(0).to_owned(),
377        reml_score: result.reml_score,
378        reml_grad_lambda: result.reml_grad_lambda,
379        reml_hess_lambda: result.reml_hess_lambda,
380        reml_grad_rho: result.reml_grad_rho,
381        reml_hess_rho: result.reml_hess_rho,
382        edf: result.edf,
383        sigma2: result.sigma2[0],
384        cache: result.cache,
385    })
386}
387
388pub fn gaussian_reml_multi_closed_form(
389    x: ArrayView2<'_, f64>,
390    y: ArrayView2<'_, f64>,
391    penalty: ArrayView2<'_, f64>,
392    weights: Option<ArrayView1<'_, f64>>,
393    init_rho: Option<f64>,
394) -> Result<GaussianRemlMultiResult, EstimationError> {
395    gaussian_reml_multi_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
396}
397
398pub fn gaussian_reml_multi_closed_form_with_nullspace_dim(
399    x: ArrayView2<'_, f64>,
400    y: ArrayView2<'_, f64>,
401    penalty: ArrayView2<'_, f64>,
402    nullspace_dim: Option<usize>,
403    weights: Option<ArrayView1<'_, f64>>,
404    init_rho: Option<f64>,
405) -> Result<GaussianRemlMultiResult, EstimationError> {
406    let init_lambda = init_rho.map(f64::exp);
407    gaussian_reml_multi_closed_form_from_parts(
408        x,
409        y,
410        penalty,
411        nullspace_dim,
412        weights,
413        init_lambda,
414        None,
415    )
416}
417
418pub fn gaussian_reml_multi_closed_form_warm_started(
419    x: ArrayView2<'_, f64>,
420    y: ArrayView2<'_, f64>,
421    penalty: ArrayView2<'_, f64>,
422    weights: Option<ArrayView1<'_, f64>>,
423    warm_start: Option<&GaussianRemlWarmStart>,
424) -> Result<GaussianRemlMultiResult, EstimationError> {
425    gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
426        x, y, penalty, None, weights, warm_start,
427    )
428}
429
430pub fn gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
431    x: ArrayView2<'_, f64>,
432    y: ArrayView2<'_, f64>,
433    penalty: ArrayView2<'_, f64>,
434    nullspace_dim: Option<usize>,
435    weights: Option<ArrayView1<'_, f64>>,
436    warm_start: Option<&GaussianRemlWarmStart>,
437) -> Result<GaussianRemlMultiResult, EstimationError> {
438    let init_lambda = warm_start.and_then(|start| start.lambda);
439    let eigen_cache = warm_start.and_then(|start| start.eigen_cache.as_ref());
440    gaussian_reml_multi_closed_form_from_parts(
441        x,
442        y,
443        penalty,
444        nullspace_dim,
445        weights,
446        init_lambda,
447        eigen_cache,
448    )
449}
450
451pub fn gaussian_reml_multi_closed_form_with_cache(
452    x: ArrayView2<'_, f64>,
453    y: ArrayView2<'_, f64>,
454    penalty: ArrayView2<'_, f64>,
455    weights: Option<ArrayView1<'_, f64>>,
456    init_lambda: Option<f64>,
457    eigen_cache: Option<&GaussianRemlEigenCache>,
458) -> Result<GaussianRemlMultiResult, EstimationError> {
459    gaussian_reml_multi_closed_form_from_parts(
460        x,
461        y,
462        penalty,
463        None,
464        weights,
465        init_lambda,
466        eigen_cache,
467    )
468}
469
470pub fn gaussian_reml_multi_closed_form_with_cache_no_alloc(
471    x: ArrayView2<'_, f64>,
472    y: ArrayView2<'_, f64>,
473    penalty: ArrayView2<'_, f64>,
474    weights: Option<ArrayView1<'_, f64>>,
475    init_lambda: Option<f64>,
476    eigen_cache: &GaussianRemlEigenCache,
477    workspace: &mut GaussianRemlNoAllocWorkspace,
478    mut coefficients: ArrayViewMut2<'_, f64>,
479    mut fitted: ArrayViewMut2<'_, f64>,
480    mut sigma2: ArrayViewMut1<'_, f64>,
481) -> Result<GaussianRemlNoAllocFit, EstimationError> {
482    // Match the symmetric-S contract used by the cache builder: the
483    // fingerprint check below compares against a fingerprint computed on the
484    // canonicalized penalty, so the input must be canonicalized first.
485    let penalty_owned = canonicalize_penalty(penalty);
486    let penalty = penalty_owned.view();
487    let n = x.nrows();
488    let p = x.ncols();
489    let d = y.ncols();
490    validate_gaussian_reml_design(x, penalty, weights)?;
491    validate_gaussian_reml_eigen_cache(eigen_cache, p)?;
492    if y.nrows() != n {
493        crate::bail_invalid_estim!(
494            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
495            y.nrows()
496        );
497    }
498    if y.iter().any(|value| !value.is_finite()) {
499        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
500    }
501    if n <= eigen_cache.nullity {
502        crate::bail_invalid_estim!(
503            "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
504            eigen_cache.nullity
505        );
506    }
507    let penalty_fingerprint = matrix_fingerprint(penalty);
508    if eigen_cache.penalty_fingerprint != penalty_fingerprint {
509        crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
510    }
511    workspace.validate(p, d)?;
512    if coefficients.dim() != (p, d) || fitted.dim() != (n, d) || sigma2.len() != d {
513        crate::bail_invalid_estim!(
514            "Gaussian REML no-alloc output shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
515        );
516    }
517    if let Some(lambda) = init_lambda {
518        validate_initial_lambda(lambda)?;
519    }
520
521    fill_weighted_rhs_no_alloc(x, y, weights, workspace)?;
522    project_rhs_no_alloc(eigen_cache, workspace);
523
524    let init_rho = init_lambda.map(f64::ln);
525    let rho = optimize_rho_no_alloc(
526        eigen_cache,
527        workspace.ywy.view(),
528        workspace.projected_rhs_squared.view(),
529        n,
530        d,
531        init_rho,
532    )?;
533    let eval = evaluate_reml_parts(
534        eigen_cache,
535        workspace.ywy.view(),
536        workspace.projected_rhs_squared.view(),
537        n,
538        d,
539        rho,
540    );
541    let lambda = rho.exp();
542    fill_coefficients_no_alloc(eigen_cache, workspace, lambda, coefficients.view_mut());
543    fill_fitted_no_alloc(x, coefficients.view(), fitted.view_mut());
544    fill_sigma2_no_alloc(
545        eigen_cache,
546        workspace.ywy.view(),
547        workspace.projected_rhs_squared.view(),
548        n,
549        d,
550        lambda,
551        sigma2.view_mut(),
552    );
553    let (reml_grad_lambda, reml_hess_lambda) =
554        rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
555    Ok(GaussianRemlNoAllocFit {
556        lambda,
557        rho,
558        reml_score: eval.cost,
559        reml_grad_lambda,
560        reml_hess_lambda,
561        reml_grad_rho: eval.grad,
562        reml_hess_rho: eval.hess,
563        edf: eval.edf,
564    })
565}
566
567
568pub fn gaussian_reml_multi_closed_form_batch<'a>(
569    problems: &[GaussianRemlMultiBatchProblem<'a>],
570    penalty: ArrayView2<'a, f64>,
571    nullspace_dim: Option<usize>,
572) -> Result<Vec<GaussianRemlMultiResult>, EstimationError> {
573    if problems.is_empty() {
574        return Ok(Vec::new());
575    }
576    // Phase A: par_iter compute X'WX per problem (the only per-fit step that
577    // depends on `n_b`; remaining work is `O(p)` and can amortize through
578    // `_with_cache`).
579    let xtwx_per_problem: Vec<Array2<f64>> = problems
580        .par_iter()
581        .map(|problem| {
582            let weight = match problem.weights.as_ref() {
583                Some(w) => w.to_owned(),
584                None => Array1::ones(problem.x.nrows()),
585            };
586            dense_xt_diag_x(problem.x.view(), weight.view())
587        })
588        .collect();
589    // Phase B: one batched cuSOLVER Cholesky when policy approves uniform p
590    // and K aggregate FLOPs; otherwise the cache builder uses the normal
591    // per-fit non-GPU factorization path.
592    let caches =
593        build_gaussian_reml_eigen_cache_batched(xtwx_per_problem, penalty.view(), nullspace_dim);
594    // Phase C: par_iter finish each fit with its prebuilt cache. A cache-build
595    // error is a real per-problem error, not a signal to rebuild through a
596    // second path.
597    let fits: Vec<Result<GaussianRemlMultiResult, EstimationError>> = problems
598        .par_iter()
599        .zip(caches.into_par_iter())
600        .map(|(problem, cache_result)| {
601            let init_lambda = problem.init_rho.map(f64::exp);
602            let cache = cache_result?;
603            gaussian_reml_multi_closed_form_from_parts(
604                problem.x.view(),
605                problem.y.view(),
606                penalty.view(),
607                nullspace_dim,
608                problem.weights.as_ref().map(|weights| weights.view()),
609                init_lambda,
610                Some(&cache),
611            )
612        })
613        .collect();
614    fits.into_iter().collect()
615}
616
617struct BlockOrthogonalEval {
618    beta: Array2<f64>,
619    logdet: f64,
620    trace: f64,
621    trace_pair: f64,
622    fitted_energy: Array1<f64>,
623    penalty_energy: Array1<f64>,
624    curvature_energy: Array1<f64>,
625    edf: f64,
626}
627
628fn block_penalty_rank_logdet(
629    penalty: ArrayView2<'_, f64>,
630) -> Result<(usize, f64), EstimationError> {
631    let eigs = penalty
632        .to_owned()
633        .eigh(Side::Lower)
634        .map_err(|_| EstimationError::ModelIsIllConditioned {
635            condition_number: f64::INFINITY,
636        })?
637        .0;
638    let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
639    let tol = (EIGEN_REL_TOL * max_abs).max(1.0e-14);
640    let mut rank = 0_usize;
641    let mut logdet = 0.0;
642    for eig in eigs.iter().copied() {
643        if eig > tol {
644            rank += 1;
645            logdet += eig.ln();
646        }
647    }
648    Ok((rank, logdet))
649}
650
651fn block_orthogonal_eval(
652    gram: &Array2<f64>,
653    rhs: &Array2<f64>,
654    penalty: &Array2<f64>,
655    rho: f64,
656) -> Result<BlockOrthogonalEval, EstimationError> {
657    let lambda = rho.exp();
658    validate_initial_lambda(lambda)?;
659    let scaled_penalty = penalty * lambda;
660    let hessian = canonicalize_penalty((gram + &scaled_penalty).view());
661    let chol = gaussian_reml_cholesky_lower(hessian)?;
662    let beta = solve_spd_from_lower_factor(&chol, rhs)?;
663    let solved_penalty = solve_spd_from_lower_factor(&chol, &scaled_penalty)?;
664    let logdet = 2.0 * chol.diag().iter().map(|value| value.ln()).sum::<f64>();
665    let trace = (0..solved_penalty.nrows())
666        .map(|i| solved_penalty[[i, i]])
667        .sum::<f64>();
668    let trace_pair =
669        gam_linalg::utils::trace_of_product(solved_penalty.view(), solved_penalty.view());
670    let fitted_energy = (rhs * &beta).sum_axis(Axis(0));
671    let p_beta = scaled_penalty.dot(&beta);
672    let penalty_energy = (&beta * &p_beta).sum_axis(Axis(0));
673    let solved_p_beta = solve_spd_from_lower_factor(&chol, &p_beta)?;
674    let curvature_energy = (&p_beta * &solved_p_beta).sum_axis(Axis(0));
675    Ok(BlockOrthogonalEval {
676        beta,
677        logdet,
678        trace,
679        trace_pair,
680        fitted_energy,
681        penalty_energy,
682        curvature_energy,
683        edf: penalty.nrows() as f64 - trace,
684    })
685}
686
687/// Block-orthogonal shared-scale REML objective VALUE together with its
688/// analytic ρ-gradient and ρ-Hessian.
689///
690/// Single source of truth: the value `½d·logdet − ½·fit − ½d·rank·ρ` and its
691/// ρ-derivatives are returned from ONE function body, so a future edit to the
692/// objective cannot leave the Newton gradient/Hessian (previously written at a
693/// physically separate site inside `solve_block_orthogonal_rho`) stale. This
694/// closes a genuine `(value_here, gradient_there)` loose pair. Mirrors the
695/// `PenaltyLogdetDerivs` single-source pattern; behavior is identical (the same
696/// closed-form formulas, reorganized).
697struct BlockOrthogonalScaleDerivs {
698    value: f64,
699    grad: f64,
700    hess: f64,
701}
702
703fn block_orthogonal_scale_objective(
704    eval: &BlockOrthogonalEval,
705    rho: f64,
706    scale_precision: ArrayView1<'_, f64>,
707    rank: usize,
708) -> BlockOrthogonalScaleDerivs {
709    let d = scale_precision.len() as f64;
710    let fit_term = scale_precision
711        .iter()
712        .zip(eval.fitted_energy.iter())
713        .map(|(scale, energy)| scale * energy)
714        .sum::<f64>();
715    // VALUE: ½d·log|H| − ½ Σ_o w_o ⟨y_o, fit_o⟩ − ½d·rank·ρ.
716    let value = 0.5 * d * eval.logdet - 0.5 * fit_term - 0.5 * d * (rank as f64) * rho;
717    // ρ-GRADIENT: d/dρ of the same scalar. The logdet term contributes
718    // ½d·(tr(H⁻¹λS) − rank); the (data-independent-at-fixed-β envelope) fit term
719    // contributes +½ Σ_o w_o βᵀ(λS)β. Both share `eval`'s cached energies.
720    let grad = 0.5 * d * (eval.trace - rank as f64)
721        + 0.5
722            * scale_precision
723                .iter()
724                .zip(eval.penalty_energy.iter())
725                .map(|(scale, energy)| scale * energy)
726                .sum::<f64>();
727    // ρ-HESSIAN: d²/dρ². Logdet term: ½d·(tr(H⁻¹λS) − tr((H⁻¹λS)²)); penalty
728    // term: ½ Σ_o w_o (βᵀλSβ − 2 βᵀλS H⁻¹ λS β).
729    let hess = 0.5 * d * (eval.trace - eval.trace_pair)
730        + 0.5
731            * scale_precision
732                .iter()
733                .zip(eval.penalty_energy.iter().zip(eval.curvature_energy.iter()))
734                .map(|(scale, (energy, curvature))| scale * (energy - 2.0 * curvature))
735                .sum::<f64>();
736    BlockOrthogonalScaleDerivs { value, grad, hess }
737}
738
739fn solve_block_orthogonal_rho(
740    gram: &Array2<f64>,
741    rhs: &Array2<f64>,
742    penalty: &Array2<f64>,
743    rho0: f64,
744    scale_precision: ArrayView1<'_, f64>,
745    rank: usize,
746    max_iter: usize,
747) -> Result<(f64, BlockOrthogonalEval), EstimationError> {
748    let mut rho = rho0;
749    let mut current = block_orthogonal_eval(gram, rhs, penalty, rho)?;
750    for _ in 0..max_iter {
751        // Value, ρ-gradient, and ρ-Hessian all come from the SINGLE
752        // single-source objective evaluation — they cannot desync.
753        let derivs = block_orthogonal_scale_objective(&current, rho, scale_precision, rank);
754        let grad = derivs.grad;
755        let hess = derivs.hess;
756        if !(grad.is_finite() && hess.is_finite()) {
757            return Err(EstimationError::ModelIsIllConditioned {
758                condition_number: f64::INFINITY,
759            });
760        }
761        // Newton step where the curvature is reliably positive; a unit
762        // gradient-descent direction where it is not (non-convex / near-flat
763        // region) so we never step ALONG negative curvature (which ascends).
764        // There is deliberately NO magic step clamp: the line search below
765        // globalizes and SKIPS any candidate ρ that is infeasible (e.g. λ =
766        // exp(ρ) overflows `validate_initial_lambda`), so an over-long Newton
767        // step is simply rejected rather than bounded by an arbitrary constant
768        // or crashing the solve. This is the root fix the old `.clamp(-2,2)`
769        // was masking: the clamp existed only to keep an over-long step from
770        // reaching `block_orthogonal_eval`, which errors on a non-finite λ.
771        let descent = grad.signum();
772        let step = if hess > 1.0e-10 { grad / hess } else { descent };
773        let mut best_rho = rho;
774        let mut best_eval = current;
775        let mut best_phi =
776            block_orthogonal_scale_objective(&best_eval, best_rho, scale_precision, rank).value;
777        for candidate_rho in [
778            rho - step,
779            rho - 0.5 * step,
780            rho - 0.25 * step,
781            rho - descent,
782            rho - 0.25 * descent,
783        ] {
784            // Skip infeasible candidates (λ overflow / ill-conditioned Gram)
785            // instead of failing the whole solve — the bounded gradient
786            // candidates (`rho - descent`, `rho - 0.25·descent`) remain valid,
787            // so a too-long Newton step degrades to gradient descent.
788            let Ok(candidate_eval) = block_orthogonal_eval(gram, rhs, penalty, candidate_rho)
789            else {
790                continue;
791            };
792            let candidate_phi = block_orthogonal_scale_objective(
793                &candidate_eval,
794                candidate_rho,
795                scale_precision,
796                rank,
797            )
798            .value;
799            if candidate_phi < best_phi {
800                best_rho = candidate_rho;
801                best_eval = candidate_eval;
802                best_phi = candidate_phi;
803            }
804        }
805        let delta = (best_rho - rho).abs();
806        rho = best_rho;
807        current = best_eval;
808        if delta < 1.0e-12 || step.abs() < 1.0e-7 {
809            break;
810        }
811    }
812    Ok((rho, current))
813}
814
815pub fn gaussian_reml_blocks_orthogonal_shared_scale(
816    designs: &[Array2<f64>],
817    penalties: &[Array2<f64>],
818    y: ArrayView2<'_, f64>,
819    weights: Option<ArrayView1<'_, f64>>,
820    init_rhos: Option<&[f64]>,
821) -> Result<GaussianRemlBlockOrthogonalResult, EstimationError> {
822    if designs.is_empty() {
823        crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one block");
824    }
825    if designs.len() != penalties.len() {
826        crate::bail_invalid_estim!(
827            "block-orthogonal Gaussian REML block mismatch: {} designs, {} penalties",
828            designs.len(),
829            penalties.len()
830        );
831    }
832    let n = y.nrows();
833    let d = y.ncols();
834    if d == 0 {
835        crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one output");
836    }
837    if y.iter().any(|value| !value.is_finite()) {
838        crate::bail_invalid_estim!("block-orthogonal Gaussian REML response must be finite");
839    }
840    let weight = gaussian_reml_weights(n, weights)?;
841    if let Some(rhos) = init_rhos {
842        if rhos.len() != designs.len() {
843            crate::bail_invalid_estim!(
844                "block-orthogonal Gaussian REML init_rhos length mismatch: expected {}, got {}",
845                designs.len(),
846                rhos.len()
847            );
848        }
849        if rhos.iter().any(|value| !value.is_finite()) {
850            crate::bail_invalid_estim!("block-orthogonal Gaussian REML init_rhos must be finite");
851        }
852    }
853
854    let mut ywy = Array1::<f64>::zeros(d);
855    for row in 0..n {
856        for output in 0..d {
857            ywy[output] += weight[row] * y[[row, output]] * y[[row, output]];
858        }
859    }
860    let mut grams = Vec::with_capacity(designs.len());
861    let mut rhs_blocks = Vec::with_capacity(designs.len());
862    let mut penalties_owned = Vec::with_capacity(penalties.len());
863    let mut ranks = Vec::with_capacity(penalties.len());
864    let mut penalty_logdets = Vec::with_capacity(penalties.len());
865    let mut nullity_total = 0_usize;
866    for (block, (design, penalty)) in designs.iter().zip(penalties.iter()).enumerate() {
867        let penalty_owned = canonicalize_penalty(penalty.view());
868        validate_gaussian_reml_design(design.view(), penalty_owned.view(), Some(weight.view()))?;
869        if design.nrows() != n {
870            crate::bail_invalid_estim!(
871                "block-orthogonal Gaussian REML designs[{block}] has {} rows, expected {n}",
872                design.nrows()
873            );
874        }
875        let gram = dense_xt_diag_x(design.view(), weight.view());
876        let rhs = dense_xt_diag_y(design.view(), weight.view(), y);
877        let (rank, logdet) = block_penalty_rank_logdet(penalty_owned.view())?;
878        nullity_total += penalty_owned.nrows().saturating_sub(rank);
879        grams.push(canonicalize_penalty(gram.view()));
880        rhs_blocks.push(rhs);
881        penalties_owned.push(penalty_owned);
882        ranks.push(rank);
883        penalty_logdets.push(logdet);
884    }
885    if n <= nullity_total {
886        crate::bail_invalid_estim!(
887            "block-orthogonal Gaussian REML requires n > total penalty nullity; got n={n}, nullity={nullity_total}"
888        );
889    }
890    let nu = (n - nullity_total) as f64;
891    let mut rhos = match init_rhos {
892        Some(values) => Array1::from_vec(values.to_vec()),
893        None => Array1::zeros(designs.len()),
894    };
895    let mut scale_precision = ywy.mapv(|value| nu / value.max(MIN_DEVIANCE));
896    let mut evals = Vec::new();
897    for _ in 0..40 {
898        evals.clear();
899        for block in 0..designs.len() {
900            let (rho, eval) = solve_block_orthogonal_rho(
901                &grams[block],
902                &rhs_blocks[block],
903                &penalties_owned[block],
904                rhos[block],
905                scale_precision.view(),
906                ranks[block],
907                32,
908            )?;
909            rhos[block] = rho;
910            evals.push(eval);
911        }
912        let mut explained = Array1::<f64>::zeros(d);
913        for eval in evals.iter() {
914            explained += &eval.fitted_energy;
915        }
916        let q = &ywy - &explained;
917        if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
918            return Err(EstimationError::ModelIsIllConditioned {
919                condition_number: f64::INFINITY,
920            });
921        }
922        let next_scale = q.mapv(|value| nu / value);
923        let scale_step = next_scale
924            .iter()
925            .zip(scale_precision.iter())
926            .map(|(next, old)| (next.ln() - old.ln()).abs())
927            .fold(0.0_f64, f64::max);
928        scale_precision = next_scale;
929        if scale_step < 1.0e-7 {
930            break;
931        }
932    }
933    evals.clear();
934    for block in 0..designs.len() {
935        let (rho, eval) = solve_block_orthogonal_rho(
936            &grams[block],
937            &rhs_blocks[block],
938            &penalties_owned[block],
939            rhos[block],
940            scale_precision.view(),
941            ranks[block],
942            16,
943        )?;
944        rhos[block] = rho;
945        evals.push(eval);
946    }
947
948    let coefficients = evals
949        .iter()
950        .map(|eval| eval.beta.clone())
951        .collect::<Vec<_>>();
952    let mut fitted = Array2::<f64>::zeros((n, d));
953    for (design, coef) in designs.iter().zip(coefficients.iter()) {
954        fitted += &fast_ab(&design.view(), &coef.view());
955    }
956    let mut explained = Array1::<f64>::zeros(d);
957    for eval in evals.iter() {
958        explained += &eval.fitted_energy;
959    }
960    let q = &ywy - &explained;
961    if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
962        return Err(EstimationError::ModelIsIllConditioned {
963            condition_number: f64::INFINITY,
964        });
965    }
966    let lambdas = rhos.mapv(f64::exp);
967    let edf = Array1::from_iter(evals.iter().map(|eval| eval.edf));
968    let logdet_term = evals
969        .iter()
970        .enumerate()
971        .map(|(block, eval)| {
972            eval.logdet - penalty_logdets[block] - (ranks[block] as f64) * rhos[block]
973        })
974        .sum::<f64>();
975    let scale_term = q
976        .iter()
977        .map(|value| nu * (1.0 + (2.0 * std::f64::consts::PI * value / nu).ln()))
978        .sum::<f64>();
979    Ok(GaussianRemlBlockOrthogonalResult {
980        coefficients,
981        fitted,
982        lambdas,
983        log_lambdas: rhos,
984        reml_score: 0.5 * (d as f64) * logdet_term + 0.5 * scale_term,
985        edf,
986    })
987}
988
989fn gaussian_reml_multi_closed_form_from_parts(
990    x: ArrayView2<'_, f64>,
991    y: ArrayView2<'_, f64>,
992    penalty: ArrayView2<'_, f64>,
993    nullspace_dim: Option<usize>,
994    weights: Option<ArrayView1<'_, f64>>,
995    init_lambda: Option<f64>,
996    eigen_cache: Option<&GaussianRemlEigenCache>,
997) -> Result<GaussianRemlMultiResult, EstimationError> {
998    let prepared = prepare_gaussian_reml(x, y, penalty, nullspace_dim, weights, eigen_cache)?;
999    let init_rho = init_lambda
1000        .map(validate_initial_lambda)
1001        .transpose()?
1002        .map(f64::ln);
1003    let rho = optimize_rho(&prepared, init_rho)?;
1004    let eval = prepared.evaluate(rho);
1005    let lambda = rho.exp();
1006    let coefficients = prepared.coefficients(lambda);
1007    let fitted = dense_ab(x, coefficients.view());
1008    let sigma2 = prepared.sigma2(lambda);
1009    let (reml_grad_lambda, reml_hess_lambda) =
1010        rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
1011    Ok(GaussianRemlMultiResult {
1012        lambda,
1013        rho,
1014        coefficients,
1015        fitted,
1016        reml_score: eval.cost,
1017        reml_grad_lambda,
1018        reml_hess_lambda,
1019        reml_grad_rho: eval.grad,
1020        reml_hess_rho: eval.hess,
1021        edf: eval.edf,
1022        sigma2,
1023        cache: prepared.cache,
1024    })
1025}
1026
1027pub fn gaussian_reml_free_b_score(
1028    x: ArrayView2<'_, f64>,
1029    y: ArrayView2<'_, f64>,
1030    coefficients: ArrayView2<'_, f64>,
1031    log_lambda: f64,
1032    penalty: ArrayView2<'_, f64>,
1033    weights: Option<ArrayView1<'_, f64>>,
1034) -> Result<GaussianRemlFreeBScore, EstimationError> {
1035    if !log_lambda.is_finite() {
1036        crate::bail_invalid_estim!("Gaussian REML log_lambda must be finite; got {log_lambda}");
1037    }
1038    let lambda = log_lambda.exp();
1039    let penalty_owned = canonicalize_penalty(penalty);
1040    let penalty = penalty_owned.view();
1041    let n = x.nrows();
1042    let p = x.ncols();
1043    let d = y.ncols();
1044    validate_gaussian_reml_design(x, penalty, weights)?;
1045    if y.nrows() != n {
1046        crate::bail_invalid_estim!(
1047            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
1048            y.nrows()
1049        );
1050    }
1051    if coefficients.dim() != (p, d) {
1052        crate::bail_invalid_estim!(
1053            "Gaussian REML coefficient shape mismatch: expected {p}x{d}, got {}x{}",
1054            coefficients.nrows(),
1055            coefficients.ncols()
1056        );
1057    }
1058    if y.iter().chain(coefficients.iter()).any(|v| !v.is_finite()) {
1059        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
1060    }
1061
1062    let weight = gaussian_reml_weights(n, weights)?;
1063    let cache =
1064        build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, Some(weight.view()))?;
1065    if n <= cache.nullity {
1066        crate::bail_invalid_estim!(
1067            "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
1068            cache.nullity
1069        );
1070    }
1071    let nu = n as f64 - cache.nullity as f64;
1072    let fitted = dense_ab(x, coefficients);
1073    let residual = y.to_owned() - &fitted;
1074    let xtw_residual = dense_xt_diag_y(x, weight.view(), residual.view());
1075    let s_beta = dense_ab(penalty, coefficients);
1076
1077    let mut logdet_h = cache.logdet_xtwx;
1078    let mut trace_h = 0.0;
1079    let mut edf = 0.0;
1080    for &delta in &cache.penalty_eigenvalues {
1081        let t = lambda * delta;
1082        logdet_h += (1.0 + t).ln();
1083        if delta > 0.0 {
1084            trace_h += t / (1.0 + t);
1085        }
1086        edf += 1.0 / (1.0 + t);
1087    }
1088    let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * log_lambda;
1089    let mut reml_score = 0.5 * (d as f64) * (logdet_h - logdet_s);
1090    let mut grad_log_lambda = 0.5 * (d as f64) * (trace_h - cache.penalty_rank as f64);
1091    let mut grad_coefficients = Array2::<f64>::zeros((p, d));
1092    let inverse_hessian = {
1093        let xtwx = dense_xt_diag_x(x, weight.view());
1094        let mut hessian = xtwx;
1095        hessian += &(penalty.to_owned() * lambda);
1096        hessian
1097            .cholesky(Side::Lower)
1098            .map_err(EstimationError::LinearSystemSolveFailed)?
1099            .solve_mat(&Array2::<f64>::eye(p))
1100    };
1101    let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(&cache);
1102    let mut grad_penalty = Array2::<f64>::zeros((p, p));
1103    for row in 0..p {
1104        for col in 0..p {
1105            grad_penalty[[row, col]] += 0.5
1106                * (d as f64)
1107                * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1108        }
1109    }
1110    let mut sigma2 = Array1::<f64>::zeros(d);
1111
1112    for output in 0..d {
1113        let mut weighted_rss = 0.0;
1114        for row in 0..n {
1115            let r = residual[[row, output]];
1116            weighted_rss += weight[row] * r * r;
1117        }
1118        let beta_col = coefficients.column(output);
1119        let s_beta_col = s_beta.column(output);
1120        let penalty_quadratic = beta_col.dot(&s_beta_col);
1121        let dp = (weighted_rss + lambda * penalty_quadratic).max(MIN_DEVIANCE);
1122        sigma2[output] = dp / nu;
1123        reml_score += 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln());
1124        grad_log_lambda += 0.5 * nu * lambda * penalty_quadratic / dp;
1125        let scale = nu / dp;
1126        for coeff in 0..p {
1127            grad_coefficients[[coeff, output]] =
1128                scale * (-xtw_residual[[coeff, output]] + lambda * s_beta[[coeff, output]]);
1129        }
1130        add_rank_one_penalty_vjp(0.5 * scale * lambda, beta_col, &mut grad_penalty);
1131    }
1132    for i in 0..p {
1133        for j in (i + 1)..p {
1134            let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1135            grad_penalty[[i, j]] = avg;
1136            grad_penalty[[j, i]] = avg;
1137        }
1138    }
1139
1140    Ok(GaussianRemlFreeBScore {
1141        reml_score,
1142        grad_coefficients,
1143        grad_penalty,
1144        grad_log_lambda,
1145        fitted,
1146        sigma2,
1147        edf,
1148    })
1149}
1150
1151pub fn gaussian_reml_multi_closed_form_backward(
1152    x: ArrayView2<'_, f64>,
1153    y: ArrayView2<'_, f64>,
1154    penalty: ArrayView2<'_, f64>,
1155    weights: Option<ArrayView1<'_, f64>>,
1156    init_lambda: Option<f64>,
1157    upstream_lambda: f64,
1158    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1159    upstream_fitted: Option<ArrayView2<'_, f64>>,
1160    upstream_reml_score: f64,
1161    upstream_edf: f64,
1162) -> Result<GaussianRemlBackwardResult, EstimationError> {
1163    let fit =
1164        gaussian_reml_multi_closed_form_with_cache(x, y, penalty, weights, init_lambda, None)?;
1165    gaussian_reml_multi_closed_form_backward_from_fit(
1166        x,
1167        y,
1168        penalty,
1169        weights,
1170        &fit,
1171        upstream_lambda,
1172        upstream_coefficients,
1173        upstream_fitted,
1174        upstream_reml_score,
1175        upstream_edf,
1176    )
1177}
1178
1179pub fn gaussian_reml_multi_closed_form_backward_from_fit(
1180    x: ArrayView2<'_, f64>,
1181    y: ArrayView2<'_, f64>,
1182    penalty: ArrayView2<'_, f64>,
1183    weights: Option<ArrayView1<'_, f64>>,
1184    fit: &GaussianRemlMultiResult,
1185    upstream_lambda: f64,
1186    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1187    upstream_fitted: Option<ArrayView2<'_, f64>>,
1188    upstream_reml_score: f64,
1189    upstream_edf: f64,
1190) -> Result<GaussianRemlBackwardResult, EstimationError> {
1191    validate_gaussian_reml_backward_upstreams(
1192        x,
1193        y,
1194        penalty,
1195        upstream_lambda,
1196        upstream_coefficients,
1197        upstream_fitted,
1198        upstream_reml_score,
1199        upstream_edf,
1200    )?;
1201    validate_gaussian_reml_forward_fit(x, y, penalty, weights, fit)?;
1202    let lambda = fit.lambda;
1203    let n = x.nrows();
1204    let p = x.ncols();
1205    let d = y.ncols();
1206    if !(fit.reml_hess_rho.is_finite() && fit.reml_hess_rho.abs() > 1.0e-14) {
1207        // Graceful degradation: when λ saturates, K = XᵀWX + λS is
1208        // effectively rank-deficient and the analytic VJP is undefined.
1209        // Return zero gradients (the correct shrink-out limit) instead of
1210        // raising — production training at large F can have individual
1211        // atoms saturate λ in early batches and must not blow up here.
1212        warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1213        return Ok(zero_backward_result(n, p, d));
1214    }
1215    let weight = gaussian_reml_weights(n, weights)?;
1216    let inverse_hessian = match gaussian_reml_inverse_hessian_from_cache(&fit.cache, lambda) {
1217        Ok(inv) => inv,
1218        Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1219            warn_ill_conditioned_backward_once(p, d, condition_number);
1220            return Ok(zero_backward_result(n, p, d));
1221        }
1222        Err(err) => return Err(err),
1223    };
1224    gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1225        x,
1226        y,
1227        penalty,
1228        weight,
1229        fit,
1230        inverse_hessian,
1231        upstream_lambda,
1232        upstream_coefficients,
1233        upstream_fitted,
1234        upstream_reml_score,
1235        upstream_edf,
1236        n,
1237        p,
1238        d,
1239    )
1240}
1241
1242fn gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1243    x: ArrayView2<'_, f64>,
1244    y: ArrayView2<'_, f64>,
1245    penalty: ArrayView2<'_, f64>,
1246    weight: Array1<f64>,
1247    fit: &GaussianRemlMultiResult,
1248    inverse_hessian: Array2<f64>,
1249    upstream_lambda: f64,
1250    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1251    upstream_fitted: Option<ArrayView2<'_, f64>>,
1252    upstream_reml_score: f64,
1253    upstream_edf: f64,
1254    n: usize,
1255    p: usize,
1256    d: usize,
1257) -> Result<GaussianRemlBackwardResult, EstimationError> {
1258    // Backward sees the same symmetric S the forward used. Canonicalize on
1259    // entry so an asymmetric input (e.g. a single-entry gradcheck perturbation
1260    // around a symmetric base) cannot leak into the per-helper VJPs.
1261    let penalty_owned = canonicalize_penalty(penalty);
1262    let penalty = penalty_owned.view();
1263    let lambda = fit.lambda;
1264    let beta = &fit.coefficients;
1265    let residual = y.to_owned() - &fit.fitted;
1266    let nu = n as f64 - fit.cache.nullity as f64;
1267
1268    let mut grad_x = Array2::<f64>::zeros((n, p));
1269    let mut grad_y = Array2::<f64>::zeros((n, d));
1270    let mut grad_penalty = Array2::<f64>::zeros((p, p));
1271    let mut grad_weights = Array1::<f64>::zeros(n);
1272
1273    let mut upstream_beta = Array2::<f64>::zeros((p, d));
1274    if let Some(upstream_coefficients) = upstream_coefficients {
1275        upstream_beta += &upstream_coefficients;
1276    }
1277    if let Some(upstream_fitted) = upstream_fitted {
1278        upstream_beta += &dense_atb(x, upstream_fitted);
1279        grad_x += &dense_ab(upstream_fitted, beta.t());
1280    }
1281
1282    let mut lambda_adjoint = upstream_lambda;
1283    if upstream_beta.iter().any(|value| *value != 0.0) {
1284        // A downstream loss that explicitly uses beta_hat or fitted = X beta_hat
1285        // cannot use the REML envelope shortcut.  Route those seeds through
1286        // the fixed-rho KKT adjoint M u = upstream_beta, then differentiate
1287        // X, y, weights, and S through the ridge solve.
1288        add_ridge_profile_vjp_with_lambda_grad(
1289            1.0,
1290            x,
1291            y,
1292            penalty,
1293            &weight,
1294            lambda,
1295            &inverse_hessian,
1296            beta,
1297            upstream_beta.view(),
1298            &mut grad_x,
1299            &mut grad_y,
1300            &mut grad_penalty,
1301            &mut grad_weights,
1302            &mut lambda_adjoint,
1303        );
1304    }
1305
1306    if upstream_reml_score != 0.0 {
1307        add_reml_score_vjp(
1308            upstream_reml_score,
1309            x,
1310            &weight,
1311            &inverse_hessian,
1312            beta,
1313            &residual,
1314            &fit.sigma2,
1315            nu,
1316            lambda,
1317            &fit.cache,
1318            &mut grad_x,
1319            &mut grad_y,
1320            &mut grad_penalty,
1321            &mut grad_weights,
1322        );
1323        lambda_adjoint += upstream_reml_score * fit.reml_grad_lambda;
1324    }
1325
1326    if upstream_edf != 0.0 {
1327        lambda_adjoint += add_edf_vjp(
1328            upstream_edf,
1329            x,
1330            penalty,
1331            &weight,
1332            lambda,
1333            &inverse_hessian,
1334            &mut grad_x,
1335            &mut grad_penalty,
1336            &mut grad_weights,
1337        );
1338    }
1339
1340    if lambda_adjoint != 0.0 {
1341        let root_scale = -lambda_adjoint * lambda / fit.reml_hess_rho;
1342        add_reml_rho_gradient_vjp(
1343            root_scale,
1344            x,
1345            y,
1346            penalty,
1347            &weight,
1348            lambda,
1349            &inverse_hessian,
1350            beta,
1351            &residual,
1352            &fit.sigma2,
1353            nu,
1354            &mut grad_x,
1355            &mut grad_y,
1356            &mut grad_penalty,
1357            &mut grad_weights,
1358        );
1359    }
1360
1361    // The forward consumes `S` only through the canonicalization
1362    // `S_canon = 0.5 (S + Sᵀ)`. By the chain rule, the gradient w.r.t. an
1363    // input `S_input` is `0.5 (G + Gᵀ)` where `G = ∂L/∂S_canon` is what the
1364    // per-helper VJPs accumulate. Symmetrize the full matrix here so a
1365    // single-entry perturbation `δS = ε E_{i,j}` (asymmetric, as
1366    // `torch.autograd.gradcheck` produces) sees the gradient component
1367    // `0.5 (G[i,j] + G[j,i])` it expects from FD — no caller-side
1368    // bookkeeping required.
1369    let p = grad_penalty.nrows();
1370    for i in 0..p {
1371        for j in (i + 1)..p {
1372            let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1373            grad_penalty[[i, j]] = avg;
1374            grad_penalty[[j, i]] = avg;
1375        }
1376    }
1377    Ok(GaussianRemlBackwardResult {
1378        grad_x,
1379        grad_y,
1380        grad_penalty,
1381        grad_weights,
1382    })
1383}
1384
1385pub fn gaussian_reml_multi_closed_form_backward_batch<'a>(
1386    problems: &[GaussianRemlMultiBackwardProblem<'a>],
1387    penalty: ArrayView2<'a, f64>,
1388) -> Vec<Result<GaussianRemlBackwardResult, EstimationError>> {
1389    let inverse_hessians = batched_inverse_hessians_from_caches(problems);
1390    let results: Vec<Result<GaussianRemlBackwardResult, EstimationError>> = problems
1391        .par_iter()
1392        .zip(inverse_hessians.into_par_iter())
1393        .map(|(problem, inverse_hessian_result)| {
1394            validate_gaussian_reml_backward_upstreams(
1395                problem.x.view(),
1396                problem.y.view(),
1397                penalty,
1398                problem.grad_lambda,
1399                problem.grad_coefficients.as_ref().map(|g| g.view()),
1400                problem.grad_fitted.as_ref().map(|g| g.view()),
1401                problem.grad_reml_score,
1402                problem.grad_edf,
1403            )?;
1404            validate_gaussian_reml_forward_fit(
1405                problem.x.view(),
1406                problem.y.view(),
1407                penalty,
1408                problem.weights.as_ref().map(|w| w.view()),
1409                problem.fit,
1410            )?;
1411            let n = problem.x.nrows();
1412            let p = problem.x.ncols();
1413            let d = problem.y.ncols();
1414            if !(problem.fit.reml_hess_rho.is_finite() && problem.fit.reml_hess_rho.abs() > 1.0e-14)
1415            {
1416                // Graceful degradation — see `gaussian_reml_multi_closed_form_backward_from_fit`.
1417                warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1418                return Ok(zero_backward_result(n, p, d));
1419            }
1420            let weight = gaussian_reml_weights(n, problem.weights.as_ref().map(|w| w.view()))?;
1421            let inverse_hessian = match inverse_hessian_result {
1422                Ok(inv) => inv,
1423                Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1424                    warn_ill_conditioned_backward_once(p, d, condition_number);
1425                    return Ok(zero_backward_result(n, p, d));
1426                }
1427                Err(err) => return Err(err),
1428            };
1429            gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1430                problem.x.view(),
1431                problem.y.view(),
1432                penalty,
1433                weight,
1434                problem.fit,
1435                inverse_hessian,
1436                problem.grad_lambda,
1437                problem.grad_coefficients.as_ref().map(|g| g.view()),
1438                problem.grad_fitted.as_ref().map(|g| g.view()),
1439                problem.grad_reml_score,
1440                problem.grad_edf,
1441                n,
1442                p,
1443                d,
1444            )
1445        })
1446        .collect();
1447    results
1448}
1449
1450fn rho_derivatives_to_lambda(lambda: f64, grad_rho: f64, hess_rho: f64) -> (f64, f64) {
1451    (grad_rho / lambda, (hess_rho - grad_rho) / (lambda * lambda))
1452}
1453
1454fn validate_gaussian_reml_backward_upstreams(
1455    x: ArrayView2<'_, f64>,
1456    y: ArrayView2<'_, f64>,
1457    penalty: ArrayView2<'_, f64>,
1458    upstream_lambda: f64,
1459    upstream_coefficients: Option<ArrayView2<'_, f64>>,
1460    upstream_fitted: Option<ArrayView2<'_, f64>>,
1461    upstream_reml_score: f64,
1462    upstream_edf: f64,
1463) -> Result<(), EstimationError> {
1464    if !(upstream_lambda.is_finite() && upstream_reml_score.is_finite() && upstream_edf.is_finite())
1465    {
1466        crate::bail_invalid_estim!("Gaussian REML backward upstream scalars must be finite");
1467    }
1468    if let Some(upstream_coefficients) = upstream_coefficients {
1469        if upstream_coefficients.dim() != (x.ncols(), y.ncols()) {
1470            crate::bail_invalid_estim!(
1471                "Gaussian REML backward coefficient upstream shape mismatch: expected {}x{}, got {}x{}",
1472                x.ncols(),
1473                y.ncols(),
1474                upstream_coefficients.nrows(),
1475                upstream_coefficients.ncols()
1476            );
1477        }
1478        if upstream_coefficients.iter().any(|value| !value.is_finite()) {
1479            crate::bail_invalid_estim!(
1480                "Gaussian REML backward coefficient upstream must be finite"
1481            );
1482        }
1483    }
1484    if let Some(upstream_fitted) = upstream_fitted {
1485        if upstream_fitted.dim() != y.dim() {
1486            crate::bail_invalid_estim!(
1487                "Gaussian REML backward fitted upstream shape mismatch: expected {}x{}, got {}x{}",
1488                y.nrows(),
1489                y.ncols(),
1490                upstream_fitted.nrows(),
1491                upstream_fitted.ncols()
1492            );
1493        }
1494        if upstream_fitted.iter().any(|value| !value.is_finite()) {
1495            crate::bail_invalid_estim!("Gaussian REML backward fitted upstream must be finite");
1496        }
1497    }
1498    validate_gaussian_reml_design(x, penalty, None)?;
1499    Ok(())
1500}
1501
1502fn validate_gaussian_reml_forward_fit(
1503    x: ArrayView2<'_, f64>,
1504    y: ArrayView2<'_, f64>,
1505    penalty: ArrayView2<'_, f64>,
1506    weights: Option<ArrayView1<'_, f64>>,
1507    fit: &GaussianRemlMultiResult,
1508) -> Result<(), EstimationError> {
1509    // Fingerprint the canonicalized penalty: caches are keyed on the
1510    // symmetric average, and the caller may hand us a raw input (e.g. a
1511    // single-entry-perturbed matrix produced by ``torch.autograd.gradcheck``).
1512    let penalty_owned = canonicalize_penalty(penalty);
1513    let penalty = penalty_owned.view();
1514    let n = x.nrows();
1515    let p = x.ncols();
1516    let d = y.ncols();
1517    validate_gaussian_reml_design(x, penalty, weights)?;
1518    validate_gaussian_reml_eigen_cache(&fit.cache, p)?;
1519    if y.nrows() != n
1520        || fit.coefficients.dim() != (p, d)
1521        || fit.fitted.dim() != (n, d)
1522        || fit.sigma2.len() != d
1523    {
1524        crate::bail_invalid_estim!(
1525            "Gaussian REML backward forward-state shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
1526        );
1527    }
1528    if !(fit.lambda.is_finite()
1529        && fit.lambda > 0.0
1530        && fit.rho.is_finite()
1531        && fit.reml_score.is_finite()
1532        && fit.reml_hess_rho.is_finite()
1533        && fit.edf.is_finite())
1534        || fit.coefficients.iter().any(|value| !value.is_finite())
1535        || fit.fitted.iter().any(|value| !value.is_finite())
1536        || fit.sigma2.iter().any(|value| !value.is_finite())
1537    {
1538        crate::bail_invalid_estim!("Gaussian REML backward forward state must be finite");
1539    }
1540    let penalty_fingerprint = matrix_fingerprint(penalty);
1541    if fit.cache.penalty_fingerprint != penalty_fingerprint {
1542        crate::bail_invalid_estim!("Gaussian REML backward forward-state penalty mismatch");
1543    }
1544    let weight = gaussian_reml_weights(n, weights)?;
1545    let xtwx = dense_xt_diag_x(x, weight.view());
1546    if fit.cache.xtwx_fingerprint != matrix_fingerprint(xtwx.view()) {
1547        crate::bail_invalid_estim!("Gaussian REML backward forward-state X'WX mismatch");
1548    }
1549    Ok(())
1550}
1551
1552fn gaussian_reml_inverse_hessian_from_cache(
1553    cache: &GaussianRemlEigenCache,
1554    lambda: f64,
1555) -> Result<Array2<f64>, EstimationError> {
1556    if !(lambda.is_finite() && lambda > 0.0) {
1557        crate::bail_invalid_estim!(
1558            "Gaussian REML lambda must be finite and positive; got {lambda}"
1559        );
1560    }
1561    let p = cache.penalty_eigenvalues.len();
1562    let mut scaled_basis = cache.coefficient_basis.clone();
1563    for eig in 0..p {
1564        let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1565        for row in 0..p {
1566            scaled_basis[[row, eig]] *= scale;
1567        }
1568    }
1569    let inverse = dense_ab(scaled_basis.view(), cache.coefficient_basis.t());
1570    if inverse.iter().any(|value| !value.is_finite()) {
1571        return Err(EstimationError::ModelIsIllConditioned {
1572            condition_number: f64::INFINITY,
1573        });
1574    }
1575    Ok(inverse)
1576}
1577
1578fn batched_inverse_hessians_from_caches(
1579    problems: &[GaussianRemlMultiBackwardProblem<'_>],
1580) -> Vec<Result<Array2<f64>, EstimationError>> {
1581    if problems.is_empty() {
1582        return Vec::new();
1583    }
1584    let p = problems[0].fit.cache.coefficient_basis.nrows();
1585    let uniform = p > 0
1586        && problems.iter().all(|problem| {
1587            let cache = &problem.fit.cache;
1588            cache.coefficient_basis.dim() == (p, p) && cache.penalty_eigenvalues.len() == p
1589        });
1590    if uniform && problems.len() > 1 {
1591        let mut scaled_basis = Array3::<f64>::zeros((problems.len(), p, p));
1592        let mut basis = Array3::<f64>::zeros((problems.len(), p, p));
1593        let mut valid = true;
1594        for (idx, problem) in problems.iter().enumerate() {
1595            let lambda = problem.fit.lambda;
1596            if !(lambda.is_finite() && lambda > 0.0) {
1597                valid = false;
1598                break;
1599            }
1600            let cache = &problem.fit.cache;
1601            basis
1602                .slice_mut(s![idx, .., ..])
1603                .assign(&cache.coefficient_basis);
1604            for eig in 0..p {
1605                let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1606                for row in 0..p {
1607                    scaled_basis[[idx, row, eig]] = cache.coefficient_basis[[row, eig]] * scale;
1608                }
1609            }
1610        }
1611        if valid
1612            && let Some(inverses) =
1613                gam_gpu::try_fast_abt_strided_batched(scaled_basis.view(), basis.view())
1614        {
1615            return inverses
1616                .axis_iter(Axis(0))
1617                .map(|inverse| Ok(inverse.to_owned()))
1618                .collect();
1619        }
1620    }
1621    problems
1622        .iter()
1623        .map(|problem| {
1624            gaussian_reml_inverse_hessian_from_cache(&problem.fit.cache, problem.fit.lambda)
1625        })
1626        .collect()
1627}
1628
1629/// Side-effects of the ridge-profile VJP that are independent of λ.
1630///
1631/// Computes the KKT adjoint `m = M^{-1} u` for `u = upstream_beta` and accumulates
1632/// the partials w.r.t. `X`, `y`, `S`, and `w` into the provided gradient buffers.
1633/// Returns `m` so callers that also need `∂L/∂λ` can fold in the λ-adjoint dot
1634/// product `−scale · ⟨m, S β⟩` without recomputing the adjoint solve.
1635fn ridge_profile_vjp_data_partials(
1636    scale: f64,
1637    x: ArrayView2<'_, f64>,
1638    y: ArrayView2<'_, f64>,
1639    penalty: ArrayView2<'_, f64>,
1640    weights: &Array1<f64>,
1641    lambda: f64,
1642    inverse_hessian: &Array2<f64>,
1643    beta: &Array2<f64>,
1644    upstream_beta: ArrayView2<'_, f64>,
1645    grad_x: &mut Array2<f64>,
1646    grad_y: &mut Array2<f64>,
1647    grad_penalty: &mut Array2<f64>,
1648    grad_weights: &mut Array1<f64>,
1649) -> Array2<f64> {
1650    let m = dense_ab(inverse_hessian.view(), upstream_beta);
1651    let c = dense_ab(m.view(), beta.t());
1652    let c_sym = &c + &c.t();
1653    let ymt = dense_ab(y, m.t());
1654    let xcs = dense_ab(x, c_sym.view());
1655    for i in 0..x.nrows() {
1656        let wi = weights[i] * scale;
1657        for k in 0..x.ncols() {
1658            grad_x[[i, k]] += wi * (ymt[[i, k]] - xcs[[i, k]]);
1659        }
1660    }
1661
1662    let xm = dense_ab(x, m.view());
1663    for i in 0..x.nrows() {
1664        let wi = weights[i] * scale;
1665        for j in 0..y.ncols() {
1666            grad_y[[i, j]] += wi * xm[[i, j]];
1667        }
1668    }
1669
1670    let xc = dense_ab(x, c.view());
1671    for i in 0..x.nrows() {
1672        let mut from_b = 0.0;
1673        for j in 0..y.ncols() {
1674            from_b += y[[i, j]] * xm[[i, j]];
1675        }
1676        let mut from_a = 0.0;
1677        for k in 0..x.ncols() {
1678            from_a += x[[i, k]] * xc[[i, k]];
1679        }
1680        grad_weights[i] += scale * (from_b - from_a);
1681    }
1682
1683    for row in 0..penalty.nrows() {
1684        for col in 0..penalty.ncols() {
1685            let mut value = 0.0;
1686            for output in 0..beta.ncols() {
1687                value += m[[row, output]] * beta[[col, output]];
1688            }
1689            grad_penalty[[row, col]] -= scale * lambda * value;
1690        }
1691    }
1692    m
1693}
1694
1695/// Ridge-profile VJP for callers that also need `∂L/∂λ`.
1696///
1697/// Accumulates the data/penalty/weight partials and adds the implicit-function
1698/// λ-adjoint contribution `−scale · ⟨M^{-1} u, S β⟩` into `lambda_adjoint_out`.
1699fn add_ridge_profile_vjp_with_lambda_grad(
1700    scale: f64,
1701    x: ArrayView2<'_, f64>,
1702    y: ArrayView2<'_, f64>,
1703    penalty: ArrayView2<'_, f64>,
1704    weights: &Array1<f64>,
1705    lambda: f64,
1706    inverse_hessian: &Array2<f64>,
1707    beta: &Array2<f64>,
1708    upstream_beta: ArrayView2<'_, f64>,
1709    grad_x: &mut Array2<f64>,
1710    grad_y: &mut Array2<f64>,
1711    grad_penalty: &mut Array2<f64>,
1712    grad_weights: &mut Array1<f64>,
1713    lambda_adjoint_out: &mut f64,
1714) {
1715    let m = ridge_profile_vjp_data_partials(
1716        scale,
1717        x,
1718        y,
1719        penalty,
1720        weights,
1721        lambda,
1722        inverse_hessian,
1723        beta,
1724        upstream_beta,
1725        grad_x,
1726        grad_y,
1727        grad_penalty,
1728        grad_weights,
1729    );
1730    let penalty_beta = dense_ab(penalty, beta.view());
1731    let dot = m
1732        .iter()
1733        .zip(penalty_beta.iter())
1734        .map(|(left, right)| left * right)
1735        .sum::<f64>();
1736    *lambda_adjoint_out += -scale * dot;
1737}
1738
1739/// Ridge-profile VJP for callers that hold λ fixed (e.g. the implicit-root
1740/// partial inside `add_reml_rho_gradient_vjp`). The λ-adjoint dot product is
1741/// skipped entirely — it would be unused work in this branch.
1742fn add_ridge_profile_vjp_fixed_lambda(
1743    scale: f64,
1744    x: ArrayView2<'_, f64>,
1745    y: ArrayView2<'_, f64>,
1746    penalty: ArrayView2<'_, f64>,
1747    weights: &Array1<f64>,
1748    lambda: f64,
1749    inverse_hessian: &Array2<f64>,
1750    beta: &Array2<f64>,
1751    upstream_beta: ArrayView2<'_, f64>,
1752    grad_x: &mut Array2<f64>,
1753    grad_y: &mut Array2<f64>,
1754    grad_penalty: &mut Array2<f64>,
1755    grad_weights: &mut Array1<f64>,
1756) {
1757    ridge_profile_vjp_data_partials(
1758        scale,
1759        x,
1760        y,
1761        penalty,
1762        weights,
1763        lambda,
1764        inverse_hessian,
1765        beta,
1766        upstream_beta,
1767        grad_x,
1768        grad_y,
1769        grad_penalty,
1770        grad_weights,
1771    );
1772}
1773
1774fn add_reml_score_vjp(
1775    scale: f64,
1776    x: ArrayView2<'_, f64>,
1777    weights: &Array1<f64>,
1778    inverse_hessian: &Array2<f64>,
1779    beta: &Array2<f64>,
1780    residual: &Array2<f64>,
1781    sigma2: &Array1<f64>,
1782    nu: f64,
1783    lambda: f64,
1784    cache: &GaussianRemlEigenCache,
1785    grad_x: &mut Array2<f64>,
1786    grad_y: &mut Array2<f64>,
1787    grad_penalty: &mut Array2<f64>,
1788    grad_weights: &mut Array1<f64>,
1789) {
1790    let d = beta.ncols() as f64;
1791    let xp = dense_ab(x, inverse_hessian.view());
1792    let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(cache);
1793    for row in 0..grad_penalty.nrows() {
1794        for col in 0..grad_penalty.ncols() {
1795            grad_penalty[[row, col]] +=
1796                scale * 0.5 * d * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1797        }
1798    }
1799    for i in 0..x.nrows() {
1800        let wi = weights[i] * scale * d;
1801        for k in 0..x.ncols() {
1802            grad_x[[i, k]] += wi * xp[[i, k]];
1803        }
1804        let mut leverage = 0.0;
1805        for k in 0..x.ncols() {
1806            leverage += x[[i, k]] * xp[[i, k]];
1807        }
1808        grad_weights[i] += scale * 0.5 * d * leverage;
1809    }
1810
1811    for j in 0..beta.ncols() {
1812        let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1813        let coef = scale * 0.5 * nu / dp;
1814        add_deviance_profile_vjp(
1815            coef,
1816            j,
1817            x,
1818            weights,
1819            beta,
1820            residual,
1821            grad_x,
1822            grad_y,
1823            grad_weights,
1824        );
1825        add_rank_one_penalty_vjp(coef * lambda, beta.column(j), grad_penalty);
1826    }
1827}
1828
1829/// VJP contribution from an upstream gradient on `edf`.
1830///
1831/// With `M = X^T W X + λ S`, `edf = trace(M^{-1} · X^T W X) = p - λ trace(M^{-1} S)`.
1832/// Holding `λ` fixed, the direct partials are
1833///   ∂edf/∂A = λ M^{-1} S M^{-1}      (A = X^T W X, symmetric)
1834///   ∂edf/∂S = −λ M^{-1} A M^{-1} = −λ M^{-1} + λ² M^{-1} S M^{-1}
1835///   ∂edf/∂λ = −trace(M^{-1} S) + λ trace((M^{-1} S)²)
1836/// The λ-component is returned as the lambda_adjoint contribution and routed
1837/// through the implicit-function chain by the caller (same path as
1838/// `upstream_lambda` and `upstream_reml_score`).
1839fn add_edf_vjp(
1840    scale: f64,
1841    x: ArrayView2<'_, f64>,
1842    penalty: ArrayView2<'_, f64>,
1843    weights: &Array1<f64>,
1844    lambda: f64,
1845    inverse_hessian: &Array2<f64>,
1846    grad_x: &mut Array2<f64>,
1847    grad_penalty: &mut Array2<f64>,
1848    grad_weights: &mut Array1<f64>,
1849) -> f64 {
1850    // m_inv_s = M^{-1} S, then g_a = λ M^{-1} S M^{-1} = ∂edf/∂A.
1851    let m_inv_s = dense_ab(inverse_hessian.view(), penalty);
1852    let mut g_a = dense_ab(m_inv_s.view(), inverse_hessian.view());
1853    g_a.mapv_inplace(|v| v * lambda);
1854
1855    // Chain ∂edf/∂A through A = X^T W X.
1856    //   grad_X += scale · 2 · (W X) · G_A
1857    //   grad_w_i += scale · (X G_A X^T)_{ii}
1858    let xg = dense_ab(x, g_a.view());
1859    // Row-scaled dense accumulate: grad_x[i,:] += (2·scale·weights[i]) · xg[i,:].
1860    // (Inlined here — the former `assembly::add_row_scaled_dense_into` helper was
1861    // removed as "unused" by 0cb722d, which missed this gam-pyffi-reachable caller.)
1862    let leading_scale = 2.0 * scale;
1863    for i in 0..xg.nrows() {
1864        let row_scale = leading_scale * weights[i];
1865        for k in 0..xg.ncols() {
1866            grad_x[[i, k]] += row_scale * xg[[i, k]];
1867        }
1868    }
1869    for i in 0..x.nrows() {
1870        let mut quad = 0.0;
1871        for k in 0..x.ncols() {
1872            quad += x[[i, k]] * xg[[i, k]];
1873        }
1874        grad_weights[i] += scale * quad;
1875    }
1876
1877    // ∂edf/∂S = -λ M^{-1} + λ² M^{-1} S M^{-1} = -λ M^{-1} + λ · g_a
1878    // (since g_a = λ M^{-1} S M^{-1}, so λ · g_a = λ² M^{-1} S M^{-1}).
1879    for row in 0..grad_penalty.nrows() {
1880        for col in 0..grad_penalty.ncols() {
1881            grad_penalty[[row, col]] +=
1882                scale * (-lambda * inverse_hessian[[row, col]] + lambda * g_a[[row, col]]);
1883        }
1884    }
1885
1886    // ∂edf/∂λ (with A, S fixed) = -tr(M^{-1} S) + λ tr((M^{-1} S)²).
1887    let p_dim = m_inv_s.nrows();
1888    let mut tr_m_inv_s = 0.0;
1889    for i in 0..p_dim {
1890        tr_m_inv_s += m_inv_s[[i, i]];
1891    }
1892    let mut tr_squared = 0.0;
1893    for i in 0..p_dim {
1894        for j in 0..p_dim {
1895            tr_squared += m_inv_s[[i, j]] * m_inv_s[[j, i]];
1896        }
1897    }
1898    scale * (-tr_m_inv_s + lambda * tr_squared)
1899}
1900
1901fn add_reml_rho_gradient_vjp(
1902    scale: f64,
1903    x: ArrayView2<'_, f64>,
1904    y: ArrayView2<'_, f64>,
1905    penalty: ArrayView2<'_, f64>,
1906    weights: &Array1<f64>,
1907    lambda: f64,
1908    inverse_hessian: &Array2<f64>,
1909    beta: &Array2<f64>,
1910    residual: &Array2<f64>,
1911    sigma2: &Array1<f64>,
1912    nu: f64,
1913    grad_x: &mut Array2<f64>,
1914    grad_y: &mut Array2<f64>,
1915    grad_penalty: &mut Array2<f64>,
1916    grad_weights: &mut Array1<f64>,
1917) {
1918    let d = beta.ncols() as f64;
1919    let inverse_s = dense_ab(inverse_hessian.view(), penalty);
1920    let trace_kernel = dense_ab(inverse_s.view(), inverse_hessian.view());
1921    for row in 0..grad_penalty.nrows() {
1922        for col in 0..grad_penalty.ncols() {
1923            grad_penalty[[row, col]] += scale
1924                * 0.5
1925                * d
1926                * lambda
1927                * (inverse_hessian[[col, row]] - lambda * trace_kernel[[col, row]]);
1928        }
1929    }
1930    let xt = dense_ab(x, trace_kernel.view());
1931    for i in 0..x.nrows() {
1932        let wi = -scale * d * lambda * weights[i];
1933        for k in 0..x.ncols() {
1934            grad_x[[i, k]] += wi * xt[[i, k]];
1935        }
1936        let mut quad = 0.0;
1937        for k in 0..x.ncols() {
1938            quad += x[[i, k]] * xt[[i, k]];
1939        }
1940        grad_weights[i] -= scale * 0.5 * d * lambda * quad;
1941    }
1942
1943    let s_beta = dense_ab(penalty, beta.view());
1944    let mut upstream_beta = Array2::<f64>::zeros(beta.dim());
1945    for j in 0..beta.ncols() {
1946        let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1947        let q = lambda * beta.column(j).dot(&s_beta.column(j));
1948        let q_coef = scale * nu / dp;
1949        for row in 0..beta.nrows() {
1950            upstream_beta[[row, j]] = q_coef * lambda * s_beta[[row, j]];
1951        }
1952        let dp_coef = -scale * 0.5 * nu * q / (dp * dp);
1953        add_rank_one_penalty_vjp(
1954            (0.5 * q_coef + dp_coef) * lambda,
1955            beta.column(j),
1956            grad_penalty,
1957        );
1958        add_deviance_profile_vjp(
1959            dp_coef,
1960            j,
1961            x,
1962            weights,
1963            beta,
1964            residual,
1965            grad_x,
1966            grad_y,
1967            grad_weights,
1968        );
1969    }
1970    // The implicit-root VJP holds lambda fixed inside this partial; only the
1971    // data, penalty, and weight side effects from the ridge solve are needed.
1972    add_ridge_profile_vjp_fixed_lambda(
1973        1.0,
1974        x,
1975        y,
1976        penalty,
1977        weights,
1978        lambda,
1979        inverse_hessian,
1980        beta,
1981        upstream_beta.view(),
1982        grad_x,
1983        grad_y,
1984        grad_penalty,
1985        grad_weights,
1986    );
1987}
1988
1989fn add_rank_one_penalty_vjp(
1990    scale: f64,
1991    beta_col: ArrayView1<'_, f64>,
1992    grad_penalty: &mut Array2<f64>,
1993) {
1994    for row in 0..beta_col.len() {
1995        for col in 0..beta_col.len() {
1996            grad_penalty[[row, col]] += scale * beta_col[row] * beta_col[col];
1997        }
1998    }
1999}
2000
2001fn gaussian_reml_penalty_pseudoinverse_from_cache(cache: &GaussianRemlEigenCache) -> Array2<f64> {
2002    let p = cache.penalty_eigenvalues.len();
2003    let mut scaled_basis = Array2::<f64>::zeros((p, p));
2004    for eig in 0..p {
2005        let delta = cache.penalty_eigenvalues[eig];
2006        if delta > 0.0 {
2007            for row in 0..p {
2008                scaled_basis[[row, eig]] = cache.coefficient_basis[[row, eig]] / delta;
2009            }
2010        }
2011    }
2012    dense_ab(scaled_basis.view(), cache.coefficient_basis.t())
2013}
2014
2015fn add_deviance_profile_vjp(
2016    scale: f64,
2017    output: usize,
2018    x: ArrayView2<'_, f64>,
2019    weights: &Array1<f64>,
2020    beta: &Array2<f64>,
2021    residual: &Array2<f64>,
2022    grad_x: &mut Array2<f64>,
2023    grad_y: &mut Array2<f64>,
2024    grad_weights: &mut Array1<f64>,
2025) {
2026    for i in 0..x.nrows() {
2027        let r = residual[[i, output]];
2028        let wr_scale = scale * weights[i] * r;
2029        grad_y[[i, output]] += 2.0 * wr_scale;
2030        for k in 0..x.ncols() {
2031            grad_x[[i, k]] -= 2.0 * wr_scale * beta[[k, output]];
2032        }
2033        grad_weights[i] += scale * r * r;
2034    }
2035}
2036
2037fn validate_initial_lambda(lambda: f64) -> Result<f64, EstimationError> {
2038    if lambda.is_finite() && lambda > 0.0 {
2039        Ok(lambda)
2040    } else {
2041        Err(EstimationError::InvalidInput(format!(
2042            "Gaussian REML initial lambda must be finite and positive; got {lambda}"
2043        )))
2044    }
2045}
2046
2047fn dense_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2048    fast_ab(&a, &b)
2049}
2050
2051fn dense_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2052    fast_atb(&a, &b)
2053}
2054
2055fn dense_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Array2<f64> {
2056    fast_xt_diag_x(&x, &w)
2057}
2058
2059fn dense_xt_diag_y(
2060    x: ArrayView2<'_, f64>,
2061    w: ArrayView1<'_, f64>,
2062    y: ArrayView2<'_, f64>,
2063) -> Array2<f64> {
2064    fast_xt_diag_y(&x, &w, &y)
2065}
2066
2067fn matrix_fingerprint(matrix: ArrayView2<'_, f64>) -> u64 {
2068    let mut hash = 0xcbf29ce484222325_u64;
2069    hash = fnv1a_mix(hash, matrix.nrows() as u64);
2070    hash = fnv1a_mix(hash, matrix.ncols() as u64);
2071    for &value in matrix {
2072        hash = fnv1a_mix(hash, value.to_bits());
2073    }
2074    hash
2075}
2076
2077fn fnv1a_mix(hash: u64, value: u64) -> u64 {
2078    (hash ^ value).wrapping_mul(0x100000001b3)
2079}
2080
2081/// Build eigen caches for K problems that share the same penalty matrix in a
2082/// single phased pipeline. X'WX construction is batched by the caller; each
2083/// cache then uses the same Cholesky/eigendecomposition implementation as the
2084/// single-fit path.
2085pub fn build_gaussian_reml_eigen_cache_batched(
2086    xtwx_matrices: Vec<Array2<f64>>,
2087    penalty: ArrayView2<'_, f64>,
2088    nullspace_dim: Option<usize>,
2089) -> Vec<Result<GaussianRemlEigenCache, EstimationError>> {
2090    let penalty_owned = canonicalize_penalty(penalty);
2091    let penalty = penalty_owned.view();
2092    let k = xtwx_matrices.len();
2093    if k == 0 {
2094        return Vec::new();
2095    }
2096    let fingerprints: Vec<u64> = xtwx_matrices
2097        .iter()
2098        .map(|m| matrix_fingerprint(m.view()))
2099        .collect();
2100
2101    let p = xtwx_matrices[0].nrows();
2102    let uniform_square = p > 0 && xtwx_matrices.iter().all(|matrix| matrix.dim() == (p, p));
2103    if uniform_square && k > 1 {
2104        let mut lower_matrices = xtwx_matrices.clone();
2105        if gam_gpu::try_cholesky_batched_lower_inplace(&mut lower_matrices).is_some() {
2106            // The batched penalty transform is an optional accelerator. On
2107            // failure we must NOT fabricate an empty Vec (indexing it per-block
2108            // would silently drop the transform for every block and could index
2109            // out of range) — instead route every block through the same
2110            // no-GPU-transform path used when the batched transform is
2111            // unavailable, which recomputes the whitened penalty on CPU from the
2112            // already-valid Cholesky factor `lower`.
2113            let transforms = batched_whitened_penalty_transforms(&lower_matrices, penalty);
2114            return lower_matrices
2115                .into_iter()
2116                .enumerate()
2117                .map(|(b, lower)| {
2118                    let precomputed_transform = transforms.as_ref().map(|t| t[b].clone());
2119                    gaussian_reml_eigen_cache_from_lower_with_transform(
2120                        lower,
2121                        penalty,
2122                        nullspace_dim,
2123                        fingerprints[b],
2124                        precomputed_transform,
2125                    )
2126                })
2127                .collect();
2128        }
2129    }
2130
2131    let mut results = Vec::with_capacity(k);
2132    for (b, xtwx) in xtwx_matrices.into_iter().enumerate() {
2133        let lower = match gaussian_reml_cholesky_lower(xtwx) {
2134            Ok(l) => l,
2135            Err(err) => {
2136                results.push(Err(err));
2137                continue;
2138            }
2139        };
2140        results.push(gaussian_reml_eigen_cache_from_lower_with_transform(
2141            lower,
2142            penalty,
2143            nullspace_dim,
2144            fingerprints[b],
2145            None,
2146        ));
2147    }
2148    results
2149}
2150
2151fn batched_whitened_penalty_transforms(
2152    lowers: &[Array2<f64>],
2153    penalty: ArrayView2<'_, f64>,
2154) -> Option<Vec<Array2<f64>>> {
2155    let first = lowers.first()?;
2156    let p = first.nrows();
2157    if p == 0 || first.ncols() != p || lowers.iter().any(|lower| lower.dim() != (p, p)) {
2158        return None;
2159    }
2160    let mut linv_stack = Array3::<f64>::zeros((lowers.len(), p, p));
2161    for (idx, lower) in lowers.iter().enumerate() {
2162        let l_inv = invert_lower_triangular(lower).ok()?;
2163        linv_stack.slice_mut(s![idx, .., ..]).assign(&l_inv);
2164    }
2165    let penalty_in_metric =
2166        gam_gpu::try_fast_ab_broadcast_b_batched(linv_stack.view(), penalty)?;
2167    let transformed =
2168        gam_gpu::try_fast_abt_strided_batched(penalty_in_metric.view(), linv_stack.view())?;
2169    Some(
2170        transformed
2171            .axis_iter(Axis(0))
2172            .map(|matrix| matrix.to_owned())
2173            .collect(),
2174    )
2175}
2176
2177pub fn build_gaussian_reml_eigen_cache(
2178    x: ArrayView2<'_, f64>,
2179    penalty: ArrayView2<'_, f64>,
2180    weights: Option<ArrayView1<'_, f64>>,
2181) -> Result<GaussianRemlEigenCache, EstimationError> {
2182    build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, weights)
2183}
2184
2185pub fn build_gaussian_reml_eigen_cache_with_nullspace_dim(
2186    x: ArrayView2<'_, f64>,
2187    penalty: ArrayView2<'_, f64>,
2188    nullspace_dim: Option<usize>,
2189    weights: Option<ArrayView1<'_, f64>>,
2190) -> Result<GaussianRemlEigenCache, EstimationError> {
2191    let penalty_owned = canonicalize_penalty(penalty);
2192    let penalty = penalty_owned.view();
2193    let n = x.nrows();
2194    validate_gaussian_reml_design(x, penalty, weights)?;
2195    let weight = gaussian_reml_weights(n, weights)?;
2196
2197    let xtwx = dense_xt_diag_x(x, weight.view());
2198    gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)
2199}
2200
2201fn validate_gaussian_reml_design(
2202    x: ArrayView2<'_, f64>,
2203    penalty: ArrayView2<'_, f64>,
2204    weights: Option<ArrayView1<'_, f64>>,
2205) -> Result<(), EstimationError> {
2206    let n = x.nrows();
2207    let p = x.ncols();
2208    if penalty.nrows() != p || penalty.ncols() != p {
2209        crate::bail_invalid_estim!(
2210            "Gaussian REML penalty shape mismatch: expected {p}x{p}, got {}x{}",
2211            penalty.nrows(),
2212            penalty.ncols()
2213        );
2214    }
2215    if x.iter().chain(penalty.iter()).any(|v| !v.is_finite()) {
2216        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2217    }
2218    if let Some(w) = weights {
2219        if w.len() != n {
2220            crate::bail_invalid_estim!(
2221                "Gaussian REML weights length mismatch: expected {n}, got {}",
2222                w.len()
2223            );
2224        }
2225        if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2226            crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2227        }
2228    }
2229    Ok(())
2230}
2231
2232fn gaussian_reml_weights(
2233    n: usize,
2234    weights: Option<ArrayView1<'_, f64>>,
2235) -> Result<Array1<f64>, EstimationError> {
2236    match weights {
2237        Some(w) => {
2238            if w.len() != n {
2239                crate::bail_invalid_estim!(
2240                    "Gaussian REML weights length mismatch: expected {n}, got {}",
2241                    w.len()
2242                );
2243            }
2244            if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2245                crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2246            }
2247            Ok(w.to_owned())
2248        }
2249        None => Ok(Array1::ones(n)),
2250    }
2251}
2252
2253fn gaussian_reml_eigen_cache_from_xtwx(
2254    xtwx: Array2<f64>,
2255    penalty: ArrayView2<'_, f64>,
2256    nullspace_dim: Option<usize>,
2257) -> Result<GaussianRemlEigenCache, EstimationError> {
2258    let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2259    let lower = gaussian_reml_cholesky_lower(xtwx)?;
2260    gaussian_reml_eigen_cache_from_lower(lower, penalty, nullspace_dim, xtwx_fingerprint)
2261}
2262
2263/// Cache-build entry point for callers that have already computed `L =
2264/// chol(X'WX, lower)`. Used by the batched K-way fit path so a single
2265/// `cusolverDnDpotrfBatched` call factors all K matrices, then each cache
2266/// finishes per-fit without re-doing the Cholesky.
2267fn gaussian_reml_eigen_cache_from_lower(
2268    lower: Array2<f64>,
2269    penalty: ArrayView2<'_, f64>,
2270    nullspace_dim: Option<usize>,
2271    xtwx_fingerprint: u64,
2272) -> Result<GaussianRemlEigenCache, EstimationError> {
2273    gaussian_reml_eigen_cache_from_lower_with_transform(
2274        lower,
2275        penalty,
2276        nullspace_dim,
2277        xtwx_fingerprint,
2278        None,
2279    )
2280}
2281
2282/// Cache-build variant that accepts a pre-computed whitened penalty
2283/// `L⁻¹·S·L⁻ᵀ`. Callers pass `None` to compute it from the Cholesky factor.
2284fn gaussian_reml_eigen_cache_from_lower_with_transform(
2285    lower: Array2<f64>,
2286    penalty: ArrayView2<'_, f64>,
2287    nullspace_dim: Option<usize>,
2288    xtwx_fingerprint: u64,
2289    precomputed_transform: Option<Array2<f64>>,
2290) -> Result<GaussianRemlEigenCache, EstimationError> {
2291    let p = lower.nrows();
2292    if lower.ncols() != p {
2293        crate::bail_invalid_estim!("Gaussian REML Cholesky factor must be square");
2294    }
2295    let penalty_fingerprint = matrix_fingerprint(penalty);
2296    let logdet_xtwx = 2.0 * lower.diag().iter().map(|v| v.ln()).sum::<f64>();
2297    let transformed_penalty = match precomputed_transform {
2298        Some(transformed) => transformed,
2299        None => {
2300            let l_inv = invert_lower_triangular(&lower)?;
2301            let penalty_in_metric = dense_ab(l_inv.view(), penalty);
2302            dense_ab(penalty_in_metric.view(), l_inv.t())
2303        }
2304    };
2305    let (mut penalty_eigenvalues, eigenvectors) =
2306        transformed_penalty.eigh(Side::Lower).map_err(|_| {
2307            EstimationError::ModelIsIllConditioned {
2308                condition_number: f64::INFINITY,
2309            }
2310        })?;
2311    // Rank tolerance must be RELATIVE to the largest eigenvalue — never
2312    // floored at an absolute value. The old `.max(1.0)` clamped the
2313    // tolerance up whenever max|eig| < 1, classifying genuine modes as
2314    // null for small-scale penalties (e.g. Wahba pseudo-spline `m=4`
2315    // with `K(p,p) ≈ 3e-4`). That broke REML's invariance under
2316    // `S → c·S` — the optimum λ rescales but the score landscape
2317    // diverges from the true marginal likelihood, and the smooth
2318    // contribution collapsed to ~0 on smooth truths.
2319    // Fully scale-invariant form: `safety · max|eig| · eps`.
2320    let max_abs_eig = penalty_eigenvalues
2321        .iter()
2322        .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2323    let eig_tol = max_abs_eig * EIGEN_REL_TOL;
2324    for value in &mut penalty_eigenvalues {
2325        if *value < 0.0 && value.abs() <= eig_tol {
2326            *value = 0.0;
2327        }
2328        if *value < 0.0 {
2329            crate::bail_invalid_estim!(
2330                "Gaussian REML penalty is not positive semidefinite; eigenvalue={value:.3e}"
2331            );
2332        }
2333    }
2334    let penalty_rank = penalty_eigenvalues
2335        .iter()
2336        .filter(|&&value| value > eig_tol)
2337        .count();
2338    let nullity = p - penalty_rank;
2339    if let Some(expected_nullity) = nullspace_dim
2340        && expected_nullity != nullity
2341    {
2342        crate::bail_invalid_estim!(
2343            "Gaussian REML penalty nullspace mismatch: expected {expected_nullity}, inferred {nullity}"
2344        );
2345    }
2346    let logdet_penalty_positive = gaussian_penalty_positive_logdet(penalty, penalty_rank)?;
2347    let coefficient_basis = solve_upper_triangular_matrix(&lower.t().to_owned(), &eigenvectors)?;
2348
2349    Ok(GaussianRemlEigenCache {
2350        penalty_eigenvalues,
2351        eigenvectors,
2352        coefficient_basis,
2353        xtwx_fingerprint,
2354        penalty_fingerprint,
2355        logdet_xtwx,
2356        logdet_penalty_positive,
2357        penalty_rank,
2358        nullity,
2359    })
2360}
2361
2362fn gaussian_reml_cholesky_lower(xtwx: Array2<f64>) -> Result<Array2<f64>, EstimationError> {
2363    // Attempt Cholesky directly; on failure, retry with a tiny diagonal jitter
2364    // proportional to the matrix trace. X'WX is symmetric positive semidefinite
2365    // by construction, but FP noise (e.g. in a basis whose kernel block is only
2366    // FP-orthogonal to its explicit polynomial nullspace columns, as the
2367    // periodic Duchon basis is) can push the smallest eigenvalue slightly
2368    // negative on adversarial inputs, intermittently failing Cholesky. A
2369    // jitter of 1e-12 * trace/p shifts every eigenvalue up by an amount well
2370    // below the natural scale of the well-conditioned eigenvalues but well
2371    // above f64 FP noise, eliminating the spurious-failure regime.
2372    let mut gpu_candidate = xtwx.clone();
2373    if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2374        return Ok(gpu_candidate);
2375    }
2376    if let Ok(chol) = xtwx.cholesky(Side::Lower) {
2377        return Ok(chol.lower_triangular());
2378    }
2379    let p = xtwx.nrows();
2380    let trace: f64 = (0..p).map(|i| xtwx[[i, i]]).sum();
2381    if !trace.is_finite() || trace <= 0.0 {
2382        return Err(EstimationError::ModelIsIllConditioned {
2383            condition_number: f64::INFINITY,
2384        });
2385    }
2386    let mut jitter = 1e-12 * trace / (p as f64);
2387    for _ in 0..6 {
2388        let mut jittered = xtwx.clone();
2389        for i in 0..p {
2390            jittered[[i, i]] += jitter;
2391        }
2392        let mut gpu_candidate = jittered.clone();
2393        if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2394            return Ok(gpu_candidate);
2395        }
2396        if let Ok(chol) = jittered.cholesky(Side::Lower) {
2397            return Ok(chol.lower_triangular());
2398        }
2399        jitter *= 10.0;
2400    }
2401    Err(EstimationError::ModelIsIllConditioned {
2402        condition_number: f64::INFINITY,
2403    })
2404}
2405
2406fn gaussian_penalty_positive_logdet(
2407    penalty: ArrayView2<'_, f64>,
2408    penalty_rank: usize,
2409) -> Result<f64, EstimationError> {
2410    if penalty_rank == 0 {
2411        return Ok(0.0);
2412    }
2413    let (pen_eigs, _) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
2414        EstimationError::ModelIsIllConditioned {
2415            condition_number: f64::INFINITY,
2416        }
2417    })?;
2418    // Scale-invariant relative tolerance — see the cousin site for the
2419    // rationale. Same `.max(1.0)` floor used to live here and corrupted
2420    // the positive-eigenvalue count for small-scale penalties.
2421    let pen_scale = pen_eigs
2422        .iter()
2423        .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2424    let pen_tol = pen_scale * EIGEN_REL_TOL;
2425    let mut positive_eigs: Vec<f64> = pen_eigs
2426        .iter()
2427        .copied()
2428        .filter(|&value| value > pen_tol)
2429        .collect();
2430    if positive_eigs.len() != penalty_rank {
2431        positive_eigs = pen_eigs
2432            .iter()
2433            .copied()
2434            .filter(|&value| value > 0.0)
2435            .collect();
2436        positive_eigs.sort_by(|a, b| b.total_cmp(a));
2437        if positive_eigs.len() < penalty_rank {
2438            return Err(EstimationError::ModelIsIllConditioned {
2439                condition_number: f64::INFINITY,
2440            });
2441        }
2442        positive_eigs.truncate(penalty_rank);
2443    }
2444    Ok(positive_eigs.iter().map(|value| value.ln()).sum())
2445}
2446
2447fn validate_gaussian_reml_eigen_cache(
2448    cache: &GaussianRemlEigenCache,
2449    p: usize,
2450) -> Result<(), EstimationError> {
2451    if cache.penalty_eigenvalues.len() != p
2452        || cache.eigenvectors.dim() != (p, p)
2453        || cache.coefficient_basis.dim() != (p, p)
2454    {
2455        crate::bail_invalid_estim!(
2456            "Gaussian REML eigen cache dimension mismatch: expected {p} coefficients"
2457        );
2458    }
2459    if cache.penalty_rank > p || cache.nullity > p || cache.penalty_rank + cache.nullity != p {
2460        crate::bail_invalid_estim!(
2461            "Gaussian REML eigen cache rank/nullity mismatch: rank={}, nullity={}, p={p}",
2462            cache.penalty_rank,
2463            cache.nullity
2464        );
2465    }
2466    if !(cache.logdet_xtwx.is_finite() && cache.logdet_penalty_positive.is_finite()) {
2467        crate::bail_invalid_estim!("Gaussian REML eigen cache log-determinants must be finite");
2468    }
2469    if cache
2470        .penalty_eigenvalues
2471        .iter()
2472        .any(|value| !value.is_finite() || *value < 0.0)
2473        || cache.eigenvectors.iter().any(|value| !value.is_finite())
2474        || cache
2475            .coefficient_basis
2476            .iter()
2477            .any(|value| !value.is_finite())
2478    {
2479        crate::bail_invalid_estim!(
2480            "Gaussian REML eigen cache entries must be finite with non-negative eigenvalues"
2481                .to_string(),
2482        );
2483    }
2484    Ok::<(), _>(())
2485}
2486
2487fn prepare_gaussian_reml(
2488    x: ArrayView2<'_, f64>,
2489    y: ArrayView2<'_, f64>,
2490    penalty: ArrayView2<'_, f64>,
2491    nullspace_dim: Option<usize>,
2492    weights: Option<ArrayView1<'_, f64>>,
2493    eigen_cache: Option<&GaussianRemlEigenCache>,
2494) -> Result<GaussianRemlPrepared, EstimationError> {
2495    // Enforce the symmetric-S contract once at the central forward chokepoint;
2496    // every closed-form forward path funnels through here.
2497    let penalty_owned = canonicalize_penalty(penalty);
2498    let penalty = penalty_owned.view();
2499    let n = x.nrows();
2500    let p = x.ncols();
2501    let d = y.ncols();
2502    validate_gaussian_reml_design(x, penalty, weights)?;
2503    if y.nrows() != n {
2504        crate::bail_invalid_estim!(
2505            "Gaussian REML row mismatch: X has {n} rows but Y has {}",
2506            y.nrows()
2507        );
2508    }
2509    if y.iter().any(|v| !v.is_finite()) {
2510        crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2511    }
2512    let weight = gaussian_reml_weights(n, weights)?;
2513
2514    let xtwy = dense_xt_diag_y(x, weight.view(), y);
2515    let ywy = Array1::from_iter((0..d).map(|j| {
2516        let mut value = 0.0;
2517        for row in 0..n {
2518            value += weight[row] * y[[row, j]] * y[[row, j]];
2519        }
2520        value
2521    }));
2522    let xtwx = dense_xt_diag_x(x, weight.view());
2523
2524    if let Some(cache) = eigen_cache {
2525        validate_gaussian_reml_eigen_cache(cache, p)?;
2526        let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2527        if cache.xtwx_fingerprint != xtwx_fingerprint {
2528            crate::bail_invalid_estim!("Gaussian REML eigen cache X'WX mismatch");
2529        }
2530        let penalty_fingerprint = matrix_fingerprint(penalty);
2531        if cache.penalty_fingerprint != penalty_fingerprint {
2532            crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
2533        }
2534        if let Some(expected_nullity) = nullspace_dim
2535            && expected_nullity != cache.nullity
2536        {
2537            crate::bail_invalid_estim!(
2538                "Gaussian REML eigen cache nullspace mismatch: expected {expected_nullity}, got {}",
2539                cache.nullity
2540            );
2541        }
2542        if n <= cache.nullity {
2543            crate::bail_invalid_estim!(
2544                "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
2545                cache.nullity
2546            );
2547        }
2548        let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2549        let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2550        return Ok(GaussianRemlPrepared {
2551            cache: cache.clone(),
2552            ywy,
2553            projected_rhs_squared,
2554            projected_rhs,
2555            n_observations: n,
2556            n_outputs: d,
2557        });
2558    }
2559
2560    let cache = gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)?;
2561    if n <= cache.nullity {
2562        crate::bail_invalid_estim!(
2563            "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
2564            cache.nullity
2565        );
2566    }
2567    let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2568    let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2569
2570    Ok(GaussianRemlPrepared {
2571        cache,
2572        ywy,
2573        projected_rhs_squared,
2574        projected_rhs,
2575        n_observations: n,
2576        n_outputs: d,
2577    })
2578}
2579
2580impl GaussianRemlPrepared {
2581    fn nu(&self) -> f64 {
2582        self.n_observations as f64 - self.cache.nullity as f64
2583    }
2584
2585    fn evaluate(&self, rho: f64) -> ObjectiveEval {
2586        evaluate_reml_parts(
2587            &self.cache,
2588            self.ywy.view(),
2589            self.projected_rhs_squared.view(),
2590            self.n_observations,
2591            self.n_outputs,
2592            rho,
2593        )
2594    }
2595
2596    fn coefficients(&self, lambda: f64) -> Array2<f64> {
2597        let mut scaled = self.projected_rhs.clone();
2598        for i in 0..self.cache.penalty_eigenvalues.len() {
2599            let scale = 1.0 / (1.0 + lambda * self.cache.penalty_eigenvalues[i]);
2600            for value in scaled.row_mut(i) {
2601                *value *= scale;
2602            }
2603        }
2604        dense_ab(self.cache.coefficient_basis.view(), scaled.view())
2605    }
2606
2607    fn sigma2(&self, lambda: f64) -> Array1<f64> {
2608        let nu = self.nu();
2609        Array1::from_iter((0..self.n_outputs).map(|j| {
2610            let mut fitted_quadratic = 0.0;
2611            for i in 0..self.cache.penalty_eigenvalues.len() {
2612                let denom = 1.0 + lambda * self.cache.penalty_eigenvalues[i];
2613                fitted_quadratic += self.projected_rhs_squared[[i, j]] / denom;
2614            }
2615            ((self.ywy[j] - fitted_quadratic).max(MIN_DEVIANCE)) / nu
2616        }))
2617    }
2618}
2619
2620fn optimize_rho(
2621    prepared: &GaussianRemlPrepared,
2622    init_rho: Option<f64>,
2623) -> Result<f64, EstimationError> {
2624    if prepared.cache.penalty_rank == 0 {
2625        return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2626    }
2627
2628    const GRID_INTERVALS: usize = 96;
2629    let mut stationary = Vec::<f64>::new();
2630    let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2631    let mut prev_rho = RHO_LOWER;
2632    let mut prev_eval = prepared.evaluate(prev_rho);
2633    grid.push((prev_rho, prev_eval.cost));
2634    for i in 1..=GRID_INTERVALS {
2635        let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2636        let eval = prepared.evaluate(rho);
2637        grid.push((rho, eval.cost));
2638        if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2639            push_candidate(
2640                &mut stationary,
2641                refine_stationary_rho(prepared, prev_rho, rho, 0.5 * (prev_rho + rho)),
2642            );
2643        }
2644        prev_rho = rho;
2645        prev_eval = eval;
2646    }
2647
2648    let mut candidates = stationary;
2649    push_candidate(&mut candidates, RHO_LOWER);
2650    push_candidate(&mut candidates, RHO_UPPER);
2651    if let Some(rho0) = init_rho {
2652        push_candidate(&mut candidates, rho0);
2653    }
2654    if let Some(rho) = refine_best_grid_cell(prepared, &grid) {
2655        push_candidate(&mut candidates, rho);
2656    }
2657
2658    // Evaluate each candidate exactly once. `min_by` over a comparator that
2659    // re-evaluates would do O(n log n) extra `prepared.evaluate` calls during
2660    // the sort.
2661    candidates
2662        .into_iter()
2663        .map(|rho| (rho, prepared.evaluate(rho).cost))
2664        .min_by(|(_, a), (_, b)| a.total_cmp(b))
2665        .map(|(rho, _)| rho)
2666        .ok_or_else(|| {
2667            EstimationError::InvalidInput(
2668                "Gaussian REML optimizer produced no candidates".to_string(),
2669            )
2670        })
2671}
2672
2673fn refine_best_grid_cell(prepared: &GaussianRemlPrepared, grid: &[(f64, f64)]) -> Option<f64> {
2674    let best_idx = grid
2675        .iter()
2676        .enumerate()
2677        .filter(|(_, (_, cost))| cost.is_finite())
2678        .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2679        .map(|(idx, _)| idx)?;
2680    if best_idx == 0 || best_idx + 1 == grid.len() {
2681        return Some(grid[best_idx].0);
2682    }
2683    // The best interior grid cell brackets a genuine REML minimum (its cost is
2684    // below both neighbours), so the objective gradient changes sign across
2685    // `[grid[i-1], grid[i+1]]`. Refine to that stationary point (∂V/∂ρ = 0)
2686    // rather than minimising the cost with a golden section: the cost-based
2687    // search only locates ρ to ~√ε of the cell (~1e-8), whereas the
2688    // grad-sign-change branch already contributes stationary candidates
2689    // converged to GRAD_TOL (~1e-12). When both target the same minimum, the
2690    // ~1e-16 cost ordering between a 1e-8-accurate and a 1e-12-accurate ρ is
2691    // numerical noise, so `min_by(cost)` used to pick between two ρ values
2692    // ~1e-8 apart essentially at random — making the selected λ̂ a
2693    // non-smooth function of the design X (its ~1e-8 jumps wrecked the
2694    // closed-form REML reverse-mode VJP's agreement with finite differences).
2695    // Returning the stationary point makes every interior candidate a
2696    // GRAD_TOL-accurate root, so the residual selection jitter collapses to
2697    // ~1e-12 and λ̂(X) is smooth to the IFT gradient.
2698    Some(refine_stationary_rho(
2699        prepared,
2700        grid[best_idx - 1].0,
2701        grid[best_idx + 1].0,
2702        grid[best_idx].0,
2703    ))
2704}
2705
2706fn fill_weighted_rhs_no_alloc(
2707    x: ArrayView2<'_, f64>,
2708    y: ArrayView2<'_, f64>,
2709    weights: Option<ArrayView1<'_, f64>>,
2710    workspace: &mut GaussianRemlNoAllocWorkspace,
2711) -> Result<(), EstimationError> {
2712    let d = y.ncols();
2713
2714    // XᵀWY and YᵀWY via faer BLAS. Both `fast_xt_diag_y` and `fast_atb`
2715    // dispatch to faer's SIMD-optimized GEMM (with chunked weight scaling
2716    // when weights are present), replacing the previous scalar triple loop
2717    // over (n, p, d). For YᵀWY we only need the diagonal entries, but d is
2718    // small (typically 1–10) so computing the full d×d Gram is negligible.
2719    let (xtwy, ywy_full) = match weights {
2720        Some(w) => (fast_xt_diag_y(&x, &w, &y), fast_xt_diag_y(&y, &w, &y)),
2721        None => (fast_atb(&x, &y), fast_atb(&y, &y)),
2722    };
2723    workspace.xtwy.assign(&xtwy);
2724    for output in 0..d {
2725        workspace.ywy[output] = ywy_full[[output, output]];
2726    }
2727
2728    if workspace
2729        .xtwy
2730        .iter()
2731        .chain(workspace.ywy.iter())
2732        .any(|value| !value.is_finite())
2733    {
2734        crate::bail_invalid_estim!("Gaussian REML weighted cross-products must be finite");
2735    }
2736    Ok(())
2737}
2738
2739fn project_rhs_no_alloc(
2740    cache: &GaussianRemlEigenCache,
2741    workspace: &mut GaussianRemlNoAllocWorkspace,
2742) {
2743    // projected_rhs = coefficient_basisᵀ · xtwy, computed via faer BLAS
2744    // (was previously a scalar triple loop over (p, d, p)).
2745    let projected = fast_atb(&cache.coefficient_basis, &workspace.xtwy);
2746    workspace.projected_rhs.assign(&projected);
2747    let p = cache.penalty_eigenvalues.len();
2748    let d = workspace.ywy.len();
2749    for eig in 0..p {
2750        for output in 0..d {
2751            let value = workspace.projected_rhs[[eig, output]];
2752            workspace.projected_rhs_squared[[eig, output]] = value * value;
2753        }
2754    }
2755}
2756
2757fn evaluate_reml_parts(
2758    cache: &GaussianRemlEigenCache,
2759    ywy: ArrayView1<'_, f64>,
2760    projected_rhs_squared: ArrayView2<'_, f64>,
2761    n_observations: usize,
2762    n_outputs: usize,
2763    rho: f64,
2764) -> ObjectiveEval {
2765    let lambda = rho.exp();
2766    let nu = n_observations as f64 - cache.nullity as f64;
2767    let d = n_outputs as f64;
2768
2769    // Each term's value and its ρ-derivatives come back from ONE function so
2770    // they cannot be edited independently; `+=` folds the triple in lock-step.
2771    let (logdet_term, edf) = gaussian_reml_logdet_term(cache, rho, d);
2772    let mut eval = ObjectiveEval {
2773        cost: 0.0,
2774        grad: 0.0,
2775        hess: 0.0,
2776        edf,
2777    };
2778    eval += logdet_term;
2779    for output in 0..n_outputs {
2780        eval +=
2781            gaussian_reml_dispersion_term(cache, ywy, projected_rhs_squared, output, nu, lambda);
2782    }
2783    eval
2784}
2785
2786fn optimize_rho_no_alloc(
2787    cache: &GaussianRemlEigenCache,
2788    ywy: ArrayView1<'_, f64>,
2789    projected_rhs_squared: ArrayView2<'_, f64>,
2790    n_observations: usize,
2791    n_outputs: usize,
2792    init_rho: Option<f64>,
2793) -> Result<f64, EstimationError> {
2794    if cache.penalty_rank == 0 {
2795        return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2796    }
2797
2798    let lower_eval = evaluate_reml_parts(
2799        cache,
2800        ywy,
2801        projected_rhs_squared,
2802        n_observations,
2803        n_outputs,
2804        RHO_LOWER,
2805    );
2806
2807    let mut best_rho = RHO_LOWER;
2808    let mut best_cost = lower_eval.cost;
2809
2810    const GRID_INTERVALS: usize = 96;
2811    let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2812    let mut prev_rho = RHO_LOWER;
2813    let mut prev_eval = lower_eval;
2814    grid.push((prev_rho, prev_eval.cost));
2815    for i in 1..=GRID_INTERVALS {
2816        let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2817        let eval = evaluate_reml_parts(
2818            cache,
2819            ywy,
2820            projected_rhs_squared,
2821            n_observations,
2822            n_outputs,
2823            rho,
2824        );
2825        grid.push((rho, eval.cost));
2826        if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2827            let stationary_rho = refine_stationary_rho_no_alloc(
2828                cache,
2829                ywy,
2830                projected_rhs_squared,
2831                n_observations,
2832                n_outputs,
2833                prev_rho,
2834                rho,
2835                0.5 * (prev_rho + rho),
2836            );
2837            consider_rho_no_alloc(
2838                cache,
2839                ywy,
2840                projected_rhs_squared,
2841                n_observations,
2842                n_outputs,
2843                stationary_rho,
2844                &mut best_rho,
2845                &mut best_cost,
2846            );
2847        }
2848        prev_rho = rho;
2849        prev_eval = eval;
2850    }
2851    if let Some(best_idx) = grid
2852        .iter()
2853        .enumerate()
2854        .filter(|(_, (_, cost))| cost.is_finite())
2855        .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2856        .map(|(idx, _)| idx)
2857    {
2858        let refined = if best_idx == 0 || best_idx + 1 == grid.len() {
2859            grid[best_idx].0
2860        } else {
2861            // Refine the best interior grid cell to the REML stationary point
2862            // (∂V/∂ρ = 0) rather than the golden-section cost minimum, mirroring
2863            // the allocating `refine_best_grid_cell`. A cost-based search locates
2864            // ρ only to ~1e-8, which competed against the GRAD_TOL-accurate
2865            // (~1e-12) stationary candidates in the cost `min_by` below and made
2866            // the selected λ̂ jump ~1e-8 with the design — a non-smoothness the
2867            // closed-form REML VJP could not match under finite differences.
2868            // (Keeping both optimizers' refinement identical preserves their
2869            // allocating/no-alloc bit-for-bit parity.)
2870            refine_stationary_rho_no_alloc(
2871                cache,
2872                ywy,
2873                projected_rhs_squared,
2874                n_observations,
2875                n_outputs,
2876                grid[best_idx - 1].0,
2877                grid[best_idx + 1].0,
2878                grid[best_idx].0,
2879            )
2880        };
2881        consider_rho_no_alloc(
2882            cache,
2883            ywy,
2884            projected_rhs_squared,
2885            n_observations,
2886            n_outputs,
2887            refined,
2888            &mut best_rho,
2889            &mut best_cost,
2890        );
2891    }
2892
2893    consider_rho_no_alloc(
2894        cache,
2895        ywy,
2896        projected_rhs_squared,
2897        n_observations,
2898        n_outputs,
2899        RHO_UPPER,
2900        &mut best_rho,
2901        &mut best_cost,
2902    );
2903    if let Some(rho0) = init_rho {
2904        consider_rho_no_alloc(
2905            cache,
2906            ywy,
2907            projected_rhs_squared,
2908            n_observations,
2909            n_outputs,
2910            rho0,
2911            &mut best_rho,
2912            &mut best_cost,
2913        );
2914    }
2915
2916    if best_cost.is_finite() {
2917        Ok(best_rho)
2918    } else {
2919        Err(EstimationError::InvalidInput(
2920            "Gaussian REML optimizer produced no finite candidates".to_string(),
2921        ))
2922    }
2923}
2924
2925fn consider_rho_no_alloc(
2926    cache: &GaussianRemlEigenCache,
2927    ywy: ArrayView1<'_, f64>,
2928    projected_rhs_squared: ArrayView2<'_, f64>,
2929    n_observations: usize,
2930    n_outputs: usize,
2931    rho: f64,
2932    best_rho: &mut f64,
2933    best_cost: &mut f64,
2934) {
2935    if !rho.is_finite() {
2936        return;
2937    }
2938    let candidate = rho.clamp(RHO_LOWER, RHO_UPPER);
2939    let eval = evaluate_reml_parts(
2940        cache,
2941        ywy,
2942        projected_rhs_squared,
2943        n_observations,
2944        n_outputs,
2945        candidate,
2946    );
2947    if eval.cost < *best_cost {
2948        *best_rho = candidate;
2949        *best_cost = eval.cost;
2950    }
2951}
2952
2953fn refine_stationary_rho_no_alloc(
2954    cache: &GaussianRemlEigenCache,
2955    ywy: ArrayView1<'_, f64>,
2956    projected_rhs_squared: ArrayView2<'_, f64>,
2957    n_observations: usize,
2958    n_outputs: usize,
2959    mut lo: f64,
2960    mut hi: f64,
2961    mut rho: f64,
2962) -> f64 {
2963    for _ in 0..80 {
2964        let eval = evaluate_reml_parts(
2965            cache,
2966            ywy,
2967            projected_rhs_squared,
2968            n_observations,
2969            n_outputs,
2970            rho,
2971        );
2972        if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
2973            return rho;
2974        }
2975        if eval.grad >= 0.0 {
2976            hi = rho;
2977        } else {
2978            lo = rho;
2979        }
2980        let newton = if eval.hess > 0.0 {
2981            let candidate = rho - eval.grad / eval.hess;
2982            (candidate > lo && candidate < hi).then_some(candidate)
2983        } else {
2984            None
2985        };
2986        if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
2987            break;
2988        }
2989        rho = newton.unwrap_or(0.5 * (lo + hi));
2990    }
2991    0.5 * (lo + hi)
2992}
2993
2994fn fill_coefficients_no_alloc(
2995    cache: &GaussianRemlEigenCache,
2996    workspace: &mut GaussianRemlNoAllocWorkspace,
2997    lambda: f64,
2998    mut coefficients: ArrayViewMut2<'_, f64>,
2999) {
3000    let p = cache.penalty_eigenvalues.len();
3001    let d = workspace.ywy.len();
3002    for eig in 0..p {
3003        let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
3004        for output in 0..d {
3005            workspace.scaled_projected_rhs[[eig, output]] =
3006                workspace.projected_rhs[[eig, output]] * scale;
3007        }
3008    }
3009
3010    for col in 0..p {
3011        for output in 0..d {
3012            let mut value = 0.0;
3013            for eig in 0..p {
3014                value += cache.coefficient_basis[[col, eig]]
3015                    * workspace.scaled_projected_rhs[[eig, output]];
3016            }
3017            coefficients[[col, output]] = value;
3018        }
3019    }
3020}
3021
3022fn fill_fitted_no_alloc(
3023    x: ArrayView2<'_, f64>,
3024    coefficients: ArrayView2<'_, f64>,
3025    mut fitted: ArrayViewMut2<'_, f64>,
3026) {
3027    let n = x.nrows();
3028    let p = x.ncols();
3029    let d = coefficients.ncols();
3030    for row in 0..n {
3031        for output in 0..d {
3032            let mut value = 0.0;
3033            for col in 0..p {
3034                value += x[[row, col]] * coefficients[[col, output]];
3035            }
3036            fitted[[row, output]] = value;
3037        }
3038    }
3039}
3040
3041fn fill_sigma2_no_alloc(
3042    cache: &GaussianRemlEigenCache,
3043    ywy: ArrayView1<'_, f64>,
3044    projected_rhs_squared: ArrayView2<'_, f64>,
3045    n_observations: usize,
3046    n_outputs: usize,
3047    lambda: f64,
3048    mut sigma2: ArrayViewMut1<'_, f64>,
3049) {
3050    let nu = n_observations as f64 - cache.nullity as f64;
3051    for output in 0..n_outputs {
3052        let mut fitted_quadratic = 0.0;
3053        for eig in 0..cache.penalty_eigenvalues.len() {
3054            let denom = 1.0 + lambda * cache.penalty_eigenvalues[eig];
3055            fitted_quadratic += projected_rhs_squared[[eig, output]] / denom;
3056        }
3057        sigma2[output] = ((ywy[output] - fitted_quadratic).max(MIN_DEVIANCE)) / nu;
3058    }
3059}
3060
3061fn push_candidate(candidates: &mut Vec<f64>, rho: f64) {
3062    if rho.is_finite() {
3063        candidates.push(rho.clamp(RHO_LOWER, RHO_UPPER));
3064    }
3065}
3066
3067fn refine_stationary_rho(
3068    prepared: &GaussianRemlPrepared,
3069    mut lo: f64,
3070    mut hi: f64,
3071    mut rho: f64,
3072) -> f64 {
3073    for _ in 0..80 {
3074        let eval = prepared.evaluate(rho);
3075        if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3076            return rho;
3077        }
3078        if eval.grad >= 0.0 {
3079            hi = rho;
3080        } else {
3081            lo = rho;
3082        }
3083        let newton = if eval.hess > 0.0 {
3084            let candidate = rho - eval.grad / eval.hess;
3085            (candidate > lo && candidate < hi).then_some(candidate)
3086        } else {
3087            None
3088        };
3089        if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3090            break;
3091        }
3092        rho = newton.unwrap_or(0.5 * (lo + hi));
3093    }
3094    0.5 * (lo + hi)
3095}
3096
3097fn invert_lower_triangular(lower: &Array2<f64>) -> Result<Array2<f64>, EstimationError> {
3098    let n = lower.nrows();
3099    if lower.ncols() != n {
3100        crate::bail_invalid_estim!("lower-triangular solve requires a square matrix");
3101    }
3102    let eye = Array2::eye(n);
3103    solve_lower_triangular_matrix(lower, &eye)
3104}
3105
3106fn solve_lower_triangular_matrix(
3107    lower: &Array2<f64>,
3108    rhs: &Array2<f64>,
3109) -> Result<Array2<f64>, EstimationError> {
3110    let n = lower.nrows();
3111    if lower.ncols() != n || rhs.nrows() != n {
3112        crate::bail_invalid_estim!("lower-triangular solve dimension mismatch");
3113    }
3114    if let Some(out) = gam_gpu::try_solve_lower_triangular_matrix(lower.view(), rhs.view()) {
3115        return Ok(out);
3116    }
3117    let mut out = Array2::<f64>::zeros(rhs.dim());
3118    for col in 0..rhs.ncols() {
3119        for i in 0..n {
3120            let mut value = rhs[[i, col]];
3121            for k in 0..i {
3122                value -= lower[[i, k]] * out[[k, col]];
3123            }
3124            let diag = lower[[i, i]];
3125            if !(diag.is_finite() && diag.abs() > 0.0) {
3126                return Err(EstimationError::ModelIsIllConditioned {
3127                    condition_number: f64::INFINITY,
3128                });
3129            }
3130            out[[i, col]] = value / diag;
3131        }
3132    }
3133    Ok(out)
3134}
3135
3136/// Solve the SPD system `L Lᵀ X = rhs` for `X` given the lower Cholesky factor
3137/// `L` (as returned by [`gaussian_reml_cholesky_lower`]): a forward solve
3138/// against `L` followed by a back solve against `Lᵀ`.
3139fn solve_spd_from_lower_factor(
3140    lower: &Array2<f64>,
3141    rhs: &Array2<f64>,
3142) -> Result<Array2<f64>, EstimationError> {
3143    let forward = solve_lower_triangular_matrix(lower, rhs)?;
3144    solve_upper_triangular_matrix(&lower.t().to_owned(), &forward)
3145}
3146
3147fn solve_upper_triangular_matrix(
3148    upper: &Array2<f64>,
3149    rhs: &Array2<f64>,
3150) -> Result<Array2<f64>, EstimationError> {
3151    let n = upper.nrows();
3152    if upper.ncols() != n || rhs.nrows() != n {
3153        crate::bail_invalid_estim!("upper-triangular solve dimension mismatch");
3154    }
3155    if let Some(out) = gam_gpu::try_solve_upper_triangular_matrix(upper.view(), rhs.view()) {
3156        return Ok(out);
3157    }
3158    let mut out = Array2::<f64>::zeros(rhs.dim());
3159    for col in 0..rhs.ncols() {
3160        for i_rev in 0..n {
3161            let i = n - 1 - i_rev;
3162            let mut value = rhs[[i, col]];
3163            for k in (i + 1)..n {
3164                value -= upper[[i, k]] * out[[k, col]];
3165            }
3166            let diag = upper[[i, i]];
3167            if !(diag.is_finite() && diag.abs() > 0.0) {
3168                return Err(EstimationError::ModelIsIllConditioned {
3169                    condition_number: f64::INFINITY,
3170                });
3171            }
3172            out[[i, col]] = value / diag;
3173        }
3174    }
3175    Ok(out)
3176}
3177
3178#[cfg(test)]
3179mod tests {
3180    use super::*;
3181    use ndarray::array;
3182
3183    #[test]
3184    fn edf_does_not_double_count_penalty_nullspace() {
3185        let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0],];
3186        let y = array![[0.0], [1.0], [1.8], [3.2], [4.1]];
3187        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3188        let result =
3189            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3190                .expect("small full-rank Gaussian REML fit");
3191
3192        assert!(result.edf >= result.cache.nullity as f64);
3193        assert!(result.edf <= x.ncols() as f64 + 1.0e-10);
3194    }
3195
3196    #[test]
3197    fn multi_output_duplicate_columns_match_scalar_fit() {
3198        let x = array![
3199            [1.0, -1.0],
3200            [1.0, -0.5],
3201            [1.0, 0.0],
3202            [1.0, 0.5],
3203            [1.0, 1.0],
3204            [1.0, 1.5],
3205        ];
3206        let y1 = array![0.5, 0.2, 0.0, 0.3, 1.1, 2.0];
3207        let y = Array2::from_shape_fn(
3208            (y1.len(), 2),
3209            |(i, j)| if j == 0 { y1[i] } else { 2.0 * y1[i] },
3210        );
3211        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3212
3213        let scalar =
3214            gaussian_reml_closed_form(x.view(), y1.view(), penalty.view(), None, Some(0.0))
3215                .expect("scalar Gaussian REML fit");
3216        let multi =
3217            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3218                .expect("multi-output Gaussian REML fit");
3219
3220        assert!((multi.rho - scalar.rho).abs() <= 1.0e-8);
3221        for i in 0..x.ncols() {
3222            assert!((multi.coefficients[[i, 0]] - scalar.coefficients[i]).abs() <= 1.0e-8);
3223            assert!((multi.coefficients[[i, 1]] - 2.0 * scalar.coefficients[i]).abs() <= 1.0e-8);
3224        }
3225    }
3226
3227    #[test]
3228    fn warm_start_reuses_cache_and_lambda_seed() {
3229        let x = array![
3230            [1.0, -1.0],
3231            [1.0, -0.25],
3232            [1.0, 0.5],
3233            [1.0, 1.25],
3234            [1.0, 2.0],
3235        ];
3236        let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3237        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3238
3239        let cold =
3240            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3241                .expect("cold fit");
3242        let warm_start = GaussianRemlWarmStart::from_multi_result(&cold);
3243        let warm = gaussian_reml_multi_closed_form_warm_started(
3244            x.view(),
3245            y.view(),
3246            penalty.view(),
3247            None,
3248            Some(&warm_start),
3249        )
3250        .expect("warm-started fit");
3251
3252        assert!((cold.lambda - warm.lambda).abs() <= 1.0e-10);
3253        assert_eq!(cold.cache.xtwx_fingerprint, warm.cache.xtwx_fingerprint);
3254        for i in 0..x.ncols() {
3255            assert!((cold.coefficients[[i, 0]] - warm.coefficients[[i, 0]]).abs() <= 1.0e-10);
3256        }
3257    }
3258
3259    #[test]
3260    fn warm_start_cache_rejects_different_penalty_geometry() {
3261        let x = array![
3262            [1.0, -1.0],
3263            [1.0, -0.25],
3264            [1.0, 0.5],
3265            [1.0, 1.25],
3266            [1.0, 2.0],
3267        ];
3268        let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3269        let penalty_a = array![[0.0, 0.0], [0.0, 1.0]];
3270        let penalty_b = array![[1.0, -1.0], [-1.0, 1.0]];
3271
3272        let first =
3273            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty_a.view(), None, Some(0.0))
3274                .expect("first fit");
3275        let warm_start = GaussianRemlWarmStart::from_multi_result(&first);
3276        let err = gaussian_reml_multi_closed_form_warm_started(
3277            x.view(),
3278            y.view(),
3279            penalty_b.view(),
3280            None,
3281            Some(&warm_start),
3282        )
3283        .expect_err("penalty-mismatched cache must be rejected");
3284
3285        assert!(err.to_string().contains("penalty mismatch"));
3286    }
3287
3288    #[test]
3289    fn no_alloc_cache_path_matches_allocating_fit() {
3290        let x = array![
3291            [1.0, -1.0, 0.25],
3292            [1.0, -0.5, 0.10],
3293            [1.0, 0.0, -0.20],
3294            [1.0, 0.5, -0.05],
3295            [1.0, 1.0, 0.30],
3296            [1.0, 1.5, 0.60],
3297        ];
3298        let y = array![
3299            [0.0, 0.2],
3300            [0.3, 0.1],
3301            [0.4, -0.1],
3302            [0.9, 0.3],
3303            [1.6, 0.8],
3304            [2.2, 1.2],
3305        ];
3306        let weights = array![1.0, 0.8, 1.2, 1.1, 0.9, 1.3];
3307        let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]];
3308
3309        let allocating = gaussian_reml_multi_closed_form_with_cache(
3310            x.view(),
3311            y.view(),
3312            penalty.view(),
3313            Some(weights.view()),
3314            Some(1.0),
3315            None,
3316        )
3317        .expect("allocating fit");
3318        let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3319        let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3320        let mut fitted = Array2::zeros(y.dim());
3321        let mut sigma2 = Array1::zeros(y.ncols());
3322
3323        let no_alloc = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3324            x.view(),
3325            y.view(),
3326            penalty.view(),
3327            Some(weights.view()),
3328            Some(allocating.lambda),
3329            &allocating.cache,
3330            &mut workspace,
3331            coefficients.view_mut(),
3332            fitted.view_mut(),
3333            sigma2.view_mut(),
3334        )
3335        .expect("no-alloc cached fit");
3336
3337        assert!((no_alloc.lambda - allocating.lambda).abs() <= 1.0e-10);
3338        assert!((no_alloc.reml_score - allocating.reml_score).abs() <= 1.0e-8);
3339        assert!((no_alloc.reml_grad_rho - allocating.reml_grad_rho).abs() <= 1.0e-8);
3340        assert!((no_alloc.reml_hess_rho - allocating.reml_hess_rho).abs() <= 1.0e-8);
3341        assert!((no_alloc.edf - allocating.edf).abs() <= 1.0e-10);
3342        for i in 0..x.ncols() {
3343            for j in 0..y.ncols() {
3344                assert!((coefficients[[i, j]] - allocating.coefficients[[i, j]]).abs() <= 1.0e-8);
3345            }
3346        }
3347        for i in 0..x.nrows() {
3348            for j in 0..y.ncols() {
3349                assert!((fitted[[i, j]] - allocating.fitted[[i, j]]).abs() <= 1.0e-8);
3350            }
3351        }
3352        for j in 0..y.ncols() {
3353            assert!((sigma2[j] - allocating.sigma2[j]).abs() <= 1.0e-10);
3354        }
3355    }
3356
3357    #[test]
3358    fn no_alloc_cache_path_rejects_bad_shapes_and_penalty_mismatch() {
3359        let x = array![[1.0, -1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 2.0]];
3360        let y = array![[0.0], [0.2], [0.9], [1.8]];
3361        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3362        let cache = build_gaussian_reml_eigen_cache(x.view(), penalty.view(), None)
3363            .expect("Gaussian REML cache");
3364
3365        let mut bad_workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols() + 1);
3366        let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3367        let mut fitted = Array2::zeros(y.dim());
3368        let mut sigma2 = Array1::zeros(y.ncols());
3369        let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3370            x.view(),
3371            y.view(),
3372            penalty.view(),
3373            None,
3374            Some(1.0),
3375            &cache,
3376            &mut bad_workspace,
3377            coefficients.view_mut(),
3378            fitted.view_mut(),
3379            sigma2.view_mut(),
3380        )
3381        .expect_err("workspace shape mismatch must be rejected");
3382        assert!(err.to_string().contains("workspace shape mismatch"));
3383
3384        let penalty_mismatch = array![[1.0, -1.0], [-1.0, 1.0]];
3385        let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3386        let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3387            x.view(),
3388            y.view(),
3389            penalty_mismatch.view(),
3390            None,
3391            Some(1.0),
3392            &cache,
3393            &mut workspace,
3394            coefficients.view_mut(),
3395            fitted.view_mut(),
3396            sigma2.view_mut(),
3397        )
3398        .expect_err("penalty mismatch must be rejected");
3399        assert!(err.to_string().contains("penalty mismatch"));
3400    }
3401
3402    #[derive(Clone, Copy, Debug)]
3403    enum ForwardScalar {
3404        Lambda,
3405        RemlScore,
3406        Coefficient(usize, usize),
3407        Fitted(usize, usize),
3408        Edf,
3409    }
3410
3411    fn finite_difference_design() -> Array2<f64> {
3412        Array2::from_shape_fn((20, 5), |(row, col)| {
3413            let t = (row as f64 - 9.5) / 10.0;
3414            match col {
3415                0 => 1.0,
3416                1 => t,
3417                2 => 0.5 * (3.0 * t * t - 1.0),
3418                3 => 0.5 * (5.0 * t * t * t - 3.0 * t),
3419                4 => (35.0 * t.powi(4) - 30.0 * t * t + 3.0) / 8.0,
3420                _ => unreachable!(),
3421            }
3422        })
3423    }
3424
3425    fn finite_difference_response(outputs: usize) -> Array2<f64> {
3426        // The truth must NOT lie (essentially) in span(X). The 5-column design
3427        // is Legendre P_0..P_4, so a low-order polynomial + low-frequency sin
3428        // would be fit to near machine precision — driving σ² → 0, dp → 0,
3429        // and ∂score/∂y ≈ ν w r / dp → ∞. Central finite differences with
3430        // Richardson extrapolation cannot resolve such steep, highly-nonlinear
3431        // surfaces at 1e-6 relative because the truncation term scales with
3432        // f^(5)(y), which explodes in that regime. The high-frequency sin
3433        // below is well outside span(P_0..P_4) on t ∈ [-0.95, 0.95], leaving
3434        // a genuine residual (σ² ≈ 1e-3) and an interior REML optimum
3435        // (ρ ≈ -3) at which the analytic-vs-FD comparison is meaningful.
3436        Array2::from_shape_fn((20, outputs), |(row, output)| {
3437            let t = (row as f64 - 9.5) / 10.0;
3438            let phase = output as f64 + 1.0;
3439            0.2 + 0.25 * phase * t - 0.12 * t * t
3440                + (0.08 + 0.03 * phase) * (1.1 * t + 0.3 * phase).sin()
3441                + 0.05 * (7.0 * t + 0.5 * phase).sin()
3442        })
3443    }
3444
3445    fn finite_difference_penalty() -> Array2<f64> {
3446        Array2::from_diag(&array![0.0, 0.8, 1.2, 1.7, 2.3])
3447    }
3448
3449    fn finite_difference_weights() -> Array1<f64> {
3450        Array1::from_shape_fn(20, |row| {
3451            let t = (row as f64 - 9.5) / 10.0;
3452            1.0 + 0.025 * (1.1 * t).sin() + 0.01 * t
3453        })
3454    }
3455
3456    /// Fallible forward-scalar probe. Returns `None` when the closed-form fit
3457    /// rejects the inputs — the relevant case being a penalty perturbation that
3458    /// pushes `S` out of the PSD cone (a single-entry central bump on a
3459    /// null-direction entry drives one eigenvalue slightly negative). Such a
3460    /// point has no well-defined REML objective, so the caller skips it rather
3461    /// than panicking.
3462    fn one_hot_objective_try(
3463        x: ArrayView2<'_, f64>,
3464        y: ArrayView2<'_, f64>,
3465        penalty: ArrayView2<'_, f64>,
3466        weights: ArrayView1<'_, f64>,
3467        target: ForwardScalar,
3468    ) -> Option<f64> {
3469        let fit = gaussian_reml_multi_closed_form_with_cache(
3470            x,
3471            y,
3472            penalty,
3473            Some(weights),
3474            Some(0.85),
3475            None,
3476        )
3477        .ok()?;
3478        Some(match target {
3479            ForwardScalar::Lambda => fit.lambda,
3480            ForwardScalar::RemlScore => fit.reml_score,
3481            ForwardScalar::Coefficient(row, col) => fit.coefficients[[row, col]],
3482            ForwardScalar::Fitted(row, col) => fit.fitted[[row, col]],
3483            ForwardScalar::Edf => fit.edf,
3484        })
3485    }
3486
3487    fn one_hot_objective(
3488        x: ArrayView2<'_, f64>,
3489        y: ArrayView2<'_, f64>,
3490        penalty: ArrayView2<'_, f64>,
3491        weights: ArrayView1<'_, f64>,
3492        target: ForwardScalar,
3493    ) -> f64 {
3494        one_hot_objective_try(x, y, penalty, weights, target)
3495            .expect("finite-difference forward fit")
3496    }
3497
3498    fn one_hot_backward(
3499        x: ArrayView2<'_, f64>,
3500        y: ArrayView2<'_, f64>,
3501        penalty: ArrayView2<'_, f64>,
3502        weights: ArrayView1<'_, f64>,
3503        target: ForwardScalar,
3504    ) -> GaussianRemlBackwardResult {
3505        let mut grad_coefficients = Array2::<f64>::zeros((x.ncols(), y.ncols()));
3506        let mut grad_fitted = Array2::<f64>::zeros(y.dim());
3507        let (grad_lambda, grad_score, grad_edf, coefficient_upstream, fitted_upstream) =
3508            match target {
3509                ForwardScalar::Lambda => (1.0, 0.0, 0.0, None, None),
3510                ForwardScalar::RemlScore => (0.0, 1.0, 0.0, None, None),
3511                ForwardScalar::Coefficient(row, col) => {
3512                    grad_coefficients[[row, col]] = 1.0;
3513                    (0.0, 0.0, 0.0, Some(grad_coefficients.view()), None)
3514                }
3515                ForwardScalar::Fitted(row, col) => {
3516                    grad_fitted[[row, col]] = 1.0;
3517                    (0.0, 0.0, 0.0, None, Some(grad_fitted.view()))
3518                }
3519                ForwardScalar::Edf => (0.0, 0.0, 1.0, None, None),
3520            };
3521        gaussian_reml_multi_closed_form_backward(
3522            x,
3523            y,
3524            penalty,
3525            Some(weights),
3526            Some(0.85),
3527            grad_lambda,
3528            coefficient_upstream,
3529            fitted_upstream,
3530            grad_score,
3531            grad_edf,
3532        )
3533        .expect("analytic backward VJP")
3534    }
3535
3536    fn assert_fd_close(label: &str, analytic: f64, finite_difference: f64) {
3537        let rel_tol = 1.0e-6_f64;
3538        let abs_tol = 1.0e-6_f64;
3539        let tol = abs_tol.max(rel_tol * analytic.abs().max(finite_difference.abs()));
3540        let diff = (analytic - finite_difference).abs();
3541        assert!(
3542            diff <= tol,
3543            "{label}: analytic={analytic:.12e}, finite_difference={finite_difference:.12e}, diff={diff:.3e}, tol={tol:.3e}"
3544        );
3545    }
3546
3547    fn adaptive_central_difference(mut eval: impl FnMut(f64) -> f64) -> f64 {
3548        let steps: [f64; 5] = [1.0e-3, 5.0e-4, 2.5e-4, 1.25e-4, 6.25e-5];
3549        let mut best = f64::NAN;
3550        let mut best_delta = f64::INFINITY;
3551        let mut previous: Option<f64> = None;
3552        for h in steps {
3553            let d1 = (eval(h) - eval(-h)) / (2.0 * h);
3554            let half_h = 0.5 * h;
3555            let d2 = (eval(half_h) - eval(-half_h)) / (2.0 * half_h);
3556            let estimate: f64 = d2 + (d2 - d1) / 3.0;
3557            if let Some(prev) = previous {
3558                let delta = (estimate - prev).abs();
3559                if delta < best_delta {
3560                    best_delta = delta;
3561                    best = estimate;
3562                }
3563            } else {
3564                best = estimate;
3565            }
3566            previous = Some(estimate);
3567        }
3568        best
3569    }
3570
3571    fn assert_backward_matches_forward_finite_difference(outputs: usize) {
3572        let x = finite_difference_design();
3573        let y = finite_difference_response(outputs);
3574        let penalty = finite_difference_penalty();
3575        let weights = finite_difference_weights();
3576        let targets = [
3577            ForwardScalar::Lambda,
3578            ForwardScalar::RemlScore,
3579            ForwardScalar::Coefficient(3, outputs - 1),
3580            ForwardScalar::Fitted(12, outputs - 1),
3581            ForwardScalar::Edf,
3582        ];
3583        for target in targets {
3584            let backward =
3585                one_hot_backward(x.view(), y.view(), penalty.view(), weights.view(), target);
3586
3587            for row in 0..x.nrows() {
3588                for col in 0..x.ncols() {
3589                    let eval = |delta: f64| {
3590                        let mut candidate = x.clone();
3591                        candidate[[row, col]] += delta;
3592                        one_hot_objective(
3593                            candidate.view(),
3594                            y.view(),
3595                            penalty.view(),
3596                            weights.view(),
3597                            target,
3598                        )
3599                    };
3600                    let fd = adaptive_central_difference(eval);
3601                    assert_fd_close(
3602                        &format!("target={target:?} x[{row},{col}]"),
3603                        backward.grad_x[[row, col]],
3604                        fd,
3605                    );
3606                }
3607            }
3608
3609            for row in 0..y.nrows() {
3610                for col in 0..y.ncols() {
3611                    let eval = |delta: f64| {
3612                        let mut candidate = y.clone();
3613                        candidate[[row, col]] += delta;
3614                        one_hot_objective(
3615                            x.view(),
3616                            candidate.view(),
3617                            penalty.view(),
3618                            weights.view(),
3619                            target,
3620                        )
3621                    };
3622                    let fd = adaptive_central_difference(eval);
3623                    assert_fd_close(
3624                        &format!("target={target:?} y[{row},{col}]"),
3625                        backward.grad_y[[row, col]],
3626                        fd,
3627                    );
3628                }
3629            }
3630
3631            for row in 0..weights.len() {
3632                let eval = |delta: f64| {
3633                    let mut candidate = weights.clone();
3634                    candidate[row] += delta;
3635                    one_hot_objective(x.view(), y.view(), penalty.view(), candidate.view(), target)
3636                };
3637                let fd = adaptive_central_difference(eval);
3638                assert_fd_close(
3639                    &format!("target={target:?} weights[{row}]"),
3640                    backward.grad_weights[row],
3641                    fd,
3642                );
3643            }
3644
3645            // ∂L/∂S over the RANGE-SPACE penalty entries. The REML objective
3646            // carries −½d·log|S|₊ (the pseudo-determinant over the NONZERO
3647            // eigenvalues), so ∂L/∂S is only a finite, FD-verifiable derivative
3648            // where a central ±h bump keeps S inside the PSD cone WITHOUT
3649            // changing its rank. A single-entry bump touching the null
3650            // direction violates both: the −h side drives an eigenvalue
3651            // slightly negative (leaves the cone → fit Err) and the +h side
3652            // turns the zero eigenvalue into a tiny positive one that joins
3653            // log|S|₊ as a −log(ε) term (a rank-change discontinuity in L).
3654            // The null-direction component of the analytic S-gradient is a
3655            // gauge convention for the null space (the L-metric pseudoinverse
3656            // `penalty_pinv` = L⁻ᵀ T⁺ L⁻¹), validated by algebra/consumer, not
3657            // FD. So restrict to the strictly-positive diagonal block (both
3658            // indices in 1..p for the diag([0, 0.8, 1.2, 1.7, 2.3]) fixture,
3659            // where S_rr > 0 and ±h stays PSD at full rank). The forward
3660            // consumes only `S_canon = 0.5(S + Sᵀ)` and the backward returns
3661            // the symmetrized gradient, so a single-entry bump of S[r, c]
3662            // (asymmetric) compares directly against `grad_penalty[r, c]` =
3663            // 0.5(G[r, c] + G[c, r]). Defensively, any entry whose largest ±h
3664            // probe leaves the cone is skipped (cone membership is monotone in
3665            // |h| here, so probing the largest step suffices).
3666            let null_index = 0usize; // diag([0.0, ...]) ⇒ coordinate 0 is the null direction.
3667            let probe_h = 1.0e-3_f64; // matches the largest adaptive_central_difference step.
3668            for r in 0..penalty.nrows() {
3669                for c in 0..penalty.ncols() {
3670                    if r == null_index || c == null_index {
3671                        continue;
3672                    }
3673                    let eval = |delta: f64| {
3674                        let mut candidate = penalty.clone();
3675                        candidate[[r, c]] += delta;
3676                        one_hot_objective(
3677                            x.view(),
3678                            y.view(),
3679                            candidate.view(),
3680                            weights.view(),
3681                            target,
3682                        )
3683                    };
3684                    let cone_safe = {
3685                        let mut s_plus = penalty.clone();
3686                        let mut s_minus = penalty.clone();
3687                        s_plus[[r, c]] += probe_h;
3688                        s_minus[[r, c]] -= probe_h;
3689                        one_hot_objective_try(
3690                            x.view(),
3691                            y.view(),
3692                            s_plus.view(),
3693                            weights.view(),
3694                            target,
3695                        )
3696                        .is_some()
3697                            && one_hot_objective_try(
3698                                x.view(),
3699                                y.view(),
3700                                s_minus.view(),
3701                                weights.view(),
3702                                target,
3703                            )
3704                            .is_some()
3705                    };
3706                    if !cone_safe {
3707                        continue;
3708                    }
3709                    let fd = adaptive_central_difference(eval);
3710                    assert_fd_close(
3711                        &format!("target={target:?} penalty[{r},{c}]"),
3712                        backward.grad_penalty[[r, c]],
3713                        fd,
3714                    );
3715                }
3716            }
3717        }
3718    }
3719
3720    #[test]
3721    fn scalar_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3722        assert_backward_matches_forward_finite_difference(1);
3723    }
3724
3725    #[test]
3726    fn multi_output_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3727        assert_backward_matches_forward_finite_difference(3);
3728    }
3729
3730    #[test]
3731    fn backward_vjp_matches_finite_difference() {
3732        let x = array![
3733            [1.0, -1.0, 0.2],
3734            [1.0, -0.3, -0.1],
3735            [1.0, 0.2, 0.4],
3736            [1.0, 0.8, 0.1],
3737            [1.0, 1.4, 0.5],
3738            [1.0, 2.0, 0.9],
3739        ];
3740        let y = array![
3741            [0.1, -0.2],
3742            [0.2, 0.1],
3743            [0.7, 0.0],
3744            [1.1, 0.3],
3745            [1.8, 0.9],
3746            [2.4, 1.4],
3747        ];
3748        let weights = array![1.0, 0.9, 1.1, 1.2, 0.8, 1.3];
3749        let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.2], [0.0, 0.2, 1.7]];
3750        let upstream_coefficients = array![[0.2, -0.1], [0.05, 0.03], [-0.04, 0.07]];
3751        let upstream_fitted = array![
3752            [0.01, -0.02],
3753            [0.03, 0.01],
3754            [-0.01, 0.02],
3755            [0.04, -0.03],
3756            [0.02, 0.05],
3757            [-0.02, 0.01],
3758        ];
3759        let upstream_lambda = 0.17;
3760        let upstream_score = -0.11;
3761
3762        let backward = gaussian_reml_multi_closed_form_backward(
3763            x.view(),
3764            y.view(),
3765            penalty.view(),
3766            Some(weights.view()),
3767            Some(0.8),
3768            upstream_lambda,
3769            Some(upstream_coefficients.view()),
3770            Some(upstream_fitted.view()),
3771            upstream_score,
3772            0.0,
3773        )
3774        .expect("backward VJP");
3775
3776        let objective = |x_eval: &Array2<f64>, y_eval: &Array2<f64>, w_eval: &Array1<f64>| {
3777            let fit = gaussian_reml_multi_closed_form_with_cache(
3778                x_eval.view(),
3779                y_eval.view(),
3780                penalty.view(),
3781                Some(w_eval.view()),
3782                Some(0.8),
3783                None,
3784            )
3785            .expect("fit for objective");
3786            upstream_lambda * fit.lambda
3787                + upstream_score * fit.reml_score
3788                + (&fit.coefficients * &upstream_coefficients).sum()
3789                + (&fit.fitted * &upstream_fitted).sum()
3790        };
3791        let eps = 1.0e-6;
3792        assert!(objective(&x, &y, &weights).is_finite());
3793
3794        let mut x_plus = x.clone();
3795        let mut x_minus = x.clone();
3796        x_plus[[3, 2]] += eps;
3797        x_minus[[3, 2]] -= eps;
3798        let fd_x =
3799            (objective(&x_plus, &y, &weights) - objective(&x_minus, &y, &weights)) / (2.0 * eps);
3800        assert!(
3801            (fd_x - backward.grad_x[[3, 2]]).abs() <= 2.0e-4,
3802            "grad_x mismatch: analytic={} fd={}",
3803            backward.grad_x[[3, 2]],
3804            fd_x
3805        );
3806
3807        let mut y_plus = y.clone();
3808        let mut y_minus = y.clone();
3809        y_plus[[4, 1]] += eps;
3810        y_minus[[4, 1]] -= eps;
3811        let fd_y =
3812            (objective(&x, &y_plus, &weights) - objective(&x, &y_minus, &weights)) / (2.0 * eps);
3813        assert!(
3814            (fd_y - backward.grad_y[[4, 1]]).abs() <= 2.0e-4,
3815            "grad_y mismatch: analytic={} fd={}",
3816            backward.grad_y[[4, 1]],
3817            fd_y
3818        );
3819
3820        let mut w_plus = weights.clone();
3821        let mut w_minus = weights.clone();
3822        w_plus[2] += eps;
3823        w_minus[2] -= eps;
3824        let fd_w = (objective(&x, &y, &w_plus) - objective(&x, &y, &w_minus)) / (2.0 * eps);
3825        assert!(
3826            (fd_w - backward.grad_weights[2]).abs() <= 2.0e-4,
3827            "grad_weight mismatch: analytic={} fd={}",
3828            backward.grad_weights[2],
3829            fd_w
3830        );
3831
3832        // Combined-seed ∂L/∂S spot-check: perturb individual penalty entries with
3833        // x/y/w held at base, under mixed (λ, score, β, fitted) seeds. The penalty
3834        // [[0,0,0],[0,1,0.2],[0,0.2,1.7]] is nullity 1 (coordinate 0 is the null
3835        // direction); ∂L/∂S is FD-verifiable only on the strictly-positive
3836        // RANGE block (indices 1,2), where a central ±h bump keeps S PSD at full
3837        // rank. Null-touching entries (any index 0) are non-FD-verifiable — the
3838        // −½d·log|S|₊ pseudo-determinant term makes L either cone-leaving or
3839        // rank-change-discontinuous there (see the exhaustive S loop above). A
3840        // single-entry asymmetric bump of S[r, c] compares directly to
3841        // grad_penalty[[r, c]] = 0.5(G[r,c] + G[c,r]), exercising the backward
3842        // symmetrization.
3843        let objective_s = |s_eval: &Array2<f64>| {
3844            let fit = gaussian_reml_multi_closed_form_with_cache(
3845                x.view(),
3846                y.view(),
3847                s_eval.view(),
3848                Some(weights.view()),
3849                Some(0.8),
3850                None,
3851            )
3852            .expect("fit for penalty objective");
3853            upstream_lambda * fit.lambda
3854                + upstream_score * fit.reml_score
3855                + (&fit.coefficients * &upstream_coefficients).sum()
3856                + (&fit.fitted * &upstream_fitted).sum()
3857        };
3858        // (1,1) full-rank diagonal; (1,2) pure off-diagonal between two penalized
3859        // directions; (2,2) full-rank diagonal. All in the strictly-positive
3860        // range block, so ±h stays PSD at full rank.
3861        for (r, c) in [(1usize, 1usize), (1, 2), (2, 2)] {
3862            let mut s_plus = penalty.clone();
3863            let mut s_minus = penalty.clone();
3864            s_plus[[r, c]] += eps;
3865            s_minus[[r, c]] -= eps;
3866            let fd_s = (objective_s(&s_plus) - objective_s(&s_minus)) / (2.0 * eps);
3867            assert!(
3868                (fd_s - backward.grad_penalty[[r, c]]).abs() <= 2.0e-4,
3869                "grad_penalty[{r},{c}] mismatch: analytic={} fd={}",
3870                backward.grad_penalty[[r, c]],
3871                fd_s
3872            );
3873        }
3874    }
3875
3876    #[test]
3877    fn batched_eigen_cache_matches_per_fit_build() {
3878        // Three K=3 problems sharing the same penalty matrix. The batched
3879        // pipeline must produce caches that are bit-exact identical to what
3880        // the per-fit `gaussian_reml_eigen_cache_from_xtwx` builder produces,
3881        // regardless of whether the GPU batched Cholesky kicks in or the
3882        // helper falls through to per-fit Cholesky.
3883        let xtwx_a = array![[4.0, 1.0], [1.0, 3.0]];
3884        let xtwx_b = array![[2.5, -0.5], [-0.5, 1.7]];
3885        let xtwx_c = array![[7.2, 0.3], [0.3, 5.1]];
3886        let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3887
3888        let batched = build_gaussian_reml_eigen_cache_batched(
3889            vec![xtwx_a.clone(), xtwx_b.clone(), xtwx_c.clone()],
3890            penalty.view(),
3891            None,
3892        );
3893        assert_eq!(batched.len(), 3);
3894
3895        for (xtwx, batched_cache) in [&xtwx_a, &xtwx_b, &xtwx_c].into_iter().zip(batched.iter()) {
3896            let single = gaussian_reml_eigen_cache_from_xtwx(xtwx.clone(), penalty.view(), None)
3897                .expect("per-fit cache");
3898            let batched_cache = batched_cache.as_ref().expect("batched cache");
3899            assert_eq!(batched_cache.penalty_rank, single.penalty_rank);
3900            assert_eq!(batched_cache.nullity, single.nullity);
3901            assert_eq!(batched_cache.xtwx_fingerprint, single.xtwx_fingerprint);
3902            assert_eq!(
3903                batched_cache.penalty_fingerprint,
3904                single.penalty_fingerprint
3905            );
3906            assert!((batched_cache.logdet_xtwx - single.logdet_xtwx).abs() <= 1.0e-12);
3907            assert!(
3908                (batched_cache.logdet_penalty_positive - single.logdet_penalty_positive).abs()
3909                    <= 1.0e-12
3910            );
3911            for (a, b) in batched_cache
3912                .penalty_eigenvalues
3913                .iter()
3914                .zip(single.penalty_eigenvalues.iter())
3915            {
3916                assert!((a - b).abs() <= 1.0e-12);
3917            }
3918            for ((a, b), _) in batched_cache
3919                .coefficient_basis
3920                .iter()
3921                .zip(single.coefficient_basis.iter())
3922                .zip(0..)
3923            {
3924                assert!((a - b).abs() <= 1.0e-12);
3925            }
3926        }
3927    }
3928
3929    #[test]
3930    fn scalar_rho_optimizer_chooses_lowest_cost_stationary_point() {
3931        let cache = GaussianRemlEigenCache {
3932            penalty_eigenvalues: array![5.2430192311066924e-05, 81734184.18548436],
3933            eigenvectors: Array2::eye(2),
3934            coefficient_basis: Array2::eye(2),
3935            xtwx_fingerprint: 0,
3936            penalty_fingerprint: 0,
3937            logdet_xtwx: 0.0,
3938            logdet_penalty_positive: 0.0,
3939            penalty_rank: 2,
3940            nullity: 0,
3941        };
3942        let prepared = GaussianRemlPrepared {
3943            cache: cache.clone(),
3944            ywy: array![0.5021347226586624],
3945            projected_rhs_squared: array![[0.361060218768292], [0.01014486085547482]],
3946            projected_rhs: array![
3947                [0.361060218768292_f64.sqrt()],
3948                [0.01014486085547482_f64.sqrt()]
3949            ],
3950            n_observations: 100,
3951            n_outputs: 1,
3952        };
3953
3954        let rho = optimize_rho(&prepared, None).expect("allocating rho optimizer");
3955        let no_alloc_rho = optimize_rho_no_alloc(
3956            &cache,
3957            prepared.ywy.view(),
3958            prepared.projected_rhs_squared.view(),
3959            prepared.n_observations,
3960            prepared.n_outputs,
3961            None,
3962        )
3963        .expect("no-alloc rho optimizer");
3964
3965        assert!(
3966            (rho - 4.3251059890).abs() < 1.0e-6,
3967            "rho optimizer selected {rho}, expected the lower-cost later stationary point"
3968        );
3969        assert!(
3970            (no_alloc_rho - rho).abs() < 1.0e-8,
3971            "no-alloc optimizer selected {no_alloc_rho}, allocating selected {rho}"
3972        );
3973        assert!(prepared.evaluate(rho).cost < prepared.evaluate(-18.9277503549).cost);
3974    }
3975
3976    #[test]
3977    fn backward_from_fit_matches_backward_with_refit() {
3978        // The Task 3 state round-trip in pyffi calls `_from_fit`; that path
3979        // must be numerically identical to the refitting `_backward` entry
3980        // when fed the same forward result. This guards the optimization
3981        // against drift when either path is touched.
3982        let x = array![[1.0, -0.9], [1.0, -0.4], [1.0, 0.1], [1.0, 0.6], [1.0, 1.1],];
3983        let y = array![[0.2, -0.1], [0.4, 0.1], [0.7, 0.3], [1.0, 0.5], [1.5, 0.8]];
3984        let penalty = array![[0.0, 0.0], [0.0, 1.5]];
3985        let weights = array![1.05, 0.95, 1.01, 0.99, 1.03];
3986
3987        let refit = gaussian_reml_multi_closed_form_backward(
3988            x.view(),
3989            y.view(),
3990            penalty.view(),
3991            Some(weights.view()),
3992            Some(0.85),
3993            0.2,
3994            None,
3995            None,
3996            -0.1,
3997            0.0,
3998        )
3999        .expect("refit backward");
4000
4001        let fit = gaussian_reml_multi_closed_form_with_cache(
4002            x.view(),
4003            y.view(),
4004            penalty.view(),
4005            Some(weights.view()),
4006            Some(0.85),
4007            None,
4008        )
4009        .expect("forward fit");
4010        let from_fit = gaussian_reml_multi_closed_form_backward_from_fit(
4011            x.view(),
4012            y.view(),
4013            penalty.view(),
4014            Some(weights.view()),
4015            &fit,
4016            0.2,
4017            None,
4018            None,
4019            -0.1,
4020            0.0,
4021        )
4022        .expect("from_fit backward");
4023
4024        for (a, b) in refit.grad_x.iter().zip(from_fit.grad_x.iter()) {
4025            assert!((a - b).abs() <= 1.0e-12);
4026        }
4027        for (a, b) in refit.grad_y.iter().zip(from_fit.grad_y.iter()) {
4028            assert!((a - b).abs() <= 1.0e-12);
4029        }
4030        for (a, b) in refit.grad_weights.iter().zip(from_fit.grad_weights.iter()) {
4031            assert!((a - b).abs() <= 1.0e-12);
4032        }
4033    }
4034
4035    /// Regression: when `K = XᵀWX + λS` is effectively rank-deficient (e.g.
4036    /// `λ` has saturated very large), the backward must NOT error — it must
4037    /// degrade gracefully and return zero gradients of the correct shape.
4038    /// This is the production-training scenario where individual atoms can
4039    /// saturate `λ_k` in early batches; raising here would crash an entire
4040    /// step. We construct the degenerate state by running a real forward
4041    /// fit and then corrupting `reml_hess_rho` to 0 (the gate variable the
4042    /// backward checks). We assert: (a) no error, (b) all gradients finite,
4043    /// (c) shapes match the inputs.
4044    #[test]
4045    fn backward_degrades_gracefully_when_k_is_near_singular() {
4046        // Small, full-rank S with a moderately-conditioned X. The exact
4047        // numbers don't matter; what matters is that we then force the
4048        // ill-conditioned gate to fire.
4049        let x = array![
4050            [1.0, -1.0, 0.5],
4051            [1.0, -0.5, 0.2],
4052            [1.0, 0.0, -0.1],
4053            [1.0, 0.5, 0.3],
4054            [1.0, 1.0, 0.8],
4055            [1.0, 1.5, 1.1],
4056            [1.0, 2.0, 1.5],
4057            [1.0, 2.5, 2.0],
4058            [1.0, 3.0, 2.6],
4059            [1.0, 3.5, 3.1],
4060        ];
4061        let y = array![
4062            [0.1],
4063            [0.3],
4064            [0.4],
4065            [0.7],
4066            [1.0],
4067            [1.5],
4068            [2.0],
4069            [2.7],
4070            [3.3],
4071            [4.0]
4072        ];
4073        // Full-rank S to keep the forward well-posed.
4074        let penalty = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
4075
4076        let mut fit =
4077            gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
4078                .expect("forward fit must succeed for well-posed input");
4079        // Force the ill-conditioned gate to fire by zeroing the REML
4080        // Hessian w.r.t. rho — this is exactly what happens in production
4081        // when `λ` saturates to 1e10+ and `d²ℓ/dρ² → 0`.
4082        fit.reml_hess_rho = 0.0;
4083
4084        let result = gaussian_reml_multi_closed_form_backward_from_fit(
4085            x.view(),
4086            y.view(),
4087            penalty.view(),
4088            None,
4089            &fit,
4090            // Nonzero upstreams to force the backward to actually try to
4091            // populate gradients (rather than short-circuit on zero seeds).
4092            1.0,
4093            None,
4094            None,
4095            1.0,
4096            1.0,
4097        )
4098        .expect("backward must NOT error on near-singular K");
4099
4100        assert_eq!(result.grad_x.dim(), (x.nrows(), x.ncols()));
4101        assert_eq!(result.grad_y.dim(), (y.nrows(), y.ncols()));
4102        assert_eq!(result.grad_penalty.dim(), (x.ncols(), x.ncols()));
4103        assert_eq!(result.grad_weights.dim(), x.nrows());
4104        for v in result.grad_x.iter() {
4105            assert!(v.is_finite(), "grad_x must be finite, got {v}");
4106        }
4107        for v in result.grad_y.iter() {
4108            assert!(v.is_finite(), "grad_y must be finite, got {v}");
4109        }
4110        for v in result.grad_penalty.iter() {
4111            assert!(v.is_finite(), "grad_penalty must be finite, got {v}");
4112        }
4113        for v in result.grad_weights.iter() {
4114            assert!(v.is_finite(), "grad_weights must be finite, got {v}");
4115        }
4116    }
4117}
4118
4119/// Vector–Jacobian products of the multi-block per-smooth-λ Gaussian REML
4120/// forward fit ([`gaussian_reml_blocks_orthogonal_shared_scale`]), back to the
4121/// design blocks, penalty blocks, response, and weights.
4122pub struct GaussianRemlBlocksBackwardAnalytic {
4123    pub grad_designs: Vec<Array2<f64>>,
4124    pub grad_penalties: Vec<Array2<f64>>,
4125    pub grad_y: Array2<f64>,
4126    pub grad_weights: Array1<f64>,
4127}
4128
4129/// Analytic backward for the multi-block per-smooth-λ Gaussian REML forward.
4130///
4131/// Computes VJPs of (coefficients, fitted, lambdas, log_lambdas, reml_score,
4132/// edf) back to (design_blocks, penalty_blocks, y, weights). The VJP is
4133/// assembled at the converged log-λ vector: fixed-ρ β/fitted/profiled-REML/EDF
4134/// terms are accumulated first, then the smoothing-parameter sensitivity is
4135/// routed through the F×F profiled REML score Hessian from the implicit optimum.
4136/// Pairs with the forward [`gaussian_reml_blocks_orthogonal_shared_scale`].
4137pub fn gaussian_reml_fit_blocks_backward_analytic(
4138    designs: &[Array2<f64>],
4139    penalties_raw: &[Array2<f64>],
4140    y: ArrayView1<'_, f64>,
4141    weights: ArrayView1<'_, f64>,
4142    rhos: &[f64],
4143    grad_coefficients: Option<ArrayView2<'_, f64>>,
4144    grad_fitted: Option<ArrayView2<'_, f64>>,
4145    grad_lambdas: Option<ArrayView1<'_, f64>>,
4146    grad_log_lambdas: Option<ArrayView1<'_, f64>>,
4147    grad_reml_score: f64,
4148    grad_edf: Option<ArrayView1<'_, f64>>,
4149) -> Result<GaussianRemlBlocksBackwardAnalytic, EstimationError> {
4150    let n = y.len();
4151    let f_blocks = designs.len();
4152    let mut offsets = Vec::with_capacity(f_blocks + 1);
4153    offsets.push(0_usize);
4154    for design in designs {
4155        offsets.push(offsets.last().copied().unwrap() + design.ncols());
4156    }
4157    let p_total = *offsets.last().unwrap();
4158    if n == 0 || p_total == 0 {
4159        return Err(EstimationError::InvalidInput(
4160            "gaussian_reml_fit_blocks_backward requires non-empty rows and at least one coefficient column"
4161                .to_string(),
4162        ));
4163    }
4164
4165    if rhos.len() != f_blocks {
4166        return Err(EstimationError::InvalidInput(format!(
4167            "log_lambdas length mismatch: expected {f_blocks}, got {}",
4168            rhos.len()
4169        )));
4170    }
4171    if let Some(gc) = grad_coefficients {
4172        if gc.dim() != (p_total, 1) {
4173            return Err(EstimationError::InvalidInput(format!(
4174                "grad_coefficients shape mismatch: expected {}x1, got {}x{}",
4175                p_total,
4176                gc.nrows(),
4177                gc.ncols()
4178            )));
4179        }
4180    }
4181    if let Some(gf) = grad_fitted {
4182        if gf.dim() != (n, 1) {
4183            return Err(EstimationError::InvalidInput(format!(
4184                "grad_fitted shape mismatch: expected {}x1, got {}x{}",
4185                n,
4186                gf.nrows(),
4187                gf.ncols()
4188            )));
4189        }
4190    }
4191    if !grad_reml_score.is_finite() {
4192        return Err(EstimationError::InvalidInput(format!(
4193            "grad_reml_score must be finite; got {grad_reml_score}"
4194        )));
4195    }
4196    if let Some(vec) = grad_lambdas {
4197        if vec.len() != f_blocks {
4198            return Err(EstimationError::InvalidInput(format!(
4199                "grad_lambdas length mismatch: expected {f_blocks}, got {}",
4200                vec.len()
4201            )));
4202        }
4203    }
4204    if let Some(vec) = grad_log_lambdas {
4205        if vec.len() != f_blocks {
4206            return Err(EstimationError::InvalidInput(format!(
4207                "grad_log_lambdas length mismatch: expected {f_blocks}, got {}",
4208                vec.len()
4209            )));
4210        }
4211    }
4212    if let Some(vec) = grad_edf {
4213        if vec.len() != f_blocks {
4214            return Err(EstimationError::InvalidInput(format!(
4215                "grad_edf length mismatch: expected {f_blocks}, got {}",
4216                vec.len()
4217            )));
4218        }
4219    }
4220    if let Some(gc) = grad_coefficients {
4221        if let Some(((row, col), value)) = gc.indexed_iter().find(|(_, value)| !value.is_finite()) {
4222            return Err(EstimationError::InvalidInput(format!(
4223                "grad_coefficients[{row},{col}] must be finite; got {value}"
4224            )));
4225        }
4226    }
4227    if let Some(gf) = grad_fitted {
4228        if let Some(((row, col), value)) = gf.indexed_iter().find(|(_, value)| !value.is_finite()) {
4229            return Err(EstimationError::InvalidInput(format!(
4230                "grad_fitted[{row},{col}] must be finite; got {value}"
4231            )));
4232        }
4233    }
4234    if let Some(vec) = grad_lambdas {
4235        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4236            return Err(EstimationError::InvalidInput(format!(
4237                "grad_lambdas[{block}] must be finite; got {value}"
4238            )));
4239        }
4240    }
4241    if let Some(vec) = grad_log_lambdas {
4242        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4243            return Err(EstimationError::InvalidInput(format!(
4244                "grad_log_lambdas[{block}] must be finite; got {value}"
4245            )));
4246        }
4247    }
4248    if let Some(vec) = grad_edf {
4249        if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4250            return Err(EstimationError::InvalidInput(format!(
4251                "grad_edf[{block}] must be finite; got {value}"
4252            )));
4253        }
4254    }
4255    for (block, design) in designs.iter().enumerate() {
4256        if let Some(((row, col), value)) =
4257            design.indexed_iter().find(|(_, value)| !value.is_finite())
4258        {
4259            return Err(EstimationError::InvalidInput(format!(
4260                "designs[{block}][{row},{col}] must be finite; got {value}"
4261            )));
4262        }
4263    }
4264    for (block, penalty) in penalties_raw.iter().enumerate() {
4265        if let Some(((row, col), value)) =
4266            penalty.indexed_iter().find(|(_, value)| !value.is_finite())
4267        {
4268            return Err(EstimationError::InvalidInput(format!(
4269                "penalties[{block}][{row},{col}] must be finite; got {value}"
4270            )));
4271        }
4272    }
4273    if let Some((row, value)) = y.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4274        return Err(EstimationError::InvalidInput(format!(
4275            "y[{row}] must be finite; got {value}"
4276        )));
4277    }
4278    if let Some((row, value)) = weights
4279        .iter()
4280        .enumerate()
4281        .find(|(_, value)| !value.is_finite() || **value < 0.0)
4282    {
4283        return Err(EstimationError::InvalidInput(format!(
4284            "weights[{row}] must be finite and non-negative; got {value}"
4285        )));
4286    }
4287
4288    let mut z = Array2::<f64>::zeros((n, p_total));
4289    for k in 0..f_blocks {
4290        z.slice_mut(s![.., offsets[k]..offsets[k + 1]])
4291            .assign(&designs[k]);
4292    }
4293
4294    let penalties: Vec<Array2<f64>> = penalties_raw
4295        .iter()
4296        .map(|p| {
4297            let mut out = p.clone();
4298            gam_linalg::matrix::symmetrize_in_place(&mut out);
4299            out
4300        })
4301        .collect();
4302    let mut ranks = Vec::with_capacity(f_blocks);
4303    let mut pinvs = Vec::with_capacity(f_blocks);
4304    for penalty in &penalties {
4305        let (rank, pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(penalty)?;
4306        ranks.push(rank);
4307        pinvs.push(pinv);
4308    }
4309
4310    let lambdas = Array1::from_iter(rhos.iter().map(|rho| rho.exp()));
4311    if let Some((block, lambda)) = lambdas
4312        .iter()
4313        .enumerate()
4314        .find(|(_, lambda)| !lambda.is_finite() || **lambda <= 0.0)
4315    {
4316        return Err(EstimationError::InvalidInput(format!(
4317            "exp(log_lambdas[{block}]) must be finite and positive; got {lambda}"
4318        )));
4319    }
4320    let mut k_matrix = fast_xt_diag_x(&z.view(), &weights);
4321    for block in 0..f_blocks {
4322        let lambda = lambdas[block];
4323        for local_i in 0..penalties[block].nrows() {
4324            let global_i = offsets[block] + local_i;
4325            for local_j in 0..penalties[block].ncols() {
4326                let global_j = offsets[block] + local_j;
4327                k_matrix[[global_i, global_j]] += lambda * penalties[block][[local_i, local_j]];
4328            }
4329        }
4330    }
4331    let r = gam_linalg::utils::invert_spd_with_ridge(&k_matrix, 0.0)?;
4332
4333    let mut xtwy = Array1::<f64>::zeros(p_total);
4334    for row in 0..n {
4335        let wy = weights[row] * y[row];
4336        for col in 0..p_total {
4337            xtwy[col] += z[[row, col]] * wy;
4338        }
4339    }
4340    let beta = r.dot(&xtwy);
4341    let fitted = z.dot(&beta);
4342    if let Some((col, value)) = beta
4343        .iter()
4344        .enumerate()
4345        .find(|(_, value)| !value.is_finite())
4346    {
4347        return Err(EstimationError::InvalidInput(format!(
4348            "solved coefficient {col} is non-finite: {value}"
4349        )));
4350    }
4351    let residual = &y.to_owned() - &fitted;
4352    let weighted_residual = &residual * &weights.to_owned();
4353    let ywy = y
4354        .iter()
4355        .zip(weights.iter())
4356        .map(|(&yi, &wi)| wi * yi * yi)
4357        .sum::<f64>();
4358    let q_raw = ywy - xtwy.dot(&beta);
4359    if !q_raw.is_finite() {
4360        return Err(EstimationError::InvalidInput(format!(
4361            "Gaussian REML residual quadratic form must be finite; got {q_raw}"
4362        )));
4363    }
4364    let q = q_raw.max(1.0e-300);
4365    let nullity = penalties
4366        .iter()
4367        .zip(ranks.iter())
4368        .map(|(penalty, rank)| penalty.nrows().saturating_sub(*rank))
4369        .sum::<usize>();
4370    let nu = n as f64 - nullity as f64;
4371    if !(nu.is_finite() && nu > 0.0) {
4372        return Err(EstimationError::InvalidInput(format!(
4373            "Gaussian REML residual degrees of freedom must be positive; got {nu}"
4374        )));
4375    }
4376    let tau = nu / q;
4377    let tau_q = -nu / (q * q);
4378    if !(tau.is_finite() && tau_q.is_finite()) {
4379        return Err(EstimationError::InvalidInput(format!(
4380            "Gaussian REML scale derivatives are non-finite: tau={tau}, tau_q={tau_q}"
4381        )));
4382    }
4383
4384    let mut grad_z = Array2::<f64>::zeros((n, p_total));
4385    let mut g_kernel = Array2::<f64>::zeros((p_total, p_total));
4386    let mut h_kernel = Array1::<f64>::zeros(p_total);
4387    let mut q_kernel = 0.0_f64;
4388    let mut j_blocks: Vec<Array2<f64>> = penalties
4389        .iter()
4390        .map(|p| Array2::<f64>::zeros(p.dim()))
4391        .collect();
4392
4393    let mut beta_tilde = Array1::<f64>::zeros(p_total);
4394    if let Some(gc) = grad_coefficients {
4395        beta_tilde += &gc.column(0).to_owned();
4396    }
4397    if let Some(gf) = grad_fitted {
4398        let gf_col = gf.column(0).to_owned();
4399        beta_tilde += &z.t().dot(&gf_col);
4400        for row in 0..n {
4401            for col in 0..p_total {
4402                grad_z[[row, col]] += gf_col[row] * beta[col];
4403            }
4404        }
4405    }
4406
4407    // Generic downstream losses that explicitly seed beta_hat or fitted
4408    // values cannot use the REML envelope shortcut. Route those seeds through
4409    // the fixed-rho KKT adjoint K u = beta_tilde before differentiating
4410    // designs, penalties, y, weights, and rho.
4411    let u = r.dot(&beta_tilde);
4412    h_kernel += &u;
4413    for i in 0..p_total {
4414        for j in 0..p_total {
4415            g_kernel[[i, j]] -= 0.5 * (beta[i] * u[j] + u[i] * beta[j]);
4416        }
4417    }
4418
4419    let mut alpha = Array1::<f64>::zeros(f_blocks);
4420    if let Some(gl) = grad_lambdas {
4421        for block in 0..f_blocks {
4422            alpha[block] += gl[block] * lambdas[block];
4423        }
4424    }
4425    if let Some(grho) = grad_log_lambdas {
4426        alpha += &grho.to_owned();
4427    }
4428
4429    let mut p_betas = Vec::with_capacity(f_blocks);
4430    let mut m_vectors = Vec::with_capacity(f_blocks);
4431    let mut rp_matrices = Vec::with_capacity(f_blocks);
4432    let mut rpr_matrices = Vec::with_capacity(f_blocks);
4433    let mut b_values = Array1::<f64>::zeros(f_blocks);
4434    let mut t_values = Array1::<f64>::zeros(f_blocks);
4435
4436    for block in 0..f_blocks {
4437        let start = offsets[block];
4438        let end = offsets[block + 1];
4439        let beta_k = beta.slice(s![start..end]).to_owned();
4440        let s_beta = penalties[block].dot(&beta_k);
4441        let lambda = lambdas[block];
4442        let lambda_s_beta = s_beta.mapv(|value| lambda * value);
4443        let mut p_beta = Array1::<f64>::zeros(p_total);
4444        for local_i in 0..(end - start) {
4445            p_beta[start + local_i] = lambda_s_beta[local_i];
4446        }
4447        let weighted_penalty = penalties[block].mapv(|value| lambda * value);
4448        let rp_block = r.slice(s![.., start..end]).dot(&weighted_penalty);
4449        let mut rp = Array2::<f64>::zeros((p_total, p_total));
4450        rp.slice_mut(s![.., start..end]).assign(&rp_block);
4451        let rpr = rp_block.dot(&r.slice(s![start..end, ..]));
4452        let m = r.slice(s![.., start..end]).dot(&lambda_s_beta);
4453        b_values[block] = beta.dot(&p_beta);
4454        t_values[block] = (0..(end - start))
4455            .map(|local_i| rp_block[[start + local_i, local_i]])
4456            .sum::<f64>();
4457        alpha[block] -= u.dot(&p_beta);
4458        p_betas.push(p_beta);
4459        m_vectors.push(m);
4460        rp_matrices.push(rp);
4461        rpr_matrices.push(rpr);
4462    }
4463
4464    if grad_reml_score != 0.0 {
4465        q_kernel += 0.5 * grad_reml_score * tau;
4466        g_kernel += &(r.clone() * (0.5 * grad_reml_score));
4467        for block in 0..f_blocks {
4468            j_blocks[block] -= &(pinvs[block].clone() * (0.5 * grad_reml_score / lambdas[block]));
4469        }
4470    }
4471
4472    let mut trace_pairs = Array2::<f64>::zeros((f_blocks, f_blocks));
4473    for i in 0..f_blocks {
4474        for j in 0..f_blocks {
4475            trace_pairs[[i, j]] = gam_linalg::utils::trace_of_product(
4476                rp_matrices[i].view(),
4477                rp_matrices[j].view(),
4478            );
4479        }
4480    }
4481
4482    if let Some(ge) = grad_edf {
4483        for edf_block in 0..f_blocks {
4484            let scale = ge[edf_block];
4485            if scale == 0.0 {
4486                continue;
4487            }
4488            let start = offsets[edf_block];
4489            let end = offsets[edf_block + 1];
4490            g_kernel += &(rpr_matrices[edf_block].clone() * scale);
4491            j_blocks[edf_block] -= &(r.slice(s![start..end, start..end]).to_owned() * scale);
4492            for rho_block in 0..f_blocks {
4493                alpha[rho_block] += scale * trace_pairs[[edf_block, rho_block]];
4494                if rho_block == edf_block {
4495                    alpha[rho_block] -= scale * t_values[edf_block];
4496                }
4497            }
4498        }
4499    }
4500
4501    if let Some((block, value)) = alpha
4502        .iter()
4503        .enumerate()
4504        .find(|(_, value)| !value.is_finite())
4505    {
4506        return Err(EstimationError::InvalidInput(format!(
4507            "rho adjoint seed for block {block} is non-finite: {value}"
4508        )));
4509    }
4510
4511    if alpha.iter().any(|value| *value != 0.0) {
4512        let mut outer_h = Array2::<f64>::zeros((f_blocks, f_blocks));
4513        for k in 0..f_blocks {
4514            for j in 0..f_blocks {
4515                let beta_pk_r_pj_beta = p_betas[k].dot(&m_vectors[j]);
4516                outer_h[[k, j]] = 0.5 * trace_pairs[[k, j]] + tau * beta_pk_r_pj_beta
4517                    - if k == j {
4518                        0.5 * (t_values[k] + tau * b_values[k])
4519                    } else {
4520                        0.0
4521                    }
4522                    - 0.5 * tau_q * b_values[k] * b_values[j];
4523            }
4524        }
4525        // `outer_h` is the Jacobian of the negative profiled REML estimating
4526        // equation. Preserve signed curvature directions while flooring
4527        // near-zero modes; flipping negative eigenvalues would change the VJP.
4528        gam_linalg::matrix::symmetrize_in_place(&mut outer_h);
4529        if let Some(((row, col), value)) =
4530            outer_h.indexed_iter().find(|(_, value)| !value.is_finite())
4531        {
4532            return Err(EstimationError::InvalidInput(format!(
4533                "outer rho curvature entry ({row},{col}) is non-finite: {value}"
4534            )));
4535        }
4536        let rho_adj =
4537            gam_linalg::utils::solve_symmetric_vector_with_floor(&outer_h, &alpha, 1.0e-10)?;
4538        if let Some((block, value)) = rho_adj
4539            .iter()
4540            .enumerate()
4541            .find(|(_, value)| !value.is_finite())
4542        {
4543            return Err(EstimationError::InvalidInput(format!(
4544                "outer rho adjoint for block {block} is non-finite: {value}"
4545            )));
4546        }
4547        let weighted_b_sum = rho_adj
4548            .iter()
4549            .zip(b_values.iter())
4550            .map(|(&zk, &bk)| zk * bk)
4551            .sum::<f64>();
4552        q_kernel += 0.5 * tau_q * weighted_b_sum;
4553        for block in 0..f_blocks {
4554            let zk = rho_adj[block];
4555            if zk == 0.0 {
4556                continue;
4557            }
4558            g_kernel -= &(rpr_matrices[block].clone() * (0.5 * zk));
4559            let m = &m_vectors[block];
4560            for i in 0..p_total {
4561                h_kernel[i] += tau * zk * m[i];
4562                for j in 0..p_total {
4563                    g_kernel[[i, j]] -= 0.5 * tau * zk * (beta[i] * m[j] + m[i] * beta[j]);
4564                }
4565            }
4566            let start = offsets[block];
4567            let end = offsets[block + 1];
4568            j_blocks[block] += &(r.slice(s![start..end, start..end]).to_owned() * (0.5 * zk));
4569            for i in 0..(end - start) {
4570                for j in 0..(end - start) {
4571                    j_blocks[block][[i, j]] += 0.5 * tau * zk * beta[start + i] * beta[start + j];
4572                }
4573            }
4574        }
4575    }
4576
4577    for row in 0..n {
4578        for col in 0..p_total {
4579            grad_z[[row, col]] += -2.0 * q_kernel * weighted_residual[row] * beta[col];
4580        }
4581    }
4582    let zg = z.dot(&g_kernel);
4583    for row in 0..n {
4584        for col in 0..p_total {
4585            grad_z[[row, col]] += 2.0 * weights[row] * zg[[row, col]];
4586        }
4587    }
4588    let wy = y.to_owned() * &weights.to_owned();
4589    for row in 0..n {
4590        for col in 0..p_total {
4591            grad_z[[row, col]] += wy[row] * h_kernel[col];
4592        }
4593    }
4594
4595    let mut grad_y = Array2::<f64>::zeros((n, 1));
4596    let zh = z.dot(&h_kernel);
4597    for row in 0..n {
4598        grad_y[[row, 0]] = 2.0 * q_kernel * weighted_residual[row] + weights[row] * zh[row];
4599    }
4600
4601    let mut grad_weights = Array1::<f64>::zeros(n);
4602    for row in 0..n {
4603        let diag_zgz = (0..p_total)
4604            .map(|col| z[[row, col]] * zg[[row, col]])
4605            .sum::<f64>();
4606        grad_weights[row] = q_kernel * residual[row] * residual[row] + diag_zgz + y[row] * zh[row];
4607    }
4608
4609    let mut grad_penalties = Vec::with_capacity(f_blocks);
4610    for block in 0..f_blocks {
4611        let start = offsets[block];
4612        let end = offsets[block + 1];
4613        let mut local = g_kernel.slice(s![start..end, start..end]).to_owned();
4614        for i in 0..(end - start) {
4615            for j in 0..(end - start) {
4616                local[[i, j]] += q_kernel * beta[start + i] * beta[start + j];
4617            }
4618        }
4619        local += &j_blocks[block];
4620        local *= lambdas[block];
4621        gam_linalg::matrix::symmetrize_in_place(&mut local);
4622        grad_penalties.push(local);
4623    }
4624
4625    let mut grad_designs = Vec::with_capacity(f_blocks);
4626    for block in 0..f_blocks {
4627        grad_designs.push(
4628            grad_z
4629                .slice(s![.., offsets[block]..offsets[block + 1]])
4630                .to_owned(),
4631        );
4632    }
4633
4634    Ok(GaussianRemlBlocksBackwardAnalytic {
4635        grad_designs,
4636        grad_penalties,
4637        grad_y,
4638        grad_weights,
4639    })
4640}
4641
4642/// Fixed-λ multi-output Gaussian fit under a per-row dense Fisher–Rao precision
4643/// metric: coefficients, fitted values, per-output residual scale, and the
4644/// penalized Fisher-weighted objective.
4645pub struct DenseFisherGaussianFit {
4646    pub coefficients: Array2<f64>,
4647    pub fitted: Array2<f64>,
4648    pub sigma2: Array1<f64>,
4649    pub objective: f64,
4650}
4651
4652/// Add a block-diagonal `λ·S` penalty (one `S` block per output) into a stacked
4653/// `(k·n_outputs)` Hessian in place, symmetrizing `S`.
4654pub fn add_block_diagonal_penalty(
4655    hessian: &mut Array2<f64>,
4656    penalty: ArrayView2<'_, f64>,
4657    lambda: f64,
4658    n_outputs: usize,
4659) -> Result<(), EstimationError> {
4660    let k = penalty.ncols();
4661    if penalty.nrows() != k {
4662        return Err(EstimationError::InvalidInput(format!(
4663            "penalty must be square for dense Fisher fit; got {}x{}",
4664            penalty.nrows(),
4665            penalty.ncols()
4666        )));
4667    }
4668    if hessian.dim() != (k * n_outputs, k * n_outputs) {
4669        return Err(EstimationError::InvalidInput(
4670            "dense Fisher Hessian shape mismatch while adding penalty".to_string(),
4671        ));
4672    }
4673    for output in 0..n_outputs {
4674        let offset = output * k;
4675        for row in 0..k {
4676            for col in 0..k {
4677                let s_sym = 0.5 * (penalty[[row, col]] + penalty[[col, row]]);
4678                hessian[[offset + row, offset + col]] += lambda * s_sym;
4679            }
4680        }
4681    }
4682    Ok(())
4683}
4684
4685/// Closed-form fixed-λ multi-output Gaussian fit with a per-row dense Fisher–Rao
4686/// precision metric. Assembles the block `XᵀWX` (+ block-diagonal `λS`) and
4687/// `XᵀWY` via the dense Fisher block kernels, solves, then forms fitted values,
4688/// per-output residual scale `sigma2`, and the penalized Fisher-weighted
4689/// objective seeded by `latent_prior_score`. `row_weights` are the (already
4690/// resolved) per-observation likelihood weights.
4691pub fn dense_fisher_gaussian_fit(
4692    design: ArrayView2<'_, f64>,
4693    y: ArrayView2<'_, f64>,
4694    penalty: ArrayView2<'_, f64>,
4695    row_weights: ArrayView1<'_, f64>,
4696    fisher_w: ArrayView3<'_, f64>,
4697    lambda: f64,
4698    latent_prior_score: f64,
4699) -> Result<DenseFisherGaussianFit, EstimationError> {
4700    let n_obs = design.nrows();
4701    let k = design.ncols();
4702    let n_outputs = y.ncols();
4703    let mut hessian = crate::pirls::dense_block_xtwx(design, fisher_w, Some(row_weights))?;
4704    add_block_diagonal_penalty(&mut hessian, penalty, lambda, n_outputs)?;
4705    let rhs = crate::pirls::dense_block_xtwy(design, fisher_w, y, Some(row_weights))?;
4706    let beta_vec =
4707        gam_linalg::utils::solve_dense_block_system(&hessian, &rhs, "dense Fisher Gaussian")
4708            .map_err(EstimationError::InvalidInput)?;
4709    let mut coefficients = Array2::<f64>::zeros((k, n_outputs));
4710    for output in 0..n_outputs {
4711        for col in 0..k {
4712            coefficients[[col, output]] = beta_vec[output * k + col];
4713        }
4714    }
4715    let fitted = design.dot(&coefficients);
4716    let mut sigma2 = Array1::<f64>::zeros(n_outputs);
4717    let mut objective = latent_prior_score;
4718    for row in 0..n_obs {
4719        for a in 0..n_outputs {
4720            let ra = y[[row, a]] - fitted[[row, a]];
4721            sigma2[a] += row_weights[row] * ra * ra;
4722            for b in 0..n_outputs {
4723                objective += 0.5
4724                    * row_weights[row]
4725                    * ra
4726                    * fisher_w[[row, a, b]]
4727                    * (y[[row, b]] - fitted[[row, b]]);
4728            }
4729        }
4730    }
4731    for output in 0..n_outputs {
4732        sigma2[output] /= (n_obs.saturating_sub(k).max(1)) as f64;
4733        let beta_col = coefficients.column(output);
4734        let s_beta = penalty.dot(&beta_col);
4735        objective += 0.5 * lambda * beta_col.dot(&s_beta);
4736    }
4737    Ok(DenseFisherGaussianFit {
4738        coefficients,
4739        fitted,
4740        sigma2,
4741        objective,
4742    })
4743}