1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{
3 FaerCholesky, FaerEigh, fast_ab, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
4};
5use faer::Side;
6use ndarray::{
7 Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, ArrayViewMut1, ArrayViewMut2, Axis,
8 s,
9};
10use rayon::prelude::*;
11use std::sync::Once;
12
13static ILL_CONDITIONED_BACKWARD_WARNED: Once = Once::new();
21
22fn warn_ill_conditioned_backward_once(p: usize, d: usize, condition_number: f64) {
23 ILL_CONDITIONED_BACKWARD_WARNED.call_once(|| {
24 log::warn!(
25 "gaussian_reml_fit_backward: K = XᵀWX + λS is near-singular \
26 (p={p}, d={d}, cond≈{condition_number:.2e}); returning zero gradients \
27 for this fit (λ has saturated, atom is effectively unused). \
28 Further occurrences are silent."
29 );
30 });
31}
32
33fn zero_backward_result(n: usize, p: usize, d: usize) -> GaussianRemlBackwardResult {
34 GaussianRemlBackwardResult {
35 grad_x: Array2::<f64>::zeros((n, p)),
36 grad_y: Array2::<f64>::zeros((n, d)),
37 grad_penalty: Array2::<f64>::zeros((p, p)),
38 grad_weights: Array1::<f64>::zeros(n),
39 }
40}
41
42const RHO_LOWER: f64 = -30.0;
43const RHO_UPPER: f64 = 30.0;
44const EIGEN_REL_TOL: f64 = 1.0e-10;
45const GRAD_TOL: f64 = 1.0e-12;
46const MIN_DEVIANCE: f64 = 1.0e-300;
47
48fn canonicalize_penalty(penalty: ArrayView2<'_, f64>) -> Array2<f64> {
59 let p = penalty.nrows();
60 let mut out = penalty.to_owned();
61 for i in 0..p {
62 for j in (i + 1)..p {
63 let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
64 out[[i, j]] = avg;
65 out[[j, i]] = avg;
66 }
67 }
68 out
69}
70
71#[derive(Clone, Debug)]
72pub struct GaussianRemlEigenCache {
73 pub penalty_eigenvalues: Array1<f64>,
74 pub eigenvectors: Array2<f64>,
75 pub coefficient_basis: Array2<f64>,
76 pub xtwx_fingerprint: u64,
77 pub penalty_fingerprint: u64,
78 pub logdet_xtwx: f64,
79 pub logdet_penalty_positive: f64,
80 pub penalty_rank: usize,
81 pub nullity: usize,
82}
83
84#[derive(Clone, Debug, Default)]
85pub struct GaussianRemlWarmStart {
86 pub lambda: Option<f64>,
87 pub eigen_cache: Option<GaussianRemlEigenCache>,
88}
89
90impl GaussianRemlWarmStart {
91 pub fn from_multi_result(result: &GaussianRemlMultiResult) -> Self {
92 Self {
93 lambda: Some(result.lambda),
94 eigen_cache: Some(result.cache.clone()),
95 }
96 }
97}
98
99#[derive(Clone, Debug)]
100pub struct GaussianRemlResult {
101 pub lambda: f64,
102 pub rho: f64,
103 pub coefficients: Array1<f64>,
104 pub fitted: Array1<f64>,
105 pub reml_score: f64,
106 pub reml_grad_lambda: f64,
107 pub reml_hess_lambda: f64,
108 pub reml_grad_rho: f64,
109 pub reml_hess_rho: f64,
110 pub edf: f64,
111 pub sigma2: f64,
112 pub cache: GaussianRemlEigenCache,
113}
114
115#[derive(Clone, Debug)]
116pub struct GaussianRemlMultiResult {
117 pub lambda: f64,
118 pub rho: f64,
119 pub coefficients: Array2<f64>,
120 pub fitted: Array2<f64>,
121 pub reml_score: f64,
122 pub reml_grad_lambda: f64,
123 pub reml_hess_lambda: f64,
124 pub reml_grad_rho: f64,
125 pub reml_hess_rho: f64,
126 pub edf: f64,
127 pub sigma2: Array1<f64>,
128 pub cache: GaussianRemlEigenCache,
129}
130
131#[derive(Clone, Debug)]
132pub struct GaussianRemlFreeBScore {
133 pub reml_score: f64,
134 pub grad_coefficients: Array2<f64>,
135 pub grad_penalty: Array2<f64>,
136 pub grad_log_lambda: f64,
137 pub fitted: Array2<f64>,
138 pub sigma2: Array1<f64>,
139 pub edf: f64,
140}
141
142#[derive(Clone, Debug)]
143pub struct GaussianRemlBackwardResult {
144 pub grad_x: Array2<f64>,
145 pub grad_y: Array2<f64>,
146 pub grad_penalty: Array2<f64>,
147 pub grad_weights: Array1<f64>,
148}
149
150#[derive(Clone, Debug)]
151pub struct GaussianRemlMultiBackwardProblem<'a> {
152 pub x: ArrayView2<'a, f64>,
153 pub y: ArrayView2<'a, f64>,
154 pub weights: Option<ArrayView1<'a, f64>>,
155 pub fit: &'a GaussianRemlMultiResult,
156 pub grad_lambda: f64,
157 pub grad_coefficients: Option<ArrayView2<'a, f64>>,
158 pub grad_fitted: Option<ArrayView2<'a, f64>>,
159 pub grad_reml_score: f64,
160 pub grad_edf: f64,
161}
162
163#[derive(Clone, Debug)]
164pub struct GaussianRemlNoAllocWorkspace {
165 pub xtwy: Array2<f64>,
166 pub ywy: Array1<f64>,
167 pub projected_rhs: Array2<f64>,
168 pub projected_rhs_squared: Array2<f64>,
169 pub scaled_projected_rhs: Array2<f64>,
170}
171
172impl GaussianRemlNoAllocWorkspace {
173 pub fn new(n_coefficients: usize, n_outputs: usize) -> Self {
174 Self {
175 xtwy: Array2::zeros((n_coefficients, n_outputs)),
176 ywy: Array1::zeros(n_outputs),
177 projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
178 projected_rhs_squared: Array2::zeros((n_coefficients, n_outputs)),
179 scaled_projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
180 }
181 }
182
183 fn validate(&self, p: usize, d: usize) -> Result<(), EstimationError> {
184 if self.xtwy.dim() != (p, d)
185 || self.ywy.len() != d
186 || self.projected_rhs.dim() != (p, d)
187 || self.projected_rhs_squared.dim() != (p, d)
188 || self.scaled_projected_rhs.dim() != (p, d)
189 {
190 crate::bail_invalid_estim!(
191 "Gaussian REML no-alloc workspace shape mismatch: expected p={p}, d={d}"
192 );
193 }
194 Ok::<(), _>(())
195 }
196}
197
198#[derive(Clone, Copy, Debug)]
199pub struct GaussianRemlNoAllocFit {
200 pub lambda: f64,
201 pub rho: f64,
202 pub reml_score: f64,
203 pub reml_grad_lambda: f64,
204 pub reml_hess_lambda: f64,
205 pub reml_grad_rho: f64,
206 pub reml_hess_rho: f64,
207 pub edf: f64,
208}
209
210#[derive(Clone, Debug)]
211pub struct GaussianRemlMultiBatchProblem<'a> {
212 pub x: ArrayView2<'a, f64>,
213 pub y: ArrayView2<'a, f64>,
214 pub weights: Option<ArrayView1<'a, f64>>,
215 pub init_rho: Option<f64>,
216}
217
218#[derive(Clone, Debug)]
219pub struct GaussianRemlBlockOrthogonalResult {
220 pub coefficients: Vec<Array2<f64>>,
221 pub fitted: Array2<f64>,
222 pub lambdas: Array1<f64>,
223 pub log_lambdas: Array1<f64>,
224 pub reml_score: f64,
225 pub edf: Array1<f64>,
226}
227
228#[derive(Clone)]
229struct GaussianRemlPrepared {
230 cache: GaussianRemlEigenCache,
231 ywy: Array1<f64>,
232 projected_rhs_squared: Array2<f64>,
233 projected_rhs: Array2<f64>,
234 n_effective: usize,
238 n_outputs: usize,
239}
240
241#[derive(Clone, Copy)]
242struct ObjectiveEval {
243 cost: f64,
244 grad: f64,
245 hess: f64,
246 edf: f64,
247}
248
249#[derive(Clone, Copy)]
260struct TermDerivs {
261 value: f64,
262 grad: f64,
263 hess: f64,
264}
265
266impl std::ops::AddAssign<TermDerivs> for ObjectiveEval {
267 fn add_assign(&mut self, rhs: TermDerivs) {
270 self.cost += rhs.value;
271 self.grad += rhs.grad;
272 self.hess += rhs.hess;
273 }
274}
275
276fn gaussian_reml_logdet_term(
282 cache: &GaussianRemlEigenCache,
283 rho: f64,
284 n_outputs: f64,
285) -> (TermDerivs, f64) {
286 let lambda = rho.exp();
287 let mut logdet_h = cache.logdet_xtwx;
288 let mut trace_h = 0.0;
289 let mut trace_h_deriv = 0.0;
290 let mut edf = 0.0;
291 for &delta in &cache.penalty_eigenvalues {
292 let t = lambda * delta;
293 logdet_h += (1.0 + t).ln();
294 if delta > 0.0 {
295 trace_h += t / (1.0 + t);
296 trace_h_deriv += t / ((1.0 + t) * (1.0 + t));
297 }
298 edf += 1.0 / (1.0 + t);
299 }
300 let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * rho;
301 let term = TermDerivs {
302 value: 0.5 * n_outputs * (logdet_h - logdet_s),
303 grad: 0.5 * n_outputs * (trace_h - cache.penalty_rank as f64),
304 hess: 0.5 * n_outputs * trace_h_deriv,
305 };
306 (term, edf)
307}
308
309fn gaussian_reml_dispersion_term(
316 cache: &GaussianRemlEigenCache,
317 ywy: ArrayView1<'_, f64>,
318 projected_rhs_squared: ArrayView2<'_, f64>,
319 output: usize,
320 nu: f64,
321 lambda: f64,
322) -> TermDerivs {
323 let mut fitted_quadratic = 0.0;
324 let mut dp_grad = 0.0;
325 let mut dp_hess = 0.0;
326 for eig in 0..cache.penalty_eigenvalues.len() {
327 let c2 = projected_rhs_squared[[eig, output]];
328 let t = lambda * cache.penalty_eigenvalues[eig];
329 let denom = 1.0 + t;
330 fitted_quadratic += c2 / denom;
331 dp_grad += c2 * t / (denom * denom);
332 dp_hess += c2 * t * (1.0 - t) / (denom * denom * denom);
333 }
334 let dp = (ywy[output] - fitted_quadratic).max(MIN_DEVIANCE);
335 TermDerivs {
336 value: 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln()),
337 grad: 0.5 * nu * dp_grad / dp,
338 hess: 0.5 * nu * (dp_hess / dp - (dp_grad * dp_grad) / (dp * dp)),
339 }
340}
341
342pub fn gaussian_reml_closed_form(
343 x: ArrayView2<'_, f64>,
344 y: ArrayView1<'_, f64>,
345 penalty: ArrayView2<'_, f64>,
346 weights: Option<ArrayView1<'_, f64>>,
347 init_rho: Option<f64>,
348) -> Result<GaussianRemlResult, EstimationError> {
349 gaussian_reml_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
350}
351
352pub fn gaussian_reml_closed_form_with_nullspace_dim(
353 x: ArrayView2<'_, f64>,
354 y: ArrayView1<'_, f64>,
355 penalty: ArrayView2<'_, f64>,
356 nullspace_dim: Option<usize>,
357 weights: Option<ArrayView1<'_, f64>>,
358 init_rho: Option<f64>,
359) -> Result<GaussianRemlResult, EstimationError> {
360 let y2 = y.insert_axis(Axis(1));
361 let result = gaussian_reml_multi_closed_form_with_nullspace_dim(
362 x,
363 y2,
364 penalty,
365 nullspace_dim,
366 weights,
367 init_rho,
368 )?;
369 scalar_result_from_multi(result)
370}
371
372fn scalar_result_from_multi(
373 result: GaussianRemlMultiResult,
374) -> Result<GaussianRemlResult, EstimationError> {
375 Ok(GaussianRemlResult {
376 lambda: result.lambda,
377 rho: result.rho,
378 coefficients: result.coefficients.column(0).to_owned(),
379 fitted: result.fitted.column(0).to_owned(),
380 reml_score: result.reml_score,
381 reml_grad_lambda: result.reml_grad_lambda,
382 reml_hess_lambda: result.reml_hess_lambda,
383 reml_grad_rho: result.reml_grad_rho,
384 reml_hess_rho: result.reml_hess_rho,
385 edf: result.edf,
386 sigma2: result.sigma2[0],
387 cache: result.cache,
388 })
389}
390
391pub fn gaussian_reml_multi_closed_form(
392 x: ArrayView2<'_, f64>,
393 y: ArrayView2<'_, f64>,
394 penalty: ArrayView2<'_, f64>,
395 weights: Option<ArrayView1<'_, f64>>,
396 init_rho: Option<f64>,
397) -> Result<GaussianRemlMultiResult, EstimationError> {
398 gaussian_reml_multi_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
399}
400
401pub fn gaussian_reml_multi_closed_form_with_nullspace_dim(
402 x: ArrayView2<'_, f64>,
403 y: ArrayView2<'_, f64>,
404 penalty: ArrayView2<'_, f64>,
405 nullspace_dim: Option<usize>,
406 weights: Option<ArrayView1<'_, f64>>,
407 init_rho: Option<f64>,
408) -> Result<GaussianRemlMultiResult, EstimationError> {
409 let init_lambda = init_rho.map(f64::exp);
410 gaussian_reml_multi_closed_form_from_parts(
411 x,
412 y,
413 penalty,
414 nullspace_dim,
415 weights,
416 init_lambda,
417 None,
418 )
419}
420
421pub fn gaussian_reml_multi_closed_form_warm_started(
422 x: ArrayView2<'_, f64>,
423 y: ArrayView2<'_, f64>,
424 penalty: ArrayView2<'_, f64>,
425 weights: Option<ArrayView1<'_, f64>>,
426 warm_start: Option<&GaussianRemlWarmStart>,
427) -> Result<GaussianRemlMultiResult, EstimationError> {
428 gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
429 x, y, penalty, None, weights, warm_start,
430 )
431}
432
433pub fn gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
434 x: ArrayView2<'_, f64>,
435 y: ArrayView2<'_, f64>,
436 penalty: ArrayView2<'_, f64>,
437 nullspace_dim: Option<usize>,
438 weights: Option<ArrayView1<'_, f64>>,
439 warm_start: Option<&GaussianRemlWarmStart>,
440) -> Result<GaussianRemlMultiResult, EstimationError> {
441 let init_lambda = warm_start.and_then(|start| start.lambda);
442 let eigen_cache = warm_start.and_then(|start| start.eigen_cache.as_ref());
443 gaussian_reml_multi_closed_form_from_parts(
444 x,
445 y,
446 penalty,
447 nullspace_dim,
448 weights,
449 init_lambda,
450 eigen_cache,
451 )
452}
453
454pub fn gaussian_reml_multi_closed_form_with_cache(
455 x: ArrayView2<'_, f64>,
456 y: ArrayView2<'_, f64>,
457 penalty: ArrayView2<'_, f64>,
458 weights: Option<ArrayView1<'_, f64>>,
459 init_lambda: Option<f64>,
460 eigen_cache: Option<&GaussianRemlEigenCache>,
461) -> Result<GaussianRemlMultiResult, EstimationError> {
462 gaussian_reml_multi_closed_form_from_parts(
463 x,
464 y,
465 penalty,
466 None,
467 weights,
468 init_lambda,
469 eigen_cache,
470 )
471}
472
473pub fn gaussian_reml_multi_closed_form_with_cache_no_alloc(
474 x: ArrayView2<'_, f64>,
475 y: ArrayView2<'_, f64>,
476 penalty: ArrayView2<'_, f64>,
477 weights: Option<ArrayView1<'_, f64>>,
478 init_lambda: Option<f64>,
479 eigen_cache: &GaussianRemlEigenCache,
480 workspace: &mut GaussianRemlNoAllocWorkspace,
481 mut coefficients: ArrayViewMut2<'_, f64>,
482 mut fitted: ArrayViewMut2<'_, f64>,
483 mut sigma2: ArrayViewMut1<'_, f64>,
484) -> Result<GaussianRemlNoAllocFit, EstimationError> {
485 let penalty_owned = canonicalize_penalty(penalty);
489 let penalty = penalty_owned.view();
490 let n = x.nrows();
491 let p = x.ncols();
492 let d = y.ncols();
493 validate_gaussian_reml_design(x, penalty, weights)?;
494 validate_gaussian_reml_eigen_cache(eigen_cache, p)?;
495 if y.nrows() != n {
496 crate::bail_invalid_estim!(
497 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
498 y.nrows()
499 );
500 }
501 if y.iter().any(|value| !value.is_finite()) {
502 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
503 }
504 let n_effective = match weights {
505 Some(w) => effective_observation_count(w),
506 None => n,
507 };
508 if n_effective <= eigen_cache.nullity {
509 crate::bail_invalid_estim!(
510 "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
511 eigen_cache.nullity
512 );
513 }
514 let penalty_fingerprint = matrix_fingerprint(penalty);
515 if eigen_cache.penalty_fingerprint != penalty_fingerprint {
516 crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
517 }
518 workspace.validate(p, d)?;
519 if coefficients.dim() != (p, d) || fitted.dim() != (n, d) || sigma2.len() != d {
520 crate::bail_invalid_estim!(
521 "Gaussian REML no-alloc output shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
522 );
523 }
524 if let Some(lambda) = init_lambda {
525 validate_initial_lambda(lambda)?;
526 }
527
528 fill_weighted_rhs_no_alloc(x, y, weights, workspace)?;
529 project_rhs_no_alloc(eigen_cache, workspace);
530
531 let init_rho = init_lambda.map(f64::ln);
532 let rho = optimize_rho_no_alloc(
533 eigen_cache,
534 workspace.ywy.view(),
535 workspace.projected_rhs_squared.view(),
536 n_effective,
537 d,
538 init_rho,
539 )?;
540 let eval = evaluate_reml_parts(
541 eigen_cache,
542 workspace.ywy.view(),
543 workspace.projected_rhs_squared.view(),
544 n_effective,
545 d,
546 rho,
547 );
548 let lambda = rho.exp();
549 fill_coefficients_no_alloc(eigen_cache, workspace, lambda, coefficients.view_mut());
550 fill_fitted_no_alloc(x, coefficients.view(), fitted.view_mut());
551 fill_sigma2_no_alloc(
552 eigen_cache,
553 workspace.ywy.view(),
554 workspace.projected_rhs_squared.view(),
555 n_effective,
556 d,
557 lambda,
558 sigma2.view_mut(),
559 );
560 let (reml_grad_lambda, reml_hess_lambda) =
561 rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
562 Ok(GaussianRemlNoAllocFit {
563 lambda,
564 rho,
565 reml_score: eval.cost,
566 reml_grad_lambda,
567 reml_hess_lambda,
568 reml_grad_rho: eval.grad,
569 reml_hess_rho: eval.hess,
570 edf: eval.edf,
571 })
572}
573
574
575pub fn gaussian_reml_multi_closed_form_batch<'a>(
576 problems: &[GaussianRemlMultiBatchProblem<'a>],
577 penalty: ArrayView2<'a, f64>,
578 nullspace_dim: Option<usize>,
579) -> Result<Vec<GaussianRemlMultiResult>, EstimationError> {
580 if problems.is_empty() {
581 return Ok(Vec::new());
582 }
583 let xtwx_per_problem: Vec<Array2<f64>> = problems
587 .par_iter()
588 .map(|problem| {
589 let weight = match problem.weights.as_ref() {
590 Some(w) => w.to_owned(),
591 None => Array1::ones(problem.x.nrows()),
592 };
593 dense_xt_diag_x(problem.x.view(), weight.view())
594 })
595 .collect();
596 let caches =
600 build_gaussian_reml_eigen_cache_batched(xtwx_per_problem, penalty.view(), nullspace_dim);
601 let fits: Vec<Result<GaussianRemlMultiResult, EstimationError>> = problems
605 .par_iter()
606 .zip(caches.into_par_iter())
607 .map(|(problem, cache_result)| {
608 let init_lambda = problem.init_rho.map(f64::exp);
609 let cache = cache_result?;
610 gaussian_reml_multi_closed_form_from_parts(
611 problem.x.view(),
612 problem.y.view(),
613 penalty.view(),
614 nullspace_dim,
615 problem.weights.as_ref().map(|weights| weights.view()),
616 init_lambda,
617 Some(&cache),
618 )
619 })
620 .collect();
621 fits.into_iter().collect()
622}
623
624struct BlockOrthogonalEval {
625 beta: Array2<f64>,
626 logdet: f64,
627 trace: f64,
628 trace_pair: f64,
629 fitted_energy: Array1<f64>,
630 penalty_energy: Array1<f64>,
631 curvature_energy: Array1<f64>,
632 edf: f64,
633}
634
635fn block_penalty_rank_logdet(
636 penalty: ArrayView2<'_, f64>,
637) -> Result<(usize, f64), EstimationError> {
638 let eigs = penalty
639 .to_owned()
640 .eigh(Side::Lower)
641 .map_err(|_| EstimationError::ModelIsIllConditioned {
642 condition_number: f64::INFINITY,
643 })?
644 .0;
645 let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
646 let tol = (EIGEN_REL_TOL * max_abs).max(1.0e-14);
647 let mut rank = 0_usize;
648 let mut logdet = 0.0;
649 for eig in eigs.iter().copied() {
650 if eig > tol {
651 rank += 1;
652 logdet += eig.ln();
653 }
654 }
655 Ok((rank, logdet))
656}
657
658fn block_orthogonal_eval(
659 gram: &Array2<f64>,
660 rhs: &Array2<f64>,
661 penalty: &Array2<f64>,
662 rho: f64,
663) -> Result<BlockOrthogonalEval, EstimationError> {
664 let lambda = rho.exp();
665 validate_initial_lambda(lambda)?;
666 let scaled_penalty = penalty * lambda;
667 let hessian = canonicalize_penalty((gram + &scaled_penalty).view());
668 let chol = gaussian_reml_cholesky_lower(hessian)?;
669 let beta = solve_spd_from_lower_factor(&chol, rhs)?;
670 let solved_penalty = solve_spd_from_lower_factor(&chol, &scaled_penalty)?;
671 let logdet = 2.0 * chol.diag().iter().map(|value| value.ln()).sum::<f64>();
672 let trace = (0..solved_penalty.nrows())
673 .map(|i| solved_penalty[[i, i]])
674 .sum::<f64>();
675 let trace_pair =
676 gam_linalg::utils::trace_of_product(solved_penalty.view(), solved_penalty.view());
677 let fitted_energy = (rhs * &beta).sum_axis(Axis(0));
678 let p_beta = scaled_penalty.dot(&beta);
679 let penalty_energy = (&beta * &p_beta).sum_axis(Axis(0));
680 let solved_p_beta = solve_spd_from_lower_factor(&chol, &p_beta)?;
681 let curvature_energy = (&p_beta * &solved_p_beta).sum_axis(Axis(0));
682 Ok(BlockOrthogonalEval {
683 beta,
684 logdet,
685 trace,
686 trace_pair,
687 fitted_energy,
688 penalty_energy,
689 curvature_energy,
690 edf: penalty.nrows() as f64 - trace,
691 })
692}
693
694struct BlockOrthogonalScaleDerivs {
705 value: f64,
706 grad: f64,
707 hess: f64,
708}
709
710fn block_orthogonal_scale_objective(
711 eval: &BlockOrthogonalEval,
712 rho: f64,
713 scale_precision: ArrayView1<'_, f64>,
714 rank: usize,
715) -> BlockOrthogonalScaleDerivs {
716 let d = scale_precision.len() as f64;
717 let fit_term = scale_precision
718 .iter()
719 .zip(eval.fitted_energy.iter())
720 .map(|(scale, energy)| scale * energy)
721 .sum::<f64>();
722 let value = 0.5 * d * eval.logdet - 0.5 * fit_term - 0.5 * d * (rank as f64) * rho;
724 let grad = 0.5 * d * (eval.trace - rank as f64)
728 + 0.5
729 * scale_precision
730 .iter()
731 .zip(eval.penalty_energy.iter())
732 .map(|(scale, energy)| scale * energy)
733 .sum::<f64>();
734 let hess = 0.5 * d * (eval.trace - eval.trace_pair)
737 + 0.5
738 * scale_precision
739 .iter()
740 .zip(eval.penalty_energy.iter().zip(eval.curvature_energy.iter()))
741 .map(|(scale, (energy, curvature))| scale * (energy - 2.0 * curvature))
742 .sum::<f64>();
743 BlockOrthogonalScaleDerivs { value, grad, hess }
744}
745
746fn solve_block_orthogonal_rho(
747 gram: &Array2<f64>,
748 rhs: &Array2<f64>,
749 penalty: &Array2<f64>,
750 rho0: f64,
751 scale_precision: ArrayView1<'_, f64>,
752 rank: usize,
753 max_iter: usize,
754) -> Result<(f64, BlockOrthogonalEval), EstimationError> {
755 let mut rho = rho0;
756 let mut current = block_orthogonal_eval(gram, rhs, penalty, rho)?;
757 for _ in 0..max_iter {
758 let derivs = block_orthogonal_scale_objective(¤t, rho, scale_precision, rank);
761 let grad = derivs.grad;
762 let hess = derivs.hess;
763 if !(grad.is_finite() && hess.is_finite()) {
764 return Err(EstimationError::ModelIsIllConditioned {
765 condition_number: f64::INFINITY,
766 });
767 }
768 let descent = grad.signum();
779 let step = if hess > 1.0e-10 { grad / hess } else { descent };
780 let mut best_rho = rho;
781 let mut best_eval = current;
782 let mut best_phi =
783 block_orthogonal_scale_objective(&best_eval, best_rho, scale_precision, rank).value;
784 for candidate_rho in [
785 rho - step,
786 rho - 0.5 * step,
787 rho - 0.25 * step,
788 rho - descent,
789 rho - 0.25 * descent,
790 ] {
791 let Ok(candidate_eval) = block_orthogonal_eval(gram, rhs, penalty, candidate_rho)
796 else {
797 continue;
798 };
799 let candidate_phi = block_orthogonal_scale_objective(
800 &candidate_eval,
801 candidate_rho,
802 scale_precision,
803 rank,
804 )
805 .value;
806 if candidate_phi < best_phi {
807 best_rho = candidate_rho;
808 best_eval = candidate_eval;
809 best_phi = candidate_phi;
810 }
811 }
812 let delta = (best_rho - rho).abs();
813 rho = best_rho;
814 current = best_eval;
815 if delta < 1.0e-12 || step.abs() < 1.0e-7 {
816 break;
817 }
818 }
819 Ok((rho, current))
820}
821
822pub fn gaussian_reml_blocks_orthogonal_shared_scale(
823 designs: &[Array2<f64>],
824 penalties: &[Array2<f64>],
825 y: ArrayView2<'_, f64>,
826 weights: Option<ArrayView1<'_, f64>>,
827 init_rhos: Option<&[f64]>,
828) -> Result<GaussianRemlBlockOrthogonalResult, EstimationError> {
829 if designs.is_empty() {
830 crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one block");
831 }
832 if designs.len() != penalties.len() {
833 crate::bail_invalid_estim!(
834 "block-orthogonal Gaussian REML block mismatch: {} designs, {} penalties",
835 designs.len(),
836 penalties.len()
837 );
838 }
839 let n = y.nrows();
840 let d = y.ncols();
841 if d == 0 {
842 crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one output");
843 }
844 if y.iter().any(|value| !value.is_finite()) {
845 crate::bail_invalid_estim!("block-orthogonal Gaussian REML response must be finite");
846 }
847 let weight = gaussian_reml_weights(n, weights)?;
848 if let Some(rhos) = init_rhos {
849 if rhos.len() != designs.len() {
850 crate::bail_invalid_estim!(
851 "block-orthogonal Gaussian REML init_rhos length mismatch: expected {}, got {}",
852 designs.len(),
853 rhos.len()
854 );
855 }
856 if rhos.iter().any(|value| !value.is_finite()) {
857 crate::bail_invalid_estim!("block-orthogonal Gaussian REML init_rhos must be finite");
858 }
859 }
860
861 let mut ywy = Array1::<f64>::zeros(d);
862 for row in 0..n {
863 for output in 0..d {
864 ywy[output] += weight[row] * y[[row, output]] * y[[row, output]];
865 }
866 }
867 let mut grams = Vec::with_capacity(designs.len());
868 let mut rhs_blocks = Vec::with_capacity(designs.len());
869 let mut penalties_owned = Vec::with_capacity(penalties.len());
870 let mut ranks = Vec::with_capacity(penalties.len());
871 let mut penalty_logdets = Vec::with_capacity(penalties.len());
872 let mut nullity_total = 0_usize;
873 for (block, (design, penalty)) in designs.iter().zip(penalties.iter()).enumerate() {
874 let penalty_owned = canonicalize_penalty(penalty.view());
875 validate_gaussian_reml_design(design.view(), penalty_owned.view(), Some(weight.view()))?;
876 if design.nrows() != n {
877 crate::bail_invalid_estim!(
878 "block-orthogonal Gaussian REML designs[{block}] has {} rows, expected {n}",
879 design.nrows()
880 );
881 }
882 let gram = dense_xt_diag_x(design.view(), weight.view());
883 let rhs = dense_xt_diag_y(design.view(), weight.view(), y);
884 let (rank, logdet) = block_penalty_rank_logdet(penalty_owned.view())?;
885 nullity_total += penalty_owned.nrows().saturating_sub(rank);
886 grams.push(canonicalize_penalty(gram.view()));
887 rhs_blocks.push(rhs);
888 penalties_owned.push(penalty_owned);
889 ranks.push(rank);
890 penalty_logdets.push(logdet);
891 }
892 let n_effective = effective_observation_count(weight.view());
893 if n_effective <= nullity_total {
894 crate::bail_invalid_estim!(
895 "block-orthogonal Gaussian REML requires more positive-weight rows than the total penalty nullity; got n_effective={n_effective}, nullity={nullity_total}"
896 );
897 }
898 let nu = (n_effective - nullity_total) as f64;
899 let mut rhos = match init_rhos {
900 Some(values) => Array1::from_vec(values.to_vec()),
901 None => Array1::zeros(designs.len()),
902 };
903 let mut scale_precision = ywy.mapv(|value| nu / value.max(MIN_DEVIANCE));
904 let mut evals = Vec::new();
905 for _ in 0..40 {
906 evals.clear();
907 for block in 0..designs.len() {
908 let (rho, eval) = solve_block_orthogonal_rho(
909 &grams[block],
910 &rhs_blocks[block],
911 &penalties_owned[block],
912 rhos[block],
913 scale_precision.view(),
914 ranks[block],
915 32,
916 )?;
917 rhos[block] = rho;
918 evals.push(eval);
919 }
920 let mut explained = Array1::<f64>::zeros(d);
921 for eval in evals.iter() {
922 explained += &eval.fitted_energy;
923 }
924 let q = &ywy - &explained;
925 if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
926 return Err(EstimationError::ModelIsIllConditioned {
927 condition_number: f64::INFINITY,
928 });
929 }
930 let next_scale = q.mapv(|value| nu / value);
931 let scale_step = next_scale
932 .iter()
933 .zip(scale_precision.iter())
934 .map(|(next, old)| (next.ln() - old.ln()).abs())
935 .fold(0.0_f64, f64::max);
936 scale_precision = next_scale;
937 if scale_step < 1.0e-7 {
938 break;
939 }
940 }
941 evals.clear();
942 for block in 0..designs.len() {
943 let (rho, eval) = solve_block_orthogonal_rho(
944 &grams[block],
945 &rhs_blocks[block],
946 &penalties_owned[block],
947 rhos[block],
948 scale_precision.view(),
949 ranks[block],
950 16,
951 )?;
952 rhos[block] = rho;
953 evals.push(eval);
954 }
955
956 let coefficients = evals
957 .iter()
958 .map(|eval| eval.beta.clone())
959 .collect::<Vec<_>>();
960 let mut fitted = Array2::<f64>::zeros((n, d));
961 for (design, coef) in designs.iter().zip(coefficients.iter()) {
962 fitted += &fast_ab(&design.view(), &coef.view());
963 }
964 let mut explained = Array1::<f64>::zeros(d);
965 for eval in evals.iter() {
966 explained += &eval.fitted_energy;
967 }
968 let q = &ywy - &explained;
969 if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
970 return Err(EstimationError::ModelIsIllConditioned {
971 condition_number: f64::INFINITY,
972 });
973 }
974 let lambdas = rhos.mapv(f64::exp);
975 let edf = Array1::from_iter(evals.iter().map(|eval| eval.edf));
976 let logdet_term = evals
977 .iter()
978 .enumerate()
979 .map(|(block, eval)| {
980 eval.logdet - penalty_logdets[block] - (ranks[block] as f64) * rhos[block]
981 })
982 .sum::<f64>();
983 let scale_term = q
984 .iter()
985 .map(|value| nu * (1.0 + (2.0 * std::f64::consts::PI * value / nu).ln()))
986 .sum::<f64>();
987 Ok(GaussianRemlBlockOrthogonalResult {
988 coefficients,
989 fitted,
990 lambdas,
991 log_lambdas: rhos,
992 reml_score: 0.5 * (d as f64) * logdet_term + 0.5 * scale_term,
993 edf,
994 })
995}
996
997fn gaussian_reml_multi_closed_form_from_parts(
998 x: ArrayView2<'_, f64>,
999 y: ArrayView2<'_, f64>,
1000 penalty: ArrayView2<'_, f64>,
1001 nullspace_dim: Option<usize>,
1002 weights: Option<ArrayView1<'_, f64>>,
1003 init_lambda: Option<f64>,
1004 eigen_cache: Option<&GaussianRemlEigenCache>,
1005) -> Result<GaussianRemlMultiResult, EstimationError> {
1006 let prepared = prepare_gaussian_reml(x, y, penalty, nullspace_dim, weights, eigen_cache)?;
1007 let init_rho = init_lambda
1008 .map(validate_initial_lambda)
1009 .transpose()?
1010 .map(f64::ln);
1011 let rho = optimize_rho(&prepared, init_rho)?;
1012 let eval = prepared.evaluate(rho);
1013 let lambda = rho.exp();
1014 let coefficients = prepared.coefficients(lambda);
1015 let fitted = dense_ab(x, coefficients.view());
1016 let sigma2 = prepared.sigma2(lambda);
1017 let (reml_grad_lambda, reml_hess_lambda) =
1018 rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
1019 Ok(GaussianRemlMultiResult {
1020 lambda,
1021 rho,
1022 coefficients,
1023 fitted,
1024 reml_score: eval.cost,
1025 reml_grad_lambda,
1026 reml_hess_lambda,
1027 reml_grad_rho: eval.grad,
1028 reml_hess_rho: eval.hess,
1029 edf: eval.edf,
1030 sigma2,
1031 cache: prepared.cache,
1032 })
1033}
1034
1035pub fn gaussian_reml_free_b_score(
1036 x: ArrayView2<'_, f64>,
1037 y: ArrayView2<'_, f64>,
1038 coefficients: ArrayView2<'_, f64>,
1039 log_lambda: f64,
1040 penalty: ArrayView2<'_, f64>,
1041 weights: Option<ArrayView1<'_, f64>>,
1042) -> Result<GaussianRemlFreeBScore, EstimationError> {
1043 if !log_lambda.is_finite() {
1044 crate::bail_invalid_estim!("Gaussian REML log_lambda must be finite; got {log_lambda}");
1045 }
1046 let lambda = log_lambda.exp();
1047 let penalty_owned = canonicalize_penalty(penalty);
1048 let penalty = penalty_owned.view();
1049 let n = x.nrows();
1050 let p = x.ncols();
1051 let d = y.ncols();
1052 validate_gaussian_reml_design(x, penalty, weights)?;
1053 if y.nrows() != n {
1054 crate::bail_invalid_estim!(
1055 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
1056 y.nrows()
1057 );
1058 }
1059 if coefficients.dim() != (p, d) {
1060 crate::bail_invalid_estim!(
1061 "Gaussian REML coefficient shape mismatch: expected {p}x{d}, got {}x{}",
1062 coefficients.nrows(),
1063 coefficients.ncols()
1064 );
1065 }
1066 if y.iter().chain(coefficients.iter()).any(|v| !v.is_finite()) {
1067 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
1068 }
1069
1070 let weight = gaussian_reml_weights(n, weights)?;
1071 let n_effective = effective_observation_count(weight.view());
1072 let cache =
1073 build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, Some(weight.view()))?;
1074 if n_effective <= cache.nullity {
1075 crate::bail_invalid_estim!(
1076 "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
1077 cache.nullity
1078 );
1079 }
1080 let nu = n_effective as f64 - cache.nullity as f64;
1081 let fitted = dense_ab(x, coefficients);
1082 let residual = y.to_owned() - &fitted;
1083 let xtw_residual = dense_xt_diag_y(x, weight.view(), residual.view());
1084 let s_beta = dense_ab(penalty, coefficients);
1085
1086 let mut logdet_h = cache.logdet_xtwx;
1087 let mut trace_h = 0.0;
1088 let mut edf = 0.0;
1089 for &delta in &cache.penalty_eigenvalues {
1090 let t = lambda * delta;
1091 logdet_h += (1.0 + t).ln();
1092 if delta > 0.0 {
1093 trace_h += t / (1.0 + t);
1094 }
1095 edf += 1.0 / (1.0 + t);
1096 }
1097 let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * log_lambda;
1098 let mut reml_score = 0.5 * (d as f64) * (logdet_h - logdet_s);
1099 let mut grad_log_lambda = 0.5 * (d as f64) * (trace_h - cache.penalty_rank as f64);
1100 let mut grad_coefficients = Array2::<f64>::zeros((p, d));
1101 let inverse_hessian = {
1102 let xtwx = dense_xt_diag_x(x, weight.view());
1103 let mut hessian = xtwx;
1104 hessian += &(penalty.to_owned() * lambda);
1105 hessian
1106 .cholesky(Side::Lower)
1107 .map_err(EstimationError::LinearSystemSolveFailed)?
1108 .solve_mat(&Array2::<f64>::eye(p))
1109 };
1110 let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(&cache);
1111 let mut grad_penalty = Array2::<f64>::zeros((p, p));
1112 for row in 0..p {
1113 for col in 0..p {
1114 grad_penalty[[row, col]] += 0.5
1115 * (d as f64)
1116 * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1117 }
1118 }
1119 let mut sigma2 = Array1::<f64>::zeros(d);
1120
1121 for output in 0..d {
1122 let mut weighted_rss = 0.0;
1123 for row in 0..n {
1124 let r = residual[[row, output]];
1125 weighted_rss += weight[row] * r * r;
1126 }
1127 let beta_col = coefficients.column(output);
1128 let s_beta_col = s_beta.column(output);
1129 let penalty_quadratic = beta_col.dot(&s_beta_col);
1130 let dp = (weighted_rss + lambda * penalty_quadratic).max(MIN_DEVIANCE);
1131 sigma2[output] = dp / nu;
1132 reml_score += 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln());
1133 grad_log_lambda += 0.5 * nu * lambda * penalty_quadratic / dp;
1134 let scale = nu / dp;
1135 for coeff in 0..p {
1136 grad_coefficients[[coeff, output]] =
1137 scale * (-xtw_residual[[coeff, output]] + lambda * s_beta[[coeff, output]]);
1138 }
1139 add_rank_one_penalty_vjp(0.5 * scale * lambda, beta_col, &mut grad_penalty);
1140 }
1141 for i in 0..p {
1142 for j in (i + 1)..p {
1143 let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1144 grad_penalty[[i, j]] = avg;
1145 grad_penalty[[j, i]] = avg;
1146 }
1147 }
1148
1149 Ok(GaussianRemlFreeBScore {
1150 reml_score,
1151 grad_coefficients,
1152 grad_penalty,
1153 grad_log_lambda,
1154 fitted,
1155 sigma2,
1156 edf,
1157 })
1158}
1159
1160pub fn gaussian_reml_multi_closed_form_backward(
1161 x: ArrayView2<'_, f64>,
1162 y: ArrayView2<'_, f64>,
1163 penalty: ArrayView2<'_, f64>,
1164 weights: Option<ArrayView1<'_, f64>>,
1165 init_lambda: Option<f64>,
1166 upstream_lambda: f64,
1167 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1168 upstream_fitted: Option<ArrayView2<'_, f64>>,
1169 upstream_reml_score: f64,
1170 upstream_edf: f64,
1171) -> Result<GaussianRemlBackwardResult, EstimationError> {
1172 let fit =
1173 gaussian_reml_multi_closed_form_with_cache(x, y, penalty, weights, init_lambda, None)?;
1174 gaussian_reml_multi_closed_form_backward_from_fit(
1175 x,
1176 y,
1177 penalty,
1178 weights,
1179 &fit,
1180 upstream_lambda,
1181 upstream_coefficients,
1182 upstream_fitted,
1183 upstream_reml_score,
1184 upstream_edf,
1185 )
1186}
1187
1188pub fn gaussian_reml_multi_closed_form_backward_from_fit(
1189 x: ArrayView2<'_, f64>,
1190 y: ArrayView2<'_, f64>,
1191 penalty: ArrayView2<'_, f64>,
1192 weights: Option<ArrayView1<'_, f64>>,
1193 fit: &GaussianRemlMultiResult,
1194 upstream_lambda: f64,
1195 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1196 upstream_fitted: Option<ArrayView2<'_, f64>>,
1197 upstream_reml_score: f64,
1198 upstream_edf: f64,
1199) -> Result<GaussianRemlBackwardResult, EstimationError> {
1200 validate_gaussian_reml_backward_upstreams(
1201 x,
1202 y,
1203 penalty,
1204 upstream_lambda,
1205 upstream_coefficients,
1206 upstream_fitted,
1207 upstream_reml_score,
1208 upstream_edf,
1209 )?;
1210 validate_gaussian_reml_forward_fit(x, y, penalty, weights, fit)?;
1211 let lambda = fit.lambda;
1212 let n = x.nrows();
1213 let p = x.ncols();
1214 let d = y.ncols();
1215 if !(fit.reml_hess_rho.is_finite() && fit.reml_hess_rho.abs() > 1.0e-14) {
1216 warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1222 return Ok(zero_backward_result(n, p, d));
1223 }
1224 let weight = gaussian_reml_weights(n, weights)?;
1225 let inverse_hessian = match gaussian_reml_inverse_hessian_from_cache(&fit.cache, lambda) {
1226 Ok(inv) => inv,
1227 Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1228 warn_ill_conditioned_backward_once(p, d, condition_number);
1229 return Ok(zero_backward_result(n, p, d));
1230 }
1231 Err(err) => return Err(err),
1232 };
1233 gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1234 x,
1235 y,
1236 penalty,
1237 weight,
1238 fit,
1239 inverse_hessian,
1240 upstream_lambda,
1241 upstream_coefficients,
1242 upstream_fitted,
1243 upstream_reml_score,
1244 upstream_edf,
1245 n,
1246 p,
1247 d,
1248 )
1249}
1250
1251fn gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1252 x: ArrayView2<'_, f64>,
1253 y: ArrayView2<'_, f64>,
1254 penalty: ArrayView2<'_, f64>,
1255 weight: Array1<f64>,
1256 fit: &GaussianRemlMultiResult,
1257 inverse_hessian: Array2<f64>,
1258 upstream_lambda: f64,
1259 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1260 upstream_fitted: Option<ArrayView2<'_, f64>>,
1261 upstream_reml_score: f64,
1262 upstream_edf: f64,
1263 n: usize,
1264 p: usize,
1265 d: usize,
1266) -> Result<GaussianRemlBackwardResult, EstimationError> {
1267 let penalty_owned = canonicalize_penalty(penalty);
1271 let penalty = penalty_owned.view();
1272 let lambda = fit.lambda;
1273 let beta = &fit.coefficients;
1274 let residual = y.to_owned() - &fit.fitted;
1275 let nu = effective_observation_count(weight.view()) as f64 - fit.cache.nullity as f64;
1279
1280 let mut grad_x = Array2::<f64>::zeros((n, p));
1281 let mut grad_y = Array2::<f64>::zeros((n, d));
1282 let mut grad_penalty = Array2::<f64>::zeros((p, p));
1283 let mut grad_weights = Array1::<f64>::zeros(n);
1284
1285 let mut upstream_beta = Array2::<f64>::zeros((p, d));
1286 if let Some(upstream_coefficients) = upstream_coefficients {
1287 upstream_beta += &upstream_coefficients;
1288 }
1289 if let Some(upstream_fitted) = upstream_fitted {
1290 upstream_beta += &dense_atb(x, upstream_fitted);
1291 grad_x += &dense_ab(upstream_fitted, beta.t());
1292 }
1293
1294 let mut lambda_adjoint = upstream_lambda;
1295 if upstream_beta.iter().any(|value| *value != 0.0) {
1296 add_ridge_profile_vjp_with_lambda_grad(
1301 1.0,
1302 x,
1303 y,
1304 penalty,
1305 &weight,
1306 lambda,
1307 &inverse_hessian,
1308 beta,
1309 upstream_beta.view(),
1310 &mut grad_x,
1311 &mut grad_y,
1312 &mut grad_penalty,
1313 &mut grad_weights,
1314 &mut lambda_adjoint,
1315 );
1316 }
1317
1318 if upstream_reml_score != 0.0 {
1319 add_reml_score_vjp(
1320 upstream_reml_score,
1321 x,
1322 &weight,
1323 &inverse_hessian,
1324 beta,
1325 &residual,
1326 &fit.sigma2,
1327 nu,
1328 lambda,
1329 &fit.cache,
1330 &mut grad_x,
1331 &mut grad_y,
1332 &mut grad_penalty,
1333 &mut grad_weights,
1334 );
1335 lambda_adjoint += upstream_reml_score * fit.reml_grad_lambda;
1336 }
1337
1338 if upstream_edf != 0.0 {
1339 lambda_adjoint += add_edf_vjp(
1340 upstream_edf,
1341 x,
1342 penalty,
1343 &weight,
1344 lambda,
1345 &inverse_hessian,
1346 &mut grad_x,
1347 &mut grad_penalty,
1348 &mut grad_weights,
1349 );
1350 }
1351
1352 if lambda_adjoint != 0.0 {
1353 let root_scale = -lambda_adjoint * lambda / fit.reml_hess_rho;
1354 add_reml_rho_gradient_vjp(
1355 root_scale,
1356 x,
1357 y,
1358 penalty,
1359 &weight,
1360 lambda,
1361 &inverse_hessian,
1362 beta,
1363 &residual,
1364 &fit.sigma2,
1365 nu,
1366 &mut grad_x,
1367 &mut grad_y,
1368 &mut grad_penalty,
1369 &mut grad_weights,
1370 );
1371 }
1372
1373 let p = grad_penalty.nrows();
1382 for i in 0..p {
1383 for j in (i + 1)..p {
1384 let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1385 grad_penalty[[i, j]] = avg;
1386 grad_penalty[[j, i]] = avg;
1387 }
1388 }
1389 Ok(GaussianRemlBackwardResult {
1390 grad_x,
1391 grad_y,
1392 grad_penalty,
1393 grad_weights,
1394 })
1395}
1396
1397pub fn gaussian_reml_multi_closed_form_backward_batch<'a>(
1398 problems: &[GaussianRemlMultiBackwardProblem<'a>],
1399 penalty: ArrayView2<'a, f64>,
1400) -> Vec<Result<GaussianRemlBackwardResult, EstimationError>> {
1401 let inverse_hessians = batched_inverse_hessians_from_caches(problems);
1402 let results: Vec<Result<GaussianRemlBackwardResult, EstimationError>> = problems
1403 .par_iter()
1404 .zip(inverse_hessians.into_par_iter())
1405 .map(|(problem, inverse_hessian_result)| {
1406 validate_gaussian_reml_backward_upstreams(
1407 problem.x.view(),
1408 problem.y.view(),
1409 penalty,
1410 problem.grad_lambda,
1411 problem.grad_coefficients.as_ref().map(|g| g.view()),
1412 problem.grad_fitted.as_ref().map(|g| g.view()),
1413 problem.grad_reml_score,
1414 problem.grad_edf,
1415 )?;
1416 validate_gaussian_reml_forward_fit(
1417 problem.x.view(),
1418 problem.y.view(),
1419 penalty,
1420 problem.weights.as_ref().map(|w| w.view()),
1421 problem.fit,
1422 )?;
1423 let n = problem.x.nrows();
1424 let p = problem.x.ncols();
1425 let d = problem.y.ncols();
1426 if !(problem.fit.reml_hess_rho.is_finite() && problem.fit.reml_hess_rho.abs() > 1.0e-14)
1427 {
1428 warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1430 return Ok(zero_backward_result(n, p, d));
1431 }
1432 let weight = gaussian_reml_weights(n, problem.weights.as_ref().map(|w| w.view()))?;
1433 let inverse_hessian = match inverse_hessian_result {
1434 Ok(inv) => inv,
1435 Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1436 warn_ill_conditioned_backward_once(p, d, condition_number);
1437 return Ok(zero_backward_result(n, p, d));
1438 }
1439 Err(err) => return Err(err),
1440 };
1441 gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1442 problem.x.view(),
1443 problem.y.view(),
1444 penalty,
1445 weight,
1446 problem.fit,
1447 inverse_hessian,
1448 problem.grad_lambda,
1449 problem.grad_coefficients.as_ref().map(|g| g.view()),
1450 problem.grad_fitted.as_ref().map(|g| g.view()),
1451 problem.grad_reml_score,
1452 problem.grad_edf,
1453 n,
1454 p,
1455 d,
1456 )
1457 })
1458 .collect();
1459 results
1460}
1461
1462fn rho_derivatives_to_lambda(lambda: f64, grad_rho: f64, hess_rho: f64) -> (f64, f64) {
1463 (grad_rho / lambda, (hess_rho - grad_rho) / (lambda * lambda))
1464}
1465
1466fn validate_gaussian_reml_backward_upstreams(
1467 x: ArrayView2<'_, f64>,
1468 y: ArrayView2<'_, f64>,
1469 penalty: ArrayView2<'_, f64>,
1470 upstream_lambda: f64,
1471 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1472 upstream_fitted: Option<ArrayView2<'_, f64>>,
1473 upstream_reml_score: f64,
1474 upstream_edf: f64,
1475) -> Result<(), EstimationError> {
1476 if !(upstream_lambda.is_finite() && upstream_reml_score.is_finite() && upstream_edf.is_finite())
1477 {
1478 crate::bail_invalid_estim!("Gaussian REML backward upstream scalars must be finite");
1479 }
1480 if let Some(upstream_coefficients) = upstream_coefficients {
1481 if upstream_coefficients.dim() != (x.ncols(), y.ncols()) {
1482 crate::bail_invalid_estim!(
1483 "Gaussian REML backward coefficient upstream shape mismatch: expected {}x{}, got {}x{}",
1484 x.ncols(),
1485 y.ncols(),
1486 upstream_coefficients.nrows(),
1487 upstream_coefficients.ncols()
1488 );
1489 }
1490 if upstream_coefficients.iter().any(|value| !value.is_finite()) {
1491 crate::bail_invalid_estim!(
1492 "Gaussian REML backward coefficient upstream must be finite"
1493 );
1494 }
1495 }
1496 if let Some(upstream_fitted) = upstream_fitted {
1497 if upstream_fitted.dim() != y.dim() {
1498 crate::bail_invalid_estim!(
1499 "Gaussian REML backward fitted upstream shape mismatch: expected {}x{}, got {}x{}",
1500 y.nrows(),
1501 y.ncols(),
1502 upstream_fitted.nrows(),
1503 upstream_fitted.ncols()
1504 );
1505 }
1506 if upstream_fitted.iter().any(|value| !value.is_finite()) {
1507 crate::bail_invalid_estim!("Gaussian REML backward fitted upstream must be finite");
1508 }
1509 }
1510 validate_gaussian_reml_design(x, penalty, None)?;
1511 Ok(())
1512}
1513
1514fn validate_gaussian_reml_forward_fit(
1515 x: ArrayView2<'_, f64>,
1516 y: ArrayView2<'_, f64>,
1517 penalty: ArrayView2<'_, f64>,
1518 weights: Option<ArrayView1<'_, f64>>,
1519 fit: &GaussianRemlMultiResult,
1520) -> Result<(), EstimationError> {
1521 let penalty_owned = canonicalize_penalty(penalty);
1525 let penalty = penalty_owned.view();
1526 let n = x.nrows();
1527 let p = x.ncols();
1528 let d = y.ncols();
1529 validate_gaussian_reml_design(x, penalty, weights)?;
1530 validate_gaussian_reml_eigen_cache(&fit.cache, p)?;
1531 if y.nrows() != n
1532 || fit.coefficients.dim() != (p, d)
1533 || fit.fitted.dim() != (n, d)
1534 || fit.sigma2.len() != d
1535 {
1536 crate::bail_invalid_estim!(
1537 "Gaussian REML backward forward-state shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
1538 );
1539 }
1540 if !(fit.lambda.is_finite()
1541 && fit.lambda > 0.0
1542 && fit.rho.is_finite()
1543 && fit.reml_score.is_finite()
1544 && fit.reml_hess_rho.is_finite()
1545 && fit.edf.is_finite())
1546 || fit.coefficients.iter().any(|value| !value.is_finite())
1547 || fit.fitted.iter().any(|value| !value.is_finite())
1548 || fit.sigma2.iter().any(|value| !value.is_finite())
1549 {
1550 crate::bail_invalid_estim!("Gaussian REML backward forward state must be finite");
1551 }
1552 let penalty_fingerprint = matrix_fingerprint(penalty);
1553 if fit.cache.penalty_fingerprint != penalty_fingerprint {
1554 crate::bail_invalid_estim!("Gaussian REML backward forward-state penalty mismatch");
1555 }
1556 let weight = gaussian_reml_weights(n, weights)?;
1557 let xtwx = dense_xt_diag_x(x, weight.view());
1558 if fit.cache.xtwx_fingerprint != matrix_fingerprint(xtwx.view()) {
1559 crate::bail_invalid_estim!("Gaussian REML backward forward-state X'WX mismatch");
1560 }
1561 Ok(())
1562}
1563
1564fn gaussian_reml_inverse_hessian_from_cache(
1565 cache: &GaussianRemlEigenCache,
1566 lambda: f64,
1567) -> Result<Array2<f64>, EstimationError> {
1568 if !(lambda.is_finite() && lambda > 0.0) {
1569 crate::bail_invalid_estim!(
1570 "Gaussian REML lambda must be finite and positive; got {lambda}"
1571 );
1572 }
1573 let p = cache.penalty_eigenvalues.len();
1574 let mut scaled_basis = cache.coefficient_basis.clone();
1575 for eig in 0..p {
1576 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1577 for row in 0..p {
1578 scaled_basis[[row, eig]] *= scale;
1579 }
1580 }
1581 let inverse = dense_ab(scaled_basis.view(), cache.coefficient_basis.t());
1582 if inverse.iter().any(|value| !value.is_finite()) {
1583 return Err(EstimationError::ModelIsIllConditioned {
1584 condition_number: f64::INFINITY,
1585 });
1586 }
1587 Ok(inverse)
1588}
1589
1590fn batched_inverse_hessians_from_caches(
1591 problems: &[GaussianRemlMultiBackwardProblem<'_>],
1592) -> Vec<Result<Array2<f64>, EstimationError>> {
1593 if problems.is_empty() {
1594 return Vec::new();
1595 }
1596 let p = problems[0].fit.cache.coefficient_basis.nrows();
1597 let uniform = p > 0
1598 && problems.iter().all(|problem| {
1599 let cache = &problem.fit.cache;
1600 cache.coefficient_basis.dim() == (p, p) && cache.penalty_eigenvalues.len() == p
1601 });
1602 if uniform && problems.len() > 1 {
1603 let mut scaled_basis = Array3::<f64>::zeros((problems.len(), p, p));
1604 let mut basis = Array3::<f64>::zeros((problems.len(), p, p));
1605 let mut valid = true;
1606 for (idx, problem) in problems.iter().enumerate() {
1607 let lambda = problem.fit.lambda;
1608 if !(lambda.is_finite() && lambda > 0.0) {
1609 valid = false;
1610 break;
1611 }
1612 let cache = &problem.fit.cache;
1613 basis
1614 .slice_mut(s![idx, .., ..])
1615 .assign(&cache.coefficient_basis);
1616 for eig in 0..p {
1617 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1618 for row in 0..p {
1619 scaled_basis[[idx, row, eig]] = cache.coefficient_basis[[row, eig]] * scale;
1620 }
1621 }
1622 }
1623 if valid
1624 && let Some(inverses) =
1625 gam_gpu::try_fast_abt_strided_batched(scaled_basis.view(), basis.view())
1626 {
1627 return inverses
1628 .axis_iter(Axis(0))
1629 .map(|inverse| Ok(inverse.to_owned()))
1630 .collect();
1631 }
1632 }
1633 problems
1634 .iter()
1635 .map(|problem| {
1636 gaussian_reml_inverse_hessian_from_cache(&problem.fit.cache, problem.fit.lambda)
1637 })
1638 .collect()
1639}
1640
1641fn ridge_profile_vjp_data_partials(
1648 scale: f64,
1649 x: ArrayView2<'_, f64>,
1650 y: ArrayView2<'_, f64>,
1651 penalty: ArrayView2<'_, f64>,
1652 weights: &Array1<f64>,
1653 lambda: f64,
1654 inverse_hessian: &Array2<f64>,
1655 beta: &Array2<f64>,
1656 upstream_beta: ArrayView2<'_, f64>,
1657 grad_x: &mut Array2<f64>,
1658 grad_y: &mut Array2<f64>,
1659 grad_penalty: &mut Array2<f64>,
1660 grad_weights: &mut Array1<f64>,
1661) -> Array2<f64> {
1662 let m = dense_ab(inverse_hessian.view(), upstream_beta);
1663 let c = dense_ab(m.view(), beta.t());
1664 let c_sym = &c + &c.t();
1665 let ymt = dense_ab(y, m.t());
1666 let xcs = dense_ab(x, c_sym.view());
1667 for i in 0..x.nrows() {
1668 let wi = weights[i] * scale;
1669 for k in 0..x.ncols() {
1670 grad_x[[i, k]] += wi * (ymt[[i, k]] - xcs[[i, k]]);
1671 }
1672 }
1673
1674 let xm = dense_ab(x, m.view());
1675 for i in 0..x.nrows() {
1676 let wi = weights[i] * scale;
1677 for j in 0..y.ncols() {
1678 grad_y[[i, j]] += wi * xm[[i, j]];
1679 }
1680 }
1681
1682 let xc = dense_ab(x, c.view());
1683 for i in 0..x.nrows() {
1684 let mut from_b = 0.0;
1685 for j in 0..y.ncols() {
1686 from_b += y[[i, j]] * xm[[i, j]];
1687 }
1688 let mut from_a = 0.0;
1689 for k in 0..x.ncols() {
1690 from_a += x[[i, k]] * xc[[i, k]];
1691 }
1692 grad_weights[i] += scale * (from_b - from_a);
1693 }
1694
1695 for row in 0..penalty.nrows() {
1696 for col in 0..penalty.ncols() {
1697 let mut value = 0.0;
1698 for output in 0..beta.ncols() {
1699 value += m[[row, output]] * beta[[col, output]];
1700 }
1701 grad_penalty[[row, col]] -= scale * lambda * value;
1702 }
1703 }
1704 m
1705}
1706
1707fn add_ridge_profile_vjp_with_lambda_grad(
1712 scale: f64,
1713 x: ArrayView2<'_, f64>,
1714 y: ArrayView2<'_, f64>,
1715 penalty: ArrayView2<'_, f64>,
1716 weights: &Array1<f64>,
1717 lambda: f64,
1718 inverse_hessian: &Array2<f64>,
1719 beta: &Array2<f64>,
1720 upstream_beta: ArrayView2<'_, f64>,
1721 grad_x: &mut Array2<f64>,
1722 grad_y: &mut Array2<f64>,
1723 grad_penalty: &mut Array2<f64>,
1724 grad_weights: &mut Array1<f64>,
1725 lambda_adjoint_out: &mut f64,
1726) {
1727 let m = ridge_profile_vjp_data_partials(
1728 scale,
1729 x,
1730 y,
1731 penalty,
1732 weights,
1733 lambda,
1734 inverse_hessian,
1735 beta,
1736 upstream_beta,
1737 grad_x,
1738 grad_y,
1739 grad_penalty,
1740 grad_weights,
1741 );
1742 let penalty_beta = dense_ab(penalty, beta.view());
1743 let dot = m
1744 .iter()
1745 .zip(penalty_beta.iter())
1746 .map(|(left, right)| left * right)
1747 .sum::<f64>();
1748 *lambda_adjoint_out += -scale * dot;
1749}
1750
1751fn add_ridge_profile_vjp_fixed_lambda(
1755 scale: f64,
1756 x: ArrayView2<'_, f64>,
1757 y: ArrayView2<'_, f64>,
1758 penalty: ArrayView2<'_, f64>,
1759 weights: &Array1<f64>,
1760 lambda: f64,
1761 inverse_hessian: &Array2<f64>,
1762 beta: &Array2<f64>,
1763 upstream_beta: ArrayView2<'_, f64>,
1764 grad_x: &mut Array2<f64>,
1765 grad_y: &mut Array2<f64>,
1766 grad_penalty: &mut Array2<f64>,
1767 grad_weights: &mut Array1<f64>,
1768) {
1769 ridge_profile_vjp_data_partials(
1770 scale,
1771 x,
1772 y,
1773 penalty,
1774 weights,
1775 lambda,
1776 inverse_hessian,
1777 beta,
1778 upstream_beta,
1779 grad_x,
1780 grad_y,
1781 grad_penalty,
1782 grad_weights,
1783 );
1784}
1785
1786fn add_reml_score_vjp(
1787 scale: f64,
1788 x: ArrayView2<'_, f64>,
1789 weights: &Array1<f64>,
1790 inverse_hessian: &Array2<f64>,
1791 beta: &Array2<f64>,
1792 residual: &Array2<f64>,
1793 sigma2: &Array1<f64>,
1794 nu: f64,
1795 lambda: f64,
1796 cache: &GaussianRemlEigenCache,
1797 grad_x: &mut Array2<f64>,
1798 grad_y: &mut Array2<f64>,
1799 grad_penalty: &mut Array2<f64>,
1800 grad_weights: &mut Array1<f64>,
1801) {
1802 let d = beta.ncols() as f64;
1803 let xp = dense_ab(x, inverse_hessian.view());
1804 let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(cache);
1805 for row in 0..grad_penalty.nrows() {
1806 for col in 0..grad_penalty.ncols() {
1807 grad_penalty[[row, col]] +=
1808 scale * 0.5 * d * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1809 }
1810 }
1811 for i in 0..x.nrows() {
1812 let wi = weights[i] * scale * d;
1813 for k in 0..x.ncols() {
1814 grad_x[[i, k]] += wi * xp[[i, k]];
1815 }
1816 let mut leverage = 0.0;
1817 for k in 0..x.ncols() {
1818 leverage += x[[i, k]] * xp[[i, k]];
1819 }
1820 grad_weights[i] += scale * 0.5 * d * leverage;
1821 }
1822
1823 for j in 0..beta.ncols() {
1824 let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1825 let coef = scale * 0.5 * nu / dp;
1826 add_deviance_profile_vjp(
1827 coef,
1828 j,
1829 x,
1830 weights,
1831 beta,
1832 residual,
1833 grad_x,
1834 grad_y,
1835 grad_weights,
1836 );
1837 add_rank_one_penalty_vjp(coef * lambda, beta.column(j), grad_penalty);
1838 }
1839}
1840
1841fn add_edf_vjp(
1852 scale: f64,
1853 x: ArrayView2<'_, f64>,
1854 penalty: ArrayView2<'_, f64>,
1855 weights: &Array1<f64>,
1856 lambda: f64,
1857 inverse_hessian: &Array2<f64>,
1858 grad_x: &mut Array2<f64>,
1859 grad_penalty: &mut Array2<f64>,
1860 grad_weights: &mut Array1<f64>,
1861) -> f64 {
1862 let m_inv_s = dense_ab(inverse_hessian.view(), penalty);
1864 let mut g_a = dense_ab(m_inv_s.view(), inverse_hessian.view());
1865 g_a.mapv_inplace(|v| v * lambda);
1866
1867 let xg = dense_ab(x, g_a.view());
1871 let leading_scale = 2.0 * scale;
1875 for i in 0..xg.nrows() {
1876 let row_scale = leading_scale * weights[i];
1877 for k in 0..xg.ncols() {
1878 grad_x[[i, k]] += row_scale * xg[[i, k]];
1879 }
1880 }
1881 for i in 0..x.nrows() {
1882 let mut quad = 0.0;
1883 for k in 0..x.ncols() {
1884 quad += x[[i, k]] * xg[[i, k]];
1885 }
1886 grad_weights[i] += scale * quad;
1887 }
1888
1889 for row in 0..grad_penalty.nrows() {
1892 for col in 0..grad_penalty.ncols() {
1893 grad_penalty[[row, col]] +=
1894 scale * (-lambda * inverse_hessian[[row, col]] + lambda * g_a[[row, col]]);
1895 }
1896 }
1897
1898 let p_dim = m_inv_s.nrows();
1900 let mut tr_m_inv_s = 0.0;
1901 for i in 0..p_dim {
1902 tr_m_inv_s += m_inv_s[[i, i]];
1903 }
1904 let mut tr_squared = 0.0;
1905 for i in 0..p_dim {
1906 for j in 0..p_dim {
1907 tr_squared += m_inv_s[[i, j]] * m_inv_s[[j, i]];
1908 }
1909 }
1910 scale * (-tr_m_inv_s + lambda * tr_squared)
1911}
1912
1913fn add_reml_rho_gradient_vjp(
1914 scale: f64,
1915 x: ArrayView2<'_, f64>,
1916 y: ArrayView2<'_, f64>,
1917 penalty: ArrayView2<'_, f64>,
1918 weights: &Array1<f64>,
1919 lambda: f64,
1920 inverse_hessian: &Array2<f64>,
1921 beta: &Array2<f64>,
1922 residual: &Array2<f64>,
1923 sigma2: &Array1<f64>,
1924 nu: f64,
1925 grad_x: &mut Array2<f64>,
1926 grad_y: &mut Array2<f64>,
1927 grad_penalty: &mut Array2<f64>,
1928 grad_weights: &mut Array1<f64>,
1929) {
1930 let d = beta.ncols() as f64;
1931 let inverse_s = dense_ab(inverse_hessian.view(), penalty);
1932 let trace_kernel = dense_ab(inverse_s.view(), inverse_hessian.view());
1933 for row in 0..grad_penalty.nrows() {
1934 for col in 0..grad_penalty.ncols() {
1935 grad_penalty[[row, col]] += scale
1936 * 0.5
1937 * d
1938 * lambda
1939 * (inverse_hessian[[col, row]] - lambda * trace_kernel[[col, row]]);
1940 }
1941 }
1942 let xt = dense_ab(x, trace_kernel.view());
1943 for i in 0..x.nrows() {
1944 let wi = -scale * d * lambda * weights[i];
1945 for k in 0..x.ncols() {
1946 grad_x[[i, k]] += wi * xt[[i, k]];
1947 }
1948 let mut quad = 0.0;
1949 for k in 0..x.ncols() {
1950 quad += x[[i, k]] * xt[[i, k]];
1951 }
1952 grad_weights[i] -= scale * 0.5 * d * lambda * quad;
1953 }
1954
1955 let s_beta = dense_ab(penalty, beta.view());
1956 let mut upstream_beta = Array2::<f64>::zeros(beta.dim());
1957 for j in 0..beta.ncols() {
1958 let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1959 let q = lambda * beta.column(j).dot(&s_beta.column(j));
1960 let q_coef = scale * nu / dp;
1961 for row in 0..beta.nrows() {
1962 upstream_beta[[row, j]] = q_coef * lambda * s_beta[[row, j]];
1963 }
1964 let dp_coef = -scale * 0.5 * nu * q / (dp * dp);
1965 add_rank_one_penalty_vjp(
1966 (0.5 * q_coef + dp_coef) * lambda,
1967 beta.column(j),
1968 grad_penalty,
1969 );
1970 add_deviance_profile_vjp(
1971 dp_coef,
1972 j,
1973 x,
1974 weights,
1975 beta,
1976 residual,
1977 grad_x,
1978 grad_y,
1979 grad_weights,
1980 );
1981 }
1982 add_ridge_profile_vjp_fixed_lambda(
1985 1.0,
1986 x,
1987 y,
1988 penalty,
1989 weights,
1990 lambda,
1991 inverse_hessian,
1992 beta,
1993 upstream_beta.view(),
1994 grad_x,
1995 grad_y,
1996 grad_penalty,
1997 grad_weights,
1998 );
1999}
2000
2001fn add_rank_one_penalty_vjp(
2002 scale: f64,
2003 beta_col: ArrayView1<'_, f64>,
2004 grad_penalty: &mut Array2<f64>,
2005) {
2006 for row in 0..beta_col.len() {
2007 for col in 0..beta_col.len() {
2008 grad_penalty[[row, col]] += scale * beta_col[row] * beta_col[col];
2009 }
2010 }
2011}
2012
2013fn gaussian_reml_penalty_pseudoinverse_from_cache(cache: &GaussianRemlEigenCache) -> Array2<f64> {
2014 let p = cache.penalty_eigenvalues.len();
2015 let mut scaled_basis = Array2::<f64>::zeros((p, p));
2016 for eig in 0..p {
2017 let delta = cache.penalty_eigenvalues[eig];
2018 if delta > 0.0 {
2019 for row in 0..p {
2020 scaled_basis[[row, eig]] = cache.coefficient_basis[[row, eig]] / delta;
2021 }
2022 }
2023 }
2024 dense_ab(scaled_basis.view(), cache.coefficient_basis.t())
2025}
2026
2027fn add_deviance_profile_vjp(
2028 scale: f64,
2029 output: usize,
2030 x: ArrayView2<'_, f64>,
2031 weights: &Array1<f64>,
2032 beta: &Array2<f64>,
2033 residual: &Array2<f64>,
2034 grad_x: &mut Array2<f64>,
2035 grad_y: &mut Array2<f64>,
2036 grad_weights: &mut Array1<f64>,
2037) {
2038 for i in 0..x.nrows() {
2039 let r = residual[[i, output]];
2040 let wr_scale = scale * weights[i] * r;
2041 grad_y[[i, output]] += 2.0 * wr_scale;
2042 for k in 0..x.ncols() {
2043 grad_x[[i, k]] -= 2.0 * wr_scale * beta[[k, output]];
2044 }
2045 grad_weights[i] += scale * r * r;
2046 }
2047}
2048
2049fn validate_initial_lambda(lambda: f64) -> Result<f64, EstimationError> {
2050 if lambda.is_finite() && lambda > 0.0 {
2051 Ok(lambda)
2052 } else {
2053 Err(EstimationError::InvalidInput(format!(
2054 "Gaussian REML initial lambda must be finite and positive; got {lambda}"
2055 )))
2056 }
2057}
2058
2059fn dense_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2060 fast_ab(&a, &b)
2061}
2062
2063fn dense_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2064 fast_atb(&a, &b)
2065}
2066
2067fn dense_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Array2<f64> {
2068 fast_xt_diag_x(&x, &w)
2069}
2070
2071fn dense_xt_diag_y(
2072 x: ArrayView2<'_, f64>,
2073 w: ArrayView1<'_, f64>,
2074 y: ArrayView2<'_, f64>,
2075) -> Array2<f64> {
2076 fast_xt_diag_y(&x, &w, &y)
2077}
2078
2079fn matrix_fingerprint(matrix: ArrayView2<'_, f64>) -> u64 {
2080 let mut hash = 0xcbf29ce484222325_u64;
2081 hash = fnv1a_mix(hash, matrix.nrows() as u64);
2082 hash = fnv1a_mix(hash, matrix.ncols() as u64);
2083 for &value in matrix {
2084 hash = fnv1a_mix(hash, value.to_bits());
2085 }
2086 hash
2087}
2088
2089fn fnv1a_mix(hash: u64, value: u64) -> u64 {
2090 (hash ^ value).wrapping_mul(0x100000001b3)
2091}
2092
2093pub fn build_gaussian_reml_eigen_cache_batched(
2098 xtwx_matrices: Vec<Array2<f64>>,
2099 penalty: ArrayView2<'_, f64>,
2100 nullspace_dim: Option<usize>,
2101) -> Vec<Result<GaussianRemlEigenCache, EstimationError>> {
2102 let penalty_owned = canonicalize_penalty(penalty);
2103 let penalty = penalty_owned.view();
2104 let k = xtwx_matrices.len();
2105 if k == 0 {
2106 return Vec::new();
2107 }
2108 let fingerprints: Vec<u64> = xtwx_matrices
2109 .iter()
2110 .map(|m| matrix_fingerprint(m.view()))
2111 .collect();
2112
2113 let p = xtwx_matrices[0].nrows();
2114 let uniform_square = p > 0 && xtwx_matrices.iter().all(|matrix| matrix.dim() == (p, p));
2115 if uniform_square && k > 1 {
2116 let mut lower_matrices = xtwx_matrices.clone();
2117 if gam_gpu::try_cholesky_batched_lower_inplace(&mut lower_matrices).is_some() {
2118 let transforms = batched_whitened_penalty_transforms(&lower_matrices, penalty);
2126 return lower_matrices
2127 .into_iter()
2128 .enumerate()
2129 .map(|(b, lower)| {
2130 let precomputed_transform = transforms.as_ref().map(|t| t[b].clone());
2131 gaussian_reml_eigen_cache_from_lower_with_transform(
2132 lower,
2133 penalty,
2134 nullspace_dim,
2135 fingerprints[b],
2136 precomputed_transform,
2137 )
2138 })
2139 .collect();
2140 }
2141 }
2142
2143 let mut results = Vec::with_capacity(k);
2144 for (b, xtwx) in xtwx_matrices.into_iter().enumerate() {
2145 let lower = match gaussian_reml_cholesky_lower(xtwx) {
2146 Ok(l) => l,
2147 Err(err) => {
2148 results.push(Err(err));
2149 continue;
2150 }
2151 };
2152 results.push(gaussian_reml_eigen_cache_from_lower_with_transform(
2153 lower,
2154 penalty,
2155 nullspace_dim,
2156 fingerprints[b],
2157 None,
2158 ));
2159 }
2160 results
2161}
2162
2163fn batched_whitened_penalty_transforms(
2164 lowers: &[Array2<f64>],
2165 penalty: ArrayView2<'_, f64>,
2166) -> Option<Vec<Array2<f64>>> {
2167 let first = lowers.first()?;
2168 let p = first.nrows();
2169 if p == 0 || first.ncols() != p || lowers.iter().any(|lower| lower.dim() != (p, p)) {
2170 return None;
2171 }
2172 let mut linv_stack = Array3::<f64>::zeros((lowers.len(), p, p));
2173 for (idx, lower) in lowers.iter().enumerate() {
2174 let l_inv = invert_lower_triangular(lower).ok()?;
2175 linv_stack.slice_mut(s![idx, .., ..]).assign(&l_inv);
2176 }
2177 let penalty_in_metric =
2178 gam_gpu::try_fast_ab_broadcast_b_batched(linv_stack.view(), penalty)?;
2179 let transformed =
2180 gam_gpu::try_fast_abt_strided_batched(penalty_in_metric.view(), linv_stack.view())?;
2181 Some(
2182 transformed
2183 .axis_iter(Axis(0))
2184 .map(|matrix| matrix.to_owned())
2185 .collect(),
2186 )
2187}
2188
2189pub fn build_gaussian_reml_eigen_cache(
2190 x: ArrayView2<'_, f64>,
2191 penalty: ArrayView2<'_, f64>,
2192 weights: Option<ArrayView1<'_, f64>>,
2193) -> Result<GaussianRemlEigenCache, EstimationError> {
2194 build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, weights)
2195}
2196
2197pub fn build_gaussian_reml_eigen_cache_with_nullspace_dim(
2198 x: ArrayView2<'_, f64>,
2199 penalty: ArrayView2<'_, f64>,
2200 nullspace_dim: Option<usize>,
2201 weights: Option<ArrayView1<'_, f64>>,
2202) -> Result<GaussianRemlEigenCache, EstimationError> {
2203 let penalty_owned = canonicalize_penalty(penalty);
2204 let penalty = penalty_owned.view();
2205 let n = x.nrows();
2206 validate_gaussian_reml_design(x, penalty, weights)?;
2207 let weight = gaussian_reml_weights(n, weights)?;
2208
2209 let xtwx = dense_xt_diag_x(x, weight.view());
2210 gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)
2211}
2212
2213fn validate_gaussian_reml_design(
2214 x: ArrayView2<'_, f64>,
2215 penalty: ArrayView2<'_, f64>,
2216 weights: Option<ArrayView1<'_, f64>>,
2217) -> Result<(), EstimationError> {
2218 let n = x.nrows();
2219 let p = x.ncols();
2220 if penalty.nrows() != p || penalty.ncols() != p {
2221 crate::bail_invalid_estim!(
2222 "Gaussian REML penalty shape mismatch: expected {p}x{p}, got {}x{}",
2223 penalty.nrows(),
2224 penalty.ncols()
2225 );
2226 }
2227 if x.iter().chain(penalty.iter()).any(|v| !v.is_finite()) {
2228 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2229 }
2230 if let Some(w) = weights {
2231 if w.len() != n {
2232 crate::bail_invalid_estim!(
2233 "Gaussian REML weights length mismatch: expected {n}, got {}",
2234 w.len()
2235 );
2236 }
2237 if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2238 crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2239 }
2240 }
2241 Ok(())
2242}
2243
2244fn effective_observation_count(weight: ArrayView1<'_, f64>) -> usize {
2258 weight.iter().filter(|&&w| w > 0.0).count()
2259}
2260
2261fn gaussian_reml_weights(
2262 n: usize,
2263 weights: Option<ArrayView1<'_, f64>>,
2264) -> Result<Array1<f64>, EstimationError> {
2265 match weights {
2266 Some(w) => {
2267 if w.len() != n {
2268 crate::bail_invalid_estim!(
2269 "Gaussian REML weights length mismatch: expected {n}, got {}",
2270 w.len()
2271 );
2272 }
2273 if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2274 crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2275 }
2276 Ok(w.to_owned())
2277 }
2278 None => Ok(Array1::ones(n)),
2279 }
2280}
2281
2282fn gaussian_reml_eigen_cache_from_xtwx(
2283 xtwx: Array2<f64>,
2284 penalty: ArrayView2<'_, f64>,
2285 nullspace_dim: Option<usize>,
2286) -> Result<GaussianRemlEigenCache, EstimationError> {
2287 let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2288 let lower = gaussian_reml_cholesky_lower(xtwx)?;
2289 gaussian_reml_eigen_cache_from_lower(lower, penalty, nullspace_dim, xtwx_fingerprint)
2290}
2291
2292fn gaussian_reml_eigen_cache_from_lower(
2297 lower: Array2<f64>,
2298 penalty: ArrayView2<'_, f64>,
2299 nullspace_dim: Option<usize>,
2300 xtwx_fingerprint: u64,
2301) -> Result<GaussianRemlEigenCache, EstimationError> {
2302 gaussian_reml_eigen_cache_from_lower_with_transform(
2303 lower,
2304 penalty,
2305 nullspace_dim,
2306 xtwx_fingerprint,
2307 None,
2308 )
2309}
2310
2311fn gaussian_reml_eigen_cache_from_lower_with_transform(
2314 lower: Array2<f64>,
2315 penalty: ArrayView2<'_, f64>,
2316 nullspace_dim: Option<usize>,
2317 xtwx_fingerprint: u64,
2318 precomputed_transform: Option<Array2<f64>>,
2319) -> Result<GaussianRemlEigenCache, EstimationError> {
2320 let p = lower.nrows();
2321 if lower.ncols() != p {
2322 crate::bail_invalid_estim!("Gaussian REML Cholesky factor must be square");
2323 }
2324 let penalty_fingerprint = matrix_fingerprint(penalty);
2325 let logdet_xtwx = 2.0 * lower.diag().iter().map(|v| v.ln()).sum::<f64>();
2326 let transformed_penalty = match precomputed_transform {
2327 Some(transformed) => transformed,
2328 None => {
2329 let l_inv = invert_lower_triangular(&lower)?;
2330 let penalty_in_metric = dense_ab(l_inv.view(), penalty);
2331 dense_ab(penalty_in_metric.view(), l_inv.t())
2332 }
2333 };
2334 let (mut penalty_eigenvalues, eigenvectors) =
2335 transformed_penalty.eigh(Side::Lower).map_err(|_| {
2336 EstimationError::ModelIsIllConditioned {
2337 condition_number: f64::INFINITY,
2338 }
2339 })?;
2340 let max_abs_eig = penalty_eigenvalues
2350 .iter()
2351 .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2352 let eig_tol = max_abs_eig * EIGEN_REL_TOL;
2353 for value in &mut penalty_eigenvalues {
2354 if *value < 0.0 && value.abs() <= eig_tol {
2355 *value = 0.0;
2356 }
2357 if *value < 0.0 {
2358 crate::bail_invalid_estim!(
2359 "Gaussian REML penalty is not positive semidefinite; eigenvalue={value:.3e}"
2360 );
2361 }
2362 }
2363 let penalty_rank = penalty_eigenvalues
2364 .iter()
2365 .filter(|&&value| value > eig_tol)
2366 .count();
2367 let nullity = p - penalty_rank;
2368 if let Some(expected_nullity) = nullspace_dim
2369 && expected_nullity != nullity
2370 {
2371 crate::bail_invalid_estim!(
2372 "Gaussian REML penalty nullspace mismatch: expected {expected_nullity}, inferred {nullity}"
2373 );
2374 }
2375 let logdet_penalty_positive = gaussian_penalty_positive_logdet(penalty, penalty_rank)?;
2376 let coefficient_basis = solve_upper_triangular_matrix(&lower.t().to_owned(), &eigenvectors)?;
2377
2378 Ok(GaussianRemlEigenCache {
2379 penalty_eigenvalues,
2380 eigenvectors,
2381 coefficient_basis,
2382 xtwx_fingerprint,
2383 penalty_fingerprint,
2384 logdet_xtwx,
2385 logdet_penalty_positive,
2386 penalty_rank,
2387 nullity,
2388 })
2389}
2390
2391fn gaussian_reml_cholesky_lower(xtwx: Array2<f64>) -> Result<Array2<f64>, EstimationError> {
2392 let mut gpu_candidate = xtwx.clone();
2402 if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2403 return Ok(gpu_candidate);
2404 }
2405 if let Ok(chol) = xtwx.cholesky(Side::Lower) {
2406 return Ok(chol.lower_triangular());
2407 }
2408 let p = xtwx.nrows();
2409 let trace: f64 = (0..p).map(|i| xtwx[[i, i]]).sum();
2410 if !trace.is_finite() || trace <= 0.0 {
2411 return Err(EstimationError::ModelIsIllConditioned {
2412 condition_number: f64::INFINITY,
2413 });
2414 }
2415 let mut jitter = 1e-12 * trace / (p as f64);
2416 for _ in 0..6 {
2417 let mut jittered = xtwx.clone();
2418 for i in 0..p {
2419 jittered[[i, i]] += jitter;
2420 }
2421 let mut gpu_candidate = jittered.clone();
2422 if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2423 return Ok(gpu_candidate);
2424 }
2425 if let Ok(chol) = jittered.cholesky(Side::Lower) {
2426 return Ok(chol.lower_triangular());
2427 }
2428 jitter *= 10.0;
2429 }
2430 Err(EstimationError::ModelIsIllConditioned {
2431 condition_number: f64::INFINITY,
2432 })
2433}
2434
2435fn gaussian_penalty_positive_logdet(
2436 penalty: ArrayView2<'_, f64>,
2437 penalty_rank: usize,
2438) -> Result<f64, EstimationError> {
2439 if penalty_rank == 0 {
2440 return Ok(0.0);
2441 }
2442 let (pen_eigs, _) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
2443 EstimationError::ModelIsIllConditioned {
2444 condition_number: f64::INFINITY,
2445 }
2446 })?;
2447 let pen_scale = pen_eigs
2451 .iter()
2452 .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2453 let pen_tol = pen_scale * EIGEN_REL_TOL;
2454 let mut positive_eigs: Vec<f64> = pen_eigs
2455 .iter()
2456 .copied()
2457 .filter(|&value| value > pen_tol)
2458 .collect();
2459 if positive_eigs.len() != penalty_rank {
2460 positive_eigs = pen_eigs
2461 .iter()
2462 .copied()
2463 .filter(|&value| value > 0.0)
2464 .collect();
2465 positive_eigs.sort_by(|a, b| b.total_cmp(a));
2466 if positive_eigs.len() < penalty_rank {
2467 return Err(EstimationError::ModelIsIllConditioned {
2468 condition_number: f64::INFINITY,
2469 });
2470 }
2471 positive_eigs.truncate(penalty_rank);
2472 }
2473 Ok(positive_eigs.iter().map(|value| value.ln()).sum())
2474}
2475
2476fn validate_gaussian_reml_eigen_cache(
2477 cache: &GaussianRemlEigenCache,
2478 p: usize,
2479) -> Result<(), EstimationError> {
2480 if cache.penalty_eigenvalues.len() != p
2481 || cache.eigenvectors.dim() != (p, p)
2482 || cache.coefficient_basis.dim() != (p, p)
2483 {
2484 crate::bail_invalid_estim!(
2485 "Gaussian REML eigen cache dimension mismatch: expected {p} coefficients"
2486 );
2487 }
2488 if cache.penalty_rank > p || cache.nullity > p || cache.penalty_rank + cache.nullity != p {
2489 crate::bail_invalid_estim!(
2490 "Gaussian REML eigen cache rank/nullity mismatch: rank={}, nullity={}, p={p}",
2491 cache.penalty_rank,
2492 cache.nullity
2493 );
2494 }
2495 if !(cache.logdet_xtwx.is_finite() && cache.logdet_penalty_positive.is_finite()) {
2496 crate::bail_invalid_estim!("Gaussian REML eigen cache log-determinants must be finite");
2497 }
2498 if cache
2499 .penalty_eigenvalues
2500 .iter()
2501 .any(|value| !value.is_finite() || *value < 0.0)
2502 || cache.eigenvectors.iter().any(|value| !value.is_finite())
2503 || cache
2504 .coefficient_basis
2505 .iter()
2506 .any(|value| !value.is_finite())
2507 {
2508 crate::bail_invalid_estim!(
2509 "Gaussian REML eigen cache entries must be finite with non-negative eigenvalues"
2510 .to_string(),
2511 );
2512 }
2513 Ok::<(), _>(())
2514}
2515
2516fn prepare_gaussian_reml(
2517 x: ArrayView2<'_, f64>,
2518 y: ArrayView2<'_, f64>,
2519 penalty: ArrayView2<'_, f64>,
2520 nullspace_dim: Option<usize>,
2521 weights: Option<ArrayView1<'_, f64>>,
2522 eigen_cache: Option<&GaussianRemlEigenCache>,
2523) -> Result<GaussianRemlPrepared, EstimationError> {
2524 let penalty_owned = canonicalize_penalty(penalty);
2527 let penalty = penalty_owned.view();
2528 let n = x.nrows();
2529 let p = x.ncols();
2530 let d = y.ncols();
2531 validate_gaussian_reml_design(x, penalty, weights)?;
2532 if y.nrows() != n {
2533 crate::bail_invalid_estim!(
2534 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
2535 y.nrows()
2536 );
2537 }
2538 if y.iter().any(|v| !v.is_finite()) {
2539 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2540 }
2541 let weight = gaussian_reml_weights(n, weights)?;
2542 let n_effective = effective_observation_count(weight.view());
2543
2544 let xtwy = dense_xt_diag_y(x, weight.view(), y);
2545 let ywy = Array1::from_iter((0..d).map(|j| {
2546 let mut value = 0.0;
2547 for row in 0..n {
2548 value += weight[row] * y[[row, j]] * y[[row, j]];
2549 }
2550 value
2551 }));
2552 let xtwx = dense_xt_diag_x(x, weight.view());
2553
2554 if let Some(cache) = eigen_cache {
2555 validate_gaussian_reml_eigen_cache(cache, p)?;
2556 let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2557 if cache.xtwx_fingerprint != xtwx_fingerprint {
2558 crate::bail_invalid_estim!("Gaussian REML eigen cache X'WX mismatch");
2559 }
2560 let penalty_fingerprint = matrix_fingerprint(penalty);
2561 if cache.penalty_fingerprint != penalty_fingerprint {
2562 crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
2563 }
2564 if let Some(expected_nullity) = nullspace_dim
2565 && expected_nullity != cache.nullity
2566 {
2567 crate::bail_invalid_estim!(
2568 "Gaussian REML eigen cache nullspace mismatch: expected {expected_nullity}, got {}",
2569 cache.nullity
2570 );
2571 }
2572 if n_effective <= cache.nullity {
2573 crate::bail_invalid_estim!(
2574 "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
2575 cache.nullity
2576 );
2577 }
2578 let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2579 let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2580 return Ok(GaussianRemlPrepared {
2581 cache: cache.clone(),
2582 ywy,
2583 projected_rhs_squared,
2584 projected_rhs,
2585 n_effective,
2586 n_outputs: d,
2587 });
2588 }
2589
2590 let cache = gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)?;
2591 if n_effective <= cache.nullity {
2592 crate::bail_invalid_estim!(
2593 "Gaussian REML requires more positive-weight rows than the nullspace dimension; got n_effective={n_effective}, nullity={}",
2594 cache.nullity
2595 );
2596 }
2597 let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2598 let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2599
2600 Ok(GaussianRemlPrepared {
2601 cache,
2602 ywy,
2603 projected_rhs_squared,
2604 projected_rhs,
2605 n_effective,
2606 n_outputs: d,
2607 })
2608}
2609
2610impl GaussianRemlPrepared {
2611 fn nu(&self) -> f64 {
2612 self.n_effective as f64 - self.cache.nullity as f64
2613 }
2614
2615 fn evaluate(&self, rho: f64) -> ObjectiveEval {
2616 evaluate_reml_parts(
2617 &self.cache,
2618 self.ywy.view(),
2619 self.projected_rhs_squared.view(),
2620 self.n_effective,
2621 self.n_outputs,
2622 rho,
2623 )
2624 }
2625
2626 fn coefficients(&self, lambda: f64) -> Array2<f64> {
2627 let mut scaled = self.projected_rhs.clone();
2628 for i in 0..self.cache.penalty_eigenvalues.len() {
2629 let scale = 1.0 / (1.0 + lambda * self.cache.penalty_eigenvalues[i]);
2630 for value in scaled.row_mut(i) {
2631 *value *= scale;
2632 }
2633 }
2634 dense_ab(self.cache.coefficient_basis.view(), scaled.view())
2635 }
2636
2637 fn sigma2(&self, lambda: f64) -> Array1<f64> {
2638 let nu = self.nu();
2639 Array1::from_iter((0..self.n_outputs).map(|j| {
2640 let mut fitted_quadratic = 0.0;
2641 for i in 0..self.cache.penalty_eigenvalues.len() {
2642 let denom = 1.0 + lambda * self.cache.penalty_eigenvalues[i];
2643 fitted_quadratic += self.projected_rhs_squared[[i, j]] / denom;
2644 }
2645 ((self.ywy[j] - fitted_quadratic).max(MIN_DEVIANCE)) / nu
2646 }))
2647 }
2648}
2649
2650fn optimize_rho(
2651 prepared: &GaussianRemlPrepared,
2652 init_rho: Option<f64>,
2653) -> Result<f64, EstimationError> {
2654 if prepared.cache.penalty_rank == 0 {
2655 return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2656 }
2657
2658 const GRID_INTERVALS: usize = 96;
2659 let mut stationary = Vec::<f64>::new();
2660 let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2661 let mut prev_rho = RHO_LOWER;
2662 let mut prev_eval = prepared.evaluate(prev_rho);
2663 grid.push((prev_rho, prev_eval.cost));
2664 for i in 1..=GRID_INTERVALS {
2665 let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2666 let eval = prepared.evaluate(rho);
2667 grid.push((rho, eval.cost));
2668 if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2669 push_candidate(
2670 &mut stationary,
2671 refine_stationary_rho(prepared, prev_rho, rho, 0.5 * (prev_rho + rho)),
2672 );
2673 }
2674 prev_rho = rho;
2675 prev_eval = eval;
2676 }
2677
2678 let mut candidates = stationary;
2679 push_candidate(&mut candidates, RHO_LOWER);
2680 push_candidate(&mut candidates, RHO_UPPER);
2681 if let Some(rho0) = init_rho {
2682 push_candidate(&mut candidates, rho0);
2683 }
2684 if let Some(rho) = refine_best_grid_cell(prepared, &grid) {
2685 push_candidate(&mut candidates, rho);
2686 }
2687
2688 candidates
2692 .into_iter()
2693 .map(|rho| (rho, prepared.evaluate(rho).cost))
2694 .min_by(|(_, a), (_, b)| a.total_cmp(b))
2695 .map(|(rho, _)| rho)
2696 .ok_or_else(|| {
2697 EstimationError::InvalidInput(
2698 "Gaussian REML optimizer produced no candidates".to_string(),
2699 )
2700 })
2701}
2702
2703fn refine_best_grid_cell(prepared: &GaussianRemlPrepared, grid: &[(f64, f64)]) -> Option<f64> {
2704 let best_idx = grid
2705 .iter()
2706 .enumerate()
2707 .filter(|(_, (_, cost))| cost.is_finite())
2708 .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2709 .map(|(idx, _)| idx)?;
2710 if best_idx == 0 || best_idx + 1 == grid.len() {
2711 return Some(grid[best_idx].0);
2712 }
2713 Some(refine_stationary_rho(
2729 prepared,
2730 grid[best_idx - 1].0,
2731 grid[best_idx + 1].0,
2732 grid[best_idx].0,
2733 ))
2734}
2735
2736fn fill_weighted_rhs_no_alloc(
2737 x: ArrayView2<'_, f64>,
2738 y: ArrayView2<'_, f64>,
2739 weights: Option<ArrayView1<'_, f64>>,
2740 workspace: &mut GaussianRemlNoAllocWorkspace,
2741) -> Result<(), EstimationError> {
2742 let d = y.ncols();
2743
2744 let (xtwy, ywy_full) = match weights {
2750 Some(w) => (fast_xt_diag_y(&x, &w, &y), fast_xt_diag_y(&y, &w, &y)),
2751 None => (fast_atb(&x, &y), fast_atb(&y, &y)),
2752 };
2753 workspace.xtwy.assign(&xtwy);
2754 for output in 0..d {
2755 workspace.ywy[output] = ywy_full[[output, output]];
2756 }
2757
2758 if workspace
2759 .xtwy
2760 .iter()
2761 .chain(workspace.ywy.iter())
2762 .any(|value| !value.is_finite())
2763 {
2764 crate::bail_invalid_estim!("Gaussian REML weighted cross-products must be finite");
2765 }
2766 Ok(())
2767}
2768
2769fn project_rhs_no_alloc(
2770 cache: &GaussianRemlEigenCache,
2771 workspace: &mut GaussianRemlNoAllocWorkspace,
2772) {
2773 let projected = fast_atb(&cache.coefficient_basis, &workspace.xtwy);
2776 workspace.projected_rhs.assign(&projected);
2777 let p = cache.penalty_eigenvalues.len();
2778 let d = workspace.ywy.len();
2779 for eig in 0..p {
2780 for output in 0..d {
2781 let value = workspace.projected_rhs[[eig, output]];
2782 workspace.projected_rhs_squared[[eig, output]] = value * value;
2783 }
2784 }
2785}
2786
2787fn evaluate_reml_parts(
2788 cache: &GaussianRemlEigenCache,
2789 ywy: ArrayView1<'_, f64>,
2790 projected_rhs_squared: ArrayView2<'_, f64>,
2791 n_effective: usize,
2792 n_outputs: usize,
2793 rho: f64,
2794) -> ObjectiveEval {
2795 let lambda = rho.exp();
2796 let nu = n_effective as f64 - cache.nullity as f64;
2797 let d = n_outputs as f64;
2798
2799 let (logdet_term, edf) = gaussian_reml_logdet_term(cache, rho, d);
2802 let mut eval = ObjectiveEval {
2803 cost: 0.0,
2804 grad: 0.0,
2805 hess: 0.0,
2806 edf,
2807 };
2808 eval += logdet_term;
2809 for output in 0..n_outputs {
2810 eval +=
2811 gaussian_reml_dispersion_term(cache, ywy, projected_rhs_squared, output, nu, lambda);
2812 }
2813 eval
2814}
2815
2816fn optimize_rho_no_alloc(
2817 cache: &GaussianRemlEigenCache,
2818 ywy: ArrayView1<'_, f64>,
2819 projected_rhs_squared: ArrayView2<'_, f64>,
2820 n_effective: usize,
2821 n_outputs: usize,
2822 init_rho: Option<f64>,
2823) -> Result<f64, EstimationError> {
2824 if cache.penalty_rank == 0 {
2825 return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2826 }
2827
2828 let lower_eval = evaluate_reml_parts(
2829 cache,
2830 ywy,
2831 projected_rhs_squared,
2832 n_effective,
2833 n_outputs,
2834 RHO_LOWER,
2835 );
2836
2837 let mut best_rho = RHO_LOWER;
2838 let mut best_cost = lower_eval.cost;
2839
2840 const GRID_INTERVALS: usize = 96;
2841 let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2842 let mut prev_rho = RHO_LOWER;
2843 let mut prev_eval = lower_eval;
2844 grid.push((prev_rho, prev_eval.cost));
2845 for i in 1..=GRID_INTERVALS {
2846 let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2847 let eval = evaluate_reml_parts(
2848 cache,
2849 ywy,
2850 projected_rhs_squared,
2851 n_effective,
2852 n_outputs,
2853 rho,
2854 );
2855 grid.push((rho, eval.cost));
2856 if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2857 let stationary_rho = refine_stationary_rho_no_alloc(
2858 cache,
2859 ywy,
2860 projected_rhs_squared,
2861 n_effective,
2862 n_outputs,
2863 prev_rho,
2864 rho,
2865 0.5 * (prev_rho + rho),
2866 );
2867 consider_rho_no_alloc(
2868 cache,
2869 ywy,
2870 projected_rhs_squared,
2871 n_effective,
2872 n_outputs,
2873 stationary_rho,
2874 &mut best_rho,
2875 &mut best_cost,
2876 );
2877 }
2878 prev_rho = rho;
2879 prev_eval = eval;
2880 }
2881 if let Some(best_idx) = grid
2882 .iter()
2883 .enumerate()
2884 .filter(|(_, (_, cost))| cost.is_finite())
2885 .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2886 .map(|(idx, _)| idx)
2887 {
2888 let refined = if best_idx == 0 || best_idx + 1 == grid.len() {
2889 grid[best_idx].0
2890 } else {
2891 refine_stationary_rho_no_alloc(
2901 cache,
2902 ywy,
2903 projected_rhs_squared,
2904 n_effective,
2905 n_outputs,
2906 grid[best_idx - 1].0,
2907 grid[best_idx + 1].0,
2908 grid[best_idx].0,
2909 )
2910 };
2911 consider_rho_no_alloc(
2912 cache,
2913 ywy,
2914 projected_rhs_squared,
2915 n_effective,
2916 n_outputs,
2917 refined,
2918 &mut best_rho,
2919 &mut best_cost,
2920 );
2921 }
2922
2923 consider_rho_no_alloc(
2924 cache,
2925 ywy,
2926 projected_rhs_squared,
2927 n_effective,
2928 n_outputs,
2929 RHO_UPPER,
2930 &mut best_rho,
2931 &mut best_cost,
2932 );
2933 if let Some(rho0) = init_rho {
2934 consider_rho_no_alloc(
2935 cache,
2936 ywy,
2937 projected_rhs_squared,
2938 n_effective,
2939 n_outputs,
2940 rho0,
2941 &mut best_rho,
2942 &mut best_cost,
2943 );
2944 }
2945
2946 if best_cost.is_finite() {
2947 Ok(best_rho)
2948 } else {
2949 Err(EstimationError::InvalidInput(
2950 "Gaussian REML optimizer produced no finite candidates".to_string(),
2951 ))
2952 }
2953}
2954
2955fn consider_rho_no_alloc(
2956 cache: &GaussianRemlEigenCache,
2957 ywy: ArrayView1<'_, f64>,
2958 projected_rhs_squared: ArrayView2<'_, f64>,
2959 n_effective: usize,
2960 n_outputs: usize,
2961 rho: f64,
2962 best_rho: &mut f64,
2963 best_cost: &mut f64,
2964) {
2965 if !rho.is_finite() {
2966 return;
2967 }
2968 let candidate = rho.clamp(RHO_LOWER, RHO_UPPER);
2969 let eval = evaluate_reml_parts(
2970 cache,
2971 ywy,
2972 projected_rhs_squared,
2973 n_effective,
2974 n_outputs,
2975 candidate,
2976 );
2977 if eval.cost < *best_cost {
2978 *best_rho = candidate;
2979 *best_cost = eval.cost;
2980 }
2981}
2982
2983fn refine_stationary_rho_no_alloc(
2984 cache: &GaussianRemlEigenCache,
2985 ywy: ArrayView1<'_, f64>,
2986 projected_rhs_squared: ArrayView2<'_, f64>,
2987 n_effective: usize,
2988 n_outputs: usize,
2989 mut lo: f64,
2990 mut hi: f64,
2991 mut rho: f64,
2992) -> f64 {
2993 for _ in 0..80 {
2994 let eval = evaluate_reml_parts(
2995 cache,
2996 ywy,
2997 projected_rhs_squared,
2998 n_effective,
2999 n_outputs,
3000 rho,
3001 );
3002 if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3003 return rho;
3004 }
3005 if eval.grad >= 0.0 {
3006 hi = rho;
3007 } else {
3008 lo = rho;
3009 }
3010 let newton = if eval.hess > 0.0 {
3011 let candidate = rho - eval.grad / eval.hess;
3012 (candidate > lo && candidate < hi).then_some(candidate)
3013 } else {
3014 None
3015 };
3016 if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3017 break;
3018 }
3019 rho = newton.unwrap_or(0.5 * (lo + hi));
3020 }
3021 0.5 * (lo + hi)
3022}
3023
3024fn fill_coefficients_no_alloc(
3025 cache: &GaussianRemlEigenCache,
3026 workspace: &mut GaussianRemlNoAllocWorkspace,
3027 lambda: f64,
3028 mut coefficients: ArrayViewMut2<'_, f64>,
3029) {
3030 let p = cache.penalty_eigenvalues.len();
3031 let d = workspace.ywy.len();
3032 for eig in 0..p {
3033 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
3034 for output in 0..d {
3035 workspace.scaled_projected_rhs[[eig, output]] =
3036 workspace.projected_rhs[[eig, output]] * scale;
3037 }
3038 }
3039
3040 for col in 0..p {
3041 for output in 0..d {
3042 let mut value = 0.0;
3043 for eig in 0..p {
3044 value += cache.coefficient_basis[[col, eig]]
3045 * workspace.scaled_projected_rhs[[eig, output]];
3046 }
3047 coefficients[[col, output]] = value;
3048 }
3049 }
3050}
3051
3052fn fill_fitted_no_alloc(
3053 x: ArrayView2<'_, f64>,
3054 coefficients: ArrayView2<'_, f64>,
3055 mut fitted: ArrayViewMut2<'_, f64>,
3056) {
3057 let n = x.nrows();
3058 let p = x.ncols();
3059 let d = coefficients.ncols();
3060 for row in 0..n {
3061 for output in 0..d {
3062 let mut value = 0.0;
3063 for col in 0..p {
3064 value += x[[row, col]] * coefficients[[col, output]];
3065 }
3066 fitted[[row, output]] = value;
3067 }
3068 }
3069}
3070
3071fn fill_sigma2_no_alloc(
3072 cache: &GaussianRemlEigenCache,
3073 ywy: ArrayView1<'_, f64>,
3074 projected_rhs_squared: ArrayView2<'_, f64>,
3075 n_effective: usize,
3076 n_outputs: usize,
3077 lambda: f64,
3078 mut sigma2: ArrayViewMut1<'_, f64>,
3079) {
3080 let nu = n_effective as f64 - cache.nullity as f64;
3081 for output in 0..n_outputs {
3082 let mut fitted_quadratic = 0.0;
3083 for eig in 0..cache.penalty_eigenvalues.len() {
3084 let denom = 1.0 + lambda * cache.penalty_eigenvalues[eig];
3085 fitted_quadratic += projected_rhs_squared[[eig, output]] / denom;
3086 }
3087 sigma2[output] = ((ywy[output] - fitted_quadratic).max(MIN_DEVIANCE)) / nu;
3088 }
3089}
3090
3091fn push_candidate(candidates: &mut Vec<f64>, rho: f64) {
3092 if rho.is_finite() {
3093 candidates.push(rho.clamp(RHO_LOWER, RHO_UPPER));
3094 }
3095}
3096
3097fn refine_stationary_rho(
3098 prepared: &GaussianRemlPrepared,
3099 mut lo: f64,
3100 mut hi: f64,
3101 mut rho: f64,
3102) -> f64 {
3103 for _ in 0..80 {
3104 let eval = prepared.evaluate(rho);
3105 if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3106 return rho;
3107 }
3108 if eval.grad >= 0.0 {
3109 hi = rho;
3110 } else {
3111 lo = rho;
3112 }
3113 let newton = if eval.hess > 0.0 {
3114 let candidate = rho - eval.grad / eval.hess;
3115 (candidate > lo && candidate < hi).then_some(candidate)
3116 } else {
3117 None
3118 };
3119 if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3120 break;
3121 }
3122 rho = newton.unwrap_or(0.5 * (lo + hi));
3123 }
3124 0.5 * (lo + hi)
3125}
3126
3127fn invert_lower_triangular(lower: &Array2<f64>) -> Result<Array2<f64>, EstimationError> {
3128 let n = lower.nrows();
3129 if lower.ncols() != n {
3130 crate::bail_invalid_estim!("lower-triangular solve requires a square matrix");
3131 }
3132 let eye = Array2::eye(n);
3133 solve_lower_triangular_matrix(lower, &eye)
3134}
3135
3136fn solve_lower_triangular_matrix(
3137 lower: &Array2<f64>,
3138 rhs: &Array2<f64>,
3139) -> Result<Array2<f64>, EstimationError> {
3140 let n = lower.nrows();
3141 if lower.ncols() != n || rhs.nrows() != n {
3142 crate::bail_invalid_estim!("lower-triangular solve dimension mismatch");
3143 }
3144 if let Some(out) = gam_gpu::try_solve_lower_triangular_matrix(lower.view(), rhs.view()) {
3145 return Ok(out);
3146 }
3147 let mut out = Array2::<f64>::zeros(rhs.dim());
3148 for col in 0..rhs.ncols() {
3149 for i in 0..n {
3150 let mut value = rhs[[i, col]];
3151 for k in 0..i {
3152 value -= lower[[i, k]] * out[[k, col]];
3153 }
3154 let diag = lower[[i, i]];
3155 if !(diag.is_finite() && diag.abs() > 0.0) {
3156 return Err(EstimationError::ModelIsIllConditioned {
3157 condition_number: f64::INFINITY,
3158 });
3159 }
3160 out[[i, col]] = value / diag;
3161 }
3162 }
3163 Ok(out)
3164}
3165
3166fn solve_spd_from_lower_factor(
3170 lower: &Array2<f64>,
3171 rhs: &Array2<f64>,
3172) -> Result<Array2<f64>, EstimationError> {
3173 let forward = solve_lower_triangular_matrix(lower, rhs)?;
3174 solve_upper_triangular_matrix(&lower.t().to_owned(), &forward)
3175}
3176
3177fn solve_upper_triangular_matrix(
3178 upper: &Array2<f64>,
3179 rhs: &Array2<f64>,
3180) -> Result<Array2<f64>, EstimationError> {
3181 let n = upper.nrows();
3182 if upper.ncols() != n || rhs.nrows() != n {
3183 crate::bail_invalid_estim!("upper-triangular solve dimension mismatch");
3184 }
3185 if let Some(out) = gam_gpu::try_solve_upper_triangular_matrix(upper.view(), rhs.view()) {
3186 return Ok(out);
3187 }
3188 let mut out = Array2::<f64>::zeros(rhs.dim());
3189 for col in 0..rhs.ncols() {
3190 for i_rev in 0..n {
3191 let i = n - 1 - i_rev;
3192 let mut value = rhs[[i, col]];
3193 for k in (i + 1)..n {
3194 value -= upper[[i, k]] * out[[k, col]];
3195 }
3196 let diag = upper[[i, i]];
3197 if !(diag.is_finite() && diag.abs() > 0.0) {
3198 return Err(EstimationError::ModelIsIllConditioned {
3199 condition_number: f64::INFINITY,
3200 });
3201 }
3202 out[[i, col]] = value / diag;
3203 }
3204 }
3205 Ok(out)
3206}
3207
3208#[cfg(test)]
3209mod tests {
3210 use super::*;
3211 use ndarray::array;
3212
3213 #[test]
3214 fn edf_does_not_double_count_penalty_nullspace() {
3215 let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0],];
3216 let y = array![[0.0], [1.0], [1.8], [3.2], [4.1]];
3217 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3218 let result =
3219 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3220 .expect("small full-rank Gaussian REML fit");
3221
3222 assert!(result.edf >= result.cache.nullity as f64);
3223 assert!(result.edf <= x.ncols() as f64 + 1.0e-10);
3224 }
3225
3226 #[test]
3227 fn multi_output_duplicate_columns_match_scalar_fit() {
3228 let x = array![
3229 [1.0, -1.0],
3230 [1.0, -0.5],
3231 [1.0, 0.0],
3232 [1.0, 0.5],
3233 [1.0, 1.0],
3234 [1.0, 1.5],
3235 ];
3236 let y1 = array![0.5, 0.2, 0.0, 0.3, 1.1, 2.0];
3237 let y = Array2::from_shape_fn(
3238 (y1.len(), 2),
3239 |(i, j)| if j == 0 { y1[i] } else { 2.0 * y1[i] },
3240 );
3241 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3242
3243 let scalar =
3244 gaussian_reml_closed_form(x.view(), y1.view(), penalty.view(), None, Some(0.0))
3245 .expect("scalar Gaussian REML fit");
3246 let multi =
3247 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3248 .expect("multi-output Gaussian REML fit");
3249
3250 assert!((multi.rho - scalar.rho).abs() <= 1.0e-8);
3251 for i in 0..x.ncols() {
3252 assert!((multi.coefficients[[i, 0]] - scalar.coefficients[i]).abs() <= 1.0e-8);
3253 assert!((multi.coefficients[[i, 1]] - 2.0 * scalar.coefficients[i]).abs() <= 1.0e-8);
3254 }
3255 }
3256
3257 #[test]
3258 fn warm_start_reuses_cache_and_lambda_seed() {
3259 let x = array![
3260 [1.0, -1.0],
3261 [1.0, -0.25],
3262 [1.0, 0.5],
3263 [1.0, 1.25],
3264 [1.0, 2.0],
3265 ];
3266 let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3267 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3268
3269 let cold =
3270 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3271 .expect("cold fit");
3272 let warm_start = GaussianRemlWarmStart::from_multi_result(&cold);
3273 let warm = gaussian_reml_multi_closed_form_warm_started(
3274 x.view(),
3275 y.view(),
3276 penalty.view(),
3277 None,
3278 Some(&warm_start),
3279 )
3280 .expect("warm-started fit");
3281
3282 assert!((cold.lambda - warm.lambda).abs() <= 1.0e-10);
3283 assert_eq!(cold.cache.xtwx_fingerprint, warm.cache.xtwx_fingerprint);
3284 for i in 0..x.ncols() {
3285 assert!((cold.coefficients[[i, 0]] - warm.coefficients[[i, 0]]).abs() <= 1.0e-10);
3286 }
3287 }
3288
3289 #[test]
3290 fn warm_start_cache_rejects_different_penalty_geometry() {
3291 let x = array![
3292 [1.0, -1.0],
3293 [1.0, -0.25],
3294 [1.0, 0.5],
3295 [1.0, 1.25],
3296 [1.0, 2.0],
3297 ];
3298 let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3299 let penalty_a = array![[0.0, 0.0], [0.0, 1.0]];
3300 let penalty_b = array![[1.0, -1.0], [-1.0, 1.0]];
3301
3302 let first =
3303 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty_a.view(), None, Some(0.0))
3304 .expect("first fit");
3305 let warm_start = GaussianRemlWarmStart::from_multi_result(&first);
3306 let err = gaussian_reml_multi_closed_form_warm_started(
3307 x.view(),
3308 y.view(),
3309 penalty_b.view(),
3310 None,
3311 Some(&warm_start),
3312 )
3313 .expect_err("penalty-mismatched cache must be rejected");
3314
3315 assert!(err.to_string().contains("penalty mismatch"));
3316 }
3317
3318 #[test]
3319 fn no_alloc_cache_path_matches_allocating_fit() {
3320 let x = array![
3321 [1.0, -1.0, 0.25],
3322 [1.0, -0.5, 0.10],
3323 [1.0, 0.0, -0.20],
3324 [1.0, 0.5, -0.05],
3325 [1.0, 1.0, 0.30],
3326 [1.0, 1.5, 0.60],
3327 ];
3328 let y = array![
3329 [0.0, 0.2],
3330 [0.3, 0.1],
3331 [0.4, -0.1],
3332 [0.9, 0.3],
3333 [1.6, 0.8],
3334 [2.2, 1.2],
3335 ];
3336 let weights = array![1.0, 0.8, 1.2, 1.1, 0.9, 1.3];
3337 let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]];
3338
3339 let allocating = gaussian_reml_multi_closed_form_with_cache(
3340 x.view(),
3341 y.view(),
3342 penalty.view(),
3343 Some(weights.view()),
3344 Some(1.0),
3345 None,
3346 )
3347 .expect("allocating fit");
3348 let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3349 let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3350 let mut fitted = Array2::zeros(y.dim());
3351 let mut sigma2 = Array1::zeros(y.ncols());
3352
3353 let no_alloc = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3354 x.view(),
3355 y.view(),
3356 penalty.view(),
3357 Some(weights.view()),
3358 Some(allocating.lambda),
3359 &allocating.cache,
3360 &mut workspace,
3361 coefficients.view_mut(),
3362 fitted.view_mut(),
3363 sigma2.view_mut(),
3364 )
3365 .expect("no-alloc cached fit");
3366
3367 assert!((no_alloc.lambda - allocating.lambda).abs() <= 1.0e-10);
3368 assert!((no_alloc.reml_score - allocating.reml_score).abs() <= 1.0e-8);
3369 assert!((no_alloc.reml_grad_rho - allocating.reml_grad_rho).abs() <= 1.0e-8);
3370 assert!((no_alloc.reml_hess_rho - allocating.reml_hess_rho).abs() <= 1.0e-8);
3371 assert!((no_alloc.edf - allocating.edf).abs() <= 1.0e-10);
3372 for i in 0..x.ncols() {
3373 for j in 0..y.ncols() {
3374 assert!((coefficients[[i, j]] - allocating.coefficients[[i, j]]).abs() <= 1.0e-8);
3375 }
3376 }
3377 for i in 0..x.nrows() {
3378 for j in 0..y.ncols() {
3379 assert!((fitted[[i, j]] - allocating.fitted[[i, j]]).abs() <= 1.0e-8);
3380 }
3381 }
3382 for j in 0..y.ncols() {
3383 assert!((sigma2[j] - allocating.sigma2[j]).abs() <= 1.0e-10);
3384 }
3385 }
3386
3387 #[test]
3388 fn no_alloc_cache_path_rejects_bad_shapes_and_penalty_mismatch() {
3389 let x = array![[1.0, -1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 2.0]];
3390 let y = array![[0.0], [0.2], [0.9], [1.8]];
3391 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3392 let cache = build_gaussian_reml_eigen_cache(x.view(), penalty.view(), None)
3393 .expect("Gaussian REML cache");
3394
3395 let mut bad_workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols() + 1);
3396 let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3397 let mut fitted = Array2::zeros(y.dim());
3398 let mut sigma2 = Array1::zeros(y.ncols());
3399 let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3400 x.view(),
3401 y.view(),
3402 penalty.view(),
3403 None,
3404 Some(1.0),
3405 &cache,
3406 &mut bad_workspace,
3407 coefficients.view_mut(),
3408 fitted.view_mut(),
3409 sigma2.view_mut(),
3410 )
3411 .expect_err("workspace shape mismatch must be rejected");
3412 assert!(err.to_string().contains("workspace shape mismatch"));
3413
3414 let penalty_mismatch = array![[1.0, -1.0], [-1.0, 1.0]];
3415 let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3416 let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3417 x.view(),
3418 y.view(),
3419 penalty_mismatch.view(),
3420 None,
3421 Some(1.0),
3422 &cache,
3423 &mut workspace,
3424 coefficients.view_mut(),
3425 fitted.view_mut(),
3426 sigma2.view_mut(),
3427 )
3428 .expect_err("penalty mismatch must be rejected");
3429 assert!(err.to_string().contains("penalty mismatch"));
3430 }
3431
3432 #[derive(Clone, Copy, Debug)]
3433 enum ForwardScalar {
3434 Lambda,
3435 RemlScore,
3436 Coefficient(usize, usize),
3437 Fitted(usize, usize),
3438 Edf,
3439 }
3440
3441 fn finite_difference_design() -> Array2<f64> {
3442 Array2::from_shape_fn((20, 5), |(row, col)| {
3443 let t = (row as f64 - 9.5) / 10.0;
3444 match col {
3445 0 => 1.0,
3446 1 => t,
3447 2 => 0.5 * (3.0 * t * t - 1.0),
3448 3 => 0.5 * (5.0 * t * t * t - 3.0 * t),
3449 4 => (35.0 * t.powi(4) - 30.0 * t * t + 3.0) / 8.0,
3450 _ => unreachable!(),
3451 }
3452 })
3453 }
3454
3455 fn finite_difference_response(outputs: usize) -> Array2<f64> {
3456 Array2::from_shape_fn((20, outputs), |(row, output)| {
3467 let t = (row as f64 - 9.5) / 10.0;
3468 let phase = output as f64 + 1.0;
3469 0.2 + 0.25 * phase * t - 0.12 * t * t
3470 + (0.08 + 0.03 * phase) * (1.1 * t + 0.3 * phase).sin()
3471 + 0.05 * (7.0 * t + 0.5 * phase).sin()
3472 })
3473 }
3474
3475 fn finite_difference_penalty() -> Array2<f64> {
3476 Array2::from_diag(&array![0.0, 0.8, 1.2, 1.7, 2.3])
3477 }
3478
3479 fn finite_difference_weights() -> Array1<f64> {
3480 Array1::from_shape_fn(20, |row| {
3481 let t = (row as f64 - 9.5) / 10.0;
3482 1.0 + 0.025 * (1.1 * t).sin() + 0.01 * t
3483 })
3484 }
3485
3486 fn one_hot_objective_try(
3493 x: ArrayView2<'_, f64>,
3494 y: ArrayView2<'_, f64>,
3495 penalty: ArrayView2<'_, f64>,
3496 weights: ArrayView1<'_, f64>,
3497 target: ForwardScalar,
3498 ) -> Option<f64> {
3499 let fit = gaussian_reml_multi_closed_form_with_cache(
3500 x,
3501 y,
3502 penalty,
3503 Some(weights),
3504 Some(0.85),
3505 None,
3506 )
3507 .ok()?;
3508 Some(match target {
3509 ForwardScalar::Lambda => fit.lambda,
3510 ForwardScalar::RemlScore => fit.reml_score,
3511 ForwardScalar::Coefficient(row, col) => fit.coefficients[[row, col]],
3512 ForwardScalar::Fitted(row, col) => fit.fitted[[row, col]],
3513 ForwardScalar::Edf => fit.edf,
3514 })
3515 }
3516
3517 fn one_hot_objective(
3518 x: ArrayView2<'_, f64>,
3519 y: ArrayView2<'_, f64>,
3520 penalty: ArrayView2<'_, f64>,
3521 weights: ArrayView1<'_, f64>,
3522 target: ForwardScalar,
3523 ) -> f64 {
3524 one_hot_objective_try(x, y, penalty, weights, target)
3525 .expect("finite-difference forward fit")
3526 }
3527
3528 fn one_hot_backward(
3529 x: ArrayView2<'_, f64>,
3530 y: ArrayView2<'_, f64>,
3531 penalty: ArrayView2<'_, f64>,
3532 weights: ArrayView1<'_, f64>,
3533 target: ForwardScalar,
3534 ) -> GaussianRemlBackwardResult {
3535 let mut grad_coefficients = Array2::<f64>::zeros((x.ncols(), y.ncols()));
3536 let mut grad_fitted = Array2::<f64>::zeros(y.dim());
3537 let (grad_lambda, grad_score, grad_edf, coefficient_upstream, fitted_upstream) =
3538 match target {
3539 ForwardScalar::Lambda => (1.0, 0.0, 0.0, None, None),
3540 ForwardScalar::RemlScore => (0.0, 1.0, 0.0, None, None),
3541 ForwardScalar::Coefficient(row, col) => {
3542 grad_coefficients[[row, col]] = 1.0;
3543 (0.0, 0.0, 0.0, Some(grad_coefficients.view()), None)
3544 }
3545 ForwardScalar::Fitted(row, col) => {
3546 grad_fitted[[row, col]] = 1.0;
3547 (0.0, 0.0, 0.0, None, Some(grad_fitted.view()))
3548 }
3549 ForwardScalar::Edf => (0.0, 0.0, 1.0, None, None),
3550 };
3551 gaussian_reml_multi_closed_form_backward(
3552 x,
3553 y,
3554 penalty,
3555 Some(weights),
3556 Some(0.85),
3557 grad_lambda,
3558 coefficient_upstream,
3559 fitted_upstream,
3560 grad_score,
3561 grad_edf,
3562 )
3563 .expect("analytic backward VJP")
3564 }
3565
3566 fn assert_fd_close(label: &str, analytic: f64, finite_difference: f64) {
3567 let rel_tol = 1.0e-6_f64;
3568 let abs_tol = 1.0e-6_f64;
3569 let tol = abs_tol.max(rel_tol * analytic.abs().max(finite_difference.abs()));
3570 let diff = (analytic - finite_difference).abs();
3571 assert!(
3572 diff <= tol,
3573 "{label}: analytic={analytic:.12e}, finite_difference={finite_difference:.12e}, diff={diff:.3e}, tol={tol:.3e}"
3574 );
3575 }
3576
3577 fn adaptive_central_difference(mut eval: impl FnMut(f64) -> f64) -> f64 {
3578 let steps: [f64; 5] = [1.0e-3, 5.0e-4, 2.5e-4, 1.25e-4, 6.25e-5];
3579 let mut best = f64::NAN;
3580 let mut best_delta = f64::INFINITY;
3581 let mut previous: Option<f64> = None;
3582 for h in steps {
3583 let d1 = (eval(h) - eval(-h)) / (2.0 * h);
3584 let half_h = 0.5 * h;
3585 let d2 = (eval(half_h) - eval(-half_h)) / (2.0 * half_h);
3586 let estimate: f64 = d2 + (d2 - d1) / 3.0;
3587 if let Some(prev) = previous {
3588 let delta = (estimate - prev).abs();
3589 if delta < best_delta {
3590 best_delta = delta;
3591 best = estimate;
3592 }
3593 } else {
3594 best = estimate;
3595 }
3596 previous = Some(estimate);
3597 }
3598 best
3599 }
3600
3601 fn assert_backward_matches_forward_finite_difference(outputs: usize) {
3602 let x = finite_difference_design();
3603 let y = finite_difference_response(outputs);
3604 let penalty = finite_difference_penalty();
3605 let weights = finite_difference_weights();
3606 let targets = [
3607 ForwardScalar::Lambda,
3608 ForwardScalar::RemlScore,
3609 ForwardScalar::Coefficient(3, outputs - 1),
3610 ForwardScalar::Fitted(12, outputs - 1),
3611 ForwardScalar::Edf,
3612 ];
3613 for target in targets {
3614 let backward =
3615 one_hot_backward(x.view(), y.view(), penalty.view(), weights.view(), target);
3616
3617 for row in 0..x.nrows() {
3618 for col in 0..x.ncols() {
3619 let eval = |delta: f64| {
3620 let mut candidate = x.clone();
3621 candidate[[row, col]] += delta;
3622 one_hot_objective(
3623 candidate.view(),
3624 y.view(),
3625 penalty.view(),
3626 weights.view(),
3627 target,
3628 )
3629 };
3630 let fd = adaptive_central_difference(eval);
3631 assert_fd_close(
3632 &format!("target={target:?} x[{row},{col}]"),
3633 backward.grad_x[[row, col]],
3634 fd,
3635 );
3636 }
3637 }
3638
3639 for row in 0..y.nrows() {
3640 for col in 0..y.ncols() {
3641 let eval = |delta: f64| {
3642 let mut candidate = y.clone();
3643 candidate[[row, col]] += delta;
3644 one_hot_objective(
3645 x.view(),
3646 candidate.view(),
3647 penalty.view(),
3648 weights.view(),
3649 target,
3650 )
3651 };
3652 let fd = adaptive_central_difference(eval);
3653 assert_fd_close(
3654 &format!("target={target:?} y[{row},{col}]"),
3655 backward.grad_y[[row, col]],
3656 fd,
3657 );
3658 }
3659 }
3660
3661 for row in 0..weights.len() {
3662 let eval = |delta: f64| {
3663 let mut candidate = weights.clone();
3664 candidate[row] += delta;
3665 one_hot_objective(x.view(), y.view(), penalty.view(), candidate.view(), target)
3666 };
3667 let fd = adaptive_central_difference(eval);
3668 assert_fd_close(
3669 &format!("target={target:?} weights[{row}]"),
3670 backward.grad_weights[row],
3671 fd,
3672 );
3673 }
3674
3675 let null_index = 0usize; let probe_h = 1.0e-3_f64; for r in 0..penalty.nrows() {
3699 for c in 0..penalty.ncols() {
3700 if r == null_index || c == null_index {
3701 continue;
3702 }
3703 let eval = |delta: f64| {
3704 let mut candidate = penalty.clone();
3705 candidate[[r, c]] += delta;
3706 one_hot_objective(
3707 x.view(),
3708 y.view(),
3709 candidate.view(),
3710 weights.view(),
3711 target,
3712 )
3713 };
3714 let cone_safe = {
3715 let mut s_plus = penalty.clone();
3716 let mut s_minus = penalty.clone();
3717 s_plus[[r, c]] += probe_h;
3718 s_minus[[r, c]] -= probe_h;
3719 one_hot_objective_try(
3720 x.view(),
3721 y.view(),
3722 s_plus.view(),
3723 weights.view(),
3724 target,
3725 )
3726 .is_some()
3727 && one_hot_objective_try(
3728 x.view(),
3729 y.view(),
3730 s_minus.view(),
3731 weights.view(),
3732 target,
3733 )
3734 .is_some()
3735 };
3736 if !cone_safe {
3737 continue;
3738 }
3739 let fd = adaptive_central_difference(eval);
3740 assert_fd_close(
3741 &format!("target={target:?} penalty[{r},{c}]"),
3742 backward.grad_penalty[[r, c]],
3743 fd,
3744 );
3745 }
3746 }
3747 }
3748 }
3749
3750 #[test]
3751 fn scalar_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3752 assert_backward_matches_forward_finite_difference(1);
3753 }
3754
3755 #[test]
3756 fn multi_output_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3757 assert_backward_matches_forward_finite_difference(3);
3758 }
3759
3760 #[test]
3761 fn backward_vjp_matches_finite_difference() {
3762 let x = array![
3763 [1.0, -1.0, 0.2],
3764 [1.0, -0.3, -0.1],
3765 [1.0, 0.2, 0.4],
3766 [1.0, 0.8, 0.1],
3767 [1.0, 1.4, 0.5],
3768 [1.0, 2.0, 0.9],
3769 ];
3770 let y = array![
3771 [0.1, -0.2],
3772 [0.2, 0.1],
3773 [0.7, 0.0],
3774 [1.1, 0.3],
3775 [1.8, 0.9],
3776 [2.4, 1.4],
3777 ];
3778 let weights = array![1.0, 0.9, 1.1, 1.2, 0.8, 1.3];
3779 let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.2], [0.0, 0.2, 1.7]];
3780 let upstream_coefficients = array![[0.2, -0.1], [0.05, 0.03], [-0.04, 0.07]];
3781 let upstream_fitted = array![
3782 [0.01, -0.02],
3783 [0.03, 0.01],
3784 [-0.01, 0.02],
3785 [0.04, -0.03],
3786 [0.02, 0.05],
3787 [-0.02, 0.01],
3788 ];
3789 let upstream_lambda = 0.17;
3790 let upstream_score = -0.11;
3791
3792 let backward = gaussian_reml_multi_closed_form_backward(
3793 x.view(),
3794 y.view(),
3795 penalty.view(),
3796 Some(weights.view()),
3797 Some(0.8),
3798 upstream_lambda,
3799 Some(upstream_coefficients.view()),
3800 Some(upstream_fitted.view()),
3801 upstream_score,
3802 0.0,
3803 )
3804 .expect("backward VJP");
3805
3806 let objective = |x_eval: &Array2<f64>, y_eval: &Array2<f64>, w_eval: &Array1<f64>| {
3807 let fit = gaussian_reml_multi_closed_form_with_cache(
3808 x_eval.view(),
3809 y_eval.view(),
3810 penalty.view(),
3811 Some(w_eval.view()),
3812 Some(0.8),
3813 None,
3814 )
3815 .expect("fit for objective");
3816 upstream_lambda * fit.lambda
3817 + upstream_score * fit.reml_score
3818 + (&fit.coefficients * &upstream_coefficients).sum()
3819 + (&fit.fitted * &upstream_fitted).sum()
3820 };
3821 let eps = 1.0e-6;
3822 assert!(objective(&x, &y, &weights).is_finite());
3823
3824 let mut x_plus = x.clone();
3825 let mut x_minus = x.clone();
3826 x_plus[[3, 2]] += eps;
3827 x_minus[[3, 2]] -= eps;
3828 let fd_x =
3829 (objective(&x_plus, &y, &weights) - objective(&x_minus, &y, &weights)) / (2.0 * eps);
3830 assert!(
3831 (fd_x - backward.grad_x[[3, 2]]).abs() <= 2.0e-4,
3832 "grad_x mismatch: analytic={} fd={}",
3833 backward.grad_x[[3, 2]],
3834 fd_x
3835 );
3836
3837 let mut y_plus = y.clone();
3838 let mut y_minus = y.clone();
3839 y_plus[[4, 1]] += eps;
3840 y_minus[[4, 1]] -= eps;
3841 let fd_y =
3842 (objective(&x, &y_plus, &weights) - objective(&x, &y_minus, &weights)) / (2.0 * eps);
3843 assert!(
3844 (fd_y - backward.grad_y[[4, 1]]).abs() <= 2.0e-4,
3845 "grad_y mismatch: analytic={} fd={}",
3846 backward.grad_y[[4, 1]],
3847 fd_y
3848 );
3849
3850 let mut w_plus = weights.clone();
3851 let mut w_minus = weights.clone();
3852 w_plus[2] += eps;
3853 w_minus[2] -= eps;
3854 let fd_w = (objective(&x, &y, &w_plus) - objective(&x, &y, &w_minus)) / (2.0 * eps);
3855 assert!(
3856 (fd_w - backward.grad_weights[2]).abs() <= 2.0e-4,
3857 "grad_weight mismatch: analytic={} fd={}",
3858 backward.grad_weights[2],
3859 fd_w
3860 );
3861
3862 let objective_s = |s_eval: &Array2<f64>| {
3874 let fit = gaussian_reml_multi_closed_form_with_cache(
3875 x.view(),
3876 y.view(),
3877 s_eval.view(),
3878 Some(weights.view()),
3879 Some(0.8),
3880 None,
3881 )
3882 .expect("fit for penalty objective");
3883 upstream_lambda * fit.lambda
3884 + upstream_score * fit.reml_score
3885 + (&fit.coefficients * &upstream_coefficients).sum()
3886 + (&fit.fitted * &upstream_fitted).sum()
3887 };
3888 for (r, c) in [(1usize, 1usize), (1, 2), (2, 2)] {
3892 let mut s_plus = penalty.clone();
3893 let mut s_minus = penalty.clone();
3894 s_plus[[r, c]] += eps;
3895 s_minus[[r, c]] -= eps;
3896 let fd_s = (objective_s(&s_plus) - objective_s(&s_minus)) / (2.0 * eps);
3897 assert!(
3898 (fd_s - backward.grad_penalty[[r, c]]).abs() <= 2.0e-4,
3899 "grad_penalty[{r},{c}] mismatch: analytic={} fd={}",
3900 backward.grad_penalty[[r, c]],
3901 fd_s
3902 );
3903 }
3904 }
3905
3906 #[test]
3907 fn batched_eigen_cache_matches_per_fit_build() {
3908 let xtwx_a = array![[4.0, 1.0], [1.0, 3.0]];
3914 let xtwx_b = array![[2.5, -0.5], [-0.5, 1.7]];
3915 let xtwx_c = array![[7.2, 0.3], [0.3, 5.1]];
3916 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3917
3918 let batched = build_gaussian_reml_eigen_cache_batched(
3919 vec![xtwx_a.clone(), xtwx_b.clone(), xtwx_c.clone()],
3920 penalty.view(),
3921 None,
3922 );
3923 assert_eq!(batched.len(), 3);
3924
3925 for (xtwx, batched_cache) in [&xtwx_a, &xtwx_b, &xtwx_c].into_iter().zip(batched.iter()) {
3926 let single = gaussian_reml_eigen_cache_from_xtwx(xtwx.clone(), penalty.view(), None)
3927 .expect("per-fit cache");
3928 let batched_cache = batched_cache.as_ref().expect("batched cache");
3929 assert_eq!(batched_cache.penalty_rank, single.penalty_rank);
3930 assert_eq!(batched_cache.nullity, single.nullity);
3931 assert_eq!(batched_cache.xtwx_fingerprint, single.xtwx_fingerprint);
3932 assert_eq!(
3933 batched_cache.penalty_fingerprint,
3934 single.penalty_fingerprint
3935 );
3936 assert!((batched_cache.logdet_xtwx - single.logdet_xtwx).abs() <= 1.0e-12);
3937 assert!(
3938 (batched_cache.logdet_penalty_positive - single.logdet_penalty_positive).abs()
3939 <= 1.0e-12
3940 );
3941 for (a, b) in batched_cache
3942 .penalty_eigenvalues
3943 .iter()
3944 .zip(single.penalty_eigenvalues.iter())
3945 {
3946 assert!((a - b).abs() <= 1.0e-12);
3947 }
3948 for ((a, b), _) in batched_cache
3949 .coefficient_basis
3950 .iter()
3951 .zip(single.coefficient_basis.iter())
3952 .zip(0..)
3953 {
3954 assert!((a - b).abs() <= 1.0e-12);
3955 }
3956 }
3957 }
3958
3959 #[test]
3960 fn scalar_rho_optimizer_chooses_lowest_cost_stationary_point() {
3961 let cache = GaussianRemlEigenCache {
3962 penalty_eigenvalues: array![5.2430192311066924e-05, 81734184.18548436],
3963 eigenvectors: Array2::eye(2),
3964 coefficient_basis: Array2::eye(2),
3965 xtwx_fingerprint: 0,
3966 penalty_fingerprint: 0,
3967 logdet_xtwx: 0.0,
3968 logdet_penalty_positive: 0.0,
3969 penalty_rank: 2,
3970 nullity: 0,
3971 };
3972 let prepared = GaussianRemlPrepared {
3973 cache: cache.clone(),
3974 ywy: array![0.5021347226586624],
3975 projected_rhs_squared: array![[0.361060218768292], [0.01014486085547482]],
3976 projected_rhs: array![
3977 [0.361060218768292_f64.sqrt()],
3978 [0.01014486085547482_f64.sqrt()]
3979 ],
3980 n_effective: 100,
3981 n_outputs: 1,
3982 };
3983
3984 let rho = optimize_rho(&prepared, None).expect("allocating rho optimizer");
3985 let no_alloc_rho = optimize_rho_no_alloc(
3986 &cache,
3987 prepared.ywy.view(),
3988 prepared.projected_rhs_squared.view(),
3989 prepared.n_effective,
3990 prepared.n_outputs,
3991 None,
3992 )
3993 .expect("no-alloc rho optimizer");
3994
3995 assert!(
3996 (rho - 4.3251059890).abs() < 1.0e-6,
3997 "rho optimizer selected {rho}, expected the lower-cost later stationary point"
3998 );
3999 assert!(
4000 (no_alloc_rho - rho).abs() < 1.0e-8,
4001 "no-alloc optimizer selected {no_alloc_rho}, allocating selected {rho}"
4002 );
4003 assert!(prepared.evaluate(rho).cost < prepared.evaluate(-18.9277503549).cost);
4004 }
4005
4006 #[test]
4007 fn backward_from_fit_matches_backward_with_refit() {
4008 let x = array![[1.0, -0.9], [1.0, -0.4], [1.0, 0.1], [1.0, 0.6], [1.0, 1.1],];
4013 let y = array![[0.2, -0.1], [0.4, 0.1], [0.7, 0.3], [1.0, 0.5], [1.5, 0.8]];
4014 let penalty = array![[0.0, 0.0], [0.0, 1.5]];
4015 let weights = array![1.05, 0.95, 1.01, 0.99, 1.03];
4016
4017 let refit = gaussian_reml_multi_closed_form_backward(
4018 x.view(),
4019 y.view(),
4020 penalty.view(),
4021 Some(weights.view()),
4022 Some(0.85),
4023 0.2,
4024 None,
4025 None,
4026 -0.1,
4027 0.0,
4028 )
4029 .expect("refit backward");
4030
4031 let fit = gaussian_reml_multi_closed_form_with_cache(
4032 x.view(),
4033 y.view(),
4034 penalty.view(),
4035 Some(weights.view()),
4036 Some(0.85),
4037 None,
4038 )
4039 .expect("forward fit");
4040 let from_fit = gaussian_reml_multi_closed_form_backward_from_fit(
4041 x.view(),
4042 y.view(),
4043 penalty.view(),
4044 Some(weights.view()),
4045 &fit,
4046 0.2,
4047 None,
4048 None,
4049 -0.1,
4050 0.0,
4051 )
4052 .expect("from_fit backward");
4053
4054 for (a, b) in refit.grad_x.iter().zip(from_fit.grad_x.iter()) {
4055 assert!((a - b).abs() <= 1.0e-12);
4056 }
4057 for (a, b) in refit.grad_y.iter().zip(from_fit.grad_y.iter()) {
4058 assert!((a - b).abs() <= 1.0e-12);
4059 }
4060 for (a, b) in refit.grad_weights.iter().zip(from_fit.grad_weights.iter()) {
4061 assert!((a - b).abs() <= 1.0e-12);
4062 }
4063 }
4064
4065 #[test]
4075 fn backward_degrades_gracefully_when_k_is_near_singular() {
4076 let x = array![
4080 [1.0, -1.0, 0.5],
4081 [1.0, -0.5, 0.2],
4082 [1.0, 0.0, -0.1],
4083 [1.0, 0.5, 0.3],
4084 [1.0, 1.0, 0.8],
4085 [1.0, 1.5, 1.1],
4086 [1.0, 2.0, 1.5],
4087 [1.0, 2.5, 2.0],
4088 [1.0, 3.0, 2.6],
4089 [1.0, 3.5, 3.1],
4090 ];
4091 let y = array![
4092 [0.1],
4093 [0.3],
4094 [0.4],
4095 [0.7],
4096 [1.0],
4097 [1.5],
4098 [2.0],
4099 [2.7],
4100 [3.3],
4101 [4.0]
4102 ];
4103 let penalty = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
4105
4106 let mut fit =
4107 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
4108 .expect("forward fit must succeed for well-posed input");
4109 fit.reml_hess_rho = 0.0;
4113
4114 let result = gaussian_reml_multi_closed_form_backward_from_fit(
4115 x.view(),
4116 y.view(),
4117 penalty.view(),
4118 None,
4119 &fit,
4120 1.0,
4123 None,
4124 None,
4125 1.0,
4126 1.0,
4127 )
4128 .expect("backward must NOT error on near-singular K");
4129
4130 assert_eq!(result.grad_x.dim(), (x.nrows(), x.ncols()));
4131 assert_eq!(result.grad_y.dim(), (y.nrows(), y.ncols()));
4132 assert_eq!(result.grad_penalty.dim(), (x.ncols(), x.ncols()));
4133 assert_eq!(result.grad_weights.dim(), x.nrows());
4134 for v in result.grad_x.iter() {
4135 assert!(v.is_finite(), "grad_x must be finite, got {v}");
4136 }
4137 for v in result.grad_y.iter() {
4138 assert!(v.is_finite(), "grad_y must be finite, got {v}");
4139 }
4140 for v in result.grad_penalty.iter() {
4141 assert!(v.is_finite(), "grad_penalty must be finite, got {v}");
4142 }
4143 for v in result.grad_weights.iter() {
4144 assert!(v.is_finite(), "grad_weights must be finite, got {v}");
4145 }
4146 }
4147}
4148
4149pub struct GaussianRemlBlocksBackwardAnalytic {
4153 pub grad_designs: Vec<Array2<f64>>,
4154 pub grad_penalties: Vec<Array2<f64>>,
4155 pub grad_y: Array2<f64>,
4156 pub grad_weights: Array1<f64>,
4157}
4158
4159pub fn gaussian_reml_fit_blocks_backward_analytic(
4168 designs: &[Array2<f64>],
4169 penalties_raw: &[Array2<f64>],
4170 y: ArrayView1<'_, f64>,
4171 weights: ArrayView1<'_, f64>,
4172 rhos: &[f64],
4173 grad_coefficients: Option<ArrayView2<'_, f64>>,
4174 grad_fitted: Option<ArrayView2<'_, f64>>,
4175 grad_lambdas: Option<ArrayView1<'_, f64>>,
4176 grad_log_lambdas: Option<ArrayView1<'_, f64>>,
4177 grad_reml_score: f64,
4178 grad_edf: Option<ArrayView1<'_, f64>>,
4179) -> Result<GaussianRemlBlocksBackwardAnalytic, EstimationError> {
4180 let n = y.len();
4181 let f_blocks = designs.len();
4182 let mut offsets = Vec::with_capacity(f_blocks + 1);
4183 offsets.push(0_usize);
4184 for design in designs {
4185 offsets.push(offsets.last().copied().unwrap() + design.ncols());
4186 }
4187 let p_total = *offsets.last().unwrap();
4188 if n == 0 || p_total == 0 {
4189 return Err(EstimationError::InvalidInput(
4190 "gaussian_reml_fit_blocks_backward requires non-empty rows and at least one coefficient column"
4191 .to_string(),
4192 ));
4193 }
4194
4195 if rhos.len() != f_blocks {
4196 return Err(EstimationError::InvalidInput(format!(
4197 "log_lambdas length mismatch: expected {f_blocks}, got {}",
4198 rhos.len()
4199 )));
4200 }
4201 if let Some(gc) = grad_coefficients {
4202 if gc.dim() != (p_total, 1) {
4203 return Err(EstimationError::InvalidInput(format!(
4204 "grad_coefficients shape mismatch: expected {}x1, got {}x{}",
4205 p_total,
4206 gc.nrows(),
4207 gc.ncols()
4208 )));
4209 }
4210 }
4211 if let Some(gf) = grad_fitted {
4212 if gf.dim() != (n, 1) {
4213 return Err(EstimationError::InvalidInput(format!(
4214 "grad_fitted shape mismatch: expected {}x1, got {}x{}",
4215 n,
4216 gf.nrows(),
4217 gf.ncols()
4218 )));
4219 }
4220 }
4221 if !grad_reml_score.is_finite() {
4222 return Err(EstimationError::InvalidInput(format!(
4223 "grad_reml_score must be finite; got {grad_reml_score}"
4224 )));
4225 }
4226 if let Some(vec) = grad_lambdas {
4227 if vec.len() != f_blocks {
4228 return Err(EstimationError::InvalidInput(format!(
4229 "grad_lambdas length mismatch: expected {f_blocks}, got {}",
4230 vec.len()
4231 )));
4232 }
4233 }
4234 if let Some(vec) = grad_log_lambdas {
4235 if vec.len() != f_blocks {
4236 return Err(EstimationError::InvalidInput(format!(
4237 "grad_log_lambdas length mismatch: expected {f_blocks}, got {}",
4238 vec.len()
4239 )));
4240 }
4241 }
4242 if let Some(vec) = grad_edf {
4243 if vec.len() != f_blocks {
4244 return Err(EstimationError::InvalidInput(format!(
4245 "grad_edf length mismatch: expected {f_blocks}, got {}",
4246 vec.len()
4247 )));
4248 }
4249 }
4250 if let Some(gc) = grad_coefficients {
4251 if let Some(((row, col), value)) = gc.indexed_iter().find(|(_, value)| !value.is_finite()) {
4252 return Err(EstimationError::InvalidInput(format!(
4253 "grad_coefficients[{row},{col}] must be finite; got {value}"
4254 )));
4255 }
4256 }
4257 if let Some(gf) = grad_fitted {
4258 if let Some(((row, col), value)) = gf.indexed_iter().find(|(_, value)| !value.is_finite()) {
4259 return Err(EstimationError::InvalidInput(format!(
4260 "grad_fitted[{row},{col}] must be finite; got {value}"
4261 )));
4262 }
4263 }
4264 if let Some(vec) = grad_lambdas {
4265 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4266 return Err(EstimationError::InvalidInput(format!(
4267 "grad_lambdas[{block}] must be finite; got {value}"
4268 )));
4269 }
4270 }
4271 if let Some(vec) = grad_log_lambdas {
4272 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4273 return Err(EstimationError::InvalidInput(format!(
4274 "grad_log_lambdas[{block}] must be finite; got {value}"
4275 )));
4276 }
4277 }
4278 if let Some(vec) = grad_edf {
4279 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4280 return Err(EstimationError::InvalidInput(format!(
4281 "grad_edf[{block}] must be finite; got {value}"
4282 )));
4283 }
4284 }
4285 for (block, design) in designs.iter().enumerate() {
4286 if let Some(((row, col), value)) =
4287 design.indexed_iter().find(|(_, value)| !value.is_finite())
4288 {
4289 return Err(EstimationError::InvalidInput(format!(
4290 "designs[{block}][{row},{col}] must be finite; got {value}"
4291 )));
4292 }
4293 }
4294 for (block, penalty) in penalties_raw.iter().enumerate() {
4295 if let Some(((row, col), value)) =
4296 penalty.indexed_iter().find(|(_, value)| !value.is_finite())
4297 {
4298 return Err(EstimationError::InvalidInput(format!(
4299 "penalties[{block}][{row},{col}] must be finite; got {value}"
4300 )));
4301 }
4302 }
4303 if let Some((row, value)) = y.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4304 return Err(EstimationError::InvalidInput(format!(
4305 "y[{row}] must be finite; got {value}"
4306 )));
4307 }
4308 if let Some((row, value)) = weights
4309 .iter()
4310 .enumerate()
4311 .find(|(_, value)| !value.is_finite() || **value < 0.0)
4312 {
4313 return Err(EstimationError::InvalidInput(format!(
4314 "weights[{row}] must be finite and non-negative; got {value}"
4315 )));
4316 }
4317
4318 let mut z = Array2::<f64>::zeros((n, p_total));
4319 for k in 0..f_blocks {
4320 z.slice_mut(s![.., offsets[k]..offsets[k + 1]])
4321 .assign(&designs[k]);
4322 }
4323
4324 let penalties: Vec<Array2<f64>> = penalties_raw
4325 .iter()
4326 .map(|p| {
4327 let mut out = p.clone();
4328 gam_linalg::matrix::symmetrize_in_place(&mut out);
4329 out
4330 })
4331 .collect();
4332 let mut ranks = Vec::with_capacity(f_blocks);
4333 let mut pinvs = Vec::with_capacity(f_blocks);
4334 for penalty in &penalties {
4335 let (rank, pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(penalty)?;
4336 ranks.push(rank);
4337 pinvs.push(pinv);
4338 }
4339
4340 let lambdas = Array1::from_iter(rhos.iter().map(|rho| rho.exp()));
4341 if let Some((block, lambda)) = lambdas
4342 .iter()
4343 .enumerate()
4344 .find(|(_, lambda)| !lambda.is_finite() || **lambda <= 0.0)
4345 {
4346 return Err(EstimationError::InvalidInput(format!(
4347 "exp(log_lambdas[{block}]) must be finite and positive; got {lambda}"
4348 )));
4349 }
4350 let mut k_matrix = fast_xt_diag_x(&z.view(), &weights);
4351 for block in 0..f_blocks {
4352 let lambda = lambdas[block];
4353 for local_i in 0..penalties[block].nrows() {
4354 let global_i = offsets[block] + local_i;
4355 for local_j in 0..penalties[block].ncols() {
4356 let global_j = offsets[block] + local_j;
4357 k_matrix[[global_i, global_j]] += lambda * penalties[block][[local_i, local_j]];
4358 }
4359 }
4360 }
4361 let r = gam_linalg::utils::invert_spd_with_ridge(&k_matrix, 0.0)?;
4362
4363 let mut xtwy = Array1::<f64>::zeros(p_total);
4364 for row in 0..n {
4365 let wy = weights[row] * y[row];
4366 for col in 0..p_total {
4367 xtwy[col] += z[[row, col]] * wy;
4368 }
4369 }
4370 let beta = r.dot(&xtwy);
4371 let fitted = z.dot(&beta);
4372 if let Some((col, value)) = beta
4373 .iter()
4374 .enumerate()
4375 .find(|(_, value)| !value.is_finite())
4376 {
4377 return Err(EstimationError::InvalidInput(format!(
4378 "solved coefficient {col} is non-finite: {value}"
4379 )));
4380 }
4381 let residual = &y.to_owned() - &fitted;
4382 let weighted_residual = &residual * &weights.to_owned();
4383 let ywy = y
4384 .iter()
4385 .zip(weights.iter())
4386 .map(|(&yi, &wi)| wi * yi * yi)
4387 .sum::<f64>();
4388 let q_raw = ywy - xtwy.dot(&beta);
4389 if !q_raw.is_finite() {
4390 return Err(EstimationError::InvalidInput(format!(
4391 "Gaussian REML residual quadratic form must be finite; got {q_raw}"
4392 )));
4393 }
4394 let q = q_raw.max(1.0e-300);
4395 let nullity = penalties
4396 .iter()
4397 .zip(ranks.iter())
4398 .map(|(penalty, rank)| penalty.nrows().saturating_sub(*rank))
4399 .sum::<usize>();
4400 let nu = effective_observation_count(weights) as f64 - nullity as f64;
4403 if !(nu.is_finite() && nu > 0.0) {
4404 return Err(EstimationError::InvalidInput(format!(
4405 "Gaussian REML residual degrees of freedom must be positive; got {nu}"
4406 )));
4407 }
4408 let tau = nu / q;
4409 let tau_q = -nu / (q * q);
4410 if !(tau.is_finite() && tau_q.is_finite()) {
4411 return Err(EstimationError::InvalidInput(format!(
4412 "Gaussian REML scale derivatives are non-finite: tau={tau}, tau_q={tau_q}"
4413 )));
4414 }
4415
4416 let mut grad_z = Array2::<f64>::zeros((n, p_total));
4417 let mut g_kernel = Array2::<f64>::zeros((p_total, p_total));
4418 let mut h_kernel = Array1::<f64>::zeros(p_total);
4419 let mut q_kernel = 0.0_f64;
4420 let mut j_blocks: Vec<Array2<f64>> = penalties
4421 .iter()
4422 .map(|p| Array2::<f64>::zeros(p.dim()))
4423 .collect();
4424
4425 let mut beta_tilde = Array1::<f64>::zeros(p_total);
4426 if let Some(gc) = grad_coefficients {
4427 beta_tilde += &gc.column(0).to_owned();
4428 }
4429 if let Some(gf) = grad_fitted {
4430 let gf_col = gf.column(0).to_owned();
4431 beta_tilde += &z.t().dot(&gf_col);
4432 for row in 0..n {
4433 for col in 0..p_total {
4434 grad_z[[row, col]] += gf_col[row] * beta[col];
4435 }
4436 }
4437 }
4438
4439 let u = r.dot(&beta_tilde);
4444 h_kernel += &u;
4445 for i in 0..p_total {
4446 for j in 0..p_total {
4447 g_kernel[[i, j]] -= 0.5 * (beta[i] * u[j] + u[i] * beta[j]);
4448 }
4449 }
4450
4451 let mut alpha = Array1::<f64>::zeros(f_blocks);
4452 if let Some(gl) = grad_lambdas {
4453 for block in 0..f_blocks {
4454 alpha[block] += gl[block] * lambdas[block];
4455 }
4456 }
4457 if let Some(grho) = grad_log_lambdas {
4458 alpha += &grho.to_owned();
4459 }
4460
4461 let mut p_betas = Vec::with_capacity(f_blocks);
4462 let mut m_vectors = Vec::with_capacity(f_blocks);
4463 let mut rp_matrices = Vec::with_capacity(f_blocks);
4464 let mut rpr_matrices = Vec::with_capacity(f_blocks);
4465 let mut b_values = Array1::<f64>::zeros(f_blocks);
4466 let mut t_values = Array1::<f64>::zeros(f_blocks);
4467
4468 for block in 0..f_blocks {
4469 let start = offsets[block];
4470 let end = offsets[block + 1];
4471 let beta_k = beta.slice(s![start..end]).to_owned();
4472 let s_beta = penalties[block].dot(&beta_k);
4473 let lambda = lambdas[block];
4474 let lambda_s_beta = s_beta.mapv(|value| lambda * value);
4475 let mut p_beta = Array1::<f64>::zeros(p_total);
4476 for local_i in 0..(end - start) {
4477 p_beta[start + local_i] = lambda_s_beta[local_i];
4478 }
4479 let weighted_penalty = penalties[block].mapv(|value| lambda * value);
4480 let rp_block = r.slice(s![.., start..end]).dot(&weighted_penalty);
4481 let mut rp = Array2::<f64>::zeros((p_total, p_total));
4482 rp.slice_mut(s![.., start..end]).assign(&rp_block);
4483 let rpr = rp_block.dot(&r.slice(s![start..end, ..]));
4484 let m = r.slice(s![.., start..end]).dot(&lambda_s_beta);
4485 b_values[block] = beta.dot(&p_beta);
4486 t_values[block] = (0..(end - start))
4487 .map(|local_i| rp_block[[start + local_i, local_i]])
4488 .sum::<f64>();
4489 alpha[block] -= u.dot(&p_beta);
4490 p_betas.push(p_beta);
4491 m_vectors.push(m);
4492 rp_matrices.push(rp);
4493 rpr_matrices.push(rpr);
4494 }
4495
4496 if grad_reml_score != 0.0 {
4497 q_kernel += 0.5 * grad_reml_score * tau;
4498 g_kernel += &(r.clone() * (0.5 * grad_reml_score));
4499 for block in 0..f_blocks {
4500 j_blocks[block] -= &(pinvs[block].clone() * (0.5 * grad_reml_score / lambdas[block]));
4501 }
4502 }
4503
4504 let mut trace_pairs = Array2::<f64>::zeros((f_blocks, f_blocks));
4505 for i in 0..f_blocks {
4506 for j in 0..f_blocks {
4507 trace_pairs[[i, j]] = gam_linalg::utils::trace_of_product(
4508 rp_matrices[i].view(),
4509 rp_matrices[j].view(),
4510 );
4511 }
4512 }
4513
4514 if let Some(ge) = grad_edf {
4515 for edf_block in 0..f_blocks {
4516 let scale = ge[edf_block];
4517 if scale == 0.0 {
4518 continue;
4519 }
4520 let start = offsets[edf_block];
4521 let end = offsets[edf_block + 1];
4522 g_kernel += &(rpr_matrices[edf_block].clone() * scale);
4523 j_blocks[edf_block] -= &(r.slice(s![start..end, start..end]).to_owned() * scale);
4524 for rho_block in 0..f_blocks {
4525 alpha[rho_block] += scale * trace_pairs[[edf_block, rho_block]];
4526 if rho_block == edf_block {
4527 alpha[rho_block] -= scale * t_values[edf_block];
4528 }
4529 }
4530 }
4531 }
4532
4533 if let Some((block, value)) = alpha
4534 .iter()
4535 .enumerate()
4536 .find(|(_, value)| !value.is_finite())
4537 {
4538 return Err(EstimationError::InvalidInput(format!(
4539 "rho adjoint seed for block {block} is non-finite: {value}"
4540 )));
4541 }
4542
4543 if alpha.iter().any(|value| *value != 0.0) {
4544 let mut outer_h = Array2::<f64>::zeros((f_blocks, f_blocks));
4545 for k in 0..f_blocks {
4546 for j in 0..f_blocks {
4547 let beta_pk_r_pj_beta = p_betas[k].dot(&m_vectors[j]);
4548 outer_h[[k, j]] = 0.5 * trace_pairs[[k, j]] + tau * beta_pk_r_pj_beta
4549 - if k == j {
4550 0.5 * (t_values[k] + tau * b_values[k])
4551 } else {
4552 0.0
4553 }
4554 - 0.5 * tau_q * b_values[k] * b_values[j];
4555 }
4556 }
4557 gam_linalg::matrix::symmetrize_in_place(&mut outer_h);
4561 if let Some(((row, col), value)) =
4562 outer_h.indexed_iter().find(|(_, value)| !value.is_finite())
4563 {
4564 return Err(EstimationError::InvalidInput(format!(
4565 "outer rho curvature entry ({row},{col}) is non-finite: {value}"
4566 )));
4567 }
4568 let rho_adj =
4569 gam_linalg::utils::solve_symmetric_vector_with_floor(&outer_h, &alpha, 1.0e-10)?;
4570 if let Some((block, value)) = rho_adj
4571 .iter()
4572 .enumerate()
4573 .find(|(_, value)| !value.is_finite())
4574 {
4575 return Err(EstimationError::InvalidInput(format!(
4576 "outer rho adjoint for block {block} is non-finite: {value}"
4577 )));
4578 }
4579 let weighted_b_sum = rho_adj
4580 .iter()
4581 .zip(b_values.iter())
4582 .map(|(&zk, &bk)| zk * bk)
4583 .sum::<f64>();
4584 q_kernel += 0.5 * tau_q * weighted_b_sum;
4585 for block in 0..f_blocks {
4586 let zk = rho_adj[block];
4587 if zk == 0.0 {
4588 continue;
4589 }
4590 g_kernel -= &(rpr_matrices[block].clone() * (0.5 * zk));
4591 let m = &m_vectors[block];
4592 for i in 0..p_total {
4593 h_kernel[i] += tau * zk * m[i];
4594 for j in 0..p_total {
4595 g_kernel[[i, j]] -= 0.5 * tau * zk * (beta[i] * m[j] + m[i] * beta[j]);
4596 }
4597 }
4598 let start = offsets[block];
4599 let end = offsets[block + 1];
4600 j_blocks[block] += &(r.slice(s![start..end, start..end]).to_owned() * (0.5 * zk));
4601 for i in 0..(end - start) {
4602 for j in 0..(end - start) {
4603 j_blocks[block][[i, j]] += 0.5 * tau * zk * beta[start + i] * beta[start + j];
4604 }
4605 }
4606 }
4607 }
4608
4609 for row in 0..n {
4610 for col in 0..p_total {
4611 grad_z[[row, col]] += -2.0 * q_kernel * weighted_residual[row] * beta[col];
4612 }
4613 }
4614 let zg = z.dot(&g_kernel);
4615 for row in 0..n {
4616 for col in 0..p_total {
4617 grad_z[[row, col]] += 2.0 * weights[row] * zg[[row, col]];
4618 }
4619 }
4620 let wy = y.to_owned() * &weights.to_owned();
4621 for row in 0..n {
4622 for col in 0..p_total {
4623 grad_z[[row, col]] += wy[row] * h_kernel[col];
4624 }
4625 }
4626
4627 let mut grad_y = Array2::<f64>::zeros((n, 1));
4628 let zh = z.dot(&h_kernel);
4629 for row in 0..n {
4630 grad_y[[row, 0]] = 2.0 * q_kernel * weighted_residual[row] + weights[row] * zh[row];
4631 }
4632
4633 let mut grad_weights = Array1::<f64>::zeros(n);
4634 for row in 0..n {
4635 let diag_zgz = (0..p_total)
4636 .map(|col| z[[row, col]] * zg[[row, col]])
4637 .sum::<f64>();
4638 grad_weights[row] = q_kernel * residual[row] * residual[row] + diag_zgz + y[row] * zh[row];
4639 }
4640
4641 if grad_reml_score != 0.0 {
4663 let q_kernel_score = 0.5 * grad_reml_score * tau;
4664 let zr = z.dot(&r);
4665 let n_pos = (0..n).filter(|&i| weights[i] > 0.0).count();
4666 if n_pos > 0 {
4667 let mut weighted_score_partial_sum = 0.0_f64;
4668 for row in 0..n {
4669 if weights[row] <= 0.0 {
4670 continue;
4671 }
4672 let z_r_z = (0..p_total).map(|col| z[[row, col]] * zr[[row, col]]).sum::<f64>();
4673 let a_score = q_kernel_score * residual[row] * residual[row]
4674 + 0.5 * grad_reml_score * z_r_z;
4675 weighted_score_partial_sum += weights[row] * a_score;
4676 }
4677 let projection = weighted_score_partial_sum / n_pos as f64;
4678 for row in 0..n {
4679 if weights[row] > 0.0 {
4680 grad_weights[row] -= projection / weights[row];
4681 }
4682 }
4683 }
4684 }
4685
4686 let mut grad_penalties = Vec::with_capacity(f_blocks);
4687 for block in 0..f_blocks {
4688 let start = offsets[block];
4689 let end = offsets[block + 1];
4690 let mut local = g_kernel.slice(s![start..end, start..end]).to_owned();
4691 for i in 0..(end - start) {
4692 for j in 0..(end - start) {
4693 local[[i, j]] += q_kernel * beta[start + i] * beta[start + j];
4694 }
4695 }
4696 local += &j_blocks[block];
4697 local *= lambdas[block];
4698 gam_linalg::matrix::symmetrize_in_place(&mut local);
4699 grad_penalties.push(local);
4700 }
4701
4702 let mut grad_designs = Vec::with_capacity(f_blocks);
4703 for block in 0..f_blocks {
4704 grad_designs.push(
4705 grad_z
4706 .slice(s![.., offsets[block]..offsets[block + 1]])
4707 .to_owned(),
4708 );
4709 }
4710
4711 Ok(GaussianRemlBlocksBackwardAnalytic {
4712 grad_designs,
4713 grad_penalties,
4714 grad_y,
4715 grad_weights,
4716 })
4717}
4718
4719pub struct DenseFisherGaussianFit {
4723 pub coefficients: Array2<f64>,
4724 pub fitted: Array2<f64>,
4725 pub sigma2: Array1<f64>,
4726 pub objective: f64,
4727}
4728
4729pub fn add_block_diagonal_penalty(
4732 hessian: &mut Array2<f64>,
4733 penalty: ArrayView2<'_, f64>,
4734 lambda: f64,
4735 n_outputs: usize,
4736) -> Result<(), EstimationError> {
4737 let k = penalty.ncols();
4738 if penalty.nrows() != k {
4739 return Err(EstimationError::InvalidInput(format!(
4740 "penalty must be square for dense Fisher fit; got {}x{}",
4741 penalty.nrows(),
4742 penalty.ncols()
4743 )));
4744 }
4745 if hessian.dim() != (k * n_outputs, k * n_outputs) {
4746 return Err(EstimationError::InvalidInput(
4747 "dense Fisher Hessian shape mismatch while adding penalty".to_string(),
4748 ));
4749 }
4750 for output in 0..n_outputs {
4751 let offset = output * k;
4752 for row in 0..k {
4753 for col in 0..k {
4754 let s_sym = 0.5 * (penalty[[row, col]] + penalty[[col, row]]);
4755 hessian[[offset + row, offset + col]] += lambda * s_sym;
4756 }
4757 }
4758 }
4759 Ok(())
4760}
4761
4762pub fn dense_fisher_gaussian_fit(
4769 design: ArrayView2<'_, f64>,
4770 y: ArrayView2<'_, f64>,
4771 penalty: ArrayView2<'_, f64>,
4772 row_weights: ArrayView1<'_, f64>,
4773 fisher_w: ArrayView3<'_, f64>,
4774 lambda: f64,
4775 latent_prior_score: f64,
4776) -> Result<DenseFisherGaussianFit, EstimationError> {
4777 let n_obs = design.nrows();
4778 let k = design.ncols();
4779 let n_outputs = y.ncols();
4780 let mut hessian = crate::pirls::dense_block_xtwx(design, fisher_w, Some(row_weights))?;
4781 add_block_diagonal_penalty(&mut hessian, penalty, lambda, n_outputs)?;
4782 let rhs = crate::pirls::dense_block_xtwy(design, fisher_w, y, Some(row_weights))?;
4783 let beta_vec =
4784 gam_linalg::utils::solve_dense_block_system(&hessian, &rhs, "dense Fisher Gaussian")
4785 .map_err(EstimationError::InvalidInput)?;
4786 let mut coefficients = Array2::<f64>::zeros((k, n_outputs));
4787 for output in 0..n_outputs {
4788 for col in 0..k {
4789 coefficients[[col, output]] = beta_vec[output * k + col];
4790 }
4791 }
4792 let fitted = design.dot(&coefficients);
4793 let mut sigma2 = Array1::<f64>::zeros(n_outputs);
4794 let mut objective = latent_prior_score;
4795 for row in 0..n_obs {
4796 for a in 0..n_outputs {
4797 let ra = y[[row, a]] - fitted[[row, a]];
4798 sigma2[a] += row_weights[row] * ra * ra;
4799 for b in 0..n_outputs {
4800 objective += 0.5
4801 * row_weights[row]
4802 * ra
4803 * fisher_w[[row, a, b]]
4804 * (y[[row, b]] - fitted[[row, b]]);
4805 }
4806 }
4807 }
4808 for output in 0..n_outputs {
4809 sigma2[output] /= (n_obs.saturating_sub(k).max(1)) as f64;
4810 let beta_col = coefficients.column(output);
4811 let s_beta = penalty.dot(&beta_col);
4812 objective += 0.5 * lambda * beta_col.dot(&s_beta);
4813 }
4814 Ok(DenseFisherGaussianFit {
4815 coefficients,
4816 fitted,
4817 sigma2,
4818 objective,
4819 })
4820}