1use crate::estimate::EstimationError;
2use gam_linalg::faer_ndarray::{
3 FaerCholesky, FaerEigh, fast_ab, fast_atb, fast_xt_diag_x, fast_xt_diag_y,
4};
5use faer::Side;
6use ndarray::{
7 Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, ArrayViewMut1, ArrayViewMut2, Axis,
8 s,
9};
10use rayon::prelude::*;
11use std::sync::Once;
12
13static ILL_CONDITIONED_BACKWARD_WARNED: Once = Once::new();
21
22fn warn_ill_conditioned_backward_once(p: usize, d: usize, condition_number: f64) {
23 ILL_CONDITIONED_BACKWARD_WARNED.call_once(|| {
24 log::warn!(
25 "gaussian_reml_fit_backward: K = XᵀWX + λS is near-singular \
26 (p={p}, d={d}, cond≈{condition_number:.2e}); returning zero gradients \
27 for this fit (λ has saturated, atom is effectively unused). \
28 Further occurrences are silent."
29 );
30 });
31}
32
33fn zero_backward_result(n: usize, p: usize, d: usize) -> GaussianRemlBackwardResult {
34 GaussianRemlBackwardResult {
35 grad_x: Array2::<f64>::zeros((n, p)),
36 grad_y: Array2::<f64>::zeros((n, d)),
37 grad_penalty: Array2::<f64>::zeros((p, p)),
38 grad_weights: Array1::<f64>::zeros(n),
39 }
40}
41
42const RHO_LOWER: f64 = -30.0;
43const RHO_UPPER: f64 = 30.0;
44const EIGEN_REL_TOL: f64 = 1.0e-10;
45const GRAD_TOL: f64 = 1.0e-12;
46const MIN_DEVIANCE: f64 = 1.0e-300;
47
48fn canonicalize_penalty(penalty: ArrayView2<'_, f64>) -> Array2<f64> {
59 let p = penalty.nrows();
60 let mut out = penalty.to_owned();
61 for i in 0..p {
62 for j in (i + 1)..p {
63 let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
64 out[[i, j]] = avg;
65 out[[j, i]] = avg;
66 }
67 }
68 out
69}
70
71#[derive(Clone, Debug)]
72pub struct GaussianRemlEigenCache {
73 pub penalty_eigenvalues: Array1<f64>,
74 pub eigenvectors: Array2<f64>,
75 pub coefficient_basis: Array2<f64>,
76 pub xtwx_fingerprint: u64,
77 pub penalty_fingerprint: u64,
78 pub logdet_xtwx: f64,
79 pub logdet_penalty_positive: f64,
80 pub penalty_rank: usize,
81 pub nullity: usize,
82}
83
84#[derive(Clone, Debug, Default)]
85pub struct GaussianRemlWarmStart {
86 pub lambda: Option<f64>,
87 pub eigen_cache: Option<GaussianRemlEigenCache>,
88}
89
90impl GaussianRemlWarmStart {
91 pub fn from_multi_result(result: &GaussianRemlMultiResult) -> Self {
92 Self {
93 lambda: Some(result.lambda),
94 eigen_cache: Some(result.cache.clone()),
95 }
96 }
97}
98
99#[derive(Clone, Debug)]
100pub struct GaussianRemlResult {
101 pub lambda: f64,
102 pub rho: f64,
103 pub coefficients: Array1<f64>,
104 pub fitted: Array1<f64>,
105 pub reml_score: f64,
106 pub reml_grad_lambda: f64,
107 pub reml_hess_lambda: f64,
108 pub reml_grad_rho: f64,
109 pub reml_hess_rho: f64,
110 pub edf: f64,
111 pub sigma2: f64,
112 pub cache: GaussianRemlEigenCache,
113}
114
115#[derive(Clone, Debug)]
116pub struct GaussianRemlMultiResult {
117 pub lambda: f64,
118 pub rho: f64,
119 pub coefficients: Array2<f64>,
120 pub fitted: Array2<f64>,
121 pub reml_score: f64,
122 pub reml_grad_lambda: f64,
123 pub reml_hess_lambda: f64,
124 pub reml_grad_rho: f64,
125 pub reml_hess_rho: f64,
126 pub edf: f64,
127 pub sigma2: Array1<f64>,
128 pub cache: GaussianRemlEigenCache,
129}
130
131#[derive(Clone, Debug)]
132pub struct GaussianRemlFreeBScore {
133 pub reml_score: f64,
134 pub grad_coefficients: Array2<f64>,
135 pub grad_penalty: Array2<f64>,
136 pub grad_log_lambda: f64,
137 pub fitted: Array2<f64>,
138 pub sigma2: Array1<f64>,
139 pub edf: f64,
140}
141
142#[derive(Clone, Debug)]
143pub struct GaussianRemlBackwardResult {
144 pub grad_x: Array2<f64>,
145 pub grad_y: Array2<f64>,
146 pub grad_penalty: Array2<f64>,
147 pub grad_weights: Array1<f64>,
148}
149
150#[derive(Clone, Debug)]
151pub struct GaussianRemlMultiBackwardProblem<'a> {
152 pub x: ArrayView2<'a, f64>,
153 pub y: ArrayView2<'a, f64>,
154 pub weights: Option<ArrayView1<'a, f64>>,
155 pub fit: &'a GaussianRemlMultiResult,
156 pub grad_lambda: f64,
157 pub grad_coefficients: Option<ArrayView2<'a, f64>>,
158 pub grad_fitted: Option<ArrayView2<'a, f64>>,
159 pub grad_reml_score: f64,
160 pub grad_edf: f64,
161}
162
163#[derive(Clone, Debug)]
164pub struct GaussianRemlNoAllocWorkspace {
165 pub xtwy: Array2<f64>,
166 pub ywy: Array1<f64>,
167 pub projected_rhs: Array2<f64>,
168 pub projected_rhs_squared: Array2<f64>,
169 pub scaled_projected_rhs: Array2<f64>,
170}
171
172impl GaussianRemlNoAllocWorkspace {
173 pub fn new(n_coefficients: usize, n_outputs: usize) -> Self {
174 Self {
175 xtwy: Array2::zeros((n_coefficients, n_outputs)),
176 ywy: Array1::zeros(n_outputs),
177 projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
178 projected_rhs_squared: Array2::zeros((n_coefficients, n_outputs)),
179 scaled_projected_rhs: Array2::zeros((n_coefficients, n_outputs)),
180 }
181 }
182
183 fn validate(&self, p: usize, d: usize) -> Result<(), EstimationError> {
184 if self.xtwy.dim() != (p, d)
185 || self.ywy.len() != d
186 || self.projected_rhs.dim() != (p, d)
187 || self.projected_rhs_squared.dim() != (p, d)
188 || self.scaled_projected_rhs.dim() != (p, d)
189 {
190 crate::bail_invalid_estim!(
191 "Gaussian REML no-alloc workspace shape mismatch: expected p={p}, d={d}"
192 );
193 }
194 Ok::<(), _>(())
195 }
196}
197
198#[derive(Clone, Copy, Debug)]
199pub struct GaussianRemlNoAllocFit {
200 pub lambda: f64,
201 pub rho: f64,
202 pub reml_score: f64,
203 pub reml_grad_lambda: f64,
204 pub reml_hess_lambda: f64,
205 pub reml_grad_rho: f64,
206 pub reml_hess_rho: f64,
207 pub edf: f64,
208}
209
210#[derive(Clone, Debug)]
211pub struct GaussianRemlMultiBatchProblem<'a> {
212 pub x: ArrayView2<'a, f64>,
213 pub y: ArrayView2<'a, f64>,
214 pub weights: Option<ArrayView1<'a, f64>>,
215 pub init_rho: Option<f64>,
216}
217
218#[derive(Clone, Debug)]
219pub struct GaussianRemlBlockOrthogonalResult {
220 pub coefficients: Vec<Array2<f64>>,
221 pub fitted: Array2<f64>,
222 pub lambdas: Array1<f64>,
223 pub log_lambdas: Array1<f64>,
224 pub reml_score: f64,
225 pub edf: Array1<f64>,
226}
227
228#[derive(Clone)]
229struct GaussianRemlPrepared {
230 cache: GaussianRemlEigenCache,
231 ywy: Array1<f64>,
232 projected_rhs_squared: Array2<f64>,
233 projected_rhs: Array2<f64>,
234 n_observations: usize,
235 n_outputs: usize,
236}
237
238#[derive(Clone, Copy)]
239struct ObjectiveEval {
240 cost: f64,
241 grad: f64,
242 hess: f64,
243 edf: f64,
244}
245
246#[derive(Clone, Copy)]
257struct TermDerivs {
258 value: f64,
259 grad: f64,
260 hess: f64,
261}
262
263impl std::ops::AddAssign<TermDerivs> for ObjectiveEval {
264 fn add_assign(&mut self, rhs: TermDerivs) {
267 self.cost += rhs.value;
268 self.grad += rhs.grad;
269 self.hess += rhs.hess;
270 }
271}
272
273fn gaussian_reml_logdet_term(
279 cache: &GaussianRemlEigenCache,
280 rho: f64,
281 n_outputs: f64,
282) -> (TermDerivs, f64) {
283 let lambda = rho.exp();
284 let mut logdet_h = cache.logdet_xtwx;
285 let mut trace_h = 0.0;
286 let mut trace_h_deriv = 0.0;
287 let mut edf = 0.0;
288 for &delta in &cache.penalty_eigenvalues {
289 let t = lambda * delta;
290 logdet_h += (1.0 + t).ln();
291 if delta > 0.0 {
292 trace_h += t / (1.0 + t);
293 trace_h_deriv += t / ((1.0 + t) * (1.0 + t));
294 }
295 edf += 1.0 / (1.0 + t);
296 }
297 let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * rho;
298 let term = TermDerivs {
299 value: 0.5 * n_outputs * (logdet_h - logdet_s),
300 grad: 0.5 * n_outputs * (trace_h - cache.penalty_rank as f64),
301 hess: 0.5 * n_outputs * trace_h_deriv,
302 };
303 (term, edf)
304}
305
306fn gaussian_reml_dispersion_term(
313 cache: &GaussianRemlEigenCache,
314 ywy: ArrayView1<'_, f64>,
315 projected_rhs_squared: ArrayView2<'_, f64>,
316 output: usize,
317 nu: f64,
318 lambda: f64,
319) -> TermDerivs {
320 let mut fitted_quadratic = 0.0;
321 let mut dp_grad = 0.0;
322 let mut dp_hess = 0.0;
323 for eig in 0..cache.penalty_eigenvalues.len() {
324 let c2 = projected_rhs_squared[[eig, output]];
325 let t = lambda * cache.penalty_eigenvalues[eig];
326 let denom = 1.0 + t;
327 fitted_quadratic += c2 / denom;
328 dp_grad += c2 * t / (denom * denom);
329 dp_hess += c2 * t * (1.0 - t) / (denom * denom * denom);
330 }
331 let dp = (ywy[output] - fitted_quadratic).max(MIN_DEVIANCE);
332 TermDerivs {
333 value: 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln()),
334 grad: 0.5 * nu * dp_grad / dp,
335 hess: 0.5 * nu * (dp_hess / dp - (dp_grad * dp_grad) / (dp * dp)),
336 }
337}
338
339pub fn gaussian_reml_closed_form(
340 x: ArrayView2<'_, f64>,
341 y: ArrayView1<'_, f64>,
342 penalty: ArrayView2<'_, f64>,
343 weights: Option<ArrayView1<'_, f64>>,
344 init_rho: Option<f64>,
345) -> Result<GaussianRemlResult, EstimationError> {
346 gaussian_reml_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
347}
348
349pub fn gaussian_reml_closed_form_with_nullspace_dim(
350 x: ArrayView2<'_, f64>,
351 y: ArrayView1<'_, f64>,
352 penalty: ArrayView2<'_, f64>,
353 nullspace_dim: Option<usize>,
354 weights: Option<ArrayView1<'_, f64>>,
355 init_rho: Option<f64>,
356) -> Result<GaussianRemlResult, EstimationError> {
357 let y2 = y.insert_axis(Axis(1));
358 let result = gaussian_reml_multi_closed_form_with_nullspace_dim(
359 x,
360 y2,
361 penalty,
362 nullspace_dim,
363 weights,
364 init_rho,
365 )?;
366 scalar_result_from_multi(result)
367}
368
369fn scalar_result_from_multi(
370 result: GaussianRemlMultiResult,
371) -> Result<GaussianRemlResult, EstimationError> {
372 Ok(GaussianRemlResult {
373 lambda: result.lambda,
374 rho: result.rho,
375 coefficients: result.coefficients.column(0).to_owned(),
376 fitted: result.fitted.column(0).to_owned(),
377 reml_score: result.reml_score,
378 reml_grad_lambda: result.reml_grad_lambda,
379 reml_hess_lambda: result.reml_hess_lambda,
380 reml_grad_rho: result.reml_grad_rho,
381 reml_hess_rho: result.reml_hess_rho,
382 edf: result.edf,
383 sigma2: result.sigma2[0],
384 cache: result.cache,
385 })
386}
387
388pub fn gaussian_reml_multi_closed_form(
389 x: ArrayView2<'_, f64>,
390 y: ArrayView2<'_, f64>,
391 penalty: ArrayView2<'_, f64>,
392 weights: Option<ArrayView1<'_, f64>>,
393 init_rho: Option<f64>,
394) -> Result<GaussianRemlMultiResult, EstimationError> {
395 gaussian_reml_multi_closed_form_with_nullspace_dim(x, y, penalty, None, weights, init_rho)
396}
397
398pub fn gaussian_reml_multi_closed_form_with_nullspace_dim(
399 x: ArrayView2<'_, f64>,
400 y: ArrayView2<'_, f64>,
401 penalty: ArrayView2<'_, f64>,
402 nullspace_dim: Option<usize>,
403 weights: Option<ArrayView1<'_, f64>>,
404 init_rho: Option<f64>,
405) -> Result<GaussianRemlMultiResult, EstimationError> {
406 let init_lambda = init_rho.map(f64::exp);
407 gaussian_reml_multi_closed_form_from_parts(
408 x,
409 y,
410 penalty,
411 nullspace_dim,
412 weights,
413 init_lambda,
414 None,
415 )
416}
417
418pub fn gaussian_reml_multi_closed_form_warm_started(
419 x: ArrayView2<'_, f64>,
420 y: ArrayView2<'_, f64>,
421 penalty: ArrayView2<'_, f64>,
422 weights: Option<ArrayView1<'_, f64>>,
423 warm_start: Option<&GaussianRemlWarmStart>,
424) -> Result<GaussianRemlMultiResult, EstimationError> {
425 gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
426 x, y, penalty, None, weights, warm_start,
427 )
428}
429
430pub fn gaussian_reml_multi_closed_form_warm_started_with_nullspace_dim(
431 x: ArrayView2<'_, f64>,
432 y: ArrayView2<'_, f64>,
433 penalty: ArrayView2<'_, f64>,
434 nullspace_dim: Option<usize>,
435 weights: Option<ArrayView1<'_, f64>>,
436 warm_start: Option<&GaussianRemlWarmStart>,
437) -> Result<GaussianRemlMultiResult, EstimationError> {
438 let init_lambda = warm_start.and_then(|start| start.lambda);
439 let eigen_cache = warm_start.and_then(|start| start.eigen_cache.as_ref());
440 gaussian_reml_multi_closed_form_from_parts(
441 x,
442 y,
443 penalty,
444 nullspace_dim,
445 weights,
446 init_lambda,
447 eigen_cache,
448 )
449}
450
451pub fn gaussian_reml_multi_closed_form_with_cache(
452 x: ArrayView2<'_, f64>,
453 y: ArrayView2<'_, f64>,
454 penalty: ArrayView2<'_, f64>,
455 weights: Option<ArrayView1<'_, f64>>,
456 init_lambda: Option<f64>,
457 eigen_cache: Option<&GaussianRemlEigenCache>,
458) -> Result<GaussianRemlMultiResult, EstimationError> {
459 gaussian_reml_multi_closed_form_from_parts(
460 x,
461 y,
462 penalty,
463 None,
464 weights,
465 init_lambda,
466 eigen_cache,
467 )
468}
469
470pub fn gaussian_reml_multi_closed_form_with_cache_no_alloc(
471 x: ArrayView2<'_, f64>,
472 y: ArrayView2<'_, f64>,
473 penalty: ArrayView2<'_, f64>,
474 weights: Option<ArrayView1<'_, f64>>,
475 init_lambda: Option<f64>,
476 eigen_cache: &GaussianRemlEigenCache,
477 workspace: &mut GaussianRemlNoAllocWorkspace,
478 mut coefficients: ArrayViewMut2<'_, f64>,
479 mut fitted: ArrayViewMut2<'_, f64>,
480 mut sigma2: ArrayViewMut1<'_, f64>,
481) -> Result<GaussianRemlNoAllocFit, EstimationError> {
482 let penalty_owned = canonicalize_penalty(penalty);
486 let penalty = penalty_owned.view();
487 let n = x.nrows();
488 let p = x.ncols();
489 let d = y.ncols();
490 validate_gaussian_reml_design(x, penalty, weights)?;
491 validate_gaussian_reml_eigen_cache(eigen_cache, p)?;
492 if y.nrows() != n {
493 crate::bail_invalid_estim!(
494 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
495 y.nrows()
496 );
497 }
498 if y.iter().any(|value| !value.is_finite()) {
499 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
500 }
501 if n <= eigen_cache.nullity {
502 crate::bail_invalid_estim!(
503 "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
504 eigen_cache.nullity
505 );
506 }
507 let penalty_fingerprint = matrix_fingerprint(penalty);
508 if eigen_cache.penalty_fingerprint != penalty_fingerprint {
509 crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
510 }
511 workspace.validate(p, d)?;
512 if coefficients.dim() != (p, d) || fitted.dim() != (n, d) || sigma2.len() != d {
513 crate::bail_invalid_estim!(
514 "Gaussian REML no-alloc output shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
515 );
516 }
517 if let Some(lambda) = init_lambda {
518 validate_initial_lambda(lambda)?;
519 }
520
521 fill_weighted_rhs_no_alloc(x, y, weights, workspace)?;
522 project_rhs_no_alloc(eigen_cache, workspace);
523
524 let init_rho = init_lambda.map(f64::ln);
525 let rho = optimize_rho_no_alloc(
526 eigen_cache,
527 workspace.ywy.view(),
528 workspace.projected_rhs_squared.view(),
529 n,
530 d,
531 init_rho,
532 )?;
533 let eval = evaluate_reml_parts(
534 eigen_cache,
535 workspace.ywy.view(),
536 workspace.projected_rhs_squared.view(),
537 n,
538 d,
539 rho,
540 );
541 let lambda = rho.exp();
542 fill_coefficients_no_alloc(eigen_cache, workspace, lambda, coefficients.view_mut());
543 fill_fitted_no_alloc(x, coefficients.view(), fitted.view_mut());
544 fill_sigma2_no_alloc(
545 eigen_cache,
546 workspace.ywy.view(),
547 workspace.projected_rhs_squared.view(),
548 n,
549 d,
550 lambda,
551 sigma2.view_mut(),
552 );
553 let (reml_grad_lambda, reml_hess_lambda) =
554 rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
555 Ok(GaussianRemlNoAllocFit {
556 lambda,
557 rho,
558 reml_score: eval.cost,
559 reml_grad_lambda,
560 reml_hess_lambda,
561 reml_grad_rho: eval.grad,
562 reml_hess_rho: eval.hess,
563 edf: eval.edf,
564 })
565}
566
567
568pub fn gaussian_reml_multi_closed_form_batch<'a>(
569 problems: &[GaussianRemlMultiBatchProblem<'a>],
570 penalty: ArrayView2<'a, f64>,
571 nullspace_dim: Option<usize>,
572) -> Result<Vec<GaussianRemlMultiResult>, EstimationError> {
573 if problems.is_empty() {
574 return Ok(Vec::new());
575 }
576 let xtwx_per_problem: Vec<Array2<f64>> = problems
580 .par_iter()
581 .map(|problem| {
582 let weight = match problem.weights.as_ref() {
583 Some(w) => w.to_owned(),
584 None => Array1::ones(problem.x.nrows()),
585 };
586 dense_xt_diag_x(problem.x.view(), weight.view())
587 })
588 .collect();
589 let caches =
593 build_gaussian_reml_eigen_cache_batched(xtwx_per_problem, penalty.view(), nullspace_dim);
594 let fits: Vec<Result<GaussianRemlMultiResult, EstimationError>> = problems
598 .par_iter()
599 .zip(caches.into_par_iter())
600 .map(|(problem, cache_result)| {
601 let init_lambda = problem.init_rho.map(f64::exp);
602 let cache = cache_result?;
603 gaussian_reml_multi_closed_form_from_parts(
604 problem.x.view(),
605 problem.y.view(),
606 penalty.view(),
607 nullspace_dim,
608 problem.weights.as_ref().map(|weights| weights.view()),
609 init_lambda,
610 Some(&cache),
611 )
612 })
613 .collect();
614 fits.into_iter().collect()
615}
616
617struct BlockOrthogonalEval {
618 beta: Array2<f64>,
619 logdet: f64,
620 trace: f64,
621 trace_pair: f64,
622 fitted_energy: Array1<f64>,
623 penalty_energy: Array1<f64>,
624 curvature_energy: Array1<f64>,
625 edf: f64,
626}
627
628fn block_penalty_rank_logdet(
629 penalty: ArrayView2<'_, f64>,
630) -> Result<(usize, f64), EstimationError> {
631 let eigs = penalty
632 .to_owned()
633 .eigh(Side::Lower)
634 .map_err(|_| EstimationError::ModelIsIllConditioned {
635 condition_number: f64::INFINITY,
636 })?
637 .0;
638 let max_abs = eigs.iter().fold(0.0_f64, |m, &v| m.max(v.abs()));
639 let tol = (EIGEN_REL_TOL * max_abs).max(1.0e-14);
640 let mut rank = 0_usize;
641 let mut logdet = 0.0;
642 for eig in eigs.iter().copied() {
643 if eig > tol {
644 rank += 1;
645 logdet += eig.ln();
646 }
647 }
648 Ok((rank, logdet))
649}
650
651fn block_orthogonal_eval(
652 gram: &Array2<f64>,
653 rhs: &Array2<f64>,
654 penalty: &Array2<f64>,
655 rho: f64,
656) -> Result<BlockOrthogonalEval, EstimationError> {
657 let lambda = rho.exp();
658 validate_initial_lambda(lambda)?;
659 let scaled_penalty = penalty * lambda;
660 let hessian = canonicalize_penalty((gram + &scaled_penalty).view());
661 let chol = gaussian_reml_cholesky_lower(hessian)?;
662 let beta = solve_spd_from_lower_factor(&chol, rhs)?;
663 let solved_penalty = solve_spd_from_lower_factor(&chol, &scaled_penalty)?;
664 let logdet = 2.0 * chol.diag().iter().map(|value| value.ln()).sum::<f64>();
665 let trace = (0..solved_penalty.nrows())
666 .map(|i| solved_penalty[[i, i]])
667 .sum::<f64>();
668 let trace_pair =
669 gam_linalg::utils::trace_of_product(solved_penalty.view(), solved_penalty.view());
670 let fitted_energy = (rhs * &beta).sum_axis(Axis(0));
671 let p_beta = scaled_penalty.dot(&beta);
672 let penalty_energy = (&beta * &p_beta).sum_axis(Axis(0));
673 let solved_p_beta = solve_spd_from_lower_factor(&chol, &p_beta)?;
674 let curvature_energy = (&p_beta * &solved_p_beta).sum_axis(Axis(0));
675 Ok(BlockOrthogonalEval {
676 beta,
677 logdet,
678 trace,
679 trace_pair,
680 fitted_energy,
681 penalty_energy,
682 curvature_energy,
683 edf: penalty.nrows() as f64 - trace,
684 })
685}
686
687struct BlockOrthogonalScaleDerivs {
698 value: f64,
699 grad: f64,
700 hess: f64,
701}
702
703fn block_orthogonal_scale_objective(
704 eval: &BlockOrthogonalEval,
705 rho: f64,
706 scale_precision: ArrayView1<'_, f64>,
707 rank: usize,
708) -> BlockOrthogonalScaleDerivs {
709 let d = scale_precision.len() as f64;
710 let fit_term = scale_precision
711 .iter()
712 .zip(eval.fitted_energy.iter())
713 .map(|(scale, energy)| scale * energy)
714 .sum::<f64>();
715 let value = 0.5 * d * eval.logdet - 0.5 * fit_term - 0.5 * d * (rank as f64) * rho;
717 let grad = 0.5 * d * (eval.trace - rank as f64)
721 + 0.5
722 * scale_precision
723 .iter()
724 .zip(eval.penalty_energy.iter())
725 .map(|(scale, energy)| scale * energy)
726 .sum::<f64>();
727 let hess = 0.5 * d * (eval.trace - eval.trace_pair)
730 + 0.5
731 * scale_precision
732 .iter()
733 .zip(eval.penalty_energy.iter().zip(eval.curvature_energy.iter()))
734 .map(|(scale, (energy, curvature))| scale * (energy - 2.0 * curvature))
735 .sum::<f64>();
736 BlockOrthogonalScaleDerivs { value, grad, hess }
737}
738
739fn solve_block_orthogonal_rho(
740 gram: &Array2<f64>,
741 rhs: &Array2<f64>,
742 penalty: &Array2<f64>,
743 rho0: f64,
744 scale_precision: ArrayView1<'_, f64>,
745 rank: usize,
746 max_iter: usize,
747) -> Result<(f64, BlockOrthogonalEval), EstimationError> {
748 let mut rho = rho0;
749 let mut current = block_orthogonal_eval(gram, rhs, penalty, rho)?;
750 for _ in 0..max_iter {
751 let derivs = block_orthogonal_scale_objective(¤t, rho, scale_precision, rank);
754 let grad = derivs.grad;
755 let hess = derivs.hess;
756 if !(grad.is_finite() && hess.is_finite()) {
757 return Err(EstimationError::ModelIsIllConditioned {
758 condition_number: f64::INFINITY,
759 });
760 }
761 let descent = grad.signum();
772 let step = if hess > 1.0e-10 { grad / hess } else { descent };
773 let mut best_rho = rho;
774 let mut best_eval = current;
775 let mut best_phi =
776 block_orthogonal_scale_objective(&best_eval, best_rho, scale_precision, rank).value;
777 for candidate_rho in [
778 rho - step,
779 rho - 0.5 * step,
780 rho - 0.25 * step,
781 rho - descent,
782 rho - 0.25 * descent,
783 ] {
784 let Ok(candidate_eval) = block_orthogonal_eval(gram, rhs, penalty, candidate_rho)
789 else {
790 continue;
791 };
792 let candidate_phi = block_orthogonal_scale_objective(
793 &candidate_eval,
794 candidate_rho,
795 scale_precision,
796 rank,
797 )
798 .value;
799 if candidate_phi < best_phi {
800 best_rho = candidate_rho;
801 best_eval = candidate_eval;
802 best_phi = candidate_phi;
803 }
804 }
805 let delta = (best_rho - rho).abs();
806 rho = best_rho;
807 current = best_eval;
808 if delta < 1.0e-12 || step.abs() < 1.0e-7 {
809 break;
810 }
811 }
812 Ok((rho, current))
813}
814
815pub fn gaussian_reml_blocks_orthogonal_shared_scale(
816 designs: &[Array2<f64>],
817 penalties: &[Array2<f64>],
818 y: ArrayView2<'_, f64>,
819 weights: Option<ArrayView1<'_, f64>>,
820 init_rhos: Option<&[f64]>,
821) -> Result<GaussianRemlBlockOrthogonalResult, EstimationError> {
822 if designs.is_empty() {
823 crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one block");
824 }
825 if designs.len() != penalties.len() {
826 crate::bail_invalid_estim!(
827 "block-orthogonal Gaussian REML block mismatch: {} designs, {} penalties",
828 designs.len(),
829 penalties.len()
830 );
831 }
832 let n = y.nrows();
833 let d = y.ncols();
834 if d == 0 {
835 crate::bail_invalid_estim!("block-orthogonal Gaussian REML requires at least one output");
836 }
837 if y.iter().any(|value| !value.is_finite()) {
838 crate::bail_invalid_estim!("block-orthogonal Gaussian REML response must be finite");
839 }
840 let weight = gaussian_reml_weights(n, weights)?;
841 if let Some(rhos) = init_rhos {
842 if rhos.len() != designs.len() {
843 crate::bail_invalid_estim!(
844 "block-orthogonal Gaussian REML init_rhos length mismatch: expected {}, got {}",
845 designs.len(),
846 rhos.len()
847 );
848 }
849 if rhos.iter().any(|value| !value.is_finite()) {
850 crate::bail_invalid_estim!("block-orthogonal Gaussian REML init_rhos must be finite");
851 }
852 }
853
854 let mut ywy = Array1::<f64>::zeros(d);
855 for row in 0..n {
856 for output in 0..d {
857 ywy[output] += weight[row] * y[[row, output]] * y[[row, output]];
858 }
859 }
860 let mut grams = Vec::with_capacity(designs.len());
861 let mut rhs_blocks = Vec::with_capacity(designs.len());
862 let mut penalties_owned = Vec::with_capacity(penalties.len());
863 let mut ranks = Vec::with_capacity(penalties.len());
864 let mut penalty_logdets = Vec::with_capacity(penalties.len());
865 let mut nullity_total = 0_usize;
866 for (block, (design, penalty)) in designs.iter().zip(penalties.iter()).enumerate() {
867 let penalty_owned = canonicalize_penalty(penalty.view());
868 validate_gaussian_reml_design(design.view(), penalty_owned.view(), Some(weight.view()))?;
869 if design.nrows() != n {
870 crate::bail_invalid_estim!(
871 "block-orthogonal Gaussian REML designs[{block}] has {} rows, expected {n}",
872 design.nrows()
873 );
874 }
875 let gram = dense_xt_diag_x(design.view(), weight.view());
876 let rhs = dense_xt_diag_y(design.view(), weight.view(), y);
877 let (rank, logdet) = block_penalty_rank_logdet(penalty_owned.view())?;
878 nullity_total += penalty_owned.nrows().saturating_sub(rank);
879 grams.push(canonicalize_penalty(gram.view()));
880 rhs_blocks.push(rhs);
881 penalties_owned.push(penalty_owned);
882 ranks.push(rank);
883 penalty_logdets.push(logdet);
884 }
885 if n <= nullity_total {
886 crate::bail_invalid_estim!(
887 "block-orthogonal Gaussian REML requires n > total penalty nullity; got n={n}, nullity={nullity_total}"
888 );
889 }
890 let nu = (n - nullity_total) as f64;
891 let mut rhos = match init_rhos {
892 Some(values) => Array1::from_vec(values.to_vec()),
893 None => Array1::zeros(designs.len()),
894 };
895 let mut scale_precision = ywy.mapv(|value| nu / value.max(MIN_DEVIANCE));
896 let mut evals = Vec::new();
897 for _ in 0..40 {
898 evals.clear();
899 for block in 0..designs.len() {
900 let (rho, eval) = solve_block_orthogonal_rho(
901 &grams[block],
902 &rhs_blocks[block],
903 &penalties_owned[block],
904 rhos[block],
905 scale_precision.view(),
906 ranks[block],
907 32,
908 )?;
909 rhos[block] = rho;
910 evals.push(eval);
911 }
912 let mut explained = Array1::<f64>::zeros(d);
913 for eval in evals.iter() {
914 explained += &eval.fitted_energy;
915 }
916 let q = &ywy - &explained;
917 if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
918 return Err(EstimationError::ModelIsIllConditioned {
919 condition_number: f64::INFINITY,
920 });
921 }
922 let next_scale = q.mapv(|value| nu / value);
923 let scale_step = next_scale
924 .iter()
925 .zip(scale_precision.iter())
926 .map(|(next, old)| (next.ln() - old.ln()).abs())
927 .fold(0.0_f64, f64::max);
928 scale_precision = next_scale;
929 if scale_step < 1.0e-7 {
930 break;
931 }
932 }
933 evals.clear();
934 for block in 0..designs.len() {
935 let (rho, eval) = solve_block_orthogonal_rho(
936 &grams[block],
937 &rhs_blocks[block],
938 &penalties_owned[block],
939 rhos[block],
940 scale_precision.view(),
941 ranks[block],
942 16,
943 )?;
944 rhos[block] = rho;
945 evals.push(eval);
946 }
947
948 let coefficients = evals
949 .iter()
950 .map(|eval| eval.beta.clone())
951 .collect::<Vec<_>>();
952 let mut fitted = Array2::<f64>::zeros((n, d));
953 for (design, coef) in designs.iter().zip(coefficients.iter()) {
954 fitted += &fast_ab(&design.view(), &coef.view());
955 }
956 let mut explained = Array1::<f64>::zeros(d);
957 for eval in evals.iter() {
958 explained += &eval.fitted_energy;
959 }
960 let q = &ywy - &explained;
961 if q.iter().any(|value| !value.is_finite() || *value <= 0.0) {
962 return Err(EstimationError::ModelIsIllConditioned {
963 condition_number: f64::INFINITY,
964 });
965 }
966 let lambdas = rhos.mapv(f64::exp);
967 let edf = Array1::from_iter(evals.iter().map(|eval| eval.edf));
968 let logdet_term = evals
969 .iter()
970 .enumerate()
971 .map(|(block, eval)| {
972 eval.logdet - penalty_logdets[block] - (ranks[block] as f64) * rhos[block]
973 })
974 .sum::<f64>();
975 let scale_term = q
976 .iter()
977 .map(|value| nu * (1.0 + (2.0 * std::f64::consts::PI * value / nu).ln()))
978 .sum::<f64>();
979 Ok(GaussianRemlBlockOrthogonalResult {
980 coefficients,
981 fitted,
982 lambdas,
983 log_lambdas: rhos,
984 reml_score: 0.5 * (d as f64) * logdet_term + 0.5 * scale_term,
985 edf,
986 })
987}
988
989fn gaussian_reml_multi_closed_form_from_parts(
990 x: ArrayView2<'_, f64>,
991 y: ArrayView2<'_, f64>,
992 penalty: ArrayView2<'_, f64>,
993 nullspace_dim: Option<usize>,
994 weights: Option<ArrayView1<'_, f64>>,
995 init_lambda: Option<f64>,
996 eigen_cache: Option<&GaussianRemlEigenCache>,
997) -> Result<GaussianRemlMultiResult, EstimationError> {
998 let prepared = prepare_gaussian_reml(x, y, penalty, nullspace_dim, weights, eigen_cache)?;
999 let init_rho = init_lambda
1000 .map(validate_initial_lambda)
1001 .transpose()?
1002 .map(f64::ln);
1003 let rho = optimize_rho(&prepared, init_rho)?;
1004 let eval = prepared.evaluate(rho);
1005 let lambda = rho.exp();
1006 let coefficients = prepared.coefficients(lambda);
1007 let fitted = dense_ab(x, coefficients.view());
1008 let sigma2 = prepared.sigma2(lambda);
1009 let (reml_grad_lambda, reml_hess_lambda) =
1010 rho_derivatives_to_lambda(lambda, eval.grad, eval.hess);
1011 Ok(GaussianRemlMultiResult {
1012 lambda,
1013 rho,
1014 coefficients,
1015 fitted,
1016 reml_score: eval.cost,
1017 reml_grad_lambda,
1018 reml_hess_lambda,
1019 reml_grad_rho: eval.grad,
1020 reml_hess_rho: eval.hess,
1021 edf: eval.edf,
1022 sigma2,
1023 cache: prepared.cache,
1024 })
1025}
1026
1027pub fn gaussian_reml_free_b_score(
1028 x: ArrayView2<'_, f64>,
1029 y: ArrayView2<'_, f64>,
1030 coefficients: ArrayView2<'_, f64>,
1031 log_lambda: f64,
1032 penalty: ArrayView2<'_, f64>,
1033 weights: Option<ArrayView1<'_, f64>>,
1034) -> Result<GaussianRemlFreeBScore, EstimationError> {
1035 if !log_lambda.is_finite() {
1036 crate::bail_invalid_estim!("Gaussian REML log_lambda must be finite; got {log_lambda}");
1037 }
1038 let lambda = log_lambda.exp();
1039 let penalty_owned = canonicalize_penalty(penalty);
1040 let penalty = penalty_owned.view();
1041 let n = x.nrows();
1042 let p = x.ncols();
1043 let d = y.ncols();
1044 validate_gaussian_reml_design(x, penalty, weights)?;
1045 if y.nrows() != n {
1046 crate::bail_invalid_estim!(
1047 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
1048 y.nrows()
1049 );
1050 }
1051 if coefficients.dim() != (p, d) {
1052 crate::bail_invalid_estim!(
1053 "Gaussian REML coefficient shape mismatch: expected {p}x{d}, got {}x{}",
1054 coefficients.nrows(),
1055 coefficients.ncols()
1056 );
1057 }
1058 if y.iter().chain(coefficients.iter()).any(|v| !v.is_finite()) {
1059 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
1060 }
1061
1062 let weight = gaussian_reml_weights(n, weights)?;
1063 let cache =
1064 build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, Some(weight.view()))?;
1065 if n <= cache.nullity {
1066 crate::bail_invalid_estim!(
1067 "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
1068 cache.nullity
1069 );
1070 }
1071 let nu = n as f64 - cache.nullity as f64;
1072 let fitted = dense_ab(x, coefficients);
1073 let residual = y.to_owned() - &fitted;
1074 let xtw_residual = dense_xt_diag_y(x, weight.view(), residual.view());
1075 let s_beta = dense_ab(penalty, coefficients);
1076
1077 let mut logdet_h = cache.logdet_xtwx;
1078 let mut trace_h = 0.0;
1079 let mut edf = 0.0;
1080 for &delta in &cache.penalty_eigenvalues {
1081 let t = lambda * delta;
1082 logdet_h += (1.0 + t).ln();
1083 if delta > 0.0 {
1084 trace_h += t / (1.0 + t);
1085 }
1086 edf += 1.0 / (1.0 + t);
1087 }
1088 let logdet_s = cache.logdet_penalty_positive + (cache.penalty_rank as f64) * log_lambda;
1089 let mut reml_score = 0.5 * (d as f64) * (logdet_h - logdet_s);
1090 let mut grad_log_lambda = 0.5 * (d as f64) * (trace_h - cache.penalty_rank as f64);
1091 let mut grad_coefficients = Array2::<f64>::zeros((p, d));
1092 let inverse_hessian = {
1093 let xtwx = dense_xt_diag_x(x, weight.view());
1094 let mut hessian = xtwx;
1095 hessian += &(penalty.to_owned() * lambda);
1096 hessian
1097 .cholesky(Side::Lower)
1098 .map_err(EstimationError::LinearSystemSolveFailed)?
1099 .solve_mat(&Array2::<f64>::eye(p))
1100 };
1101 let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(&cache);
1102 let mut grad_penalty = Array2::<f64>::zeros((p, p));
1103 for row in 0..p {
1104 for col in 0..p {
1105 grad_penalty[[row, col]] += 0.5
1106 * (d as f64)
1107 * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1108 }
1109 }
1110 let mut sigma2 = Array1::<f64>::zeros(d);
1111
1112 for output in 0..d {
1113 let mut weighted_rss = 0.0;
1114 for row in 0..n {
1115 let r = residual[[row, output]];
1116 weighted_rss += weight[row] * r * r;
1117 }
1118 let beta_col = coefficients.column(output);
1119 let s_beta_col = s_beta.column(output);
1120 let penalty_quadratic = beta_col.dot(&s_beta_col);
1121 let dp = (weighted_rss + lambda * penalty_quadratic).max(MIN_DEVIANCE);
1122 sigma2[output] = dp / nu;
1123 reml_score += 0.5 * nu * (1.0 + (2.0 * std::f64::consts::PI * dp / nu).ln());
1124 grad_log_lambda += 0.5 * nu * lambda * penalty_quadratic / dp;
1125 let scale = nu / dp;
1126 for coeff in 0..p {
1127 grad_coefficients[[coeff, output]] =
1128 scale * (-xtw_residual[[coeff, output]] + lambda * s_beta[[coeff, output]]);
1129 }
1130 add_rank_one_penalty_vjp(0.5 * scale * lambda, beta_col, &mut grad_penalty);
1131 }
1132 for i in 0..p {
1133 for j in (i + 1)..p {
1134 let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1135 grad_penalty[[i, j]] = avg;
1136 grad_penalty[[j, i]] = avg;
1137 }
1138 }
1139
1140 Ok(GaussianRemlFreeBScore {
1141 reml_score,
1142 grad_coefficients,
1143 grad_penalty,
1144 grad_log_lambda,
1145 fitted,
1146 sigma2,
1147 edf,
1148 })
1149}
1150
1151pub fn gaussian_reml_multi_closed_form_backward(
1152 x: ArrayView2<'_, f64>,
1153 y: ArrayView2<'_, f64>,
1154 penalty: ArrayView2<'_, f64>,
1155 weights: Option<ArrayView1<'_, f64>>,
1156 init_lambda: Option<f64>,
1157 upstream_lambda: f64,
1158 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1159 upstream_fitted: Option<ArrayView2<'_, f64>>,
1160 upstream_reml_score: f64,
1161 upstream_edf: f64,
1162) -> Result<GaussianRemlBackwardResult, EstimationError> {
1163 let fit =
1164 gaussian_reml_multi_closed_form_with_cache(x, y, penalty, weights, init_lambda, None)?;
1165 gaussian_reml_multi_closed_form_backward_from_fit(
1166 x,
1167 y,
1168 penalty,
1169 weights,
1170 &fit,
1171 upstream_lambda,
1172 upstream_coefficients,
1173 upstream_fitted,
1174 upstream_reml_score,
1175 upstream_edf,
1176 )
1177}
1178
1179pub fn gaussian_reml_multi_closed_form_backward_from_fit(
1180 x: ArrayView2<'_, f64>,
1181 y: ArrayView2<'_, f64>,
1182 penalty: ArrayView2<'_, f64>,
1183 weights: Option<ArrayView1<'_, f64>>,
1184 fit: &GaussianRemlMultiResult,
1185 upstream_lambda: f64,
1186 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1187 upstream_fitted: Option<ArrayView2<'_, f64>>,
1188 upstream_reml_score: f64,
1189 upstream_edf: f64,
1190) -> Result<GaussianRemlBackwardResult, EstimationError> {
1191 validate_gaussian_reml_backward_upstreams(
1192 x,
1193 y,
1194 penalty,
1195 upstream_lambda,
1196 upstream_coefficients,
1197 upstream_fitted,
1198 upstream_reml_score,
1199 upstream_edf,
1200 )?;
1201 validate_gaussian_reml_forward_fit(x, y, penalty, weights, fit)?;
1202 let lambda = fit.lambda;
1203 let n = x.nrows();
1204 let p = x.ncols();
1205 let d = y.ncols();
1206 if !(fit.reml_hess_rho.is_finite() && fit.reml_hess_rho.abs() > 1.0e-14) {
1207 warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1213 return Ok(zero_backward_result(n, p, d));
1214 }
1215 let weight = gaussian_reml_weights(n, weights)?;
1216 let inverse_hessian = match gaussian_reml_inverse_hessian_from_cache(&fit.cache, lambda) {
1217 Ok(inv) => inv,
1218 Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1219 warn_ill_conditioned_backward_once(p, d, condition_number);
1220 return Ok(zero_backward_result(n, p, d));
1221 }
1222 Err(err) => return Err(err),
1223 };
1224 gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1225 x,
1226 y,
1227 penalty,
1228 weight,
1229 fit,
1230 inverse_hessian,
1231 upstream_lambda,
1232 upstream_coefficients,
1233 upstream_fitted,
1234 upstream_reml_score,
1235 upstream_edf,
1236 n,
1237 p,
1238 d,
1239 )
1240}
1241
1242fn gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1243 x: ArrayView2<'_, f64>,
1244 y: ArrayView2<'_, f64>,
1245 penalty: ArrayView2<'_, f64>,
1246 weight: Array1<f64>,
1247 fit: &GaussianRemlMultiResult,
1248 inverse_hessian: Array2<f64>,
1249 upstream_lambda: f64,
1250 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1251 upstream_fitted: Option<ArrayView2<'_, f64>>,
1252 upstream_reml_score: f64,
1253 upstream_edf: f64,
1254 n: usize,
1255 p: usize,
1256 d: usize,
1257) -> Result<GaussianRemlBackwardResult, EstimationError> {
1258 let penalty_owned = canonicalize_penalty(penalty);
1262 let penalty = penalty_owned.view();
1263 let lambda = fit.lambda;
1264 let beta = &fit.coefficients;
1265 let residual = y.to_owned() - &fit.fitted;
1266 let nu = n as f64 - fit.cache.nullity as f64;
1267
1268 let mut grad_x = Array2::<f64>::zeros((n, p));
1269 let mut grad_y = Array2::<f64>::zeros((n, d));
1270 let mut grad_penalty = Array2::<f64>::zeros((p, p));
1271 let mut grad_weights = Array1::<f64>::zeros(n);
1272
1273 let mut upstream_beta = Array2::<f64>::zeros((p, d));
1274 if let Some(upstream_coefficients) = upstream_coefficients {
1275 upstream_beta += &upstream_coefficients;
1276 }
1277 if let Some(upstream_fitted) = upstream_fitted {
1278 upstream_beta += &dense_atb(x, upstream_fitted);
1279 grad_x += &dense_ab(upstream_fitted, beta.t());
1280 }
1281
1282 let mut lambda_adjoint = upstream_lambda;
1283 if upstream_beta.iter().any(|value| *value != 0.0) {
1284 add_ridge_profile_vjp_with_lambda_grad(
1289 1.0,
1290 x,
1291 y,
1292 penalty,
1293 &weight,
1294 lambda,
1295 &inverse_hessian,
1296 beta,
1297 upstream_beta.view(),
1298 &mut grad_x,
1299 &mut grad_y,
1300 &mut grad_penalty,
1301 &mut grad_weights,
1302 &mut lambda_adjoint,
1303 );
1304 }
1305
1306 if upstream_reml_score != 0.0 {
1307 add_reml_score_vjp(
1308 upstream_reml_score,
1309 x,
1310 &weight,
1311 &inverse_hessian,
1312 beta,
1313 &residual,
1314 &fit.sigma2,
1315 nu,
1316 lambda,
1317 &fit.cache,
1318 &mut grad_x,
1319 &mut grad_y,
1320 &mut grad_penalty,
1321 &mut grad_weights,
1322 );
1323 lambda_adjoint += upstream_reml_score * fit.reml_grad_lambda;
1324 }
1325
1326 if upstream_edf != 0.0 {
1327 lambda_adjoint += add_edf_vjp(
1328 upstream_edf,
1329 x,
1330 penalty,
1331 &weight,
1332 lambda,
1333 &inverse_hessian,
1334 &mut grad_x,
1335 &mut grad_penalty,
1336 &mut grad_weights,
1337 );
1338 }
1339
1340 if lambda_adjoint != 0.0 {
1341 let root_scale = -lambda_adjoint * lambda / fit.reml_hess_rho;
1342 add_reml_rho_gradient_vjp(
1343 root_scale,
1344 x,
1345 y,
1346 penalty,
1347 &weight,
1348 lambda,
1349 &inverse_hessian,
1350 beta,
1351 &residual,
1352 &fit.sigma2,
1353 nu,
1354 &mut grad_x,
1355 &mut grad_y,
1356 &mut grad_penalty,
1357 &mut grad_weights,
1358 );
1359 }
1360
1361 let p = grad_penalty.nrows();
1370 for i in 0..p {
1371 for j in (i + 1)..p {
1372 let avg = 0.5 * (grad_penalty[[i, j]] + grad_penalty[[j, i]]);
1373 grad_penalty[[i, j]] = avg;
1374 grad_penalty[[j, i]] = avg;
1375 }
1376 }
1377 Ok(GaussianRemlBackwardResult {
1378 grad_x,
1379 grad_y,
1380 grad_penalty,
1381 grad_weights,
1382 })
1383}
1384
1385pub fn gaussian_reml_multi_closed_form_backward_batch<'a>(
1386 problems: &[GaussianRemlMultiBackwardProblem<'a>],
1387 penalty: ArrayView2<'a, f64>,
1388) -> Vec<Result<GaussianRemlBackwardResult, EstimationError>> {
1389 let inverse_hessians = batched_inverse_hessians_from_caches(problems);
1390 let results: Vec<Result<GaussianRemlBackwardResult, EstimationError>> = problems
1391 .par_iter()
1392 .zip(inverse_hessians.into_par_iter())
1393 .map(|(problem, inverse_hessian_result)| {
1394 validate_gaussian_reml_backward_upstreams(
1395 problem.x.view(),
1396 problem.y.view(),
1397 penalty,
1398 problem.grad_lambda,
1399 problem.grad_coefficients.as_ref().map(|g| g.view()),
1400 problem.grad_fitted.as_ref().map(|g| g.view()),
1401 problem.grad_reml_score,
1402 problem.grad_edf,
1403 )?;
1404 validate_gaussian_reml_forward_fit(
1405 problem.x.view(),
1406 problem.y.view(),
1407 penalty,
1408 problem.weights.as_ref().map(|w| w.view()),
1409 problem.fit,
1410 )?;
1411 let n = problem.x.nrows();
1412 let p = problem.x.ncols();
1413 let d = problem.y.ncols();
1414 if !(problem.fit.reml_hess_rho.is_finite() && problem.fit.reml_hess_rho.abs() > 1.0e-14)
1415 {
1416 warn_ill_conditioned_backward_once(p, d, f64::INFINITY);
1418 return Ok(zero_backward_result(n, p, d));
1419 }
1420 let weight = gaussian_reml_weights(n, problem.weights.as_ref().map(|w| w.view()))?;
1421 let inverse_hessian = match inverse_hessian_result {
1422 Ok(inv) => inv,
1423 Err(EstimationError::ModelIsIllConditioned { condition_number }) => {
1424 warn_ill_conditioned_backward_once(p, d, condition_number);
1425 return Ok(zero_backward_result(n, p, d));
1426 }
1427 Err(err) => return Err(err),
1428 };
1429 gaussian_reml_multi_closed_form_backward_from_fit_with_inverse_hessian_impl(
1430 problem.x.view(),
1431 problem.y.view(),
1432 penalty,
1433 weight,
1434 problem.fit,
1435 inverse_hessian,
1436 problem.grad_lambda,
1437 problem.grad_coefficients.as_ref().map(|g| g.view()),
1438 problem.grad_fitted.as_ref().map(|g| g.view()),
1439 problem.grad_reml_score,
1440 problem.grad_edf,
1441 n,
1442 p,
1443 d,
1444 )
1445 })
1446 .collect();
1447 results
1448}
1449
1450fn rho_derivatives_to_lambda(lambda: f64, grad_rho: f64, hess_rho: f64) -> (f64, f64) {
1451 (grad_rho / lambda, (hess_rho - grad_rho) / (lambda * lambda))
1452}
1453
1454fn validate_gaussian_reml_backward_upstreams(
1455 x: ArrayView2<'_, f64>,
1456 y: ArrayView2<'_, f64>,
1457 penalty: ArrayView2<'_, f64>,
1458 upstream_lambda: f64,
1459 upstream_coefficients: Option<ArrayView2<'_, f64>>,
1460 upstream_fitted: Option<ArrayView2<'_, f64>>,
1461 upstream_reml_score: f64,
1462 upstream_edf: f64,
1463) -> Result<(), EstimationError> {
1464 if !(upstream_lambda.is_finite() && upstream_reml_score.is_finite() && upstream_edf.is_finite())
1465 {
1466 crate::bail_invalid_estim!("Gaussian REML backward upstream scalars must be finite");
1467 }
1468 if let Some(upstream_coefficients) = upstream_coefficients {
1469 if upstream_coefficients.dim() != (x.ncols(), y.ncols()) {
1470 crate::bail_invalid_estim!(
1471 "Gaussian REML backward coefficient upstream shape mismatch: expected {}x{}, got {}x{}",
1472 x.ncols(),
1473 y.ncols(),
1474 upstream_coefficients.nrows(),
1475 upstream_coefficients.ncols()
1476 );
1477 }
1478 if upstream_coefficients.iter().any(|value| !value.is_finite()) {
1479 crate::bail_invalid_estim!(
1480 "Gaussian REML backward coefficient upstream must be finite"
1481 );
1482 }
1483 }
1484 if let Some(upstream_fitted) = upstream_fitted {
1485 if upstream_fitted.dim() != y.dim() {
1486 crate::bail_invalid_estim!(
1487 "Gaussian REML backward fitted upstream shape mismatch: expected {}x{}, got {}x{}",
1488 y.nrows(),
1489 y.ncols(),
1490 upstream_fitted.nrows(),
1491 upstream_fitted.ncols()
1492 );
1493 }
1494 if upstream_fitted.iter().any(|value| !value.is_finite()) {
1495 crate::bail_invalid_estim!("Gaussian REML backward fitted upstream must be finite");
1496 }
1497 }
1498 validate_gaussian_reml_design(x, penalty, None)?;
1499 Ok(())
1500}
1501
1502fn validate_gaussian_reml_forward_fit(
1503 x: ArrayView2<'_, f64>,
1504 y: ArrayView2<'_, f64>,
1505 penalty: ArrayView2<'_, f64>,
1506 weights: Option<ArrayView1<'_, f64>>,
1507 fit: &GaussianRemlMultiResult,
1508) -> Result<(), EstimationError> {
1509 let penalty_owned = canonicalize_penalty(penalty);
1513 let penalty = penalty_owned.view();
1514 let n = x.nrows();
1515 let p = x.ncols();
1516 let d = y.ncols();
1517 validate_gaussian_reml_design(x, penalty, weights)?;
1518 validate_gaussian_reml_eigen_cache(&fit.cache, p)?;
1519 if y.nrows() != n
1520 || fit.coefficients.dim() != (p, d)
1521 || fit.fitted.dim() != (n, d)
1522 || fit.sigma2.len() != d
1523 {
1524 crate::bail_invalid_estim!(
1525 "Gaussian REML backward forward-state shape mismatch: expected coefficients=({p},{d}), fitted=({n},{d}), sigma2={d}"
1526 );
1527 }
1528 if !(fit.lambda.is_finite()
1529 && fit.lambda > 0.0
1530 && fit.rho.is_finite()
1531 && fit.reml_score.is_finite()
1532 && fit.reml_hess_rho.is_finite()
1533 && fit.edf.is_finite())
1534 || fit.coefficients.iter().any(|value| !value.is_finite())
1535 || fit.fitted.iter().any(|value| !value.is_finite())
1536 || fit.sigma2.iter().any(|value| !value.is_finite())
1537 {
1538 crate::bail_invalid_estim!("Gaussian REML backward forward state must be finite");
1539 }
1540 let penalty_fingerprint = matrix_fingerprint(penalty);
1541 if fit.cache.penalty_fingerprint != penalty_fingerprint {
1542 crate::bail_invalid_estim!("Gaussian REML backward forward-state penalty mismatch");
1543 }
1544 let weight = gaussian_reml_weights(n, weights)?;
1545 let xtwx = dense_xt_diag_x(x, weight.view());
1546 if fit.cache.xtwx_fingerprint != matrix_fingerprint(xtwx.view()) {
1547 crate::bail_invalid_estim!("Gaussian REML backward forward-state X'WX mismatch");
1548 }
1549 Ok(())
1550}
1551
1552fn gaussian_reml_inverse_hessian_from_cache(
1553 cache: &GaussianRemlEigenCache,
1554 lambda: f64,
1555) -> Result<Array2<f64>, EstimationError> {
1556 if !(lambda.is_finite() && lambda > 0.0) {
1557 crate::bail_invalid_estim!(
1558 "Gaussian REML lambda must be finite and positive; got {lambda}"
1559 );
1560 }
1561 let p = cache.penalty_eigenvalues.len();
1562 let mut scaled_basis = cache.coefficient_basis.clone();
1563 for eig in 0..p {
1564 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1565 for row in 0..p {
1566 scaled_basis[[row, eig]] *= scale;
1567 }
1568 }
1569 let inverse = dense_ab(scaled_basis.view(), cache.coefficient_basis.t());
1570 if inverse.iter().any(|value| !value.is_finite()) {
1571 return Err(EstimationError::ModelIsIllConditioned {
1572 condition_number: f64::INFINITY,
1573 });
1574 }
1575 Ok(inverse)
1576}
1577
1578fn batched_inverse_hessians_from_caches(
1579 problems: &[GaussianRemlMultiBackwardProblem<'_>],
1580) -> Vec<Result<Array2<f64>, EstimationError>> {
1581 if problems.is_empty() {
1582 return Vec::new();
1583 }
1584 let p = problems[0].fit.cache.coefficient_basis.nrows();
1585 let uniform = p > 0
1586 && problems.iter().all(|problem| {
1587 let cache = &problem.fit.cache;
1588 cache.coefficient_basis.dim() == (p, p) && cache.penalty_eigenvalues.len() == p
1589 });
1590 if uniform && problems.len() > 1 {
1591 let mut scaled_basis = Array3::<f64>::zeros((problems.len(), p, p));
1592 let mut basis = Array3::<f64>::zeros((problems.len(), p, p));
1593 let mut valid = true;
1594 for (idx, problem) in problems.iter().enumerate() {
1595 let lambda = problem.fit.lambda;
1596 if !(lambda.is_finite() && lambda > 0.0) {
1597 valid = false;
1598 break;
1599 }
1600 let cache = &problem.fit.cache;
1601 basis
1602 .slice_mut(s![idx, .., ..])
1603 .assign(&cache.coefficient_basis);
1604 for eig in 0..p {
1605 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
1606 for row in 0..p {
1607 scaled_basis[[idx, row, eig]] = cache.coefficient_basis[[row, eig]] * scale;
1608 }
1609 }
1610 }
1611 if valid
1612 && let Some(inverses) =
1613 gam_gpu::try_fast_abt_strided_batched(scaled_basis.view(), basis.view())
1614 {
1615 return inverses
1616 .axis_iter(Axis(0))
1617 .map(|inverse| Ok(inverse.to_owned()))
1618 .collect();
1619 }
1620 }
1621 problems
1622 .iter()
1623 .map(|problem| {
1624 gaussian_reml_inverse_hessian_from_cache(&problem.fit.cache, problem.fit.lambda)
1625 })
1626 .collect()
1627}
1628
1629fn ridge_profile_vjp_data_partials(
1636 scale: f64,
1637 x: ArrayView2<'_, f64>,
1638 y: ArrayView2<'_, f64>,
1639 penalty: ArrayView2<'_, f64>,
1640 weights: &Array1<f64>,
1641 lambda: f64,
1642 inverse_hessian: &Array2<f64>,
1643 beta: &Array2<f64>,
1644 upstream_beta: ArrayView2<'_, f64>,
1645 grad_x: &mut Array2<f64>,
1646 grad_y: &mut Array2<f64>,
1647 grad_penalty: &mut Array2<f64>,
1648 grad_weights: &mut Array1<f64>,
1649) -> Array2<f64> {
1650 let m = dense_ab(inverse_hessian.view(), upstream_beta);
1651 let c = dense_ab(m.view(), beta.t());
1652 let c_sym = &c + &c.t();
1653 let ymt = dense_ab(y, m.t());
1654 let xcs = dense_ab(x, c_sym.view());
1655 for i in 0..x.nrows() {
1656 let wi = weights[i] * scale;
1657 for k in 0..x.ncols() {
1658 grad_x[[i, k]] += wi * (ymt[[i, k]] - xcs[[i, k]]);
1659 }
1660 }
1661
1662 let xm = dense_ab(x, m.view());
1663 for i in 0..x.nrows() {
1664 let wi = weights[i] * scale;
1665 for j in 0..y.ncols() {
1666 grad_y[[i, j]] += wi * xm[[i, j]];
1667 }
1668 }
1669
1670 let xc = dense_ab(x, c.view());
1671 for i in 0..x.nrows() {
1672 let mut from_b = 0.0;
1673 for j in 0..y.ncols() {
1674 from_b += y[[i, j]] * xm[[i, j]];
1675 }
1676 let mut from_a = 0.0;
1677 for k in 0..x.ncols() {
1678 from_a += x[[i, k]] * xc[[i, k]];
1679 }
1680 grad_weights[i] += scale * (from_b - from_a);
1681 }
1682
1683 for row in 0..penalty.nrows() {
1684 for col in 0..penalty.ncols() {
1685 let mut value = 0.0;
1686 for output in 0..beta.ncols() {
1687 value += m[[row, output]] * beta[[col, output]];
1688 }
1689 grad_penalty[[row, col]] -= scale * lambda * value;
1690 }
1691 }
1692 m
1693}
1694
1695fn add_ridge_profile_vjp_with_lambda_grad(
1700 scale: f64,
1701 x: ArrayView2<'_, f64>,
1702 y: ArrayView2<'_, f64>,
1703 penalty: ArrayView2<'_, f64>,
1704 weights: &Array1<f64>,
1705 lambda: f64,
1706 inverse_hessian: &Array2<f64>,
1707 beta: &Array2<f64>,
1708 upstream_beta: ArrayView2<'_, f64>,
1709 grad_x: &mut Array2<f64>,
1710 grad_y: &mut Array2<f64>,
1711 grad_penalty: &mut Array2<f64>,
1712 grad_weights: &mut Array1<f64>,
1713 lambda_adjoint_out: &mut f64,
1714) {
1715 let m = ridge_profile_vjp_data_partials(
1716 scale,
1717 x,
1718 y,
1719 penalty,
1720 weights,
1721 lambda,
1722 inverse_hessian,
1723 beta,
1724 upstream_beta,
1725 grad_x,
1726 grad_y,
1727 grad_penalty,
1728 grad_weights,
1729 );
1730 let penalty_beta = dense_ab(penalty, beta.view());
1731 let dot = m
1732 .iter()
1733 .zip(penalty_beta.iter())
1734 .map(|(left, right)| left * right)
1735 .sum::<f64>();
1736 *lambda_adjoint_out += -scale * dot;
1737}
1738
1739fn add_ridge_profile_vjp_fixed_lambda(
1743 scale: f64,
1744 x: ArrayView2<'_, f64>,
1745 y: ArrayView2<'_, f64>,
1746 penalty: ArrayView2<'_, f64>,
1747 weights: &Array1<f64>,
1748 lambda: f64,
1749 inverse_hessian: &Array2<f64>,
1750 beta: &Array2<f64>,
1751 upstream_beta: ArrayView2<'_, f64>,
1752 grad_x: &mut Array2<f64>,
1753 grad_y: &mut Array2<f64>,
1754 grad_penalty: &mut Array2<f64>,
1755 grad_weights: &mut Array1<f64>,
1756) {
1757 ridge_profile_vjp_data_partials(
1758 scale,
1759 x,
1760 y,
1761 penalty,
1762 weights,
1763 lambda,
1764 inverse_hessian,
1765 beta,
1766 upstream_beta,
1767 grad_x,
1768 grad_y,
1769 grad_penalty,
1770 grad_weights,
1771 );
1772}
1773
1774fn add_reml_score_vjp(
1775 scale: f64,
1776 x: ArrayView2<'_, f64>,
1777 weights: &Array1<f64>,
1778 inverse_hessian: &Array2<f64>,
1779 beta: &Array2<f64>,
1780 residual: &Array2<f64>,
1781 sigma2: &Array1<f64>,
1782 nu: f64,
1783 lambda: f64,
1784 cache: &GaussianRemlEigenCache,
1785 grad_x: &mut Array2<f64>,
1786 grad_y: &mut Array2<f64>,
1787 grad_penalty: &mut Array2<f64>,
1788 grad_weights: &mut Array1<f64>,
1789) {
1790 let d = beta.ncols() as f64;
1791 let xp = dense_ab(x, inverse_hessian.view());
1792 let penalty_pinv = gaussian_reml_penalty_pseudoinverse_from_cache(cache);
1793 for row in 0..grad_penalty.nrows() {
1794 for col in 0..grad_penalty.ncols() {
1795 grad_penalty[[row, col]] +=
1796 scale * 0.5 * d * (lambda * inverse_hessian[[col, row]] - penalty_pinv[[col, row]]);
1797 }
1798 }
1799 for i in 0..x.nrows() {
1800 let wi = weights[i] * scale * d;
1801 for k in 0..x.ncols() {
1802 grad_x[[i, k]] += wi * xp[[i, k]];
1803 }
1804 let mut leverage = 0.0;
1805 for k in 0..x.ncols() {
1806 leverage += x[[i, k]] * xp[[i, k]];
1807 }
1808 grad_weights[i] += scale * 0.5 * d * leverage;
1809 }
1810
1811 for j in 0..beta.ncols() {
1812 let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1813 let coef = scale * 0.5 * nu / dp;
1814 add_deviance_profile_vjp(
1815 coef,
1816 j,
1817 x,
1818 weights,
1819 beta,
1820 residual,
1821 grad_x,
1822 grad_y,
1823 grad_weights,
1824 );
1825 add_rank_one_penalty_vjp(coef * lambda, beta.column(j), grad_penalty);
1826 }
1827}
1828
1829fn add_edf_vjp(
1840 scale: f64,
1841 x: ArrayView2<'_, f64>,
1842 penalty: ArrayView2<'_, f64>,
1843 weights: &Array1<f64>,
1844 lambda: f64,
1845 inverse_hessian: &Array2<f64>,
1846 grad_x: &mut Array2<f64>,
1847 grad_penalty: &mut Array2<f64>,
1848 grad_weights: &mut Array1<f64>,
1849) -> f64 {
1850 let m_inv_s = dense_ab(inverse_hessian.view(), penalty);
1852 let mut g_a = dense_ab(m_inv_s.view(), inverse_hessian.view());
1853 g_a.mapv_inplace(|v| v * lambda);
1854
1855 let xg = dense_ab(x, g_a.view());
1859 let leading_scale = 2.0 * scale;
1863 for i in 0..xg.nrows() {
1864 let row_scale = leading_scale * weights[i];
1865 for k in 0..xg.ncols() {
1866 grad_x[[i, k]] += row_scale * xg[[i, k]];
1867 }
1868 }
1869 for i in 0..x.nrows() {
1870 let mut quad = 0.0;
1871 for k in 0..x.ncols() {
1872 quad += x[[i, k]] * xg[[i, k]];
1873 }
1874 grad_weights[i] += scale * quad;
1875 }
1876
1877 for row in 0..grad_penalty.nrows() {
1880 for col in 0..grad_penalty.ncols() {
1881 grad_penalty[[row, col]] +=
1882 scale * (-lambda * inverse_hessian[[row, col]] + lambda * g_a[[row, col]]);
1883 }
1884 }
1885
1886 let p_dim = m_inv_s.nrows();
1888 let mut tr_m_inv_s = 0.0;
1889 for i in 0..p_dim {
1890 tr_m_inv_s += m_inv_s[[i, i]];
1891 }
1892 let mut tr_squared = 0.0;
1893 for i in 0..p_dim {
1894 for j in 0..p_dim {
1895 tr_squared += m_inv_s[[i, j]] * m_inv_s[[j, i]];
1896 }
1897 }
1898 scale * (-tr_m_inv_s + lambda * tr_squared)
1899}
1900
1901fn add_reml_rho_gradient_vjp(
1902 scale: f64,
1903 x: ArrayView2<'_, f64>,
1904 y: ArrayView2<'_, f64>,
1905 penalty: ArrayView2<'_, f64>,
1906 weights: &Array1<f64>,
1907 lambda: f64,
1908 inverse_hessian: &Array2<f64>,
1909 beta: &Array2<f64>,
1910 residual: &Array2<f64>,
1911 sigma2: &Array1<f64>,
1912 nu: f64,
1913 grad_x: &mut Array2<f64>,
1914 grad_y: &mut Array2<f64>,
1915 grad_penalty: &mut Array2<f64>,
1916 grad_weights: &mut Array1<f64>,
1917) {
1918 let d = beta.ncols() as f64;
1919 let inverse_s = dense_ab(inverse_hessian.view(), penalty);
1920 let trace_kernel = dense_ab(inverse_s.view(), inverse_hessian.view());
1921 for row in 0..grad_penalty.nrows() {
1922 for col in 0..grad_penalty.ncols() {
1923 grad_penalty[[row, col]] += scale
1924 * 0.5
1925 * d
1926 * lambda
1927 * (inverse_hessian[[col, row]] - lambda * trace_kernel[[col, row]]);
1928 }
1929 }
1930 let xt = dense_ab(x, trace_kernel.view());
1931 for i in 0..x.nrows() {
1932 let wi = -scale * d * lambda * weights[i];
1933 for k in 0..x.ncols() {
1934 grad_x[[i, k]] += wi * xt[[i, k]];
1935 }
1936 let mut quad = 0.0;
1937 for k in 0..x.ncols() {
1938 quad += x[[i, k]] * xt[[i, k]];
1939 }
1940 grad_weights[i] -= scale * 0.5 * d * lambda * quad;
1941 }
1942
1943 let s_beta = dense_ab(penalty, beta.view());
1944 let mut upstream_beta = Array2::<f64>::zeros(beta.dim());
1945 for j in 0..beta.ncols() {
1946 let dp = (sigma2[j] * nu).max(MIN_DEVIANCE);
1947 let q = lambda * beta.column(j).dot(&s_beta.column(j));
1948 let q_coef = scale * nu / dp;
1949 for row in 0..beta.nrows() {
1950 upstream_beta[[row, j]] = q_coef * lambda * s_beta[[row, j]];
1951 }
1952 let dp_coef = -scale * 0.5 * nu * q / (dp * dp);
1953 add_rank_one_penalty_vjp(
1954 (0.5 * q_coef + dp_coef) * lambda,
1955 beta.column(j),
1956 grad_penalty,
1957 );
1958 add_deviance_profile_vjp(
1959 dp_coef,
1960 j,
1961 x,
1962 weights,
1963 beta,
1964 residual,
1965 grad_x,
1966 grad_y,
1967 grad_weights,
1968 );
1969 }
1970 add_ridge_profile_vjp_fixed_lambda(
1973 1.0,
1974 x,
1975 y,
1976 penalty,
1977 weights,
1978 lambda,
1979 inverse_hessian,
1980 beta,
1981 upstream_beta.view(),
1982 grad_x,
1983 grad_y,
1984 grad_penalty,
1985 grad_weights,
1986 );
1987}
1988
1989fn add_rank_one_penalty_vjp(
1990 scale: f64,
1991 beta_col: ArrayView1<'_, f64>,
1992 grad_penalty: &mut Array2<f64>,
1993) {
1994 for row in 0..beta_col.len() {
1995 for col in 0..beta_col.len() {
1996 grad_penalty[[row, col]] += scale * beta_col[row] * beta_col[col];
1997 }
1998 }
1999}
2000
2001fn gaussian_reml_penalty_pseudoinverse_from_cache(cache: &GaussianRemlEigenCache) -> Array2<f64> {
2002 let p = cache.penalty_eigenvalues.len();
2003 let mut scaled_basis = Array2::<f64>::zeros((p, p));
2004 for eig in 0..p {
2005 let delta = cache.penalty_eigenvalues[eig];
2006 if delta > 0.0 {
2007 for row in 0..p {
2008 scaled_basis[[row, eig]] = cache.coefficient_basis[[row, eig]] / delta;
2009 }
2010 }
2011 }
2012 dense_ab(scaled_basis.view(), cache.coefficient_basis.t())
2013}
2014
2015fn add_deviance_profile_vjp(
2016 scale: f64,
2017 output: usize,
2018 x: ArrayView2<'_, f64>,
2019 weights: &Array1<f64>,
2020 beta: &Array2<f64>,
2021 residual: &Array2<f64>,
2022 grad_x: &mut Array2<f64>,
2023 grad_y: &mut Array2<f64>,
2024 grad_weights: &mut Array1<f64>,
2025) {
2026 for i in 0..x.nrows() {
2027 let r = residual[[i, output]];
2028 let wr_scale = scale * weights[i] * r;
2029 grad_y[[i, output]] += 2.0 * wr_scale;
2030 for k in 0..x.ncols() {
2031 grad_x[[i, k]] -= 2.0 * wr_scale * beta[[k, output]];
2032 }
2033 grad_weights[i] += scale * r * r;
2034 }
2035}
2036
2037fn validate_initial_lambda(lambda: f64) -> Result<f64, EstimationError> {
2038 if lambda.is_finite() && lambda > 0.0 {
2039 Ok(lambda)
2040 } else {
2041 Err(EstimationError::InvalidInput(format!(
2042 "Gaussian REML initial lambda must be finite and positive; got {lambda}"
2043 )))
2044 }
2045}
2046
2047fn dense_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2048 fast_ab(&a, &b)
2049}
2050
2051fn dense_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Array2<f64> {
2052 fast_atb(&a, &b)
2053}
2054
2055fn dense_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Array2<f64> {
2056 fast_xt_diag_x(&x, &w)
2057}
2058
2059fn dense_xt_diag_y(
2060 x: ArrayView2<'_, f64>,
2061 w: ArrayView1<'_, f64>,
2062 y: ArrayView2<'_, f64>,
2063) -> Array2<f64> {
2064 fast_xt_diag_y(&x, &w, &y)
2065}
2066
2067fn matrix_fingerprint(matrix: ArrayView2<'_, f64>) -> u64 {
2068 let mut hash = 0xcbf29ce484222325_u64;
2069 hash = fnv1a_mix(hash, matrix.nrows() as u64);
2070 hash = fnv1a_mix(hash, matrix.ncols() as u64);
2071 for &value in matrix {
2072 hash = fnv1a_mix(hash, value.to_bits());
2073 }
2074 hash
2075}
2076
2077fn fnv1a_mix(hash: u64, value: u64) -> u64 {
2078 (hash ^ value).wrapping_mul(0x100000001b3)
2079}
2080
2081pub fn build_gaussian_reml_eigen_cache_batched(
2086 xtwx_matrices: Vec<Array2<f64>>,
2087 penalty: ArrayView2<'_, f64>,
2088 nullspace_dim: Option<usize>,
2089) -> Vec<Result<GaussianRemlEigenCache, EstimationError>> {
2090 let penalty_owned = canonicalize_penalty(penalty);
2091 let penalty = penalty_owned.view();
2092 let k = xtwx_matrices.len();
2093 if k == 0 {
2094 return Vec::new();
2095 }
2096 let fingerprints: Vec<u64> = xtwx_matrices
2097 .iter()
2098 .map(|m| matrix_fingerprint(m.view()))
2099 .collect();
2100
2101 let p = xtwx_matrices[0].nrows();
2102 let uniform_square = p > 0 && xtwx_matrices.iter().all(|matrix| matrix.dim() == (p, p));
2103 if uniform_square && k > 1 {
2104 let mut lower_matrices = xtwx_matrices.clone();
2105 if gam_gpu::try_cholesky_batched_lower_inplace(&mut lower_matrices).is_some() {
2106 let transforms = batched_whitened_penalty_transforms(&lower_matrices, penalty);
2114 return lower_matrices
2115 .into_iter()
2116 .enumerate()
2117 .map(|(b, lower)| {
2118 let precomputed_transform = transforms.as_ref().map(|t| t[b].clone());
2119 gaussian_reml_eigen_cache_from_lower_with_transform(
2120 lower,
2121 penalty,
2122 nullspace_dim,
2123 fingerprints[b],
2124 precomputed_transform,
2125 )
2126 })
2127 .collect();
2128 }
2129 }
2130
2131 let mut results = Vec::with_capacity(k);
2132 for (b, xtwx) in xtwx_matrices.into_iter().enumerate() {
2133 let lower = match gaussian_reml_cholesky_lower(xtwx) {
2134 Ok(l) => l,
2135 Err(err) => {
2136 results.push(Err(err));
2137 continue;
2138 }
2139 };
2140 results.push(gaussian_reml_eigen_cache_from_lower_with_transform(
2141 lower,
2142 penalty,
2143 nullspace_dim,
2144 fingerprints[b],
2145 None,
2146 ));
2147 }
2148 results
2149}
2150
2151fn batched_whitened_penalty_transforms(
2152 lowers: &[Array2<f64>],
2153 penalty: ArrayView2<'_, f64>,
2154) -> Option<Vec<Array2<f64>>> {
2155 let first = lowers.first()?;
2156 let p = first.nrows();
2157 if p == 0 || first.ncols() != p || lowers.iter().any(|lower| lower.dim() != (p, p)) {
2158 return None;
2159 }
2160 let mut linv_stack = Array3::<f64>::zeros((lowers.len(), p, p));
2161 for (idx, lower) in lowers.iter().enumerate() {
2162 let l_inv = invert_lower_triangular(lower).ok()?;
2163 linv_stack.slice_mut(s![idx, .., ..]).assign(&l_inv);
2164 }
2165 let penalty_in_metric =
2166 gam_gpu::try_fast_ab_broadcast_b_batched(linv_stack.view(), penalty)?;
2167 let transformed =
2168 gam_gpu::try_fast_abt_strided_batched(penalty_in_metric.view(), linv_stack.view())?;
2169 Some(
2170 transformed
2171 .axis_iter(Axis(0))
2172 .map(|matrix| matrix.to_owned())
2173 .collect(),
2174 )
2175}
2176
2177pub fn build_gaussian_reml_eigen_cache(
2178 x: ArrayView2<'_, f64>,
2179 penalty: ArrayView2<'_, f64>,
2180 weights: Option<ArrayView1<'_, f64>>,
2181) -> Result<GaussianRemlEigenCache, EstimationError> {
2182 build_gaussian_reml_eigen_cache_with_nullspace_dim(x, penalty, None, weights)
2183}
2184
2185pub fn build_gaussian_reml_eigen_cache_with_nullspace_dim(
2186 x: ArrayView2<'_, f64>,
2187 penalty: ArrayView2<'_, f64>,
2188 nullspace_dim: Option<usize>,
2189 weights: Option<ArrayView1<'_, f64>>,
2190) -> Result<GaussianRemlEigenCache, EstimationError> {
2191 let penalty_owned = canonicalize_penalty(penalty);
2192 let penalty = penalty_owned.view();
2193 let n = x.nrows();
2194 validate_gaussian_reml_design(x, penalty, weights)?;
2195 let weight = gaussian_reml_weights(n, weights)?;
2196
2197 let xtwx = dense_xt_diag_x(x, weight.view());
2198 gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)
2199}
2200
2201fn validate_gaussian_reml_design(
2202 x: ArrayView2<'_, f64>,
2203 penalty: ArrayView2<'_, f64>,
2204 weights: Option<ArrayView1<'_, f64>>,
2205) -> Result<(), EstimationError> {
2206 let n = x.nrows();
2207 let p = x.ncols();
2208 if penalty.nrows() != p || penalty.ncols() != p {
2209 crate::bail_invalid_estim!(
2210 "Gaussian REML penalty shape mismatch: expected {p}x{p}, got {}x{}",
2211 penalty.nrows(),
2212 penalty.ncols()
2213 );
2214 }
2215 if x.iter().chain(penalty.iter()).any(|v| !v.is_finite()) {
2216 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2217 }
2218 if let Some(w) = weights {
2219 if w.len() != n {
2220 crate::bail_invalid_estim!(
2221 "Gaussian REML weights length mismatch: expected {n}, got {}",
2222 w.len()
2223 );
2224 }
2225 if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2226 crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2227 }
2228 }
2229 Ok(())
2230}
2231
2232fn gaussian_reml_weights(
2233 n: usize,
2234 weights: Option<ArrayView1<'_, f64>>,
2235) -> Result<Array1<f64>, EstimationError> {
2236 match weights {
2237 Some(w) => {
2238 if w.len() != n {
2239 crate::bail_invalid_estim!(
2240 "Gaussian REML weights length mismatch: expected {n}, got {}",
2241 w.len()
2242 );
2243 }
2244 if w.iter().any(|value| !value.is_finite() || *value < 0.0) {
2245 crate::bail_invalid_estim!("Gaussian REML weights must be finite and non-negative");
2246 }
2247 Ok(w.to_owned())
2248 }
2249 None => Ok(Array1::ones(n)),
2250 }
2251}
2252
2253fn gaussian_reml_eigen_cache_from_xtwx(
2254 xtwx: Array2<f64>,
2255 penalty: ArrayView2<'_, f64>,
2256 nullspace_dim: Option<usize>,
2257) -> Result<GaussianRemlEigenCache, EstimationError> {
2258 let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2259 let lower = gaussian_reml_cholesky_lower(xtwx)?;
2260 gaussian_reml_eigen_cache_from_lower(lower, penalty, nullspace_dim, xtwx_fingerprint)
2261}
2262
2263fn gaussian_reml_eigen_cache_from_lower(
2268 lower: Array2<f64>,
2269 penalty: ArrayView2<'_, f64>,
2270 nullspace_dim: Option<usize>,
2271 xtwx_fingerprint: u64,
2272) -> Result<GaussianRemlEigenCache, EstimationError> {
2273 gaussian_reml_eigen_cache_from_lower_with_transform(
2274 lower,
2275 penalty,
2276 nullspace_dim,
2277 xtwx_fingerprint,
2278 None,
2279 )
2280}
2281
2282fn gaussian_reml_eigen_cache_from_lower_with_transform(
2285 lower: Array2<f64>,
2286 penalty: ArrayView2<'_, f64>,
2287 nullspace_dim: Option<usize>,
2288 xtwx_fingerprint: u64,
2289 precomputed_transform: Option<Array2<f64>>,
2290) -> Result<GaussianRemlEigenCache, EstimationError> {
2291 let p = lower.nrows();
2292 if lower.ncols() != p {
2293 crate::bail_invalid_estim!("Gaussian REML Cholesky factor must be square");
2294 }
2295 let penalty_fingerprint = matrix_fingerprint(penalty);
2296 let logdet_xtwx = 2.0 * lower.diag().iter().map(|v| v.ln()).sum::<f64>();
2297 let transformed_penalty = match precomputed_transform {
2298 Some(transformed) => transformed,
2299 None => {
2300 let l_inv = invert_lower_triangular(&lower)?;
2301 let penalty_in_metric = dense_ab(l_inv.view(), penalty);
2302 dense_ab(penalty_in_metric.view(), l_inv.t())
2303 }
2304 };
2305 let (mut penalty_eigenvalues, eigenvectors) =
2306 transformed_penalty.eigh(Side::Lower).map_err(|_| {
2307 EstimationError::ModelIsIllConditioned {
2308 condition_number: f64::INFINITY,
2309 }
2310 })?;
2311 let max_abs_eig = penalty_eigenvalues
2321 .iter()
2322 .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2323 let eig_tol = max_abs_eig * EIGEN_REL_TOL;
2324 for value in &mut penalty_eigenvalues {
2325 if *value < 0.0 && value.abs() <= eig_tol {
2326 *value = 0.0;
2327 }
2328 if *value < 0.0 {
2329 crate::bail_invalid_estim!(
2330 "Gaussian REML penalty is not positive semidefinite; eigenvalue={value:.3e}"
2331 );
2332 }
2333 }
2334 let penalty_rank = penalty_eigenvalues
2335 .iter()
2336 .filter(|&&value| value > eig_tol)
2337 .count();
2338 let nullity = p - penalty_rank;
2339 if let Some(expected_nullity) = nullspace_dim
2340 && expected_nullity != nullity
2341 {
2342 crate::bail_invalid_estim!(
2343 "Gaussian REML penalty nullspace mismatch: expected {expected_nullity}, inferred {nullity}"
2344 );
2345 }
2346 let logdet_penalty_positive = gaussian_penalty_positive_logdet(penalty, penalty_rank)?;
2347 let coefficient_basis = solve_upper_triangular_matrix(&lower.t().to_owned(), &eigenvectors)?;
2348
2349 Ok(GaussianRemlEigenCache {
2350 penalty_eigenvalues,
2351 eigenvectors,
2352 coefficient_basis,
2353 xtwx_fingerprint,
2354 penalty_fingerprint,
2355 logdet_xtwx,
2356 logdet_penalty_positive,
2357 penalty_rank,
2358 nullity,
2359 })
2360}
2361
2362fn gaussian_reml_cholesky_lower(xtwx: Array2<f64>) -> Result<Array2<f64>, EstimationError> {
2363 let mut gpu_candidate = xtwx.clone();
2373 if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2374 return Ok(gpu_candidate);
2375 }
2376 if let Ok(chol) = xtwx.cholesky(Side::Lower) {
2377 return Ok(chol.lower_triangular());
2378 }
2379 let p = xtwx.nrows();
2380 let trace: f64 = (0..p).map(|i| xtwx[[i, i]]).sum();
2381 if !trace.is_finite() || trace <= 0.0 {
2382 return Err(EstimationError::ModelIsIllConditioned {
2383 condition_number: f64::INFINITY,
2384 });
2385 }
2386 let mut jitter = 1e-12 * trace / (p as f64);
2387 for _ in 0..6 {
2388 let mut jittered = xtwx.clone();
2389 for i in 0..p {
2390 jittered[[i, i]] += jitter;
2391 }
2392 let mut gpu_candidate = jittered.clone();
2393 if gam_gpu::try_cholesky_lower_inplace(&mut gpu_candidate).is_some() {
2394 return Ok(gpu_candidate);
2395 }
2396 if let Ok(chol) = jittered.cholesky(Side::Lower) {
2397 return Ok(chol.lower_triangular());
2398 }
2399 jitter *= 10.0;
2400 }
2401 Err(EstimationError::ModelIsIllConditioned {
2402 condition_number: f64::INFINITY,
2403 })
2404}
2405
2406fn gaussian_penalty_positive_logdet(
2407 penalty: ArrayView2<'_, f64>,
2408 penalty_rank: usize,
2409) -> Result<f64, EstimationError> {
2410 if penalty_rank == 0 {
2411 return Ok(0.0);
2412 }
2413 let (pen_eigs, _) = penalty.to_owned().eigh(Side::Lower).map_err(|_| {
2414 EstimationError::ModelIsIllConditioned {
2415 condition_number: f64::INFINITY,
2416 }
2417 })?;
2418 let pen_scale = pen_eigs
2422 .iter()
2423 .fold(0.0_f64, |acc, &value| acc.max(value.abs()));
2424 let pen_tol = pen_scale * EIGEN_REL_TOL;
2425 let mut positive_eigs: Vec<f64> = pen_eigs
2426 .iter()
2427 .copied()
2428 .filter(|&value| value > pen_tol)
2429 .collect();
2430 if positive_eigs.len() != penalty_rank {
2431 positive_eigs = pen_eigs
2432 .iter()
2433 .copied()
2434 .filter(|&value| value > 0.0)
2435 .collect();
2436 positive_eigs.sort_by(|a, b| b.total_cmp(a));
2437 if positive_eigs.len() < penalty_rank {
2438 return Err(EstimationError::ModelIsIllConditioned {
2439 condition_number: f64::INFINITY,
2440 });
2441 }
2442 positive_eigs.truncate(penalty_rank);
2443 }
2444 Ok(positive_eigs.iter().map(|value| value.ln()).sum())
2445}
2446
2447fn validate_gaussian_reml_eigen_cache(
2448 cache: &GaussianRemlEigenCache,
2449 p: usize,
2450) -> Result<(), EstimationError> {
2451 if cache.penalty_eigenvalues.len() != p
2452 || cache.eigenvectors.dim() != (p, p)
2453 || cache.coefficient_basis.dim() != (p, p)
2454 {
2455 crate::bail_invalid_estim!(
2456 "Gaussian REML eigen cache dimension mismatch: expected {p} coefficients"
2457 );
2458 }
2459 if cache.penalty_rank > p || cache.nullity > p || cache.penalty_rank + cache.nullity != p {
2460 crate::bail_invalid_estim!(
2461 "Gaussian REML eigen cache rank/nullity mismatch: rank={}, nullity={}, p={p}",
2462 cache.penalty_rank,
2463 cache.nullity
2464 );
2465 }
2466 if !(cache.logdet_xtwx.is_finite() && cache.logdet_penalty_positive.is_finite()) {
2467 crate::bail_invalid_estim!("Gaussian REML eigen cache log-determinants must be finite");
2468 }
2469 if cache
2470 .penalty_eigenvalues
2471 .iter()
2472 .any(|value| !value.is_finite() || *value < 0.0)
2473 || cache.eigenvectors.iter().any(|value| !value.is_finite())
2474 || cache
2475 .coefficient_basis
2476 .iter()
2477 .any(|value| !value.is_finite())
2478 {
2479 crate::bail_invalid_estim!(
2480 "Gaussian REML eigen cache entries must be finite with non-negative eigenvalues"
2481 .to_string(),
2482 );
2483 }
2484 Ok::<(), _>(())
2485}
2486
2487fn prepare_gaussian_reml(
2488 x: ArrayView2<'_, f64>,
2489 y: ArrayView2<'_, f64>,
2490 penalty: ArrayView2<'_, f64>,
2491 nullspace_dim: Option<usize>,
2492 weights: Option<ArrayView1<'_, f64>>,
2493 eigen_cache: Option<&GaussianRemlEigenCache>,
2494) -> Result<GaussianRemlPrepared, EstimationError> {
2495 let penalty_owned = canonicalize_penalty(penalty);
2498 let penalty = penalty_owned.view();
2499 let n = x.nrows();
2500 let p = x.ncols();
2501 let d = y.ncols();
2502 validate_gaussian_reml_design(x, penalty, weights)?;
2503 if y.nrows() != n {
2504 crate::bail_invalid_estim!(
2505 "Gaussian REML row mismatch: X has {n} rows but Y has {}",
2506 y.nrows()
2507 );
2508 }
2509 if y.iter().any(|v| !v.is_finite()) {
2510 crate::bail_invalid_estim!("Gaussian REML inputs must be finite");
2511 }
2512 let weight = gaussian_reml_weights(n, weights)?;
2513
2514 let xtwy = dense_xt_diag_y(x, weight.view(), y);
2515 let ywy = Array1::from_iter((0..d).map(|j| {
2516 let mut value = 0.0;
2517 for row in 0..n {
2518 value += weight[row] * y[[row, j]] * y[[row, j]];
2519 }
2520 value
2521 }));
2522 let xtwx = dense_xt_diag_x(x, weight.view());
2523
2524 if let Some(cache) = eigen_cache {
2525 validate_gaussian_reml_eigen_cache(cache, p)?;
2526 let xtwx_fingerprint = matrix_fingerprint(xtwx.view());
2527 if cache.xtwx_fingerprint != xtwx_fingerprint {
2528 crate::bail_invalid_estim!("Gaussian REML eigen cache X'WX mismatch");
2529 }
2530 let penalty_fingerprint = matrix_fingerprint(penalty);
2531 if cache.penalty_fingerprint != penalty_fingerprint {
2532 crate::bail_invalid_estim!("Gaussian REML eigen cache penalty mismatch");
2533 }
2534 if let Some(expected_nullity) = nullspace_dim
2535 && expected_nullity != cache.nullity
2536 {
2537 crate::bail_invalid_estim!(
2538 "Gaussian REML eigen cache nullspace mismatch: expected {expected_nullity}, got {}",
2539 cache.nullity
2540 );
2541 }
2542 if n <= cache.nullity {
2543 crate::bail_invalid_estim!(
2544 "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
2545 cache.nullity
2546 );
2547 }
2548 let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2549 let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2550 return Ok(GaussianRemlPrepared {
2551 cache: cache.clone(),
2552 ywy,
2553 projected_rhs_squared,
2554 projected_rhs,
2555 n_observations: n,
2556 n_outputs: d,
2557 });
2558 }
2559
2560 let cache = gaussian_reml_eigen_cache_from_xtwx(xtwx, penalty, nullspace_dim)?;
2561 if n <= cache.nullity {
2562 crate::bail_invalid_estim!(
2563 "Gaussian REML requires n > nullspace dimension; got n={n}, nullity={}",
2564 cache.nullity
2565 );
2566 }
2567 let projected_rhs = dense_atb(cache.coefficient_basis.view(), xtwy.view());
2568 let projected_rhs_squared = projected_rhs.mapv(|value| value * value);
2569
2570 Ok(GaussianRemlPrepared {
2571 cache,
2572 ywy,
2573 projected_rhs_squared,
2574 projected_rhs,
2575 n_observations: n,
2576 n_outputs: d,
2577 })
2578}
2579
2580impl GaussianRemlPrepared {
2581 fn nu(&self) -> f64 {
2582 self.n_observations as f64 - self.cache.nullity as f64
2583 }
2584
2585 fn evaluate(&self, rho: f64) -> ObjectiveEval {
2586 evaluate_reml_parts(
2587 &self.cache,
2588 self.ywy.view(),
2589 self.projected_rhs_squared.view(),
2590 self.n_observations,
2591 self.n_outputs,
2592 rho,
2593 )
2594 }
2595
2596 fn coefficients(&self, lambda: f64) -> Array2<f64> {
2597 let mut scaled = self.projected_rhs.clone();
2598 for i in 0..self.cache.penalty_eigenvalues.len() {
2599 let scale = 1.0 / (1.0 + lambda * self.cache.penalty_eigenvalues[i]);
2600 for value in scaled.row_mut(i) {
2601 *value *= scale;
2602 }
2603 }
2604 dense_ab(self.cache.coefficient_basis.view(), scaled.view())
2605 }
2606
2607 fn sigma2(&self, lambda: f64) -> Array1<f64> {
2608 let nu = self.nu();
2609 Array1::from_iter((0..self.n_outputs).map(|j| {
2610 let mut fitted_quadratic = 0.0;
2611 for i in 0..self.cache.penalty_eigenvalues.len() {
2612 let denom = 1.0 + lambda * self.cache.penalty_eigenvalues[i];
2613 fitted_quadratic += self.projected_rhs_squared[[i, j]] / denom;
2614 }
2615 ((self.ywy[j] - fitted_quadratic).max(MIN_DEVIANCE)) / nu
2616 }))
2617 }
2618}
2619
2620fn optimize_rho(
2621 prepared: &GaussianRemlPrepared,
2622 init_rho: Option<f64>,
2623) -> Result<f64, EstimationError> {
2624 if prepared.cache.penalty_rank == 0 {
2625 return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2626 }
2627
2628 const GRID_INTERVALS: usize = 96;
2629 let mut stationary = Vec::<f64>::new();
2630 let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2631 let mut prev_rho = RHO_LOWER;
2632 let mut prev_eval = prepared.evaluate(prev_rho);
2633 grid.push((prev_rho, prev_eval.cost));
2634 for i in 1..=GRID_INTERVALS {
2635 let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2636 let eval = prepared.evaluate(rho);
2637 grid.push((rho, eval.cost));
2638 if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2639 push_candidate(
2640 &mut stationary,
2641 refine_stationary_rho(prepared, prev_rho, rho, 0.5 * (prev_rho + rho)),
2642 );
2643 }
2644 prev_rho = rho;
2645 prev_eval = eval;
2646 }
2647
2648 let mut candidates = stationary;
2649 push_candidate(&mut candidates, RHO_LOWER);
2650 push_candidate(&mut candidates, RHO_UPPER);
2651 if let Some(rho0) = init_rho {
2652 push_candidate(&mut candidates, rho0);
2653 }
2654 if let Some(rho) = refine_best_grid_cell(prepared, &grid) {
2655 push_candidate(&mut candidates, rho);
2656 }
2657
2658 candidates
2662 .into_iter()
2663 .map(|rho| (rho, prepared.evaluate(rho).cost))
2664 .min_by(|(_, a), (_, b)| a.total_cmp(b))
2665 .map(|(rho, _)| rho)
2666 .ok_or_else(|| {
2667 EstimationError::InvalidInput(
2668 "Gaussian REML optimizer produced no candidates".to_string(),
2669 )
2670 })
2671}
2672
2673fn refine_best_grid_cell(prepared: &GaussianRemlPrepared, grid: &[(f64, f64)]) -> Option<f64> {
2674 let best_idx = grid
2675 .iter()
2676 .enumerate()
2677 .filter(|(_, (_, cost))| cost.is_finite())
2678 .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2679 .map(|(idx, _)| idx)?;
2680 if best_idx == 0 || best_idx + 1 == grid.len() {
2681 return Some(grid[best_idx].0);
2682 }
2683 Some(refine_stationary_rho(
2699 prepared,
2700 grid[best_idx - 1].0,
2701 grid[best_idx + 1].0,
2702 grid[best_idx].0,
2703 ))
2704}
2705
2706fn fill_weighted_rhs_no_alloc(
2707 x: ArrayView2<'_, f64>,
2708 y: ArrayView2<'_, f64>,
2709 weights: Option<ArrayView1<'_, f64>>,
2710 workspace: &mut GaussianRemlNoAllocWorkspace,
2711) -> Result<(), EstimationError> {
2712 let d = y.ncols();
2713
2714 let (xtwy, ywy_full) = match weights {
2720 Some(w) => (fast_xt_diag_y(&x, &w, &y), fast_xt_diag_y(&y, &w, &y)),
2721 None => (fast_atb(&x, &y), fast_atb(&y, &y)),
2722 };
2723 workspace.xtwy.assign(&xtwy);
2724 for output in 0..d {
2725 workspace.ywy[output] = ywy_full[[output, output]];
2726 }
2727
2728 if workspace
2729 .xtwy
2730 .iter()
2731 .chain(workspace.ywy.iter())
2732 .any(|value| !value.is_finite())
2733 {
2734 crate::bail_invalid_estim!("Gaussian REML weighted cross-products must be finite");
2735 }
2736 Ok(())
2737}
2738
2739fn project_rhs_no_alloc(
2740 cache: &GaussianRemlEigenCache,
2741 workspace: &mut GaussianRemlNoAllocWorkspace,
2742) {
2743 let projected = fast_atb(&cache.coefficient_basis, &workspace.xtwy);
2746 workspace.projected_rhs.assign(&projected);
2747 let p = cache.penalty_eigenvalues.len();
2748 let d = workspace.ywy.len();
2749 for eig in 0..p {
2750 for output in 0..d {
2751 let value = workspace.projected_rhs[[eig, output]];
2752 workspace.projected_rhs_squared[[eig, output]] = value * value;
2753 }
2754 }
2755}
2756
2757fn evaluate_reml_parts(
2758 cache: &GaussianRemlEigenCache,
2759 ywy: ArrayView1<'_, f64>,
2760 projected_rhs_squared: ArrayView2<'_, f64>,
2761 n_observations: usize,
2762 n_outputs: usize,
2763 rho: f64,
2764) -> ObjectiveEval {
2765 let lambda = rho.exp();
2766 let nu = n_observations as f64 - cache.nullity as f64;
2767 let d = n_outputs as f64;
2768
2769 let (logdet_term, edf) = gaussian_reml_logdet_term(cache, rho, d);
2772 let mut eval = ObjectiveEval {
2773 cost: 0.0,
2774 grad: 0.0,
2775 hess: 0.0,
2776 edf,
2777 };
2778 eval += logdet_term;
2779 for output in 0..n_outputs {
2780 eval +=
2781 gaussian_reml_dispersion_term(cache, ywy, projected_rhs_squared, output, nu, lambda);
2782 }
2783 eval
2784}
2785
2786fn optimize_rho_no_alloc(
2787 cache: &GaussianRemlEigenCache,
2788 ywy: ArrayView1<'_, f64>,
2789 projected_rhs_squared: ArrayView2<'_, f64>,
2790 n_observations: usize,
2791 n_outputs: usize,
2792 init_rho: Option<f64>,
2793) -> Result<f64, EstimationError> {
2794 if cache.penalty_rank == 0 {
2795 return Ok(init_rho.unwrap_or(0.0).clamp(RHO_LOWER, RHO_UPPER));
2796 }
2797
2798 let lower_eval = evaluate_reml_parts(
2799 cache,
2800 ywy,
2801 projected_rhs_squared,
2802 n_observations,
2803 n_outputs,
2804 RHO_LOWER,
2805 );
2806
2807 let mut best_rho = RHO_LOWER;
2808 let mut best_cost = lower_eval.cost;
2809
2810 const GRID_INTERVALS: usize = 96;
2811 let mut grid = Vec::<(f64, f64)>::with_capacity(GRID_INTERVALS + 1);
2812 let mut prev_rho = RHO_LOWER;
2813 let mut prev_eval = lower_eval;
2814 grid.push((prev_rho, prev_eval.cost));
2815 for i in 1..=GRID_INTERVALS {
2816 let rho = RHO_LOWER + (RHO_UPPER - RHO_LOWER) * (i as f64) / (GRID_INTERVALS as f64);
2817 let eval = evaluate_reml_parts(
2818 cache,
2819 ywy,
2820 projected_rhs_squared,
2821 n_observations,
2822 n_outputs,
2823 rho,
2824 );
2825 grid.push((rho, eval.cost));
2826 if prev_eval.grad <= 0.0 && eval.grad >= 0.0 {
2827 let stationary_rho = refine_stationary_rho_no_alloc(
2828 cache,
2829 ywy,
2830 projected_rhs_squared,
2831 n_observations,
2832 n_outputs,
2833 prev_rho,
2834 rho,
2835 0.5 * (prev_rho + rho),
2836 );
2837 consider_rho_no_alloc(
2838 cache,
2839 ywy,
2840 projected_rhs_squared,
2841 n_observations,
2842 n_outputs,
2843 stationary_rho,
2844 &mut best_rho,
2845 &mut best_cost,
2846 );
2847 }
2848 prev_rho = rho;
2849 prev_eval = eval;
2850 }
2851 if let Some(best_idx) = grid
2852 .iter()
2853 .enumerate()
2854 .filter(|(_, (_, cost))| cost.is_finite())
2855 .min_by(|(_, (_, a)), (_, (_, b))| a.total_cmp(b))
2856 .map(|(idx, _)| idx)
2857 {
2858 let refined = if best_idx == 0 || best_idx + 1 == grid.len() {
2859 grid[best_idx].0
2860 } else {
2861 refine_stationary_rho_no_alloc(
2871 cache,
2872 ywy,
2873 projected_rhs_squared,
2874 n_observations,
2875 n_outputs,
2876 grid[best_idx - 1].0,
2877 grid[best_idx + 1].0,
2878 grid[best_idx].0,
2879 )
2880 };
2881 consider_rho_no_alloc(
2882 cache,
2883 ywy,
2884 projected_rhs_squared,
2885 n_observations,
2886 n_outputs,
2887 refined,
2888 &mut best_rho,
2889 &mut best_cost,
2890 );
2891 }
2892
2893 consider_rho_no_alloc(
2894 cache,
2895 ywy,
2896 projected_rhs_squared,
2897 n_observations,
2898 n_outputs,
2899 RHO_UPPER,
2900 &mut best_rho,
2901 &mut best_cost,
2902 );
2903 if let Some(rho0) = init_rho {
2904 consider_rho_no_alloc(
2905 cache,
2906 ywy,
2907 projected_rhs_squared,
2908 n_observations,
2909 n_outputs,
2910 rho0,
2911 &mut best_rho,
2912 &mut best_cost,
2913 );
2914 }
2915
2916 if best_cost.is_finite() {
2917 Ok(best_rho)
2918 } else {
2919 Err(EstimationError::InvalidInput(
2920 "Gaussian REML optimizer produced no finite candidates".to_string(),
2921 ))
2922 }
2923}
2924
2925fn consider_rho_no_alloc(
2926 cache: &GaussianRemlEigenCache,
2927 ywy: ArrayView1<'_, f64>,
2928 projected_rhs_squared: ArrayView2<'_, f64>,
2929 n_observations: usize,
2930 n_outputs: usize,
2931 rho: f64,
2932 best_rho: &mut f64,
2933 best_cost: &mut f64,
2934) {
2935 if !rho.is_finite() {
2936 return;
2937 }
2938 let candidate = rho.clamp(RHO_LOWER, RHO_UPPER);
2939 let eval = evaluate_reml_parts(
2940 cache,
2941 ywy,
2942 projected_rhs_squared,
2943 n_observations,
2944 n_outputs,
2945 candidate,
2946 );
2947 if eval.cost < *best_cost {
2948 *best_rho = candidate;
2949 *best_cost = eval.cost;
2950 }
2951}
2952
2953fn refine_stationary_rho_no_alloc(
2954 cache: &GaussianRemlEigenCache,
2955 ywy: ArrayView1<'_, f64>,
2956 projected_rhs_squared: ArrayView2<'_, f64>,
2957 n_observations: usize,
2958 n_outputs: usize,
2959 mut lo: f64,
2960 mut hi: f64,
2961 mut rho: f64,
2962) -> f64 {
2963 for _ in 0..80 {
2964 let eval = evaluate_reml_parts(
2965 cache,
2966 ywy,
2967 projected_rhs_squared,
2968 n_observations,
2969 n_outputs,
2970 rho,
2971 );
2972 if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
2973 return rho;
2974 }
2975 if eval.grad >= 0.0 {
2976 hi = rho;
2977 } else {
2978 lo = rho;
2979 }
2980 let newton = if eval.hess > 0.0 {
2981 let candidate = rho - eval.grad / eval.hess;
2982 (candidate > lo && candidate < hi).then_some(candidate)
2983 } else {
2984 None
2985 };
2986 if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
2987 break;
2988 }
2989 rho = newton.unwrap_or(0.5 * (lo + hi));
2990 }
2991 0.5 * (lo + hi)
2992}
2993
2994fn fill_coefficients_no_alloc(
2995 cache: &GaussianRemlEigenCache,
2996 workspace: &mut GaussianRemlNoAllocWorkspace,
2997 lambda: f64,
2998 mut coefficients: ArrayViewMut2<'_, f64>,
2999) {
3000 let p = cache.penalty_eigenvalues.len();
3001 let d = workspace.ywy.len();
3002 for eig in 0..p {
3003 let scale = 1.0 / (1.0 + lambda * cache.penalty_eigenvalues[eig]);
3004 for output in 0..d {
3005 workspace.scaled_projected_rhs[[eig, output]] =
3006 workspace.projected_rhs[[eig, output]] * scale;
3007 }
3008 }
3009
3010 for col in 0..p {
3011 for output in 0..d {
3012 let mut value = 0.0;
3013 for eig in 0..p {
3014 value += cache.coefficient_basis[[col, eig]]
3015 * workspace.scaled_projected_rhs[[eig, output]];
3016 }
3017 coefficients[[col, output]] = value;
3018 }
3019 }
3020}
3021
3022fn fill_fitted_no_alloc(
3023 x: ArrayView2<'_, f64>,
3024 coefficients: ArrayView2<'_, f64>,
3025 mut fitted: ArrayViewMut2<'_, f64>,
3026) {
3027 let n = x.nrows();
3028 let p = x.ncols();
3029 let d = coefficients.ncols();
3030 for row in 0..n {
3031 for output in 0..d {
3032 let mut value = 0.0;
3033 for col in 0..p {
3034 value += x[[row, col]] * coefficients[[col, output]];
3035 }
3036 fitted[[row, output]] = value;
3037 }
3038 }
3039}
3040
3041fn fill_sigma2_no_alloc(
3042 cache: &GaussianRemlEigenCache,
3043 ywy: ArrayView1<'_, f64>,
3044 projected_rhs_squared: ArrayView2<'_, f64>,
3045 n_observations: usize,
3046 n_outputs: usize,
3047 lambda: f64,
3048 mut sigma2: ArrayViewMut1<'_, f64>,
3049) {
3050 let nu = n_observations as f64 - cache.nullity as f64;
3051 for output in 0..n_outputs {
3052 let mut fitted_quadratic = 0.0;
3053 for eig in 0..cache.penalty_eigenvalues.len() {
3054 let denom = 1.0 + lambda * cache.penalty_eigenvalues[eig];
3055 fitted_quadratic += projected_rhs_squared[[eig, output]] / denom;
3056 }
3057 sigma2[output] = ((ywy[output] - fitted_quadratic).max(MIN_DEVIANCE)) / nu;
3058 }
3059}
3060
3061fn push_candidate(candidates: &mut Vec<f64>, rho: f64) {
3062 if rho.is_finite() {
3063 candidates.push(rho.clamp(RHO_LOWER, RHO_UPPER));
3064 }
3065}
3066
3067fn refine_stationary_rho(
3068 prepared: &GaussianRemlPrepared,
3069 mut lo: f64,
3070 mut hi: f64,
3071 mut rho: f64,
3072) -> f64 {
3073 for _ in 0..80 {
3074 let eval = prepared.evaluate(rho);
3075 if eval.grad.abs() <= GRAD_TOL * (1.0 + eval.cost.abs()) {
3076 return rho;
3077 }
3078 if eval.grad >= 0.0 {
3079 hi = rho;
3080 } else {
3081 lo = rho;
3082 }
3083 let newton = if eval.hess > 0.0 {
3084 let candidate = rho - eval.grad / eval.hess;
3085 (candidate > lo && candidate < hi).then_some(candidate)
3086 } else {
3087 None
3088 };
3089 if (hi - lo).abs() <= 1e-12 * (1.0 + rho.abs()) {
3090 break;
3091 }
3092 rho = newton.unwrap_or(0.5 * (lo + hi));
3093 }
3094 0.5 * (lo + hi)
3095}
3096
3097fn invert_lower_triangular(lower: &Array2<f64>) -> Result<Array2<f64>, EstimationError> {
3098 let n = lower.nrows();
3099 if lower.ncols() != n {
3100 crate::bail_invalid_estim!("lower-triangular solve requires a square matrix");
3101 }
3102 let eye = Array2::eye(n);
3103 solve_lower_triangular_matrix(lower, &eye)
3104}
3105
3106fn solve_lower_triangular_matrix(
3107 lower: &Array2<f64>,
3108 rhs: &Array2<f64>,
3109) -> Result<Array2<f64>, EstimationError> {
3110 let n = lower.nrows();
3111 if lower.ncols() != n || rhs.nrows() != n {
3112 crate::bail_invalid_estim!("lower-triangular solve dimension mismatch");
3113 }
3114 if let Some(out) = gam_gpu::try_solve_lower_triangular_matrix(lower.view(), rhs.view()) {
3115 return Ok(out);
3116 }
3117 let mut out = Array2::<f64>::zeros(rhs.dim());
3118 for col in 0..rhs.ncols() {
3119 for i in 0..n {
3120 let mut value = rhs[[i, col]];
3121 for k in 0..i {
3122 value -= lower[[i, k]] * out[[k, col]];
3123 }
3124 let diag = lower[[i, i]];
3125 if !(diag.is_finite() && diag.abs() > 0.0) {
3126 return Err(EstimationError::ModelIsIllConditioned {
3127 condition_number: f64::INFINITY,
3128 });
3129 }
3130 out[[i, col]] = value / diag;
3131 }
3132 }
3133 Ok(out)
3134}
3135
3136fn solve_spd_from_lower_factor(
3140 lower: &Array2<f64>,
3141 rhs: &Array2<f64>,
3142) -> Result<Array2<f64>, EstimationError> {
3143 let forward = solve_lower_triangular_matrix(lower, rhs)?;
3144 solve_upper_triangular_matrix(&lower.t().to_owned(), &forward)
3145}
3146
3147fn solve_upper_triangular_matrix(
3148 upper: &Array2<f64>,
3149 rhs: &Array2<f64>,
3150) -> Result<Array2<f64>, EstimationError> {
3151 let n = upper.nrows();
3152 if upper.ncols() != n || rhs.nrows() != n {
3153 crate::bail_invalid_estim!("upper-triangular solve dimension mismatch");
3154 }
3155 if let Some(out) = gam_gpu::try_solve_upper_triangular_matrix(upper.view(), rhs.view()) {
3156 return Ok(out);
3157 }
3158 let mut out = Array2::<f64>::zeros(rhs.dim());
3159 for col in 0..rhs.ncols() {
3160 for i_rev in 0..n {
3161 let i = n - 1 - i_rev;
3162 let mut value = rhs[[i, col]];
3163 for k in (i + 1)..n {
3164 value -= upper[[i, k]] * out[[k, col]];
3165 }
3166 let diag = upper[[i, i]];
3167 if !(diag.is_finite() && diag.abs() > 0.0) {
3168 return Err(EstimationError::ModelIsIllConditioned {
3169 condition_number: f64::INFINITY,
3170 });
3171 }
3172 out[[i, col]] = value / diag;
3173 }
3174 }
3175 Ok(out)
3176}
3177
3178#[cfg(test)]
3179mod tests {
3180 use super::*;
3181 use ndarray::array;
3182
3183 #[test]
3184 fn edf_does_not_double_count_penalty_nullspace() {
3185 let x = array![[1.0, 0.0], [1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0],];
3186 let y = array![[0.0], [1.0], [1.8], [3.2], [4.1]];
3187 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3188 let result =
3189 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3190 .expect("small full-rank Gaussian REML fit");
3191
3192 assert!(result.edf >= result.cache.nullity as f64);
3193 assert!(result.edf <= x.ncols() as f64 + 1.0e-10);
3194 }
3195
3196 #[test]
3197 fn multi_output_duplicate_columns_match_scalar_fit() {
3198 let x = array![
3199 [1.0, -1.0],
3200 [1.0, -0.5],
3201 [1.0, 0.0],
3202 [1.0, 0.5],
3203 [1.0, 1.0],
3204 [1.0, 1.5],
3205 ];
3206 let y1 = array![0.5, 0.2, 0.0, 0.3, 1.1, 2.0];
3207 let y = Array2::from_shape_fn(
3208 (y1.len(), 2),
3209 |(i, j)| if j == 0 { y1[i] } else { 2.0 * y1[i] },
3210 );
3211 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3212
3213 let scalar =
3214 gaussian_reml_closed_form(x.view(), y1.view(), penalty.view(), None, Some(0.0))
3215 .expect("scalar Gaussian REML fit");
3216 let multi =
3217 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3218 .expect("multi-output Gaussian REML fit");
3219
3220 assert!((multi.rho - scalar.rho).abs() <= 1.0e-8);
3221 for i in 0..x.ncols() {
3222 assert!((multi.coefficients[[i, 0]] - scalar.coefficients[i]).abs() <= 1.0e-8);
3223 assert!((multi.coefficients[[i, 1]] - 2.0 * scalar.coefficients[i]).abs() <= 1.0e-8);
3224 }
3225 }
3226
3227 #[test]
3228 fn warm_start_reuses_cache_and_lambda_seed() {
3229 let x = array![
3230 [1.0, -1.0],
3231 [1.0, -0.25],
3232 [1.0, 0.5],
3233 [1.0, 1.25],
3234 [1.0, 2.0],
3235 ];
3236 let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3237 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3238
3239 let cold =
3240 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
3241 .expect("cold fit");
3242 let warm_start = GaussianRemlWarmStart::from_multi_result(&cold);
3243 let warm = gaussian_reml_multi_closed_form_warm_started(
3244 x.view(),
3245 y.view(),
3246 penalty.view(),
3247 None,
3248 Some(&warm_start),
3249 )
3250 .expect("warm-started fit");
3251
3252 assert!((cold.lambda - warm.lambda).abs() <= 1.0e-10);
3253 assert_eq!(cold.cache.xtwx_fingerprint, warm.cache.xtwx_fingerprint);
3254 for i in 0..x.ncols() {
3255 assert!((cold.coefficients[[i, 0]] - warm.coefficients[[i, 0]]).abs() <= 1.0e-10);
3256 }
3257 }
3258
3259 #[test]
3260 fn warm_start_cache_rejects_different_penalty_geometry() {
3261 let x = array![
3262 [1.0, -1.0],
3263 [1.0, -0.25],
3264 [1.0, 0.5],
3265 [1.0, 1.25],
3266 [1.0, 2.0],
3267 ];
3268 let y = array![[0.1], [0.4], [0.7], [1.4], [2.2]];
3269 let penalty_a = array![[0.0, 0.0], [0.0, 1.0]];
3270 let penalty_b = array![[1.0, -1.0], [-1.0, 1.0]];
3271
3272 let first =
3273 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty_a.view(), None, Some(0.0))
3274 .expect("first fit");
3275 let warm_start = GaussianRemlWarmStart::from_multi_result(&first);
3276 let err = gaussian_reml_multi_closed_form_warm_started(
3277 x.view(),
3278 y.view(),
3279 penalty_b.view(),
3280 None,
3281 Some(&warm_start),
3282 )
3283 .expect_err("penalty-mismatched cache must be rejected");
3284
3285 assert!(err.to_string().contains("penalty mismatch"));
3286 }
3287
3288 #[test]
3289 fn no_alloc_cache_path_matches_allocating_fit() {
3290 let x = array![
3291 [1.0, -1.0, 0.25],
3292 [1.0, -0.5, 0.10],
3293 [1.0, 0.0, -0.20],
3294 [1.0, 0.5, -0.05],
3295 [1.0, 1.0, 0.30],
3296 [1.0, 1.5, 0.60],
3297 ];
3298 let y = array![
3299 [0.0, 0.2],
3300 [0.3, 0.1],
3301 [0.4, -0.1],
3302 [0.9, 0.3],
3303 [1.6, 0.8],
3304 [2.2, 1.2],
3305 ];
3306 let weights = array![1.0, 0.8, 1.2, 1.1, 0.9, 1.3];
3307 let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 4.0]];
3308
3309 let allocating = gaussian_reml_multi_closed_form_with_cache(
3310 x.view(),
3311 y.view(),
3312 penalty.view(),
3313 Some(weights.view()),
3314 Some(1.0),
3315 None,
3316 )
3317 .expect("allocating fit");
3318 let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3319 let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3320 let mut fitted = Array2::zeros(y.dim());
3321 let mut sigma2 = Array1::zeros(y.ncols());
3322
3323 let no_alloc = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3324 x.view(),
3325 y.view(),
3326 penalty.view(),
3327 Some(weights.view()),
3328 Some(allocating.lambda),
3329 &allocating.cache,
3330 &mut workspace,
3331 coefficients.view_mut(),
3332 fitted.view_mut(),
3333 sigma2.view_mut(),
3334 )
3335 .expect("no-alloc cached fit");
3336
3337 assert!((no_alloc.lambda - allocating.lambda).abs() <= 1.0e-10);
3338 assert!((no_alloc.reml_score - allocating.reml_score).abs() <= 1.0e-8);
3339 assert!((no_alloc.reml_grad_rho - allocating.reml_grad_rho).abs() <= 1.0e-8);
3340 assert!((no_alloc.reml_hess_rho - allocating.reml_hess_rho).abs() <= 1.0e-8);
3341 assert!((no_alloc.edf - allocating.edf).abs() <= 1.0e-10);
3342 for i in 0..x.ncols() {
3343 for j in 0..y.ncols() {
3344 assert!((coefficients[[i, j]] - allocating.coefficients[[i, j]]).abs() <= 1.0e-8);
3345 }
3346 }
3347 for i in 0..x.nrows() {
3348 for j in 0..y.ncols() {
3349 assert!((fitted[[i, j]] - allocating.fitted[[i, j]]).abs() <= 1.0e-8);
3350 }
3351 }
3352 for j in 0..y.ncols() {
3353 assert!((sigma2[j] - allocating.sigma2[j]).abs() <= 1.0e-10);
3354 }
3355 }
3356
3357 #[test]
3358 fn no_alloc_cache_path_rejects_bad_shapes_and_penalty_mismatch() {
3359 let x = array![[1.0, -1.0], [1.0, 0.0], [1.0, 1.0], [1.0, 2.0]];
3360 let y = array![[0.0], [0.2], [0.9], [1.8]];
3361 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3362 let cache = build_gaussian_reml_eigen_cache(x.view(), penalty.view(), None)
3363 .expect("Gaussian REML cache");
3364
3365 let mut bad_workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols() + 1);
3366 let mut coefficients = Array2::zeros((x.ncols(), y.ncols()));
3367 let mut fitted = Array2::zeros(y.dim());
3368 let mut sigma2 = Array1::zeros(y.ncols());
3369 let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3370 x.view(),
3371 y.view(),
3372 penalty.view(),
3373 None,
3374 Some(1.0),
3375 &cache,
3376 &mut bad_workspace,
3377 coefficients.view_mut(),
3378 fitted.view_mut(),
3379 sigma2.view_mut(),
3380 )
3381 .expect_err("workspace shape mismatch must be rejected");
3382 assert!(err.to_string().contains("workspace shape mismatch"));
3383
3384 let penalty_mismatch = array![[1.0, -1.0], [-1.0, 1.0]];
3385 let mut workspace = GaussianRemlNoAllocWorkspace::new(x.ncols(), y.ncols());
3386 let err = gaussian_reml_multi_closed_form_with_cache_no_alloc(
3387 x.view(),
3388 y.view(),
3389 penalty_mismatch.view(),
3390 None,
3391 Some(1.0),
3392 &cache,
3393 &mut workspace,
3394 coefficients.view_mut(),
3395 fitted.view_mut(),
3396 sigma2.view_mut(),
3397 )
3398 .expect_err("penalty mismatch must be rejected");
3399 assert!(err.to_string().contains("penalty mismatch"));
3400 }
3401
3402 #[derive(Clone, Copy, Debug)]
3403 enum ForwardScalar {
3404 Lambda,
3405 RemlScore,
3406 Coefficient(usize, usize),
3407 Fitted(usize, usize),
3408 Edf,
3409 }
3410
3411 fn finite_difference_design() -> Array2<f64> {
3412 Array2::from_shape_fn((20, 5), |(row, col)| {
3413 let t = (row as f64 - 9.5) / 10.0;
3414 match col {
3415 0 => 1.0,
3416 1 => t,
3417 2 => 0.5 * (3.0 * t * t - 1.0),
3418 3 => 0.5 * (5.0 * t * t * t - 3.0 * t),
3419 4 => (35.0 * t.powi(4) - 30.0 * t * t + 3.0) / 8.0,
3420 _ => unreachable!(),
3421 }
3422 })
3423 }
3424
3425 fn finite_difference_response(outputs: usize) -> Array2<f64> {
3426 Array2::from_shape_fn((20, outputs), |(row, output)| {
3437 let t = (row as f64 - 9.5) / 10.0;
3438 let phase = output as f64 + 1.0;
3439 0.2 + 0.25 * phase * t - 0.12 * t * t
3440 + (0.08 + 0.03 * phase) * (1.1 * t + 0.3 * phase).sin()
3441 + 0.05 * (7.0 * t + 0.5 * phase).sin()
3442 })
3443 }
3444
3445 fn finite_difference_penalty() -> Array2<f64> {
3446 Array2::from_diag(&array![0.0, 0.8, 1.2, 1.7, 2.3])
3447 }
3448
3449 fn finite_difference_weights() -> Array1<f64> {
3450 Array1::from_shape_fn(20, |row| {
3451 let t = (row as f64 - 9.5) / 10.0;
3452 1.0 + 0.025 * (1.1 * t).sin() + 0.01 * t
3453 })
3454 }
3455
3456 fn one_hot_objective_try(
3463 x: ArrayView2<'_, f64>,
3464 y: ArrayView2<'_, f64>,
3465 penalty: ArrayView2<'_, f64>,
3466 weights: ArrayView1<'_, f64>,
3467 target: ForwardScalar,
3468 ) -> Option<f64> {
3469 let fit = gaussian_reml_multi_closed_form_with_cache(
3470 x,
3471 y,
3472 penalty,
3473 Some(weights),
3474 Some(0.85),
3475 None,
3476 )
3477 .ok()?;
3478 Some(match target {
3479 ForwardScalar::Lambda => fit.lambda,
3480 ForwardScalar::RemlScore => fit.reml_score,
3481 ForwardScalar::Coefficient(row, col) => fit.coefficients[[row, col]],
3482 ForwardScalar::Fitted(row, col) => fit.fitted[[row, col]],
3483 ForwardScalar::Edf => fit.edf,
3484 })
3485 }
3486
3487 fn one_hot_objective(
3488 x: ArrayView2<'_, f64>,
3489 y: ArrayView2<'_, f64>,
3490 penalty: ArrayView2<'_, f64>,
3491 weights: ArrayView1<'_, f64>,
3492 target: ForwardScalar,
3493 ) -> f64 {
3494 one_hot_objective_try(x, y, penalty, weights, target)
3495 .expect("finite-difference forward fit")
3496 }
3497
3498 fn one_hot_backward(
3499 x: ArrayView2<'_, f64>,
3500 y: ArrayView2<'_, f64>,
3501 penalty: ArrayView2<'_, f64>,
3502 weights: ArrayView1<'_, f64>,
3503 target: ForwardScalar,
3504 ) -> GaussianRemlBackwardResult {
3505 let mut grad_coefficients = Array2::<f64>::zeros((x.ncols(), y.ncols()));
3506 let mut grad_fitted = Array2::<f64>::zeros(y.dim());
3507 let (grad_lambda, grad_score, grad_edf, coefficient_upstream, fitted_upstream) =
3508 match target {
3509 ForwardScalar::Lambda => (1.0, 0.0, 0.0, None, None),
3510 ForwardScalar::RemlScore => (0.0, 1.0, 0.0, None, None),
3511 ForwardScalar::Coefficient(row, col) => {
3512 grad_coefficients[[row, col]] = 1.0;
3513 (0.0, 0.0, 0.0, Some(grad_coefficients.view()), None)
3514 }
3515 ForwardScalar::Fitted(row, col) => {
3516 grad_fitted[[row, col]] = 1.0;
3517 (0.0, 0.0, 0.0, None, Some(grad_fitted.view()))
3518 }
3519 ForwardScalar::Edf => (0.0, 0.0, 1.0, None, None),
3520 };
3521 gaussian_reml_multi_closed_form_backward(
3522 x,
3523 y,
3524 penalty,
3525 Some(weights),
3526 Some(0.85),
3527 grad_lambda,
3528 coefficient_upstream,
3529 fitted_upstream,
3530 grad_score,
3531 grad_edf,
3532 )
3533 .expect("analytic backward VJP")
3534 }
3535
3536 fn assert_fd_close(label: &str, analytic: f64, finite_difference: f64) {
3537 let rel_tol = 1.0e-6_f64;
3538 let abs_tol = 1.0e-6_f64;
3539 let tol = abs_tol.max(rel_tol * analytic.abs().max(finite_difference.abs()));
3540 let diff = (analytic - finite_difference).abs();
3541 assert!(
3542 diff <= tol,
3543 "{label}: analytic={analytic:.12e}, finite_difference={finite_difference:.12e}, diff={diff:.3e}, tol={tol:.3e}"
3544 );
3545 }
3546
3547 fn adaptive_central_difference(mut eval: impl FnMut(f64) -> f64) -> f64 {
3548 let steps: [f64; 5] = [1.0e-3, 5.0e-4, 2.5e-4, 1.25e-4, 6.25e-5];
3549 let mut best = f64::NAN;
3550 let mut best_delta = f64::INFINITY;
3551 let mut previous: Option<f64> = None;
3552 for h in steps {
3553 let d1 = (eval(h) - eval(-h)) / (2.0 * h);
3554 let half_h = 0.5 * h;
3555 let d2 = (eval(half_h) - eval(-half_h)) / (2.0 * half_h);
3556 let estimate: f64 = d2 + (d2 - d1) / 3.0;
3557 if let Some(prev) = previous {
3558 let delta = (estimate - prev).abs();
3559 if delta < best_delta {
3560 best_delta = delta;
3561 best = estimate;
3562 }
3563 } else {
3564 best = estimate;
3565 }
3566 previous = Some(estimate);
3567 }
3568 best
3569 }
3570
3571 fn assert_backward_matches_forward_finite_difference(outputs: usize) {
3572 let x = finite_difference_design();
3573 let y = finite_difference_response(outputs);
3574 let penalty = finite_difference_penalty();
3575 let weights = finite_difference_weights();
3576 let targets = [
3577 ForwardScalar::Lambda,
3578 ForwardScalar::RemlScore,
3579 ForwardScalar::Coefficient(3, outputs - 1),
3580 ForwardScalar::Fitted(12, outputs - 1),
3581 ForwardScalar::Edf,
3582 ];
3583 for target in targets {
3584 let backward =
3585 one_hot_backward(x.view(), y.view(), penalty.view(), weights.view(), target);
3586
3587 for row in 0..x.nrows() {
3588 for col in 0..x.ncols() {
3589 let eval = |delta: f64| {
3590 let mut candidate = x.clone();
3591 candidate[[row, col]] += delta;
3592 one_hot_objective(
3593 candidate.view(),
3594 y.view(),
3595 penalty.view(),
3596 weights.view(),
3597 target,
3598 )
3599 };
3600 let fd = adaptive_central_difference(eval);
3601 assert_fd_close(
3602 &format!("target={target:?} x[{row},{col}]"),
3603 backward.grad_x[[row, col]],
3604 fd,
3605 );
3606 }
3607 }
3608
3609 for row in 0..y.nrows() {
3610 for col in 0..y.ncols() {
3611 let eval = |delta: f64| {
3612 let mut candidate = y.clone();
3613 candidate[[row, col]] += delta;
3614 one_hot_objective(
3615 x.view(),
3616 candidate.view(),
3617 penalty.view(),
3618 weights.view(),
3619 target,
3620 )
3621 };
3622 let fd = adaptive_central_difference(eval);
3623 assert_fd_close(
3624 &format!("target={target:?} y[{row},{col}]"),
3625 backward.grad_y[[row, col]],
3626 fd,
3627 );
3628 }
3629 }
3630
3631 for row in 0..weights.len() {
3632 let eval = |delta: f64| {
3633 let mut candidate = weights.clone();
3634 candidate[row] += delta;
3635 one_hot_objective(x.view(), y.view(), penalty.view(), candidate.view(), target)
3636 };
3637 let fd = adaptive_central_difference(eval);
3638 assert_fd_close(
3639 &format!("target={target:?} weights[{row}]"),
3640 backward.grad_weights[row],
3641 fd,
3642 );
3643 }
3644
3645 let null_index = 0usize; let probe_h = 1.0e-3_f64; for r in 0..penalty.nrows() {
3669 for c in 0..penalty.ncols() {
3670 if r == null_index || c == null_index {
3671 continue;
3672 }
3673 let eval = |delta: f64| {
3674 let mut candidate = penalty.clone();
3675 candidate[[r, c]] += delta;
3676 one_hot_objective(
3677 x.view(),
3678 y.view(),
3679 candidate.view(),
3680 weights.view(),
3681 target,
3682 )
3683 };
3684 let cone_safe = {
3685 let mut s_plus = penalty.clone();
3686 let mut s_minus = penalty.clone();
3687 s_plus[[r, c]] += probe_h;
3688 s_minus[[r, c]] -= probe_h;
3689 one_hot_objective_try(
3690 x.view(),
3691 y.view(),
3692 s_plus.view(),
3693 weights.view(),
3694 target,
3695 )
3696 .is_some()
3697 && one_hot_objective_try(
3698 x.view(),
3699 y.view(),
3700 s_minus.view(),
3701 weights.view(),
3702 target,
3703 )
3704 .is_some()
3705 };
3706 if !cone_safe {
3707 continue;
3708 }
3709 let fd = adaptive_central_difference(eval);
3710 assert_fd_close(
3711 &format!("target={target:?} penalty[{r},{c}]"),
3712 backward.grad_penalty[[r, c]],
3713 fd,
3714 );
3715 }
3716 }
3717 }
3718 }
3719
3720 #[test]
3721 fn scalar_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3722 assert_backward_matches_forward_finite_difference(1);
3723 }
3724
3725 #[test]
3726 fn multi_output_backward_matches_forward_finite_difference_for_all_x_y_and_weight_entries() {
3727 assert_backward_matches_forward_finite_difference(3);
3728 }
3729
3730 #[test]
3731 fn backward_vjp_matches_finite_difference() {
3732 let x = array![
3733 [1.0, -1.0, 0.2],
3734 [1.0, -0.3, -0.1],
3735 [1.0, 0.2, 0.4],
3736 [1.0, 0.8, 0.1],
3737 [1.0, 1.4, 0.5],
3738 [1.0, 2.0, 0.9],
3739 ];
3740 let y = array![
3741 [0.1, -0.2],
3742 [0.2, 0.1],
3743 [0.7, 0.0],
3744 [1.1, 0.3],
3745 [1.8, 0.9],
3746 [2.4, 1.4],
3747 ];
3748 let weights = array![1.0, 0.9, 1.1, 1.2, 0.8, 1.3];
3749 let penalty = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.2], [0.0, 0.2, 1.7]];
3750 let upstream_coefficients = array![[0.2, -0.1], [0.05, 0.03], [-0.04, 0.07]];
3751 let upstream_fitted = array![
3752 [0.01, -0.02],
3753 [0.03, 0.01],
3754 [-0.01, 0.02],
3755 [0.04, -0.03],
3756 [0.02, 0.05],
3757 [-0.02, 0.01],
3758 ];
3759 let upstream_lambda = 0.17;
3760 let upstream_score = -0.11;
3761
3762 let backward = gaussian_reml_multi_closed_form_backward(
3763 x.view(),
3764 y.view(),
3765 penalty.view(),
3766 Some(weights.view()),
3767 Some(0.8),
3768 upstream_lambda,
3769 Some(upstream_coefficients.view()),
3770 Some(upstream_fitted.view()),
3771 upstream_score,
3772 0.0,
3773 )
3774 .expect("backward VJP");
3775
3776 let objective = |x_eval: &Array2<f64>, y_eval: &Array2<f64>, w_eval: &Array1<f64>| {
3777 let fit = gaussian_reml_multi_closed_form_with_cache(
3778 x_eval.view(),
3779 y_eval.view(),
3780 penalty.view(),
3781 Some(w_eval.view()),
3782 Some(0.8),
3783 None,
3784 )
3785 .expect("fit for objective");
3786 upstream_lambda * fit.lambda
3787 + upstream_score * fit.reml_score
3788 + (&fit.coefficients * &upstream_coefficients).sum()
3789 + (&fit.fitted * &upstream_fitted).sum()
3790 };
3791 let eps = 1.0e-6;
3792 assert!(objective(&x, &y, &weights).is_finite());
3793
3794 let mut x_plus = x.clone();
3795 let mut x_minus = x.clone();
3796 x_plus[[3, 2]] += eps;
3797 x_minus[[3, 2]] -= eps;
3798 let fd_x =
3799 (objective(&x_plus, &y, &weights) - objective(&x_minus, &y, &weights)) / (2.0 * eps);
3800 assert!(
3801 (fd_x - backward.grad_x[[3, 2]]).abs() <= 2.0e-4,
3802 "grad_x mismatch: analytic={} fd={}",
3803 backward.grad_x[[3, 2]],
3804 fd_x
3805 );
3806
3807 let mut y_plus = y.clone();
3808 let mut y_minus = y.clone();
3809 y_plus[[4, 1]] += eps;
3810 y_minus[[4, 1]] -= eps;
3811 let fd_y =
3812 (objective(&x, &y_plus, &weights) - objective(&x, &y_minus, &weights)) / (2.0 * eps);
3813 assert!(
3814 (fd_y - backward.grad_y[[4, 1]]).abs() <= 2.0e-4,
3815 "grad_y mismatch: analytic={} fd={}",
3816 backward.grad_y[[4, 1]],
3817 fd_y
3818 );
3819
3820 let mut w_plus = weights.clone();
3821 let mut w_minus = weights.clone();
3822 w_plus[2] += eps;
3823 w_minus[2] -= eps;
3824 let fd_w = (objective(&x, &y, &w_plus) - objective(&x, &y, &w_minus)) / (2.0 * eps);
3825 assert!(
3826 (fd_w - backward.grad_weights[2]).abs() <= 2.0e-4,
3827 "grad_weight mismatch: analytic={} fd={}",
3828 backward.grad_weights[2],
3829 fd_w
3830 );
3831
3832 let objective_s = |s_eval: &Array2<f64>| {
3844 let fit = gaussian_reml_multi_closed_form_with_cache(
3845 x.view(),
3846 y.view(),
3847 s_eval.view(),
3848 Some(weights.view()),
3849 Some(0.8),
3850 None,
3851 )
3852 .expect("fit for penalty objective");
3853 upstream_lambda * fit.lambda
3854 + upstream_score * fit.reml_score
3855 + (&fit.coefficients * &upstream_coefficients).sum()
3856 + (&fit.fitted * &upstream_fitted).sum()
3857 };
3858 for (r, c) in [(1usize, 1usize), (1, 2), (2, 2)] {
3862 let mut s_plus = penalty.clone();
3863 let mut s_minus = penalty.clone();
3864 s_plus[[r, c]] += eps;
3865 s_minus[[r, c]] -= eps;
3866 let fd_s = (objective_s(&s_plus) - objective_s(&s_minus)) / (2.0 * eps);
3867 assert!(
3868 (fd_s - backward.grad_penalty[[r, c]]).abs() <= 2.0e-4,
3869 "grad_penalty[{r},{c}] mismatch: analytic={} fd={}",
3870 backward.grad_penalty[[r, c]],
3871 fd_s
3872 );
3873 }
3874 }
3875
3876 #[test]
3877 fn batched_eigen_cache_matches_per_fit_build() {
3878 let xtwx_a = array![[4.0, 1.0], [1.0, 3.0]];
3884 let xtwx_b = array![[2.5, -0.5], [-0.5, 1.7]];
3885 let xtwx_c = array![[7.2, 0.3], [0.3, 5.1]];
3886 let penalty = array![[0.0, 0.0], [0.0, 1.0]];
3887
3888 let batched = build_gaussian_reml_eigen_cache_batched(
3889 vec![xtwx_a.clone(), xtwx_b.clone(), xtwx_c.clone()],
3890 penalty.view(),
3891 None,
3892 );
3893 assert_eq!(batched.len(), 3);
3894
3895 for (xtwx, batched_cache) in [&xtwx_a, &xtwx_b, &xtwx_c].into_iter().zip(batched.iter()) {
3896 let single = gaussian_reml_eigen_cache_from_xtwx(xtwx.clone(), penalty.view(), None)
3897 .expect("per-fit cache");
3898 let batched_cache = batched_cache.as_ref().expect("batched cache");
3899 assert_eq!(batched_cache.penalty_rank, single.penalty_rank);
3900 assert_eq!(batched_cache.nullity, single.nullity);
3901 assert_eq!(batched_cache.xtwx_fingerprint, single.xtwx_fingerprint);
3902 assert_eq!(
3903 batched_cache.penalty_fingerprint,
3904 single.penalty_fingerprint
3905 );
3906 assert!((batched_cache.logdet_xtwx - single.logdet_xtwx).abs() <= 1.0e-12);
3907 assert!(
3908 (batched_cache.logdet_penalty_positive - single.logdet_penalty_positive).abs()
3909 <= 1.0e-12
3910 );
3911 for (a, b) in batched_cache
3912 .penalty_eigenvalues
3913 .iter()
3914 .zip(single.penalty_eigenvalues.iter())
3915 {
3916 assert!((a - b).abs() <= 1.0e-12);
3917 }
3918 for ((a, b), _) in batched_cache
3919 .coefficient_basis
3920 .iter()
3921 .zip(single.coefficient_basis.iter())
3922 .zip(0..)
3923 {
3924 assert!((a - b).abs() <= 1.0e-12);
3925 }
3926 }
3927 }
3928
3929 #[test]
3930 fn scalar_rho_optimizer_chooses_lowest_cost_stationary_point() {
3931 let cache = GaussianRemlEigenCache {
3932 penalty_eigenvalues: array![5.2430192311066924e-05, 81734184.18548436],
3933 eigenvectors: Array2::eye(2),
3934 coefficient_basis: Array2::eye(2),
3935 xtwx_fingerprint: 0,
3936 penalty_fingerprint: 0,
3937 logdet_xtwx: 0.0,
3938 logdet_penalty_positive: 0.0,
3939 penalty_rank: 2,
3940 nullity: 0,
3941 };
3942 let prepared = GaussianRemlPrepared {
3943 cache: cache.clone(),
3944 ywy: array![0.5021347226586624],
3945 projected_rhs_squared: array![[0.361060218768292], [0.01014486085547482]],
3946 projected_rhs: array![
3947 [0.361060218768292_f64.sqrt()],
3948 [0.01014486085547482_f64.sqrt()]
3949 ],
3950 n_observations: 100,
3951 n_outputs: 1,
3952 };
3953
3954 let rho = optimize_rho(&prepared, None).expect("allocating rho optimizer");
3955 let no_alloc_rho = optimize_rho_no_alloc(
3956 &cache,
3957 prepared.ywy.view(),
3958 prepared.projected_rhs_squared.view(),
3959 prepared.n_observations,
3960 prepared.n_outputs,
3961 None,
3962 )
3963 .expect("no-alloc rho optimizer");
3964
3965 assert!(
3966 (rho - 4.3251059890).abs() < 1.0e-6,
3967 "rho optimizer selected {rho}, expected the lower-cost later stationary point"
3968 );
3969 assert!(
3970 (no_alloc_rho - rho).abs() < 1.0e-8,
3971 "no-alloc optimizer selected {no_alloc_rho}, allocating selected {rho}"
3972 );
3973 assert!(prepared.evaluate(rho).cost < prepared.evaluate(-18.9277503549).cost);
3974 }
3975
3976 #[test]
3977 fn backward_from_fit_matches_backward_with_refit() {
3978 let x = array![[1.0, -0.9], [1.0, -0.4], [1.0, 0.1], [1.0, 0.6], [1.0, 1.1],];
3983 let y = array![[0.2, -0.1], [0.4, 0.1], [0.7, 0.3], [1.0, 0.5], [1.5, 0.8]];
3984 let penalty = array![[0.0, 0.0], [0.0, 1.5]];
3985 let weights = array![1.05, 0.95, 1.01, 0.99, 1.03];
3986
3987 let refit = gaussian_reml_multi_closed_form_backward(
3988 x.view(),
3989 y.view(),
3990 penalty.view(),
3991 Some(weights.view()),
3992 Some(0.85),
3993 0.2,
3994 None,
3995 None,
3996 -0.1,
3997 0.0,
3998 )
3999 .expect("refit backward");
4000
4001 let fit = gaussian_reml_multi_closed_form_with_cache(
4002 x.view(),
4003 y.view(),
4004 penalty.view(),
4005 Some(weights.view()),
4006 Some(0.85),
4007 None,
4008 )
4009 .expect("forward fit");
4010 let from_fit = gaussian_reml_multi_closed_form_backward_from_fit(
4011 x.view(),
4012 y.view(),
4013 penalty.view(),
4014 Some(weights.view()),
4015 &fit,
4016 0.2,
4017 None,
4018 None,
4019 -0.1,
4020 0.0,
4021 )
4022 .expect("from_fit backward");
4023
4024 for (a, b) in refit.grad_x.iter().zip(from_fit.grad_x.iter()) {
4025 assert!((a - b).abs() <= 1.0e-12);
4026 }
4027 for (a, b) in refit.grad_y.iter().zip(from_fit.grad_y.iter()) {
4028 assert!((a - b).abs() <= 1.0e-12);
4029 }
4030 for (a, b) in refit.grad_weights.iter().zip(from_fit.grad_weights.iter()) {
4031 assert!((a - b).abs() <= 1.0e-12);
4032 }
4033 }
4034
4035 #[test]
4045 fn backward_degrades_gracefully_when_k_is_near_singular() {
4046 let x = array![
4050 [1.0, -1.0, 0.5],
4051 [1.0, -0.5, 0.2],
4052 [1.0, 0.0, -0.1],
4053 [1.0, 0.5, 0.3],
4054 [1.0, 1.0, 0.8],
4055 [1.0, 1.5, 1.1],
4056 [1.0, 2.0, 1.5],
4057 [1.0, 2.5, 2.0],
4058 [1.0, 3.0, 2.6],
4059 [1.0, 3.5, 3.1],
4060 ];
4061 let y = array![
4062 [0.1],
4063 [0.3],
4064 [0.4],
4065 [0.7],
4066 [1.0],
4067 [1.5],
4068 [2.0],
4069 [2.7],
4070 [3.3],
4071 [4.0]
4072 ];
4073 let penalty = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
4075
4076 let mut fit =
4077 gaussian_reml_multi_closed_form(x.view(), y.view(), penalty.view(), None, Some(0.0))
4078 .expect("forward fit must succeed for well-posed input");
4079 fit.reml_hess_rho = 0.0;
4083
4084 let result = gaussian_reml_multi_closed_form_backward_from_fit(
4085 x.view(),
4086 y.view(),
4087 penalty.view(),
4088 None,
4089 &fit,
4090 1.0,
4093 None,
4094 None,
4095 1.0,
4096 1.0,
4097 )
4098 .expect("backward must NOT error on near-singular K");
4099
4100 assert_eq!(result.grad_x.dim(), (x.nrows(), x.ncols()));
4101 assert_eq!(result.grad_y.dim(), (y.nrows(), y.ncols()));
4102 assert_eq!(result.grad_penalty.dim(), (x.ncols(), x.ncols()));
4103 assert_eq!(result.grad_weights.dim(), x.nrows());
4104 for v in result.grad_x.iter() {
4105 assert!(v.is_finite(), "grad_x must be finite, got {v}");
4106 }
4107 for v in result.grad_y.iter() {
4108 assert!(v.is_finite(), "grad_y must be finite, got {v}");
4109 }
4110 for v in result.grad_penalty.iter() {
4111 assert!(v.is_finite(), "grad_penalty must be finite, got {v}");
4112 }
4113 for v in result.grad_weights.iter() {
4114 assert!(v.is_finite(), "grad_weights must be finite, got {v}");
4115 }
4116 }
4117}
4118
4119pub struct GaussianRemlBlocksBackwardAnalytic {
4123 pub grad_designs: Vec<Array2<f64>>,
4124 pub grad_penalties: Vec<Array2<f64>>,
4125 pub grad_y: Array2<f64>,
4126 pub grad_weights: Array1<f64>,
4127}
4128
4129pub fn gaussian_reml_fit_blocks_backward_analytic(
4138 designs: &[Array2<f64>],
4139 penalties_raw: &[Array2<f64>],
4140 y: ArrayView1<'_, f64>,
4141 weights: ArrayView1<'_, f64>,
4142 rhos: &[f64],
4143 grad_coefficients: Option<ArrayView2<'_, f64>>,
4144 grad_fitted: Option<ArrayView2<'_, f64>>,
4145 grad_lambdas: Option<ArrayView1<'_, f64>>,
4146 grad_log_lambdas: Option<ArrayView1<'_, f64>>,
4147 grad_reml_score: f64,
4148 grad_edf: Option<ArrayView1<'_, f64>>,
4149) -> Result<GaussianRemlBlocksBackwardAnalytic, EstimationError> {
4150 let n = y.len();
4151 let f_blocks = designs.len();
4152 let mut offsets = Vec::with_capacity(f_blocks + 1);
4153 offsets.push(0_usize);
4154 for design in designs {
4155 offsets.push(offsets.last().copied().unwrap() + design.ncols());
4156 }
4157 let p_total = *offsets.last().unwrap();
4158 if n == 0 || p_total == 0 {
4159 return Err(EstimationError::InvalidInput(
4160 "gaussian_reml_fit_blocks_backward requires non-empty rows and at least one coefficient column"
4161 .to_string(),
4162 ));
4163 }
4164
4165 if rhos.len() != f_blocks {
4166 return Err(EstimationError::InvalidInput(format!(
4167 "log_lambdas length mismatch: expected {f_blocks}, got {}",
4168 rhos.len()
4169 )));
4170 }
4171 if let Some(gc) = grad_coefficients {
4172 if gc.dim() != (p_total, 1) {
4173 return Err(EstimationError::InvalidInput(format!(
4174 "grad_coefficients shape mismatch: expected {}x1, got {}x{}",
4175 p_total,
4176 gc.nrows(),
4177 gc.ncols()
4178 )));
4179 }
4180 }
4181 if let Some(gf) = grad_fitted {
4182 if gf.dim() != (n, 1) {
4183 return Err(EstimationError::InvalidInput(format!(
4184 "grad_fitted shape mismatch: expected {}x1, got {}x{}",
4185 n,
4186 gf.nrows(),
4187 gf.ncols()
4188 )));
4189 }
4190 }
4191 if !grad_reml_score.is_finite() {
4192 return Err(EstimationError::InvalidInput(format!(
4193 "grad_reml_score must be finite; got {grad_reml_score}"
4194 )));
4195 }
4196 if let Some(vec) = grad_lambdas {
4197 if vec.len() != f_blocks {
4198 return Err(EstimationError::InvalidInput(format!(
4199 "grad_lambdas length mismatch: expected {f_blocks}, got {}",
4200 vec.len()
4201 )));
4202 }
4203 }
4204 if let Some(vec) = grad_log_lambdas {
4205 if vec.len() != f_blocks {
4206 return Err(EstimationError::InvalidInput(format!(
4207 "grad_log_lambdas length mismatch: expected {f_blocks}, got {}",
4208 vec.len()
4209 )));
4210 }
4211 }
4212 if let Some(vec) = grad_edf {
4213 if vec.len() != f_blocks {
4214 return Err(EstimationError::InvalidInput(format!(
4215 "grad_edf length mismatch: expected {f_blocks}, got {}",
4216 vec.len()
4217 )));
4218 }
4219 }
4220 if let Some(gc) = grad_coefficients {
4221 if let Some(((row, col), value)) = gc.indexed_iter().find(|(_, value)| !value.is_finite()) {
4222 return Err(EstimationError::InvalidInput(format!(
4223 "grad_coefficients[{row},{col}] must be finite; got {value}"
4224 )));
4225 }
4226 }
4227 if let Some(gf) = grad_fitted {
4228 if let Some(((row, col), value)) = gf.indexed_iter().find(|(_, value)| !value.is_finite()) {
4229 return Err(EstimationError::InvalidInput(format!(
4230 "grad_fitted[{row},{col}] must be finite; got {value}"
4231 )));
4232 }
4233 }
4234 if let Some(vec) = grad_lambdas {
4235 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4236 return Err(EstimationError::InvalidInput(format!(
4237 "grad_lambdas[{block}] must be finite; got {value}"
4238 )));
4239 }
4240 }
4241 if let Some(vec) = grad_log_lambdas {
4242 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4243 return Err(EstimationError::InvalidInput(format!(
4244 "grad_log_lambdas[{block}] must be finite; got {value}"
4245 )));
4246 }
4247 }
4248 if let Some(vec) = grad_edf {
4249 if let Some((block, value)) = vec.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4250 return Err(EstimationError::InvalidInput(format!(
4251 "grad_edf[{block}] must be finite; got {value}"
4252 )));
4253 }
4254 }
4255 for (block, design) in designs.iter().enumerate() {
4256 if let Some(((row, col), value)) =
4257 design.indexed_iter().find(|(_, value)| !value.is_finite())
4258 {
4259 return Err(EstimationError::InvalidInput(format!(
4260 "designs[{block}][{row},{col}] must be finite; got {value}"
4261 )));
4262 }
4263 }
4264 for (block, penalty) in penalties_raw.iter().enumerate() {
4265 if let Some(((row, col), value)) =
4266 penalty.indexed_iter().find(|(_, value)| !value.is_finite())
4267 {
4268 return Err(EstimationError::InvalidInput(format!(
4269 "penalties[{block}][{row},{col}] must be finite; got {value}"
4270 )));
4271 }
4272 }
4273 if let Some((row, value)) = y.iter().enumerate().find(|(_, value)| !value.is_finite()) {
4274 return Err(EstimationError::InvalidInput(format!(
4275 "y[{row}] must be finite; got {value}"
4276 )));
4277 }
4278 if let Some((row, value)) = weights
4279 .iter()
4280 .enumerate()
4281 .find(|(_, value)| !value.is_finite() || **value < 0.0)
4282 {
4283 return Err(EstimationError::InvalidInput(format!(
4284 "weights[{row}] must be finite and non-negative; got {value}"
4285 )));
4286 }
4287
4288 let mut z = Array2::<f64>::zeros((n, p_total));
4289 for k in 0..f_blocks {
4290 z.slice_mut(s![.., offsets[k]..offsets[k + 1]])
4291 .assign(&designs[k]);
4292 }
4293
4294 let penalties: Vec<Array2<f64>> = penalties_raw
4295 .iter()
4296 .map(|p| {
4297 let mut out = p.clone();
4298 gam_linalg::matrix::symmetrize_in_place(&mut out);
4299 out
4300 })
4301 .collect();
4302 let mut ranks = Vec::with_capacity(f_blocks);
4303 let mut pinvs = Vec::with_capacity(f_blocks);
4304 for penalty in &penalties {
4305 let (rank, pinv) = gam_linalg::utils::block_penalty_rank_and_pinv(penalty)?;
4306 ranks.push(rank);
4307 pinvs.push(pinv);
4308 }
4309
4310 let lambdas = Array1::from_iter(rhos.iter().map(|rho| rho.exp()));
4311 if let Some((block, lambda)) = lambdas
4312 .iter()
4313 .enumerate()
4314 .find(|(_, lambda)| !lambda.is_finite() || **lambda <= 0.0)
4315 {
4316 return Err(EstimationError::InvalidInput(format!(
4317 "exp(log_lambdas[{block}]) must be finite and positive; got {lambda}"
4318 )));
4319 }
4320 let mut k_matrix = fast_xt_diag_x(&z.view(), &weights);
4321 for block in 0..f_blocks {
4322 let lambda = lambdas[block];
4323 for local_i in 0..penalties[block].nrows() {
4324 let global_i = offsets[block] + local_i;
4325 for local_j in 0..penalties[block].ncols() {
4326 let global_j = offsets[block] + local_j;
4327 k_matrix[[global_i, global_j]] += lambda * penalties[block][[local_i, local_j]];
4328 }
4329 }
4330 }
4331 let r = gam_linalg::utils::invert_spd_with_ridge(&k_matrix, 0.0)?;
4332
4333 let mut xtwy = Array1::<f64>::zeros(p_total);
4334 for row in 0..n {
4335 let wy = weights[row] * y[row];
4336 for col in 0..p_total {
4337 xtwy[col] += z[[row, col]] * wy;
4338 }
4339 }
4340 let beta = r.dot(&xtwy);
4341 let fitted = z.dot(&beta);
4342 if let Some((col, value)) = beta
4343 .iter()
4344 .enumerate()
4345 .find(|(_, value)| !value.is_finite())
4346 {
4347 return Err(EstimationError::InvalidInput(format!(
4348 "solved coefficient {col} is non-finite: {value}"
4349 )));
4350 }
4351 let residual = &y.to_owned() - &fitted;
4352 let weighted_residual = &residual * &weights.to_owned();
4353 let ywy = y
4354 .iter()
4355 .zip(weights.iter())
4356 .map(|(&yi, &wi)| wi * yi * yi)
4357 .sum::<f64>();
4358 let q_raw = ywy - xtwy.dot(&beta);
4359 if !q_raw.is_finite() {
4360 return Err(EstimationError::InvalidInput(format!(
4361 "Gaussian REML residual quadratic form must be finite; got {q_raw}"
4362 )));
4363 }
4364 let q = q_raw.max(1.0e-300);
4365 let nullity = penalties
4366 .iter()
4367 .zip(ranks.iter())
4368 .map(|(penalty, rank)| penalty.nrows().saturating_sub(*rank))
4369 .sum::<usize>();
4370 let nu = n as f64 - nullity as f64;
4371 if !(nu.is_finite() && nu > 0.0) {
4372 return Err(EstimationError::InvalidInput(format!(
4373 "Gaussian REML residual degrees of freedom must be positive; got {nu}"
4374 )));
4375 }
4376 let tau = nu / q;
4377 let tau_q = -nu / (q * q);
4378 if !(tau.is_finite() && tau_q.is_finite()) {
4379 return Err(EstimationError::InvalidInput(format!(
4380 "Gaussian REML scale derivatives are non-finite: tau={tau}, tau_q={tau_q}"
4381 )));
4382 }
4383
4384 let mut grad_z = Array2::<f64>::zeros((n, p_total));
4385 let mut g_kernel = Array2::<f64>::zeros((p_total, p_total));
4386 let mut h_kernel = Array1::<f64>::zeros(p_total);
4387 let mut q_kernel = 0.0_f64;
4388 let mut j_blocks: Vec<Array2<f64>> = penalties
4389 .iter()
4390 .map(|p| Array2::<f64>::zeros(p.dim()))
4391 .collect();
4392
4393 let mut beta_tilde = Array1::<f64>::zeros(p_total);
4394 if let Some(gc) = grad_coefficients {
4395 beta_tilde += &gc.column(0).to_owned();
4396 }
4397 if let Some(gf) = grad_fitted {
4398 let gf_col = gf.column(0).to_owned();
4399 beta_tilde += &z.t().dot(&gf_col);
4400 for row in 0..n {
4401 for col in 0..p_total {
4402 grad_z[[row, col]] += gf_col[row] * beta[col];
4403 }
4404 }
4405 }
4406
4407 let u = r.dot(&beta_tilde);
4412 h_kernel += &u;
4413 for i in 0..p_total {
4414 for j in 0..p_total {
4415 g_kernel[[i, j]] -= 0.5 * (beta[i] * u[j] + u[i] * beta[j]);
4416 }
4417 }
4418
4419 let mut alpha = Array1::<f64>::zeros(f_blocks);
4420 if let Some(gl) = grad_lambdas {
4421 for block in 0..f_blocks {
4422 alpha[block] += gl[block] * lambdas[block];
4423 }
4424 }
4425 if let Some(grho) = grad_log_lambdas {
4426 alpha += &grho.to_owned();
4427 }
4428
4429 let mut p_betas = Vec::with_capacity(f_blocks);
4430 let mut m_vectors = Vec::with_capacity(f_blocks);
4431 let mut rp_matrices = Vec::with_capacity(f_blocks);
4432 let mut rpr_matrices = Vec::with_capacity(f_blocks);
4433 let mut b_values = Array1::<f64>::zeros(f_blocks);
4434 let mut t_values = Array1::<f64>::zeros(f_blocks);
4435
4436 for block in 0..f_blocks {
4437 let start = offsets[block];
4438 let end = offsets[block + 1];
4439 let beta_k = beta.slice(s![start..end]).to_owned();
4440 let s_beta = penalties[block].dot(&beta_k);
4441 let lambda = lambdas[block];
4442 let lambda_s_beta = s_beta.mapv(|value| lambda * value);
4443 let mut p_beta = Array1::<f64>::zeros(p_total);
4444 for local_i in 0..(end - start) {
4445 p_beta[start + local_i] = lambda_s_beta[local_i];
4446 }
4447 let weighted_penalty = penalties[block].mapv(|value| lambda * value);
4448 let rp_block = r.slice(s![.., start..end]).dot(&weighted_penalty);
4449 let mut rp = Array2::<f64>::zeros((p_total, p_total));
4450 rp.slice_mut(s![.., start..end]).assign(&rp_block);
4451 let rpr = rp_block.dot(&r.slice(s![start..end, ..]));
4452 let m = r.slice(s![.., start..end]).dot(&lambda_s_beta);
4453 b_values[block] = beta.dot(&p_beta);
4454 t_values[block] = (0..(end - start))
4455 .map(|local_i| rp_block[[start + local_i, local_i]])
4456 .sum::<f64>();
4457 alpha[block] -= u.dot(&p_beta);
4458 p_betas.push(p_beta);
4459 m_vectors.push(m);
4460 rp_matrices.push(rp);
4461 rpr_matrices.push(rpr);
4462 }
4463
4464 if grad_reml_score != 0.0 {
4465 q_kernel += 0.5 * grad_reml_score * tau;
4466 g_kernel += &(r.clone() * (0.5 * grad_reml_score));
4467 for block in 0..f_blocks {
4468 j_blocks[block] -= &(pinvs[block].clone() * (0.5 * grad_reml_score / lambdas[block]));
4469 }
4470 }
4471
4472 let mut trace_pairs = Array2::<f64>::zeros((f_blocks, f_blocks));
4473 for i in 0..f_blocks {
4474 for j in 0..f_blocks {
4475 trace_pairs[[i, j]] = gam_linalg::utils::trace_of_product(
4476 rp_matrices[i].view(),
4477 rp_matrices[j].view(),
4478 );
4479 }
4480 }
4481
4482 if let Some(ge) = grad_edf {
4483 for edf_block in 0..f_blocks {
4484 let scale = ge[edf_block];
4485 if scale == 0.0 {
4486 continue;
4487 }
4488 let start = offsets[edf_block];
4489 let end = offsets[edf_block + 1];
4490 g_kernel += &(rpr_matrices[edf_block].clone() * scale);
4491 j_blocks[edf_block] -= &(r.slice(s![start..end, start..end]).to_owned() * scale);
4492 for rho_block in 0..f_blocks {
4493 alpha[rho_block] += scale * trace_pairs[[edf_block, rho_block]];
4494 if rho_block == edf_block {
4495 alpha[rho_block] -= scale * t_values[edf_block];
4496 }
4497 }
4498 }
4499 }
4500
4501 if let Some((block, value)) = alpha
4502 .iter()
4503 .enumerate()
4504 .find(|(_, value)| !value.is_finite())
4505 {
4506 return Err(EstimationError::InvalidInput(format!(
4507 "rho adjoint seed for block {block} is non-finite: {value}"
4508 )));
4509 }
4510
4511 if alpha.iter().any(|value| *value != 0.0) {
4512 let mut outer_h = Array2::<f64>::zeros((f_blocks, f_blocks));
4513 for k in 0..f_blocks {
4514 for j in 0..f_blocks {
4515 let beta_pk_r_pj_beta = p_betas[k].dot(&m_vectors[j]);
4516 outer_h[[k, j]] = 0.5 * trace_pairs[[k, j]] + tau * beta_pk_r_pj_beta
4517 - if k == j {
4518 0.5 * (t_values[k] + tau * b_values[k])
4519 } else {
4520 0.0
4521 }
4522 - 0.5 * tau_q * b_values[k] * b_values[j];
4523 }
4524 }
4525 gam_linalg::matrix::symmetrize_in_place(&mut outer_h);
4529 if let Some(((row, col), value)) =
4530 outer_h.indexed_iter().find(|(_, value)| !value.is_finite())
4531 {
4532 return Err(EstimationError::InvalidInput(format!(
4533 "outer rho curvature entry ({row},{col}) is non-finite: {value}"
4534 )));
4535 }
4536 let rho_adj =
4537 gam_linalg::utils::solve_symmetric_vector_with_floor(&outer_h, &alpha, 1.0e-10)?;
4538 if let Some((block, value)) = rho_adj
4539 .iter()
4540 .enumerate()
4541 .find(|(_, value)| !value.is_finite())
4542 {
4543 return Err(EstimationError::InvalidInput(format!(
4544 "outer rho adjoint for block {block} is non-finite: {value}"
4545 )));
4546 }
4547 let weighted_b_sum = rho_adj
4548 .iter()
4549 .zip(b_values.iter())
4550 .map(|(&zk, &bk)| zk * bk)
4551 .sum::<f64>();
4552 q_kernel += 0.5 * tau_q * weighted_b_sum;
4553 for block in 0..f_blocks {
4554 let zk = rho_adj[block];
4555 if zk == 0.0 {
4556 continue;
4557 }
4558 g_kernel -= &(rpr_matrices[block].clone() * (0.5 * zk));
4559 let m = &m_vectors[block];
4560 for i in 0..p_total {
4561 h_kernel[i] += tau * zk * m[i];
4562 for j in 0..p_total {
4563 g_kernel[[i, j]] -= 0.5 * tau * zk * (beta[i] * m[j] + m[i] * beta[j]);
4564 }
4565 }
4566 let start = offsets[block];
4567 let end = offsets[block + 1];
4568 j_blocks[block] += &(r.slice(s![start..end, start..end]).to_owned() * (0.5 * zk));
4569 for i in 0..(end - start) {
4570 for j in 0..(end - start) {
4571 j_blocks[block][[i, j]] += 0.5 * tau * zk * beta[start + i] * beta[start + j];
4572 }
4573 }
4574 }
4575 }
4576
4577 for row in 0..n {
4578 for col in 0..p_total {
4579 grad_z[[row, col]] += -2.0 * q_kernel * weighted_residual[row] * beta[col];
4580 }
4581 }
4582 let zg = z.dot(&g_kernel);
4583 for row in 0..n {
4584 for col in 0..p_total {
4585 grad_z[[row, col]] += 2.0 * weights[row] * zg[[row, col]];
4586 }
4587 }
4588 let wy = y.to_owned() * &weights.to_owned();
4589 for row in 0..n {
4590 for col in 0..p_total {
4591 grad_z[[row, col]] += wy[row] * h_kernel[col];
4592 }
4593 }
4594
4595 let mut grad_y = Array2::<f64>::zeros((n, 1));
4596 let zh = z.dot(&h_kernel);
4597 for row in 0..n {
4598 grad_y[[row, 0]] = 2.0 * q_kernel * weighted_residual[row] + weights[row] * zh[row];
4599 }
4600
4601 let mut grad_weights = Array1::<f64>::zeros(n);
4602 for row in 0..n {
4603 let diag_zgz = (0..p_total)
4604 .map(|col| z[[row, col]] * zg[[row, col]])
4605 .sum::<f64>();
4606 grad_weights[row] = q_kernel * residual[row] * residual[row] + diag_zgz + y[row] * zh[row];
4607 }
4608
4609 let mut grad_penalties = Vec::with_capacity(f_blocks);
4610 for block in 0..f_blocks {
4611 let start = offsets[block];
4612 let end = offsets[block + 1];
4613 let mut local = g_kernel.slice(s![start..end, start..end]).to_owned();
4614 for i in 0..(end - start) {
4615 for j in 0..(end - start) {
4616 local[[i, j]] += q_kernel * beta[start + i] * beta[start + j];
4617 }
4618 }
4619 local += &j_blocks[block];
4620 local *= lambdas[block];
4621 gam_linalg::matrix::symmetrize_in_place(&mut local);
4622 grad_penalties.push(local);
4623 }
4624
4625 let mut grad_designs = Vec::with_capacity(f_blocks);
4626 for block in 0..f_blocks {
4627 grad_designs.push(
4628 grad_z
4629 .slice(s![.., offsets[block]..offsets[block + 1]])
4630 .to_owned(),
4631 );
4632 }
4633
4634 Ok(GaussianRemlBlocksBackwardAnalytic {
4635 grad_designs,
4636 grad_penalties,
4637 grad_y,
4638 grad_weights,
4639 })
4640}
4641
4642pub struct DenseFisherGaussianFit {
4646 pub coefficients: Array2<f64>,
4647 pub fitted: Array2<f64>,
4648 pub sigma2: Array1<f64>,
4649 pub objective: f64,
4650}
4651
4652pub fn add_block_diagonal_penalty(
4655 hessian: &mut Array2<f64>,
4656 penalty: ArrayView2<'_, f64>,
4657 lambda: f64,
4658 n_outputs: usize,
4659) -> Result<(), EstimationError> {
4660 let k = penalty.ncols();
4661 if penalty.nrows() != k {
4662 return Err(EstimationError::InvalidInput(format!(
4663 "penalty must be square for dense Fisher fit; got {}x{}",
4664 penalty.nrows(),
4665 penalty.ncols()
4666 )));
4667 }
4668 if hessian.dim() != (k * n_outputs, k * n_outputs) {
4669 return Err(EstimationError::InvalidInput(
4670 "dense Fisher Hessian shape mismatch while adding penalty".to_string(),
4671 ));
4672 }
4673 for output in 0..n_outputs {
4674 let offset = output * k;
4675 for row in 0..k {
4676 for col in 0..k {
4677 let s_sym = 0.5 * (penalty[[row, col]] + penalty[[col, row]]);
4678 hessian[[offset + row, offset + col]] += lambda * s_sym;
4679 }
4680 }
4681 }
4682 Ok(())
4683}
4684
4685pub fn dense_fisher_gaussian_fit(
4692 design: ArrayView2<'_, f64>,
4693 y: ArrayView2<'_, f64>,
4694 penalty: ArrayView2<'_, f64>,
4695 row_weights: ArrayView1<'_, f64>,
4696 fisher_w: ArrayView3<'_, f64>,
4697 lambda: f64,
4698 latent_prior_score: f64,
4699) -> Result<DenseFisherGaussianFit, EstimationError> {
4700 let n_obs = design.nrows();
4701 let k = design.ncols();
4702 let n_outputs = y.ncols();
4703 let mut hessian = crate::pirls::dense_block_xtwx(design, fisher_w, Some(row_weights))?;
4704 add_block_diagonal_penalty(&mut hessian, penalty, lambda, n_outputs)?;
4705 let rhs = crate::pirls::dense_block_xtwy(design, fisher_w, y, Some(row_weights))?;
4706 let beta_vec =
4707 gam_linalg::utils::solve_dense_block_system(&hessian, &rhs, "dense Fisher Gaussian")
4708 .map_err(EstimationError::InvalidInput)?;
4709 let mut coefficients = Array2::<f64>::zeros((k, n_outputs));
4710 for output in 0..n_outputs {
4711 for col in 0..k {
4712 coefficients[[col, output]] = beta_vec[output * k + col];
4713 }
4714 }
4715 let fitted = design.dot(&coefficients);
4716 let mut sigma2 = Array1::<f64>::zeros(n_outputs);
4717 let mut objective = latent_prior_score;
4718 for row in 0..n_obs {
4719 for a in 0..n_outputs {
4720 let ra = y[[row, a]] - fitted[[row, a]];
4721 sigma2[a] += row_weights[row] * ra * ra;
4722 for b in 0..n_outputs {
4723 objective += 0.5
4724 * row_weights[row]
4725 * ra
4726 * fisher_w[[row, a, b]]
4727 * (y[[row, b]] - fitted[[row, b]]);
4728 }
4729 }
4730 }
4731 for output in 0..n_outputs {
4732 sigma2[output] /= (n_obs.saturating_sub(k).max(1)) as f64;
4733 let beta_col = coefficients.column(output);
4734 let s_beta = penalty.dot(&beta_col);
4735 objective += 0.5 * lambda * beta_col.dot(&s_beta);
4736 }
4737 Ok(DenseFisherGaussianFit {
4738 coefficients,
4739 fitted,
4740 sigma2,
4741 objective,
4742 })
4743}