1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15struct FirthReducedCore {
21 w: Array1<f64>,
22 w1: Array1<f64>,
23 w2: Array1<f64>,
24 w3: Array1<f64>,
25 w4: Array1<f64>,
26 k_reduced: Array2<f64>,
27 half_log_det: f64,
28 h_diag: Array1<f64>,
29}
30
31pub(crate) struct FirthSecondDirEyeCache {
36 eye: Array2<f64>,
39 eta_rhs: Array2<f64>,
41 p_b_rhs: Array2<f64>,
43 p_bx: Vec<Array2<f64>>,
45 pu_qv: Vec<Array2<f64>>,
47}
48
49impl<'a> RemlState<'a> {
50 pub(crate) fn xt_diag_x_dense_into(
51 x: &Array2<f64>,
52 diag: &Array1<f64>,
53 weighted: &mut Array2<f64>,
54 ) -> Array2<f64> {
55 super::assembly::xt_diag_x_dense_into(x, diag, weighted)
56 }
57
58 #[inline]
59 pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
60 n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
61 }
62
63 pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
64 let mut out = Array2::<f64>::zeros(x.raw_dim());
65 super::assembly::row_scale_dense_into(x, scale, &mut out);
66 out
67 }
68
69 #[inline]
70 pub(crate) fn dense_product_likely_uses_inner_parallelism(
71 m: usize,
72 n: usize,
73 k: usize,
74 ) -> bool {
75 const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
80 const PAR_MIN_LONG_DIM: usize = 256;
81 let flop_scale = m.saturating_mul(n).saturating_mul(k);
82 let long_dim = m.max(n).max(k);
83 flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
84 }
85
86 #[inline]
87 pub(crate) fn should_join_independent_dense_products(
88 products: &[(usize, usize, usize)],
89 ) -> bool {
90 const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
91 if rayon::current_num_threads() <= 1 {
92 return false;
93 }
94 let mut total_flop_scale = 0usize;
95 for &(m, n, k) in products {
96 if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
97 return false;
98 }
99 total_flop_scale =
100 total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
101 }
102 total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
103 }
104
105 #[inline]
114 pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
115 out: &mut Array2<f64>,
116 observation_weight_sqrt: Option<&Array1<f64>>,
117 ) {
118 let Some(scale) = observation_weight_sqrt else {
119 return;
120 };
121 super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
122 }
123
124 #[inline]
129 pub(crate) fn fisher_weight_derivatives(
130 link: &InverseLink,
131 eta: f64,
132 ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
133 fisher_weight_jet5_for_inverse_link(link, eta)
134 }
135
136 #[inline]
137 pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
138 let mut min_pivot_sq = f64::INFINITY;
139 let mut max_pivot_sq = 0.0_f64;
140 for &pivot in chol_diag {
141 if !pivot.is_finite() || pivot <= 0.0 {
142 return false;
143 }
144 let pivot_sq = pivot * pivot;
145 min_pivot_sq = min_pivot_sq.min(pivot_sq);
146 max_pivot_sq = max_pivot_sq.max(pivot_sq);
147 }
148 if !min_pivot_sq.is_finite() {
149 return false;
150 }
151 let scale = max_pivot_sq.max(1.0);
152 let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
153 min_pivot_sq > floor
154 }
155
156 pub(crate) fn reduced_fisher_inverse_and_half_logdet(
157 fisher_reduced: &Array2<f64>,
158 ) -> Result<(Array2<f64>, f64), EstimationError> {
159 let r = fisher_reduced.nrows();
160 assert_eq!(r, fisher_reduced.ncols());
161 let mut k_reduced = Array2::<f64>::zeros((r, r));
162 if r == 0 {
163 return Ok((k_reduced, 0.0));
164 }
165
166 if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
167 let chol_diag = chol.diag();
168 if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
169 let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
170 for col in 0..r {
171 let mut e_col = Array1::<f64>::zeros(r);
172 e_col[col] = 1.0;
173 let solved = chol.solvevec(&e_col);
174 k_reduced.column_mut(col).assign(&solved);
175 }
176 return Ok((k_reduced, half_log_det));
177 }
178 }
179
180 let (evals_ir, evecs_ir) = fisher_reduced
181 .eigh(Side::Lower)
182 .map_err(EstimationError::EigendecompositionFailed)?;
183 let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
184 let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
185 let mut kept_positive_direction = false;
186 let mut half_log_det = 0.0_f64;
187 for (eig_idx, &eig) in evals_ir.iter().enumerate() {
188 if eig > tol {
189 kept_positive_direction = true;
190 half_log_det += 0.5 * eig.ln();
191 let inv = eig.recip();
192 let vec = evecs_ir.column(eig_idx).to_owned();
193 for row in 0..r {
194 for col in 0..r {
195 k_reduced[[row, col]] += inv * vec[row] * vec[col];
196 }
197 }
198 }
199 }
200 if !kept_positive_direction {
201 return Err(EstimationError::ModelIsIllConditioned {
202 condition_number: f64::INFINITY,
203 });
204 }
205 Ok((k_reduced, half_log_det))
206 }
207
208 pub(crate) fn fill_fisher_weight_derivative_arrays(
209 link: &InverseLink,
210 eta: &Array1<f64>,
211 w: &mut Array1<f64>,
212 w1: &mut Array1<f64>,
213 w2: &mut Array1<f64>,
214 w3: &mut Array1<f64>,
215 w4: &mut Array1<f64>,
216 ) -> Result<(), EstimationError> {
217 assert_eq!(eta.len(), w.len());
218 assert_eq!(eta.len(), w1.len());
219 assert_eq!(eta.len(), w2.len());
220 assert_eq!(eta.len(), w3.len());
221 assert_eq!(eta.len(), w4.len());
222
223 if Self::parallelize_firth_derivative_rows(eta.len()) {
224 let values: Result<Vec<_>, EstimationError> = eta
225 .par_iter()
226 .map(|&ei| Self::fisher_weight_derivatives(link, ei))
227 .collect();
228 for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
229 w[i] = value;
230 w1[i] = first;
231 w2[i] = second;
232 w3[i] = third;
233 w4[i] = fourth;
234 }
235 return Ok(());
236 }
237 for i in 0..eta.len() {
238 let (value, first, second, third, fourth) =
239 Self::fisher_weight_derivatives(link, eta[i])?;
240 w[i] = value;
241 w1[i] = first;
242 w2[i] = second;
243 w3[i] = third;
244 w4[i] = fourth;
245 }
246 Ok(())
247 }
248
249 pub(crate) fn weighted_cross(
250 left: &Array2<f64>,
251 right: &Array2<f64>,
252 weights: &Array1<f64>,
253 ) -> Array2<f64> {
254 assert_eq!(left.nrows(), right.nrows());
255 assert_eq!(left.nrows(), weights.len());
256 super::assembly::weighted_cross_dense(left, right, weights)
257 }
258
259 pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
260 assert_eq!(a.nrows(), b.ncols());
261 assert_eq!(a.ncols(), b.nrows());
262 let elems = a.nrows().saturating_mul(a.ncols());
263 if elems >= 32 * 32 {
264 let aview = FaerArrayView::new(a);
265 let bview = FaerArrayView::new(b);
266 return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
267 }
268 let m = a.nrows();
269 let n = a.ncols();
270 kahan_sum((0..m).map(|i| {
271 let mut acc = 0.0_f64;
272 for j in 0..n {
273 acc += a[[i, j]] * b[[j, i]];
274 }
275 acc
276 }))
277 }
278
279 pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
280 let weighted = Self::row_scale(z, weights);
287 fast_atb(z, &weighted)
288 }
289
290 pub(crate) fn reduced_crossweighted_gram(
291 z_left: &Array2<f64>,
292 z_right: &Array2<f64>,
293 weights: &Array1<f64>,
294 ) -> Array2<f64> {
295 let weighted = Self::row_scale(z_right, weights);
300 fast_atb(z_left, &weighted)
301 }
302
303 pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
304 let za = fast_ab(z, a);
310 (z * &za).sum_axis(ndarray::Axis(1))
311 }
312
313 pub(crate) fn apply_hadamard_gram(
314 z: &Array2<f64>,
315 a_left: &Array2<f64>,
316 a_right: &Array2<f64>,
317 vec: &Array1<f64>,
318 ) -> Array1<f64> {
319 let s = Self::reducedweighted_gram(z, vec);
328 let left_s = a_left.dot(&s);
329 let t = left_s.dot(a_right);
330 Self::reduced_diag_gram(z, &t)
331 }
332
333 pub(crate) fn apply_hadamard_gram_to_matrix(
334 z: &Array2<f64>,
335 a_left: &Array2<f64>,
336 a_right: &Array2<f64>,
337 mat: &Array2<f64>,
338 ) -> Array2<f64> {
339 let mut out = Array2::<f64>::zeros(mat.raw_dim());
347 for col in 0..mat.ncols() {
348 let v = mat.column(col).to_owned();
349 let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
350 out.column_mut(col).assign(&y);
351 }
352 out
353 }
354
355 pub(super) fn build_firth_dense_operator_for_link(
360 link: &InverseLink,
361 x_dense: &Array2<f64>,
362 eta: &Array1<f64>,
363 observation_weights: ndarray::ArrayView1<'_, f64>,
364 ) -> Result<FirthDenseOperator, EstimationError> {
365 FirthDenseOperator::build_with_observation_weights_impl(
366 link,
367 x_dense,
368 eta,
369 Some(observation_weights),
370 )
371 }
372
373 pub(super) fn firth_exact_tau_kernel(
374 op: &FirthDenseOperator,
375 x_tau: &Array2<f64>,
376 beta: &Array1<f64>,
377 include_hphi_tau_kernel: bool,
378 ) -> FirthTauExactKernel {
379 op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
380 }
381
382 pub(super) fn firth_hphi_tau_partial_apply(
383 op: &FirthDenseOperator,
384 x_tau: &Array2<f64>,
385 kernel: &FirthTauPartialKernel,
386 rhs: &Array2<f64>,
387 ) -> Array2<f64> {
388 op.hphi_tau_partial_apply(x_tau, kernel, rhs)
389 }
390}
391
392impl FirthDenseOperator {
393 pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
394 for col in 0..q_basis.ncols() {
395 let mut pivot_row = 0usize;
396 let mut pivot_abs = 0.0_f64;
397 for row in 0..q_basis.nrows() {
398 let value = q_basis[[row, col]];
399 let abs_value = value.abs();
400 if abs_value > pivot_abs {
401 pivot_abs = abs_value;
402 pivot_row = row;
403 }
404 }
405 if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
406 q_basis.column_mut(col).mapv_inplace(|v| -v);
407 }
408 }
409 }
410
411 pub(crate) fn identifiable_subspace_basis_from_gram(
412 gram: &Array2<f64>,
413 ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
414 let p = gram.nrows();
415 assert_eq!(p, gram.ncols());
416 if p == 0 {
417 return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
418 }
419
420 let (evals, evecs) = gram
421 .eigh(Side::Lower)
422 .map_err(EstimationError::EigendecompositionFailed)?;
423 let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
424 let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
425 let mut keep: Vec<usize> = evals
426 .iter()
427 .enumerate()
428 .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
429 .collect();
430 if keep.is_empty() {
431 return Err(EstimationError::ModelIsIllConditioned {
432 condition_number: f64::INFINITY,
433 });
434 }
435
436 keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
442 let r = keep.len();
443 let mut q_basis = Array2::<f64>::zeros((p, r));
444 let mut metric_spectrum = Array1::<f64>::zeros(r);
445 for (col_idx, eig_idx) in keep.into_iter().enumerate() {
446 q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
447 metric_spectrum[col_idx] = evals[eig_idx];
448 }
449 Self::canonicalize_basis_column_signs(&mut q_basis);
450 Ok((q_basis, metric_spectrum))
451 }
452
453 #[inline]
454 pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
455 assert_eq!(diag.len(), matrix.nrows());
456 assert_eq!(matrix.nrows(), matrix.ncols());
457 kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
458 }
459
460 pub fn build_for_link(
461 link: &InverseLink,
462 x_dense: &Array2<f64>,
463 eta: &Array1<f64>,
464 ) -> Result<FirthDenseOperator, EstimationError> {
465 Self::build_with_observation_weights_impl(link, x_dense, eta, None)
466 }
467
468 pub fn build_with_observation_weights_for_link(
469 link: &InverseLink,
470 x_dense: &Array2<f64>,
471 eta: &Array1<f64>,
472 observation_weights: ndarray::ArrayView1<'_, f64>,
473 ) -> Result<FirthDenseOperator, EstimationError> {
474 Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
475 }
476
477 pub(crate) fn build_design_factor_with_observation_weights(
487 x_dense: &Array2<f64>,
488 observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
489 ) -> Result<FirthDesignFactor, EstimationError> {
490 let n = x_dense.nrows();
491 let observation_weight_sqrt = if let Some(weights) = observation_weights {
492 if weights.len() != n {
493 crate::bail_invalid_estim!(
494 "Firth operator observation weight length {} != number of rows {}",
495 weights.len(),
496 n
497 );
498 }
499 let mut sqrt = Array1::<f64>::zeros(n);
500 for i in 0..n {
501 let weight = weights[i];
502 if !weight.is_finite() || weight < 0.0 {
503 crate::bail_invalid_estim!(
504 "Firth operator requires finite nonnegative observation weights, got {} at row {}",
505 weight,
506 i
507 );
508 }
509 sqrt[i] = weight.sqrt();
510 }
511 Some(sqrt)
512 } else {
513 None
514 };
515 let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
516 RemlState::row_scale(x_dense, scale)
517 } else {
518 x_dense.clone()
519 };
520 let gram = fast_atb(&basis_design, &basis_design);
522 let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
523 let x_reduced = fast_ab(&basis_design, &q_basis);
524 let r = q_basis.ncols();
525 let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
526 for col in 0..r {
527 x_metric_reduced_inv_diag[col] = metric_spectrum[col].recip();
528 }
529 let x_dense_t = x_dense.t().to_owned();
530 Ok(FirthDesignFactor {
531 x_dense: x_dense.clone(),
532 x_dense_t,
533 q_basis,
534 x_reduced,
535 observation_weight_sqrt,
536 metric_spectrum,
537 x_metric_reduced_inv_diag,
538 r,
539 n,
540 })
541 }
542
543 fn firth_reduced_core(
551 factor: &FirthDesignFactor,
552 link: &InverseLink,
553 eta: &Array1<f64>,
554 ) -> Result<FirthReducedCore, EstimationError> {
555 let n = factor.n;
556 if eta.len() != n {
557 crate::bail_invalid_estim!(
558 "Firth operator shape mismatch: nrows={}, eta_len={}",
559 n,
560 eta.len()
561 );
562 }
563 let r = factor.r;
564 let mut w = Array1::<f64>::zeros(n);
565 let mut w1 = Array1::<f64>::zeros(n);
566 let mut w2 = Array1::<f64>::zeros(n);
567 let mut w3 = Array1::<f64>::zeros(n);
568 let mut w4 = Array1::<f64>::zeros(n);
569 RemlState::fill_fisher_weight_derivative_arrays(
570 link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
571 )?;
572
573 let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&factor.x_reduced, &w);
575 if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
576 let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
577 let min_ev = eigvals_ir
578 .iter()
579 .copied()
580 .filter(|v| v.is_finite() && *v > 0.0)
581 .fold(f64::INFINITY, f64::min);
582 if min_ev.is_finite() {
583 let rel = min_ev / max_ev;
584 if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
585 log::warn!(
586 "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
587 min_ev,
588 max_ev,
589 rel
590 );
591 }
592 }
593 }
594
595 let (k_reduced, mut half_log_det) = if r > 0 {
596 RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
597 } else {
598 (Array2::<f64>::zeros((r, r)), 0.0)
599 };
600 if r > 0 {
601 for col in 0..r {
602 let metric_eig = factor.metric_spectrum[col];
603 half_log_det -= 0.5 * metric_eig.ln();
604 }
605 }
606 let h_diag = if r > 0 {
607 RemlState::reduced_diag_gram(&factor.x_reduced, &k_reduced)
608 } else {
609 Array1::<f64>::zeros(n)
610 };
611 Ok(FirthReducedCore {
612 w,
613 w1,
614 w2,
615 w3,
616 w4,
617 k_reduced,
618 half_log_det,
619 h_diag,
620 })
621 }
622
623 pub(crate) fn build_from_design_factor(
627 factor: &FirthDesignFactor,
628 link: &InverseLink,
629 eta: &Array1<f64>,
630 ) -> Result<FirthDenseOperator, EstimationError> {
631 let FirthReducedCore {
632 w,
633 w1,
634 w2,
635 w3,
636 w4,
637 k_reduced,
638 half_log_det,
639 h_diag,
640 } = Self::firth_reduced_core(factor, link, eta)?;
641 let b_base = RemlState::row_scale(&factor.x_dense, &w1);
642 let p_b_base = RemlState::apply_hadamard_gram_to_matrix(
643 &factor.x_reduced,
644 &k_reduced,
645 &k_reduced,
646 &b_base,
647 );
648 Ok(FirthDenseOperator {
649 x_dense: factor.x_dense.clone(),
650 x_dense_t: factor.x_dense_t.clone(),
651 q_basis: factor.q_basis.clone(),
652 x_reduced: factor.x_reduced.clone(),
653 observation_weight_sqrt: factor.observation_weight_sqrt.clone(),
654 k_reduced,
655 x_metric_reduced_inv_diag: factor.x_metric_reduced_inv_diag.clone(),
656 half_log_det,
657 h_diag,
658 w,
659 w1,
660 w2,
661 w3,
662 w4,
663 b_base,
664 p_b_base,
665 })
666 }
667
668 pub(crate) fn pirls_diagnostics_from_factor(
679 factor: &FirthDesignFactor,
680 link: &InverseLink,
681 eta: &Array1<f64>,
682 ) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
683 let core = Self::firth_reduced_core(factor, link, eta)?;
684 let (w, w1, h_diag, half_log_det) =
685 (core.w, core.w1, core.h_diag, core.half_log_det);
686 let hat_diag = &w * &h_diag;
688 let mut score_shift = Array1::<f64>::zeros(w.len());
691 for i in 0..w.len() {
692 let wi = w[i];
693 if wi > 0.0 {
694 score_shift[i] = 0.5 * (w1[i] / wi) * h_diag[i];
695 }
696 }
697 Ok((hat_diag, half_log_det, score_shift))
698 }
699
700 pub(crate) fn build_with_observation_weights_impl(
701 link: &InverseLink,
702 x_dense: &Array2<f64>,
703 eta: &Array1<f64>,
704 observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
705 ) -> Result<FirthDenseOperator, EstimationError> {
706 let n = x_dense.nrows();
767 if eta.len() != n {
768 crate::bail_invalid_estim!(
769 "Firth operator shape mismatch: nrows={}, eta_len={}",
770 n,
771 eta.len()
772 );
773 }
774 let factor =
775 Self::build_design_factor_with_observation_weights(x_dense, observation_weights)?;
776 Self::build_from_design_factor(&factor, link, eta)
777 }
778
779 #[inline]
780 pub(crate) fn jeffreys_logdet(&self) -> f64 {
781 self.half_log_det
782 }
783
784 pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
797 use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
798 let p = self.x_dense.ncols();
799 assert_eq!(
800 z.nrows(),
801 p,
802 "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
803 p,
804 z.nrows()
805 );
806 let m = z.ncols();
807 if m == 0 {
808 return 0.0;
809 }
810 let z_owned = z.to_owned();
812 let xz = fast_ab(&self.x_dense, &z_owned);
813 let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
814 RemlState::row_scale(&xz, scale)
815 } else {
816 xz
817 };
818 let mut j_t = fast_xt_diag_x(&xtz, &self.w);
820 symmetrize_in_place(&mut j_t);
821 let (evals, _) = match j_t.eigh(Side::Lower) {
822 Ok(pair) => pair,
823 Err(_) => return f64::NEG_INFINITY,
824 };
825 let Some(evals_slice) = evals.as_slice() else {
826 return f64::NEG_INFINITY;
827 };
828 let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
829 0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
830 }
831
832 #[inline]
833 pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
834 0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
839 }
840
841 #[inline]
842 pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
843 (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
844 }
845
846 #[inline]
847 pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
848 let mut reduced = fast_ab(x, &self.q_basis);
849 if let Some(scale) = self.observation_weight_sqrt.as_ref() {
850 reduced = RemlState::row_scale(&reduced, scale);
851 }
852 reduced
853 }
854
855 pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
856 let s_u = &self.w1 * &deta;
872 let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
877 let k_g_u = self.k_reduced.dot(&g_u_reduced);
878 let a_u_reduced = k_g_u.dot(&self.k_reduced);
879 let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
882 let b_uvec = &self.w2 * &deta;
883 FirthDirection {
884 deta,
885 g_u_reduced,
886 a_u_reduced,
887 dh,
888 b_uvec,
889 }
890 }
891
892 #[inline]
893 pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
894 fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
895 }
896
897 #[inline]
898 pub(crate) fn apply_p_u_to_matrix(
899 &self,
900 a_u_reduced: &Array2<f64>,
901 mat: &Array2<f64>,
902 ) -> Array2<f64> {
903 let mut out = RemlState::apply_hadamard_gram_to_matrix(
904 &self.x_reduced,
905 &self.k_reduced,
906 a_u_reduced,
907 mat,
908 );
909 out.mapv_inplace(|v| -2.0 * v);
910 out
911 }
912
913 pub(crate) fn hphi_direction_apply(
914 &self,
915 dir: &FirthDirection,
916 rhs: &Array2<f64>,
917 ) -> Array2<f64> {
918 let p = self.x_dense.ncols();
919 if rhs.nrows() != p {
920 return Array2::<f64>::zeros((p, rhs.ncols()));
921 }
922 if rhs.ncols() == 0 || p == 0 {
923 return Array2::<f64>::zeros((p, rhs.ncols()));
924 }
925 let etav = fast_ab(&self.x_dense, rhs);
932 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
933 let m_qv = RemlState::apply_hadamard_gram_to_matrix(
934 &self.x_reduced,
935 &self.k_reduced,
936 &self.k_reduced,
937 &qv,
938 );
939 let buvec = &dir.b_uvec;
940 let m_buv = RemlState::apply_hadamard_gram_to_matrix(
941 &self.x_reduced,
942 &self.k_reduced,
943 &self.k_reduced,
944 &(&etav * &buvec.view().insert_axis(Axis(1))),
945 );
946 let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
947 let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
948 let diag_term = self
949 .x_dense_t
950 .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
951 let term1 = self.left_scaled_xt(buvec, &m_qv);
952 let term2 = self.left_scaled_xt(&self.w1, &m_buv);
953 let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
954 0.5 * (diag_term - (term1 + term2 + term3))
955 }
956
957 pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
958 let p = self.x_dense.ncols();
959 let eye = Array2::<f64>::eye(p);
960 let mut out = self.hphi_direction_apply(dir, &eye);
961 symmetrize_in_place(&mut out);
974 out
975 }
976
977 pub(crate) fn hphisecond_direction_apply(
978 &self,
979 u: &FirthDirection,
980 v: &FirthDirection,
981 rhs: &Array2<f64>,
982 ) -> Array2<f64> {
983 let p = self.x_dense.ncols();
984 if rhs.nrows() != p {
985 return Array2::<f64>::zeros((p, rhs.ncols()));
986 }
987 if rhs.ncols() == 0 || p == 0 {
988 return Array2::<f64>::zeros((p, rhs.ncols()));
989 }
990 let deta_uv = &u.deta * &v.deta;
1002 let s_uv = &self.w2 * &deta_uv;
1005 let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1006 let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1007 let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1008 let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1009 let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1012 - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1013 - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1014 let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1015 let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1020 + &(&self.w3 * &(&u.deta * &v.dh))
1021 + &(&self.w3 * &(&v.deta * &u.dh))
1022 + &(&self.w2 * &d2h);
1023
1024 let eta_rhs = fast_ab(&self.x_dense, rhs);
1025 let diag_term = fast_ab(
1026 &self.x_dense_t,
1027 &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1028 );
1029
1030 let b_uvvec = &self.w3 * &deta_uv;
1031 let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1032 let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1033
1034 let p_b_rhs = fast_ab(&self.p_b_base, rhs);
1039 let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
1040 &self.x_reduced,
1041 &self.k_reduced,
1042 &self.k_reduced,
1043 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1044 );
1045 let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
1046 &self.x_reduced,
1047 &self.k_reduced,
1048 &self.k_reduced,
1049 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1050 );
1051 let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1052 &self.x_reduced,
1053 &self.k_reduced,
1054 &self.k_reduced,
1055 &b_uv_base,
1056 );
1057 let p_buv_rhs = fast_ab(&p_buv_base, rhs);
1058
1059 let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
1060 let pv_bu_rhs = self.apply_p_u_to_matrix(
1061 &v.a_u_reduced,
1062 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1063 );
1064 let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
1065 let p_u_bv_rhs = self.apply_p_u_to_matrix(
1066 &u.a_u_reduced,
1067 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1068 );
1069
1070 let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1071 &self.x_reduced,
1072 &u.a_u_reduced,
1073 &v.a_u_reduced,
1074 &self.b_base,
1075 );
1076 let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1077 &self.x_reduced,
1078 &self.k_reduced,
1079 &a_uv_reduced,
1080 &self.b_base,
1081 );
1082 let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1083 let p_uv_rhs = fast_ab(&p_uv_base, rhs);
1084
1085 let d2_terms = [
1087 self.left_scaled_xt(&b_uvvec, &p_b_rhs),
1088 self.left_scaled_xt(&self.w1, &p_buv_rhs),
1089 self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
1090 self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
1091 self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
1092 self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1093 self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
1094 self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1095 self.left_scaled_xt(&self.w1, &p_uv_rhs),
1096 ];
1097 let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1098 for term in d2_terms {
1099 d2_j2 += &term;
1100 }
1101
1102 0.5 * (diag_term - d2_j2)
1103 }
1104
1105 pub(crate) fn tk_second_direction_eye_cache(
1119 &self,
1120 dirs: &[FirthDirection],
1121 ) -> FirthSecondDirEyeCache {
1122 let p = self.x_dense.ncols();
1123 let eye = Array2::<f64>::eye(p);
1124 let eta_rhs = fast_ab(&self.x_dense, &eye);
1126 let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1127 let p_b_rhs = fast_ab(&self.p_b_base, &eye);
1129 let mut p_bx = Vec::with_capacity(dirs.len());
1130 let mut pu_qv = Vec::with_capacity(dirs.len());
1131 for d in dirs {
1132 p_bx.push(RemlState::apply_hadamard_gram_to_matrix(
1134 &self.x_reduced,
1135 &self.k_reduced,
1136 &self.k_reduced,
1137 &(&eta_rhs * &d.b_uvec.view().insert_axis(Axis(1))),
1138 ));
1139 pu_qv.push(self.apply_p_u_to_matrix(&d.a_u_reduced, &qv));
1141 }
1142 FirthSecondDirEyeCache {
1143 eye,
1144 eta_rhs,
1145 p_b_rhs,
1146 p_bx,
1147 pu_qv,
1148 }
1149 }
1150
1151 pub(crate) fn hphisecond_direction_apply_eye_cached(
1158 &self,
1159 cache: &FirthSecondDirEyeCache,
1160 dirs: &[FirthDirection],
1161 i: usize,
1162 j: usize,
1163 ) -> Array2<f64> {
1164 let u = &dirs[i];
1165 let v = &dirs[j];
1166 let p = self.x_dense.ncols();
1167 let cols = cache.eta_rhs.ncols();
1168 if p == 0 || cols == 0 {
1169 return Array2::<f64>::zeros((p, cols));
1170 }
1171 let deta_uv = &u.deta * &v.deta;
1172 let s_uv = &self.w2 * &deta_uv;
1173 let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1174 let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1175 let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1176 let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1177 let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1178 - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1179 - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1180 let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1181 let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1182 + &(&self.w3 * &(&u.deta * &v.dh))
1183 + &(&self.w3 * &(&v.deta * &u.dh))
1184 + &(&self.w2 * &d2h);
1185
1186 let eta_rhs = &cache.eta_rhs;
1187 let diag_term = fast_ab(
1188 &self.x_dense_t,
1189 &(eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1190 );
1191
1192 let b_uvvec = &self.w3 * &deta_uv;
1193 let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1194
1195 let p_b_rhs = &cache.p_b_rhs;
1197 let p_bu_rhs = &cache.p_bx[i];
1198 let p_bv_rhs = &cache.p_bx[j];
1199 let p_u_b_rhs = &cache.pu_qv[i];
1200 let pv_b_rhs = &cache.pu_qv[j];
1201
1202 let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1204 &self.x_reduced,
1205 &self.k_reduced,
1206 &self.k_reduced,
1207 &b_uv_base,
1208 );
1209 let p_buv_rhs = fast_ab(&p_buv_base, &cache.eye);
1210
1211 let pv_bu_rhs = self.apply_p_u_to_matrix(
1212 &v.a_u_reduced,
1213 &(eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1214 );
1215 let p_u_bv_rhs = self.apply_p_u_to_matrix(
1216 &u.a_u_reduced,
1217 &(eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1218 );
1219
1220 let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1221 &self.x_reduced,
1222 &u.a_u_reduced,
1223 &v.a_u_reduced,
1224 &self.b_base,
1225 );
1226 let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1227 &self.x_reduced,
1228 &self.k_reduced,
1229 &a_uv_reduced,
1230 &self.b_base,
1231 );
1232 let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1233 let p_uv_rhs = fast_ab(&p_uv_base, &cache.eye);
1234
1235 let d2_terms = [
1236 self.left_scaled_xt(&b_uvvec, p_b_rhs),
1237 self.left_scaled_xt(&self.w1, &p_buv_rhs),
1238 self.left_scaled_xt(&u.b_uvec, p_bv_rhs),
1239 self.left_scaled_xt(&v.b_uvec, p_bu_rhs),
1240 self.left_scaled_xt(&u.b_uvec, pv_b_rhs),
1241 self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1242 self.left_scaled_xt(&v.b_uvec, p_u_b_rhs),
1243 self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1244 self.left_scaled_xt(&self.w1, &p_uv_rhs),
1245 ];
1246 let mut d2_j2 = Array2::<f64>::zeros((p, cols));
1247 for term in d2_terms {
1248 d2_j2 += &term;
1249 }
1250
1251 0.5 * (diag_term - d2_j2)
1252 }
1253
1254 pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1255 assert_eq!(a.nrows(), b.nrows());
1256 assert_eq!(a.ncols(), b.ncols());
1257 let mut out = Array1::<f64>::zeros(a.nrows());
1258 for i in 0..a.nrows() {
1259 let mut acc = 0.0_f64;
1260 for j in 0..a.ncols() {
1261 acc += a[[i, j]] * b[[i, j]];
1262 }
1263 out[i] = acc;
1264 }
1265 out
1266 }
1267
1268 pub(super) fn rowwise_bilinear(
1269 a: &Array2<f64>,
1270 m: &Array2<f64>,
1271 b: &Array2<f64>,
1272 ) -> Array1<f64> {
1273 assert_eq!(a.nrows(), b.nrows());
1275 assert_eq!(a.ncols(), m.nrows());
1276 assert_eq!(b.ncols(), m.ncols());
1277 let am = fast_ab(a, m);
1278 Self::rowwise_dot(&am, b)
1279 }
1280
1281 pub(crate) fn dot_i_and_h_from_reduced(
1282 &self,
1283 x_tau_reduced: &Array2<f64>,
1284 deta: &Array1<f64>,
1285 ) -> (Array2<f64>, Array1<f64>) {
1286 let dw = &self.w1 * deta;
1306 let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1307 + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1308 + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1309
1310 let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1311 let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1312 let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1313 let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1314 let dot_h = dot_h_explicit + dot_h_implicit;
1315 (dot_i, dot_h)
1316 }
1317
1318 pub(crate) fn exact_tau_kernel(
1319 &self,
1320 x_tau: &Array2<f64>,
1321 beta: &Array1<f64>,
1322 include_hphi_tau_kernel: bool,
1323 ) -> FirthTauExactKernel {
1324 let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1348 let x_tau_reduced = self.reduce_explicit_design(x_tau);
1349 let (dot_i_partial, dot_h_partial) =
1350 self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1351 let dot_s_partial =
1352 fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1353
1354 let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1355 let secondvec =
1356 &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1357 let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1358 let gphi_tau = first + second;
1359 let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1360 - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1361
1362 let tau_kernel = if include_hphi_tau_kernel {
1363 Some(self.hphi_tau_partial_prepare_from_partials(
1364 x_tau_reduced,
1365 &deta_partial,
1366 dot_h_partial,
1367 dot_i_partial,
1368 ))
1369 } else {
1370 None
1371 };
1372 FirthTauExactKernel {
1373 gphi_tau,
1374 phi_tau_partial,
1375 tau_kernel,
1376 }
1377 }
1378
1379 pub(crate) fn hphi_tau_partial_prepare_from_partials(
1380 &self,
1381 x_tau_reduced: Array2<f64>,
1382 deta_partial: &Array1<f64>,
1383 dot_h_partial: Array1<f64>,
1384 dot_i_partial: Array2<f64>,
1385 ) -> FirthTauPartialKernel {
1386 let dotw1 = &self.w2 * deta_partial;
1387 let dotw2 = &self.w3 * deta_partial;
1388 let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1389 FirthTauPartialKernel {
1390 deta_partial: deta_partial.clone(),
1391 dotw1,
1392 dotw2,
1393 dot_h_partial,
1394 x_tau_reduced,
1395 dot_i_partial,
1396 dot_k_reduced: dot_k,
1397 }
1398 }
1399
1400 pub(crate) fn d_beta_hphi_tau_partial_dense(
1401 &self,
1402 x_tau: &Array2<f64>,
1403 beta: &Array1<f64>,
1404 beta_direction: &Array1<f64>,
1405 ) -> Option<Array2<f64>> {
1406 if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1407 return None;
1408 }
1409 if !x_tau.iter().any(|value| *value != 0.0) {
1410 return None;
1411 }
1412 let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1413 let tau_kernel = tau_bundle.tau_kernel?;
1414 let firth_direction =
1415 self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1416 let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1417 let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1418 &tau_kernel,
1419 &tau_kernel.deta_partial,
1420 &tau_kernel.dot_i_partial,
1421 &firth_direction,
1422 &x_tau_v,
1423 );
1424 let eye = Array2::<f64>::eye(beta_direction.len());
1425 Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1426 }
1427
1428 pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1429 RemlState::apply_hadamard_gram_to_matrix(
1431 &self.x_reduced,
1432 &self.k_reduced,
1433 &self.k_reduced,
1434 mat,
1435 )
1436 }
1437
1438 pub(crate) fn apply_mtau_to_matrix(
1439 &self,
1440 kernel: &FirthTauPartialKernel,
1441 mat: &Array2<f64>,
1442 ) -> Array2<f64> {
1443 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1455 return Array2::<f64>::zeros(mat.raw_dim());
1456 }
1457 let mut out = Array2::<f64>::zeros(mat.raw_dim());
1458 for col in 0..mat.ncols() {
1459 let v = mat.column(col).to_owned();
1460 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1461 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1462 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1463
1464 let szt =
1465 RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1466 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1467 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1468
1469 let t3 = RemlState::apply_hadamard_gram(
1470 &self.x_reduced,
1471 &self.k_reduced,
1472 &kernel.dot_k_reduced,
1473 &v,
1474 );
1475
1476 let y = 2.0 * (t1 + t2 + t3);
1477 out.column_mut(col).assign(&y);
1478 }
1479 out
1480 }
1481
1482 pub(crate) fn hphi_tau_partial_apply(
1483 &self,
1484 x_tau: &Array2<f64>,
1485 kernel: &FirthTauPartialKernel,
1486 rhs: &Array2<f64>,
1487 ) -> Array2<f64> {
1488 let p = self.x_dense.ncols();
1489 if rhs.nrows() != p {
1490 return Array2::<f64>::zeros((p, rhs.ncols()));
1491 }
1492 if rhs.ncols() == 0 || p == 0 {
1493 return Array2::<f64>::zeros((p, rhs.ncols()));
1494 }
1495 let etav = fast_ab(&self.x_dense, rhs);
1512 let etav_tau = fast_ab(x_tau, rhs);
1513 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1514 let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1515 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1516 let m_qv = self.apply_pbar_to_matrix(&qv);
1517 let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1518 let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1519 * &self.h_diag.view().insert_axis(Axis(1))
1520 - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1521 let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1522 + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1523 * self.h_diag.view().insert_axis(Axis(1))
1524 + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1525 * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1526 - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1527 + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1528 0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1529 }
1530
1531 pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1871 &self,
1872 x_tau_i_reduced: Array2<f64>,
1873 x_tau_j_reduced: Array2<f64>,
1874 deta_i_partial: &Array1<f64>,
1875 deta_j_partial: &Array1<f64>,
1876 dot_h_i_partial: Array1<f64>,
1877 dot_h_j_partial: Array1<f64>,
1878 dot_i_i_partial: Array2<f64>,
1879 dot_i_j_partial: Array2<f64>,
1880 x_tau_tau_reduced: Option<Array2<f64>>,
1881 deta_ij_partial: Option<Array1<f64>>,
1882 ) -> FirthTauTauPartialKernel {
1883 let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1885 let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1886 FirthTauTauPartialKernel {
1887 x_tau_i_reduced,
1888 x_tau_j_reduced,
1889 deta_i_partial: deta_i_partial.clone(),
1890 deta_j_partial: deta_j_partial.clone(),
1891 dot_h_i_partial,
1892 dot_h_j_partial,
1893 dot_k_i_reduced,
1894 dot_k_j_reduced,
1895 dot_i_i_partial,
1896 dot_i_j_partial,
1897 x_tau_tau_reduced,
1898 deta_ij_partial,
1899 }
1900 }
1901
1902 pub(crate) fn hphi_tau_tau_partial_apply(
1910 &self,
1911 x_tau_i: &Array2<f64>,
1912 x_tau_j: &Array2<f64>,
1913 kernel: &FirthTauTauPartialKernel,
1914 rhs: &Array2<f64>,
1915 ) -> Array2<f64> {
1916 let p = self.x_dense.ncols();
1917 if rhs.nrows() != p {
1918 return Array2::<f64>::zeros((p, rhs.ncols()));
1919 }
1920 if rhs.ncols() == 0 || p == 0 {
1921 return Array2::<f64>::zeros((p, rhs.ncols()));
1922 }
1923 let n = self.x_dense.nrows();
1924 let m = rhs.ncols();
1925
1926 let z = &self.x_reduced;
1928 let x_r = &self.x_reduced;
1929 let k = &self.k_reduced;
1930 let x_ri = &kernel.x_tau_i_reduced;
1931 let x_rj = &kernel.x_tau_j_reduced;
1932 let deta_i = &kernel.deta_i_partial;
1933 let deta_j = &kernel.deta_j_partial;
1934 let dh_i = &kernel.dot_h_i_partial;
1935 let dh_j = &kernel.dot_h_j_partial;
1936 let dot_k_i = &kernel.dot_k_i_reduced;
1937 let dot_k_j = &kernel.dot_k_j_reduced;
1938 let dot_i_i = &kernel.dot_i_i_partial;
1939 let dot_i_j = &kernel.dot_i_j_partial;
1940
1941 let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1944 let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1945 let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1946 let zeros_n = Array1::<f64>::zeros(n);
1947 let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1948
1949 let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1953 (n, m, p),
1954 (n, m, p),
1955 (n, m, p),
1956 ]) {
1957 let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1958 || fast_ab(&self.x_dense, rhs),
1959 || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1960 );
1961 (eta_v, eta_i_v, eta_j_v)
1962 } else {
1963 (
1964 fast_ab(&self.x_dense, rhs),
1965 fast_ab(x_tau_i, rhs),
1966 fast_ab(x_tau_j, rhs),
1967 )
1968 }; let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1973 let qt_v = fast_atb(&self.q_basis, rhs); let mut out = fast_ab(x_rij, &qt_v); RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1976 &mut out,
1977 self.observation_weight_sqrt.as_ref(),
1978 );
1979 out
1980 } else {
1981 Array2::<f64>::zeros((n, m))
1982 };
1983
1984 let a_i_reduced = -dot_k_i; let a_j_reduced = -dot_k_j;
1991
1992 let dw_i = &self.w1 * deta_i;
2002 let dw_j = &self.w1 * deta_j;
2003 let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
2004 let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
2005 if x_tau_tau_is_some {
2006 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2007 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2008 }
2009 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
2010 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
2011 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
2012 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
2013 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
2014 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
2015 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2016
2017 let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
2022 + a_i_reduced.dot(dot_i_j).dot(k)
2023 + a_j_reduced.dot(dot_i_i).dot(k);
2024
2025 let dh_ij: Array1<f64> = {
2035 let r = k.ncols();
2036 let can_join = RemlState::should_join_independent_dense_products(&[
2037 (n, r, r),
2038 (n, r, r),
2039 (n, r, r),
2040 (n, r, r),
2041 ]);
2042 let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
2043 let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
2044 || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
2045 || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
2046 );
2047 (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
2048 } else {
2049 (
2050 fast_ab(x_r, &k_ddot),
2051 fast_ab(x_ri, dot_k_j),
2052 fast_ab(x_rj, dot_k_i),
2053 fast_ab(x_ri, k),
2054 )
2055 };
2056
2057 let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
2058 acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2059 acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2060 acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
2061 if x_tau_tau_is_some {
2062 let rij_k = fast_ab(x_rij, k);
2063 acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2064 }
2065 acc
2066 };
2067
2068 let gamma = &self.w2 * &self.h_diag;
2079 let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
2080 let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
2081 let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
2082 + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
2083 + &(&(&self.w3 * deta_i) * dh_j)
2084 + &(&(&self.w3 * deta_j) * dh_i)
2085 + &(&self.w2 * &dh_ij));
2086
2087 let mut diag_term = Array2::<f64>::zeros((p, m));
2097 let gamma_col = gamma.view().insert_axis(Axis(1));
2098 let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
2099 let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
2100 let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
2101
2102 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
2104 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
2105 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
2107 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
2108 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
2110 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
2111 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
2113 if x_tau_tau_is_some {
2115 let y: Array2<f64> = &eta_v * &gamma_col;
2121 let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2122 let mut y_scaled = y.clone();
2123 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2124 &mut y_scaled,
2125 self.observation_weight_sqrt.as_ref(),
2126 );
2127 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2128 } else {
2129 self.q_basis.dot(&x_rij.t().dot(&y))
2130 };
2131 diag_term = diag_term + xt_ij_y;
2132 diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
2133 }
2134
2135 let w1_col = self.w1.view().insert_axis(Axis(1));
2147 let b_v = &eta_v * &w1_col;
2148
2149 let w2_deta_i = &self.w2 * deta_i;
2151 let w2_deta_j = &self.w2 * deta_j;
2152 let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
2153 let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
2154 let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
2155 let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
2156
2157 let w3_didj = &(&self.w3 * deta_i) * deta_j;
2163 let w2_dij = &self.w2 * deta_ij;
2164 let bddot_scale = &w3_didj + &w2_dij;
2165 let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
2166 let mut bddot_ij_v = &eta_v * &bddot_scale_col;
2167 bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
2168 bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
2169 bddot_ij_v += &(&eta_ij_v * &w1_col);
2170
2171 let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
2173 let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
2174
2175 let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
2182 let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
2183 let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
2184 let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
2185
2186 let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
2188 let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
2189
2190 let p_ddot_b_v = self.apply_p_ddot_ij(
2192 x_r,
2193 x_ri,
2194 x_rj,
2195 x_rij,
2196 k,
2197 dot_k_i,
2198 dot_k_j,
2199 &k_ddot,
2200 x_tau_tau_is_some,
2201 &b_v,
2202 );
2203
2204 let apply_bdot_tau_t =
2218 |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
2219 let scale_col = scale_deta.view().insert_axis(Axis(1));
2220 self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
2221 };
2222
2223 let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
2224 let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
2225 let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
2226 out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
2227 out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
2228 if x_tau_tau_is_some {
2229 let y = q_v * &w1_col;
2231 let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2232 let mut y_scaled = y.clone();
2233 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2234 &mut y_scaled,
2235 self.observation_weight_sqrt.as_ref(),
2236 );
2237 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2238 } else {
2239 self.q_basis.dot(&x_rij.t().dot(&y))
2240 };
2241 out = out + contrib;
2242 }
2243 out
2244 };
2245
2246 let t1a = apply_bddot_ij_t(&p_bv);
2248 let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2249 let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2251 let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2252 let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2254 let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2255 let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2257 let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2258 let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2260
2261 let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2262
2263 0.5 * (diag_term - d2_bpb)
2264 }
2265
2266 pub(crate) fn exact_tau_tau_kernel(
2329 &self,
2330 x_tau_i: &Array2<f64>,
2331 x_tau_j: &Array2<f64>,
2332 x_tau_tau: Option<&Array2<f64>>,
2333 beta: &Array1<f64>,
2334 include_hphi_tau_tau_kernel: bool,
2335 ) -> FirthTauTauExactKernel {
2336 let deta_i = x_tau_i.dot(beta);
2337 let deta_j = x_tau_j.dot(beta);
2338 let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2339
2340 let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2341 let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2342 let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2343
2344 let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2345 let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2346
2347 let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2354 let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2355 let dw_i = &self.w1 * &deta_i;
2356 let dw_j = &self.w1 * &deta_j;
2357 let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2358
2359 let x_r = &self.x_reduced;
2360 let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2361 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2362 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2363 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2364 }
2365 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2366 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2367 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2368 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2369 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2370 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2371 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2372
2373 let k = &self.k_reduced;
2377 let k_dot_i_i = k.dot(&dot_i_i);
2378 let k_dot_i_j = k.dot(&dot_i_j);
2379 let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2380 - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2381
2382 let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2389 let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2390 let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2391 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2392 s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2393 }
2394 s_ddot = s_ddot
2395 + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2396 + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2397 let g_inv = &self.x_metric_reduced_inv_diag;
2406 let rdim = k.nrows();
2407 let mut a_pen = 0.0_f64;
2408 for kk in 0..rdim {
2409 for ll in 0..rdim {
2410 a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2411 }
2412 a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2413 }
2414 let phi_tau_tau_partial = a_lik + a_pen;
2415
2416 let dot_k_i = -k.dot(&dot_i_i).dot(k);
2423 let dot_k_j = -k.dot(&dot_i_j).dot(k);
2424 let a_i_red = -&dot_k_i; let a_j_red = -&dot_k_j; let k_ddot: Array2<f64> =
2427 -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2428
2429 let n = self.x_dense.nrows();
2435 let mut dh_ij = Array1::<f64>::zeros(n);
2436 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2437 let rij_k = x_rij.dot(k);
2438 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2439 }
2440 let xr_kddot = x_r.dot(&k_ddot);
2441 dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2442 let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2443 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2444 let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2445 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2446 let ri_k = x_tau_i_reduced.dot(k);
2447 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2448
2449 let w1_h = &self.w1 * &self.h_diag;
2452 let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2453 if let Some(x_ij) = x_tau_tau.as_ref() {
2454 gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2455 }
2456 let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2457 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2458
2459 let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2461 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2462
2463 let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2470 v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2471 v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2472 v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2473 v_dot_ij += &(&self.w1 * &dh_ij);
2474 gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2475
2476 let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2477 Some(self.hphi_tau_tau_partial_prepare_from_partials(
2478 x_tau_i_reduced,
2479 x_tau_j_reduced,
2480 &deta_i,
2481 &deta_j,
2482 dot_h_i,
2483 dot_h_j,
2484 dot_i_i,
2485 dot_i_j,
2486 x_tau_tau_reduced,
2487 deta_ij,
2488 ))
2489 } else {
2490 None
2491 };
2492
2493 FirthTauTauExactKernel {
2494 phi_tau_tau_partial,
2495 gphi_tau_tau,
2496 tau_tau_kernel,
2497 }
2498 }
2499
2500 pub(crate) fn apply_mtau_from_reduced(
2507 &self,
2508 x_tau_reduced: &Array2<f64>,
2509 dot_k_reduced: &Array2<f64>,
2510 mat: &Array2<f64>,
2511 ) -> Array2<f64> {
2512 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2513 return Array2::<f64>::zeros(mat.raw_dim());
2514 }
2515 let mut out = Array2::<f64>::zeros(mat.raw_dim());
2516 for col in 0..mat.ncols() {
2517 let v = mat.column(col).to_owned();
2518 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2519 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2520 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2521
2522 let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2523 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2524 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2525
2526 let t3 =
2527 RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2528
2529 let y = 2.0 * (t1 + t2 + t3);
2530 out.column_mut(col).assign(&y);
2531 }
2532 out
2533 }
2534
2535 pub(crate) fn apply_p_ddot_ij(
2546 &self,
2547 x_r: &Array2<f64>,
2548 x_ri: &Array2<f64>,
2549 x_rj: &Array2<f64>,
2550 x_rij: &Array2<f64>,
2551 k: &Array2<f64>,
2552 dot_k_i: &Array2<f64>,
2553 dot_k_j: &Array2<f64>,
2554 k_ddot: &Array2<f64>,
2555 x_tau_tau_is_some: bool,
2556 mat: &Array2<f64>,
2557 ) -> Array2<f64> {
2558 let n = self.x_dense.nrows();
2559 let m = mat.ncols();
2560 if mat.nrows() != n || m == 0 {
2561 return Array2::<f64>::zeros(mat.raw_dim());
2562 }
2563 let mut out = Array2::<f64>::zeros((n, m));
2564 for col in 0..m {
2565 let v = mat.column(col).to_owned();
2566 let s_zz = RemlState::reducedweighted_gram(x_r, &v); let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); let mut mdot_mdot = Array1::<f64>::zeros(n);
2586 {
2588 let core = k.dot(&s_zz).dot(&k.t());
2589 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2590 }
2591 {
2593 let core = k.dot(&s_zz).dot(&dot_k_j.t());
2594 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2595 }
2596 {
2598 let core = k.dot(&s_zj).dot(&k.t());
2599 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2600 }
2601 {
2603 let core = dot_k_i.dot(&s_zz).dot(&k.t());
2604 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2605 }
2606 {
2608 let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2609 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2610 }
2611 {
2613 let core = dot_k_i.dot(&s_zj).dot(&k.t());
2614 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2615 }
2616 {
2619 let core = k.dot(&s_iz).dot(&k.t());
2620 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2621 }
2622 {
2625 let core = k.dot(&s_iz).dot(&dot_k_j.t());
2626 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2627 }
2628 {
2631 let core = k.dot(&s_ij).dot(&k.t());
2632 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2633 }
2634
2635 let mut m_mddot = Array1::<f64>::zeros(n);
2640 if x_tau_tau_is_some {
2642 let core = k.dot(&s_zz).dot(k);
2643 m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2644 }
2645 if x_tau_tau_is_some {
2647 let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2648 let core = k.dot(&s_ijz).dot(k);
2649 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2650 }
2651 {
2653 let core = dot_k_j.dot(&s_zz).dot(k);
2654 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2655 }
2656 {
2658 let core = dot_k_j.dot(&s_iz).dot(k);
2659 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2660 }
2661 {
2663 let core = dot_k_i.dot(&s_zz).dot(k);
2664 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2665 }
2666 {
2668 let core = dot_k_i.dot(&s_jz).dot(k);
2669 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2670 }
2671 {
2673 let core = k.dot(&s_jz).dot(k);
2674 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2675 }
2676 {
2678 let core = k.dot(&s_iz).dot(k);
2679 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2680 }
2681 {
2683 let core = k_ddot.dot(&s_zz).dot(k);
2684 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2685 }
2686
2687 let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2692 out.column_mut(col).assign(&col_out);
2693 }
2694 out
2695 }
2696
2697 pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2708 &self,
2709 tau_kernel: &FirthTauPartialKernel,
2710 deta_partial: &Array1<f64>,
2711 dot_i_partial: &Array2<f64>,
2712 beta_direction: &FirthDirection,
2713 x_tau_v: &Array1<f64>,
2714 ) -> FirthTauBetaPartialKernel {
2715 let s_v = &self.w1 * &beta_direction.deta;
2727 let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2728 let cross1 =
2729 RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2730 let cross2 =
2731 RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2732 let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2733 let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2734
2735 let term_a = beta_direction
2744 .a_u_reduced
2745 .dot(dot_i_partial)
2746 .dot(&self.k_reduced);
2747 let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2748 let term_c = self
2749 .k_reduced
2750 .dot(dot_i_partial)
2751 .dot(&beta_direction.a_u_reduced);
2752 let d_beta_dot_k = &term_a - &term_b + &term_c;
2753
2754 let cross_diag = Self::rowwise_bilinear(
2760 &tau_kernel.x_tau_reduced,
2761 &beta_direction.a_u_reduced,
2762 &self.x_reduced,
2763 );
2764 let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2765 let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2766
2767 FirthTauBetaPartialKernel {
2768 x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2769 deta_partial: deta_partial.clone(),
2770 dot_h_partial: tau_kernel.dot_h_partial.clone(),
2771 dot_i_partial: dot_i_partial.clone(),
2772 dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2773 deta_v: beta_direction.deta.clone(),
2774 deta_tau_v: x_tau_v.clone(),
2775 a_v_reduced: beta_direction.a_u_reduced.clone(),
2776 dh_v: beta_direction.dh.clone(),
2777 b_vvec: beta_direction.b_uvec.clone(),
2778 d_beta_dot_k,
2779 d_beta_dot_h,
2780 }
2781 }
2782
2783 pub(crate) fn apply_p_tau_v_to_matrix(
2794 &self,
2795 kernel: &FirthTauBetaPartialKernel,
2796 mat: &Array2<f64>,
2797 ) -> Array2<f64> {
2798 let n = self.x_dense.nrows();
2799 if mat.nrows() != n || mat.ncols() == 0 {
2800 return Array2::<f64>::zeros(mat.raw_dim());
2801 }
2802 let z = &self.x_reduced;
2803 let z_tau = &kernel.x_tau_reduced;
2804 let k_r = &self.k_reduced;
2805 let a_v = &kernel.a_v_reduced; let dot_k_tau = &kernel.dot_k_reduced; let d_beta_dot_k = &kernel.d_beta_dot_k; let mut out = Array2::<f64>::zeros(mat.raw_dim());
2809 for col in 0..mat.ncols() {
2810 let v = mat.column(col).to_owned();
2811 let s_zz = RemlState::reducedweighted_gram(z, &v);
2812 let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2813
2814 let mid_1 = a_v.dot(&s_zz).dot(k_r);
2817 let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2818 let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2821 let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2822 let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2825 let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2826 let mid_4 = k_r.dot(&s_zz).dot(a_v);
2829 let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2830 let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2833 let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2834 let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2836
2837 let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2840 out.column_mut(col).assign(&y);
2841 }
2842 out
2843 }
2844
2845 pub(crate) fn d_beta_hphi_tau_partial_apply(
2846 &self,
2847 x_tau: &Array2<f64>,
2848 kernel: &FirthTauBetaPartialKernel,
2849 rhs: &Array2<f64>,
2850 ) -> Array2<f64> {
2851 let p = self.x_dense.ncols();
2852 if rhs.nrows() != p {
2853 return Array2::<f64>::zeros((p, rhs.ncols()));
2854 }
2855 if rhs.ncols() == 0 || p == 0 {
2856 return Array2::<f64>::zeros((p, rhs.ncols()));
2857 }
2858 let etav = fast_ab(&self.x_dense, rhs);
2867 let etav_tau = fast_ab(x_tau, rhs);
2868 let deta_v = &kernel.deta_v;
2869 let deta_tau_v = &kernel.deta_tau_v;
2870 let eta_tau = &kernel.deta_partial;
2871 let dot_h = &kernel.dot_h_partial;
2872
2873 let dotw1 = &self.w2 * eta_tau;
2875 let dotw2 = &self.w3 * eta_tau;
2876
2877 let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2883 let b_vvec = &kernel.b_vvec;
2884 let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2885 let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2886
2887 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2889 let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2890 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2891 let m_qv = self.apply_pbar_to_matrix(&qv);
2892 let tau_kernel_view = FirthTauPartialKernel {
2895 deta_partial: eta_tau.clone(),
2896 dotw1: dotw1.clone(),
2897 dotw2: dotw2.clone(),
2898 dot_h_partial: dot_h.clone(),
2899 x_tau_reduced: kernel.x_tau_reduced.clone(),
2900 dot_i_partial: kernel.dot_i_partial.clone(),
2901 dot_k_reduced: kernel.dot_k_reduced.clone(),
2902 };
2903 let m_qv_tau =
2904 self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2905
2906 let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2910 let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2911 + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2912
2913 let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2915 + self.apply_pbar_to_matrix(&d_beta_qv);
2916
2917 let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2919 + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2920 + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2921 + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2922
2923 let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2926 - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2927 - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2928
2929 let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2940 let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2941
2942 let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2943 + &etav_tau * &c_v.view().insert_axis(Axis(1))
2944 + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2945 - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2946 - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2947 - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2948 - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2949
2950 0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2951 }
2952}
2953
2954#[cfg(test)]
2955mod tests {
2956 use super::*;
2957 use crate::mixture_link::logit_inverse_link_jet5;
2958 use gam_problem::StandardLink;
2959 use ndarray::{Array1, Array2, array};
2960
2961 impl FirthDenseOperator {
2967 pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
2968 &self.w * &self.h_diag
2969 }
2970
2971 pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
2975 let mut shift = Array1::<f64>::zeros(self.w.len());
2976 for i in 0..self.w.len() {
2977 let wi = self.w[i];
2978 if wi > 0.0 {
2979 shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
2980 }
2981 }
2982 shift
2983 }
2984 }
2985
2986 pub(crate) fn build_logit_firth_dense_operator(
2987 x_dense: &Array2<f64>,
2988 eta: &Array1<f64>,
2989 ) -> Result<FirthDenseOperator, EstimationError> {
2990 FirthDenseOperator::build_with_observation_weights_impl(
2991 &InverseLink::Standard(StandardLink::Logit),
2992 x_dense,
2993 eta,
2994 None,
2995 )
2996 }
2997
2998 pub(crate) fn build_weighted_logit_firth_dense_operator(
2999 x_dense: &Array2<f64>,
3000 eta: &Array1<f64>,
3001 observation_weights: ndarray::ArrayView1<'_, f64>,
3002 ) -> Result<FirthDenseOperator, EstimationError> {
3003 FirthDenseOperator::build_with_observation_weights_impl(
3004 &InverseLink::Standard(StandardLink::Logit),
3005 x_dense,
3006 eta,
3007 Some(observation_weights),
3008 )
3009 }
3010
3011 pub(crate) fn logisticweight(eta: f64) -> f64 {
3012 logit_inverse_link_jet5(eta).d1
3013 }
3014
3015 pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3016 let eta = x.dot(beta);
3017 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3018 op.jeffreys_logdet()
3019 }
3020
3021 pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
3022 let eta = x.dot(beta);
3023 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3024 op.jeffreys_beta_gradient()
3025 }
3026
3027 pub(crate) fn weighted_firthphivalue(
3028 x: &Array2<f64>,
3029 beta: &Array1<f64>,
3030 observation_weights: &Array1<f64>,
3031 ) -> f64 {
3032 let eta = x.dot(beta);
3033 let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
3034 .expect("weighted firth operator");
3035 op.jeffreys_logdet()
3036 }
3037
3038 #[test]
3039 pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
3040 let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
3041 let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
3042 .expect("barely positive-definite reduced fisher");
3043 let expected = 0.5 * 16.0_f64.ln();
3044
3045 assert!(
3046 half_log_det.is_finite(),
3047 "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
3048 );
3049 assert!(
3050 (half_log_det - expected).abs() < 1e-12,
3051 "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
3052 );
3053 assert!(
3054 k_reduced.iter().all(|value| value.is_finite()),
3055 "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
3056 );
3057 assert!(
3058 k_reduced[[1, 1]].abs() < f64::EPSILON,
3059 "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
3060 );
3061 }
3062
3063 #[test]
3064 pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
3065 let x = array![
3084 [1.0, -1.1, 0.2],
3085 [1.0, -0.5, -0.6],
3086 [1.0, 0.0, 0.3],
3087 [1.0, 0.8, -0.4],
3088 [1.0, 1.2, 0.7],
3089 ];
3090 let beta = array![0.15, -0.6, 0.35];
3091 let eta = x.dot(&beta);
3092 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3093
3094 let h = 1e-2_f64;
3095 let w = |z: f64| logisticweight(z);
3096 let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
3097 let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
3098 let d3direct = |z: f64| {
3099 (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
3100 };
3101 let d4direct = |z: f64| {
3102 (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
3103 / h.powi(4)
3104 };
3105 for i in 0..eta.len() {
3106 let z = eta[i];
3107 let wfd = w(z);
3108 let w1fd = d1direct(z);
3109 let w2fd = d2direct(z);
3110 let w3fd = d3direct(z);
3111 let w4fd = d4direct(z);
3112
3113 assert!((op.w[i] - wfd).abs() < 1e-12);
3114 assert_eq!(op.w1[i].signum(), w1fd.signum());
3115 assert_eq!(op.w2[i].signum(), w2fd.signum());
3116 assert_eq!(op.w3[i].signum(), w3fd.signum());
3117 assert_eq!(op.w4[i].signum(), w4fd.signum());
3118 assert!((op.w1[i] - w1fd).abs() < 1e-5);
3119 assert!((op.w2[i] - w2fd).abs() < 1e-4);
3120 assert!((op.w3[i] - w3fd).abs() < 1e-4);
3121 assert!((op.w4[i] - w4fd).abs() < 1e-3);
3122 }
3123 }
3124
3125 #[test]
3126 pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
3127 let x = array![
3128 [1.0, -0.7, 0.3],
3129 [1.0, -0.2, -0.4],
3130 [1.0, 0.5, 0.1],
3131 [1.0, 1.1, -0.6],
3132 [1.0, 1.6, 0.8],
3133 ];
3134 let beta = array![0.2, -0.45, 0.25];
3135 let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
3136 let eta = x.dot(&beta);
3137 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3138 .expect("weighted firth operator");
3139 let grad = op.jeffreys_beta_gradient();
3140 let h = 1e-6;
3141
3142 for j in 0..beta.len() {
3143 let mut beta_plus = beta.clone();
3144 beta_plus[j] += h;
3145 let mut beta_minus = beta.clone();
3146 beta_minus[j] -= h;
3147 let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
3148 - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
3149 / (2.0 * h);
3150 assert!(
3151 (grad[j] - fd).abs() < 1e-5,
3152 "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
3153 j,
3154 grad[j],
3155 fd
3156 );
3157 }
3158 }
3159
3160 pub(crate) fn build_link_firth_op(
3168 link: StandardLink,
3169 x: &Array2<f64>,
3170 beta: &Array1<f64>,
3171 ) -> FirthDenseOperator {
3172 let eta = x.dot(beta);
3173 FirthDenseOperator::build_with_observation_weights_impl(
3174 &InverseLink::Standard(link),
3175 x,
3176 &eta,
3177 None,
3178 )
3179 .expect("link-general firth operator")
3180 }
3181
3182 pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3183 build_link_firth_op(link, x, beta).jeffreys_logdet()
3184 }
3185
3186 pub(crate) fn link_firth_grad(
3187 link: StandardLink,
3188 x: &Array2<f64>,
3189 beta: &Array1<f64>,
3190 ) -> Array1<f64> {
3191 build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
3192 }
3193
3194 pub(crate) fn numeric_firth_hessian(
3199 link: StandardLink,
3200 x: &Array2<f64>,
3201 beta: &Array1<f64>,
3202 h: f64,
3203 ) -> Array2<f64> {
3204 let p = beta.len();
3205 let mut hess = Array2::<f64>::zeros((p, p));
3206 for j in 0..p {
3207 let mut bp = beta.clone();
3208 bp[j] += h;
3209 let mut bm = beta.clone();
3210 bm[j] -= h;
3211 let gp = link_firth_grad(link, x, &bp);
3212 let gm = link_firth_grad(link, x, &bm);
3213 let col = (&gp - &gm) / (2.0 * h);
3214 hess.column_mut(j).assign(&col);
3215 }
3216 hess
3217 }
3218
3219 #[test]
3226 fn hphisecond_eye_cached_matches_per_pair_bit_identical_1575() {
3227 let x = array![
3230 [1.0, -1.10, 0.35],
3231 [1.0, -0.40, -0.65],
3232 [1.0, 0.15, 0.20],
3233 [1.0, 0.80, -0.45],
3234 [1.0, 1.25, 0.70],
3235 [1.0, -0.55, 0.95],
3236 ];
3237 let beta = array![0.20, -0.55, 0.30];
3238 let op = build_link_firth_op(StandardLink::Logit, &x, &beta);
3239 let p = x.ncols();
3240
3241 let deta_list = [
3243 x.dot(&array![0.9, -0.3, 0.2]),
3244 x.dot(&array![-0.4, 0.7, 0.1]),
3245 x.dot(&array![0.1, 0.2, -0.8]),
3246 ];
3247 let dirs: Vec<FirthDirection> = deta_list
3248 .iter()
3249 .map(|d| op.direction_from_deta(d.clone()))
3250 .collect();
3251
3252 let eye = Array2::<f64>::eye(p);
3253 let cache = op.tk_second_direction_eye_cache(&dirs);
3254 for i in 0..dirs.len() {
3255 for j in 0..=i {
3256 let reference = op.hphisecond_direction_apply(&dirs[i], &dirs[j], &eye);
3257 let cached = op.hphisecond_direction_apply_eye_cached(&cache, &dirs, i, j);
3258 assert_eq!(
3259 reference.dim(),
3260 cached.dim(),
3261 "shape mismatch at pair ({i},{j})"
3262 );
3263 for (a, b) in reference.iter().zip(cached.iter()) {
3264 assert_eq!(
3265 a.to_bits(),
3266 b.to_bits(),
3267 "cached D²H_φ[{i},{j}] is not bit-identical to per-pair: \
3268 reference={a}, cached={b}"
3269 );
3270 }
3271 }
3272 }
3273 }
3274
3275 pub(crate) fn fixed_design_5x3() -> Array2<f64> {
3277 array![
3278 [1.0, -1.10, 0.35],
3279 [1.0, -0.40, -0.65],
3280 [1.0, 0.15, 0.20],
3281 [1.0, 0.80, -0.45],
3282 [1.0, 1.25, 0.70],
3283 ]
3284 }
3285
3286 #[test]
3287 pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
3288 let x = fixed_design_5x3();
3293 let beta = array![0.20, -0.55, 0.30];
3294 let eta = x.dot(&beta);
3295
3296 let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
3297 let link_general = FirthDenseOperator::build_with_observation_weights_impl(
3298 &InverseLink::Standard(StandardLink::Logit),
3299 &x,
3300 &eta,
3301 None,
3302 )
3303 .expect("link-general logit");
3304
3305 assert_eq!(
3306 historical.jeffreys_logdet(),
3307 link_general.jeffreys_logdet(),
3308 "logit Φ must be bit-identical through the link-general path"
3309 );
3310 let g_hist = historical.jeffreys_beta_gradient();
3311 let g_link = link_general.jeffreys_beta_gradient();
3312 for j in 0..g_hist.len() {
3313 assert_eq!(
3314 g_hist[j], g_link[j],
3315 "logit gradient component {j} must be bit-identical"
3316 );
3317 }
3318 let hat_hist = historical.pirls_hat_diag();
3319 let hat_link = link_general.pirls_hat_diag();
3320 for i in 0..hat_hist.len() {
3321 assert_eq!(
3322 hat_hist[i], hat_link[i],
3323 "logit PIRLS hat diagonal {i} must be bit-identical"
3324 );
3325 }
3326 for i in 0..eta.len() {
3327 assert_eq!(historical.w[i], link_general.w[i]);
3328 assert_eq!(historical.w1[i], link_general.w1[i]);
3329 assert_eq!(historical.w2[i], link_general.w2[i]);
3330 assert_eq!(historical.w3[i], link_general.w3[i]);
3331 assert_eq!(historical.w4[i], link_general.w4[i]);
3332 }
3333 }
3334
3335 #[test]
3336 pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3337 let x = fixed_design_5x3();
3340 let beta = array![0.10, -0.40, 0.25];
3341 let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3342 let h = 1e-6_f64;
3343 let mut max_rel = 0.0_f64;
3344 for j in 0..beta.len() {
3345 let mut bp = beta.clone();
3346 bp[j] += h;
3347 let mut bm = beta.clone();
3348 bm[j] -= h;
3349 let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3350 - link_firth_phi(StandardLink::Probit, &x, &bm))
3351 / (2.0 * h);
3352 let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3353 let rel = (grad[j] - fd).abs() / denom;
3354 max_rel = max_rel.max(rel);
3355 assert!(
3356 rel < 1e-6,
3357 "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3358 grad[j],
3359 fd,
3360 rel
3361 );
3362 }
3363 assert!(
3364 max_rel < 1e-6,
3365 "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3366 );
3367 }
3368
3369 #[test]
3370 pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3371 let x = fixed_design_5x3();
3379 let beta = array![0.10, -0.40, 0.25];
3380 let p = beta.len();
3381
3382 let directions = [
3384 array![1.0, 0.0, 0.0],
3385 array![0.0, 1.0, 0.0],
3386 array![0.0, 0.0, 1.0],
3387 array![0.7, -0.5, 0.3],
3388 ];
3389
3390 let h_inner = 1e-4_f64; let h_dir = 1e-4_f64; let mut worst = 0.0_f64;
3393 for u in directions.iter() {
3394 let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3395 let deta = x.dot(u);
3396 let dir = op.direction_from_deta(deta);
3397 let analytic = op.hphi_direction(&dir);
3398
3399 let beta_plus = &beta + &(u * h_dir);
3400 let beta_minus = &beta - &(u * h_dir);
3401 let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3402 let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3403 let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3404
3405 let mut scale = 1e-6_f64;
3406 for r in 0..p {
3407 for c in 0..p {
3408 scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3409 }
3410 }
3411 for r in 0..p {
3412 for c in 0..p {
3413 let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3414 worst = worst.max(rel);
3415 assert!(
3416 rel < 5e-3,
3417 "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3418 analytic[[r, c]],
3419 fd[[r, c]],
3420 rel
3421 );
3422 }
3423 }
3424 }
3425 assert!(
3426 worst < 5e-3,
3427 "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3428 );
3429 }
3430
3431 #[test]
3432 pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3433 let x_full = array![
3437 [1.0, -1.20, -0.20],
3438 [1.0, -0.40, 0.60],
3439 [1.0, 0.10, 1.10],
3440 [1.0, 0.70, 1.70],
3441 [1.0, 1.30, 2.30],
3442 ];
3443 let x_reduced = array![
3444 [1.0, -1.20],
3445 [1.0, -0.40],
3446 [1.0, 0.10],
3447 [1.0, 0.70],
3448 [1.0, 1.30],
3449 ];
3450 let beta_full = array![0.25, -0.50, 0.15];
3451 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3452
3453 let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3454 let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3455 assert!(
3456 phi_full.is_finite(),
3457 "probit Φ on rank-deficient design must be finite, got {phi_full}"
3458 );
3459 assert!(
3460 (phi_full - phi_reduced).abs() < 1e-12,
3461 "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3462 );
3463
3464 let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3465 let grad_full = op_full.jeffreys_beta_gradient();
3466 assert!(
3467 grad_full.iter().all(|v| v.is_finite()),
3468 "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3469 );
3470 let hat_full = op_full.pirls_hat_diag();
3471 let hat_reduced =
3472 build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3473 for i in 0..hat_full.len() {
3474 assert!(
3475 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3476 "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3477 hat_full[i],
3478 hat_reduced[i]
3479 );
3480 }
3481 }
3482
3483 #[test]
3484 pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3485 let x_full = array![
3489 [1.0, -1.2, -0.2],
3490 [1.0, -0.4, 0.6],
3491 [1.0, 0.1, 1.1],
3492 [1.0, 0.7, 1.7],
3493 [1.0, 1.3, 2.3],
3494 ];
3495 let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3496 let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3497 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3498 let eta_full = x_full.dot(&beta_full);
3499 let eta_reduced = x_reduced.dot(&beta_reduced);
3500 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3501
3502 for i in 0..eta_full.len() {
3503 assert!(
3504 (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3505 "eta mismatch at row {i}: full={} reduced={}",
3506 eta_full[i],
3507 eta_reduced[i]
3508 );
3509 }
3510
3511 let op_full = build_weighted_logit_firth_dense_operator(
3512 &x_full,
3513 &eta_full,
3514 observation_weights.view(),
3515 )
3516 .expect("full firth operator");
3517 let op_reduced = build_weighted_logit_firth_dense_operator(
3518 &x_reduced,
3519 &eta_reduced,
3520 observation_weights.view(),
3521 )
3522 .expect("reduced firth operator");
3523
3524 assert!(
3525 (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3526 "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3527 op_full.jeffreys_logdet(),
3528 op_reduced.jeffreys_logdet()
3529 );
3530
3531 let hat_full = op_full.pirls_hat_diag();
3532 let hat_reduced = op_reduced.pirls_hat_diag();
3533 for i in 0..hat_full.len() {
3534 assert!(
3535 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3536 "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3537 hat_full[i],
3538 hat_reduced[i]
3539 );
3540 }
3541 }
3542
3543 #[test]
3544 pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3545 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3546 let basis = array![[1.4, -0.3], [0.6, 1.1]];
3547 let x_reparameterized = x.dot(&basis);
3548 let beta = array![0.25, -0.5];
3549 let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3550 assert!(
3551 basis_det.abs() > 1e-12,
3552 "basis transform must be invertible"
3553 );
3554 let basis_inv = array![
3555 [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3556 [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3557 ];
3558 let beta_reparameterized = basis_inv.dot(&beta);
3559 let eta = x.dot(&beta);
3560 let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3561 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3562
3563 for i in 0..eta.len() {
3564 assert!(
3565 (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3566 "eta mismatch at row {i}: original={} reparameterized={}",
3567 eta[i],
3568 eta_reparameterized[i]
3569 );
3570 }
3571
3572 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3573 .expect("original firth operator");
3574 let op_reparameterized = build_weighted_logit_firth_dense_operator(
3575 &x_reparameterized,
3576 &eta_reparameterized,
3577 observation_weights.view(),
3578 )
3579 .expect("reparameterized firth operator");
3580
3581 assert!(
3582 (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3583 "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3584 op.jeffreys_logdet(),
3585 op_reparameterized.jeffreys_logdet()
3586 );
3587
3588 let hat = op.pirls_hat_diag();
3589 let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3590 for i in 0..hat.len() {
3591 assert!(
3592 (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3593 "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3594 hat[i],
3595 hat_reparameterized[i]
3596 );
3597 }
3598 }
3599
3600 #[test]
3601 pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3602 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3603 let beta = array![0.25, -0.5];
3604 let eta = x.dot(&beta);
3605 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3606 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3607 .expect("firth operator");
3608
3609 let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3610 for i in 0..reduced_metric.nrows() {
3611 for j in 0..reduced_metric.ncols() {
3612 if i == j {
3613 continue;
3614 }
3615 assert!(
3616 reduced_metric[[i, j]].abs() < 1e-10,
3617 "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3618 reduced_metric[[i, j]]
3619 );
3620 }
3621 }
3622 }
3623
3624 #[test]
3625 pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3626 let x = array![
3627 [1.0, -1.0, 0.2],
3628 [1.0, -0.6, -0.3],
3629 [1.0, -0.1, 0.5],
3630 [1.0, 0.3, -0.7],
3631 [1.0, 0.8, 0.1],
3632 [1.0, 1.2, -0.4],
3633 ];
3634 let beta = array![0.1, -0.25, 0.2];
3635 let eta = x.dot(&beta);
3636 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3637
3638 let u = array![0.3, -0.2, 0.4];
3639 let v = array![-0.5, 0.1, 0.25];
3640 let du = op.direction_from_deta(x.dot(&u));
3641 let dv = op.direction_from_deta(x.dot(&v));
3642
3643 let eye = Array2::<f64>::eye(x.ncols());
3644 let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3645 let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3646
3647 for i in 0..uv.nrows() {
3648 for j in 0..uv.ncols() {
3649 let a = uv[[i, j]];
3650 let b = vu[[i, j]];
3651 assert_eq!(
3652 a.signum(),
3653 b.signum(),
3654 "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3655 );
3656 assert!(
3657 (a - b).abs() < 2e-7,
3658 "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3659 );
3660 }
3661 }
3662 }
3663
3664 #[test]
3665 pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3666 let x = array![
3667 [1.0, -1.1, 0.2],
3668 [1.0, -0.6, -0.3],
3669 [1.0, -0.1, 0.5],
3670 [1.0, 0.3, -0.7],
3671 [1.0, 0.8, 0.1],
3672 [1.0, 1.2, -0.4],
3673 ];
3674 let beta = array![0.08, -0.22, 0.27];
3675 let eta = x.dot(&beta);
3676 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3677 let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3678 let dir = op.direction_from_deta(x.dot(&u));
3679
3680 let p = x.ncols();
3681 let eye = Array2::<f64>::eye(p);
3682 let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3683 for i in 0..p {
3684 for j in 0..i {
3685 let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3686 via_apply[[i, j]] = sym;
3687 via_apply[[j, i]] = sym;
3688 }
3689 }
3690 let direct = op.hphi_direction(&dir);
3691 let diff = &direct - &via_apply;
3692 let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3693 assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3694 }
3695
3696 #[test]
3697 pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3698 let x = array![
3699 [1.0, -1.0, 0.2],
3700 [1.0, -0.6, -0.3],
3701 [1.0, -0.1, 0.5],
3702 [1.0, 0.3, -0.7],
3703 [1.0, 0.8, 0.1],
3704 [1.0, 1.2, -0.4],
3705 ];
3706 let x_tau = array![
3707 [0.0, 0.15, -0.05],
3708 [0.0, -0.10, 0.02],
3709 [0.0, 0.08, 0.04],
3710 [0.0, -0.06, -0.03],
3711 [0.0, 0.05, 0.01],
3712 [0.0, -0.12, 0.06],
3713 ];
3714 let beta = array![0.1, -0.25, 0.2];
3715 let eta = x.dot(&beta);
3716 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3717 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3718
3719 let h = 1e-6;
3720 let x_plus = &x + &(h * &x_tau);
3721 let x_minus = &x - &(h * &x_tau);
3722 let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3723
3724 assert!(
3725 (analytic - fd).abs() < 1e-6,
3726 "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3727 );
3728 }
3729
3730 #[test]
3731 pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3732 let x = array![
3733 [1.0, -1.0, 0.2],
3734 [1.0, -0.6, -0.3],
3735 [1.0, -0.1, 0.5],
3736 [1.0, 0.3, -0.7],
3737 [1.0, 0.8, 0.1],
3738 [1.0, 1.2, -0.4],
3739 ];
3740 let x_tau = array![
3741 [0.0, 0.15, -0.05],
3742 [0.0, -0.10, 0.02],
3743 [0.0, 0.08, 0.04],
3744 [0.0, -0.06, -0.03],
3745 [0.0, 0.05, 0.01],
3746 [0.0, -0.12, 0.06],
3747 ];
3748 let beta = array![0.1, -0.25, 0.2];
3749 let eta = x.dot(&beta);
3750 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3751 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3752
3753 let h = 1e-6;
3754 let x_plus = &x + &(h * &x_tau);
3755 let x_minus = &x - &(h * &x_tau);
3756 let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3757
3758 let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3759 assert!(
3760 err < 1e-6,
3761 "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3762 );
3763 }
3764
3765 #[test]
3770 pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3771 let x = array![
3772 [1.0, -1.0, 0.2],
3773 [1.0, -0.6, -0.3],
3774 [1.0, -0.1, 0.5],
3775 [1.0, 0.3, -0.7],
3776 [1.0, 0.8, 0.1],
3777 [1.0, 1.2, -0.4],
3778 ];
3779 let x_tau_i = array![
3780 [0.0, 0.15, -0.05],
3781 [0.0, -0.10, 0.02],
3782 [0.0, 0.08, 0.04],
3783 [0.0, -0.06, -0.03],
3784 [0.0, 0.05, 0.01],
3785 [0.0, -0.12, 0.06],
3786 ];
3787 let x_tau_j = array![
3788 [0.0, -0.04, 0.11],
3789 [0.0, 0.09, -0.02],
3790 [0.0, -0.06, 0.07],
3791 [0.0, 0.10, -0.05],
3792 [0.0, -0.03, 0.08],
3793 [0.0, 0.07, -0.09],
3794 ];
3795 let beta = array![0.1, -0.25, 0.2];
3796 let eta = x.dot(&beta);
3797 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3798
3799 let analytic = op
3800 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3801 .phi_tau_tau_partial;
3802
3803 let h = 1e-5_f64;
3804 let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3805 let eta_e = x_eval.dot(&beta);
3806 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3807 op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3808 .phi_tau_partial
3809 };
3810 let x_plus = &x + &(h * &x_tau_j);
3811 let x_minus = &x - &(h * &x_tau_j);
3812 let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3813
3814 let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3815 assert!(
3816 rel < 1e-7,
3817 "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3818 );
3819 }
3820
3821 #[test]
3826 pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3827 let x = array![
3828 [1.0, -1.0, 0.2],
3829 [1.0, -0.6, -0.3],
3830 [1.0, -0.1, 0.5],
3831 [1.0, 0.3, -0.7],
3832 [1.0, 0.8, 0.1],
3833 [1.0, 1.2, -0.4],
3834 ];
3835 let x_tau_i = array![
3836 [0.0, 0.15, -0.05],
3837 [0.0, -0.10, 0.02],
3838 [0.0, 0.08, 0.04],
3839 [0.0, -0.06, -0.03],
3840 [0.0, 0.05, 0.01],
3841 [0.0, -0.12, 0.06],
3842 ];
3843 let x_tau_j = array![
3844 [0.0, -0.04, 0.11],
3845 [0.0, 0.09, -0.02],
3846 [0.0, -0.06, 0.07],
3847 [0.0, 0.10, -0.05],
3848 [0.0, -0.03, 0.08],
3849 [0.0, 0.07, -0.09],
3850 ];
3851 let beta = array![0.1, -0.25, 0.2];
3852 let eta = x.dot(&beta);
3853 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3854
3855 let analytic = op
3856 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3857 .gphi_tau_tau;
3858
3859 let h = 1e-5_f64;
3860 let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3861 let eta_e = x_eval.dot(&beta);
3862 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3863 op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3864 };
3865 let x_plus = &x + &(h * &x_tau_j);
3866 let x_minus = &x - &(h * &x_tau_j);
3867 let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3868
3869 let scale = analytic
3870 .iter()
3871 .chain(fd.iter())
3872 .map(|v| v.abs())
3873 .fold(0.0_f64, f64::max)
3874 .max(1.0);
3875 let err_max = (&analytic - &fd)
3876 .iter()
3877 .map(|v| v.abs())
3878 .fold(0.0_f64, f64::max);
3879 let rel = err_max / scale;
3880 assert!(
3881 rel < 1e-7,
3882 "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3883 );
3884 }
3885
3886 #[test]
3906 pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3907 let x = array![
3908 [1.0, -1.0, 0.2],
3909 [1.0, -0.6, -0.3],
3910 [1.0, -0.1, 0.5],
3911 [1.0, 0.3, -0.7],
3912 [1.0, 0.8, 0.1],
3913 [1.0, 1.2, -0.4],
3914 ];
3915 let x_tau_i = array![
3916 [0.0, 0.15, -0.05],
3917 [0.0, -0.10, 0.02],
3918 [0.0, 0.08, 0.04],
3919 [0.0, -0.06, -0.03],
3920 [0.0, 0.05, 0.01],
3921 [0.0, -0.12, 0.06],
3922 ];
3923 let x_tau_j = array![
3924 [0.0, -0.04, 0.11],
3925 [0.0, 0.09, -0.02],
3926 [0.0, -0.06, 0.07],
3927 [0.0, 0.10, -0.05],
3928 [0.0, -0.03, 0.08],
3929 [0.0, 0.07, -0.09],
3930 ];
3931 let beta = array![0.1, -0.25, 0.2];
3932 let eta = x.dot(&beta);
3933 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3934 let p = x.ncols();
3935
3936 let m = 3usize;
3938 let mut rhs = Array2::<f64>::zeros((p, m));
3939 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3940 for r in 0..p {
3941 for c in 0..m {
3942 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3943 }
3944 }
3945
3946 let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3949 let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3950 let deta_i = x_tau_i.dot(&beta);
3951 let deta_j = x_tau_j.dot(&beta);
3952 let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3953 let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3954
3955 let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3956 x_tau_i_reduced.clone(),
3957 x_tau_j_reduced.clone(),
3958 &deta_i,
3959 &deta_j,
3960 dot_h_i.clone(),
3961 dot_h_j.clone(),
3962 dot_i_i.clone(),
3963 dot_i_j.clone(),
3964 None,
3965 None,
3966 );
3967 let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3968 x_tau_j_reduced,
3969 x_tau_i_reduced,
3970 &deta_j,
3971 &deta_i,
3972 dot_h_j,
3973 dot_h_i,
3974 dot_i_j,
3975 dot_i_i,
3976 None,
3977 None,
3978 );
3979 let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3980 let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3981
3982 let sym_diff: f64 = (&analytic_ij - &analytic_ji)
3984 .iter()
3985 .map(|v| v.abs())
3986 .fold(0.0_f64, f64::max);
3987 let sym_scale: f64 = analytic_ij
3988 .iter()
3989 .chain(analytic_ji.iter())
3990 .map(|v| v.abs())
3991 .fold(0.0_f64, f64::max)
3992 .max(1.0);
3993 assert!(
3994 sym_diff / sym_scale < 1e-10,
3995 "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
3996 );
3997
3998 let h = 1e-5_f64;
4001 let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
4002 let eta_e = x_eval.dot(&beta);
4003 let op_e =
4004 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4005 let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
4006 let deta_i_e = x_tau_i.dot(&beta);
4007 let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
4008 let kernel_i_e = op_e
4009 .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
4010 op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
4011 };
4012 let x_plus = &x + &(h * &x_tau_j);
4013 let x_minus = &x - &(h * &x_tau_j);
4014 let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
4015
4016 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4019 let scale = a
4020 .iter()
4021 .chain(b.iter())
4022 .map(|v| v.abs())
4023 .fold(0.0_f64, f64::max)
4024 .max(1.0);
4025 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4026 max_diff / scale
4027 };
4028 let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
4029
4030 let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
4033 let eta_e = x_eval.dot(&beta);
4034 let op_e =
4035 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4036 let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
4037 let deta_j_e = x_tau_j.dot(&beta);
4038 let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
4039 let kernel_j_e = op_e
4040 .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
4041 op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
4042 };
4043 let x_plus_i = &x + &(h * &x_tau_i);
4044 let x_minus_i = &x - &(h * &x_tau_i);
4045 let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
4046 let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
4047
4048 let tol = 1e-7_f64;
4049 assert!(
4050 err_ij < tol,
4051 "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
4052 analytic=\n{analytic_ij:?}\n\
4053 fd=\n{fd_ij:?}"
4054 );
4055 assert!(
4056 err_ji < tol,
4057 "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
4058 analytic=\n{analytic_ji:?}\n\
4059 fd=\n{fd_ji:?}"
4060 );
4061 }
4062
4063 #[test]
4081 pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
4082 let x = array![
4083 [1.0, -1.0, 0.2],
4084 [1.0, -0.6, -0.3],
4085 [1.0, -0.1, 0.5],
4086 [1.0, 0.3, -0.7],
4087 [1.0, 0.8, 0.1],
4088 [1.0, 1.2, -0.4],
4089 ];
4090 let x_tau = array![
4091 [0.0, 0.15, -0.05],
4092 [0.0, -0.10, 0.02],
4093 [0.0, 0.08, 0.04],
4094 [0.0, -0.06, -0.03],
4095 [0.0, 0.05, 0.01],
4096 [0.0, -0.12, 0.06],
4097 ];
4098 let beta = array![0.1, -0.25, 0.2];
4099 let v = array![0.3, 0.2, -0.15];
4101
4102 let eta = x.dot(&beta);
4103 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
4104 let p = x.ncols();
4105
4106 let m = 3usize;
4108 let mut rhs = Array2::<f64>::zeros((p, m));
4109 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
4110 for r in 0..p {
4111 for c in 0..m {
4112 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
4113 }
4114 }
4115
4116 let x_tau_reduced = op.reduce_explicit_design(&x_tau);
4118 let deta_partial = x_tau.dot(&beta);
4119 let (dot_i_partial, dot_h_partial) =
4120 op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
4121 let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
4122 x_tau_reduced.clone(),
4123 &deta_partial,
4124 dot_h_partial.clone(),
4125 dot_i_partial.clone(),
4126 );
4127
4128 let deta_v = x.dot(&v);
4129 let direction = op.direction_from_deta(deta_v);
4130 let x_tau_v = x_tau.dot(&v);
4131 let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
4132 &tau_kernel,
4133 &deta_partial,
4134 &dot_i_partial,
4135 &direction,
4136 &x_tau_v,
4137 );
4138 let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
4139
4140 let h = 1e-5_f64;
4143 let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
4144 let eta_e = x.dot(beta_eval);
4145 let op_e =
4146 build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
4147 let x_tau_r = op_e.reduce_explicit_design(&x_tau);
4148 let deta_e = x_tau.dot(beta_eval);
4149 let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
4150 let ker_e =
4151 op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
4152 op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
4153 };
4154 let beta_plus = &beta + &(h * &v);
4155 let beta_minus = &beta - &(h * &v);
4156 let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
4157
4158 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4159 let scale = a
4160 .iter()
4161 .chain(b.iter())
4162 .map(|v| v.abs())
4163 .fold(0.0_f64, f64::max)
4164 .max(1.0);
4165 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4166 max_diff / scale
4167 };
4168 let err = rel_max_abs_diff(&analytic, &fd);
4169
4170 let tol = 1e-7_f64;
4171 assert!(
4172 err < tol,
4173 "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
4174 analytic=\n{analytic:?}\n\
4175 fd=\n{fd:?}"
4176 );
4177 }
4178
4179 #[test]
4180 pub(crate) fn logisticweight_loses_positive_tail_mass() {
4181 let eta = 50.0_f64;
4182 let z = (-eta).exp();
4183 let stable = z / (1.0_f64 + z).powi(2);
4184 assert!(stable > 0.0);
4185 let got = logisticweight(eta);
4186 assert!(
4187 (got - stable).abs() < 1e-30,
4188 "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
4189 got,
4190 stable
4191 );
4192 }
4193
4194 #[test]
4195 pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
4196 for &eta in &[
4200 -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
4201 ] {
4202 let jet = logit_inverse_link_jet5(eta);
4203 let (w, w1, w2, w3, w4) =
4204 crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
4205 assert!(
4206 w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
4207 "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
4208 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
4209 jet.d1,
4210 jet.d2,
4211 jet.d3,
4212 jet.d4,
4213 jet.d5
4214 );
4215 }
4216 }
4217
4218 #[test]
4219 pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
4220 fn reference_probit_weight(eta: f64) -> f64 {
4224 let p = gam_math::probability::normal_cdf(eta);
4225 let q = 1.0 - p;
4226 let phi = gam_math::probability::normal_pdf(eta);
4227 if p <= 0.0 || q <= 0.0 {
4228 return 0.0;
4229 }
4230 phi * phi / (p * q)
4231 }
4232 let h = 1e-4_f64;
4233 for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
4234 let (w, w1, w2, _w3, _w4) =
4235 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4236 let ref_w = reference_probit_weight(eta);
4237 let fd1 =
4238 (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
4239 let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
4240 + reference_probit_weight(eta - h))
4241 / (h * h);
4242 assert!(
4243 (w - ref_w).abs() < 1e-10,
4244 "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
4245 );
4246 assert!(
4247 (w1 - fd1).abs() < 1e-5,
4248 "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
4249 );
4250 assert!(
4251 (w2 - fd2).abs() < 1e-3,
4252 "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
4253 );
4254 }
4255 }
4256
4257 #[test]
4258 pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
4259 for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
4263 let (w, w1, w2, w3, w4) =
4264 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4265 assert!(
4266 w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
4267 "probit Fisher weight jet must saturate to zero at eta={eta}; got \
4268 ({w}, {w1}, {w2}, {w3}, {w4})"
4269 );
4270 }
4271 for &eta in &[12.0_f64, -12.0] {
4276 let (w, w1, w2, w3, w4) =
4277 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4278 assert!(
4279 w > 0.0
4280 && w.is_finite()
4281 && w1.is_finite()
4282 && w2.is_finite()
4283 && w3.is_finite()
4284 && w4.is_finite(),
4285 "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
4286 ({w}, {w1}, {w2}, {w3}, {w4})"
4287 );
4288 }
4289 }
4290}