1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15struct FirthReducedCore {
21 w: Array1<f64>,
22 w1: Array1<f64>,
23 w2: Array1<f64>,
24 w3: Array1<f64>,
25 w4: Array1<f64>,
26 k_reduced: Array2<f64>,
27 half_log_det: f64,
28 h_diag: Array1<f64>,
29}
30
31pub(crate) struct FirthSecondDirEyeCache {
36 eye: Array2<f64>,
39 eta_rhs: Array2<f64>,
41 p_b_rhs: Array2<f64>,
43 p_bx: Vec<Array2<f64>>,
45 pu_qv: Vec<Array2<f64>>,
47}
48
49impl<'a> RemlState<'a> {
50 pub(crate) fn xt_diag_x_dense_into(
51 x: &Array2<f64>,
52 diag: &Array1<f64>,
53 weighted: &mut Array2<f64>,
54 ) -> Array2<f64> {
55 super::assembly::xt_diag_x_dense_into(x, diag, weighted)
56 }
57
58 #[inline]
59 pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
60 n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
61 }
62
63 pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
64 let mut out = Array2::<f64>::zeros(x.raw_dim());
65 super::assembly::row_scale_dense_into(x, scale, &mut out);
66 out
67 }
68
69 #[inline]
70 pub(crate) fn dense_product_likely_uses_inner_parallelism(
71 m: usize,
72 n: usize,
73 k: usize,
74 ) -> bool {
75 const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
80 const PAR_MIN_LONG_DIM: usize = 256;
81 let flop_scale = m.saturating_mul(n).saturating_mul(k);
82 let long_dim = m.max(n).max(k);
83 flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
84 }
85
86 #[inline]
87 pub(crate) fn should_join_independent_dense_products(
88 products: &[(usize, usize, usize)],
89 ) -> bool {
90 const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
91 if rayon::current_num_threads() <= 1 {
92 return false;
93 }
94 let mut total_flop_scale = 0usize;
95 for &(m, n, k) in products {
96 if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
97 return false;
98 }
99 total_flop_scale =
100 total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
101 }
102 total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
103 }
104
105 #[inline]
114 pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
115 out: &mut Array2<f64>,
116 observation_weight_sqrt: Option<&Array1<f64>>,
117 ) {
118 let Some(scale) = observation_weight_sqrt else {
119 return;
120 };
121 super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
122 }
123
124 #[inline]
129 pub(crate) fn fisher_weight_derivatives(
130 link: &InverseLink,
131 eta: f64,
132 ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
133 fisher_weight_jet5_for_inverse_link(link, eta)
134 }
135
136 #[inline]
137 pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
138 let mut min_pivot_sq = f64::INFINITY;
139 let mut max_pivot_sq = 0.0_f64;
140 for &pivot in chol_diag {
141 if !pivot.is_finite() || pivot <= 0.0 {
142 return false;
143 }
144 let pivot_sq = pivot * pivot;
145 min_pivot_sq = min_pivot_sq.min(pivot_sq);
146 max_pivot_sq = max_pivot_sq.max(pivot_sq);
147 }
148 if !min_pivot_sq.is_finite() {
149 return false;
150 }
151 let scale = max_pivot_sq.max(1.0);
152 let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
153 min_pivot_sq > floor
154 }
155
156 pub(crate) fn reduced_fisher_inverse_and_half_logdet(
157 fisher_reduced: &Array2<f64>,
158 ) -> Result<(Array2<f64>, f64), EstimationError> {
159 let r = fisher_reduced.nrows();
160 assert_eq!(r, fisher_reduced.ncols());
161 let mut k_reduced = Array2::<f64>::zeros((r, r));
162 if r == 0 {
163 return Ok((k_reduced, 0.0));
164 }
165
166 if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
167 let chol_diag = chol.diag();
168 if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
169 let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
170 for col in 0..r {
171 let mut e_col = Array1::<f64>::zeros(r);
172 e_col[col] = 1.0;
173 let solved = chol.solvevec(&e_col);
174 k_reduced.column_mut(col).assign(&solved);
175 }
176 return Ok((k_reduced, half_log_det));
177 }
178 }
179
180 let (evals_ir, evecs_ir) = fisher_reduced
181 .eigh(Side::Lower)
182 .map_err(EstimationError::EigendecompositionFailed)?;
183 let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
184 let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
185 let mut kept_positive_direction = false;
186 let mut half_log_det = 0.0_f64;
187 for (eig_idx, &eig) in evals_ir.iter().enumerate() {
188 if eig > tol {
189 kept_positive_direction = true;
190 half_log_det += 0.5 * eig.ln();
191 let inv = eig.recip();
192 let vec = evecs_ir.column(eig_idx).to_owned();
193 for row in 0..r {
194 for col in 0..r {
195 k_reduced[[row, col]] += inv * vec[row] * vec[col];
196 }
197 }
198 }
199 }
200 if !kept_positive_direction {
201 return Err(EstimationError::ModelIsIllConditioned {
202 condition_number: f64::INFINITY,
203 });
204 }
205 Ok((k_reduced, half_log_det))
206 }
207
208 pub(crate) fn fill_fisher_weight_derivative_arrays(
209 link: &InverseLink,
210 eta: &Array1<f64>,
211 w: &mut Array1<f64>,
212 w1: &mut Array1<f64>,
213 w2: &mut Array1<f64>,
214 w3: &mut Array1<f64>,
215 w4: &mut Array1<f64>,
216 ) -> Result<(), EstimationError> {
217 assert_eq!(eta.len(), w.len());
218 assert_eq!(eta.len(), w1.len());
219 assert_eq!(eta.len(), w2.len());
220 assert_eq!(eta.len(), w3.len());
221 assert_eq!(eta.len(), w4.len());
222
223 if Self::parallelize_firth_derivative_rows(eta.len()) {
224 let values: Result<Vec<_>, EstimationError> = eta
225 .par_iter()
226 .map(|&ei| Self::fisher_weight_derivatives(link, ei))
227 .collect();
228 for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
229 w[i] = value;
230 w1[i] = first;
231 w2[i] = second;
232 w3[i] = third;
233 w4[i] = fourth;
234 }
235 return Ok(());
236 }
237 for i in 0..eta.len() {
238 let (value, first, second, third, fourth) =
239 Self::fisher_weight_derivatives(link, eta[i])?;
240 w[i] = value;
241 w1[i] = first;
242 w2[i] = second;
243 w3[i] = third;
244 w4[i] = fourth;
245 }
246 Ok(())
247 }
248
249 pub(crate) fn weighted_cross(
250 left: &Array2<f64>,
251 right: &Array2<f64>,
252 weights: &Array1<f64>,
253 ) -> Array2<f64> {
254 assert_eq!(left.nrows(), right.nrows());
255 assert_eq!(left.nrows(), weights.len());
256 super::assembly::weighted_cross_dense(left, right, weights)
257 }
258
259 pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
260 assert_eq!(a.nrows(), b.ncols());
261 assert_eq!(a.ncols(), b.nrows());
262 let elems = a.nrows().saturating_mul(a.ncols());
263 if elems >= 32 * 32 {
264 let aview = FaerArrayView::new(a);
265 let bview = FaerArrayView::new(b);
266 return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
267 }
268 let m = a.nrows();
269 let n = a.ncols();
270 kahan_sum((0..m).map(|i| {
271 let mut acc = 0.0_f64;
272 for j in 0..n {
273 acc += a[[i, j]] * b[[j, i]];
274 }
275 acc
276 }))
277 }
278
279 pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
280 let weighted = Self::row_scale(z, weights);
287 fast_atb(z, &weighted)
288 }
289
290 pub(crate) fn reduced_crossweighted_gram(
291 z_left: &Array2<f64>,
292 z_right: &Array2<f64>,
293 weights: &Array1<f64>,
294 ) -> Array2<f64> {
295 let weighted = Self::row_scale(z_right, weights);
300 fast_atb(z_left, &weighted)
301 }
302
303 pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
304 let za = fast_ab(z, a);
310 (z * &za).sum_axis(ndarray::Axis(1))
311 }
312
313 pub(crate) fn apply_hadamard_gram(
314 z: &Array2<f64>,
315 a_left: &Array2<f64>,
316 a_right: &Array2<f64>,
317 vec: &Array1<f64>,
318 ) -> Array1<f64> {
319 let s = Self::reducedweighted_gram(z, vec);
328 let left_s = a_left.dot(&s);
329 let t = left_s.dot(a_right);
330 Self::reduced_diag_gram(z, &t)
331 }
332
333 pub(crate) fn apply_hadamard_gram_to_matrix(
334 z: &Array2<f64>,
335 a_left: &Array2<f64>,
336 a_right: &Array2<f64>,
337 mat: &Array2<f64>,
338 ) -> Array2<f64> {
339 let mut out = Array2::<f64>::zeros(mat.raw_dim());
347 for col in 0..mat.ncols() {
348 let v = mat.column(col).to_owned();
349 let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
350 out.column_mut(col).assign(&y);
351 }
352 out
353 }
354
355 pub(super) fn build_firth_dense_operator_for_link(
360 link: &InverseLink,
361 x_dense: &Array2<f64>,
362 eta: &Array1<f64>,
363 observation_weights: ndarray::ArrayView1<'_, f64>,
364 ) -> Result<FirthDenseOperator, EstimationError> {
365 FirthDenseOperator::build_with_observation_weights_impl(
366 link,
367 x_dense,
368 eta,
369 Some(observation_weights),
370 )
371 }
372
373 pub(super) fn firth_exact_tau_kernel(
374 op: &FirthDenseOperator,
375 x_tau: &Array2<f64>,
376 beta: &Array1<f64>,
377 include_hphi_tau_kernel: bool,
378 ) -> FirthTauExactKernel {
379 op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
380 }
381
382 pub(super) fn firth_hphi_tau_partial_apply(
383 op: &FirthDenseOperator,
384 x_tau: &Array2<f64>,
385 kernel: &FirthTauPartialKernel,
386 rhs: &Array2<f64>,
387 ) -> Array2<f64> {
388 op.hphi_tau_partial_apply(x_tau, kernel, rhs)
389 }
390}
391
392impl FirthDenseOperator {
393 pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
394 for col in 0..q_basis.ncols() {
395 let mut pivot_row = 0usize;
396 let mut pivot_abs = 0.0_f64;
397 for row in 0..q_basis.nrows() {
398 let value = q_basis[[row, col]];
399 let abs_value = value.abs();
400 if abs_value > pivot_abs {
401 pivot_abs = abs_value;
402 pivot_row = row;
403 }
404 }
405 if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
406 q_basis.column_mut(col).mapv_inplace(|v| -v);
407 }
408 }
409 }
410
411 pub(crate) fn identifiable_subspace_basis_from_gram(
412 gram: &Array2<f64>,
413 ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
414 let p = gram.nrows();
415 assert_eq!(p, gram.ncols());
416 if p == 0 {
417 return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
418 }
419
420 let (evals, evecs) = gram
421 .eigh(Side::Lower)
422 .map_err(EstimationError::EigendecompositionFailed)?;
423 let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
424 let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
425 let mut keep: Vec<usize> = evals
426 .iter()
427 .enumerate()
428 .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
429 .collect();
430 if keep.is_empty() {
431 return Err(EstimationError::ModelIsIllConditioned {
432 condition_number: f64::INFINITY,
433 });
434 }
435
436 keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
442 let r = keep.len();
443 let mut q_basis = Array2::<f64>::zeros((p, r));
444 let mut metric_spectrum = Array1::<f64>::zeros(r);
445 for (col_idx, eig_idx) in keep.into_iter().enumerate() {
446 q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
447 metric_spectrum[col_idx] = evals[eig_idx];
448 }
449 Self::canonicalize_basis_column_signs(&mut q_basis);
450 Ok((q_basis, metric_spectrum))
451 }
452
453 #[inline]
454 pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
455 assert_eq!(diag.len(), matrix.nrows());
456 assert_eq!(matrix.nrows(), matrix.ncols());
457 kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
458 }
459
460 pub fn build_for_link(
461 link: &InverseLink,
462 x_dense: &Array2<f64>,
463 eta: &Array1<f64>,
464 ) -> Result<FirthDenseOperator, EstimationError> {
465 Self::build_with_observation_weights_impl(link, x_dense, eta, None)
466 }
467
468 pub fn build_with_observation_weights_for_link(
469 link: &InverseLink,
470 x_dense: &Array2<f64>,
471 eta: &Array1<f64>,
472 observation_weights: ndarray::ArrayView1<'_, f64>,
473 ) -> Result<FirthDenseOperator, EstimationError> {
474 Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
475 }
476
477 pub(crate) fn build_design_factor_with_observation_weights(
487 x_dense: &Array2<f64>,
488 observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
489 ) -> Result<FirthDesignFactor, EstimationError> {
490 let n = x_dense.nrows();
491 let observation_weight_sqrt = if let Some(weights) = observation_weights {
492 if weights.len() != n {
493 crate::bail_invalid_estim!(
494 "Firth operator observation weight length {} != number of rows {}",
495 weights.len(),
496 n
497 );
498 }
499 let mut sqrt = Array1::<f64>::zeros(n);
500 for i in 0..n {
501 let weight = weights[i];
502 if !weight.is_finite() || weight < 0.0 {
503 crate::bail_invalid_estim!(
504 "Firth operator requires finite nonnegative observation weights, got {} at row {}",
505 weight,
506 i
507 );
508 }
509 sqrt[i] = weight.sqrt();
510 }
511 Some(sqrt)
512 } else {
513 None
514 };
515 let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
516 RemlState::row_scale(x_dense, scale)
517 } else {
518 x_dense.clone()
519 };
520 let gram = fast_atb(&basis_design, &basis_design);
522 let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
523 let x_reduced = fast_ab(&basis_design, &q_basis);
524 let r = q_basis.ncols();
525 let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
526 for col in 0..r {
527 x_metric_reduced_inv_diag[col] = metric_spectrum[col].recip();
528 }
529 let x_dense_t = x_dense.t().to_owned();
530 Ok(FirthDesignFactor {
531 x_dense: x_dense.clone(),
532 x_dense_t,
533 q_basis,
534 x_reduced,
535 observation_weight_sqrt,
536 metric_spectrum,
537 x_metric_reduced_inv_diag,
538 r,
539 n,
540 })
541 }
542
543 fn firth_reduced_core(
551 factor: &FirthDesignFactor,
552 link: &InverseLink,
553 eta: &Array1<f64>,
554 ) -> Result<FirthReducedCore, EstimationError> {
555 let n = factor.n;
556 if eta.len() != n {
557 crate::bail_invalid_estim!(
558 "Firth operator shape mismatch: nrows={}, eta_len={}",
559 n,
560 eta.len()
561 );
562 }
563 let r = factor.r;
564 let mut w = Array1::<f64>::zeros(n);
565 let mut w1 = Array1::<f64>::zeros(n);
566 let mut w2 = Array1::<f64>::zeros(n);
567 let mut w3 = Array1::<f64>::zeros(n);
568 let mut w4 = Array1::<f64>::zeros(n);
569 RemlState::fill_fisher_weight_derivative_arrays(
570 link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
571 )?;
572
573 let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&factor.x_reduced, &w);
575 if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
576 let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
577 let min_ev = eigvals_ir
578 .iter()
579 .copied()
580 .filter(|v| v.is_finite() && *v > 0.0)
581 .fold(f64::INFINITY, f64::min);
582 if min_ev.is_finite() {
583 let rel = min_ev / max_ev;
584 if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
585 log::warn!(
586 "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
587 min_ev,
588 max_ev,
589 rel
590 );
591 }
592 }
593 }
594
595 let (k_reduced, mut half_log_det) = if r > 0 {
596 RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
597 } else {
598 (Array2::<f64>::zeros((r, r)), 0.0)
599 };
600 if r > 0 {
601 for col in 0..r {
602 let metric_eig = factor.metric_spectrum[col];
603 half_log_det -= 0.5 * metric_eig.ln();
604 }
605 }
606 let h_diag = if r > 0 {
607 RemlState::reduced_diag_gram(&factor.x_reduced, &k_reduced)
608 } else {
609 Array1::<f64>::zeros(n)
610 };
611 Ok(FirthReducedCore {
612 w,
613 w1,
614 w2,
615 w3,
616 w4,
617 k_reduced,
618 half_log_det,
619 h_diag,
620 })
621 }
622
623 pub(crate) fn build_from_design_factor(
627 factor: &FirthDesignFactor,
628 link: &InverseLink,
629 eta: &Array1<f64>,
630 ) -> Result<FirthDenseOperator, EstimationError> {
631 let FirthReducedCore {
632 w,
633 w1,
634 w2,
635 w3,
636 w4,
637 k_reduced,
638 half_log_det,
639 h_diag,
640 } = Self::firth_reduced_core(factor, link, eta)?;
641 let b_base = RemlState::row_scale(&factor.x_dense, &w1);
642 let p_b_base = RemlState::apply_hadamard_gram_to_matrix(
643 &factor.x_reduced,
644 &k_reduced,
645 &k_reduced,
646 &b_base,
647 );
648 Ok(FirthDenseOperator {
649 x_dense: factor.x_dense.clone(),
650 x_dense_t: factor.x_dense_t.clone(),
651 q_basis: factor.q_basis.clone(),
652 x_reduced: factor.x_reduced.clone(),
653 observation_weight_sqrt: factor.observation_weight_sqrt.clone(),
654 k_reduced,
655 x_metric_reduced_inv_diag: factor.x_metric_reduced_inv_diag.clone(),
656 half_log_det,
657 h_diag,
658 w,
659 w1,
660 w2,
661 w3,
662 w4,
663 b_base,
664 p_b_base,
665 })
666 }
667
668 pub(crate) fn pirls_diagnostics_from_factor(
679 factor: &FirthDesignFactor,
680 link: &InverseLink,
681 eta: &Array1<f64>,
682 ) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
683 let core = Self::firth_reduced_core(factor, link, eta)?;
684 let (w, w1, h_diag, half_log_det) =
685 (core.w, core.w1, core.h_diag, core.half_log_det);
686 let hat_diag = &w * &h_diag;
688 let mut score_shift = Array1::<f64>::zeros(w.len());
691 for i in 0..w.len() {
692 let wi = w[i];
693 if wi > 0.0 {
694 score_shift[i] = 0.5 * (w1[i] / wi) * h_diag[i];
695 }
696 }
697 Ok((hat_diag, half_log_det, score_shift))
698 }
699
700 pub(crate) fn build_with_observation_weights_impl(
701 link: &InverseLink,
702 x_dense: &Array2<f64>,
703 eta: &Array1<f64>,
704 observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
705 ) -> Result<FirthDenseOperator, EstimationError> {
706 let n = x_dense.nrows();
767 if eta.len() != n {
768 crate::bail_invalid_estim!(
769 "Firth operator shape mismatch: nrows={}, eta_len={}",
770 n,
771 eta.len()
772 );
773 }
774 let factor =
775 Self::build_design_factor_with_observation_weights(x_dense, observation_weights)?;
776 Self::build_from_design_factor(&factor, link, eta)
777 }
778
779 #[inline]
780 pub(crate) fn jeffreys_logdet(&self) -> f64 {
781 self.half_log_det
782 }
783
784 pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
797 use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
798 let p = self.x_dense.ncols();
799 assert_eq!(
800 z.nrows(),
801 p,
802 "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
803 p,
804 z.nrows()
805 );
806 let m = z.ncols();
807 if m == 0 {
808 return 0.0;
809 }
810 let z_owned = z.to_owned();
812 let xz = fast_ab(&self.x_dense, &z_owned);
813 let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
814 RemlState::row_scale(&xz, scale)
815 } else {
816 xz
817 };
818 let mut j_t = fast_xt_diag_x(&xtz, &self.w);
820 symmetrize_in_place(&mut j_t);
821 let (evals, _) = match j_t.eigh(Side::Lower) {
822 Ok(pair) => pair,
823 Err(_) => return f64::NEG_INFINITY,
824 };
825 let Some(evals_slice) = evals.as_slice() else {
826 return f64::NEG_INFINITY;
827 };
828 let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
829 0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
830 }
831
832 #[inline]
833 pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
834 0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
839 }
840
841 #[inline]
842 pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
843 (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
844 }
845
846 #[inline]
847 pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
848 let mut reduced = fast_ab(x, &self.q_basis);
849 if let Some(scale) = self.observation_weight_sqrt.as_ref() {
850 reduced = RemlState::row_scale(&reduced, scale);
851 }
852 reduced
853 }
854
855 pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
856 let s_u = &self.w1 * &deta;
872 let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
877 let k_g_u = self.k_reduced.dot(&g_u_reduced);
878 let a_u_reduced = k_g_u.dot(&self.k_reduced);
879 let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
882 let b_uvec = &self.w2 * &deta;
883 FirthDirection {
884 deta,
885 g_u_reduced,
886 a_u_reduced,
887 dh,
888 b_uvec,
889 }
890 }
891
892 #[inline]
893 pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
894 fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
895 }
896
897 #[inline]
898 pub(crate) fn apply_p_u_to_matrix(
899 &self,
900 a_u_reduced: &Array2<f64>,
901 mat: &Array2<f64>,
902 ) -> Array2<f64> {
903 let mut out = RemlState::apply_hadamard_gram_to_matrix(
904 &self.x_reduced,
905 &self.k_reduced,
906 a_u_reduced,
907 mat,
908 );
909 out.mapv_inplace(|v| -2.0 * v);
910 out
911 }
912
913 pub(crate) fn hphi_direction_apply(
914 &self,
915 dir: &FirthDirection,
916 rhs: &Array2<f64>,
917 ) -> Array2<f64> {
918 let p = self.x_dense.ncols();
919 if rhs.nrows() != p {
920 return Array2::<f64>::zeros((p, rhs.ncols()));
921 }
922 if rhs.ncols() == 0 || p == 0 {
923 return Array2::<f64>::zeros((p, rhs.ncols()));
924 }
925 let etav = fast_ab(&self.x_dense, rhs);
932 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
933 let m_qv = RemlState::apply_hadamard_gram_to_matrix(
934 &self.x_reduced,
935 &self.k_reduced,
936 &self.k_reduced,
937 &qv,
938 );
939 let buvec = &dir.b_uvec;
940 let m_buv = RemlState::apply_hadamard_gram_to_matrix(
941 &self.x_reduced,
942 &self.k_reduced,
943 &self.k_reduced,
944 &(&etav * &buvec.view().insert_axis(Axis(1))),
945 );
946 let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
947 let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
948 let diag_term = self
949 .x_dense_t
950 .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
951 let term1 = self.left_scaled_xt(buvec, &m_qv);
952 let term2 = self.left_scaled_xt(&self.w1, &m_buv);
953 let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
954 0.5 * (diag_term - (term1 + term2 + term3))
955 }
956
957 pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
958 let p = self.x_dense.ncols();
959 let eye = Array2::<f64>::eye(p);
960 let mut out = self.hphi_direction_apply(dir, &eye);
961 symmetrize_in_place(&mut out);
974 out
975 }
976
977 pub(crate) fn hphisecond_direction_apply(
978 &self,
979 u: &FirthDirection,
980 v: &FirthDirection,
981 rhs: &Array2<f64>,
982 ) -> Array2<f64> {
983 let p = self.x_dense.ncols();
984 if rhs.nrows() != p {
985 return Array2::<f64>::zeros((p, rhs.ncols()));
986 }
987 if rhs.ncols() == 0 || p == 0 {
988 return Array2::<f64>::zeros((p, rhs.ncols()));
989 }
990 let deta_uv = &u.deta * &v.deta;
1002 let s_uv = &self.w2 * &deta_uv;
1005 let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1006 let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1007 let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1008 let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1009 let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1012 - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1013 - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1014 let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1015 let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1020 + &(&self.w3 * &(&u.deta * &v.dh))
1021 + &(&self.w3 * &(&v.deta * &u.dh))
1022 + &(&self.w2 * &d2h);
1023
1024 let eta_rhs = fast_ab(&self.x_dense, rhs);
1025 let diag_term = fast_ab(
1026 &self.x_dense_t,
1027 &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1028 );
1029
1030 let b_uvvec = &self.w3 * &deta_uv;
1031 let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1032 let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1033
1034 let p_b_rhs = fast_ab(&self.p_b_base, rhs);
1039 let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
1040 &self.x_reduced,
1041 &self.k_reduced,
1042 &self.k_reduced,
1043 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1044 );
1045 let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
1046 &self.x_reduced,
1047 &self.k_reduced,
1048 &self.k_reduced,
1049 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1050 );
1051 let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1052 &self.x_reduced,
1053 &self.k_reduced,
1054 &self.k_reduced,
1055 &b_uv_base,
1056 );
1057 let p_buv_rhs = fast_ab(&p_buv_base, rhs);
1058
1059 let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
1060 let pv_bu_rhs = self.apply_p_u_to_matrix(
1061 &v.a_u_reduced,
1062 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1063 );
1064 let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
1065 let p_u_bv_rhs = self.apply_p_u_to_matrix(
1066 &u.a_u_reduced,
1067 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1068 );
1069
1070 let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1071 &self.x_reduced,
1072 &u.a_u_reduced,
1073 &v.a_u_reduced,
1074 &self.b_base,
1075 );
1076 let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1077 &self.x_reduced,
1078 &self.k_reduced,
1079 &a_uv_reduced,
1080 &self.b_base,
1081 );
1082 let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1083 let p_uv_rhs = fast_ab(&p_uv_base, rhs);
1084
1085 let d2_terms = [
1087 self.left_scaled_xt(&b_uvvec, &p_b_rhs),
1088 self.left_scaled_xt(&self.w1, &p_buv_rhs),
1089 self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
1090 self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
1091 self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
1092 self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1093 self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
1094 self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1095 self.left_scaled_xt(&self.w1, &p_uv_rhs),
1096 ];
1097 let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1098 for term in d2_terms {
1099 d2_j2 += &term;
1100 }
1101
1102 0.5 * (diag_term - d2_j2)
1103 }
1104
1105 pub(crate) fn tk_second_direction_eye_cache(
1119 &self,
1120 dirs: &[FirthDirection],
1121 ) -> FirthSecondDirEyeCache {
1122 let p = self.x_dense.ncols();
1123 let eye = Array2::<f64>::eye(p);
1124 let eta_rhs = fast_ab(&self.x_dense, &eye);
1126 let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1127 let p_b_rhs = fast_ab(&self.p_b_base, &eye);
1129 let compute_blocks = |d: &FirthDirection| -> (Array2<f64>, Array2<f64>) {
1139 let p_b = RemlState::apply_hadamard_gram_to_matrix(
1141 &self.x_reduced,
1142 &self.k_reduced,
1143 &self.k_reduced,
1144 &(&eta_rhs * &d.b_uvec.view().insert_axis(Axis(1))),
1145 );
1146 let pu = self.apply_p_u_to_matrix(&d.a_u_reduced, &qv);
1148 (p_b, pu)
1149 };
1150 let (p_bx, pu_qv): (Vec<Array2<f64>>, Vec<Array2<f64>>) =
1151 if dirs.len() > 1 && rayon::current_num_threads() > 1 {
1152 use rayon::prelude::*;
1153 dirs.par_iter()
1154 .map(|d| gam_problem::with_nested_parallel(|| compute_blocks(d)))
1155 .unzip()
1156 } else {
1157 dirs.iter().map(compute_blocks).unzip()
1158 };
1159 FirthSecondDirEyeCache {
1160 eye,
1161 eta_rhs,
1162 p_b_rhs,
1163 p_bx,
1164 pu_qv,
1165 }
1166 }
1167
1168 pub(crate) fn hphisecond_direction_apply_eye_cached(
1175 &self,
1176 cache: &FirthSecondDirEyeCache,
1177 dirs: &[FirthDirection],
1178 i: usize,
1179 j: usize,
1180 ) -> Array2<f64> {
1181 let u = &dirs[i];
1182 let v = &dirs[j];
1183 let p = self.x_dense.ncols();
1184 let cols = cache.eta_rhs.ncols();
1185 if p == 0 || cols == 0 {
1186 return Array2::<f64>::zeros((p, cols));
1187 }
1188 let deta_uv = &u.deta * &v.deta;
1189 let s_uv = &self.w2 * &deta_uv;
1190 let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1191 let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1192 let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1193 let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1194 let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1195 - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1196 - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1197 let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1198 let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1199 + &(&self.w3 * &(&u.deta * &v.dh))
1200 + &(&self.w3 * &(&v.deta * &u.dh))
1201 + &(&self.w2 * &d2h);
1202
1203 let eta_rhs = &cache.eta_rhs;
1204 let diag_term = fast_ab(
1205 &self.x_dense_t,
1206 &(eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1207 );
1208
1209 let b_uvvec = &self.w3 * &deta_uv;
1210 let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1211
1212 let p_b_rhs = &cache.p_b_rhs;
1214 let p_bu_rhs = &cache.p_bx[i];
1215 let p_bv_rhs = &cache.p_bx[j];
1216 let p_u_b_rhs = &cache.pu_qv[i];
1217 let pv_b_rhs = &cache.pu_qv[j];
1218
1219 let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1221 &self.x_reduced,
1222 &self.k_reduced,
1223 &self.k_reduced,
1224 &b_uv_base,
1225 );
1226 let p_buv_rhs = fast_ab(&p_buv_base, &cache.eye);
1227
1228 let pv_bu_rhs = self.apply_p_u_to_matrix(
1229 &v.a_u_reduced,
1230 &(eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1231 );
1232 let p_u_bv_rhs = self.apply_p_u_to_matrix(
1233 &u.a_u_reduced,
1234 &(eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1235 );
1236
1237 let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1238 &self.x_reduced,
1239 &u.a_u_reduced,
1240 &v.a_u_reduced,
1241 &self.b_base,
1242 );
1243 let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1244 &self.x_reduced,
1245 &self.k_reduced,
1246 &a_uv_reduced,
1247 &self.b_base,
1248 );
1249 let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1250 let p_uv_rhs = fast_ab(&p_uv_base, &cache.eye);
1251
1252 let d2_terms = [
1253 self.left_scaled_xt(&b_uvvec, p_b_rhs),
1254 self.left_scaled_xt(&self.w1, &p_buv_rhs),
1255 self.left_scaled_xt(&u.b_uvec, p_bv_rhs),
1256 self.left_scaled_xt(&v.b_uvec, p_bu_rhs),
1257 self.left_scaled_xt(&u.b_uvec, pv_b_rhs),
1258 self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1259 self.left_scaled_xt(&v.b_uvec, p_u_b_rhs),
1260 self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1261 self.left_scaled_xt(&self.w1, &p_uv_rhs),
1262 ];
1263 let mut d2_j2 = Array2::<f64>::zeros((p, cols));
1264 for term in d2_terms {
1265 d2_j2 += &term;
1266 }
1267
1268 0.5 * (diag_term - d2_j2)
1269 }
1270
1271 pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1272 assert_eq!(a.nrows(), b.nrows());
1273 assert_eq!(a.ncols(), b.ncols());
1274 let mut out = Array1::<f64>::zeros(a.nrows());
1275 for i in 0..a.nrows() {
1276 let mut acc = 0.0_f64;
1277 for j in 0..a.ncols() {
1278 acc += a[[i, j]] * b[[i, j]];
1279 }
1280 out[i] = acc;
1281 }
1282 out
1283 }
1284
1285 pub(super) fn rowwise_bilinear(
1286 a: &Array2<f64>,
1287 m: &Array2<f64>,
1288 b: &Array2<f64>,
1289 ) -> Array1<f64> {
1290 assert_eq!(a.nrows(), b.nrows());
1292 assert_eq!(a.ncols(), m.nrows());
1293 assert_eq!(b.ncols(), m.ncols());
1294 let am = fast_ab(a, m);
1295 Self::rowwise_dot(&am, b)
1296 }
1297
1298 pub(crate) fn dot_i_and_h_from_reduced(
1299 &self,
1300 x_tau_reduced: &Array2<f64>,
1301 deta: &Array1<f64>,
1302 ) -> (Array2<f64>, Array1<f64>) {
1303 let dw = &self.w1 * deta;
1323 let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1324 + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1325 + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1326
1327 let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1328 let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1329 let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1330 let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1331 let dot_h = dot_h_explicit + dot_h_implicit;
1332 (dot_i, dot_h)
1333 }
1334
1335 pub(crate) fn exact_tau_kernel(
1336 &self,
1337 x_tau: &Array2<f64>,
1338 beta: &Array1<f64>,
1339 include_hphi_tau_kernel: bool,
1340 ) -> FirthTauExactKernel {
1341 let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1365 let x_tau_reduced = self.reduce_explicit_design(x_tau);
1366 let (dot_i_partial, dot_h_partial) =
1367 self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1368 let dot_s_partial =
1369 fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1370
1371 let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1372 let secondvec =
1373 &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1374 let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1375 let gphi_tau = first + second;
1376 let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1377 - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1378
1379 let tau_kernel = if include_hphi_tau_kernel {
1380 Some(self.hphi_tau_partial_prepare_from_partials(
1381 x_tau_reduced,
1382 &deta_partial,
1383 dot_h_partial,
1384 dot_i_partial,
1385 ))
1386 } else {
1387 None
1388 };
1389 FirthTauExactKernel {
1390 gphi_tau,
1391 phi_tau_partial,
1392 tau_kernel,
1393 }
1394 }
1395
1396 pub(crate) fn hphi_tau_partial_prepare_from_partials(
1397 &self,
1398 x_tau_reduced: Array2<f64>,
1399 deta_partial: &Array1<f64>,
1400 dot_h_partial: Array1<f64>,
1401 dot_i_partial: Array2<f64>,
1402 ) -> FirthTauPartialKernel {
1403 let dotw1 = &self.w2 * deta_partial;
1404 let dotw2 = &self.w3 * deta_partial;
1405 let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1406 FirthTauPartialKernel {
1407 deta_partial: deta_partial.clone(),
1408 dotw1,
1409 dotw2,
1410 dot_h_partial,
1411 x_tau_reduced,
1412 dot_i_partial,
1413 dot_k_reduced: dot_k,
1414 }
1415 }
1416
1417 pub(crate) fn d_beta_hphi_tau_partial_dense(
1418 &self,
1419 x_tau: &Array2<f64>,
1420 beta: &Array1<f64>,
1421 beta_direction: &Array1<f64>,
1422 ) -> Option<Array2<f64>> {
1423 if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1424 return None;
1425 }
1426 if !x_tau.iter().any(|value| *value != 0.0) {
1427 return None;
1428 }
1429 let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1430 let tau_kernel = tau_bundle.tau_kernel?;
1431 let firth_direction =
1432 self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1433 let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1434 let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1435 &tau_kernel,
1436 &tau_kernel.deta_partial,
1437 &tau_kernel.dot_i_partial,
1438 &firth_direction,
1439 &x_tau_v,
1440 );
1441 let eye = Array2::<f64>::eye(beta_direction.len());
1442 Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1443 }
1444
1445 pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1446 RemlState::apply_hadamard_gram_to_matrix(
1448 &self.x_reduced,
1449 &self.k_reduced,
1450 &self.k_reduced,
1451 mat,
1452 )
1453 }
1454
1455 pub(crate) fn apply_mtau_to_matrix(
1456 &self,
1457 kernel: &FirthTauPartialKernel,
1458 mat: &Array2<f64>,
1459 ) -> Array2<f64> {
1460 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1472 return Array2::<f64>::zeros(mat.raw_dim());
1473 }
1474 let mut out = Array2::<f64>::zeros(mat.raw_dim());
1475 for col in 0..mat.ncols() {
1476 let v = mat.column(col).to_owned();
1477 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1478 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1479 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1480
1481 let szt =
1482 RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1483 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1484 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1485
1486 let t3 = RemlState::apply_hadamard_gram(
1487 &self.x_reduced,
1488 &self.k_reduced,
1489 &kernel.dot_k_reduced,
1490 &v,
1491 );
1492
1493 let y = 2.0 * (t1 + t2 + t3);
1494 out.column_mut(col).assign(&y);
1495 }
1496 out
1497 }
1498
1499 pub(crate) fn hphi_tau_partial_apply(
1500 &self,
1501 x_tau: &Array2<f64>,
1502 kernel: &FirthTauPartialKernel,
1503 rhs: &Array2<f64>,
1504 ) -> Array2<f64> {
1505 let p = self.x_dense.ncols();
1506 if rhs.nrows() != p {
1507 return Array2::<f64>::zeros((p, rhs.ncols()));
1508 }
1509 if rhs.ncols() == 0 || p == 0 {
1510 return Array2::<f64>::zeros((p, rhs.ncols()));
1511 }
1512 let etav = fast_ab(&self.x_dense, rhs);
1529 let etav_tau = fast_ab(x_tau, rhs);
1530 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1531 let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1532 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1533 let m_qv = self.apply_pbar_to_matrix(&qv);
1534 let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1535 let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1536 * &self.h_diag.view().insert_axis(Axis(1))
1537 - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1538 let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1539 + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1540 * self.h_diag.view().insert_axis(Axis(1))
1541 + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1542 * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1543 - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1544 + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1545 0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1546 }
1547
1548 pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1888 &self,
1889 x_tau_i_reduced: Array2<f64>,
1890 x_tau_j_reduced: Array2<f64>,
1891 deta_i_partial: &Array1<f64>,
1892 deta_j_partial: &Array1<f64>,
1893 dot_h_i_partial: Array1<f64>,
1894 dot_h_j_partial: Array1<f64>,
1895 dot_i_i_partial: Array2<f64>,
1896 dot_i_j_partial: Array2<f64>,
1897 x_tau_tau_reduced: Option<Array2<f64>>,
1898 deta_ij_partial: Option<Array1<f64>>,
1899 ) -> FirthTauTauPartialKernel {
1900 let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1902 let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1903 FirthTauTauPartialKernel {
1904 x_tau_i_reduced,
1905 x_tau_j_reduced,
1906 deta_i_partial: deta_i_partial.clone(),
1907 deta_j_partial: deta_j_partial.clone(),
1908 dot_h_i_partial,
1909 dot_h_j_partial,
1910 dot_k_i_reduced,
1911 dot_k_j_reduced,
1912 dot_i_i_partial,
1913 dot_i_j_partial,
1914 x_tau_tau_reduced,
1915 deta_ij_partial,
1916 }
1917 }
1918
1919 pub(crate) fn hphi_tau_tau_partial_apply(
1927 &self,
1928 x_tau_i: &Array2<f64>,
1929 x_tau_j: &Array2<f64>,
1930 kernel: &FirthTauTauPartialKernel,
1931 rhs: &Array2<f64>,
1932 ) -> Array2<f64> {
1933 let p = self.x_dense.ncols();
1934 if rhs.nrows() != p {
1935 return Array2::<f64>::zeros((p, rhs.ncols()));
1936 }
1937 if rhs.ncols() == 0 || p == 0 {
1938 return Array2::<f64>::zeros((p, rhs.ncols()));
1939 }
1940 let n = self.x_dense.nrows();
1941 let m = rhs.ncols();
1942
1943 let z = &self.x_reduced;
1945 let x_r = &self.x_reduced;
1946 let k = &self.k_reduced;
1947 let x_ri = &kernel.x_tau_i_reduced;
1948 let x_rj = &kernel.x_tau_j_reduced;
1949 let deta_i = &kernel.deta_i_partial;
1950 let deta_j = &kernel.deta_j_partial;
1951 let dh_i = &kernel.dot_h_i_partial;
1952 let dh_j = &kernel.dot_h_j_partial;
1953 let dot_k_i = &kernel.dot_k_i_reduced;
1954 let dot_k_j = &kernel.dot_k_j_reduced;
1955 let dot_i_i = &kernel.dot_i_i_partial;
1956 let dot_i_j = &kernel.dot_i_j_partial;
1957
1958 let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1961 let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1962 let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1963 let zeros_n = Array1::<f64>::zeros(n);
1964 let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1965
1966 let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1970 (n, m, p),
1971 (n, m, p),
1972 (n, m, p),
1973 ]) {
1974 let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1975 || fast_ab(&self.x_dense, rhs),
1976 || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1977 );
1978 (eta_v, eta_i_v, eta_j_v)
1979 } else {
1980 (
1981 fast_ab(&self.x_dense, rhs),
1982 fast_ab(x_tau_i, rhs),
1983 fast_ab(x_tau_j, rhs),
1984 )
1985 }; let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1990 let qt_v = fast_atb(&self.q_basis, rhs); let mut out = fast_ab(x_rij, &qt_v); RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1993 &mut out,
1994 self.observation_weight_sqrt.as_ref(),
1995 );
1996 out
1997 } else {
1998 Array2::<f64>::zeros((n, m))
1999 };
2000
2001 let a_i_reduced = -dot_k_i; let a_j_reduced = -dot_k_j;
2008
2009 let dw_i = &self.w1 * deta_i;
2019 let dw_j = &self.w1 * deta_j;
2020 let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
2021 let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
2022 if x_tau_tau_is_some {
2023 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2024 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2025 }
2026 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
2027 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
2028 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
2029 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
2030 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
2031 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
2032 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2033
2034 let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
2039 + a_i_reduced.dot(dot_i_j).dot(k)
2040 + a_j_reduced.dot(dot_i_i).dot(k);
2041
2042 let dh_ij: Array1<f64> = {
2052 let r = k.ncols();
2053 let can_join = RemlState::should_join_independent_dense_products(&[
2054 (n, r, r),
2055 (n, r, r),
2056 (n, r, r),
2057 (n, r, r),
2058 ]);
2059 let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
2060 let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
2061 || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
2062 || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
2063 );
2064 (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
2065 } else {
2066 (
2067 fast_ab(x_r, &k_ddot),
2068 fast_ab(x_ri, dot_k_j),
2069 fast_ab(x_rj, dot_k_i),
2070 fast_ab(x_ri, k),
2071 )
2072 };
2073
2074 let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
2075 acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2076 acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2077 acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
2078 if x_tau_tau_is_some {
2079 let rij_k = fast_ab(x_rij, k);
2080 acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2081 }
2082 acc
2083 };
2084
2085 let gamma = &self.w2 * &self.h_diag;
2096 let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
2097 let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
2098 let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
2099 + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
2100 + &(&(&self.w3 * deta_i) * dh_j)
2101 + &(&(&self.w3 * deta_j) * dh_i)
2102 + &(&self.w2 * &dh_ij));
2103
2104 let mut diag_term = Array2::<f64>::zeros((p, m));
2114 let gamma_col = gamma.view().insert_axis(Axis(1));
2115 let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
2116 let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
2117 let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
2118
2119 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
2121 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
2122 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
2124 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
2125 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
2127 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
2128 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
2130 if x_tau_tau_is_some {
2132 let y: Array2<f64> = &eta_v * &gamma_col;
2138 let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2139 let mut y_scaled = y.clone();
2140 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2141 &mut y_scaled,
2142 self.observation_weight_sqrt.as_ref(),
2143 );
2144 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2145 } else {
2146 self.q_basis.dot(&x_rij.t().dot(&y))
2147 };
2148 diag_term = diag_term + xt_ij_y;
2149 diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
2150 }
2151
2152 let w1_col = self.w1.view().insert_axis(Axis(1));
2164 let b_v = &eta_v * &w1_col;
2165
2166 let w2_deta_i = &self.w2 * deta_i;
2168 let w2_deta_j = &self.w2 * deta_j;
2169 let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
2170 let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
2171 let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
2172 let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
2173
2174 let w3_didj = &(&self.w3 * deta_i) * deta_j;
2180 let w2_dij = &self.w2 * deta_ij;
2181 let bddot_scale = &w3_didj + &w2_dij;
2182 let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
2183 let mut bddot_ij_v = &eta_v * &bddot_scale_col;
2184 bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
2185 bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
2186 bddot_ij_v += &(&eta_ij_v * &w1_col);
2187
2188 let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
2190 let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
2191
2192 let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
2199 let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
2200 let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
2201 let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
2202
2203 let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
2205 let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
2206
2207 let p_ddot_b_v = self.apply_p_ddot_ij(
2209 x_r,
2210 x_ri,
2211 x_rj,
2212 x_rij,
2213 k,
2214 dot_k_i,
2215 dot_k_j,
2216 &k_ddot,
2217 x_tau_tau_is_some,
2218 &b_v,
2219 );
2220
2221 let apply_bdot_tau_t =
2235 |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
2236 let scale_col = scale_deta.view().insert_axis(Axis(1));
2237 self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
2238 };
2239
2240 let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
2241 let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
2242 let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
2243 out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
2244 out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
2245 if x_tau_tau_is_some {
2246 let y = q_v * &w1_col;
2248 let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2249 let mut y_scaled = y.clone();
2250 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2251 &mut y_scaled,
2252 self.observation_weight_sqrt.as_ref(),
2253 );
2254 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2255 } else {
2256 self.q_basis.dot(&x_rij.t().dot(&y))
2257 };
2258 out = out + contrib;
2259 }
2260 out
2261 };
2262
2263 let t1a = apply_bddot_ij_t(&p_bv);
2265 let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2266 let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2268 let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2269 let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2271 let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2272 let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2274 let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2275 let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2277
2278 let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2279
2280 0.5 * (diag_term - d2_bpb)
2281 }
2282
2283 pub(crate) fn exact_tau_tau_kernel(
2346 &self,
2347 x_tau_i: &Array2<f64>,
2348 x_tau_j: &Array2<f64>,
2349 x_tau_tau: Option<&Array2<f64>>,
2350 beta: &Array1<f64>,
2351 include_hphi_tau_tau_kernel: bool,
2352 ) -> FirthTauTauExactKernel {
2353 let deta_i = x_tau_i.dot(beta);
2354 let deta_j = x_tau_j.dot(beta);
2355 let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2356
2357 let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2358 let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2359 let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2360
2361 let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2362 let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2363
2364 let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2371 let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2372 let dw_i = &self.w1 * &deta_i;
2373 let dw_j = &self.w1 * &deta_j;
2374 let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2375
2376 let x_r = &self.x_reduced;
2377 let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2378 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2379 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2380 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2381 }
2382 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2383 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2384 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2385 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2386 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2387 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2388 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2389
2390 let k = &self.k_reduced;
2394 let k_dot_i_i = k.dot(&dot_i_i);
2395 let k_dot_i_j = k.dot(&dot_i_j);
2396 let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2397 - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2398
2399 let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2406 let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2407 let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2408 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2409 s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2410 }
2411 s_ddot = s_ddot
2412 + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2413 + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2414 let g_inv = &self.x_metric_reduced_inv_diag;
2423 let rdim = k.nrows();
2424 let mut a_pen = 0.0_f64;
2425 for kk in 0..rdim {
2426 for ll in 0..rdim {
2427 a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2428 }
2429 a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2430 }
2431 let phi_tau_tau_partial = a_lik + a_pen;
2432
2433 let dot_k_i = -k.dot(&dot_i_i).dot(k);
2440 let dot_k_j = -k.dot(&dot_i_j).dot(k);
2441 let a_i_red = -&dot_k_i; let a_j_red = -&dot_k_j; let k_ddot: Array2<f64> =
2444 -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2445
2446 let n = self.x_dense.nrows();
2452 let mut dh_ij = Array1::<f64>::zeros(n);
2453 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2454 let rij_k = x_rij.dot(k);
2455 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2456 }
2457 let xr_kddot = x_r.dot(&k_ddot);
2458 dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2459 let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2460 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2461 let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2462 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2463 let ri_k = x_tau_i_reduced.dot(k);
2464 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2465
2466 let w1_h = &self.w1 * &self.h_diag;
2469 let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2470 if let Some(x_ij) = x_tau_tau.as_ref() {
2471 gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2472 }
2473 let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2474 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2475
2476 let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2478 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2479
2480 let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2487 v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2488 v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2489 v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2490 v_dot_ij += &(&self.w1 * &dh_ij);
2491 gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2492
2493 let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2494 Some(self.hphi_tau_tau_partial_prepare_from_partials(
2495 x_tau_i_reduced,
2496 x_tau_j_reduced,
2497 &deta_i,
2498 &deta_j,
2499 dot_h_i,
2500 dot_h_j,
2501 dot_i_i,
2502 dot_i_j,
2503 x_tau_tau_reduced,
2504 deta_ij,
2505 ))
2506 } else {
2507 None
2508 };
2509
2510 FirthTauTauExactKernel {
2511 phi_tau_tau_partial,
2512 gphi_tau_tau,
2513 tau_tau_kernel,
2514 }
2515 }
2516
2517 pub(crate) fn apply_mtau_from_reduced(
2524 &self,
2525 x_tau_reduced: &Array2<f64>,
2526 dot_k_reduced: &Array2<f64>,
2527 mat: &Array2<f64>,
2528 ) -> Array2<f64> {
2529 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2530 return Array2::<f64>::zeros(mat.raw_dim());
2531 }
2532 let mut out = Array2::<f64>::zeros(mat.raw_dim());
2533 for col in 0..mat.ncols() {
2534 let v = mat.column(col).to_owned();
2535 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2536 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2537 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2538
2539 let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2540 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2541 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2542
2543 let t3 =
2544 RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2545
2546 let y = 2.0 * (t1 + t2 + t3);
2547 out.column_mut(col).assign(&y);
2548 }
2549 out
2550 }
2551
2552 pub(crate) fn apply_p_ddot_ij(
2563 &self,
2564 x_r: &Array2<f64>,
2565 x_ri: &Array2<f64>,
2566 x_rj: &Array2<f64>,
2567 x_rij: &Array2<f64>,
2568 k: &Array2<f64>,
2569 dot_k_i: &Array2<f64>,
2570 dot_k_j: &Array2<f64>,
2571 k_ddot: &Array2<f64>,
2572 x_tau_tau_is_some: bool,
2573 mat: &Array2<f64>,
2574 ) -> Array2<f64> {
2575 let n = self.x_dense.nrows();
2576 let m = mat.ncols();
2577 if mat.nrows() != n || m == 0 {
2578 return Array2::<f64>::zeros(mat.raw_dim());
2579 }
2580 let mut out = Array2::<f64>::zeros((n, m));
2581 for col in 0..m {
2582 let v = mat.column(col).to_owned();
2583 let s_zz = RemlState::reducedweighted_gram(x_r, &v); let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); let mut mdot_mdot = Array1::<f64>::zeros(n);
2603 {
2605 let core = k.dot(&s_zz).dot(&k.t());
2606 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2607 }
2608 {
2610 let core = k.dot(&s_zz).dot(&dot_k_j.t());
2611 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2612 }
2613 {
2615 let core = k.dot(&s_zj).dot(&k.t());
2616 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2617 }
2618 {
2620 let core = dot_k_i.dot(&s_zz).dot(&k.t());
2621 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2622 }
2623 {
2625 let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2626 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2627 }
2628 {
2630 let core = dot_k_i.dot(&s_zj).dot(&k.t());
2631 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2632 }
2633 {
2636 let core = k.dot(&s_iz).dot(&k.t());
2637 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2638 }
2639 {
2642 let core = k.dot(&s_iz).dot(&dot_k_j.t());
2643 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2644 }
2645 {
2648 let core = k.dot(&s_ij).dot(&k.t());
2649 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2650 }
2651
2652 let mut m_mddot = Array1::<f64>::zeros(n);
2657 if x_tau_tau_is_some {
2659 let core = k.dot(&s_zz).dot(k);
2660 m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2661 }
2662 if x_tau_tau_is_some {
2664 let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2665 let core = k.dot(&s_ijz).dot(k);
2666 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2667 }
2668 {
2670 let core = dot_k_j.dot(&s_zz).dot(k);
2671 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2672 }
2673 {
2675 let core = dot_k_j.dot(&s_iz).dot(k);
2676 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2677 }
2678 {
2680 let core = dot_k_i.dot(&s_zz).dot(k);
2681 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2682 }
2683 {
2685 let core = dot_k_i.dot(&s_jz).dot(k);
2686 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2687 }
2688 {
2690 let core = k.dot(&s_jz).dot(k);
2691 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2692 }
2693 {
2695 let core = k.dot(&s_iz).dot(k);
2696 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2697 }
2698 {
2700 let core = k_ddot.dot(&s_zz).dot(k);
2701 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2702 }
2703
2704 let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2709 out.column_mut(col).assign(&col_out);
2710 }
2711 out
2712 }
2713
2714 pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2725 &self,
2726 tau_kernel: &FirthTauPartialKernel,
2727 deta_partial: &Array1<f64>,
2728 dot_i_partial: &Array2<f64>,
2729 beta_direction: &FirthDirection,
2730 x_tau_v: &Array1<f64>,
2731 ) -> FirthTauBetaPartialKernel {
2732 let s_v = &self.w1 * &beta_direction.deta;
2744 let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2745 let cross1 =
2746 RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2747 let cross2 =
2748 RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2749 let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2750 let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2751
2752 let term_a = beta_direction
2761 .a_u_reduced
2762 .dot(dot_i_partial)
2763 .dot(&self.k_reduced);
2764 let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2765 let term_c = self
2766 .k_reduced
2767 .dot(dot_i_partial)
2768 .dot(&beta_direction.a_u_reduced);
2769 let d_beta_dot_k = &term_a - &term_b + &term_c;
2770
2771 let cross_diag = Self::rowwise_bilinear(
2777 &tau_kernel.x_tau_reduced,
2778 &beta_direction.a_u_reduced,
2779 &self.x_reduced,
2780 );
2781 let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2782 let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2783
2784 FirthTauBetaPartialKernel {
2785 x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2786 deta_partial: deta_partial.clone(),
2787 dot_h_partial: tau_kernel.dot_h_partial.clone(),
2788 dot_i_partial: dot_i_partial.clone(),
2789 dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2790 deta_v: beta_direction.deta.clone(),
2791 deta_tau_v: x_tau_v.clone(),
2792 a_v_reduced: beta_direction.a_u_reduced.clone(),
2793 dh_v: beta_direction.dh.clone(),
2794 b_vvec: beta_direction.b_uvec.clone(),
2795 d_beta_dot_k,
2796 d_beta_dot_h,
2797 }
2798 }
2799
2800 pub(crate) fn apply_p_tau_v_to_matrix(
2811 &self,
2812 kernel: &FirthTauBetaPartialKernel,
2813 mat: &Array2<f64>,
2814 ) -> Array2<f64> {
2815 let n = self.x_dense.nrows();
2816 if mat.nrows() != n || mat.ncols() == 0 {
2817 return Array2::<f64>::zeros(mat.raw_dim());
2818 }
2819 let z = &self.x_reduced;
2820 let z_tau = &kernel.x_tau_reduced;
2821 let k_r = &self.k_reduced;
2822 let a_v = &kernel.a_v_reduced; let dot_k_tau = &kernel.dot_k_reduced; let d_beta_dot_k = &kernel.d_beta_dot_k; let mut out = Array2::<f64>::zeros(mat.raw_dim());
2826 for col in 0..mat.ncols() {
2827 let v = mat.column(col).to_owned();
2828 let s_zz = RemlState::reducedweighted_gram(z, &v);
2829 let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2830
2831 let mid_1 = a_v.dot(&s_zz).dot(k_r);
2834 let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2835 let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2838 let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2839 let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2842 let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2843 let mid_4 = k_r.dot(&s_zz).dot(a_v);
2846 let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2847 let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2850 let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2851 let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2853
2854 let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2857 out.column_mut(col).assign(&y);
2858 }
2859 out
2860 }
2861
2862 pub(crate) fn d_beta_hphi_tau_partial_apply(
2863 &self,
2864 x_tau: &Array2<f64>,
2865 kernel: &FirthTauBetaPartialKernel,
2866 rhs: &Array2<f64>,
2867 ) -> Array2<f64> {
2868 let p = self.x_dense.ncols();
2869 if rhs.nrows() != p {
2870 return Array2::<f64>::zeros((p, rhs.ncols()));
2871 }
2872 if rhs.ncols() == 0 || p == 0 {
2873 return Array2::<f64>::zeros((p, rhs.ncols()));
2874 }
2875 let etav = fast_ab(&self.x_dense, rhs);
2884 let etav_tau = fast_ab(x_tau, rhs);
2885 let deta_v = &kernel.deta_v;
2886 let deta_tau_v = &kernel.deta_tau_v;
2887 let eta_tau = &kernel.deta_partial;
2888 let dot_h = &kernel.dot_h_partial;
2889
2890 let dotw1 = &self.w2 * eta_tau;
2892 let dotw2 = &self.w3 * eta_tau;
2893
2894 let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2900 let b_vvec = &kernel.b_vvec;
2901 let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2902 let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2903
2904 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2906 let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2907 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2908 let m_qv = self.apply_pbar_to_matrix(&qv);
2909 let tau_kernel_view = FirthTauPartialKernel {
2912 deta_partial: eta_tau.clone(),
2913 dotw1: dotw1.clone(),
2914 dotw2: dotw2.clone(),
2915 dot_h_partial: dot_h.clone(),
2916 x_tau_reduced: kernel.x_tau_reduced.clone(),
2917 dot_i_partial: kernel.dot_i_partial.clone(),
2918 dot_k_reduced: kernel.dot_k_reduced.clone(),
2919 };
2920 let m_qv_tau =
2921 self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2922
2923 let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2927 let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2928 + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2929
2930 let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2932 + self.apply_pbar_to_matrix(&d_beta_qv);
2933
2934 let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2936 + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2937 + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2938 + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2939
2940 let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2943 - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2944 - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2945
2946 let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2957 let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2958
2959 let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2960 + &etav_tau * &c_v.view().insert_axis(Axis(1))
2961 + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2962 - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2963 - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2964 - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2965 - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2966
2967 0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2968 }
2969}
2970
2971#[cfg(test)]
2972mod tests {
2973 use super::*;
2974 use crate::mixture_link::logit_inverse_link_jet5;
2975 use gam_problem::StandardLink;
2976 use ndarray::{Array1, Array2, array};
2977
2978 impl FirthDenseOperator {
2984 pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
2985 &self.w * &self.h_diag
2986 }
2987
2988 pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
2992 let mut shift = Array1::<f64>::zeros(self.w.len());
2993 for i in 0..self.w.len() {
2994 let wi = self.w[i];
2995 if wi > 0.0 {
2996 shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
2997 }
2998 }
2999 shift
3000 }
3001 }
3002
3003 pub(crate) fn build_logit_firth_dense_operator(
3004 x_dense: &Array2<f64>,
3005 eta: &Array1<f64>,
3006 ) -> Result<FirthDenseOperator, EstimationError> {
3007 FirthDenseOperator::build_with_observation_weights_impl(
3008 &InverseLink::Standard(StandardLink::Logit),
3009 x_dense,
3010 eta,
3011 None,
3012 )
3013 }
3014
3015 pub(crate) fn build_weighted_logit_firth_dense_operator(
3016 x_dense: &Array2<f64>,
3017 eta: &Array1<f64>,
3018 observation_weights: ndarray::ArrayView1<'_, f64>,
3019 ) -> Result<FirthDenseOperator, EstimationError> {
3020 FirthDenseOperator::build_with_observation_weights_impl(
3021 &InverseLink::Standard(StandardLink::Logit),
3022 x_dense,
3023 eta,
3024 Some(observation_weights),
3025 )
3026 }
3027
3028 pub(crate) fn logisticweight(eta: f64) -> f64 {
3029 logit_inverse_link_jet5(eta).d1
3030 }
3031
3032 pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3033 let eta = x.dot(beta);
3034 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3035 op.jeffreys_logdet()
3036 }
3037
3038 pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
3039 let eta = x.dot(beta);
3040 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3041 op.jeffreys_beta_gradient()
3042 }
3043
3044 pub(crate) fn weighted_firthphivalue(
3045 x: &Array2<f64>,
3046 beta: &Array1<f64>,
3047 observation_weights: &Array1<f64>,
3048 ) -> f64 {
3049 let eta = x.dot(beta);
3050 let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
3051 .expect("weighted firth operator");
3052 op.jeffreys_logdet()
3053 }
3054
3055 #[test]
3056 pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
3057 let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
3058 let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
3059 .expect("barely positive-definite reduced fisher");
3060 let expected = 0.5 * 16.0_f64.ln();
3061
3062 assert!(
3063 half_log_det.is_finite(),
3064 "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
3065 );
3066 assert!(
3067 (half_log_det - expected).abs() < 1e-12,
3068 "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
3069 );
3070 assert!(
3071 k_reduced.iter().all(|value| value.is_finite()),
3072 "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
3073 );
3074 assert!(
3075 k_reduced[[1, 1]].abs() < f64::EPSILON,
3076 "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
3077 );
3078 }
3079
3080 #[test]
3081 pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
3082 let x = array![
3101 [1.0, -1.1, 0.2],
3102 [1.0, -0.5, -0.6],
3103 [1.0, 0.0, 0.3],
3104 [1.0, 0.8, -0.4],
3105 [1.0, 1.2, 0.7],
3106 ];
3107 let beta = array![0.15, -0.6, 0.35];
3108 let eta = x.dot(&beta);
3109 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3110
3111 let h = 1e-2_f64;
3112 let w = |z: f64| logisticweight(z);
3113 let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
3114 let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
3115 let d3direct = |z: f64| {
3116 (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
3117 };
3118 let d4direct = |z: f64| {
3119 (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
3120 / h.powi(4)
3121 };
3122 for i in 0..eta.len() {
3123 let z = eta[i];
3124 let wfd = w(z);
3125 let w1fd = d1direct(z);
3126 let w2fd = d2direct(z);
3127 let w3fd = d3direct(z);
3128 let w4fd = d4direct(z);
3129
3130 assert!((op.w[i] - wfd).abs() < 1e-12);
3131 assert_eq!(op.w1[i].signum(), w1fd.signum());
3132 assert_eq!(op.w2[i].signum(), w2fd.signum());
3133 assert_eq!(op.w3[i].signum(), w3fd.signum());
3134 assert_eq!(op.w4[i].signum(), w4fd.signum());
3135 assert!((op.w1[i] - w1fd).abs() < 1e-5);
3136 assert!((op.w2[i] - w2fd).abs() < 1e-4);
3137 assert!((op.w3[i] - w3fd).abs() < 1e-4);
3138 assert!((op.w4[i] - w4fd).abs() < 1e-3);
3139 }
3140 }
3141
3142 #[test]
3143 pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
3144 let x = array![
3145 [1.0, -0.7, 0.3],
3146 [1.0, -0.2, -0.4],
3147 [1.0, 0.5, 0.1],
3148 [1.0, 1.1, -0.6],
3149 [1.0, 1.6, 0.8],
3150 ];
3151 let beta = array![0.2, -0.45, 0.25];
3152 let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
3153 let eta = x.dot(&beta);
3154 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3155 .expect("weighted firth operator");
3156 let grad = op.jeffreys_beta_gradient();
3157 let h = 1e-6;
3158
3159 for j in 0..beta.len() {
3160 let mut beta_plus = beta.clone();
3161 beta_plus[j] += h;
3162 let mut beta_minus = beta.clone();
3163 beta_minus[j] -= h;
3164 let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
3165 - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
3166 / (2.0 * h);
3167 assert!(
3168 (grad[j] - fd).abs() < 1e-5,
3169 "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
3170 j,
3171 grad[j],
3172 fd
3173 );
3174 }
3175 }
3176
3177 pub(crate) fn build_link_firth_op(
3185 link: StandardLink,
3186 x: &Array2<f64>,
3187 beta: &Array1<f64>,
3188 ) -> FirthDenseOperator {
3189 let eta = x.dot(beta);
3190 FirthDenseOperator::build_with_observation_weights_impl(
3191 &InverseLink::Standard(link),
3192 x,
3193 &eta,
3194 None,
3195 )
3196 .expect("link-general firth operator")
3197 }
3198
3199 pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3200 build_link_firth_op(link, x, beta).jeffreys_logdet()
3201 }
3202
3203 pub(crate) fn link_firth_grad(
3204 link: StandardLink,
3205 x: &Array2<f64>,
3206 beta: &Array1<f64>,
3207 ) -> Array1<f64> {
3208 build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
3209 }
3210
3211 pub(crate) fn numeric_firth_hessian(
3216 link: StandardLink,
3217 x: &Array2<f64>,
3218 beta: &Array1<f64>,
3219 h: f64,
3220 ) -> Array2<f64> {
3221 let p = beta.len();
3222 let mut hess = Array2::<f64>::zeros((p, p));
3223 for j in 0..p {
3224 let mut bp = beta.clone();
3225 bp[j] += h;
3226 let mut bm = beta.clone();
3227 bm[j] -= h;
3228 let gp = link_firth_grad(link, x, &bp);
3229 let gm = link_firth_grad(link, x, &bm);
3230 let col = (&gp - &gm) / (2.0 * h);
3231 hess.column_mut(j).assign(&col);
3232 }
3233 hess
3234 }
3235
3236 #[test]
3243 fn hphisecond_eye_cached_matches_per_pair_bit_identical_1575() {
3244 let x = array![
3247 [1.0, -1.10, 0.35],
3248 [1.0, -0.40, -0.65],
3249 [1.0, 0.15, 0.20],
3250 [1.0, 0.80, -0.45],
3251 [1.0, 1.25, 0.70],
3252 [1.0, -0.55, 0.95],
3253 ];
3254 let beta = array![0.20, -0.55, 0.30];
3255 let op = build_link_firth_op(StandardLink::Logit, &x, &beta);
3256 let p = x.ncols();
3257
3258 let deta_list = [
3260 x.dot(&array![0.9, -0.3, 0.2]),
3261 x.dot(&array![-0.4, 0.7, 0.1]),
3262 x.dot(&array![0.1, 0.2, -0.8]),
3263 ];
3264 let dirs: Vec<FirthDirection> = deta_list
3265 .iter()
3266 .map(|d| op.direction_from_deta(d.clone()))
3267 .collect();
3268
3269 let eye = Array2::<f64>::eye(p);
3270 let cache = op.tk_second_direction_eye_cache(&dirs);
3271 for i in 0..dirs.len() {
3272 for j in 0..=i {
3273 let reference = op.hphisecond_direction_apply(&dirs[i], &dirs[j], &eye);
3274 let cached = op.hphisecond_direction_apply_eye_cached(&cache, &dirs, i, j);
3275 assert_eq!(
3276 reference.dim(),
3277 cached.dim(),
3278 "shape mismatch at pair ({i},{j})"
3279 );
3280 for (a, b) in reference.iter().zip(cached.iter()) {
3281 assert_eq!(
3282 a.to_bits(),
3283 b.to_bits(),
3284 "cached D²H_φ[{i},{j}] is not bit-identical to per-pair: \
3285 reference={a}, cached={b}"
3286 );
3287 }
3288 }
3289 }
3290 }
3291
3292 pub(crate) fn fixed_design_5x3() -> Array2<f64> {
3294 array![
3295 [1.0, -1.10, 0.35],
3296 [1.0, -0.40, -0.65],
3297 [1.0, 0.15, 0.20],
3298 [1.0, 0.80, -0.45],
3299 [1.0, 1.25, 0.70],
3300 ]
3301 }
3302
3303 #[test]
3304 pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
3305 let x = fixed_design_5x3();
3310 let beta = array![0.20, -0.55, 0.30];
3311 let eta = x.dot(&beta);
3312
3313 let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
3314 let link_general = FirthDenseOperator::build_with_observation_weights_impl(
3315 &InverseLink::Standard(StandardLink::Logit),
3316 &x,
3317 &eta,
3318 None,
3319 )
3320 .expect("link-general logit");
3321
3322 assert_eq!(
3323 historical.jeffreys_logdet(),
3324 link_general.jeffreys_logdet(),
3325 "logit Φ must be bit-identical through the link-general path"
3326 );
3327 let g_hist = historical.jeffreys_beta_gradient();
3328 let g_link = link_general.jeffreys_beta_gradient();
3329 for j in 0..g_hist.len() {
3330 assert_eq!(
3331 g_hist[j], g_link[j],
3332 "logit gradient component {j} must be bit-identical"
3333 );
3334 }
3335 let hat_hist = historical.pirls_hat_diag();
3336 let hat_link = link_general.pirls_hat_diag();
3337 for i in 0..hat_hist.len() {
3338 assert_eq!(
3339 hat_hist[i], hat_link[i],
3340 "logit PIRLS hat diagonal {i} must be bit-identical"
3341 );
3342 }
3343 for i in 0..eta.len() {
3344 assert_eq!(historical.w[i], link_general.w[i]);
3345 assert_eq!(historical.w1[i], link_general.w1[i]);
3346 assert_eq!(historical.w2[i], link_general.w2[i]);
3347 assert_eq!(historical.w3[i], link_general.w3[i]);
3348 assert_eq!(historical.w4[i], link_general.w4[i]);
3349 }
3350 }
3351
3352 #[test]
3353 pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3354 let x = fixed_design_5x3();
3357 let beta = array![0.10, -0.40, 0.25];
3358 let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3359 let h = 1e-6_f64;
3360 let mut max_rel = 0.0_f64;
3361 for j in 0..beta.len() {
3362 let mut bp = beta.clone();
3363 bp[j] += h;
3364 let mut bm = beta.clone();
3365 bm[j] -= h;
3366 let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3367 - link_firth_phi(StandardLink::Probit, &x, &bm))
3368 / (2.0 * h);
3369 let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3370 let rel = (grad[j] - fd).abs() / denom;
3371 max_rel = max_rel.max(rel);
3372 assert!(
3373 rel < 1e-6,
3374 "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3375 grad[j],
3376 fd,
3377 rel
3378 );
3379 }
3380 assert!(
3381 max_rel < 1e-6,
3382 "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3383 );
3384 }
3385
3386 #[test]
3387 pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3388 let x = fixed_design_5x3();
3396 let beta = array![0.10, -0.40, 0.25];
3397 let p = beta.len();
3398
3399 let directions = [
3401 array![1.0, 0.0, 0.0],
3402 array![0.0, 1.0, 0.0],
3403 array![0.0, 0.0, 1.0],
3404 array![0.7, -0.5, 0.3],
3405 ];
3406
3407 let h_inner = 1e-4_f64; let h_dir = 1e-4_f64; let mut worst = 0.0_f64;
3410 for u in directions.iter() {
3411 let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3412 let deta = x.dot(u);
3413 let dir = op.direction_from_deta(deta);
3414 let analytic = op.hphi_direction(&dir);
3415
3416 let beta_plus = &beta + &(u * h_dir);
3417 let beta_minus = &beta - &(u * h_dir);
3418 let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3419 let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3420 let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3421
3422 let mut scale = 1e-6_f64;
3423 for r in 0..p {
3424 for c in 0..p {
3425 scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3426 }
3427 }
3428 for r in 0..p {
3429 for c in 0..p {
3430 let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3431 worst = worst.max(rel);
3432 assert!(
3433 rel < 5e-3,
3434 "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3435 analytic[[r, c]],
3436 fd[[r, c]],
3437 rel
3438 );
3439 }
3440 }
3441 }
3442 assert!(
3443 worst < 5e-3,
3444 "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3445 );
3446 }
3447
3448 #[test]
3449 pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3450 let x_full = array![
3454 [1.0, -1.20, -0.20],
3455 [1.0, -0.40, 0.60],
3456 [1.0, 0.10, 1.10],
3457 [1.0, 0.70, 1.70],
3458 [1.0, 1.30, 2.30],
3459 ];
3460 let x_reduced = array![
3461 [1.0, -1.20],
3462 [1.0, -0.40],
3463 [1.0, 0.10],
3464 [1.0, 0.70],
3465 [1.0, 1.30],
3466 ];
3467 let beta_full = array![0.25, -0.50, 0.15];
3468 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3469
3470 let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3471 let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3472 assert!(
3473 phi_full.is_finite(),
3474 "probit Φ on rank-deficient design must be finite, got {phi_full}"
3475 );
3476 assert!(
3477 (phi_full - phi_reduced).abs() < 1e-12,
3478 "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3479 );
3480
3481 let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3482 let grad_full = op_full.jeffreys_beta_gradient();
3483 assert!(
3484 grad_full.iter().all(|v| v.is_finite()),
3485 "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3486 );
3487 let hat_full = op_full.pirls_hat_diag();
3488 let hat_reduced =
3489 build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3490 for i in 0..hat_full.len() {
3491 assert!(
3492 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3493 "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3494 hat_full[i],
3495 hat_reduced[i]
3496 );
3497 }
3498 }
3499
3500 #[test]
3501 pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3502 let x_full = array![
3506 [1.0, -1.2, -0.2],
3507 [1.0, -0.4, 0.6],
3508 [1.0, 0.1, 1.1],
3509 [1.0, 0.7, 1.7],
3510 [1.0, 1.3, 2.3],
3511 ];
3512 let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3513 let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3514 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3515 let eta_full = x_full.dot(&beta_full);
3516 let eta_reduced = x_reduced.dot(&beta_reduced);
3517 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3518
3519 for i in 0..eta_full.len() {
3520 assert!(
3521 (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3522 "eta mismatch at row {i}: full={} reduced={}",
3523 eta_full[i],
3524 eta_reduced[i]
3525 );
3526 }
3527
3528 let op_full = build_weighted_logit_firth_dense_operator(
3529 &x_full,
3530 &eta_full,
3531 observation_weights.view(),
3532 )
3533 .expect("full firth operator");
3534 let op_reduced = build_weighted_logit_firth_dense_operator(
3535 &x_reduced,
3536 &eta_reduced,
3537 observation_weights.view(),
3538 )
3539 .expect("reduced firth operator");
3540
3541 assert!(
3542 (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3543 "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3544 op_full.jeffreys_logdet(),
3545 op_reduced.jeffreys_logdet()
3546 );
3547
3548 let hat_full = op_full.pirls_hat_diag();
3549 let hat_reduced = op_reduced.pirls_hat_diag();
3550 for i in 0..hat_full.len() {
3551 assert!(
3552 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3553 "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3554 hat_full[i],
3555 hat_reduced[i]
3556 );
3557 }
3558 }
3559
3560 #[test]
3561 pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3562 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3563 let basis = array![[1.4, -0.3], [0.6, 1.1]];
3564 let x_reparameterized = x.dot(&basis);
3565 let beta = array![0.25, -0.5];
3566 let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3567 assert!(
3568 basis_det.abs() > 1e-12,
3569 "basis transform must be invertible"
3570 );
3571 let basis_inv = array![
3572 [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3573 [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3574 ];
3575 let beta_reparameterized = basis_inv.dot(&beta);
3576 let eta = x.dot(&beta);
3577 let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3578 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3579
3580 for i in 0..eta.len() {
3581 assert!(
3582 (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3583 "eta mismatch at row {i}: original={} reparameterized={}",
3584 eta[i],
3585 eta_reparameterized[i]
3586 );
3587 }
3588
3589 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3590 .expect("original firth operator");
3591 let op_reparameterized = build_weighted_logit_firth_dense_operator(
3592 &x_reparameterized,
3593 &eta_reparameterized,
3594 observation_weights.view(),
3595 )
3596 .expect("reparameterized firth operator");
3597
3598 assert!(
3599 (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3600 "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3601 op.jeffreys_logdet(),
3602 op_reparameterized.jeffreys_logdet()
3603 );
3604
3605 let hat = op.pirls_hat_diag();
3606 let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3607 for i in 0..hat.len() {
3608 assert!(
3609 (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3610 "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3611 hat[i],
3612 hat_reparameterized[i]
3613 );
3614 }
3615 }
3616
3617 #[test]
3618 pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3619 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3620 let beta = array![0.25, -0.5];
3621 let eta = x.dot(&beta);
3622 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3623 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3624 .expect("firth operator");
3625
3626 let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3627 for i in 0..reduced_metric.nrows() {
3628 for j in 0..reduced_metric.ncols() {
3629 if i == j {
3630 continue;
3631 }
3632 assert!(
3633 reduced_metric[[i, j]].abs() < 1e-10,
3634 "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3635 reduced_metric[[i, j]]
3636 );
3637 }
3638 }
3639 }
3640
3641 #[test]
3642 pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3643 let x = array![
3644 [1.0, -1.0, 0.2],
3645 [1.0, -0.6, -0.3],
3646 [1.0, -0.1, 0.5],
3647 [1.0, 0.3, -0.7],
3648 [1.0, 0.8, 0.1],
3649 [1.0, 1.2, -0.4],
3650 ];
3651 let beta = array![0.1, -0.25, 0.2];
3652 let eta = x.dot(&beta);
3653 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3654
3655 let u = array![0.3, -0.2, 0.4];
3656 let v = array![-0.5, 0.1, 0.25];
3657 let du = op.direction_from_deta(x.dot(&u));
3658 let dv = op.direction_from_deta(x.dot(&v));
3659
3660 let eye = Array2::<f64>::eye(x.ncols());
3661 let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3662 let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3663
3664 for i in 0..uv.nrows() {
3665 for j in 0..uv.ncols() {
3666 let a = uv[[i, j]];
3667 let b = vu[[i, j]];
3668 assert_eq!(
3669 a.signum(),
3670 b.signum(),
3671 "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3672 );
3673 assert!(
3674 (a - b).abs() < 2e-7,
3675 "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3676 );
3677 }
3678 }
3679 }
3680
3681 #[test]
3682 pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3683 let x = array![
3684 [1.0, -1.1, 0.2],
3685 [1.0, -0.6, -0.3],
3686 [1.0, -0.1, 0.5],
3687 [1.0, 0.3, -0.7],
3688 [1.0, 0.8, 0.1],
3689 [1.0, 1.2, -0.4],
3690 ];
3691 let beta = array![0.08, -0.22, 0.27];
3692 let eta = x.dot(&beta);
3693 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3694 let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3695 let dir = op.direction_from_deta(x.dot(&u));
3696
3697 let p = x.ncols();
3698 let eye = Array2::<f64>::eye(p);
3699 let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3700 for i in 0..p {
3701 for j in 0..i {
3702 let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3703 via_apply[[i, j]] = sym;
3704 via_apply[[j, i]] = sym;
3705 }
3706 }
3707 let direct = op.hphi_direction(&dir);
3708 let diff = &direct - &via_apply;
3709 let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3710 assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3711 }
3712
3713 #[test]
3714 pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3715 let x = array![
3716 [1.0, -1.0, 0.2],
3717 [1.0, -0.6, -0.3],
3718 [1.0, -0.1, 0.5],
3719 [1.0, 0.3, -0.7],
3720 [1.0, 0.8, 0.1],
3721 [1.0, 1.2, -0.4],
3722 ];
3723 let x_tau = array![
3724 [0.0, 0.15, -0.05],
3725 [0.0, -0.10, 0.02],
3726 [0.0, 0.08, 0.04],
3727 [0.0, -0.06, -0.03],
3728 [0.0, 0.05, 0.01],
3729 [0.0, -0.12, 0.06],
3730 ];
3731 let beta = array![0.1, -0.25, 0.2];
3732 let eta = x.dot(&beta);
3733 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3734 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3735
3736 let h = 1e-6;
3737 let x_plus = &x + &(h * &x_tau);
3738 let x_minus = &x - &(h * &x_tau);
3739 let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3740
3741 assert!(
3742 (analytic - fd).abs() < 1e-6,
3743 "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3744 );
3745 }
3746
3747 #[test]
3748 pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3749 let x = array![
3750 [1.0, -1.0, 0.2],
3751 [1.0, -0.6, -0.3],
3752 [1.0, -0.1, 0.5],
3753 [1.0, 0.3, -0.7],
3754 [1.0, 0.8, 0.1],
3755 [1.0, 1.2, -0.4],
3756 ];
3757 let x_tau = array![
3758 [0.0, 0.15, -0.05],
3759 [0.0, -0.10, 0.02],
3760 [0.0, 0.08, 0.04],
3761 [0.0, -0.06, -0.03],
3762 [0.0, 0.05, 0.01],
3763 [0.0, -0.12, 0.06],
3764 ];
3765 let beta = array![0.1, -0.25, 0.2];
3766 let eta = x.dot(&beta);
3767 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3768 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3769
3770 let h = 1e-6;
3771 let x_plus = &x + &(h * &x_tau);
3772 let x_minus = &x - &(h * &x_tau);
3773 let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3774
3775 let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3776 assert!(
3777 err < 1e-6,
3778 "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3779 );
3780 }
3781
3782 #[test]
3787 pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3788 let x = array![
3789 [1.0, -1.0, 0.2],
3790 [1.0, -0.6, -0.3],
3791 [1.0, -0.1, 0.5],
3792 [1.0, 0.3, -0.7],
3793 [1.0, 0.8, 0.1],
3794 [1.0, 1.2, -0.4],
3795 ];
3796 let x_tau_i = array![
3797 [0.0, 0.15, -0.05],
3798 [0.0, -0.10, 0.02],
3799 [0.0, 0.08, 0.04],
3800 [0.0, -0.06, -0.03],
3801 [0.0, 0.05, 0.01],
3802 [0.0, -0.12, 0.06],
3803 ];
3804 let x_tau_j = array![
3805 [0.0, -0.04, 0.11],
3806 [0.0, 0.09, -0.02],
3807 [0.0, -0.06, 0.07],
3808 [0.0, 0.10, -0.05],
3809 [0.0, -0.03, 0.08],
3810 [0.0, 0.07, -0.09],
3811 ];
3812 let beta = array![0.1, -0.25, 0.2];
3813 let eta = x.dot(&beta);
3814 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3815
3816 let analytic = op
3817 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3818 .phi_tau_tau_partial;
3819
3820 let h = 1e-5_f64;
3821 let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3822 let eta_e = x_eval.dot(&beta);
3823 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3824 op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3825 .phi_tau_partial
3826 };
3827 let x_plus = &x + &(h * &x_tau_j);
3828 let x_minus = &x - &(h * &x_tau_j);
3829 let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3830
3831 let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3832 assert!(
3833 rel < 1e-7,
3834 "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3835 );
3836 }
3837
3838 #[test]
3843 pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3844 let x = array![
3845 [1.0, -1.0, 0.2],
3846 [1.0, -0.6, -0.3],
3847 [1.0, -0.1, 0.5],
3848 [1.0, 0.3, -0.7],
3849 [1.0, 0.8, 0.1],
3850 [1.0, 1.2, -0.4],
3851 ];
3852 let x_tau_i = array![
3853 [0.0, 0.15, -0.05],
3854 [0.0, -0.10, 0.02],
3855 [0.0, 0.08, 0.04],
3856 [0.0, -0.06, -0.03],
3857 [0.0, 0.05, 0.01],
3858 [0.0, -0.12, 0.06],
3859 ];
3860 let x_tau_j = array![
3861 [0.0, -0.04, 0.11],
3862 [0.0, 0.09, -0.02],
3863 [0.0, -0.06, 0.07],
3864 [0.0, 0.10, -0.05],
3865 [0.0, -0.03, 0.08],
3866 [0.0, 0.07, -0.09],
3867 ];
3868 let beta = array![0.1, -0.25, 0.2];
3869 let eta = x.dot(&beta);
3870 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3871
3872 let analytic = op
3873 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3874 .gphi_tau_tau;
3875
3876 let h = 1e-5_f64;
3877 let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3878 let eta_e = x_eval.dot(&beta);
3879 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3880 op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3881 };
3882 let x_plus = &x + &(h * &x_tau_j);
3883 let x_minus = &x - &(h * &x_tau_j);
3884 let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3885
3886 let scale = analytic
3887 .iter()
3888 .chain(fd.iter())
3889 .map(|v| v.abs())
3890 .fold(0.0_f64, f64::max)
3891 .max(1.0);
3892 let err_max = (&analytic - &fd)
3893 .iter()
3894 .map(|v| v.abs())
3895 .fold(0.0_f64, f64::max);
3896 let rel = err_max / scale;
3897 assert!(
3898 rel < 1e-7,
3899 "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3900 );
3901 }
3902
3903 #[test]
3923 pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3924 let x = array![
3925 [1.0, -1.0, 0.2],
3926 [1.0, -0.6, -0.3],
3927 [1.0, -0.1, 0.5],
3928 [1.0, 0.3, -0.7],
3929 [1.0, 0.8, 0.1],
3930 [1.0, 1.2, -0.4],
3931 ];
3932 let x_tau_i = array![
3933 [0.0, 0.15, -0.05],
3934 [0.0, -0.10, 0.02],
3935 [0.0, 0.08, 0.04],
3936 [0.0, -0.06, -0.03],
3937 [0.0, 0.05, 0.01],
3938 [0.0, -0.12, 0.06],
3939 ];
3940 let x_tau_j = array![
3941 [0.0, -0.04, 0.11],
3942 [0.0, 0.09, -0.02],
3943 [0.0, -0.06, 0.07],
3944 [0.0, 0.10, -0.05],
3945 [0.0, -0.03, 0.08],
3946 [0.0, 0.07, -0.09],
3947 ];
3948 let beta = array![0.1, -0.25, 0.2];
3949 let eta = x.dot(&beta);
3950 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3951 let p = x.ncols();
3952
3953 let m = 3usize;
3955 let mut rhs = Array2::<f64>::zeros((p, m));
3956 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3957 for r in 0..p {
3958 for c in 0..m {
3959 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3960 }
3961 }
3962
3963 let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3966 let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3967 let deta_i = x_tau_i.dot(&beta);
3968 let deta_j = x_tau_j.dot(&beta);
3969 let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3970 let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3971
3972 let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3973 x_tau_i_reduced.clone(),
3974 x_tau_j_reduced.clone(),
3975 &deta_i,
3976 &deta_j,
3977 dot_h_i.clone(),
3978 dot_h_j.clone(),
3979 dot_i_i.clone(),
3980 dot_i_j.clone(),
3981 None,
3982 None,
3983 );
3984 let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3985 x_tau_j_reduced,
3986 x_tau_i_reduced,
3987 &deta_j,
3988 &deta_i,
3989 dot_h_j,
3990 dot_h_i,
3991 dot_i_j,
3992 dot_i_i,
3993 None,
3994 None,
3995 );
3996 let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3997 let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3998
3999 let sym_diff: f64 = (&analytic_ij - &analytic_ji)
4001 .iter()
4002 .map(|v| v.abs())
4003 .fold(0.0_f64, f64::max);
4004 let sym_scale: f64 = analytic_ij
4005 .iter()
4006 .chain(analytic_ji.iter())
4007 .map(|v| v.abs())
4008 .fold(0.0_f64, f64::max)
4009 .max(1.0);
4010 assert!(
4011 sym_diff / sym_scale < 1e-10,
4012 "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
4013 );
4014
4015 let h = 1e-5_f64;
4018 let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
4019 let eta_e = x_eval.dot(&beta);
4020 let op_e =
4021 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4022 let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
4023 let deta_i_e = x_tau_i.dot(&beta);
4024 let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
4025 let kernel_i_e = op_e
4026 .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
4027 op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
4028 };
4029 let x_plus = &x + &(h * &x_tau_j);
4030 let x_minus = &x - &(h * &x_tau_j);
4031 let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
4032
4033 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4036 let scale = a
4037 .iter()
4038 .chain(b.iter())
4039 .map(|v| v.abs())
4040 .fold(0.0_f64, f64::max)
4041 .max(1.0);
4042 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4043 max_diff / scale
4044 };
4045 let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
4046
4047 let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
4050 let eta_e = x_eval.dot(&beta);
4051 let op_e =
4052 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4053 let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
4054 let deta_j_e = x_tau_j.dot(&beta);
4055 let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
4056 let kernel_j_e = op_e
4057 .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
4058 op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
4059 };
4060 let x_plus_i = &x + &(h * &x_tau_i);
4061 let x_minus_i = &x - &(h * &x_tau_i);
4062 let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
4063 let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
4064
4065 let tol = 1e-7_f64;
4066 assert!(
4067 err_ij < tol,
4068 "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
4069 analytic=\n{analytic_ij:?}\n\
4070 fd=\n{fd_ij:?}"
4071 );
4072 assert!(
4073 err_ji < tol,
4074 "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
4075 analytic=\n{analytic_ji:?}\n\
4076 fd=\n{fd_ji:?}"
4077 );
4078 }
4079
4080 #[test]
4098 pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
4099 let x = array![
4100 [1.0, -1.0, 0.2],
4101 [1.0, -0.6, -0.3],
4102 [1.0, -0.1, 0.5],
4103 [1.0, 0.3, -0.7],
4104 [1.0, 0.8, 0.1],
4105 [1.0, 1.2, -0.4],
4106 ];
4107 let x_tau = array![
4108 [0.0, 0.15, -0.05],
4109 [0.0, -0.10, 0.02],
4110 [0.0, 0.08, 0.04],
4111 [0.0, -0.06, -0.03],
4112 [0.0, 0.05, 0.01],
4113 [0.0, -0.12, 0.06],
4114 ];
4115 let beta = array![0.1, -0.25, 0.2];
4116 let v = array![0.3, 0.2, -0.15];
4118
4119 let eta = x.dot(&beta);
4120 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
4121 let p = x.ncols();
4122
4123 let m = 3usize;
4125 let mut rhs = Array2::<f64>::zeros((p, m));
4126 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
4127 for r in 0..p {
4128 for c in 0..m {
4129 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
4130 }
4131 }
4132
4133 let x_tau_reduced = op.reduce_explicit_design(&x_tau);
4135 let deta_partial = x_tau.dot(&beta);
4136 let (dot_i_partial, dot_h_partial) =
4137 op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
4138 let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
4139 x_tau_reduced.clone(),
4140 &deta_partial,
4141 dot_h_partial.clone(),
4142 dot_i_partial.clone(),
4143 );
4144
4145 let deta_v = x.dot(&v);
4146 let direction = op.direction_from_deta(deta_v);
4147 let x_tau_v = x_tau.dot(&v);
4148 let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
4149 &tau_kernel,
4150 &deta_partial,
4151 &dot_i_partial,
4152 &direction,
4153 &x_tau_v,
4154 );
4155 let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
4156
4157 let h = 1e-5_f64;
4160 let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
4161 let eta_e = x.dot(beta_eval);
4162 let op_e =
4163 build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
4164 let x_tau_r = op_e.reduce_explicit_design(&x_tau);
4165 let deta_e = x_tau.dot(beta_eval);
4166 let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
4167 let ker_e =
4168 op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
4169 op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
4170 };
4171 let beta_plus = &beta + &(h * &v);
4172 let beta_minus = &beta - &(h * &v);
4173 let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
4174
4175 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4176 let scale = a
4177 .iter()
4178 .chain(b.iter())
4179 .map(|v| v.abs())
4180 .fold(0.0_f64, f64::max)
4181 .max(1.0);
4182 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4183 max_diff / scale
4184 };
4185 let err = rel_max_abs_diff(&analytic, &fd);
4186
4187 let tol = 1e-7_f64;
4188 assert!(
4189 err < tol,
4190 "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
4191 analytic=\n{analytic:?}\n\
4192 fd=\n{fd:?}"
4193 );
4194 }
4195
4196 #[test]
4197 pub(crate) fn logisticweight_loses_positive_tail_mass() {
4198 let eta = 50.0_f64;
4199 let z = (-eta).exp();
4200 let stable = z / (1.0_f64 + z).powi(2);
4201 assert!(stable > 0.0);
4202 let got = logisticweight(eta);
4203 assert!(
4204 (got - stable).abs() < 1e-30,
4205 "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
4206 got,
4207 stable
4208 );
4209 }
4210
4211 #[test]
4212 pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
4213 for &eta in &[
4217 -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
4218 ] {
4219 let jet = logit_inverse_link_jet5(eta);
4220 let (w, w1, w2, w3, w4) =
4221 crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
4222 assert!(
4223 w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
4224 "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
4225 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
4226 jet.d1,
4227 jet.d2,
4228 jet.d3,
4229 jet.d4,
4230 jet.d5
4231 );
4232 }
4233 }
4234
4235 #[test]
4236 pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
4237 fn reference_probit_weight(eta: f64) -> f64 {
4241 let p = gam_math::probability::normal_cdf(eta);
4242 let q = 1.0 - p;
4243 let phi = gam_math::probability::normal_pdf(eta);
4244 if p <= 0.0 || q <= 0.0 {
4245 return 0.0;
4246 }
4247 phi * phi / (p * q)
4248 }
4249 let h = 1e-4_f64;
4250 for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
4251 let (w, w1, w2, _w3, _w4) =
4252 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4253 let ref_w = reference_probit_weight(eta);
4254 let fd1 =
4255 (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
4256 let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
4257 + reference_probit_weight(eta - h))
4258 / (h * h);
4259 assert!(
4260 (w - ref_w).abs() < 1e-10,
4261 "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
4262 );
4263 assert!(
4264 (w1 - fd1).abs() < 1e-5,
4265 "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
4266 );
4267 assert!(
4268 (w2 - fd2).abs() < 1e-3,
4269 "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
4270 );
4271 }
4272 }
4273
4274 #[test]
4275 pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
4276 for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
4280 let (w, w1, w2, w3, w4) =
4281 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4282 assert!(
4283 w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
4284 "probit Fisher weight jet must saturate to zero at eta={eta}; got \
4285 ({w}, {w1}, {w2}, {w3}, {w4})"
4286 );
4287 }
4288 for &eta in &[12.0_f64, -12.0] {
4293 let (w, w1, w2, w3, w4) =
4294 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4295 assert!(
4296 w > 0.0
4297 && w.is_finite()
4298 && w1.is_finite()
4299 && w2.is_finite()
4300 && w3.is_finite()
4301 && w4.is_finite(),
4302 "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
4303 ({w}, {w1}, {w2}, {w3}, {w4})"
4304 );
4305 }
4306 }
4307}