1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15impl<'a> RemlState<'a> {
16 pub(crate) fn xt_diag_x_dense_into(
17 x: &Array2<f64>,
18 diag: &Array1<f64>,
19 weighted: &mut Array2<f64>,
20 ) -> Array2<f64> {
21 super::assembly::xt_diag_x_dense_into(x, diag, weighted)
22 }
23
24 #[inline]
25 pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
26 n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
27 }
28
29 pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
30 let mut out = Array2::<f64>::zeros(x.raw_dim());
31 super::assembly::row_scale_dense_into(x, scale, &mut out);
32 out
33 }
34
35 #[inline]
36 pub(crate) fn dense_product_likely_uses_inner_parallelism(
37 m: usize,
38 n: usize,
39 k: usize,
40 ) -> bool {
41 const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
46 const PAR_MIN_LONG_DIM: usize = 256;
47 let flop_scale = m.saturating_mul(n).saturating_mul(k);
48 let long_dim = m.max(n).max(k);
49 flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
50 }
51
52 #[inline]
53 pub(crate) fn should_join_independent_dense_products(
54 products: &[(usize, usize, usize)],
55 ) -> bool {
56 const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
57 if rayon::current_num_threads() <= 1 {
58 return false;
59 }
60 let mut total_flop_scale = 0usize;
61 for &(m, n, k) in products {
62 if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
63 return false;
64 }
65 total_flop_scale =
66 total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
67 }
68 total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
69 }
70
71 #[inline]
80 pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
81 out: &mut Array2<f64>,
82 observation_weight_sqrt: Option<&Array1<f64>>,
83 ) {
84 let Some(scale) = observation_weight_sqrt else {
85 return;
86 };
87 super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
88 }
89
90 #[inline]
95 pub(crate) fn fisher_weight_derivatives(
96 link: &InverseLink,
97 eta: f64,
98 ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
99 fisher_weight_jet5_for_inverse_link(link, eta)
100 }
101
102 #[inline]
103 pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
104 let mut min_pivot_sq = f64::INFINITY;
105 let mut max_pivot_sq = 0.0_f64;
106 for &pivot in chol_diag {
107 if !pivot.is_finite() || pivot <= 0.0 {
108 return false;
109 }
110 let pivot_sq = pivot * pivot;
111 min_pivot_sq = min_pivot_sq.min(pivot_sq);
112 max_pivot_sq = max_pivot_sq.max(pivot_sq);
113 }
114 if !min_pivot_sq.is_finite() {
115 return false;
116 }
117 let scale = max_pivot_sq.max(1.0);
118 let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
119 min_pivot_sq > floor
120 }
121
122 pub(crate) fn reduced_fisher_inverse_and_half_logdet(
123 fisher_reduced: &Array2<f64>,
124 ) -> Result<(Array2<f64>, f64), EstimationError> {
125 let r = fisher_reduced.nrows();
126 assert_eq!(r, fisher_reduced.ncols());
127 let mut k_reduced = Array2::<f64>::zeros((r, r));
128 if r == 0 {
129 return Ok((k_reduced, 0.0));
130 }
131
132 if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
133 let chol_diag = chol.diag();
134 if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
135 let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
136 for col in 0..r {
137 let mut e_col = Array1::<f64>::zeros(r);
138 e_col[col] = 1.0;
139 let solved = chol.solvevec(&e_col);
140 k_reduced.column_mut(col).assign(&solved);
141 }
142 return Ok((k_reduced, half_log_det));
143 }
144 }
145
146 let (evals_ir, evecs_ir) = fisher_reduced
147 .eigh(Side::Lower)
148 .map_err(EstimationError::EigendecompositionFailed)?;
149 let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
150 let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
151 let mut kept_positive_direction = false;
152 let mut half_log_det = 0.0_f64;
153 for (eig_idx, &eig) in evals_ir.iter().enumerate() {
154 if eig > tol {
155 kept_positive_direction = true;
156 half_log_det += 0.5 * eig.ln();
157 let inv = eig.recip();
158 let vec = evecs_ir.column(eig_idx).to_owned();
159 for row in 0..r {
160 for col in 0..r {
161 k_reduced[[row, col]] += inv * vec[row] * vec[col];
162 }
163 }
164 }
165 }
166 if !kept_positive_direction {
167 return Err(EstimationError::ModelIsIllConditioned {
168 condition_number: f64::INFINITY,
169 });
170 }
171 Ok((k_reduced, half_log_det))
172 }
173
174 pub(crate) fn fill_fisher_weight_derivative_arrays(
175 link: &InverseLink,
176 eta: &Array1<f64>,
177 w: &mut Array1<f64>,
178 w1: &mut Array1<f64>,
179 w2: &mut Array1<f64>,
180 w3: &mut Array1<f64>,
181 w4: &mut Array1<f64>,
182 ) -> Result<(), EstimationError> {
183 assert_eq!(eta.len(), w.len());
184 assert_eq!(eta.len(), w1.len());
185 assert_eq!(eta.len(), w2.len());
186 assert_eq!(eta.len(), w3.len());
187 assert_eq!(eta.len(), w4.len());
188
189 if Self::parallelize_firth_derivative_rows(eta.len()) {
190 let values: Result<Vec<_>, EstimationError> = eta
191 .par_iter()
192 .map(|&ei| Self::fisher_weight_derivatives(link, ei))
193 .collect();
194 for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
195 w[i] = value;
196 w1[i] = first;
197 w2[i] = second;
198 w3[i] = third;
199 w4[i] = fourth;
200 }
201 return Ok(());
202 }
203 for i in 0..eta.len() {
204 let (value, first, second, third, fourth) =
205 Self::fisher_weight_derivatives(link, eta[i])?;
206 w[i] = value;
207 w1[i] = first;
208 w2[i] = second;
209 w3[i] = third;
210 w4[i] = fourth;
211 }
212 Ok(())
213 }
214
215 pub(crate) fn weighted_cross(
216 left: &Array2<f64>,
217 right: &Array2<f64>,
218 weights: &Array1<f64>,
219 ) -> Array2<f64> {
220 assert_eq!(left.nrows(), right.nrows());
221 assert_eq!(left.nrows(), weights.len());
222 super::assembly::weighted_cross_dense(left, right, weights)
223 }
224
225 pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
226 assert_eq!(a.nrows(), b.ncols());
227 assert_eq!(a.ncols(), b.nrows());
228 let elems = a.nrows().saturating_mul(a.ncols());
229 if elems >= 32 * 32 {
230 let aview = FaerArrayView::new(a);
231 let bview = FaerArrayView::new(b);
232 return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
233 }
234 let m = a.nrows();
235 let n = a.ncols();
236 kahan_sum((0..m).map(|i| {
237 let mut acc = 0.0_f64;
238 for j in 0..n {
239 acc += a[[i, j]] * b[[j, i]];
240 }
241 acc
242 }))
243 }
244
245 pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
246 let weighted = Self::row_scale(z, weights);
253 fast_atb(z, &weighted)
254 }
255
256 pub(crate) fn reduced_crossweighted_gram(
257 z_left: &Array2<f64>,
258 z_right: &Array2<f64>,
259 weights: &Array1<f64>,
260 ) -> Array2<f64> {
261 let weighted = Self::row_scale(z_right, weights);
266 fast_atb(z_left, &weighted)
267 }
268
269 pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
270 let za = fast_ab(z, a);
276 (z * &za).sum_axis(ndarray::Axis(1))
277 }
278
279 pub(crate) fn apply_hadamard_gram(
280 z: &Array2<f64>,
281 a_left: &Array2<f64>,
282 a_right: &Array2<f64>,
283 vec: &Array1<f64>,
284 ) -> Array1<f64> {
285 let s = Self::reducedweighted_gram(z, vec);
294 let left_s = a_left.dot(&s);
295 let t = left_s.dot(a_right);
296 Self::reduced_diag_gram(z, &t)
297 }
298
299 pub(crate) fn apply_hadamard_gram_to_matrix(
300 z: &Array2<f64>,
301 a_left: &Array2<f64>,
302 a_right: &Array2<f64>,
303 mat: &Array2<f64>,
304 ) -> Array2<f64> {
305 let mut out = Array2::<f64>::zeros(mat.raw_dim());
313 for col in 0..mat.ncols() {
314 let v = mat.column(col).to_owned();
315 let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
316 out.column_mut(col).assign(&y);
317 }
318 out
319 }
320
321 pub(super) fn build_firth_dense_operator_for_link(
326 link: &InverseLink,
327 x_dense: &Array2<f64>,
328 eta: &Array1<f64>,
329 observation_weights: ndarray::ArrayView1<'_, f64>,
330 ) -> Result<FirthDenseOperator, EstimationError> {
331 FirthDenseOperator::build_with_observation_weights_impl(
332 link,
333 x_dense,
334 eta,
335 Some(observation_weights),
336 )
337 }
338
339 pub(super) fn firth_exact_tau_kernel(
340 op: &FirthDenseOperator,
341 x_tau: &Array2<f64>,
342 beta: &Array1<f64>,
343 include_hphi_tau_kernel: bool,
344 ) -> FirthTauExactKernel {
345 op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
346 }
347
348 pub(super) fn firth_hphi_tau_partial_apply(
349 op: &FirthDenseOperator,
350 x_tau: &Array2<f64>,
351 kernel: &FirthTauPartialKernel,
352 rhs: &Array2<f64>,
353 ) -> Array2<f64> {
354 op.hphi_tau_partial_apply(x_tau, kernel, rhs)
355 }
356}
357
358impl FirthDenseOperator {
359 pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
360 for col in 0..q_basis.ncols() {
361 let mut pivot_row = 0usize;
362 let mut pivot_abs = 0.0_f64;
363 for row in 0..q_basis.nrows() {
364 let value = q_basis[[row, col]];
365 let abs_value = value.abs();
366 if abs_value > pivot_abs {
367 pivot_abs = abs_value;
368 pivot_row = row;
369 }
370 }
371 if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
372 q_basis.column_mut(col).mapv_inplace(|v| -v);
373 }
374 }
375 }
376
377 pub(crate) fn identifiable_subspace_basis_from_gram(
378 gram: &Array2<f64>,
379 ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
380 let p = gram.nrows();
381 assert_eq!(p, gram.ncols());
382 if p == 0 {
383 return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
384 }
385
386 let (evals, evecs) = gram
387 .eigh(Side::Lower)
388 .map_err(EstimationError::EigendecompositionFailed)?;
389 let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
390 let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
391 let mut keep: Vec<usize> = evals
392 .iter()
393 .enumerate()
394 .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
395 .collect();
396 if keep.is_empty() {
397 return Err(EstimationError::ModelIsIllConditioned {
398 condition_number: f64::INFINITY,
399 });
400 }
401
402 keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
408 let r = keep.len();
409 let mut q_basis = Array2::<f64>::zeros((p, r));
410 let mut metric_spectrum = Array1::<f64>::zeros(r);
411 for (col_idx, eig_idx) in keep.into_iter().enumerate() {
412 q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
413 metric_spectrum[col_idx] = evals[eig_idx];
414 }
415 Self::canonicalize_basis_column_signs(&mut q_basis);
416 Ok((q_basis, metric_spectrum))
417 }
418
419 #[inline]
420 pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
421 assert_eq!(diag.len(), matrix.nrows());
422 assert_eq!(matrix.nrows(), matrix.ncols());
423 kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
424 }
425
426 pub fn build_for_link(
427 link: &InverseLink,
428 x_dense: &Array2<f64>,
429 eta: &Array1<f64>,
430 ) -> Result<FirthDenseOperator, EstimationError> {
431 Self::build_with_observation_weights_impl(link, x_dense, eta, None)
432 }
433
434 pub fn build_with_observation_weights_for_link(
435 link: &InverseLink,
436 x_dense: &Array2<f64>,
437 eta: &Array1<f64>,
438 observation_weights: ndarray::ArrayView1<'_, f64>,
439 ) -> Result<FirthDenseOperator, EstimationError> {
440 Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
441 }
442
443 #[inline]
444 pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
445 &self.w * &self.h_diag
446 }
447
448 #[inline]
468 pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
469 let mut shift = Array1::<f64>::zeros(self.w.len());
470 for i in 0..self.w.len() {
471 let wi = self.w[i];
472 if wi > 0.0 {
473 shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
474 }
475 }
476 shift
477 }
478
479 pub(crate) fn build_with_observation_weights_impl(
480 link: &InverseLink,
481 x_dense: &Array2<f64>,
482 eta: &Array1<f64>,
483 observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
484 ) -> Result<FirthDenseOperator, EstimationError> {
485 let n = x_dense.nrows();
536 if eta.len() != n {
537 crate::bail_invalid_estim!(
538 "Firth operator shape mismatch: nrows={}, eta_len={}",
539 n,
540 eta.len()
541 );
542 }
543 let observation_weight_sqrt = if let Some(weights) = observation_weights {
544 if weights.len() != n {
545 crate::bail_invalid_estim!(
546 "Firth operator observation weight length {} != number of rows {}",
547 weights.len(),
548 n
549 );
550 }
551 let mut sqrt = Array1::<f64>::zeros(n);
552 for i in 0..n {
553 let weight = weights[i];
554 if !weight.is_finite() || weight < 0.0 {
555 crate::bail_invalid_estim!(
556 "Firth operator requires finite nonnegative observation weights, got {} at row {}",
557 weight,
558 i
559 );
560 }
561 sqrt[i] = weight.sqrt();
562 }
563 Some(sqrt)
564 } else {
565 None
566 };
567 let mut w = Array1::<f64>::zeros(n);
568 let mut w1 = Array1::<f64>::zeros(n);
569 let mut w2 = Array1::<f64>::zeros(n);
570 let mut w3 = Array1::<f64>::zeros(n);
571 let mut w4 = Array1::<f64>::zeros(n);
572 RemlState::fill_fisher_weight_derivative_arrays(
573 link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
574 )?;
575 let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
576 RemlState::row_scale(x_dense, scale)
577 } else {
578 x_dense.clone()
579 };
580
581 let gram = fast_atb(&basis_design, &basis_design);
597 let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
598 let x_reduced = fast_ab(&basis_design, &q_basis);
599 let r = q_basis.ncols();
600
601 let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&x_reduced, &w);
606 if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
612 let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
613 let min_ev = eigvals_ir
614 .iter()
615 .copied()
616 .filter(|v| v.is_finite() && *v > 0.0)
617 .fold(f64::INFINITY, f64::min);
618 if min_ev.is_finite() {
619 let rel = min_ev / max_ev;
620 if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
621 log::warn!(
622 "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
623 min_ev,
624 max_ev,
625 rel
626 );
627 }
628 }
629 }
630
631 let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
632 let (k_reduced, mut half_log_det) = if r > 0 {
633 RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
644 } else {
645 (Array2::<f64>::zeros((r, r)), 0.0)
646 };
647 if r > 0 {
648 for col in 0..r {
649 let metric_eig = metric_spectrum[col];
650 half_log_det -= 0.5 * metric_eig.ln();
651 x_metric_reduced_inv_diag[col] = metric_eig.recip();
652 }
653 }
654 let h_diag = if r > 0 {
656 RemlState::reduced_diag_gram(&x_reduced, &k_reduced)
657 } else {
658 Array1::<f64>::zeros(n)
659 };
660 let x_dense_t = x_dense.t().to_owned();
661 let b_base = RemlState::row_scale(x_dense, &w1);
662 let p_b_base =
663 RemlState::apply_hadamard_gram_to_matrix(&x_reduced, &k_reduced, &k_reduced, &b_base);
664 Ok(FirthDenseOperator {
665 x_dense: x_dense.clone(),
666 x_dense_t,
667 q_basis,
668 x_reduced,
669 observation_weight_sqrt,
670 k_reduced,
671 x_metric_reduced_inv_diag,
672 half_log_det,
673 h_diag,
674 w,
675 w1,
676 w2,
677 w3,
678 w4,
679 b_base,
680 p_b_base,
681 })
682 }
683
684 #[inline]
685 pub(crate) fn jeffreys_logdet(&self) -> f64 {
686 self.half_log_det
687 }
688
689 pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
702 use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
703 let p = self.x_dense.ncols();
704 assert_eq!(
705 z.nrows(),
706 p,
707 "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
708 p,
709 z.nrows()
710 );
711 let m = z.ncols();
712 if m == 0 {
713 return 0.0;
714 }
715 let z_owned = z.to_owned();
717 let xz = fast_ab(&self.x_dense, &z_owned);
718 let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
719 RemlState::row_scale(&xz, scale)
720 } else {
721 xz
722 };
723 let mut j_t = fast_xt_diag_x(&xtz, &self.w);
725 symmetrize_in_place(&mut j_t);
726 let (evals, _) = match j_t.eigh(Side::Lower) {
727 Ok(pair) => pair,
728 Err(_) => return f64::NEG_INFINITY,
729 };
730 let Some(evals_slice) = evals.as_slice() else {
731 return f64::NEG_INFINITY;
732 };
733 let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
734 0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
735 }
736
737 #[inline]
738 pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
739 0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
744 }
745
746 #[inline]
747 pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
748 (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
749 }
750
751 #[inline]
752 pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
753 let mut reduced = fast_ab(x, &self.q_basis);
754 if let Some(scale) = self.observation_weight_sqrt.as_ref() {
755 reduced = RemlState::row_scale(&reduced, scale);
756 }
757 reduced
758 }
759
760 pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
761 let s_u = &self.w1 * &deta;
777 let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
782 let k_g_u = self.k_reduced.dot(&g_u_reduced);
783 let a_u_reduced = k_g_u.dot(&self.k_reduced);
784 let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
787 let b_uvec = &self.w2 * &deta;
788 FirthDirection {
789 deta,
790 g_u_reduced,
791 a_u_reduced,
792 dh,
793 b_uvec,
794 }
795 }
796
797 #[inline]
798 pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
799 fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
800 }
801
802 #[inline]
803 pub(crate) fn apply_p_u_to_matrix(
804 &self,
805 a_u_reduced: &Array2<f64>,
806 mat: &Array2<f64>,
807 ) -> Array2<f64> {
808 let mut out = RemlState::apply_hadamard_gram_to_matrix(
809 &self.x_reduced,
810 &self.k_reduced,
811 a_u_reduced,
812 mat,
813 );
814 out.mapv_inplace(|v| -2.0 * v);
815 out
816 }
817
818 pub(crate) fn hphi_direction_apply(
819 &self,
820 dir: &FirthDirection,
821 rhs: &Array2<f64>,
822 ) -> Array2<f64> {
823 let p = self.x_dense.ncols();
824 if rhs.nrows() != p {
825 return Array2::<f64>::zeros((p, rhs.ncols()));
826 }
827 if rhs.ncols() == 0 || p == 0 {
828 return Array2::<f64>::zeros((p, rhs.ncols()));
829 }
830 let etav = fast_ab(&self.x_dense, rhs);
837 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
838 let m_qv = RemlState::apply_hadamard_gram_to_matrix(
839 &self.x_reduced,
840 &self.k_reduced,
841 &self.k_reduced,
842 &qv,
843 );
844 let buvec = &dir.b_uvec;
845 let m_buv = RemlState::apply_hadamard_gram_to_matrix(
846 &self.x_reduced,
847 &self.k_reduced,
848 &self.k_reduced,
849 &(&etav * &buvec.view().insert_axis(Axis(1))),
850 );
851 let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
852 let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
853 let diag_term = self
854 .x_dense_t
855 .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
856 let term1 = self.left_scaled_xt(buvec, &m_qv);
857 let term2 = self.left_scaled_xt(&self.w1, &m_buv);
858 let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
859 0.5 * (diag_term - (term1 + term2 + term3))
860 }
861
862 pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
863 let p = self.x_dense.ncols();
864 let eye = Array2::<f64>::eye(p);
865 let mut out = self.hphi_direction_apply(dir, &eye);
866 symmetrize_in_place(&mut out);
879 out
880 }
881
882 pub(crate) fn hphisecond_direction_apply(
883 &self,
884 u: &FirthDirection,
885 v: &FirthDirection,
886 rhs: &Array2<f64>,
887 ) -> Array2<f64> {
888 let p = self.x_dense.ncols();
889 if rhs.nrows() != p {
890 return Array2::<f64>::zeros((p, rhs.ncols()));
891 }
892 if rhs.ncols() == 0 || p == 0 {
893 return Array2::<f64>::zeros((p, rhs.ncols()));
894 }
895 let deta_uv = &u.deta * &v.deta;
907 let s_uv = &self.w2 * &deta_uv;
910 let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
911 let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
912 let k_gv = self.k_reduced.dot(&v.g_u_reduced);
913 let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
914 let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
917 - k_gv.dot(&k_g_u).dot(&self.k_reduced)
918 - k_g_u.dot(&k_gv).dot(&self.k_reduced);
919 let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
920 let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
925 + &(&self.w3 * &(&u.deta * &v.dh))
926 + &(&self.w3 * &(&v.deta * &u.dh))
927 + &(&self.w2 * &d2h);
928
929 let eta_rhs = fast_ab(&self.x_dense, rhs);
930 let diag_term = fast_ab(
931 &self.x_dense_t,
932 &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
933 );
934
935 let b_uvvec = &self.w3 * &deta_uv;
936 let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
937 let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
938
939 let p_b_rhs = fast_ab(&self.p_b_base, rhs);
944 let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
945 &self.x_reduced,
946 &self.k_reduced,
947 &self.k_reduced,
948 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
949 );
950 let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
951 &self.x_reduced,
952 &self.k_reduced,
953 &self.k_reduced,
954 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
955 );
956 let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
957 &self.x_reduced,
958 &self.k_reduced,
959 &self.k_reduced,
960 &b_uv_base,
961 );
962 let p_buv_rhs = fast_ab(&p_buv_base, rhs);
963
964 let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
965 let pv_bu_rhs = self.apply_p_u_to_matrix(
966 &v.a_u_reduced,
967 &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
968 );
969 let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
970 let p_u_bv_rhs = self.apply_p_u_to_matrix(
971 &u.a_u_reduced,
972 &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
973 );
974
975 let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
976 &self.x_reduced,
977 &u.a_u_reduced,
978 &v.a_u_reduced,
979 &self.b_base,
980 );
981 let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
982 &self.x_reduced,
983 &self.k_reduced,
984 &a_uv_reduced,
985 &self.b_base,
986 );
987 let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
988 let p_uv_rhs = fast_ab(&p_uv_base, rhs);
989
990 let d2_terms = [
992 self.left_scaled_xt(&b_uvvec, &p_b_rhs),
993 self.left_scaled_xt(&self.w1, &p_buv_rhs),
994 self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
995 self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
996 self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
997 self.left_scaled_xt(&self.w1, &pv_bu_rhs),
998 self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
999 self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1000 self.left_scaled_xt(&self.w1, &p_uv_rhs),
1001 ];
1002 let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1003 for term in d2_terms {
1004 d2_j2 += &term;
1005 }
1006
1007 0.5 * (diag_term - d2_j2)
1008 }
1009
1010 pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1011 assert_eq!(a.nrows(), b.nrows());
1012 assert_eq!(a.ncols(), b.ncols());
1013 let mut out = Array1::<f64>::zeros(a.nrows());
1014 for i in 0..a.nrows() {
1015 let mut acc = 0.0_f64;
1016 for j in 0..a.ncols() {
1017 acc += a[[i, j]] * b[[i, j]];
1018 }
1019 out[i] = acc;
1020 }
1021 out
1022 }
1023
1024 pub(super) fn rowwise_bilinear(
1025 a: &Array2<f64>,
1026 m: &Array2<f64>,
1027 b: &Array2<f64>,
1028 ) -> Array1<f64> {
1029 assert_eq!(a.nrows(), b.nrows());
1031 assert_eq!(a.ncols(), m.nrows());
1032 assert_eq!(b.ncols(), m.ncols());
1033 let am = fast_ab(a, m);
1034 Self::rowwise_dot(&am, b)
1035 }
1036
1037 pub(crate) fn dot_i_and_h_from_reduced(
1038 &self,
1039 x_tau_reduced: &Array2<f64>,
1040 deta: &Array1<f64>,
1041 ) -> (Array2<f64>, Array1<f64>) {
1042 let dw = &self.w1 * deta;
1062 let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1063 + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1064 + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1065
1066 let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1067 let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1068 let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1069 let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1070 let dot_h = dot_h_explicit + dot_h_implicit;
1071 (dot_i, dot_h)
1072 }
1073
1074 pub(crate) fn exact_tau_kernel(
1075 &self,
1076 x_tau: &Array2<f64>,
1077 beta: &Array1<f64>,
1078 include_hphi_tau_kernel: bool,
1079 ) -> FirthTauExactKernel {
1080 let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1104 let x_tau_reduced = self.reduce_explicit_design(x_tau);
1105 let (dot_i_partial, dot_h_partial) =
1106 self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1107 let dot_s_partial =
1108 fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1109
1110 let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1111 let secondvec =
1112 &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1113 let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1114 let gphi_tau = first + second;
1115 let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1116 - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1117
1118 let tau_kernel = if include_hphi_tau_kernel {
1119 Some(self.hphi_tau_partial_prepare_from_partials(
1120 x_tau_reduced,
1121 &deta_partial,
1122 dot_h_partial,
1123 dot_i_partial,
1124 ))
1125 } else {
1126 None
1127 };
1128 FirthTauExactKernel {
1129 gphi_tau,
1130 phi_tau_partial,
1131 tau_kernel,
1132 }
1133 }
1134
1135 pub(crate) fn hphi_tau_partial_prepare_from_partials(
1136 &self,
1137 x_tau_reduced: Array2<f64>,
1138 deta_partial: &Array1<f64>,
1139 dot_h_partial: Array1<f64>,
1140 dot_i_partial: Array2<f64>,
1141 ) -> FirthTauPartialKernel {
1142 let dotw1 = &self.w2 * deta_partial;
1143 let dotw2 = &self.w3 * deta_partial;
1144 let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1145 FirthTauPartialKernel {
1146 deta_partial: deta_partial.clone(),
1147 dotw1,
1148 dotw2,
1149 dot_h_partial,
1150 x_tau_reduced,
1151 dot_i_partial,
1152 dot_k_reduced: dot_k,
1153 }
1154 }
1155
1156 pub(crate) fn d_beta_hphi_tau_partial_dense(
1157 &self,
1158 x_tau: &Array2<f64>,
1159 beta: &Array1<f64>,
1160 beta_direction: &Array1<f64>,
1161 ) -> Option<Array2<f64>> {
1162 if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1163 return None;
1164 }
1165 if !x_tau.iter().any(|value| *value != 0.0) {
1166 return None;
1167 }
1168 let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1169 let tau_kernel = tau_bundle.tau_kernel?;
1170 let firth_direction =
1171 self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1172 let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1173 let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1174 &tau_kernel,
1175 &tau_kernel.deta_partial,
1176 &tau_kernel.dot_i_partial,
1177 &firth_direction,
1178 &x_tau_v,
1179 );
1180 let eye = Array2::<f64>::eye(beta_direction.len());
1181 Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1182 }
1183
1184 pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1185 RemlState::apply_hadamard_gram_to_matrix(
1187 &self.x_reduced,
1188 &self.k_reduced,
1189 &self.k_reduced,
1190 mat,
1191 )
1192 }
1193
1194 pub(crate) fn apply_mtau_to_matrix(
1195 &self,
1196 kernel: &FirthTauPartialKernel,
1197 mat: &Array2<f64>,
1198 ) -> Array2<f64> {
1199 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1211 return Array2::<f64>::zeros(mat.raw_dim());
1212 }
1213 let mut out = Array2::<f64>::zeros(mat.raw_dim());
1214 for col in 0..mat.ncols() {
1215 let v = mat.column(col).to_owned();
1216 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1217 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1218 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1219
1220 let szt =
1221 RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1222 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1223 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1224
1225 let t3 = RemlState::apply_hadamard_gram(
1226 &self.x_reduced,
1227 &self.k_reduced,
1228 &kernel.dot_k_reduced,
1229 &v,
1230 );
1231
1232 let y = 2.0 * (t1 + t2 + t3);
1233 out.column_mut(col).assign(&y);
1234 }
1235 out
1236 }
1237
1238 pub(crate) fn hphi_tau_partial_apply(
1239 &self,
1240 x_tau: &Array2<f64>,
1241 kernel: &FirthTauPartialKernel,
1242 rhs: &Array2<f64>,
1243 ) -> Array2<f64> {
1244 let p = self.x_dense.ncols();
1245 if rhs.nrows() != p {
1246 return Array2::<f64>::zeros((p, rhs.ncols()));
1247 }
1248 if rhs.ncols() == 0 || p == 0 {
1249 return Array2::<f64>::zeros((p, rhs.ncols()));
1250 }
1251 let etav = fast_ab(&self.x_dense, rhs);
1268 let etav_tau = fast_ab(x_tau, rhs);
1269 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1270 let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1271 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1272 let m_qv = self.apply_pbar_to_matrix(&qv);
1273 let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1274 let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1275 * &self.h_diag.view().insert_axis(Axis(1))
1276 - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1277 let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1278 + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1279 * self.h_diag.view().insert_axis(Axis(1))
1280 + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1281 * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1282 - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1283 + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1284 0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1285 }
1286
1287 pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1627 &self,
1628 x_tau_i_reduced: Array2<f64>,
1629 x_tau_j_reduced: Array2<f64>,
1630 deta_i_partial: &Array1<f64>,
1631 deta_j_partial: &Array1<f64>,
1632 dot_h_i_partial: Array1<f64>,
1633 dot_h_j_partial: Array1<f64>,
1634 dot_i_i_partial: Array2<f64>,
1635 dot_i_j_partial: Array2<f64>,
1636 x_tau_tau_reduced: Option<Array2<f64>>,
1637 deta_ij_partial: Option<Array1<f64>>,
1638 ) -> FirthTauTauPartialKernel {
1639 let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1641 let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1642 FirthTauTauPartialKernel {
1643 x_tau_i_reduced,
1644 x_tau_j_reduced,
1645 deta_i_partial: deta_i_partial.clone(),
1646 deta_j_partial: deta_j_partial.clone(),
1647 dot_h_i_partial,
1648 dot_h_j_partial,
1649 dot_k_i_reduced,
1650 dot_k_j_reduced,
1651 dot_i_i_partial,
1652 dot_i_j_partial,
1653 x_tau_tau_reduced,
1654 deta_ij_partial,
1655 }
1656 }
1657
1658 pub(crate) fn hphi_tau_tau_partial_apply(
1666 &self,
1667 x_tau_i: &Array2<f64>,
1668 x_tau_j: &Array2<f64>,
1669 kernel: &FirthTauTauPartialKernel,
1670 rhs: &Array2<f64>,
1671 ) -> Array2<f64> {
1672 let p = self.x_dense.ncols();
1673 if rhs.nrows() != p {
1674 return Array2::<f64>::zeros((p, rhs.ncols()));
1675 }
1676 if rhs.ncols() == 0 || p == 0 {
1677 return Array2::<f64>::zeros((p, rhs.ncols()));
1678 }
1679 let n = self.x_dense.nrows();
1680 let m = rhs.ncols();
1681
1682 let z = &self.x_reduced;
1684 let x_r = &self.x_reduced;
1685 let k = &self.k_reduced;
1686 let x_ri = &kernel.x_tau_i_reduced;
1687 let x_rj = &kernel.x_tau_j_reduced;
1688 let deta_i = &kernel.deta_i_partial;
1689 let deta_j = &kernel.deta_j_partial;
1690 let dh_i = &kernel.dot_h_i_partial;
1691 let dh_j = &kernel.dot_h_j_partial;
1692 let dot_k_i = &kernel.dot_k_i_reduced;
1693 let dot_k_j = &kernel.dot_k_j_reduced;
1694 let dot_i_i = &kernel.dot_i_i_partial;
1695 let dot_i_j = &kernel.dot_i_j_partial;
1696
1697 let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1700 let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1701 let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1702 let zeros_n = Array1::<f64>::zeros(n);
1703 let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1704
1705 let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1709 (n, m, p),
1710 (n, m, p),
1711 (n, m, p),
1712 ]) {
1713 let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1714 || fast_ab(&self.x_dense, rhs),
1715 || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1716 );
1717 (eta_v, eta_i_v, eta_j_v)
1718 } else {
1719 (
1720 fast_ab(&self.x_dense, rhs),
1721 fast_ab(x_tau_i, rhs),
1722 fast_ab(x_tau_j, rhs),
1723 )
1724 }; let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1729 let qt_v = fast_atb(&self.q_basis, rhs); let mut out = fast_ab(x_rij, &qt_v); RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1732 &mut out,
1733 self.observation_weight_sqrt.as_ref(),
1734 );
1735 out
1736 } else {
1737 Array2::<f64>::zeros((n, m))
1738 };
1739
1740 let a_i_reduced = -dot_k_i; let a_j_reduced = -dot_k_j;
1747
1748 let dw_i = &self.w1 * deta_i;
1758 let dw_j = &self.w1 * deta_j;
1759 let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
1760 let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
1761 if x_tau_tau_is_some {
1762 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
1763 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
1764 }
1765 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
1766 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
1767 i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
1768 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
1769 i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
1770 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
1771 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
1772
1773 let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
1778 + a_i_reduced.dot(dot_i_j).dot(k)
1779 + a_j_reduced.dot(dot_i_i).dot(k);
1780
1781 let dh_ij: Array1<f64> = {
1791 let r = k.ncols();
1792 let can_join = RemlState::should_join_independent_dense_products(&[
1793 (n, r, r),
1794 (n, r, r),
1795 (n, r, r),
1796 (n, r, r),
1797 ]);
1798 let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
1799 let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
1800 || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
1801 || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
1802 );
1803 (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
1804 } else {
1805 (
1806 fast_ab(x_r, &k_ddot),
1807 fast_ab(x_ri, dot_k_j),
1808 fast_ab(x_rj, dot_k_i),
1809 fast_ab(x_ri, k),
1810 )
1811 };
1812
1813 let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
1814 acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
1815 acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
1816 acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
1817 if x_tau_tau_is_some {
1818 let rij_k = fast_ab(x_rij, k);
1819 acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
1820 }
1821 acc
1822 };
1823
1824 let gamma = &self.w2 * &self.h_diag;
1835 let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
1836 let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
1837 let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
1838 + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
1839 + &(&(&self.w3 * deta_i) * dh_j)
1840 + &(&(&self.w3 * deta_j) * dh_i)
1841 + &(&self.w2 * &dh_ij));
1842
1843 let mut diag_term = Array2::<f64>::zeros((p, m));
1853 let gamma_col = gamma.view().insert_axis(Axis(1));
1854 let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
1855 let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
1856 let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
1857
1858 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
1860 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
1861 diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
1863 diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
1864 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
1866 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
1867 diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
1869 if x_tau_tau_is_some {
1871 let y: Array2<f64> = &eta_v * &gamma_col;
1877 let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
1878 let mut y_scaled = y.clone();
1879 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1880 &mut y_scaled,
1881 self.observation_weight_sqrt.as_ref(),
1882 );
1883 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
1884 } else {
1885 self.q_basis.dot(&x_rij.t().dot(&y))
1886 };
1887 diag_term = diag_term + xt_ij_y;
1888 diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
1889 }
1890
1891 let w1_col = self.w1.view().insert_axis(Axis(1));
1903 let b_v = &eta_v * &w1_col;
1904
1905 let w2_deta_i = &self.w2 * deta_i;
1907 let w2_deta_j = &self.w2 * deta_j;
1908 let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
1909 let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
1910 let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
1911 let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
1912
1913 let w3_didj = &(&self.w3 * deta_i) * deta_j;
1919 let w2_dij = &self.w2 * deta_ij;
1920 let bddot_scale = &w3_didj + &w2_dij;
1921 let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
1922 let mut bddot_ij_v = &eta_v * &bddot_scale_col;
1923 bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
1924 bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
1925 bddot_ij_v += &(&eta_ij_v * &w1_col);
1926
1927 let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
1929 let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
1930
1931 let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
1938 let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
1939 let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
1940 let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
1941
1942 let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
1944 let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
1945
1946 let p_ddot_b_v = self.apply_p_ddot_ij(
1948 x_r,
1949 x_ri,
1950 x_rj,
1951 x_rij,
1952 k,
1953 dot_k_i,
1954 dot_k_j,
1955 &k_ddot,
1956 x_tau_tau_is_some,
1957 &b_v,
1958 );
1959
1960 let apply_bdot_tau_t =
1974 |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
1975 let scale_col = scale_deta.view().insert_axis(Axis(1));
1976 self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
1977 };
1978
1979 let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
1980 let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
1981 let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
1982 out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
1983 out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
1984 if x_tau_tau_is_some {
1985 let y = q_v * &w1_col;
1987 let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
1988 let mut y_scaled = y.clone();
1989 RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1990 &mut y_scaled,
1991 self.observation_weight_sqrt.as_ref(),
1992 );
1993 self.q_basis.dot(&x_rij.t().dot(&y_scaled))
1994 } else {
1995 self.q_basis.dot(&x_rij.t().dot(&y))
1996 };
1997 out = out + contrib;
1998 }
1999 out
2000 };
2001
2002 let t1a = apply_bddot_ij_t(&p_bv);
2004 let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2005 let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2007 let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2008 let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2010 let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2011 let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2013 let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2014 let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2016
2017 let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2018
2019 0.5 * (diag_term - d2_bpb)
2020 }
2021
2022 pub(crate) fn exact_tau_tau_kernel(
2085 &self,
2086 x_tau_i: &Array2<f64>,
2087 x_tau_j: &Array2<f64>,
2088 x_tau_tau: Option<&Array2<f64>>,
2089 beta: &Array1<f64>,
2090 include_hphi_tau_tau_kernel: bool,
2091 ) -> FirthTauTauExactKernel {
2092 let deta_i = x_tau_i.dot(beta);
2093 let deta_j = x_tau_j.dot(beta);
2094 let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2095
2096 let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2097 let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2098 let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2099
2100 let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2101 let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2102
2103 let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2110 let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2111 let dw_i = &self.w1 * &deta_i;
2112 let dw_j = &self.w1 * &deta_j;
2113 let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2114
2115 let x_r = &self.x_reduced;
2116 let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2117 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2118 i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2119 i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2120 }
2121 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2122 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2123 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2124 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2125 i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2126 i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2127 i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2128
2129 let k = &self.k_reduced;
2133 let k_dot_i_i = k.dot(&dot_i_i);
2134 let k_dot_i_j = k.dot(&dot_i_j);
2135 let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2136 - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2137
2138 let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2145 let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2146 let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2147 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2148 s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2149 }
2150 s_ddot = s_ddot
2151 + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2152 + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2153 let g_inv = &self.x_metric_reduced_inv_diag;
2162 let rdim = k.nrows();
2163 let mut a_pen = 0.0_f64;
2164 for kk in 0..rdim {
2165 for ll in 0..rdim {
2166 a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2167 }
2168 a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2169 }
2170 let phi_tau_tau_partial = a_lik + a_pen;
2171
2172 let dot_k_i = -k.dot(&dot_i_i).dot(k);
2179 let dot_k_j = -k.dot(&dot_i_j).dot(k);
2180 let a_i_red = -&dot_k_i; let a_j_red = -&dot_k_j; let k_ddot: Array2<f64> =
2183 -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2184
2185 let n = self.x_dense.nrows();
2191 let mut dh_ij = Array1::<f64>::zeros(n);
2192 if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2193 let rij_k = x_rij.dot(k);
2194 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2195 }
2196 let xr_kddot = x_r.dot(&k_ddot);
2197 dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2198 let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2199 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2200 let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2201 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2202 let ri_k = x_tau_i_reduced.dot(k);
2203 dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2204
2205 let w1_h = &self.w1 * &self.h_diag;
2208 let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2209 if let Some(x_ij) = x_tau_tau.as_ref() {
2210 gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2211 }
2212 let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2213 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2214
2215 let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2217 gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2218
2219 let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2226 v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2227 v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2228 v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2229 v_dot_ij += &(&self.w1 * &dh_ij);
2230 gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2231
2232 let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2233 Some(self.hphi_tau_tau_partial_prepare_from_partials(
2234 x_tau_i_reduced,
2235 x_tau_j_reduced,
2236 &deta_i,
2237 &deta_j,
2238 dot_h_i,
2239 dot_h_j,
2240 dot_i_i,
2241 dot_i_j,
2242 x_tau_tau_reduced,
2243 deta_ij,
2244 ))
2245 } else {
2246 None
2247 };
2248
2249 FirthTauTauExactKernel {
2250 phi_tau_tau_partial,
2251 gphi_tau_tau,
2252 tau_tau_kernel,
2253 }
2254 }
2255
2256 pub(crate) fn apply_mtau_from_reduced(
2263 &self,
2264 x_tau_reduced: &Array2<f64>,
2265 dot_k_reduced: &Array2<f64>,
2266 mat: &Array2<f64>,
2267 ) -> Array2<f64> {
2268 if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2269 return Array2::<f64>::zeros(mat.raw_dim());
2270 }
2271 let mut out = Array2::<f64>::zeros(mat.raw_dim());
2272 for col in 0..mat.ncols() {
2273 let v = mat.column(col).to_owned();
2274 let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2275 let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2276 let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2277
2278 let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2279 let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2280 let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2281
2282 let t3 =
2283 RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2284
2285 let y = 2.0 * (t1 + t2 + t3);
2286 out.column_mut(col).assign(&y);
2287 }
2288 out
2289 }
2290
2291 pub(crate) fn apply_p_ddot_ij(
2302 &self,
2303 x_r: &Array2<f64>,
2304 x_ri: &Array2<f64>,
2305 x_rj: &Array2<f64>,
2306 x_rij: &Array2<f64>,
2307 k: &Array2<f64>,
2308 dot_k_i: &Array2<f64>,
2309 dot_k_j: &Array2<f64>,
2310 k_ddot: &Array2<f64>,
2311 x_tau_tau_is_some: bool,
2312 mat: &Array2<f64>,
2313 ) -> Array2<f64> {
2314 let n = self.x_dense.nrows();
2315 let m = mat.ncols();
2316 if mat.nrows() != n || m == 0 {
2317 return Array2::<f64>::zeros(mat.raw_dim());
2318 }
2319 let mut out = Array2::<f64>::zeros((n, m));
2320 for col in 0..m {
2321 let v = mat.column(col).to_owned();
2322 let s_zz = RemlState::reducedweighted_gram(x_r, &v); let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); let mut mdot_mdot = Array1::<f64>::zeros(n);
2342 {
2344 let core = k.dot(&s_zz).dot(&k.t());
2345 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2346 }
2347 {
2349 let core = k.dot(&s_zz).dot(&dot_k_j.t());
2350 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2351 }
2352 {
2354 let core = k.dot(&s_zj).dot(&k.t());
2355 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2356 }
2357 {
2359 let core = dot_k_i.dot(&s_zz).dot(&k.t());
2360 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2361 }
2362 {
2364 let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2365 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2366 }
2367 {
2369 let core = dot_k_i.dot(&s_zj).dot(&k.t());
2370 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2371 }
2372 {
2375 let core = k.dot(&s_iz).dot(&k.t());
2376 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2377 }
2378 {
2381 let core = k.dot(&s_iz).dot(&dot_k_j.t());
2382 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2383 }
2384 {
2387 let core = k.dot(&s_ij).dot(&k.t());
2388 mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2389 }
2390
2391 let mut m_mddot = Array1::<f64>::zeros(n);
2396 if x_tau_tau_is_some {
2398 let core = k.dot(&s_zz).dot(k);
2399 m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2400 }
2401 if x_tau_tau_is_some {
2403 let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2404 let core = k.dot(&s_ijz).dot(k);
2405 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2406 }
2407 {
2409 let core = dot_k_j.dot(&s_zz).dot(k);
2410 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2411 }
2412 {
2414 let core = dot_k_j.dot(&s_iz).dot(k);
2415 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2416 }
2417 {
2419 let core = dot_k_i.dot(&s_zz).dot(k);
2420 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2421 }
2422 {
2424 let core = dot_k_i.dot(&s_jz).dot(k);
2425 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2426 }
2427 {
2429 let core = k.dot(&s_jz).dot(k);
2430 m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2431 }
2432 {
2434 let core = k.dot(&s_iz).dot(k);
2435 m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2436 }
2437 {
2439 let core = k_ddot.dot(&s_zz).dot(k);
2440 m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2441 }
2442
2443 let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2448 out.column_mut(col).assign(&col_out);
2449 }
2450 out
2451 }
2452
2453 pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2464 &self,
2465 tau_kernel: &FirthTauPartialKernel,
2466 deta_partial: &Array1<f64>,
2467 dot_i_partial: &Array2<f64>,
2468 beta_direction: &FirthDirection,
2469 x_tau_v: &Array1<f64>,
2470 ) -> FirthTauBetaPartialKernel {
2471 let s_v = &self.w1 * &beta_direction.deta;
2483 let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2484 let cross1 =
2485 RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2486 let cross2 =
2487 RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2488 let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2489 let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2490
2491 let term_a = beta_direction
2500 .a_u_reduced
2501 .dot(dot_i_partial)
2502 .dot(&self.k_reduced);
2503 let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2504 let term_c = self
2505 .k_reduced
2506 .dot(dot_i_partial)
2507 .dot(&beta_direction.a_u_reduced);
2508 let d_beta_dot_k = &term_a - &term_b + &term_c;
2509
2510 let cross_diag = Self::rowwise_bilinear(
2516 &tau_kernel.x_tau_reduced,
2517 &beta_direction.a_u_reduced,
2518 &self.x_reduced,
2519 );
2520 let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2521 let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2522
2523 FirthTauBetaPartialKernel {
2524 x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2525 deta_partial: deta_partial.clone(),
2526 dot_h_partial: tau_kernel.dot_h_partial.clone(),
2527 dot_i_partial: dot_i_partial.clone(),
2528 dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2529 deta_v: beta_direction.deta.clone(),
2530 deta_tau_v: x_tau_v.clone(),
2531 a_v_reduced: beta_direction.a_u_reduced.clone(),
2532 dh_v: beta_direction.dh.clone(),
2533 b_vvec: beta_direction.b_uvec.clone(),
2534 d_beta_dot_k,
2535 d_beta_dot_h,
2536 }
2537 }
2538
2539 pub(crate) fn apply_p_tau_v_to_matrix(
2550 &self,
2551 kernel: &FirthTauBetaPartialKernel,
2552 mat: &Array2<f64>,
2553 ) -> Array2<f64> {
2554 let n = self.x_dense.nrows();
2555 if mat.nrows() != n || mat.ncols() == 0 {
2556 return Array2::<f64>::zeros(mat.raw_dim());
2557 }
2558 let z = &self.x_reduced;
2559 let z_tau = &kernel.x_tau_reduced;
2560 let k_r = &self.k_reduced;
2561 let a_v = &kernel.a_v_reduced; let dot_k_tau = &kernel.dot_k_reduced; let d_beta_dot_k = &kernel.d_beta_dot_k; let mut out = Array2::<f64>::zeros(mat.raw_dim());
2565 for col in 0..mat.ncols() {
2566 let v = mat.column(col).to_owned();
2567 let s_zz = RemlState::reducedweighted_gram(z, &v);
2568 let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2569
2570 let mid_1 = a_v.dot(&s_zz).dot(k_r);
2573 let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2574 let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2577 let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2578 let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2581 let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2582 let mid_4 = k_r.dot(&s_zz).dot(a_v);
2585 let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2586 let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2589 let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2590 let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2592
2593 let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2596 out.column_mut(col).assign(&y);
2597 }
2598 out
2599 }
2600
2601 pub(crate) fn d_beta_hphi_tau_partial_apply(
2602 &self,
2603 x_tau: &Array2<f64>,
2604 kernel: &FirthTauBetaPartialKernel,
2605 rhs: &Array2<f64>,
2606 ) -> Array2<f64> {
2607 let p = self.x_dense.ncols();
2608 if rhs.nrows() != p {
2609 return Array2::<f64>::zeros((p, rhs.ncols()));
2610 }
2611 if rhs.ncols() == 0 || p == 0 {
2612 return Array2::<f64>::zeros((p, rhs.ncols()));
2613 }
2614 let etav = fast_ab(&self.x_dense, rhs);
2623 let etav_tau = fast_ab(x_tau, rhs);
2624 let deta_v = &kernel.deta_v;
2625 let deta_tau_v = &kernel.deta_tau_v;
2626 let eta_tau = &kernel.deta_partial;
2627 let dot_h = &kernel.dot_h_partial;
2628
2629 let dotw1 = &self.w2 * eta_tau;
2631 let dotw2 = &self.w3 * eta_tau;
2632
2633 let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2639 let b_vvec = &kernel.b_vvec;
2640 let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2641 let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2642
2643 let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2645 let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2646 + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2647 let m_qv = self.apply_pbar_to_matrix(&qv);
2648 let tau_kernel_view = FirthTauPartialKernel {
2651 deta_partial: eta_tau.clone(),
2652 dotw1: dotw1.clone(),
2653 dotw2: dotw2.clone(),
2654 dot_h_partial: dot_h.clone(),
2655 x_tau_reduced: kernel.x_tau_reduced.clone(),
2656 dot_i_partial: kernel.dot_i_partial.clone(),
2657 dot_k_reduced: kernel.dot_k_reduced.clone(),
2658 };
2659 let m_qv_tau =
2660 self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2661
2662 let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2666 let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2667 + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2668
2669 let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2671 + self.apply_pbar_to_matrix(&d_beta_qv);
2672
2673 let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2675 + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2676 + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2677 + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2678
2679 let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2682 - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2683 - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2684
2685 let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2696 let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2697
2698 let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2699 + &etav_tau * &c_v.view().insert_axis(Axis(1))
2700 + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2701 - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2702 - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2703 - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2704 - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2705
2706 0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2707 }
2708}
2709
2710#[cfg(test)]
2711mod tests {
2712 use super::*;
2713 use crate::mixture_link::logit_inverse_link_jet5;
2714 use gam_problem::StandardLink;
2715 use ndarray::{Array1, Array2, array};
2716
2717 pub(crate) fn build_logit_firth_dense_operator(
2718 x_dense: &Array2<f64>,
2719 eta: &Array1<f64>,
2720 ) -> Result<FirthDenseOperator, EstimationError> {
2721 FirthDenseOperator::build_with_observation_weights_impl(
2722 &InverseLink::Standard(StandardLink::Logit),
2723 x_dense,
2724 eta,
2725 None,
2726 )
2727 }
2728
2729 pub(crate) fn build_weighted_logit_firth_dense_operator(
2730 x_dense: &Array2<f64>,
2731 eta: &Array1<f64>,
2732 observation_weights: ndarray::ArrayView1<'_, f64>,
2733 ) -> Result<FirthDenseOperator, EstimationError> {
2734 FirthDenseOperator::build_with_observation_weights_impl(
2735 &InverseLink::Standard(StandardLink::Logit),
2736 x_dense,
2737 eta,
2738 Some(observation_weights),
2739 )
2740 }
2741
2742 pub(crate) fn logisticweight(eta: f64) -> f64 {
2743 logit_inverse_link_jet5(eta).d1
2744 }
2745
2746 pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
2747 let eta = x.dot(beta);
2748 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
2749 op.jeffreys_logdet()
2750 }
2751
2752 pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
2753 let eta = x.dot(beta);
2754 let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
2755 op.jeffreys_beta_gradient()
2756 }
2757
2758 pub(crate) fn weighted_firthphivalue(
2759 x: &Array2<f64>,
2760 beta: &Array1<f64>,
2761 observation_weights: &Array1<f64>,
2762 ) -> f64 {
2763 let eta = x.dot(beta);
2764 let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
2765 .expect("weighted firth operator");
2766 op.jeffreys_logdet()
2767 }
2768
2769 #[test]
2770 pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
2771 let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
2772 let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
2773 .expect("barely positive-definite reduced fisher");
2774 let expected = 0.5 * 16.0_f64.ln();
2775
2776 assert!(
2777 half_log_det.is_finite(),
2778 "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
2779 );
2780 assert!(
2781 (half_log_det - expected).abs() < 1e-12,
2782 "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
2783 );
2784 assert!(
2785 k_reduced.iter().all(|value| value.is_finite()),
2786 "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
2787 );
2788 assert!(
2789 k_reduced[[1, 1]].abs() < f64::EPSILON,
2790 "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
2791 );
2792 }
2793
2794 #[test]
2795 pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
2796 let x = array![
2815 [1.0, -1.1, 0.2],
2816 [1.0, -0.5, -0.6],
2817 [1.0, 0.0, 0.3],
2818 [1.0, 0.8, -0.4],
2819 [1.0, 1.2, 0.7],
2820 ];
2821 let beta = array![0.15, -0.6, 0.35];
2822 let eta = x.dot(&beta);
2823 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
2824
2825 let h = 1e-2_f64;
2826 let w = |z: f64| logisticweight(z);
2827 let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
2828 let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
2829 let d3direct = |z: f64| {
2830 (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
2831 };
2832 let d4direct = |z: f64| {
2833 (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
2834 / h.powi(4)
2835 };
2836 for i in 0..eta.len() {
2837 let z = eta[i];
2838 let wfd = w(z);
2839 let w1fd = d1direct(z);
2840 let w2fd = d2direct(z);
2841 let w3fd = d3direct(z);
2842 let w4fd = d4direct(z);
2843
2844 assert!((op.w[i] - wfd).abs() < 1e-12);
2845 assert_eq!(op.w1[i].signum(), w1fd.signum());
2846 assert_eq!(op.w2[i].signum(), w2fd.signum());
2847 assert_eq!(op.w3[i].signum(), w3fd.signum());
2848 assert_eq!(op.w4[i].signum(), w4fd.signum());
2849 assert!((op.w1[i] - w1fd).abs() < 1e-5);
2850 assert!((op.w2[i] - w2fd).abs() < 1e-4);
2851 assert!((op.w3[i] - w3fd).abs() < 1e-4);
2852 assert!((op.w4[i] - w4fd).abs() < 1e-3);
2853 }
2854 }
2855
2856 #[test]
2857 pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
2858 let x = array![
2859 [1.0, -0.7, 0.3],
2860 [1.0, -0.2, -0.4],
2861 [1.0, 0.5, 0.1],
2862 [1.0, 1.1, -0.6],
2863 [1.0, 1.6, 0.8],
2864 ];
2865 let beta = array![0.2, -0.45, 0.25];
2866 let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
2867 let eta = x.dot(&beta);
2868 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
2869 .expect("weighted firth operator");
2870 let grad = op.jeffreys_beta_gradient();
2871 let h = 1e-6;
2872
2873 for j in 0..beta.len() {
2874 let mut beta_plus = beta.clone();
2875 beta_plus[j] += h;
2876 let mut beta_minus = beta.clone();
2877 beta_minus[j] -= h;
2878 let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
2879 - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
2880 / (2.0 * h);
2881 assert!(
2882 (grad[j] - fd).abs() < 1e-5,
2883 "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
2884 j,
2885 grad[j],
2886 fd
2887 );
2888 }
2889 }
2890
2891 pub(crate) fn build_link_firth_op(
2899 link: StandardLink,
2900 x: &Array2<f64>,
2901 beta: &Array1<f64>,
2902 ) -> FirthDenseOperator {
2903 let eta = x.dot(beta);
2904 FirthDenseOperator::build_with_observation_weights_impl(
2905 &InverseLink::Standard(link),
2906 x,
2907 &eta,
2908 None,
2909 )
2910 .expect("link-general firth operator")
2911 }
2912
2913 pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
2914 build_link_firth_op(link, x, beta).jeffreys_logdet()
2915 }
2916
2917 pub(crate) fn link_firth_grad(
2918 link: StandardLink,
2919 x: &Array2<f64>,
2920 beta: &Array1<f64>,
2921 ) -> Array1<f64> {
2922 build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
2923 }
2924
2925 pub(crate) fn numeric_firth_hessian(
2930 link: StandardLink,
2931 x: &Array2<f64>,
2932 beta: &Array1<f64>,
2933 h: f64,
2934 ) -> Array2<f64> {
2935 let p = beta.len();
2936 let mut hess = Array2::<f64>::zeros((p, p));
2937 for j in 0..p {
2938 let mut bp = beta.clone();
2939 bp[j] += h;
2940 let mut bm = beta.clone();
2941 bm[j] -= h;
2942 let gp = link_firth_grad(link, x, &bp);
2943 let gm = link_firth_grad(link, x, &bm);
2944 let col = (&gp - &gm) / (2.0 * h);
2945 hess.column_mut(j).assign(&col);
2946 }
2947 hess
2948 }
2949
2950 pub(crate) fn fixed_design_5x3() -> Array2<f64> {
2952 array![
2953 [1.0, -1.10, 0.35],
2954 [1.0, -0.40, -0.65],
2955 [1.0, 0.15, 0.20],
2956 [1.0, 0.80, -0.45],
2957 [1.0, 1.25, 0.70],
2958 ]
2959 }
2960
2961 #[test]
2962 pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
2963 let x = fixed_design_5x3();
2968 let beta = array![0.20, -0.55, 0.30];
2969 let eta = x.dot(&beta);
2970
2971 let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
2972 let link_general = FirthDenseOperator::build_with_observation_weights_impl(
2973 &InverseLink::Standard(StandardLink::Logit),
2974 &x,
2975 &eta,
2976 None,
2977 )
2978 .expect("link-general logit");
2979
2980 assert_eq!(
2981 historical.jeffreys_logdet(),
2982 link_general.jeffreys_logdet(),
2983 "logit Φ must be bit-identical through the link-general path"
2984 );
2985 let g_hist = historical.jeffreys_beta_gradient();
2986 let g_link = link_general.jeffreys_beta_gradient();
2987 for j in 0..g_hist.len() {
2988 assert_eq!(
2989 g_hist[j], g_link[j],
2990 "logit gradient component {j} must be bit-identical"
2991 );
2992 }
2993 let hat_hist = historical.pirls_hat_diag();
2994 let hat_link = link_general.pirls_hat_diag();
2995 for i in 0..hat_hist.len() {
2996 assert_eq!(
2997 hat_hist[i], hat_link[i],
2998 "logit PIRLS hat diagonal {i} must be bit-identical"
2999 );
3000 }
3001 for i in 0..eta.len() {
3002 assert_eq!(historical.w[i], link_general.w[i]);
3003 assert_eq!(historical.w1[i], link_general.w1[i]);
3004 assert_eq!(historical.w2[i], link_general.w2[i]);
3005 assert_eq!(historical.w3[i], link_general.w3[i]);
3006 assert_eq!(historical.w4[i], link_general.w4[i]);
3007 }
3008 }
3009
3010 #[test]
3011 pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3012 let x = fixed_design_5x3();
3015 let beta = array![0.10, -0.40, 0.25];
3016 let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3017 let h = 1e-6_f64;
3018 let mut max_rel = 0.0_f64;
3019 for j in 0..beta.len() {
3020 let mut bp = beta.clone();
3021 bp[j] += h;
3022 let mut bm = beta.clone();
3023 bm[j] -= h;
3024 let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3025 - link_firth_phi(StandardLink::Probit, &x, &bm))
3026 / (2.0 * h);
3027 let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3028 let rel = (grad[j] - fd).abs() / denom;
3029 max_rel = max_rel.max(rel);
3030 assert!(
3031 rel < 1e-6,
3032 "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3033 grad[j],
3034 fd,
3035 rel
3036 );
3037 }
3038 assert!(
3039 max_rel < 1e-6,
3040 "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3041 );
3042 }
3043
3044 #[test]
3045 pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3046 let x = fixed_design_5x3();
3054 let beta = array![0.10, -0.40, 0.25];
3055 let p = beta.len();
3056
3057 let directions = [
3059 array![1.0, 0.0, 0.0],
3060 array![0.0, 1.0, 0.0],
3061 array![0.0, 0.0, 1.0],
3062 array![0.7, -0.5, 0.3],
3063 ];
3064
3065 let h_inner = 1e-4_f64; let h_dir = 1e-4_f64; let mut worst = 0.0_f64;
3068 for u in directions.iter() {
3069 let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3070 let deta = x.dot(u);
3071 let dir = op.direction_from_deta(deta);
3072 let analytic = op.hphi_direction(&dir);
3073
3074 let beta_plus = &beta + &(u * h_dir);
3075 let beta_minus = &beta - &(u * h_dir);
3076 let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3077 let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3078 let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3079
3080 let mut scale = 1e-6_f64;
3081 for r in 0..p {
3082 for c in 0..p {
3083 scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3084 }
3085 }
3086 for r in 0..p {
3087 for c in 0..p {
3088 let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3089 worst = worst.max(rel);
3090 assert!(
3091 rel < 5e-3,
3092 "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3093 analytic[[r, c]],
3094 fd[[r, c]],
3095 rel
3096 );
3097 }
3098 }
3099 }
3100 assert!(
3101 worst < 5e-3,
3102 "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3103 );
3104 }
3105
3106 #[test]
3107 pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3108 let x_full = array![
3112 [1.0, -1.20, -0.20],
3113 [1.0, -0.40, 0.60],
3114 [1.0, 0.10, 1.10],
3115 [1.0, 0.70, 1.70],
3116 [1.0, 1.30, 2.30],
3117 ];
3118 let x_reduced = array![
3119 [1.0, -1.20],
3120 [1.0, -0.40],
3121 [1.0, 0.10],
3122 [1.0, 0.70],
3123 [1.0, 1.30],
3124 ];
3125 let beta_full = array![0.25, -0.50, 0.15];
3126 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3127
3128 let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3129 let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3130 assert!(
3131 phi_full.is_finite(),
3132 "probit Φ on rank-deficient design must be finite, got {phi_full}"
3133 );
3134 assert!(
3135 (phi_full - phi_reduced).abs() < 1e-12,
3136 "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3137 );
3138
3139 let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3140 let grad_full = op_full.jeffreys_beta_gradient();
3141 assert!(
3142 grad_full.iter().all(|v| v.is_finite()),
3143 "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3144 );
3145 let hat_full = op_full.pirls_hat_diag();
3146 let hat_reduced =
3147 build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3148 for i in 0..hat_full.len() {
3149 assert!(
3150 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3151 "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3152 hat_full[i],
3153 hat_reduced[i]
3154 );
3155 }
3156 }
3157
3158 #[test]
3159 pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3160 let x_full = array![
3164 [1.0, -1.2, -0.2],
3165 [1.0, -0.4, 0.6],
3166 [1.0, 0.1, 1.1],
3167 [1.0, 0.7, 1.7],
3168 [1.0, 1.3, 2.3],
3169 ];
3170 let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3171 let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3172 let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3173 let eta_full = x_full.dot(&beta_full);
3174 let eta_reduced = x_reduced.dot(&beta_reduced);
3175 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3176
3177 for i in 0..eta_full.len() {
3178 assert!(
3179 (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3180 "eta mismatch at row {i}: full={} reduced={}",
3181 eta_full[i],
3182 eta_reduced[i]
3183 );
3184 }
3185
3186 let op_full = build_weighted_logit_firth_dense_operator(
3187 &x_full,
3188 &eta_full,
3189 observation_weights.view(),
3190 )
3191 .expect("full firth operator");
3192 let op_reduced = build_weighted_logit_firth_dense_operator(
3193 &x_reduced,
3194 &eta_reduced,
3195 observation_weights.view(),
3196 )
3197 .expect("reduced firth operator");
3198
3199 assert!(
3200 (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3201 "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3202 op_full.jeffreys_logdet(),
3203 op_reduced.jeffreys_logdet()
3204 );
3205
3206 let hat_full = op_full.pirls_hat_diag();
3207 let hat_reduced = op_reduced.pirls_hat_diag();
3208 for i in 0..hat_full.len() {
3209 assert!(
3210 (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3211 "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3212 hat_full[i],
3213 hat_reduced[i]
3214 );
3215 }
3216 }
3217
3218 #[test]
3219 pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3220 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3221 let basis = array![[1.4, -0.3], [0.6, 1.1]];
3222 let x_reparameterized = x.dot(&basis);
3223 let beta = array![0.25, -0.5];
3224 let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3225 assert!(
3226 basis_det.abs() > 1e-12,
3227 "basis transform must be invertible"
3228 );
3229 let basis_inv = array![
3230 [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3231 [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3232 ];
3233 let beta_reparameterized = basis_inv.dot(&beta);
3234 let eta = x.dot(&beta);
3235 let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3236 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3237
3238 for i in 0..eta.len() {
3239 assert!(
3240 (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3241 "eta mismatch at row {i}: original={} reparameterized={}",
3242 eta[i],
3243 eta_reparameterized[i]
3244 );
3245 }
3246
3247 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3248 .expect("original firth operator");
3249 let op_reparameterized = build_weighted_logit_firth_dense_operator(
3250 &x_reparameterized,
3251 &eta_reparameterized,
3252 observation_weights.view(),
3253 )
3254 .expect("reparameterized firth operator");
3255
3256 assert!(
3257 (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3258 "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3259 op.jeffreys_logdet(),
3260 op_reparameterized.jeffreys_logdet()
3261 );
3262
3263 let hat = op.pirls_hat_diag();
3264 let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3265 for i in 0..hat.len() {
3266 assert!(
3267 (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3268 "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3269 hat[i],
3270 hat_reparameterized[i]
3271 );
3272 }
3273 }
3274
3275 #[test]
3276 pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3277 let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3278 let beta = array![0.25, -0.5];
3279 let eta = x.dot(&beta);
3280 let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3281 let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3282 .expect("firth operator");
3283
3284 let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3285 for i in 0..reduced_metric.nrows() {
3286 for j in 0..reduced_metric.ncols() {
3287 if i == j {
3288 continue;
3289 }
3290 assert!(
3291 reduced_metric[[i, j]].abs() < 1e-10,
3292 "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3293 reduced_metric[[i, j]]
3294 );
3295 }
3296 }
3297 }
3298
3299 #[test]
3300 pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3301 let x = array![
3302 [1.0, -1.0, 0.2],
3303 [1.0, -0.6, -0.3],
3304 [1.0, -0.1, 0.5],
3305 [1.0, 0.3, -0.7],
3306 [1.0, 0.8, 0.1],
3307 [1.0, 1.2, -0.4],
3308 ];
3309 let beta = array![0.1, -0.25, 0.2];
3310 let eta = x.dot(&beta);
3311 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3312
3313 let u = array![0.3, -0.2, 0.4];
3314 let v = array![-0.5, 0.1, 0.25];
3315 let du = op.direction_from_deta(x.dot(&u));
3316 let dv = op.direction_from_deta(x.dot(&v));
3317
3318 let eye = Array2::<f64>::eye(x.ncols());
3319 let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3320 let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3321
3322 for i in 0..uv.nrows() {
3323 for j in 0..uv.ncols() {
3324 let a = uv[[i, j]];
3325 let b = vu[[i, j]];
3326 assert_eq!(
3327 a.signum(),
3328 b.signum(),
3329 "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3330 );
3331 assert!(
3332 (a - b).abs() < 2e-7,
3333 "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3334 );
3335 }
3336 }
3337 }
3338
3339 #[test]
3340 pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3341 let x = array![
3342 [1.0, -1.1, 0.2],
3343 [1.0, -0.6, -0.3],
3344 [1.0, -0.1, 0.5],
3345 [1.0, 0.3, -0.7],
3346 [1.0, 0.8, 0.1],
3347 [1.0, 1.2, -0.4],
3348 ];
3349 let beta = array![0.08, -0.22, 0.27];
3350 let eta = x.dot(&beta);
3351 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3352 let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3353 let dir = op.direction_from_deta(x.dot(&u));
3354
3355 let p = x.ncols();
3356 let eye = Array2::<f64>::eye(p);
3357 let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3358 for i in 0..p {
3359 for j in 0..i {
3360 let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3361 via_apply[[i, j]] = sym;
3362 via_apply[[j, i]] = sym;
3363 }
3364 }
3365 let direct = op.hphi_direction(&dir);
3366 let diff = &direct - &via_apply;
3367 let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3368 assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3369 }
3370
3371 #[test]
3372 pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3373 let x = array![
3374 [1.0, -1.0, 0.2],
3375 [1.0, -0.6, -0.3],
3376 [1.0, -0.1, 0.5],
3377 [1.0, 0.3, -0.7],
3378 [1.0, 0.8, 0.1],
3379 [1.0, 1.2, -0.4],
3380 ];
3381 let x_tau = array![
3382 [0.0, 0.15, -0.05],
3383 [0.0, -0.10, 0.02],
3384 [0.0, 0.08, 0.04],
3385 [0.0, -0.06, -0.03],
3386 [0.0, 0.05, 0.01],
3387 [0.0, -0.12, 0.06],
3388 ];
3389 let beta = array![0.1, -0.25, 0.2];
3390 let eta = x.dot(&beta);
3391 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3392 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3393
3394 let h = 1e-6;
3395 let x_plus = &x + &(h * &x_tau);
3396 let x_minus = &x - &(h * &x_tau);
3397 let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3398
3399 assert!(
3400 (analytic - fd).abs() < 1e-6,
3401 "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3402 );
3403 }
3404
3405 #[test]
3406 pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3407 let x = array![
3408 [1.0, -1.0, 0.2],
3409 [1.0, -0.6, -0.3],
3410 [1.0, -0.1, 0.5],
3411 [1.0, 0.3, -0.7],
3412 [1.0, 0.8, 0.1],
3413 [1.0, 1.2, -0.4],
3414 ];
3415 let x_tau = array![
3416 [0.0, 0.15, -0.05],
3417 [0.0, -0.10, 0.02],
3418 [0.0, 0.08, 0.04],
3419 [0.0, -0.06, -0.03],
3420 [0.0, 0.05, 0.01],
3421 [0.0, -0.12, 0.06],
3422 ];
3423 let beta = array![0.1, -0.25, 0.2];
3424 let eta = x.dot(&beta);
3425 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3426 let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3427
3428 let h = 1e-6;
3429 let x_plus = &x + &(h * &x_tau);
3430 let x_minus = &x - &(h * &x_tau);
3431 let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3432
3433 let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3434 assert!(
3435 err < 1e-6,
3436 "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3437 );
3438 }
3439
3440 #[test]
3445 pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3446 let x = array![
3447 [1.0, -1.0, 0.2],
3448 [1.0, -0.6, -0.3],
3449 [1.0, -0.1, 0.5],
3450 [1.0, 0.3, -0.7],
3451 [1.0, 0.8, 0.1],
3452 [1.0, 1.2, -0.4],
3453 ];
3454 let x_tau_i = array![
3455 [0.0, 0.15, -0.05],
3456 [0.0, -0.10, 0.02],
3457 [0.0, 0.08, 0.04],
3458 [0.0, -0.06, -0.03],
3459 [0.0, 0.05, 0.01],
3460 [0.0, -0.12, 0.06],
3461 ];
3462 let x_tau_j = array![
3463 [0.0, -0.04, 0.11],
3464 [0.0, 0.09, -0.02],
3465 [0.0, -0.06, 0.07],
3466 [0.0, 0.10, -0.05],
3467 [0.0, -0.03, 0.08],
3468 [0.0, 0.07, -0.09],
3469 ];
3470 let beta = array![0.1, -0.25, 0.2];
3471 let eta = x.dot(&beta);
3472 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3473
3474 let analytic = op
3475 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3476 .phi_tau_tau_partial;
3477
3478 let h = 1e-5_f64;
3479 let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3480 let eta_e = x_eval.dot(&beta);
3481 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3482 op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3483 .phi_tau_partial
3484 };
3485 let x_plus = &x + &(h * &x_tau_j);
3486 let x_minus = &x - &(h * &x_tau_j);
3487 let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3488
3489 let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3490 assert!(
3491 rel < 1e-7,
3492 "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3493 );
3494 }
3495
3496 #[test]
3501 pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3502 let x = array![
3503 [1.0, -1.0, 0.2],
3504 [1.0, -0.6, -0.3],
3505 [1.0, -0.1, 0.5],
3506 [1.0, 0.3, -0.7],
3507 [1.0, 0.8, 0.1],
3508 [1.0, 1.2, -0.4],
3509 ];
3510 let x_tau_i = array![
3511 [0.0, 0.15, -0.05],
3512 [0.0, -0.10, 0.02],
3513 [0.0, 0.08, 0.04],
3514 [0.0, -0.06, -0.03],
3515 [0.0, 0.05, 0.01],
3516 [0.0, -0.12, 0.06],
3517 ];
3518 let x_tau_j = array![
3519 [0.0, -0.04, 0.11],
3520 [0.0, 0.09, -0.02],
3521 [0.0, -0.06, 0.07],
3522 [0.0, 0.10, -0.05],
3523 [0.0, -0.03, 0.08],
3524 [0.0, 0.07, -0.09],
3525 ];
3526 let beta = array![0.1, -0.25, 0.2];
3527 let eta = x.dot(&beta);
3528 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3529
3530 let analytic = op
3531 .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3532 .gphi_tau_tau;
3533
3534 let h = 1e-5_f64;
3535 let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3536 let eta_e = x_eval.dot(&beta);
3537 let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3538 op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3539 };
3540 let x_plus = &x + &(h * &x_tau_j);
3541 let x_minus = &x - &(h * &x_tau_j);
3542 let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3543
3544 let scale = analytic
3545 .iter()
3546 .chain(fd.iter())
3547 .map(|v| v.abs())
3548 .fold(0.0_f64, f64::max)
3549 .max(1.0);
3550 let err_max = (&analytic - &fd)
3551 .iter()
3552 .map(|v| v.abs())
3553 .fold(0.0_f64, f64::max);
3554 let rel = err_max / scale;
3555 assert!(
3556 rel < 1e-7,
3557 "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3558 );
3559 }
3560
3561 #[test]
3581 pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3582 let x = array![
3583 [1.0, -1.0, 0.2],
3584 [1.0, -0.6, -0.3],
3585 [1.0, -0.1, 0.5],
3586 [1.0, 0.3, -0.7],
3587 [1.0, 0.8, 0.1],
3588 [1.0, 1.2, -0.4],
3589 ];
3590 let x_tau_i = array![
3591 [0.0, 0.15, -0.05],
3592 [0.0, -0.10, 0.02],
3593 [0.0, 0.08, 0.04],
3594 [0.0, -0.06, -0.03],
3595 [0.0, 0.05, 0.01],
3596 [0.0, -0.12, 0.06],
3597 ];
3598 let x_tau_j = array![
3599 [0.0, -0.04, 0.11],
3600 [0.0, 0.09, -0.02],
3601 [0.0, -0.06, 0.07],
3602 [0.0, 0.10, -0.05],
3603 [0.0, -0.03, 0.08],
3604 [0.0, 0.07, -0.09],
3605 ];
3606 let beta = array![0.1, -0.25, 0.2];
3607 let eta = x.dot(&beta);
3608 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3609 let p = x.ncols();
3610
3611 let m = 3usize;
3613 let mut rhs = Array2::<f64>::zeros((p, m));
3614 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3615 for r in 0..p {
3616 for c in 0..m {
3617 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3618 }
3619 }
3620
3621 let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3624 let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3625 let deta_i = x_tau_i.dot(&beta);
3626 let deta_j = x_tau_j.dot(&beta);
3627 let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3628 let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3629
3630 let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3631 x_tau_i_reduced.clone(),
3632 x_tau_j_reduced.clone(),
3633 &deta_i,
3634 &deta_j,
3635 dot_h_i.clone(),
3636 dot_h_j.clone(),
3637 dot_i_i.clone(),
3638 dot_i_j.clone(),
3639 None,
3640 None,
3641 );
3642 let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3643 x_tau_j_reduced,
3644 x_tau_i_reduced,
3645 &deta_j,
3646 &deta_i,
3647 dot_h_j,
3648 dot_h_i,
3649 dot_i_j,
3650 dot_i_i,
3651 None,
3652 None,
3653 );
3654 let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3655 let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3656
3657 let sym_diff: f64 = (&analytic_ij - &analytic_ji)
3659 .iter()
3660 .map(|v| v.abs())
3661 .fold(0.0_f64, f64::max);
3662 let sym_scale: f64 = analytic_ij
3663 .iter()
3664 .chain(analytic_ji.iter())
3665 .map(|v| v.abs())
3666 .fold(0.0_f64, f64::max)
3667 .max(1.0);
3668 assert!(
3669 sym_diff / sym_scale < 1e-10,
3670 "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
3671 );
3672
3673 let h = 1e-5_f64;
3676 let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
3677 let eta_e = x_eval.dot(&beta);
3678 let op_e =
3679 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
3680 let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
3681 let deta_i_e = x_tau_i.dot(&beta);
3682 let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
3683 let kernel_i_e = op_e
3684 .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
3685 op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
3686 };
3687 let x_plus = &x + &(h * &x_tau_j);
3688 let x_minus = &x - &(h * &x_tau_j);
3689 let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
3690
3691 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
3694 let scale = a
3695 .iter()
3696 .chain(b.iter())
3697 .map(|v| v.abs())
3698 .fold(0.0_f64, f64::max)
3699 .max(1.0);
3700 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
3701 max_diff / scale
3702 };
3703 let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
3704
3705 let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
3708 let eta_e = x_eval.dot(&beta);
3709 let op_e =
3710 build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
3711 let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
3712 let deta_j_e = x_tau_j.dot(&beta);
3713 let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
3714 let kernel_j_e = op_e
3715 .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
3716 op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
3717 };
3718 let x_plus_i = &x + &(h * &x_tau_i);
3719 let x_minus_i = &x - &(h * &x_tau_i);
3720 let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
3721 let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
3722
3723 let tol = 1e-7_f64;
3724 assert!(
3725 err_ij < tol,
3726 "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
3727 analytic=\n{analytic_ij:?}\n\
3728 fd=\n{fd_ij:?}"
3729 );
3730 assert!(
3731 err_ji < tol,
3732 "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
3733 analytic=\n{analytic_ji:?}\n\
3734 fd=\n{fd_ji:?}"
3735 );
3736 }
3737
3738 #[test]
3756 pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
3757 let x = array![
3758 [1.0, -1.0, 0.2],
3759 [1.0, -0.6, -0.3],
3760 [1.0, -0.1, 0.5],
3761 [1.0, 0.3, -0.7],
3762 [1.0, 0.8, 0.1],
3763 [1.0, 1.2, -0.4],
3764 ];
3765 let x_tau = array![
3766 [0.0, 0.15, -0.05],
3767 [0.0, -0.10, 0.02],
3768 [0.0, 0.08, 0.04],
3769 [0.0, -0.06, -0.03],
3770 [0.0, 0.05, 0.01],
3771 [0.0, -0.12, 0.06],
3772 ];
3773 let beta = array![0.1, -0.25, 0.2];
3774 let v = array![0.3, 0.2, -0.15];
3776
3777 let eta = x.dot(&beta);
3778 let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3779 let p = x.ncols();
3780
3781 let m = 3usize;
3783 let mut rhs = Array2::<f64>::zeros((p, m));
3784 let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3785 for r in 0..p {
3786 for c in 0..m {
3787 rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3788 }
3789 }
3790
3791 let x_tau_reduced = op.reduce_explicit_design(&x_tau);
3793 let deta_partial = x_tau.dot(&beta);
3794 let (dot_i_partial, dot_h_partial) =
3795 op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
3796 let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
3797 x_tau_reduced.clone(),
3798 &deta_partial,
3799 dot_h_partial.clone(),
3800 dot_i_partial.clone(),
3801 );
3802
3803 let deta_v = x.dot(&v);
3804 let direction = op.direction_from_deta(deta_v);
3805 let x_tau_v = x_tau.dot(&v);
3806 let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
3807 &tau_kernel,
3808 &deta_partial,
3809 &dot_i_partial,
3810 &direction,
3811 &x_tau_v,
3812 );
3813 let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
3814
3815 let h = 1e-5_f64;
3818 let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
3819 let eta_e = x.dot(beta_eval);
3820 let op_e =
3821 build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
3822 let x_tau_r = op_e.reduce_explicit_design(&x_tau);
3823 let deta_e = x_tau.dot(beta_eval);
3824 let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
3825 let ker_e =
3826 op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
3827 op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
3828 };
3829 let beta_plus = &beta + &(h * &v);
3830 let beta_minus = &beta - &(h * &v);
3831 let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
3832
3833 let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
3834 let scale = a
3835 .iter()
3836 .chain(b.iter())
3837 .map(|v| v.abs())
3838 .fold(0.0_f64, f64::max)
3839 .max(1.0);
3840 let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
3841 max_diff / scale
3842 };
3843 let err = rel_max_abs_diff(&analytic, &fd);
3844
3845 let tol = 1e-7_f64;
3846 assert!(
3847 err < tol,
3848 "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
3849 analytic=\n{analytic:?}\n\
3850 fd=\n{fd:?}"
3851 );
3852 }
3853
3854 #[test]
3855 pub(crate) fn logisticweight_loses_positive_tail_mass() {
3856 let eta = 50.0_f64;
3857 let z = (-eta).exp();
3858 let stable = z / (1.0_f64 + z).powi(2);
3859 assert!(stable > 0.0);
3860 let got = logisticweight(eta);
3861 assert!(
3862 (got - stable).abs() < 1e-30,
3863 "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
3864 got,
3865 stable
3866 );
3867 }
3868
3869 #[test]
3870 pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
3871 for &eta in &[
3875 -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
3876 ] {
3877 let jet = logit_inverse_link_jet5(eta);
3878 let (w, w1, w2, w3, w4) =
3879 crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
3880 assert!(
3881 w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
3882 "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
3883 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
3884 jet.d1,
3885 jet.d2,
3886 jet.d3,
3887 jet.d4,
3888 jet.d5
3889 );
3890 }
3891 }
3892
3893 #[test]
3894 pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
3895 fn reference_probit_weight(eta: f64) -> f64 {
3899 let p = gam_math::probability::normal_cdf(eta);
3900 let q = 1.0 - p;
3901 let phi = gam_math::probability::normal_pdf(eta);
3902 if p <= 0.0 || q <= 0.0 {
3903 return 0.0;
3904 }
3905 phi * phi / (p * q)
3906 }
3907 let h = 1e-4_f64;
3908 for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
3909 let (w, w1, w2, _w3, _w4) =
3910 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3911 let ref_w = reference_probit_weight(eta);
3912 let fd1 =
3913 (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
3914 let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
3915 + reference_probit_weight(eta - h))
3916 / (h * h);
3917 assert!(
3918 (w - ref_w).abs() < 1e-10,
3919 "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
3920 );
3921 assert!(
3922 (w1 - fd1).abs() < 1e-5,
3923 "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
3924 );
3925 assert!(
3926 (w2 - fd2).abs() < 1e-3,
3927 "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
3928 );
3929 }
3930 }
3931
3932 #[test]
3933 pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
3934 for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
3938 let (w, w1, w2, w3, w4) =
3939 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3940 assert!(
3941 w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
3942 "probit Fisher weight jet must saturate to zero at eta={eta}; got \
3943 ({w}, {w1}, {w2}, {w3}, {w4})"
3944 );
3945 }
3946 for &eta in &[12.0_f64, -12.0] {
3951 let (w, w1, w2, w3, w4) =
3952 crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3953 assert!(
3954 w > 0.0
3955 && w.is_finite()
3956 && w1.is_finite()
3957 && w2.is_finite()
3958 && w3.is_finite()
3959 && w4.is_finite(),
3960 "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
3961 ({w}, {w1}, {w2}, {w3}, {w4})"
3962 );
3963 }
3964 }
3965}