Skip to main content

gam_solve/reml/
firth.rs

1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8/// Reciprocal-condition-number floor below which the reduced Fisher information
9/// `I_r` is flagged as near-singular (a diagnostic warning only, not a hard
10/// gate). At `λ_min/λ_max < 1e-10` the SPD assumption on the identifiable
11/// subspace is numerically fragile and the exact pseudodet derivatives may be
12/// ill-conditioned near active-subspace boundaries.
13pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15impl<'a> RemlState<'a> {
16    pub(crate) fn xt_diag_x_dense_into(
17        x: &Array2<f64>,
18        diag: &Array1<f64>,
19        weighted: &mut Array2<f64>,
20    ) -> Array2<f64> {
21        super::assembly::xt_diag_x_dense_into(x, diag, weighted)
22    }
23
24    #[inline]
25    pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
26        n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
27    }
28
29    pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
30        let mut out = Array2::<f64>::zeros(x.raw_dim());
31        super::assembly::row_scale_dense_into(x, scale, &mut out);
32        out
33    }
34
35    #[inline]
36    pub(crate) fn dense_product_likely_uses_inner_parallelism(
37        m: usize,
38        n: usize,
39        k: usize,
40    ) -> bool {
41        // Keep this in sync with faer_ndarray::matmul_parallelism.  When a
42        // dense product is large enough for faer/BLAS-style internal
43        // parallelism, do not also wrap sibling products in rayon::join: that
44        // can oversubscribe CPU threads and slow down the REML hot path.
45        const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
46        const PAR_MIN_LONG_DIM: usize = 256;
47        let flop_scale = m.saturating_mul(n).saturating_mul(k);
48        let long_dim = m.max(n).max(k);
49        flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
50    }
51
52    #[inline]
53    pub(crate) fn should_join_independent_dense_products(
54        products: &[(usize, usize, usize)],
55    ) -> bool {
56        const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
57        if rayon::current_num_threads() <= 1 {
58            return false;
59        }
60        let mut total_flop_scale = 0usize;
61        for &(m, n, k) in products {
62            if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
63                return false;
64            }
65            total_flop_scale =
66                total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
67        }
68        total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
69    }
70
71    /// Undo the fixed observation-weight row scale used by Firth reduced
72    /// designs.
73    ///
74    /// Keep this distinct from PIRLS's sparse-SpGEMM `sqrt_weights` cache in
75    /// `solver/pirls.rs`: PIRLS materializes roots of the current working
76    /// weights for a Gram factorization, while this uses stored fixed
77    /// case-weight roots and their reciprocals to map reduced design
78    /// derivatives back to raw design space.
79    #[inline]
80    pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
81        out: &mut Array2<f64>,
82        observation_weight_sqrt: Option<&Array1<f64>>,
83    ) {
84        let Some(scale) = observation_weight_sqrt else {
85            return;
86        };
87        super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
88    }
89
90    /// GLM Fisher working-weight 5-jet for the requested inverse link. For
91    /// standard Logit this is byte-identical to the historical
92    /// `logit_inverse_link_jet5(eta).d1..d5` path that the Firth operator used
93    /// before the weights were generalized to arbitrary inverse links.
94    #[inline]
95    pub(crate) fn fisher_weight_derivatives(
96        link: &InverseLink,
97        eta: f64,
98    ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
99        fisher_weight_jet5_for_inverse_link(link, eta)
100    }
101
102    #[inline]
103    pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
104        let mut min_pivot_sq = f64::INFINITY;
105        let mut max_pivot_sq = 0.0_f64;
106        for &pivot in chol_diag {
107            if !pivot.is_finite() || pivot <= 0.0 {
108                return false;
109            }
110            let pivot_sq = pivot * pivot;
111            min_pivot_sq = min_pivot_sq.min(pivot_sq);
112            max_pivot_sq = max_pivot_sq.max(pivot_sq);
113        }
114        if !min_pivot_sq.is_finite() {
115            return false;
116        }
117        let scale = max_pivot_sq.max(1.0);
118        let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
119        min_pivot_sq > floor
120    }
121
122    pub(crate) fn reduced_fisher_inverse_and_half_logdet(
123        fisher_reduced: &Array2<f64>,
124    ) -> Result<(Array2<f64>, f64), EstimationError> {
125        let r = fisher_reduced.nrows();
126        assert_eq!(r, fisher_reduced.ncols());
127        let mut k_reduced = Array2::<f64>::zeros((r, r));
128        if r == 0 {
129            return Ok((k_reduced, 0.0));
130        }
131
132        if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
133            let chol_diag = chol.diag();
134            if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
135                let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
136                for col in 0..r {
137                    let mut e_col = Array1::<f64>::zeros(r);
138                    e_col[col] = 1.0;
139                    let solved = chol.solvevec(&e_col);
140                    k_reduced.column_mut(col).assign(&solved);
141                }
142                return Ok((k_reduced, half_log_det));
143            }
144        }
145
146        let (evals_ir, evecs_ir) = fisher_reduced
147            .eigh(Side::Lower)
148            .map_err(EstimationError::EigendecompositionFailed)?;
149        let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
150        let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
151        let mut kept_positive_direction = false;
152        let mut half_log_det = 0.0_f64;
153        for (eig_idx, &eig) in evals_ir.iter().enumerate() {
154            if eig > tol {
155                kept_positive_direction = true;
156                half_log_det += 0.5 * eig.ln();
157                let inv = eig.recip();
158                let vec = evecs_ir.column(eig_idx).to_owned();
159                for row in 0..r {
160                    for col in 0..r {
161                        k_reduced[[row, col]] += inv * vec[row] * vec[col];
162                    }
163                }
164            }
165        }
166        if !kept_positive_direction {
167            return Err(EstimationError::ModelIsIllConditioned {
168                condition_number: f64::INFINITY,
169            });
170        }
171        Ok((k_reduced, half_log_det))
172    }
173
174    pub(crate) fn fill_fisher_weight_derivative_arrays(
175        link: &InverseLink,
176        eta: &Array1<f64>,
177        w: &mut Array1<f64>,
178        w1: &mut Array1<f64>,
179        w2: &mut Array1<f64>,
180        w3: &mut Array1<f64>,
181        w4: &mut Array1<f64>,
182    ) -> Result<(), EstimationError> {
183        assert_eq!(eta.len(), w.len());
184        assert_eq!(eta.len(), w1.len());
185        assert_eq!(eta.len(), w2.len());
186        assert_eq!(eta.len(), w3.len());
187        assert_eq!(eta.len(), w4.len());
188
189        if Self::parallelize_firth_derivative_rows(eta.len()) {
190            let values: Result<Vec<_>, EstimationError> = eta
191                .par_iter()
192                .map(|&ei| Self::fisher_weight_derivatives(link, ei))
193                .collect();
194            for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
195                w[i] = value;
196                w1[i] = first;
197                w2[i] = second;
198                w3[i] = third;
199                w4[i] = fourth;
200            }
201            return Ok(());
202        }
203        for i in 0..eta.len() {
204            let (value, first, second, third, fourth) =
205                Self::fisher_weight_derivatives(link, eta[i])?;
206            w[i] = value;
207            w1[i] = first;
208            w2[i] = second;
209            w3[i] = third;
210            w4[i] = fourth;
211        }
212        Ok(())
213    }
214
215    pub(crate) fn weighted_cross(
216        left: &Array2<f64>,
217        right: &Array2<f64>,
218        weights: &Array1<f64>,
219    ) -> Array2<f64> {
220        assert_eq!(left.nrows(), right.nrows());
221        assert_eq!(left.nrows(), weights.len());
222        super::assembly::weighted_cross_dense(left, right, weights)
223    }
224
225    pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
226        assert_eq!(a.nrows(), b.ncols());
227        assert_eq!(a.ncols(), b.nrows());
228        let elems = a.nrows().saturating_mul(a.ncols());
229        if elems >= 32 * 32 {
230            let aview = FaerArrayView::new(a);
231            let bview = FaerArrayView::new(b);
232            return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
233        }
234        let m = a.nrows();
235        let n = a.ncols();
236        kahan_sum((0..m).map(|i| {
237            let mut acc = 0.0_f64;
238            for j in 0..n {
239                acc += a[[i, j]] * b[[j, i]];
240            }
241            acc
242        }))
243    }
244
245    pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
246        // Returns Zᵀ diag(weights) Z (exact).
247        //
248        // Used for:
249        //   S = Zᵀ diag(v) Z               in Hadamard-Gram apply,
250        //   G_u = X_rᵀ diag(s_u) X_r       with s_u = w' ⊙ (Xu),
251        // both of which avoid constructing dense n×n intermediates.
252        let weighted = Self::row_scale(z, weights);
253        fast_atb(z, &weighted)
254    }
255
256    pub(crate) fn reduced_crossweighted_gram(
257        z_left: &Array2<f64>,
258        z_right: &Array2<f64>,
259        weights: &Array1<f64>,
260    ) -> Array2<f64> {
261        // Returns Z_leftᵀ diag(weights) Z_right (exact).
262        //
263        // This is used for explicit design-moving terms where left/right
264        // reduced designs differ (X_r vs X_{tau,r}) in Hadamard-Gram products.
265        let weighted = Self::row_scale(z_right, weights);
266        fast_atb(z_left, &weighted)
267    }
268
269    pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
270        // Returns diag(Z A Zᵀ), exact without forming dense n×n matrix.
271        //
272        // Identity:
273        //   [diag(Z A Zᵀ)]_i = z_iᵀ A z_i,
274        // where z_i is row i of Z. We compute this as rowwise dot(Z, Z A).
275        let za = fast_ab(z, a);
276        (z * &za).sum_axis(ndarray::Axis(1))
277    }
278
279    pub(crate) fn apply_hadamard_gram(
280        z: &Array2<f64>,
281        a_left: &Array2<f64>,
282        a_right: &Array2<f64>,
283        vec: &Array1<f64>,
284    ) -> Array1<f64> {
285        // Exact apply of:
286        //   y = ((Z A_left Z^T) ⊙ (Z A_right Z^T)) vec
287        // using Gram/Hadamard identity with S = Z^T diag(vec) Z:
288        //   y_i = z_i^T A_left S A_right z_i.
289        //
290        // This is the matrix-free kernel behind the Firth terms that would
291        // otherwise require dense P = M⊙M and M⊙N_u products.
292        // Complexity is O(n r^2) with r = rank(X), no n×n storage.
293        let s = Self::reducedweighted_gram(z, vec);
294        let left_s = a_left.dot(&s);
295        let t = left_s.dot(a_right);
296        Self::reduced_diag_gram(z, &t)
297    }
298
299    pub(crate) fn apply_hadamard_gram_to_matrix(
300        z: &Array2<f64>,
301        a_left: &Array2<f64>,
302        a_right: &Array2<f64>,
303        mat: &Array2<f64>,
304    ) -> Array2<f64> {
305        // Columnwise extension of apply_hadamard_gram:
306        //   out[:,j] = ((Z A_left Zᵀ) ⊙ (Z A_right Zᵀ)) mat[:,j].
307        //
308        // In the Firth derivatives this is used with:
309        //   A_left = K_r, A_right = K_r        for Hw⊙Hw actions,
310        //   A_left = K_r, A_right = A_u        for Hw⊙Nbar_u actions,
311        // and symmetric variants in mixed second-direction terms.
312        let mut out = Array2::<f64>::zeros(mat.raw_dim());
313        for col in 0..mat.ncols() {
314            let v = mat.column(col).to_owned();
315            let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
316            out.column_mut(col).assign(&y);
317        }
318        out
319    }
320
321    /// Link-aware dense Firth/Jeffreys builder. The REML callsites resolve the
322    /// Fisher-weight link via `reml_robust_jeffreys_link` and pass it here;
323    /// standard Logit reproduces the historical logit-pinned build byte-for-byte,
324    /// and stateful links flow through the same inverse-link derivative path.
325    pub(super) fn build_firth_dense_operator_for_link(
326        link: &InverseLink,
327        x_dense: &Array2<f64>,
328        eta: &Array1<f64>,
329        observation_weights: ndarray::ArrayView1<'_, f64>,
330    ) -> Result<FirthDenseOperator, EstimationError> {
331        FirthDenseOperator::build_with_observation_weights_impl(
332            link,
333            x_dense,
334            eta,
335            Some(observation_weights),
336        )
337    }
338
339    pub(super) fn firth_exact_tau_kernel(
340        op: &FirthDenseOperator,
341        x_tau: &Array2<f64>,
342        beta: &Array1<f64>,
343        include_hphi_tau_kernel: bool,
344    ) -> FirthTauExactKernel {
345        op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
346    }
347
348    pub(super) fn firth_hphi_tau_partial_apply(
349        op: &FirthDenseOperator,
350        x_tau: &Array2<f64>,
351        kernel: &FirthTauPartialKernel,
352        rhs: &Array2<f64>,
353    ) -> Array2<f64> {
354        op.hphi_tau_partial_apply(x_tau, kernel, rhs)
355    }
356}
357
358impl FirthDenseOperator {
359    pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
360        for col in 0..q_basis.ncols() {
361            let mut pivot_row = 0usize;
362            let mut pivot_abs = 0.0_f64;
363            for row in 0..q_basis.nrows() {
364                let value = q_basis[[row, col]];
365                let abs_value = value.abs();
366                if abs_value > pivot_abs {
367                    pivot_abs = abs_value;
368                    pivot_row = row;
369                }
370            }
371            if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
372                q_basis.column_mut(col).mapv_inplace(|v| -v);
373            }
374        }
375    }
376
377    pub(crate) fn identifiable_subspace_basis_from_gram(
378        gram: &Array2<f64>,
379    ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
380        let p = gram.nrows();
381        assert_eq!(p, gram.ncols());
382        if p == 0 {
383            return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
384        }
385
386        let (evals, evecs) = gram
387            .eigh(Side::Lower)
388            .map_err(EstimationError::EigendecompositionFailed)?;
389        let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
390        let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
391        let mut keep: Vec<usize> = evals
392            .iter()
393            .enumerate()
394            .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
395            .collect();
396        if keep.is_empty() {
397            return Err(EstimationError::ModelIsIllConditioned {
398                condition_number: f64::INFINITY,
399            });
400        }
401
402        // Use one orthonormal identifiable basis for both the full-rank and
403        // rank-deficient cases. Sorting retained modes by descending design
404        // energy makes the representation deterministic up to eigenspace
405        // degeneracy; fixing the column signs removes the remaining trivial
406        // sign ambiguity.
407        keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
408        let r = keep.len();
409        let mut q_basis = Array2::<f64>::zeros((p, r));
410        let mut metric_spectrum = Array1::<f64>::zeros(r);
411        for (col_idx, eig_idx) in keep.into_iter().enumerate() {
412            q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
413            metric_spectrum[col_idx] = evals[eig_idx];
414        }
415        Self::canonicalize_basis_column_signs(&mut q_basis);
416        Ok((q_basis, metric_spectrum))
417    }
418
419    #[inline]
420    pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
421        assert_eq!(diag.len(), matrix.nrows());
422        assert_eq!(matrix.nrows(), matrix.ncols());
423        kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
424    }
425
426    pub fn build_for_link(
427        link: &InverseLink,
428        x_dense: &Array2<f64>,
429        eta: &Array1<f64>,
430    ) -> Result<FirthDenseOperator, EstimationError> {
431        Self::build_with_observation_weights_impl(link, x_dense, eta, None)
432    }
433
434    pub fn build_with_observation_weights_for_link(
435        link: &InverseLink,
436        x_dense: &Array2<f64>,
437        eta: &Array1<f64>,
438        observation_weights: ndarray::ArrayView1<'_, f64>,
439    ) -> Result<FirthDenseOperator, EstimationError> {
440        Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
441    }
442
443    #[inline]
444    pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
445        &self.w * &self.h_diag
446    }
447
448    /// Per-observation Firth working-response shift `Δ_i` for the single-eta
449    /// PIRLS inner solve, so the inner stationary point is the SAME penalized
450    /// objective the outer REML differentiates through `jeffreys_beta_gradient`.
451    ///
452    /// PIRLS solves `Xᵀ W (z* − η) = 0` with `z*_i = z_i + Δ_i`, so the Firth
453    /// term it adds to the score is `Σ_i w_i Δ_i x_i`. The Jeffreys score is
454    /// `∂Φ/∂β = ½ Σ_i w'_i h_i x_i` (see `jeffreys_beta_gradient`, with
455    /// `h_i = [X_r K_r X_rᵀ]_ii = h_diag_i` and `w'_i = ∂w_i/∂η_i = w1`). Matching
456    /// the two gives `w_i Δ_i = ½ w'_i h_i`, i.e.
457    ///   `Δ_i = ½ · (w'_i / w_i) · h_diag_i`.
458    ///
459    /// For the canonical logit this collapses to `h_i (½ − μ_i) / w_i` because
460    /// `w'_i / w_i = (1 − 2μ_i)` and `h_i = w_i h_diag_i`; but for a NON-canonical
461    /// binomial link (probit, cloglog, …) `w'_i / w_i ≠ (1 − 2μ_i)`, so the
462    /// logit-pinned `(½ − μ_i)` shift solved a DIFFERENT objective than the one
463    /// the outer REML/`hphi_direction` machinery differentiates. This builds the
464    /// correct link-general shift from the same `w`, `w1`, `h_diag` the operator
465    /// already caches. `w_i ≤ 0` rows contribute no curvature and get a zero
466    /// shift (they cannot enter `Xᵀ W (z*−η)` anyway).
467    #[inline]
468    pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
469        let mut shift = Array1::<f64>::zeros(self.w.len());
470        for i in 0..self.w.len() {
471            let wi = self.w[i];
472            if wi > 0.0 {
473                shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
474            }
475        }
476        shift
477    }
478
479    pub(crate) fn build_with_observation_weights_impl(
480        link: &InverseLink,
481        x_dense: &Array2<f64>,
482        eta: &Array1<f64>,
483        observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
484    ) -> Result<FirthDenseOperator, EstimationError> {
485        // Precompute dense Firth objects at current β̂ for:
486        //   Φ(β) = 0.5 log|Uᵀ W U|,
487        // where U is a canonical orthonormal basis of the identifiable
488        // subspace of A^{1/2} X.
489        //
490        // Identifiability note:
491        // For rank-deficient design matrices X, I = Xᵀ W X is singular for all β
492        // (assuming w_i > 0). The mathematically coherent Jeffreys/Firth term is
493        // therefore the identifiable-subspace form:
494        //   Φ(β) = 0.5 log|Uᵀ W U|
495        //        = 0.5 log|I_r(β)| - 0.5 log|S_r|,
496        // with
497        //   I_r = X_rᵀ W X_r,
498        //   S_r = X_rᵀ X_r,
499        //   U   = X_r S_r^{-1/2}.
500        // The beta differential is still
501        //   dΦ = 0.5 tr(I_+^† dI),
502        // because S_r is beta-independent for a fixed design.
503        //
504        // For binomial-logit with finite eta, 0 < w_i <= 1/4, so W is SPD and
505        // Null(X'WX)=Null(X). Therefore singular directions are structural
506        // (from X) and independent of beta, which is why fixed-Q reduced-space
507        // derivatives are exact in this regime.
508        //
509        // This implementation uses the fixed identifiable-basis route directly:
510        //   I_r = X_rᵀ W X_r,  S_r = X_rᵀ X_r,
511        //   I_+^† = Q I_r^{-1} Qᵀ,
512        //   Φ     = 0.5 (log|I_r| - log|S_r|).
513        //
514        // Why `eta` (not `mu`) enters here:
515        //   all logistic weight derivatives are functions of eta through
516        //   mu(eta)=sigmoid(eta), and exact derivative consistency is preserved by
517        //   generating (w, w', w'', w''', w'''') from one coherent eta source.
518        // Using eta avoids any mismatch from externally clamped/post-processed mu.
519        //
520        // We cache reduced operators so derivatives are exact but matrix-free in n:
521        //   K_r = I_r^{-1},  h = diag(X_r K_r X_rᵀ),  B = diag(w')X,
522        // along with logistic derivatives w', w'', w''', w'''' to evaluate:
523        //   H_φ      = ∇²_β Φ,
524        //   D H_φ[u],
525        //   D² H_φ[u,v]
526        // exactly via reduced-space products (no explicit high-order tensors).
527        //
528        // Fixed observation weights:
529        // When callers provide nonnegative case weights a_i that are constant in
530        // β, the Jeffreys information is
531        //   I(β) = Xᵀ diag(a_i w_i(η)) X.
532        // We fold those fixed a_i into the identifiable basis and reduced design
533        // via X̃ = diag(sqrt(a_i)) X, so all derivative formulas continue to use
534        // the same η-derivatives of the family Fisher weights w(η), w'(η), ....
535        let n = x_dense.nrows();
536        if eta.len() != n {
537            crate::bail_invalid_estim!(
538                "Firth operator shape mismatch: nrows={}, eta_len={}",
539                n,
540                eta.len()
541            );
542        }
543        let observation_weight_sqrt = if let Some(weights) = observation_weights {
544            if weights.len() != n {
545                crate::bail_invalid_estim!(
546                    "Firth operator observation weight length {} != number of rows {}",
547                    weights.len(),
548                    n
549                );
550            }
551            let mut sqrt = Array1::<f64>::zeros(n);
552            for i in 0..n {
553                let weight = weights[i];
554                if !weight.is_finite() || weight < 0.0 {
555                    crate::bail_invalid_estim!(
556                        "Firth operator requires finite nonnegative observation weights, got {} at row {}",
557                        weight,
558                        i
559                    );
560                }
561                sqrt[i] = weight.sqrt();
562            }
563            Some(sqrt)
564        } else {
565            None
566        };
567        let mut w = Array1::<f64>::zeros(n);
568        let mut w1 = Array1::<f64>::zeros(n);
569        let mut w2 = Array1::<f64>::zeros(n);
570        let mut w3 = Array1::<f64>::zeros(n);
571        let mut w4 = Array1::<f64>::zeros(n);
572        RemlState::fill_fisher_weight_derivative_arrays(
573            link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
574        )?;
575        let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
576            RemlState::row_scale(x_dense, scale)
577        } else {
578            x_dense.clone()
579        };
580
581        // Build one orthonormal coefficient-space basis Q for the identifiable
582        // subspace of A^{1/2} X (A = I without fixed case weights).
583        //
584        // This must happen even when the weighted design is full rank. The old
585        // Q = I shortcut left full-rank designs in raw coordinates while the
586        // singular branch switched to an orthonormal identifiable basis, so the
587        // reduced Fisher determinant depended on which branch we took rather
588        // than only on the identifiable subspace itself.
589        //
590        // Using the retained eigenspace of X̃ᵀ X̃ for every design keeps both
591        // branches on one representation:
592        //   X̃ = A^{1/2} X,
593        //   Qᵀ Q = I,
594        //   X_r = X̃ Q,
595        //   S_r = X_rᵀ X_r = diag(positive spectrum of X̃ᵀ X̃).
596        let gram = fast_atb(&basis_design, &basis_design);
597        let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
598        let x_reduced = fast_ab(&basis_design, &q_basis);
599        let r = q_basis.ncols();
600
601        // Reduced Fisher on the identifiable subspace:
602        //   I_r = X_rᵀ W X_r.
603        // Under finite-logit eta, W has strictly positive diagonal entries and
604        // X_r has full column rank by construction, so I_r is SPD.
605        let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&x_reduced, &w);
606        // Smooth-regime diagnostic:
607        // for exact pseudodet derivatives we require I_r to stay SPD on the
608        // fixed identifiable subspace. Emit a warning when I_r appears close
609        // to singular, because this is where Jeffreys/Firth curvature becomes
610        // numerically fragile and active-subspace assumptions may fail.
611        if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
612            let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
613            let min_ev = eigvals_ir
614                .iter()
615                .copied()
616                .filter(|v| v.is_finite() && *v > 0.0)
617                .fold(f64::INFINITY, f64::min);
618            if min_ev.is_finite() {
619                let rel = min_ev / max_ev;
620                if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
621                    log::warn!(
622                        "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
623                        min_ev,
624                        max_ev,
625                        rel
626                    );
627                }
628            }
629        }
630
631        let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
632        let (k_reduced, mut half_log_det) = if r > 0 {
633            // The fixed-Q identifiable-space value is the generalized
634            // determinant 0.5(log|I_r| - log|S_r|), which is equivalent to
635            // evaluating W on the orthonormalized design U = X_r S_r^{-1/2}.
636            //
637            // Because Q diagonalizes X̃ᵀ X̃ by construction, S_r is exactly the
638            // retained positive spectrum of that Gram matrix in these reduced
639            // coordinates, so its inverse/logdet are diagonal operations here.
640            // Prefer the fast SPD path for I_r, but when I_r is only
641            // numerically semidefinite after projection, keep the exact
642            // positive eigenspace instead of failing outright.
643            RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
644        } else {
645            (Array2::<f64>::zeros((r, r)), 0.0)
646        };
647        if r > 0 {
648            for col in 0..r {
649                let metric_eig = metric_spectrum[col];
650                half_log_det -= 0.5 * metric_eig.ln();
651                x_metric_reduced_inv_diag[col] = metric_eig.recip();
652            }
653        }
654        // Reduced design enters M = X_r K_r X_rᵀ and P = M⊙M.
655        let h_diag = if r > 0 {
656            RemlState::reduced_diag_gram(&x_reduced, &k_reduced)
657        } else {
658            Array1::<f64>::zeros(n)
659        };
660        let x_dense_t = x_dense.t().to_owned();
661        let b_base = RemlState::row_scale(x_dense, &w1);
662        let p_b_base =
663            RemlState::apply_hadamard_gram_to_matrix(&x_reduced, &k_reduced, &k_reduced, &b_base);
664        Ok(FirthDenseOperator {
665            x_dense: x_dense.clone(),
666            x_dense_t,
667            q_basis,
668            x_reduced,
669            observation_weight_sqrt,
670            k_reduced,
671            x_metric_reduced_inv_diag,
672            half_log_det,
673            h_diag,
674            w,
675            w1,
676            w2,
677            w3,
678            w4,
679            b_base,
680            p_b_base,
681        })
682    }
683
684    #[inline]
685    pub(crate) fn jeffreys_logdet(&self) -> f64 {
686        self.half_log_det
687    }
688
689    /// Tangent-projected Jeffreys/Firth log-determinant `½ log|Zᵀ J Z|_+`,
690    /// where `J = Xᵀ A W(η) X` is the full p-space Fisher information at
691    /// the current `η` and `Z` is the `p × m` orthonormal basis of
692    /// `null(A_act)` produced by the active-constraint tangent projector.
693    ///
694    /// Identity: `Zᵀ J Z = (X̃ Z)ᵀ W (X̃ Z)` with `X̃ = A^{1/2} X` when
695    /// fixed observation weights are present, `X̃ = X` otherwise. The
696    /// projected log-pseudo-det uses the same positive-eigenvalue
697    /// threshold convention as the rest of the tangent-projected LAML
698    /// (`positive_eigenvalue_threshold` / `exact_pseudo_logdet`) so the
699    /// kernel that defines "active subspace" is consistent across the
700    /// objective, its gradient, and the Firth contribution.
701    pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
702        use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
703        let p = self.x_dense.ncols();
704        assert_eq!(
705            z.nrows(),
706            p,
707            "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
708            p,
709            z.nrows()
710        );
711        let m = z.ncols();
712        if m == 0 {
713            return 0.0;
714        }
715        // X·Z, then optional sqrt(A) row-scale → X̃·Z.
716        let z_owned = z.to_owned();
717        let xz = fast_ab(&self.x_dense, &z_owned);
718        let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
719            RemlState::row_scale(&xz, scale)
720        } else {
721            xz
722        };
723        // J_T = (X̃ Z)ᵀ W (X̃ Z), symmetric m × m PSD.
724        let mut j_t = fast_xt_diag_x(&xtz, &self.w);
725        symmetrize_in_place(&mut j_t);
726        let (evals, _) = match j_t.eigh(Side::Lower) {
727            Ok(pair) => pair,
728            Err(_) => return f64::NEG_INFINITY,
729        };
730        let Some(evals_slice) = evals.as_slice() else {
731            return f64::NEG_INFINITY;
732        };
733        let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
734        0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
735    }
736
737    #[inline]
738    pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
739        // For I(β) = Xᵀ A W(η) X with fixed observation weights A,
740        //   ∂/∂β_j [0.5 log|I|]
741        //   = 0.5 Σ_i h_i w_i'(η_i) x_{ij},
742        // where h_i = [A^{1/2} X I^{-1} Xᵀ A^{1/2}]_{ii}.
743        0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
744    }
745
746    #[inline]
747    pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
748        (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
749    }
750
751    #[inline]
752    pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
753        let mut reduced = fast_ab(x, &self.q_basis);
754        if let Some(scale) = self.observation_weight_sqrt.as_ref() {
755            reduced = RemlState::row_scale(&reduced, scale);
756        }
757        reduced
758    }
759
760    pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
761        // Directional building blocks for u:
762        //   δη_u = X u
763        //   I_u  = Xᵀ diag(w' ⊙ δη_u) X
764        //   T_u  = K I_u K
765        //   N_u  = X T_u Xᵀ
766        //   Dh[u] = -diag(N_u)
767        //   P_u   = D(M⊙M)[u] = -2(M⊙N_u)
768        // and B_u = diag(w'' ⊙ δη_u) X.
769        //
770        // In this implementation, active-subspace ambiguity is removed by fixed Q:
771        // K is represented by K_r = I_r^{-1} in reduced coordinates, so
772        //   A_u = K_r G_u K_r
773        // is exact for logit with finite eta (w_i > 0) and fixed rank(X).
774        // s_u is the diagonal weight for D I[u]:
775        //   D I[u] = Xᵀ diag(s_u) X,  s_u = w' ⊙ (X u).
776        let s_u = &self.w1 * &deta;
777        // G_u = X_rᵀ diag(s_u) X_r,  A_u = K_r G_u K_r.
778        // These are reduced-space forms of
779        //   I_u and T_u = K I_u K
780        // from the full-space derivation.
781        let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
782        let k_g_u = self.k_reduced.dot(&g_u_reduced);
783        let a_u_reduced = k_g_u.dot(&self.k_reduced);
784        // Dh[u] = -diag(N_u),  N_u = X T_u Xᵀ, represented here as
785        //   N_u = Z A_u Zᵀ in weighted reduced coordinates.
786        let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
787        let b_uvec = &self.w2 * &deta;
788        FirthDirection {
789            deta,
790            g_u_reduced,
791            a_u_reduced,
792            dh,
793            b_uvec,
794        }
795    }
796
797    #[inline]
798    pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
799        fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
800    }
801
802    #[inline]
803    pub(crate) fn apply_p_u_to_matrix(
804        &self,
805        a_u_reduced: &Array2<f64>,
806        mat: &Array2<f64>,
807    ) -> Array2<f64> {
808        let mut out = RemlState::apply_hadamard_gram_to_matrix(
809            &self.x_reduced,
810            &self.k_reduced,
811            a_u_reduced,
812            mat,
813        );
814        out.mapv_inplace(|v| -2.0 * v);
815        out
816    }
817
818    pub(crate) fn hphi_direction_apply(
819        &self,
820        dir: &FirthDirection,
821        rhs: &Array2<f64>,
822    ) -> Array2<f64> {
823        let p = self.x_dense.ncols();
824        if rhs.nrows() != p {
825            return Array2::<f64>::zeros((p, rhs.ncols()));
826        }
827        if rhs.ncols() == 0 || p == 0 {
828            return Array2::<f64>::zeros((p, rhs.ncols()));
829        }
830        // Matrix-free apply of D(Hphi)[u] to a block V:
831        //   D(Hphi)[u] V
832        // = 0.5[ Xᵀ(c_u ⊙ (X V))
833        //       - B_uᵀ P (B V) - Bᵀ P (B_u V) - Bᵀ P_u (B V) ].
834        // This avoids dense p×p materialization and is used by sparse exact
835        // trace contractions through tr(H^{-1} ·).
836        let etav = fast_ab(&self.x_dense, rhs);
837        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
838        let m_qv = RemlState::apply_hadamard_gram_to_matrix(
839            &self.x_reduced,
840            &self.k_reduced,
841            &self.k_reduced,
842            &qv,
843        );
844        let buvec = &dir.b_uvec;
845        let m_buv = RemlState::apply_hadamard_gram_to_matrix(
846            &self.x_reduced,
847            &self.k_reduced,
848            &self.k_reduced,
849            &(&etav * &buvec.view().insert_axis(Axis(1))),
850        );
851        let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
852        let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
853        let diag_term = self
854            .x_dense_t
855            .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
856        let term1 = self.left_scaled_xt(buvec, &m_qv);
857        let term2 = self.left_scaled_xt(&self.w1, &m_buv);
858        let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
859        0.5 * (diag_term - (term1 + term2 + term3))
860    }
861
862    pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
863        let p = self.x_dense.ncols();
864        let eye = Array2::<f64>::eye(p);
865        let mut out = self.hphi_direction_apply(dir, &eye);
866        // Exact first directional derivative of H_φ:
867        //   D H_φ[u]
868        //   = 0.5 [ Xᵀ diag(c_u) X
869        //           - (B_uᵀ P B + Bᵀ P B_u + Bᵀ P_u B) ],
870        // where
871        //   c_u = w''' ⊙ δη_u ⊙ h + w'' ⊙ Dh[u],
872        //   B   = diag(w') X.
873        //
874        // Matrix-free contraction map used below:
875        //   Bᵀ P B      via apply_hadamard_gram_to_matrix(Z, K_r, K_r, B)
876        //   Bᵀ P_u B    via apply_hadamard_gram_to_matrix(Z, K_r, A_u, B), then *(-2)
877        // where P = M⊙M and P_u = -2(M⊙N_u), but M/N_u are never formed explicitly.
878        symmetrize_in_place(&mut out);
879        out
880    }
881
882    pub(crate) fn hphisecond_direction_apply(
883        &self,
884        u: &FirthDirection,
885        v: &FirthDirection,
886        rhs: &Array2<f64>,
887    ) -> Array2<f64> {
888        let p = self.x_dense.ncols();
889        if rhs.nrows() != p {
890            return Array2::<f64>::zeros((p, rhs.ncols()));
891        }
892        if rhs.ncols() == 0 || p == 0 {
893            return Array2::<f64>::zeros((p, rhs.ncols()));
894        }
895        // Exact mixed second directional derivative:
896        //   D² H_φ[u,v] = 0.5 [ Xᵀ diag(c_uv) X - D²J₂[u,v] ], J₂ = Bᵀ P B.
897        // Implemented with matrix identities for N_{u,v}, P_{u,v}, and the
898        // nine-term expansion of D²J₂[u,v].
899        //
900        // Because we parameterize Phi through fixed-rank I_r = X_rᵀ W X_r (SPD for
901        // finite-logit eta), this mixed derivative is evaluated on a smooth
902        // manifold without dynamic active-set switching in the Firth block.
903        //
904        // The nine contraction terms below are the explicit D²J₂[u,v] expansion,
905        // each computed through reduced Hadamard-Gram operators.
906        let deta_uv = &u.deta * &v.deta;
907        // Mixed reduced Gram:
908        //   G_uv = X_rᵀ diag(w'' ⊙ (Xu) ⊙ (Xv)) X_r.
909        let s_uv = &self.w2 * &deta_uv;
910        let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
911        let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
912        let k_gv = self.k_reduced.dot(&v.g_u_reduced);
913        let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
914        // Reduced form of:
915        //   T_{u,v} = K I_{u,v} K - K Iv K I_u K - K I_u K Iv K.
916        let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
917            - k_gv.dot(&k_g_u).dot(&self.k_reduced)
918            - k_g_u.dot(&k_gv).dot(&self.k_reduced);
919        let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
920        // Implements mixed diagonal coefficient:
921        //   c_uv = w'''' ⊙ (Xu) ⊙ (Xv) ⊙ h
922        //          + w''' ⊙ ((Xu) ⊙ Dh[v] + (Xv) ⊙ Dh[u])
923        //          + w'' ⊙ D²h[u,v].
924        let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
925            + &(&self.w3 * &(&u.deta * &v.dh))
926            + &(&self.w3 * &(&v.deta * &u.dh))
927            + &(&self.w2 * &d2h);
928
929        let eta_rhs = fast_ab(&self.x_dense, rhs);
930        let diag_term = fast_ab(
931            &self.x_dense_t,
932            &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
933        );
934
935        let b_uvvec = &self.w3 * &deta_uv;
936        let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
937        let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
938
939        // Linearity in the rhs argument lets us precompute the expensive
940        // Hadamard-Gram operator on the full base blocks B, B_u, Bv, B_uv once,
941        // then post-multiply by rhs. This preserves the exact operator while
942        // avoiding repeated O(n r^2 c) work for every rhs block.
943        let p_b_rhs = fast_ab(&self.p_b_base, rhs);
944        let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
945            &self.x_reduced,
946            &self.k_reduced,
947            &self.k_reduced,
948            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
949        );
950        let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
951            &self.x_reduced,
952            &self.k_reduced,
953            &self.k_reduced,
954            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
955        );
956        let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
957            &self.x_reduced,
958            &self.k_reduced,
959            &self.k_reduced,
960            &b_uv_base,
961        );
962        let p_buv_rhs = fast_ab(&p_buv_base, rhs);
963
964        let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
965        let pv_bu_rhs = self.apply_p_u_to_matrix(
966            &v.a_u_reduced,
967            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
968        );
969        let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
970        let p_u_bv_rhs = self.apply_p_u_to_matrix(
971            &u.a_u_reduced,
972            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
973        );
974
975        let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
976            &self.x_reduced,
977            &u.a_u_reduced,
978            &v.a_u_reduced,
979            &self.b_base,
980        );
981        let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
982            &self.x_reduced,
983            &self.k_reduced,
984            &a_uv_reduced,
985            &self.b_base,
986        );
987        let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
988        let p_uv_rhs = fast_ab(&p_uv_base, rhs);
989
990        // Nine-term expansion of D²J₂[u,v] with J₂ = Bᵀ P B.
991        let d2_terms = [
992            self.left_scaled_xt(&b_uvvec, &p_b_rhs),
993            self.left_scaled_xt(&self.w1, &p_buv_rhs),
994            self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
995            self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
996            self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
997            self.left_scaled_xt(&self.w1, &pv_bu_rhs),
998            self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
999            self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1000            self.left_scaled_xt(&self.w1, &p_uv_rhs),
1001        ];
1002        let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1003        for term in d2_terms {
1004            d2_j2 += &term;
1005        }
1006
1007        0.5 * (diag_term - d2_j2)
1008    }
1009
1010    pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1011        assert_eq!(a.nrows(), b.nrows());
1012        assert_eq!(a.ncols(), b.ncols());
1013        let mut out = Array1::<f64>::zeros(a.nrows());
1014        for i in 0..a.nrows() {
1015            let mut acc = 0.0_f64;
1016            for j in 0..a.ncols() {
1017                acc += a[[i, j]] * b[[i, j]];
1018            }
1019            out[i] = acc;
1020        }
1021        out
1022    }
1023
1024    pub(super) fn rowwise_bilinear(
1025        a: &Array2<f64>,
1026        m: &Array2<f64>,
1027        b: &Array2<f64>,
1028    ) -> Array1<f64> {
1029        // Returns vector with entries a_iᵀ M b_i for each row i.
1030        assert_eq!(a.nrows(), b.nrows());
1031        assert_eq!(a.ncols(), m.nrows());
1032        assert_eq!(b.ncols(), m.ncols());
1033        let am = fast_ab(a, m);
1034        Self::rowwise_dot(&am, b)
1035    }
1036
1037    pub(crate) fn dot_i_and_h_from_reduced(
1038        &self,
1039        x_tau_reduced: &Array2<f64>,
1040        deta: &Array1<f64>,
1041    ) -> (Array2<f64>, Array1<f64>) {
1042        // Reduced Fisher directional derivative under fixed identifiable basis:
1043        //   I_r = X_r' W X_r
1044        //   I_r,tau = X_{r,tau}' W X_r + X_r' W X_{r,tau} + X_r' W_tau X_r
1045        // with W_tau = diag(w' ⊙ eta_tau).
1046        //
1047        // Leverage derivative used by Firth score partial:
1048        //   h_i = x_{r,i}' K_r x_{r,i}, K_r = I_r^{-1}
1049        //   h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r')
1050        //   K_{r,tau} = -K_r I_{r,tau} K_r.
1051        //
1052        // This is exactly the fixed-beta directional derivative required by
1053        //   (gphi)_tau and Phi_tau in the Jeffreys/Firth design-moving path:
1054        //   I_{r,tau}|beta = X_{r,tau}' W X_r + X_r' W X_{r,tau}
1055        //                    + X_r' diag(w' ⊙ eta_tau|beta) X_r,
1056        //   eta_tau|beta = X_tau beta.
1057        //
1058        // We return:
1059        //   dot_i  = I_{r,tau}|beta,
1060        //   dot_h  = h_tau|beta.
1061        let dw = &self.w1 * deta;
1062        let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1063            + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1064            + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1065
1066        let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1067        let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1068        let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1069        let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1070        let dot_h = dot_h_explicit + dot_h_implicit;
1071        (dot_i, dot_h)
1072    }
1073
1074    pub(crate) fn exact_tau_kernel(
1075        &self,
1076        x_tau: &Array2<f64>,
1077        beta: &Array1<f64>,
1078        include_hphi_tau_kernel: bool,
1079    ) -> FirthTauExactKernel {
1080        // Shared exact tau-partial bundle used by both dense and sparse paths:
1081        //   (gphi)_tau | beta-fixed,
1082        //   Phi_tau | beta-fixed,
1083        // and optional H_{phi,tau}|beta kernel for later matrix-free applies.
1084        //
1085        // Closed forms (reduced Fisher, fixed active subspace):
1086        //   Phi = 0.5 log|I_r| - 0.5 log|S_r|,
1087        //   I_r = X_r' W X_r, K_r = I_r^{-1},
1088        //   S_r = X_r' X_r,   diag(G_r) = diag(S_r^{-1}),
1089        //   Phi_tau|beta = 0.5 tr(K_r I_{r,tau}) - 0.5 tr(G_r S_{r,tau}).
1090        // In the canonical reduced basis used here, G_r is diagonal.
1091        //
1092        //   (gphi)_tau = Phi_beta,tau
1093        //               = 0.5 X_tau' (w1 .* h)
1094        //                 + 0.5 X'((w2 .* eta_tau) .* h + w1 .* h_tau),
1095        //   where
1096        //     h_i = x_{r,i}' K_r x_{r,i},
1097        //     h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r'),
1098        //     K_{r,tau} = -K_r I_{r,tau} K_r.
1099        //
1100        // Phi_beta,tau is unchanged by the -0.5 log|S_r| term because S_r does
1101        // not depend on beta. Only Phi_tau gets the explicit basis-drift
1102        // subtraction.
1103        let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1104        let x_tau_reduced = self.reduce_explicit_design(x_tau);
1105        let (dot_i_partial, dot_h_partial) =
1106            self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1107        let dot_s_partial =
1108            fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1109
1110        let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1111        let secondvec =
1112            &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1113        let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1114        let gphi_tau = first + second;
1115        let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1116            - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1117
1118        let tau_kernel = if include_hphi_tau_kernel {
1119            Some(self.hphi_tau_partial_prepare_from_partials(
1120                x_tau_reduced,
1121                &deta_partial,
1122                dot_h_partial,
1123                dot_i_partial,
1124            ))
1125        } else {
1126            None
1127        };
1128        FirthTauExactKernel {
1129            gphi_tau,
1130            phi_tau_partial,
1131            tau_kernel,
1132        }
1133    }
1134
1135    pub(crate) fn hphi_tau_partial_prepare_from_partials(
1136        &self,
1137        x_tau_reduced: Array2<f64>,
1138        deta_partial: &Array1<f64>,
1139        dot_h_partial: Array1<f64>,
1140        dot_i_partial: Array2<f64>,
1141    ) -> FirthTauPartialKernel {
1142        let dotw1 = &self.w2 * deta_partial;
1143        let dotw2 = &self.w3 * deta_partial;
1144        let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1145        FirthTauPartialKernel {
1146            deta_partial: deta_partial.clone(),
1147            dotw1,
1148            dotw2,
1149            dot_h_partial,
1150            x_tau_reduced,
1151            dot_i_partial,
1152            dot_k_reduced: dot_k,
1153        }
1154    }
1155
1156    pub(crate) fn d_beta_hphi_tau_partial_dense(
1157        &self,
1158        x_tau: &Array2<f64>,
1159        beta: &Array1<f64>,
1160        beta_direction: &Array1<f64>,
1161    ) -> Option<Array2<f64>> {
1162        if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1163            return None;
1164        }
1165        if !x_tau.iter().any(|value| *value != 0.0) {
1166            return None;
1167        }
1168        let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1169        let tau_kernel = tau_bundle.tau_kernel?;
1170        let firth_direction =
1171            self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1172        let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1173        let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1174            &tau_kernel,
1175            &tau_kernel.deta_partial,
1176            &tau_kernel.dot_i_partial,
1177            &firth_direction,
1178            &x_tau_v,
1179        );
1180        let eye = Array2::<f64>::eye(beta_direction.len());
1181        Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1182    }
1183
1184    pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1185        // Applies P̄ = (X_r K_r X_rᵀ)⊙(X_r K_r X_rᵀ) to each column of mat.
1186        RemlState::apply_hadamard_gram_to_matrix(
1187            &self.x_reduced,
1188            &self.k_reduced,
1189            &self.k_reduced,
1190            mat,
1191        )
1192    }
1193
1194    pub(crate) fn apply_mtau_to_matrix(
1195        &self,
1196        kernel: &FirthTauPartialKernel,
1197        mat: &Array2<f64>,
1198    ) -> Array2<f64> {
1199        // Exact apply of
1200        //   M_tau = d/dtau[(P⊙P)]|_{beta fixed} = 2(P⊙P_tau)
1201        // without building dense n×n objects.
1202        //
1203        // Decomposition:
1204        //   P = Z K Zᵀ, Z = X_r
1205        //   P_tau = Z_tau K Zᵀ + Z K Z_tauᵀ + Z dotK Zᵀ
1206        // and for each vector v:
1207        //   (P⊙(Z_tau K Zᵀ))v   : rowwise bilinear with K (Zᵀdiag(v)Z) K
1208        //   (P⊙(Z K Z_tauᵀ))v   : diag_Z( K (Zᵀdiag(v)Z_tau) K )
1209        //   (P⊙(Z dotK Zᵀ))v    : Hadamard-Gram apply with (K, dotK).
1210        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1211            return Array2::<f64>::zeros(mat.raw_dim());
1212        }
1213        let mut out = Array2::<f64>::zeros(mat.raw_dim());
1214        for col in 0..mat.ncols() {
1215            let v = mat.column(col).to_owned();
1216            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1217            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1218            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1219
1220            let szt =
1221                RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1222            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1223            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1224
1225            let t3 = RemlState::apply_hadamard_gram(
1226                &self.x_reduced,
1227                &self.k_reduced,
1228                &kernel.dot_k_reduced,
1229                &v,
1230            );
1231
1232            let y = 2.0 * (t1 + t2 + t3);
1233            out.column_mut(col).assign(&y);
1234        }
1235        out
1236    }
1237
1238    pub(crate) fn hphi_tau_partial_apply(
1239        &self,
1240        x_tau: &Array2<f64>,
1241        kernel: &FirthTauPartialKernel,
1242        rhs: &Array2<f64>,
1243    ) -> Array2<f64> {
1244        let p = self.x_dense.ncols();
1245        if rhs.nrows() != p {
1246            return Array2::<f64>::zeros((p, rhs.ncols()));
1247        }
1248        if rhs.ncols() == 0 || p == 0 {
1249            return Array2::<f64>::zeros((p, rhs.ncols()));
1250        }
1251        // Matrix-free block apply of Hphi,tau|beta:
1252        //   Hphi,tau|beta(V) = 0.5 [ X_tau' r(V) + X' r_tau(V) ].
1253        //
1254        // Tensor identity behind this apply:
1255        //   Hphi,tau|beta = Phi_beta,beta,tau
1256        // and for test vectors b1,b2 (matrix columns V are batched b2's):
1257        //   Phi_beta,beta,tau[b1,b2]
1258        //   = 0.5[
1259        //       tr(I^{-1} I_{b1,b2,tau})
1260        //       - tr(I^{-1} I_{b1,b2} I^{-1} I_tau)
1261        //       - tr(I^{-1} I_{b1,tau} I^{-1} I_{b2})
1262        //       - tr(I^{-1} I_{b2,tau} I^{-1} I_{b1})
1263        //       + 2 tr(I^{-1} I_{b1} I^{-1} I_{b2} I^{-1} I_tau)
1264        //     ].
1265        // This routine evaluates that form in reduced coordinates without forming
1266        // dense 3rd-order tensors explicitly.
1267        let etav = fast_ab(&self.x_dense, rhs);
1268        let etav_tau = fast_ab(x_tau, rhs);
1269        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1270        let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1271            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1272        let m_qv = self.apply_pbar_to_matrix(&qv);
1273        let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1274        let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1275            * &self.h_diag.view().insert_axis(Axis(1))
1276            - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1277        let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1278            + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1279            * self.h_diag.view().insert_axis(Axis(1))
1280            + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1281                * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1282            - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1283                + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1284        0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1285    }
1286
1287    // ═════════════════════════════════════════════════════════════════════════
1288    //  Pair-term primitives for the Firth outer Hessian (Task #13a / #17)
1289    // ═════════════════════════════════════════════════════════════════════════
1290    //
1291    // The REML outer Hessian at a ψ=(ρ,τ) pair needs two Firth contributions
1292    // that are NOT covered by the existing single-τ primitives:
1293    //
1294    //   A.  the pure τ×τ second partial of H_φ at fixed β (pair drift inside
1295    //       the fixed-β second-derivative trace of B_i,j used by
1296    //       build_tau_tau_pair_callback).  This is the Firth analog of the
1297    //       penalty-logdet pair term in the outer-derivative cookbook.
1298    //
1299    //   B.  the β-derivative of (H_φ)_τ|_β in direction v.  This is the
1300    //       fixed-drift-derivative M_i[v] = D_β B_i[v] that
1301    //       compute_drift_deriv_traces uses through the fixed_drift_deriv
1302    //       callback in build_tau_hyper_coords.  It is currently always None
1303    //       in the Firth+Logit path, which is what makes the outer Hessian
1304    //       approximate for Firth-reweighted models.
1305    //
1306    // Both primitives operate in the reduced identifiable subspace (X_r, K_r,
1307    // S_r, Z=X_r) of the dense Firth operator and are matrix-free in n.  They
1308    // do NOT introduce any dense n×n or p×p×p object; every contraction is
1309    // routed through reduced-space Hadamard-Gram applies and rowwise
1310    // bilinear forms, in the same spirit as hphi_tau_partial_apply and
1311    // hphisecond_direction_apply.
1312    //
1313    // Both are exact in the smooth-regime operating point assumed by this
1314    // module: X SPD-full-rank on its identifiable subspace, Q held fixed
1315    // within one outer REML step (active-subspace drift enters only between
1316    // outer iterates), w_i(η) > 0 strictly positive, and β is at the P-IRLS
1317    // solution for the current ψ so β_τ is supplied by the unified evaluator
1318    // via the IFT solve.
1319    //
1320    // Symbol conventions shared with the existing operator code:
1321    //   X_r   := X Q            (reduced identifiable design)
1322    //   W     := diag(w(η))     (Fisher weights), w', w'', w''', w''''
1323    //   I_r   := X_rᵀ W X_r,  K_r := I_r^{-1},  S_r := X_rᵀ X_r
1324    //   M     := X_r K_r X_rᵀ,  P := M ⊙ M,  B := diag(w') X
1325    //   h     := diag(M)
1326    //   X_i   := ∂X/∂τ_i,  X_{r,i} := X_i Q,  η̇_i := X_i β,  etc.
1327    //   δη_v  := X v           (β-direction v),  δη_{τ,v} := X_τ v
1328    //
1329    // H_φ structural form (reduced form of ∇²_β Φ_F):
1330    //   H_φ  =  ½ [ Xᵀ diag(w'' ⊙ h) X  −  Bᵀ P B ]
1331    // where the first term arises from differentiating the Jeffreys gradient
1332    // ½ Xᵀ (w' ⊙ h) once more in β, and the second collects the IFT-mediated
1333    // β-derivative of h = diag(X_r K_r X_rᵀ) through K_r = I_r^{-1}.  This is
1334    // the same form the existing hphi_direction code implements in directional
1335    // form (cf. firth.rs hphi_direction / hphisecond_direction_apply, the
1336    // 9-term D²J₂ expansion with J₂ = BᵀPB).
1337    //
1338    // ─────────────────────────────────────────────────────────────────────────
1339    //  Primitive A — ∂²H_φ/∂τ_i ∂τ_j |_β
1340    // ─────────────────────────────────────────────────────────────────────────
1341    //
1342    // WHAT IT COMPUTES
1343    //   Given a pair of τ-drift designs (X_τ_i, X_τ_j) and optional second
1344    //   design derivative X_{τ_i τ_j}, evaluates
1345    //
1346    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [ ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j
1347    //                                − ∂²(Bᵀ P B)/∂τ_i ∂τ_j ],
1348    //   with Γ := diag(w'' ⊙ h).  Acts on a p×m rhs and returns a p×m block
1349    //   (exact same contract as hphi_tau_partial_apply, but for the *second*
1350    //   mixed τ-derivative at fixed β).
1351    //
1352    // WHY (REML callsite)
1353    //   In the outer Hessian entry Ḧ_{i,j} for τ×τ pair (i,j), the fixed-β
1354    //   second drift of B_i is exactly this primitive (with the Firth sign:
1355    //   B_i = −(H_φ)_τ_i|_β + other likelihood pieces, so ∂²B_i/∂τ_j|_β
1356    //   contributes −∂²H_φ/∂τ_i∂τ_j|_β to the outer Hessian trace).  The
1357    //   existing Firth pair callback at build_tau_tau_pair_callback currently
1358    //   carries zero for this Firth contribution; wiring this primitive into
1359    //   the TauTauPairHyperOperator is the remaining step to make the τ×τ
1360    //   outer Hessian exact in the Firth-reweighted Logit path.
1361    //
1362    // DERIVATION (full chain-rule expansion, at fixed β)
1363    //
1364    //   Building blocks at fixed β (single-τ):
1365    //     İ_i      := ∂I_r/∂τ_i |_β
1366    //                = X_{r,i}ᵀ W X_r + X_rᵀ W X_{r,i} + X_rᵀ Ẇ_i X_r,
1367    //       Ẇ_i    := diag(w' ⊙ η̇_i),   η̇_i := X_i β.
1368    //     K̇_i      := ∂K_r/∂τ_i = −K_r İ_i K_r.
1369    //     ḣ_i      := ∂h/∂τ_i |_β
1370    //                = 2·diag(X_{r,i} K_r X_rᵀ) + diag(X_r K̇_i X_rᵀ).
1371    //     Ṁ_i      := ∂M/∂τ_i |_β
1372    //                = X_{r,i} K_r X_rᵀ + X_r K̇_i X_rᵀ + X_r K_r X_{r,i}ᵀ.
1373    //     Ḃ_i      := ∂B/∂τ_i |_β
1374    //                = diag(w'' ⊙ η̇_i) X + diag(w') X_i.
1375    //     Ṗ_i      := ∂P/∂τ_i = 2 (M ⊙ Ṁ_i).
1376    //     Γ̇_i     := ∂Γ/∂τ_i |_β = diag(w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i).
1377    //
1378    //   Second-order building blocks:
1379    //     η̈_{ij}  := X_{ij} β                   (0 for design linear in τ)
1380    //     Ẅ_{ij} := diag(w'' ⊙ η̇_i ⊙ η̇_j
1381    //                     + w' ⊙ η̈_{ij})
1382    //
1383    //     Ï_{ij}   := ∂²I_r/∂τ_i ∂τ_j |_β
1384    //                = X_{r,ij}ᵀ W X_r  +  X_rᵀ W X_{r,ij}
1385    //                 + X_{r,i}ᵀ W X_{r,j}  +  X_{r,j}ᵀ W X_{r,i}
1386    //                 + X_{r,i}ᵀ Ẇ_j X_r  +  X_rᵀ Ẇ_j X_{r,i}
1387    //                 + X_{r,j}ᵀ Ẇ_i X_r  +  X_rᵀ Ẇ_i X_{r,j}
1388    //                 + X_rᵀ Ẅ_{ij} X_r.
1389    //
1390    //     K̈_{ij}  := ∂²K_r/∂τ_i ∂τ_j
1391    //                = −K_r Ï_{ij} K_r
1392    //                  + K_r İ_i K_r İ_j K_r
1393    //                  + K_r İ_j K_r İ_i K_r.
1394    //
1395    //     M̈_{ij}  := X_{r,ij} K_r X_rᵀ + X_r K_r X_{r,ij}ᵀ
1396    //                 + X_{r,i} K̇_j X_rᵀ + X_r K̇_j X_{r,i}ᵀ
1397    //                 + X_{r,j} K̇_i X_rᵀ + X_r K̇_i X_{r,j}ᵀ
1398    //                 + X_{r,i} K_r X_{r,j}ᵀ + X_{r,j} K_r X_{r,i}ᵀ
1399    //                 + X_r K̈_{ij} X_rᵀ.
1400    //
1401    //     P̈_{ij}  := ∂²P/∂τ_i ∂τ_j
1402    //                = 2 (Ṁ_i ⊙ Ṁ_j) + 2 (Ṁ_j ⊙ Ṁ_i) + 2 (M ⊙ M̈_{ij})
1403    //                = 4 (Ṁ_i ⊙ Ṁ_j) + 2 (M ⊙ M̈_{ij}).
1404    //
1405    //     ḧ_{ij}  := ∂²h/∂τ_i ∂τ_j |_β
1406    //                = 2·diag(X_{r,ij} K_r X_rᵀ)
1407    //                 + diag(X_r K̈_{ij} X_rᵀ)
1408    //                 + 2·diag(X_{r,i} K̇_j X_rᵀ)
1409    //                 + 2·diag(X_{r,j} K̇_i X_rᵀ)
1410    //                 + 2·diag(X_{r,i} K_r X_{r,j}ᵀ).
1411    //
1412    //     B̈_{ij}  := ∂²B/∂τ_i ∂τ_j |_β
1413    //                = diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_{ij}) X
1414    //                 + diag(w'' ⊙ η̇_i) X_j
1415    //                 + diag(w'' ⊙ η̇_j) X_i
1416    //                 + diag(w') X_{ij}.
1417    //
1418    //     Γ̈_{ij} := ∂²Γ/∂τ_i ∂τ_j |_β
1419    //                = diag( w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
1420    //                       + w''' ⊙ η̈_{ij} ⊙ h
1421    //                       + w''' ⊙ η̇_i ⊙ ḣ_j
1422    //                       + w''' ⊙ η̇_j ⊙ ḣ_i
1423    //                       + w'' ⊙ ḧ_{ij} ).
1424    //
1425    //   Diagonal-term expansion (the Xᵀ Γ X branch):
1426    //
1427    //     ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  =
1428    //         X_{ij}ᵀ Γ X  + Xᵀ Γ X_{ij}
1429    //       + X_iᵀ Γ X_j  + X_jᵀ Γ X_i
1430    //       + X_iᵀ Γ̇_j X  + Xᵀ Γ̇_j X_i
1431    //       + X_jᵀ Γ̇_i X  + Xᵀ Γ̇_i X_j
1432    //       + Xᵀ Γ̈_{ij} X.
1433    //
1434    //   9-term expansion for the BᵀPB branch (structurally identical to
1435    //   the existing β×β D²J₂[u,v] at firth.rs:~820-830 with (u,v)
1436    //   substituted by (τ_i, τ_j) and the appropriate Ḃ, B̈, Ṗ, P̈):
1437    //
1438    //     D²(BᵀPB)[τ_i,τ_j]  =
1439    //         B̈_{ij}ᵀ  P    B      +  Bᵀ       P    B̈_{ij}
1440    //       + Ḃ_iᵀ    P    Ḃ_j   +  Ḃ_jᵀ    P    Ḃ_i
1441    //       + Ḃ_iᵀ    Ṗ_j  B      +  Bᵀ       Ṗ_j  Ḃ_i
1442    //       + Ḃ_jᵀ    Ṗ_i  B      +  Bᵀ       Ṗ_i  Ḃ_j
1443    //       + Bᵀ       P̈_{ij} B.
1444    //
1445    //   Combining,
1446    //
1447    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [
1448    //         ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  −  D²(BᵀPB)[τ_i, τ_j]
1449    //     ].
1450    //
1451    // IMPLEMENTATION SKETCH (for 13b)
1452    //   • Build per-direction reduced quantities for τ_i and τ_j:
1453    //       (x_tau_reduced, η̇, İ, K̇, Ṁ operator pieces, ḣ, b_uvec = w''⊙η̇).
1454    //     The existing `dot_i_and_h_from_reduced` yields İ and ḣ already;
1455    //     the per-direction "A_u" analog is A_τ = K_r İ K_r, matching the
1456    //     FirthDirection form used by hphisecond_direction_apply.
1457    //   • Use apply_hadamard_gram_to_matrix with
1458    //       (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_τ_i), (K_r, A_τ_j),
1459    //                             (A_τ_i, A_τ_j) }
1460    //     to realize P-products, Ṗ_τ-products, and the (Ṁ_i ⊙ Ṁ_j) piece of
1461    //     P̈_{ij} without forming any n×n dense intermediate.
1462    //   • The pure-second piece `X_r K̈_{ij} X_rᵀ` decomposes into three
1463    //     reduced triple products (K_r Ï_{ij} K_r, K_r İ_i K_r İ_j K_r, and
1464    //     its transpose).  All are size-r×r in reduced coordinates.
1465    //   • For design-linear-in-τ smooths, X_{ij}=0 and η̈_{ij}=0, which
1466    //     prunes many sub-terms; callers who have X_{τ_i τ_j} available
1467    //     should pass it so the primitive remains exact on curved designs.
1468    //
1469    // ─────────────────────────────────────────────────────────────────────────
1470    //  Primitive B — D_β((H_φ)_τ|_β)[v]
1471    // ─────────────────────────────────────────────────────────────────────────
1472    //
1473    // WHAT IT COMPUTES
1474    //   Given a single τ-drift design X_τ, the β-fixed Firth partial
1475    //   (H_φ)_τ|_β encoded by FirthTauPartialKernel, and a β-direction
1476    //   vector v (of length p), returns the β-derivative of (H_φ)_τ|_β
1477    //   applied to an rhs block (so output is p×m, matching the pair's
1478    //   fixed_drift_deriv callback signature DriftDerivResult).  In
1479    //   symbols:
1480    //
1481    //     D_β((H_φ)_τ|_β)[v]  =  ½ [ D_β{(∂(XᵀΓX)/∂τ)|_β}[v]
1482    //                                 −  D_β{(∂(BᵀPB)/∂τ)|_β}[v] ].
1483    //
1484    // WHY (REML callsite)
1485    //   In the exact outer Hessian assembly (compute_drift_deriv_traces in
1486    //   unified.rs), the Ḧ_{ij} entry picks up
1487    //     tr(G_ε · D_β B_i[v_j])  +  tr(G_ε · D_β B_j[v_i]).
1488    //   For τ coordinates in the Firth+Logit path, B_τ = (penalty / design
1489    //   pieces) − (H_φ)_τ|_β, so the Firth share of D_β B_τ[v] is
1490    //     − D_β((H_φ)_τ|_β)[v].
1491    //   Hooking this primitive up through a FixedDriftDerivFn (returning
1492    //   DriftDerivResult::Dense of this p×p β-v action) is exactly what
1493    //   lets build_tau_hyper_coords pass a non-None fixed_drift_deriv
1494    //   closure into the unified evaluator, closing the approximation gap
1495    //   that firth_pair_terms_unavailable currently tracks.
1496    //
1497    // DERIVATION (β-derivative of each τ-partial term in direction v)
1498    //
1499    //   β enters only through η=Xβ, so designs X, X_τ, Q, X_r are all
1500    //   β-independent; D_β acts on w(η) and its derivatives, on I_r, K_r,
1501    //   M, h, and on η̇_τ = X_τ β.
1502    //
1503    //   Primary β-derivative building blocks (matches FirthDirection with
1504    //   deta := δη_v = X v):
1505    //     I'_v  := D_β I_r[v] = X_rᵀ diag(w' ⊙ δη_v) X_r      (g_u_reduced)
1506    //     A_v   := D_β K_r[v] = −K_r I'_v K_r                  (a_u_reduced)
1507    //     dh_v  := D_β h[v]    = −diag(X_r K_r I'_v K_r X_rᵀ)
1508    //                          = diag(X_r A_v X_rᵀ)            (dh)
1509    //     (w')_v  := D_β w'[v]  = w''  ⊙ δη_v
1510    //     (w'')_v := D_β w''[v] = w''' ⊙ δη_v
1511    //     (w''')_v:= D_β w'''[v]= w''''⊙ δη_v
1512    //     δη_{τ,v} := D_β(η̇_τ)[v] = X_τ v
1513    //
1514    //   Mixed τ-β pieces:
1515    //     D_β(İ_τ)[v]
1516    //       = X_{r,τ}ᵀ diag(w'' ⊙ δη_v) X_r
1517    //        + X_rᵀ diag(w'' ⊙ δη_v) X_{r,τ}
1518    //        + X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v
1519    //                     + w' ⊙ δη_{τ,v}) X_r.
1520    //     D_β(K̇_τ)[v]
1521    //       = −( A_v İ_τ K_r  +  K_r D_β(İ_τ)[v] K_r
1522    //             +  K_r İ_τ A_v ).
1523    //     D_β(Ṁ_τ)[v]
1524    //       = X_{r,τ} A_v X_rᵀ
1525    //        + X_r D_β(K̇_τ)[v] X_rᵀ
1526    //        + X_r A_v X_{r,τ}ᵀ.
1527    //     D_β(ḣ_τ)[v]
1528    //       = 2·diag(X_{r,τ} A_v X_rᵀ)
1529    //        + diag(X_r D_β(K̇_τ)[v] X_rᵀ).
1530    //
1531    //   Diagonal-term β-derivative ( (X_τᵀΓX + XᵀΓX_τ + XᵀΓ̇_τ X) branch ):
1532    //     D_β(X_τᵀ Γ X + Xᵀ Γ X_τ)[v]
1533    //       = X_τᵀ Γ_v X + Xᵀ Γ_v X_τ,
1534    //       Γ_v  := D_β Γ[v] = diag((w'')_v ⊙ h + w'' ⊙ dh_v)
1535    //                        = diag(w''' ⊙ δη_v ⊙ h + w'' ⊙ dh_v).
1536    //     D_β(Xᵀ Γ̇_τ X)[v]
1537    //       = Xᵀ Γ̇_{τ,v} X,
1538    //       Γ̇_{τ,v}
1539    //        := D_β Γ̇_τ[v]
1540    //         = diag( (w''')_v ⊙ η̇_τ ⊙ h
1541    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1542    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1543    //                 + (w'')_v ⊙ ḣ_τ
1544    //                 + w'' ⊙ D_β(ḣ_τ)[v] )
1545    //         = diag( w'''' ⊙ η̇_τ ⊙ δη_v ⊙ h
1546    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1547    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1548    //                 + w''' ⊙ δη_v ⊙ ḣ_τ
1549    //                 + w'' ⊙ D_β(ḣ_τ)[v] ).
1550    //
1551    //   Cross-coupling τ-β pieces for B:
1552    //     B_v  := D_β B[v]   = diag(w'' ⊙ δη_v) X               (b_uvec)
1553    //     B_τ  := ∂B/∂τ|_β   = diag(w'' ⊙ η̇_τ) X
1554    //                         + diag(w') X_τ.
1555    //     B_{τ,v}
1556    //         := D_β B_τ[v]  = diag( w''' ⊙ η̇_τ ⊙ δη_v
1557    //                                 + w'' ⊙ δη_{τ,v} ) X
1558    //                         + diag(w'' ⊙ δη_v) X_τ.
1559    //
1560    //   BᵀPB branch — 9 terms, obtained by applying the product rule to
1561    //   ∂(BᵀPB)/∂τ = Ḃ_τᵀ P B + Bᵀ Ṗ_τ B + Bᵀ P Ḃ_τ and then taking
1562    //   D_β(·)[v] of each factor:
1563    //
1564    //     D_β(Ḃ_τᵀ P B)[v]   = B_{τ,v}ᵀ P B + Ḃ_τᵀ P_v B + Ḃ_τᵀ P B_v,
1565    //     D_β(Bᵀ Ṗ_τ B)[v]   = B_vᵀ Ṗ_τ B  + Bᵀ P_{τ,v} B + Bᵀ Ṗ_τ B_v,
1566    //     D_β(Bᵀ P Ḃ_τ)[v]   = B_vᵀ P Ḃ_τ + Bᵀ P_v Ḃ_τ + Bᵀ P B_{τ,v}.
1567    //
1568    //   Here Ḃ_τ = B_τ above, and
1569    //     P_v := D_β P[v]         = 2 (M ⊙ M_v),   M_v = X_r A_v X_rᵀ.
1570    //     Ṗ_τ := ∂P/∂τ|_β         = 2 (M ⊙ M_τ),
1571    //       M_τ = X_{r,τ} K_r X_rᵀ + X_r K̇_τ X_rᵀ + X_r K_r X_{r,τ}ᵀ.
1572    //     P_{τ,v} := D_β(Ṗ_τ)[v]  = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
1573    //       M_{τ,v} = X_{r,τ} A_v X_rᵀ + X_r D_β(K̇_τ)[v] X_rᵀ + X_r A_v X_{r,τ}ᵀ.
1574    //
1575    //   Final primitive:
1576    //
1577    //     D_β((H_φ)_τ|_β)[v]  =  ½ [
1578    //           X_τᵀ Γ_v X  + Xᵀ Γ_v X_τ  + Xᵀ Γ̇_{τ,v} X
1579    //         −  (9-term BᵀPB β-τ expansion above)
1580    //     ].
1581    //
1582    //   Applied to an rhs block `R ∈ ℝ^{p × m}`, each Xᵀ(…) X R collapses
1583    //   to n-length row scalings of (X R) followed by Xᵀ; each Bᵀ P B
1584    //   variant uses apply_hadamard_gram_to_matrix with the correct
1585    //   (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_v), (K_r, K̇_τ),
1586    //     (K_r, D_β(K̇_τ)[v]), (A_v, K̇_τ), (K_r, K̇_τ) } to realize
1587    //   P, P_v, Ṗ_τ, P_{τ,v} actions.  All operators are r×r in reduced
1588    //   coordinates, matching the existing apply cost profile.
1589    //
1590    // IMPLEMENTATION SKETCH (for 13c)
1591    //   • Build `FirthDirection` from deta = X v (reuses existing
1592    //     direction_from_deta, giving I'_v, A_v, dh_v, b_uvec).
1593    //   • Build β-derivatives of the τ-specific fields of
1594    //     FirthTauPartialKernel (dotw1, dotw2, dot_h_partial, dot_k_reduced,
1595    //     and the implicit M_τ reduced-coords operator).  These become a
1596    //     new FirthTauBetaPartialKernel attached to the prepared state.
1597    //   • The apply step is then algebraically identical to
1598    //     hphi_tau_partial_apply but with every W-tensor weight replaced by
1599    //     its β-derivative in v, and every (M, K_r)-Gram replaced by the
1600    //     appropriate β-derivative Gram above.  The structure is regular
1601    //     enough that a single helper, shared with Primitive A, can absorb
1602    //     both pair dispatches.
1603    //
1604    // NOTE ON DESIGN-LINEAR SMOOTHS
1605    //   For the common case of design-linear-in-τ smooths (scale-moving
1606    //   anisotropic bases), X_i and X_τ are constant in τ, so X_{ij}=0 and
1607    //   η̈_{ij}=0.  The primitives collapse to their W-reweighted cores but
1608    //   remain matrix-free; no special fast path is needed because the
1609    //   zeroed terms simply drop out of the Hadamard-Gram assembly.
1610    //
1611    // ═════════════════════════════════════════════════════════════════════════
1612
1613    /// Primitive A — prepare step: assemble the τ_i × τ_j reduced kernel.
1614    ///
1615    /// Consumes the per-direction partial quantities produced by
1616    /// `dot_i_and_h_from_reduced` for τ_i and τ_j (plus an optional second
1617    /// design derivative X_{τ_i τ_j}), and returns a cached kernel carrying
1618    /// the M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij}-related reduced
1619    /// coordinates needed by `hphi_tau_tau_partial_apply`.
1620    ///
1621    /// This signature mirrors `hphi_tau_partial_prepare_from_partials` for
1622    /// consistency; the pair version needs both directions simultaneously
1623    /// (to realize the 9-term D² expansion) and therefore owns both
1624    /// `x_tau_{i,j}_reduced` and their η̇_i / η̇_j.
1625    ///
1626    pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1627        &self,
1628        x_tau_i_reduced: Array2<f64>,
1629        x_tau_j_reduced: Array2<f64>,
1630        deta_i_partial: &Array1<f64>,
1631        deta_j_partial: &Array1<f64>,
1632        dot_h_i_partial: Array1<f64>,
1633        dot_h_j_partial: Array1<f64>,
1634        dot_i_i_partial: Array2<f64>,
1635        dot_i_j_partial: Array2<f64>,
1636        x_tau_tau_reduced: Option<Array2<f64>>,
1637        deta_ij_partial: Option<Array1<f64>>,
1638    ) -> FirthTauTauPartialKernel {
1639        // K̇_i = -K_r İ_i K_r;  K̇_j = -K_r İ_j K_r.
1640        let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1641        let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1642        FirthTauTauPartialKernel {
1643            x_tau_i_reduced,
1644            x_tau_j_reduced,
1645            deta_i_partial: deta_i_partial.clone(),
1646            deta_j_partial: deta_j_partial.clone(),
1647            dot_h_i_partial,
1648            dot_h_j_partial,
1649            dot_k_i_reduced,
1650            dot_k_j_reduced,
1651            dot_i_i_partial,
1652            dot_i_j_partial,
1653            x_tau_tau_reduced,
1654            deta_ij_partial,
1655        }
1656    }
1657
1658    /// Primitive A — apply step: evaluate ∂²H_φ/∂τ_i ∂τ_j |_β against a p×m
1659    /// rhs block, returning a p×m block.
1660    ///
1661    /// Contract mirrors `hphi_tau_partial_apply`: the caller passes the two
1662    /// τ-drift designs and the prepared kernel, and receives the fixed-β
1663    /// second-τ Firth drift as a dense p×m action.  Matrix-free in n.
1664    ///
1665    pub(crate) fn hphi_tau_tau_partial_apply(
1666        &self,
1667        x_tau_i: &Array2<f64>,
1668        x_tau_j: &Array2<f64>,
1669        kernel: &FirthTauTauPartialKernel,
1670        rhs: &Array2<f64>,
1671    ) -> Array2<f64> {
1672        let p = self.x_dense.ncols();
1673        if rhs.nrows() != p {
1674            return Array2::<f64>::zeros((p, rhs.ncols()));
1675        }
1676        if rhs.ncols() == 0 || p == 0 {
1677            return Array2::<f64>::zeros((p, rhs.ncols()));
1678        }
1679        let n = self.x_dense.nrows();
1680        let m = rhs.ncols();
1681
1682        // Short aliases.
1683        let z = &self.x_reduced;
1684        let x_r = &self.x_reduced;
1685        let k = &self.k_reduced;
1686        let x_ri = &kernel.x_tau_i_reduced;
1687        let x_rj = &kernel.x_tau_j_reduced;
1688        let deta_i = &kernel.deta_i_partial;
1689        let deta_j = &kernel.deta_j_partial;
1690        let dh_i = &kernel.dot_h_i_partial;
1691        let dh_j = &kernel.dot_h_j_partial;
1692        let dot_k_i = &kernel.dot_k_i_reduced;
1693        let dot_k_j = &kernel.dot_k_j_reduced;
1694        let dot_i_i = &kernel.dot_i_i_partial;
1695        let dot_i_j = &kernel.dot_i_j_partial;
1696
1697        // Optional second-design pieces: default to zero when the design is
1698        // τ-linear (η̈_{ij} = 0, X_{ij} = 0).
1699        let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1700        let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1701        let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1702        let zeros_n = Array1::<f64>::zeros(n);
1703        let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1704
1705        // ─────────────────────────────────────────────────────────────────
1706        //  η̇ vectors in β-rhs space (η_V := X V, η_{i,V} := X_i V, etc.)
1707        // ─────────────────────────────────────────────────────────────────
1708        let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1709            (n, m, p),
1710            (n, m, p),
1711            (n, m, p),
1712        ]) {
1713            let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1714                || fast_ab(&self.x_dense, rhs),
1715                || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1716            );
1717            (eta_v, eta_i_v, eta_j_v)
1718        } else {
1719            (
1720                fast_ab(&self.x_dense, rhs),
1721                fast_ab(x_tau_i, rhs),
1722                fast_ab(x_tau_j, rhs),
1723            )
1724        }; // n×m blocks
1725        // X_{ij} V from the reduced second-derivative design:
1726        //   reduce_explicit_design: X_{r,τ} = diag(√a) X_τ Q,
1727        //   invert:  X_{ij} = diag(1/√a) X_{r,ij} Qᵀ.
1728        let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1729            let qt_v = fast_atb(&self.q_basis, rhs); // r×m
1730            let mut out = fast_ab(x_rij, &qt_v); // n×m in sqrt(a)-scaled space
1731            RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1732                &mut out,
1733                self.observation_weight_sqrt.as_ref(),
1734            );
1735            out
1736        } else {
1737            Array2::<f64>::zeros((n, m))
1738        };
1739
1740        // ─────────────────────────────────────────────────────────────────
1741        //  Shared per-direction reduced operators
1742        //    A_τ = K İ K   (reduced analog of T_τ = K I_τ K)
1743        //    K̇_τ = -A_τ  (already cached)
1744        // ─────────────────────────────────────────────────────────────────
1745        let a_i_reduced = -dot_k_i; // K İ_i K = -K̇_i
1746        let a_j_reduced = -dot_k_j;
1747
1748        // ─────────────────────────────────────────────────────────────────
1749        //  Ï_{ij}  — second cross derivative of reduced Fisher
1750        // ─────────────────────────────────────────────────────────────────
1751        //   Ï_{ij} = X_{r,ij}ᵀ W X_r + X_rᵀ W X_{r,ij}
1752        //          + X_{r,i}ᵀ W X_{r,j} + X_{r,j}ᵀ W X_{r,i}
1753        //          + X_{r,i}ᵀ Ẇ_j X_r + X_rᵀ Ẇ_j X_{r,i}
1754        //          + X_{r,j}ᵀ Ẇ_i X_r + X_rᵀ Ẇ_i X_{r,j}
1755        //          + X_rᵀ Ẅ_{ij} X_r.
1756        //   Ẇ_α   = diag(w' ⊙ η̇_α),  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_ij).
1757        let dw_i = &self.w1 * deta_i;
1758        let dw_j = &self.w1 * deta_j;
1759        let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
1760        let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
1761        if x_tau_tau_is_some {
1762            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
1763            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
1764        }
1765        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
1766        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
1767        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
1768        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
1769        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
1770        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
1771        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
1772
1773        // K̈_{ij} = −K Ï K + K İ_i K İ_j K + K İ_j K İ_i K.
1774        //   Using K İ_α K = −K̇_α = a_α_reduced, the two product terms collapse to
1775        //   K İ_i K İ_j K = a_i_reduced · İ_j · K,
1776        //   K İ_j K İ_i K = a_j_reduced · İ_i · K.
1777        let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
1778            + a_i_reduced.dot(dot_i_j).dot(k)
1779            + a_j_reduced.dot(dot_i_i).dot(k);
1780
1781        // ─────────────────────────────────────────────────────────────────
1782        //  ḧ_{ij}
1783        // ─────────────────────────────────────────────────────────────────
1784        //   ḧ_ij = 2 diag(X_{r,ij} K X_rᵀ)
1785        //        + diag(X_r K̈_ij X_rᵀ)
1786        //        + 2 diag(X_{r,i} K̇_j X_rᵀ)
1787        //        + 2 diag(X_{r,j} K̇_i X_rᵀ)
1788        //        + 2 diag(X_{r,i} K X_{r,j}ᵀ).
1789        // Using diag(A Bᵀ) = rowwise_dot(A, B):
1790        let dh_ij: Array1<f64> = {
1791            let r = k.ncols();
1792            let can_join = RemlState::should_join_independent_dense_products(&[
1793                (n, r, r),
1794                (n, r, r),
1795                (n, r, r),
1796                (n, r, r),
1797            ]);
1798            let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
1799                let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
1800                    || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
1801                    || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
1802                );
1803                (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
1804            } else {
1805                (
1806                    fast_ab(x_r, &k_ddot),
1807                    fast_ab(x_ri, dot_k_j),
1808                    fast_ab(x_rj, dot_k_i),
1809                    fast_ab(x_ri, k),
1810                )
1811            };
1812
1813            let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
1814            acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
1815            acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
1816            acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
1817            if x_tau_tau_is_some {
1818                let rij_k = fast_ab(x_rij, k);
1819                acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
1820            }
1821            acc
1822        };
1823
1824        // ─────────────────────────────────────────────────────────────────
1825        //  Γ, Γ̇_i, Γ̇_j, Γ̈_ij  (diagonal row-weight n-vectors)
1826        // ─────────────────────────────────────────────────────────────────
1827        //   γ        = w'' ⊙ h
1828        //   γ̇_i     = w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i
1829        //   γ̈_ij   = w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
1830        //            + w''' ⊙ η̈_ij ⊙ h
1831        //            + w''' ⊙ η̇_i ⊙ ḣ_j
1832        //            + w''' ⊙ η̇_j ⊙ ḣ_i
1833        //            + w'' ⊙ ḧ_ij
1834        let gamma = &self.w2 * &self.h_diag;
1835        let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
1836        let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
1837        let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
1838            + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
1839                + &(&(&self.w3 * deta_i) * dh_j)
1840                + &(&(&self.w3 * deta_j) * dh_i)
1841                + &(&self.w2 * &dh_ij));
1842
1843        // ─────────────────────────────────────────────────────────────────
1844        //  Diagonal-term β-rhs contributions:
1845        //    ∂²(XᵀΓX)/∂τ_i∂τ_j · V
1846        //  = X_{ij}ᵀ (γ ⊙ η_V)       + Xᵀ (γ ⊙ η_{ij,V})         [if X_ij]
1847        //    + X_iᵀ (γ ⊙ η_{j,V})    + X_jᵀ (γ ⊙ η_{i,V})
1848        //    + X_iᵀ (γ̇_j ⊙ η_V)     + Xᵀ (γ̇_j ⊙ η_{i,V})
1849        //    + X_jᵀ (γ̇_i ⊙ η_V)     + Xᵀ (γ̇_i ⊙ η_{j,V})
1850        //    + Xᵀ (γ̈_ij ⊙ η_V).
1851        // ─────────────────────────────────────────────────────────────────
1852        let mut diag_term = Array2::<f64>::zeros((p, m));
1853        let gamma_col = gamma.view().insert_axis(Axis(1));
1854        let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
1855        let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
1856        let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
1857
1858        // X_iᵀ (γ ⊙ η_{j,V}) + X_jᵀ (γ ⊙ η_{i,V})
1859        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
1860        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
1861        // X_iᵀ (γ̇_j ⊙ η_V) + X_jᵀ (γ̇_i ⊙ η_V)
1862        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
1863        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
1864        // Xᵀ (γ̇_j ⊙ η_{i,V}) + Xᵀ (γ̇_i ⊙ η_{j,V})
1865        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
1866        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
1867        // Xᵀ (γ̈_ij ⊙ η_V)
1868        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
1869        // X_{ij}ᵀ (γ ⊙ η_V) + Xᵀ (γ ⊙ η_{ij,V})
1870        if x_tau_tau_is_some {
1871            // X_{ij}ᵀ = Q X_{r,ij}ᵀ · diag(1/√a)  (inverse of the reduce shim),
1872            // but caller supplies the reduced second-derivative design.  We
1873            // form X_{ij}ᵀ Y as q_basis · (X_{r,ij}ᵀ · diag(1/√a)·Y) = Q · X_{r,ij}ᵀ (Y unscaled).
1874            // When no observation weights, X_{r,ij} = X_{ij} Q and
1875            //   X_{ij}ᵀ Y = Q X_{r,ij}ᵀ Y.
1876            let y: Array2<f64> = &eta_v * &gamma_col;
1877            let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
1878                let mut y_scaled = y.clone();
1879                RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1880                    &mut y_scaled,
1881                    self.observation_weight_sqrt.as_ref(),
1882                );
1883                self.q_basis.dot(&x_rij.t().dot(&y_scaled))
1884            } else {
1885                self.q_basis.dot(&x_rij.t().dot(&y))
1886            };
1887            diag_term = diag_term + xt_ij_y;
1888            diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
1889        }
1890
1891        // ─────────────────────────────────────────────────────────────────
1892        //  BᵀPB branch — 9-term expansion.
1893        //
1894        //  Represent each B-like operator as an "n-row scaling vector for the
1895        //  X part plus tails along X_τ and X_{ij}".  For rhs V, define the
1896        //  row-scaled η-space blocks R(B) = diag(scale) X V + tails.  Then
1897        //  Bᵀ (P action) R is assembled by row-scaling and left-multiplying
1898        //  the appropriate full designs.
1899        // ─────────────────────────────────────────────────────────────────
1900
1901        // B V row-block (eta-space):  B V = diag(w') X V.
1902        let w1_col = self.w1.view().insert_axis(Axis(1));
1903        let b_v = &eta_v * &w1_col;
1904
1905        // Ḃ_i V = diag(w'' ⊙ η̇_i) X V + diag(w') X_i V.
1906        let w2_deta_i = &self.w2 * deta_i;
1907        let w2_deta_j = &self.w2 * deta_j;
1908        let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
1909        let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
1910        let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
1911        let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
1912
1913        // B̈_{ij} V =
1914        //   diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_ij) X V
1915        //   + diag(w'' ⊙ η̇_i) X_j V
1916        //   + diag(w'' ⊙ η̇_j) X_i V
1917        //   + diag(w') X_{ij} V.
1918        let w3_didj = &(&self.w3 * deta_i) * deta_j;
1919        let w2_dij = &self.w2 * deta_ij;
1920        let bddot_scale = &w3_didj + &w2_dij;
1921        let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
1922        let mut bddot_ij_v = &eta_v * &bddot_scale_col;
1923        bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
1924        bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
1925        bddot_ij_v += &(&eta_ij_v * &w1_col);
1926
1927        // P V  (columnwise, using K ⊙ K Hadamard gram on Z = X_r).
1928        let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
1929        let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
1930
1931        // Ṗ_i, Ṗ_j applied to B V, Ḃ_j V, Ḃ_i V — use the existing
1932        // apply_mtau_to_matrix helper, which computes 2(M ⊙ Ṁ_τ) · mat.
1933        //
1934        // Construct a lightweight "FirthTauPartialKernel"-shaped tuple only for
1935        // apply_mtau_to_matrix; we mirror its input contract inline to avoid
1936        // owning a FirthTauPartialKernel copy here.
1937        let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
1938        let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
1939        let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
1940        let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
1941
1942        // P Ḃ_j V and P Ḃ_i V.
1943        let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
1944        let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
1945
1946        // P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V  + 2 (M ⊙ M̈_{ij}) V.
1947        let p_ddot_b_v = self.apply_p_ddot_ij(
1948            x_r,
1949            x_ri,
1950            x_rj,
1951            x_rij,
1952            k,
1953            dot_k_i,
1954            dot_k_j,
1955            &k_ddot,
1956            x_tau_tau_is_some,
1957            &b_v,
1958        );
1959
1960        // Assemble 9 terms of D²(BᵀPB)[τ_i, τ_j] · V.
1961        //   term1 = B̈_ijᵀ P B V + Bᵀ P B̈_ij V
1962        //   term2 = Ḃ_iᵀ P Ḃ_j V + Ḃ_jᵀ P Ḃ_i V
1963        //   term3 = Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
1964        //   term4 = Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
1965        //   term5 = Bᵀ P̈_ij B V
1966        //
1967        // "Bᵀ Q V" with B = diag(w') X equals left_scaled_xt(w1, Q V).
1968        // "Ḃ_iᵀ Q V" = diag(w'' ⊙ η̇_i) X acting on the left, plus
1969        //              diag(w') X_i on the left.  In transpose:
1970        //   Ḃ_iᵀ Q V = Xᵀ (diag(w'' ⊙ η̇_i) Q V) + X_iᵀ (diag(w') Q V).
1971        // "B̈_ijᵀ Q V" mirrors B̈_ij above in transpose.
1972
1973        let apply_bdot_tau_t =
1974            |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
1975                let scale_col = scale_deta.view().insert_axis(Axis(1));
1976                self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
1977            };
1978
1979        let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
1980            let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
1981            let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
1982            out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
1983            out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
1984            if x_tau_tau_is_some {
1985                // X_{ij}ᵀ (w1 ⊙ Q V)
1986                let y = q_v * &w1_col;
1987                let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
1988                    let mut y_scaled = y.clone();
1989                    RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1990                        &mut y_scaled,
1991                        self.observation_weight_sqrt.as_ref(),
1992                    );
1993                    self.q_basis.dot(&x_rij.t().dot(&y_scaled))
1994                } else {
1995                    self.q_basis.dot(&x_rij.t().dot(&y))
1996                };
1997                out = out + contrib;
1998            }
1999            out
2000        };
2001
2002        // term1
2003        let t1a = apply_bddot_ij_t(&p_bv);
2004        let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2005        // term2
2006        let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2007        let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2008        // term3: Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
2009        let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2010        let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2011        // term4: Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
2012        let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2013        let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2014        // term5
2015        let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2016
2017        let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2018
2019        0.5 * (diag_term - d2_bpb)
2020    }
2021
2022    /// Pair-level exact Firth kernel at fixed β for a (τ_i, τ_j) outer
2023    /// coordinate pair.
2024    ///
2025    /// Returns the two SCALAR- and P-VECTOR-valued second-derivative
2026    /// objects that the unified REML evaluator threads into
2027    /// `HyperCoordPair::{a,g}` as additive Firth contributions, plus an
2028    /// optional prepared Primitive-A `FirthTauTauPartialKernel` that the
2029    /// pair-callback can reuse for the `b_operator` action.
2030    ///
2031    /// ═════════════════════════════════════════════════════════════════════
2032    ///  DERIVATIONS (fixed β, reduced-basis identifiable coords).
2033    ///
2034    ///  Φ = 0.5 log|I_r| − 0.5 log|S_r|,   K_r = I_r⁻¹,   G_r = diag(S_r⁻¹)
2035    ///  Φ_{τ_i}|β = 0.5 tr(K_r İ_{r,i}) − 0.5 tr(G_r Ṡ_{r,i}).
2036    ///
2037    /// ┌── pair.a scalar Φ_{τ_i τ_j}|β ────────────────────────────────────┐
2038    ///  ∂/∂τ_j [0.5 tr(K_r İ_{r,i})]
2039    ///    = 0.5 tr(K̇_{r,j} İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2040    ///    = −0.5 tr(K_r İ_{r,j} K_r İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2041    ///
2042    ///  ∂/∂τ_j [−0.5 tr(G_r Ṡ_{r,i})]
2043    ///    = −0.5 tr(Ġ_{r,j} Ṡ_{r,i}) − 0.5 tr(G_r S̈_{r,ij})
2044    ///  (G_r diagonal in canonical basis →
2045    ///   Ġ_{r,j}_kk = −G_r_kk² · diag(Ṡ_{r,j})_kk.)
2046    ///
2047    ///  Ï_{r,ij} is the same 9-term Fisher cross used by Primitive A
2048    ///  (see `hphi_tau_tau_partial_apply`:i_ddot block).
2049    ///
2050    ///  S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2051    ///            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}.
2052    /// └─────────────────────────────────────────────────────────────────┘
2053    ///
2054    /// ┌── pair.g p-vector (gΦ)_{τ_i τ_j}|β ───────────────────────────────┐
2055    ///  (gΦ)_{τ_i} = 0.5 X_{τ_i}^T (w1 ⊙ h)
2056    ///              + 0.5 X^T [ (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i ]
2057    ///
2058    ///  Differentiating wrt τ_j at fixed β, using η̇_α = X_α β, η̈_{ij} =
2059    ///  X_{ij} β (when x_tau_tau is provided, else 0), and ḣ_α, ḧ_{ij}
2060    ///  from Primitive A:
2061    ///
2062    ///  term_A = 0.5 ∂/∂τ_j [X_{τ_i}^T (w1 ⊙ h)]
2063    ///        = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)          [if X_{ij} present]
2064    ///        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2065    ///
2066    ///  term_B = 0.5 ∂/∂τ_j [X^T · v_{τ_i}] with
2067    ///            v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2068    ///        = 0.5 X_{τ_j}^T v_{τ_i}
2069    ///        + 0.5 X^T · v̇_{τ_i,τ_j}
2070    ///
2071    ///  where the inner derivative
2072    ///  v̇_{τ_i,τ_j} = (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h   (from ∂w2 = w3 ⊙ η̇_j)
2073    ///              + (w2 ⊙ η̈_ij) ⊙ h          (from ∂η̇_i = η̈_{ij})
2074    ///              + (w2 ⊙ η̇_i) ⊙ ḣ_j        (from ∂h = ḣ_j)
2075    ///              + (w2 ⊙ η̇_j) ⊙ ḣ_i        (from ∂w1 = w2 ⊙ η̇_j, ⊙ ḣ_i)
2076    ///              +  w1 ⊙ ḧ_{ij}             (from ∂ḣ_i = ḧ_{ij}).
2077    /// └─────────────────────────────────────────────────────────────────┘
2078    ///
2079    /// ALL Ï_{r,ij}, η̈_{ij}, ḣ_i, ḧ_{ij} computations are identical to
2080    /// those already computed inside Primitive A's `hphi_tau_tau_partial_apply`.
2081    /// We replicate only the pieces needed to yield the scalar and p-vector
2082    /// outputs to avoid computing the full p×m action when unnecessary.
2083    ///
2084    pub(crate) fn exact_tau_tau_kernel(
2085        &self,
2086        x_tau_i: &Array2<f64>,
2087        x_tau_j: &Array2<f64>,
2088        x_tau_tau: Option<&Array2<f64>>,
2089        beta: &Array1<f64>,
2090        include_hphi_tau_tau_kernel: bool,
2091    ) -> FirthTauTauExactKernel {
2092        let deta_i = x_tau_i.dot(beta);
2093        let deta_j = x_tau_j.dot(beta);
2094        let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2095
2096        let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2097        let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2098        let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2099
2100        let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2101        let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2102
2103        // Ï_{r,ij} = X_{r,ij}^T W X_r + X_r^T W X_{r,ij}
2104        //            + X_{r,i}^T W X_{r,j} + X_{r,j}^T W X_{r,i}
2105        //            + X_{r,i}^T Ẇ_j X_r + X_r^T Ẇ_j X_{r,i}
2106        //            + X_{r,j}^T Ẇ_i X_r + X_r^T Ẇ_i X_{r,j}
2107        //            + X_r^T Ẅ_{ij} X_r
2108        // Ẇ_α = diag(w' ⊙ η̇_α);  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_{ij}).
2109        let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2110        let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2111        let dw_i = &self.w1 * &deta_i;
2112        let dw_j = &self.w1 * &deta_j;
2113        let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2114
2115        let x_r = &self.x_reduced;
2116        let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2117        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2118            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2119            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2120        }
2121        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2122        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2123        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2124        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2125        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2126        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2127        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2128
2129        // pair.a likelihood contribution:
2130        //   0.5 tr(K_r Ï_{r,ij}) − 0.5 tr(K_r İ_{r,j} K_r İ_{r,i}).
2131        // K_r İ_{r,α} — reuse inline dot().
2132        let k = &self.k_reduced;
2133        let k_dot_i_i = k.dot(&dot_i_i);
2134        let k_dot_i_j = k.dot(&dot_i_j);
2135        let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2136            - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2137
2138        // pair.a penalty-basis contribution:
2139        //   Ṡ_{r,α} = X_{r,α}^T X_r + X_r^T X_{r,α}
2140        //   S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2141        //            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}
2142        //   tr(G_r Ṡ_{r,i}) = Σ_k G_r_kk · diag(Ṡ_{r,i})_kk
2143        //   tr(Ġ_{r,j} Ṡ_{r,i}) = −Σ_k G_r_kk² · diag(Ṡ_{r,j})_kk · diag(Ṡ_{r,i})_kk
2144        let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2145        let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2146        let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2147        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2148            s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2149        }
2150        s_ddot = s_ddot
2151            + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2152            + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2153        // With G_r = diag(g) in the canonical reduced basis (where
2154        // S_r is diagonal), S_r + τ·Ṡ is generally non-diagonal under
2155        // perturbation, so Ġ_j = −G Ṡ_j G picks up OFF-DIAGONAL terms:
2156        //     (Ġ_j)_{kl} = −G_k · (Ṡ_j)_{kl} · G_l.
2157        // Hence tr(Ġ_j Ṡ_i) = −Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{lk}.
2158        // Using symmetry of Ṡ_i (and Ṡ_j):
2159        //     −0.5 tr(Ġ_j Ṡ_i) = +0.5 Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{kl}.
2160        // The S̈_{ij} trace against diagonal G_r picks only the diagonal.
2161        let g_inv = &self.x_metric_reduced_inv_diag;
2162        let rdim = k.nrows();
2163        let mut a_pen = 0.0_f64;
2164        for kk in 0..rdim {
2165            for ll in 0..rdim {
2166                a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2167            }
2168            a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2169        }
2170        let phi_tau_tau_partial = a_lik + a_pen;
2171
2172        // ─── pair.g p-vector: (gΦ)_{τ_i τ_j}|β ──────────────────────────
2173        //
2174        // Assemble ḧ_{ij} identically to Primitive A's body.  We need:
2175        //   K̇_{r,α} = −K_r İ_{r,α} K_r,
2176        //   K̈_{r,ij} = −K_r Ï_{r,ij} K_r + K_r İ_{r,i} K_r İ_{r,j} K_r
2177        //                                 + K_r İ_{r,j} K_r İ_{r,i} K_r.
2178        let dot_k_i = -k.dot(&dot_i_i).dot(k);
2179        let dot_k_j = -k.dot(&dot_i_j).dot(k);
2180        let a_i_red = -&dot_k_i; // K İ_i K
2181        let a_j_red = -&dot_k_j; // K İ_j K
2182        let k_ddot: Array2<f64> =
2183            -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2184
2185        // ḧ_{ij} = 2 diag(X_{r,ij} K X_r^T)
2186        //        + diag(X_r K̈_{ij} X_r^T)
2187        //        + 2 diag(X_{r,i} K̇_j X_r^T)
2188        //        + 2 diag(X_{r,j} K̇_i X_r^T)
2189        //        + 2 diag(X_{r,i} K X_{r,j}^T).
2190        let n = self.x_dense.nrows();
2191        let mut dh_ij = Array1::<f64>::zeros(n);
2192        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2193            let rij_k = x_rij.dot(k);
2194            dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2195        }
2196        let xr_kddot = x_r.dot(&k_ddot);
2197        dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2198        let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2199        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2200        let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2201        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2202        let ri_k = x_tau_i_reduced.dot(k);
2203        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2204
2205        // term_A = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)
2206        //        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2207        let w1_h = &self.w1 * &self.h_diag;
2208        let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2209        if let Some(x_ij) = x_tau_tau.as_ref() {
2210            gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2211        }
2212        let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2213        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2214
2215        // term_B pieces:  v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2216        let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2217        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2218
2219        // v̇_{τ_i,τ_j} =
2220        //    (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h
2221        //  + (w2 ⊙ η̈_{ij}) ⊙ h
2222        //  + (w2 ⊙ η̇_i) ⊙ ḣ_j
2223        //  + (w2 ⊙ η̇_j) ⊙ ḣ_i
2224        //  +  w1 ⊙ ḧ_{ij}.
2225        let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2226        v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2227        v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2228        v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2229        v_dot_ij += &(&self.w1 * &dh_ij);
2230        gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2231
2232        let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2233            Some(self.hphi_tau_tau_partial_prepare_from_partials(
2234                x_tau_i_reduced,
2235                x_tau_j_reduced,
2236                &deta_i,
2237                &deta_j,
2238                dot_h_i,
2239                dot_h_j,
2240                dot_i_i,
2241                dot_i_j,
2242                x_tau_tau_reduced,
2243                deta_ij,
2244            ))
2245        } else {
2246            None
2247        };
2248
2249        FirthTauTauExactKernel {
2250            phi_tau_tau_partial,
2251            gphi_tau_tau,
2252            tau_tau_kernel,
2253        }
2254    }
2255
2256    /// Apply `Ṗ_τ V = 2 (M ⊙ Ṁ_τ) V` given the reduced τ-drift design
2257    /// `x_tau_reduced` and the reduced Fisher-inverse drift `dot_k_reduced`.
2258    ///
2259    /// This mirrors the body of `apply_mtau_to_matrix` but accepts the
2260    /// x_tau/dot_k pieces directly, letting Primitive A reuse the same
2261    /// matrix-free Ṗ_τ applies without owning a `FirthTauPartialKernel`.
2262    pub(crate) fn apply_mtau_from_reduced(
2263        &self,
2264        x_tau_reduced: &Array2<f64>,
2265        dot_k_reduced: &Array2<f64>,
2266        mat: &Array2<f64>,
2267    ) -> Array2<f64> {
2268        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2269            return Array2::<f64>::zeros(mat.raw_dim());
2270        }
2271        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2272        for col in 0..mat.ncols() {
2273            let v = mat.column(col).to_owned();
2274            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2275            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2276            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2277
2278            let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2279            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2280            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2281
2282            let t3 =
2283                RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2284
2285            let y = 2.0 * (t1 + t2 + t3);
2286            out.column_mut(col).assign(&y);
2287        }
2288        out
2289    }
2290
2291    /// Apply `P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V + 2 (M ⊙ M̈_{ij}) V` columnwise.
2292    ///
2293    /// `M̈_{ij}` expands into 9 pieces `Y_α C Y_βᵀ`; `Ṁ_i ⊙ Ṁ_j` into 9 cross
2294    /// pieces `(Y_{1,α} B_{1,α} W_{1,α}ᵀ) ⊙ (Y_{2,β} B_{2,β} W_{2,β}ᵀ)`.  Both
2295    /// are evaluated via the matrix-free identities:
2296    ///
2297    ///   [(ZAZᵀ) ⊙ (YBWᵀ) v]_i   = rowwise_bilinear(Y, B · (Wᵀdiag(v)Z) · A, Z)_i,
2298    ///   [(YBWᵀ) ⊙ (Y'B'W'ᵀ) v]_i= rowwise_bilinear(Y, B · (Wᵀdiag(v)W') · B'ᵀ, Y')_i,
2299    ///
2300    /// with S := row-wise reducedweighted Gram.
2301    pub(crate) fn apply_p_ddot_ij(
2302        &self,
2303        x_r: &Array2<f64>,
2304        x_ri: &Array2<f64>,
2305        x_rj: &Array2<f64>,
2306        x_rij: &Array2<f64>,
2307        k: &Array2<f64>,
2308        dot_k_i: &Array2<f64>,
2309        dot_k_j: &Array2<f64>,
2310        k_ddot: &Array2<f64>,
2311        x_tau_tau_is_some: bool,
2312        mat: &Array2<f64>,
2313    ) -> Array2<f64> {
2314        let n = self.x_dense.nrows();
2315        let m = mat.ncols();
2316        if mat.nrows() != n || m == 0 {
2317            return Array2::<f64>::zeros(mat.raw_dim());
2318        }
2319        let mut out = Array2::<f64>::zeros((n, m));
2320        for col in 0..m {
2321            let v = mat.column(col).to_owned();
2322            // Shared reducedweighted Grams for this column.  Only the Grams
2323            // actually appearing in the 18 pieces below are computed.
2324            let s_zz = RemlState::reducedweighted_gram(x_r, &v); // Z'diag(v)Z
2325            let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); // Z'diag(v)Y_j
2326            let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); // Y_i'diag(v)Z
2327            let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); // Y_j'diag(v)Z
2328            let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); // Y_i'diag(v)Y_j
2329
2330            // ── 4 (Ṁ_i ⊙ Ṁ_j) v ──
2331            // Ṁ_i has three pieces:
2332            //   P_i,a = Y_i K Zᵀ          — Y=Y_i, B=K, W=Z
2333            //   P_i,b = Z K̇_i Zᵀ         — Y=Z,   B=K̇_i, W=Z
2334            //   P_i,c = Z K Y_iᵀ          — Y=Z,   B=K,  W=Y_i
2335            // And symmetrically for Ṁ_j with (i→j).
2336            //
2337            // For each cross pair (α, β), compute
2338            //   core = B_α · (W_αᵀ diag(v) W_β) · B_βᵀ,
2339            //   y_piece = rowwise_bilinear(Y_α, core, Y_β),
2340            // then sum all 9 and scale by 4.
2341            let mut mdot_mdot = Array1::<f64>::zeros(n);
2342            // (a_i, a_j): Y_i, K, Z  ×  Y_j, K, Z  → W_α=Z, W_β=Z, S = s_zz
2343            {
2344                let core = k.dot(&s_zz).dot(&k.t());
2345                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2346            }
2347            // (a_i, b_j): Y_i, K, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K · s_zz · K̇_jᵀ
2348            {
2349                let core = k.dot(&s_zz).dot(&dot_k_j.t());
2350                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2351            }
2352            // (a_i, c_j): Y_i, K, Z  ×  Z, K, Y_j  → S = s_zj; core = K · s_zj · Kᵀ
2353            {
2354                let core = k.dot(&s_zj).dot(&k.t());
2355                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2356            }
2357            // (b_i, a_j): Z, K̇_i, Z  ×  Y_j, K, Z  → S = s_zz; core = K̇_i · s_zz · Kᵀ
2358            {
2359                let core = dot_k_i.dot(&s_zz).dot(&k.t());
2360                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2361            }
2362            // (b_i, b_j): Z, K̇_i, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K̇_i · s_zz · K̇_jᵀ
2363            {
2364                let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2365                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2366            }
2367            // (b_i, c_j): Z, K̇_i, Z  ×  Z, K, Y_j  → S = s_zj; core = K̇_i · s_zj · Kᵀ
2368            {
2369                let core = dot_k_i.dot(&s_zj).dot(&k.t());
2370                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2371            }
2372            // (c_i, a_j): Z, K, Y_i  ×  Y_j, K, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2373            //   core = K · s_iz · Kᵀ
2374            {
2375                let core = k.dot(&s_iz).dot(&k.t());
2376                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2377            }
2378            // (c_i, b_j): Z, K, Y_i  ×  Z, K̇_j, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2379            //   core = K · s_iz · K̇_jᵀ
2380            {
2381                let core = k.dot(&s_iz).dot(&dot_k_j.t());
2382                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2383            }
2384            // (c_i, c_j): Z, K, Y_i  ×  Z, K, Y_j  → S = Y_iᵀ diag(v) Y_j = s_ij;
2385            //   core = K · s_ij · Kᵀ
2386            {
2387                let core = k.dot(&s_ij).dot(&k.t());
2388                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2389            }
2390
2391            // ── 2 (M ⊙ M̈_{ij}) v ──
2392            // Each piece has the form Y_α C W_βᵀ; M = Z K Zᵀ with A=K.
2393            // Identity:  [(ZAZᵀ) ⊙ (Y_α C W_βᵀ) v]_i
2394            //          = rowwise_bilinear(Y_α, C · (W_βᵀ diag(v) Z) · A, Z).
2395            let mut m_mddot = Array1::<f64>::zeros(n);
2396            // (a) Y_α = X_{r,ij}, C = K, W_β = X_r  → W_βᵀ diag(v) Z = s_zz
2397            if x_tau_tau_is_some {
2398                let core = k.dot(&s_zz).dot(k);
2399                m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2400            }
2401            // (b) Y_α = X_r, C = K, W_β = X_{r,ij} → W_βᵀ diag(v) Z = X_{r,ij}ᵀ diag(v) Z
2402            if x_tau_tau_is_some {
2403                let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2404                let core = k.dot(&s_ijz).dot(k);
2405                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2406            }
2407            // (c) Y_α = X_{r,i}, C = K̇_j, W_β = X_r → S = s_zz
2408            {
2409                let core = dot_k_j.dot(&s_zz).dot(k);
2410                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2411            }
2412            // (d) Y_α = X_r, C = K̇_j, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2413            {
2414                let core = dot_k_j.dot(&s_iz).dot(k);
2415                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2416            }
2417            // (e) Y_α = X_{r,j}, C = K̇_i, W_β = X_r → S = s_zz
2418            {
2419                let core = dot_k_i.dot(&s_zz).dot(k);
2420                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2421            }
2422            // (f) Y_α = X_r, C = K̇_i, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2423            {
2424                let core = dot_k_i.dot(&s_jz).dot(k);
2425                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2426            }
2427            // (g) Y_α = X_{r,i}, C = K, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2428            {
2429                let core = k.dot(&s_jz).dot(k);
2430                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2431            }
2432            // (h) Y_α = X_{r,j}, C = K, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2433            {
2434                let core = k.dot(&s_iz).dot(k);
2435                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2436            }
2437            // (i) Y_α = X_r, C = K̈_ij, W_β = X_r → S = s_zz
2438            {
2439                let core = k_ddot.dot(&s_zz).dot(k);
2440                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2441            }
2442
2443            // P̈_{ij} = ∂²(M⊙M)/∂τ_i∂τ_j = 2(Ṁ_i ⊙ Ṁ_j) + 2(M ⊙ M̈_{ij}),
2444            // with Ṁ_τ = ∂M/∂τ (NOT the pair-squared derivative).  Factor is 2,
2445            // not 4 — the earlier "4·Ṁ_i⊙Ṁ_j" was a sign-of-the-derivative
2446            // confusion between Ṁ and ∂(M⊙M)/∂τ = 2(M⊙Ṁ).
2447            let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2448            out.column_mut(col).assign(&col_out);
2449        }
2450        out
2451    }
2452
2453    /// Primitive B — prepare step: assemble the reduced kernel for
2454    /// D_β((H_φ)_τ|_β)[v].
2455    ///
2456    /// Consumes the existing `FirthTauPartialKernel`, the τ-drift partials
2457    /// (`deta_partial = η̇_τ = X_τ β` and `dot_i_partial = İ_τ`), the
2458    /// β-direction `FirthDirection` built from `deta = X v`, and
2459    /// `x_tau_v = X_τ v`, and returns a cached kernel carrying the mixed
2460    /// β-τ reduced quantities A_v, dh_v, D_β(İ_τ)[v], D_β(K̇_τ)[v],
2461    /// D_β(ḣ_τ)[v], and the w-chain derivatives needed by
2462    /// `d_beta_hphi_tau_partial_apply`.
2463    pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2464        &self,
2465        tau_kernel: &FirthTauPartialKernel,
2466        deta_partial: &Array1<f64>,
2467        dot_i_partial: &Array2<f64>,
2468        beta_direction: &FirthDirection,
2469        x_tau_v: &Array1<f64>,
2470    ) -> FirthTauBetaPartialKernel {
2471        // D_β(İ_τ)[v] — three-piece symmetric form from the product rule on
2472        //   İ_τ = X_{r,τ}ᵀ W X_r + X_rᵀ W X_{r,τ} + X_rᵀ diag(w' ⊙ η̇_τ) X_r,
2473        // where W = diag(w(η)) is the Fisher weight (not its derivative).
2474        // The β-differential hits w (through η=Xβ) and η̇_τ (through X_τ β):
2475        //   D_β(X_{r,τ}ᵀ W X_r)[v] = X_{r,τ}ᵀ diag(w' ⊙ δη_v) X_r,
2476        //   D_β(X_rᵀ diag(w' ⊙ η̇_τ) X_r)[v]
2477        //     = X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v + w' ⊙ δη_{τ,v}) X_r,
2478        // where δη_v = beta_direction.deta, δη_{τ,v} = x_tau_v.
2479        // s_v := w' ⊙ δη_v (same weight the FirthDirection uses to build
2480        // g_u_reduced); b_vvec := w'' ⊙ δη_v = beta_direction.b_uvec is the
2481        // weight for the third-term product-rule piece.
2482        let s_v = &self.w1 * &beta_direction.deta;
2483        let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2484        let cross1 =
2485            RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2486        let cross2 =
2487            RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2488        let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2489        let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2490
2491        // D_β(K̇_τ)[v] — direct Leibniz on K̇_τ = -K_r İ_τ K_r with
2492        //   D_β K_r[v] = -K_r I'_v K_r = -beta_direction.a_u_reduced.
2493        // Expanding yields
2494        //   D_β K̇_τ[v] = +A_v İ_τ K_r − K_r D_β(İ_τ)[v] K_r + K_r İ_τ A_v,
2495        // where A_v := beta_direction.a_u_reduced = +K_r I'_v K_r.  The
2496        // FirthDirection carries a_u with the opposite sign convention to the
2497        // derivation block's "A_v"; we keep the direction convention and
2498        // compose signs correctly here.
2499        let term_a = beta_direction
2500            .a_u_reduced
2501            .dot(dot_i_partial)
2502            .dot(&self.k_reduced);
2503        let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2504        let term_c = self
2505            .k_reduced
2506            .dot(dot_i_partial)
2507            .dot(&beta_direction.a_u_reduced);
2508        let d_beta_dot_k = &term_a - &term_b + &term_c;
2509
2510        // D_β(ḣ_τ)[v] — β-differential of
2511        //   ḣ_τ = 2·diag(X_{r,τ} K_r X_rᵀ) + diag(X_r K̇_τ X_rᵀ):
2512        //   D_β ḣ_τ[v]
2513        //     = 2·diag(X_{r,τ} D_β K_r[v] X_rᵀ) + diag(X_r D_β K̇_τ[v] X_rᵀ)
2514        //     = -2·diag(X_{r,τ} A_v X_rᵀ) + diag(X_r (D_β K̇_τ[v]) X_rᵀ).
2515        let cross_diag = Self::rowwise_bilinear(
2516            &tau_kernel.x_tau_reduced,
2517            &beta_direction.a_u_reduced,
2518            &self.x_reduced,
2519        );
2520        let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2521        let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2522
2523        FirthTauBetaPartialKernel {
2524            x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2525            deta_partial: deta_partial.clone(),
2526            dot_h_partial: tau_kernel.dot_h_partial.clone(),
2527            dot_i_partial: dot_i_partial.clone(),
2528            dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2529            deta_v: beta_direction.deta.clone(),
2530            deta_tau_v: x_tau_v.clone(),
2531            a_v_reduced: beta_direction.a_u_reduced.clone(),
2532            dh_v: beta_direction.dh.clone(),
2533            b_vvec: beta_direction.b_uvec.clone(),
2534            d_beta_dot_k,
2535            d_beta_dot_h,
2536        }
2537    }
2538
2539    /// Apply the mixed β-τ P-action `P_{τ,v} · mat` to an n×m column block.
2540    ///
2541    /// Expansion:
2542    ///   P_{τ,v} = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
2543    ///     M_v     = X_r K̇_v X_rᵀ,  K̇_v = -A_v (A_v = a_v_reduced),
2544    ///     M_τ     = X_{r,τ} K_r X_rᵀ + X_r K_r X_{r,τ}ᵀ + X_r K̇_τ X_rᵀ,
2545    ///     M_{τ,v} = X_{r,τ} K̇_v X_rᵀ + X_r K̇_v X_{r,τ}ᵀ + X_r D_β K̇_τ[v] X_rᵀ.
2546    /// Hadamard-Gram pieces are evaluated column-wise via
2547    ///   ((Z M_A Wᵀ) ⊙ (Y M_B Xᵀ)) v row-i
2548    ///       = z_iᵀ M_A (Wᵀ diag(v) X) M_Bᵀ y_i.
2549    pub(crate) fn apply_p_tau_v_to_matrix(
2550        &self,
2551        kernel: &FirthTauBetaPartialKernel,
2552        mat: &Array2<f64>,
2553    ) -> Array2<f64> {
2554        let n = self.x_dense.nrows();
2555        if mat.nrows() != n || mat.ncols() == 0 {
2556            return Array2::<f64>::zeros(mat.raw_dim());
2557        }
2558        let z = &self.x_reduced;
2559        let z_tau = &kernel.x_tau_reduced;
2560        let k_r = &self.k_reduced;
2561        let a_v = &kernel.a_v_reduced; // = +K_r I'_v K_r  (so K̇_v = -a_v)
2562        let dot_k_tau = &kernel.dot_k_reduced; // K̇_τ = -K_r İ_τ K_r
2563        let d_beta_dot_k = &kernel.d_beta_dot_k; // D_β K̇_τ[v]
2564        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2565        for col in 0..mat.ncols() {
2566            let v = mat.column(col).to_owned();
2567            let s_zz = RemlState::reducedweighted_gram(z, &v);
2568            let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2569
2570            // Piece 1: (X_r K̇_v X_rᵀ ⊙ X_{r,τ} K_r X_rᵀ) · v
2571            //   = -rowwise_bilinear(Z, a_v S_zz K_r, Z_τ).
2572            let mid_1 = a_v.dot(&s_zz).dot(k_r);
2573            let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2574            // Piece 2: (X_r K̇_v X_rᵀ ⊙ X_r K_r X_{r,τ}ᵀ) · v
2575            //   = -reduced_diag_gram(Z, a_v S_z_ztau K_r).
2576            let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2577            let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2578            // Piece 3: (X_r K̇_v X_rᵀ ⊙ X_r K̇_τ X_rᵀ) · v
2579            //   = -reduced_diag_gram(Z, a_v S_zz K̇_τ).
2580            let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2581            let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2582            // Piece 4: (M ⊙ X_{r,τ} K̇_v X_rᵀ) · v
2583            //   = -rowwise_bilinear(Z, K_r S_zz a_v, Z_τ).
2584            let mid_4 = k_r.dot(&s_zz).dot(a_v);
2585            let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2586            // Piece 5: (M ⊙ X_r K̇_v X_{r,τ}ᵀ) · v
2587            //   = -reduced_diag_gram(Z, K_r S_z_ztau a_v).
2588            let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2589            let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2590            // Piece 6: (M ⊙ X_r D_β K̇_τ[v] X_rᵀ) · v.
2591            let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2592
2593            // P_{τ,v} = 2·(pieces 1-3) + 2·(pieces 4-6); each group contributes
2594            // with the same outer factor 2.
2595            let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2596            out.column_mut(col).assign(&y);
2597        }
2598        out
2599    }
2600
2601    pub(crate) fn d_beta_hphi_tau_partial_apply(
2602        &self,
2603        x_tau: &Array2<f64>,
2604        kernel: &FirthTauBetaPartialKernel,
2605        rhs: &Array2<f64>,
2606    ) -> Array2<f64> {
2607        let p = self.x_dense.ncols();
2608        if rhs.nrows() != p {
2609            return Array2::<f64>::zeros((p, rhs.ncols()));
2610        }
2611        if rhs.ncols() == 0 || p == 0 {
2612            return Array2::<f64>::zeros((p, rhs.ncols()));
2613        }
2614        // Matrix-free block apply of D_β((H_φ)_τ|_β)[v] evaluated on a rhs V.
2615        // Structure follows hphi_tau_partial_apply but replaces every weight
2616        // and every reduced Gram with its β-derivative in direction v:
2617        //
2618        //   (H_φ)_τ|_β (V) = 0.5 [X_τᵀ r(V) + Xᵀ r_τ(V)].
2619        //
2620        // D_β[v] leaves X, X_τ fixed and acts on r, r_τ:
2621        //   D_β((H_φ)_τ|_β)[v](V) = 0.5 [X_τᵀ D_β r(V)[v] + Xᵀ D_β r_τ(V)[v]].
2622        let etav = fast_ab(&self.x_dense, rhs);
2623        let etav_tau = fast_ab(x_tau, rhs);
2624        let deta_v = &kernel.deta_v;
2625        let deta_tau_v = &kernel.deta_tau_v;
2626        let eta_tau = &kernel.deta_partial;
2627        let dot_h = &kernel.dot_h_partial;
2628
2629        // Reuse τ-kernel weights.  dotw1 = w'' ⊙ η̇_τ, dotw2 = w''' ⊙ η̇_τ.
2630        let dotw1 = &self.w2 * eta_tau;
2631        let dotw2 = &self.w3 * eta_tau;
2632
2633        // β-derivative scaling vectors in direction v:
2634        //   c_v              = D_β(w''·h)[v]    = w'''·δη_v·h + w''·dh_v
2635        //   b_vvec           = D_β(w')[v]       = w''·δη_v   (= kernel.b_vvec)
2636        //   d_beta_dotw1_vec = D_β(w''·η̇_τ)[v]  = w'''·δη_v·η̇_τ + w''·δη_{τ,v}
2637        //   d_beta_dotw2_vec = D_β(w'''·η̇_τ)[v] = w''''·δη_v·η̇_τ + w'''·δη_{τ,v}
2638        let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2639        let b_vvec = &kernel.b_vvec;
2640        let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2641        let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2642
2643        // Single-τ pieces (identical to hphi_tau_partial_apply).
2644        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2645        let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2646            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2647        let m_qv = self.apply_pbar_to_matrix(&qv);
2648        // apply_mtau_to_matrix only reads x_tau_reduced and dot_k_reduced off
2649        // the τ-kernel, but owning the full struct is cheap.
2650        let tau_kernel_view = FirthTauPartialKernel {
2651            deta_partial: eta_tau.clone(),
2652            dotw1: dotw1.clone(),
2653            dotw2: dotw2.clone(),
2654            dot_h_partial: dot_h.clone(),
2655            x_tau_reduced: kernel.x_tau_reduced.clone(),
2656            dot_i_partial: kernel.dot_i_partial.clone(),
2657            dot_k_reduced: kernel.dot_k_reduced.clone(),
2658        };
2659        let m_qv_tau =
2660            self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2661
2662        // β-derivatives of the single-τ pieces:
2663        //   D_β qv     = etav · D_β w'[v]       = etav · b_vvec
2664        //   D_β qv_tau = etav · D_β dotw1[v] + etav_tau · D_β w'[v]
2665        let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2666        let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2667            + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2668
2669        //   D_β m_qv = P_v · qv + P · D_β qv
2670        let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2671            + self.apply_pbar_to_matrix(&d_beta_qv);
2672
2673        //   D_β m_qv_tau = P_{τ,v}·qv + P_τ·D_β qv + P_v·qv_tau + P·D_β qv_tau
2674        let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2675            + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2676            + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2677            + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2678
2679        // D_β rv[v] where rv = etav·(w''·h) − w'·m_qv:
2680        //   D_β rv[v] = etav·c_v − b_vvec·m_qv − w'·D_β m_qv.
2681        let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2682            - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2683            - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2684
2685        // D_β rv_tau[v] where
2686        //   rv_tau = etav·dotw2·h + etav_tau·w''·h + etav·w''·dot_h
2687        //            − m_qv·dotw1 − m_qv_tau·w'.
2688        //
2689        //   D_β(dotw2·h)[v]   = (w''''·δη_v·η̇_τ + w'''·δη_{τ,v})·h
2690        //                        + dotw2·dh_v,
2691        //   D_β(w''·h)[v]     = c_v,
2692        //   D_β(w''·dot_h)[v] = w'''·δη_v·dot_h + w''·D_β dot_h[v],
2693        //   D_β dotw1[v]      = d_beta_dotw1_vec,
2694        //   D_β w'[v]         = b_vvec.
2695        let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2696        let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2697
2698        let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2699            + &etav_tau * &c_v.view().insert_axis(Axis(1))
2700            + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2701            - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2702            - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2703            - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2704            - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2705
2706        0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2707    }
2708}
2709
2710#[cfg(test)]
2711mod tests {
2712    use super::*;
2713    use crate::mixture_link::logit_inverse_link_jet5;
2714    use gam_problem::StandardLink;
2715    use ndarray::{Array1, Array2, array};
2716
2717    pub(crate) fn build_logit_firth_dense_operator(
2718        x_dense: &Array2<f64>,
2719        eta: &Array1<f64>,
2720    ) -> Result<FirthDenseOperator, EstimationError> {
2721        FirthDenseOperator::build_with_observation_weights_impl(
2722            &InverseLink::Standard(StandardLink::Logit),
2723            x_dense,
2724            eta,
2725            None,
2726        )
2727    }
2728
2729    pub(crate) fn build_weighted_logit_firth_dense_operator(
2730        x_dense: &Array2<f64>,
2731        eta: &Array1<f64>,
2732        observation_weights: ndarray::ArrayView1<'_, f64>,
2733    ) -> Result<FirthDenseOperator, EstimationError> {
2734        FirthDenseOperator::build_with_observation_weights_impl(
2735            &InverseLink::Standard(StandardLink::Logit),
2736            x_dense,
2737            eta,
2738            Some(observation_weights),
2739        )
2740    }
2741
2742    pub(crate) fn logisticweight(eta: f64) -> f64 {
2743        logit_inverse_link_jet5(eta).d1
2744    }
2745
2746    pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
2747        let eta = x.dot(beta);
2748        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
2749        op.jeffreys_logdet()
2750    }
2751
2752    pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
2753        let eta = x.dot(beta);
2754        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
2755        op.jeffreys_beta_gradient()
2756    }
2757
2758    pub(crate) fn weighted_firthphivalue(
2759        x: &Array2<f64>,
2760        beta: &Array1<f64>,
2761        observation_weights: &Array1<f64>,
2762    ) -> f64 {
2763        let eta = x.dot(beta);
2764        let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
2765            .expect("weighted firth operator");
2766        op.jeffreys_logdet()
2767    }
2768
2769    #[test]
2770    pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
2771        let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
2772        let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
2773            .expect("barely positive-definite reduced fisher");
2774        let expected = 0.5 * 16.0_f64.ln();
2775
2776        assert!(
2777            half_log_det.is_finite(),
2778            "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
2779        );
2780        assert!(
2781            (half_log_det - expected).abs() < 1e-12,
2782            "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
2783        );
2784        assert!(
2785            k_reduced.iter().all(|value| value.is_finite()),
2786            "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
2787        );
2788        assert!(
2789            k_reduced[[1, 1]].abs() < f64::EPSILON,
2790            "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
2791        );
2792    }
2793
2794    #[test]
2795    pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
2796        // Validates op.w[i] (= jet.d1) and op.w1..w4[i] (= jet.d2..jet.d5)
2797        // against direct central finite differences of the logistic inverse
2798        // link pdf w(η) = μ(η)(1−μ(η)).
2799        //
2800        // Nested central differences amplify roundoff by 1/h per nesting
2801        // level, so a d1fd-of-d1fd-of-d2fd cannot deliver the tolerances
2802        // that 4th-order agreement requires. The principled replacement is
2803        // a direct higher-order stencil whose truncation and roundoff are
2804        // both controlled by a single step h:
2805        //
2806        //   d1  (2-pt):  (f(z+h) − f(z−h)) / (2h)                       O(h²) trunc
2807        //   d2  (3-pt):  (f(z+h) − 2f(z) + f(z−h)) / h²                 O(h²) trunc
2808        //   d3  (4-pt):  (−f(z−2h)+2f(z−h)−2f(z+h)+f(z+2h)) / (2h³)     O(h²) trunc
2809        //   d4  (5-pt):  (f(z−2h)−4f(z−h)+6f(z)−4f(z+h)+f(z+2h)) / h⁴   O(h²) trunc
2810        //
2811        // At h = 1e-2 the logistic pdf and its higher derivatives stay of
2812        // order ≤ 1, so truncation O(h²·M) ≲ 1e-4 and roundoff O(ε/h^n)
2813        // is well below any asserted tolerance through the 4th order.
2814        let x = array![
2815            [1.0, -1.1, 0.2],
2816            [1.0, -0.5, -0.6],
2817            [1.0, 0.0, 0.3],
2818            [1.0, 0.8, -0.4],
2819            [1.0, 1.2, 0.7],
2820        ];
2821        let beta = array![0.15, -0.6, 0.35];
2822        let eta = x.dot(&beta);
2823        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
2824
2825        let h = 1e-2_f64;
2826        let w = |z: f64| logisticweight(z);
2827        let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
2828        let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
2829        let d3direct = |z: f64| {
2830            (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
2831        };
2832        let d4direct = |z: f64| {
2833            (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
2834                / h.powi(4)
2835        };
2836        for i in 0..eta.len() {
2837            let z = eta[i];
2838            let wfd = w(z);
2839            let w1fd = d1direct(z);
2840            let w2fd = d2direct(z);
2841            let w3fd = d3direct(z);
2842            let w4fd = d4direct(z);
2843
2844            assert!((op.w[i] - wfd).abs() < 1e-12);
2845            assert_eq!(op.w1[i].signum(), w1fd.signum());
2846            assert_eq!(op.w2[i].signum(), w2fd.signum());
2847            assert_eq!(op.w3[i].signum(), w3fd.signum());
2848            assert_eq!(op.w4[i].signum(), w4fd.signum());
2849            assert!((op.w1[i] - w1fd).abs() < 1e-5);
2850            assert!((op.w2[i] - w2fd).abs() < 1e-4);
2851            assert!((op.w3[i] - w3fd).abs() < 1e-4);
2852            assert!((op.w4[i] - w4fd).abs() < 1e-3);
2853        }
2854    }
2855
2856    #[test]
2857    pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
2858        let x = array![
2859            [1.0, -0.7, 0.3],
2860            [1.0, -0.2, -0.4],
2861            [1.0, 0.5, 0.1],
2862            [1.0, 1.1, -0.6],
2863            [1.0, 1.6, 0.8],
2864        ];
2865        let beta = array![0.2, -0.45, 0.25];
2866        let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
2867        let eta = x.dot(&beta);
2868        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
2869            .expect("weighted firth operator");
2870        let grad = op.jeffreys_beta_gradient();
2871        let h = 1e-6;
2872
2873        for j in 0..beta.len() {
2874            let mut beta_plus = beta.clone();
2875            beta_plus[j] += h;
2876            let mut beta_minus = beta.clone();
2877            beta_minus[j] -= h;
2878            let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
2879                - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
2880                / (2.0 * h);
2881            assert!(
2882                (grad[j] - fd).abs() < 1e-5,
2883                "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
2884                j,
2885                grad[j],
2886                fd
2887            );
2888        }
2889    }
2890
2891    // ----------------------------------------------------------------------
2892    // Link-general (probit) finite-difference proof of the Jeffreys/Firth
2893    // Φ(β) = ½ log|I_r(β)|, its β-gradient ∂Φ/∂β, and the β-Hessian
2894    // derivative D H_φ[u] exposed via `hphi_direction`. Logit is used as a
2895    // regression guard against the historical logit-pinned build.
2896    // ----------------------------------------------------------------------
2897
2898    pub(crate) fn build_link_firth_op(
2899        link: StandardLink,
2900        x: &Array2<f64>,
2901        beta: &Array1<f64>,
2902    ) -> FirthDenseOperator {
2903        let eta = x.dot(beta);
2904        FirthDenseOperator::build_with_observation_weights_impl(
2905            &InverseLink::Standard(link),
2906            x,
2907            &eta,
2908            None,
2909        )
2910        .expect("link-general firth operator")
2911    }
2912
2913    pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
2914        build_link_firth_op(link, x, beta).jeffreys_logdet()
2915    }
2916
2917    pub(crate) fn link_firth_grad(
2918        link: StandardLink,
2919        x: &Array2<f64>,
2920        beta: &Array1<f64>,
2921    ) -> Array1<f64> {
2922        build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
2923    }
2924
2925    /// Central-difference Jacobian of the *analytic* Firth gradient, i.e. a
2926    /// numerical realization of the β-Hessian H_φ = ∂g/∂β. The Newton/REML
2927    /// path consumes H_φ (and its directional derivative) so this is the
2928    /// matrix the analytic curvature must reproduce.
2929    pub(crate) fn numeric_firth_hessian(
2930        link: StandardLink,
2931        x: &Array2<f64>,
2932        beta: &Array1<f64>,
2933        h: f64,
2934    ) -> Array2<f64> {
2935        let p = beta.len();
2936        let mut hess = Array2::<f64>::zeros((p, p));
2937        for j in 0..p {
2938            let mut bp = beta.clone();
2939            bp[j] += h;
2940            let mut bm = beta.clone();
2941            bm[j] -= h;
2942            let gp = link_firth_grad(link, x, &bp);
2943            let gm = link_firth_grad(link, x, &bm);
2944            let col = (&gp - &gm) / (2.0 * h);
2945            hess.column_mut(j).assign(&col);
2946        }
2947        hess
2948    }
2949
2950    /// A fixed, well-conditioned full-rank design (deterministic, no RNG).
2951    pub(crate) fn fixed_design_5x3() -> Array2<f64> {
2952        array![
2953            [1.0, -1.10, 0.35],
2954            [1.0, -0.40, -0.65],
2955            [1.0, 0.15, 0.20],
2956            [1.0, 0.80, -0.45],
2957            [1.0, 1.25, 0.70],
2958        ]
2959    }
2960
2961    #[test]
2962    pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
2963        // Guard: the StandardLink::Logit path through the link-general builder
2964        // must be byte-identical to the historical logit-pinned operator for
2965        // Φ, the β-gradient, the PIRLS hat diagonal, and the cached weight
2966        // jets w, w'..w''''.
2967        let x = fixed_design_5x3();
2968        let beta = array![0.20, -0.55, 0.30];
2969        let eta = x.dot(&beta);
2970
2971        let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
2972        let link_general = FirthDenseOperator::build_with_observation_weights_impl(
2973            &InverseLink::Standard(StandardLink::Logit),
2974            &x,
2975            &eta,
2976            None,
2977        )
2978        .expect("link-general logit");
2979
2980        assert_eq!(
2981            historical.jeffreys_logdet(),
2982            link_general.jeffreys_logdet(),
2983            "logit Φ must be bit-identical through the link-general path"
2984        );
2985        let g_hist = historical.jeffreys_beta_gradient();
2986        let g_link = link_general.jeffreys_beta_gradient();
2987        for j in 0..g_hist.len() {
2988            assert_eq!(
2989                g_hist[j], g_link[j],
2990                "logit gradient component {j} must be bit-identical"
2991            );
2992        }
2993        let hat_hist = historical.pirls_hat_diag();
2994        let hat_link = link_general.pirls_hat_diag();
2995        for i in 0..hat_hist.len() {
2996            assert_eq!(
2997                hat_hist[i], hat_link[i],
2998                "logit PIRLS hat diagonal {i} must be bit-identical"
2999            );
3000        }
3001        for i in 0..eta.len() {
3002            assert_eq!(historical.w[i], link_general.w[i]);
3003            assert_eq!(historical.w1[i], link_general.w1[i]);
3004            assert_eq!(historical.w2[i], link_general.w2[i]);
3005            assert_eq!(historical.w3[i], link_general.w3[i]);
3006            assert_eq!(historical.w4[i], link_general.w4[i]);
3007        }
3008    }
3009
3010    #[test]
3011    pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3012        // PROBIT correctness: ∂Φ/∂β from `jeffreys_beta_gradient` must match a
3013        // central finite difference of Φ(β) on a well-conditioned design.
3014        let x = fixed_design_5x3();
3015        let beta = array![0.10, -0.40, 0.25];
3016        let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3017        let h = 1e-6_f64;
3018        let mut max_rel = 0.0_f64;
3019        for j in 0..beta.len() {
3020            let mut bp = beta.clone();
3021            bp[j] += h;
3022            let mut bm = beta.clone();
3023            bm[j] -= h;
3024            let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3025                - link_firth_phi(StandardLink::Probit, &x, &bm))
3026                / (2.0 * h);
3027            let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3028            let rel = (grad[j] - fd).abs() / denom;
3029            max_rel = max_rel.max(rel);
3030            assert!(
3031                rel < 1e-6,
3032                "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3033                grad[j],
3034                fd,
3035                rel
3036            );
3037        }
3038        assert!(
3039            max_rel < 1e-6,
3040            "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3041        );
3042    }
3043
3044    #[test]
3045    pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3046        // PROBIT Hessian: `hphi_direction(direction_from_deta(X·u))` is the
3047        // analytic directional derivative D H_φ[u] of the β-Hessian. Verify it
3048        // against the central finite difference of the (numerically realized)
3049        // β-Hessian H_φ along u. The numeric H_φ at each shifted β is itself a
3050        // finite difference of the *analytic* gradient, so the base operand is
3051        // analytic at first order; only the directional step is differenced
3052        // here.
3053        let x = fixed_design_5x3();
3054        let beta = array![0.10, -0.40, 0.25];
3055        let p = beta.len();
3056
3057        // Probe several directions, including non-axis-aligned ones.
3058        let directions = [
3059            array![1.0, 0.0, 0.0],
3060            array![0.0, 1.0, 0.0],
3061            array![0.0, 0.0, 1.0],
3062            array![0.7, -0.5, 0.3],
3063        ];
3064
3065        let h_inner = 1e-4_f64; // step for the numeric Hessian (FD of analytic grad)
3066        let h_dir = 1e-4_f64; // step for the directional derivative of the Hessian
3067        let mut worst = 0.0_f64;
3068        for u in directions.iter() {
3069            let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3070            let deta = x.dot(u);
3071            let dir = op.direction_from_deta(deta);
3072            let analytic = op.hphi_direction(&dir);
3073
3074            let beta_plus = &beta + &(u * h_dir);
3075            let beta_minus = &beta - &(u * h_dir);
3076            let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3077            let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3078            let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3079
3080            let mut scale = 1e-6_f64;
3081            for r in 0..p {
3082                for c in 0..p {
3083                    scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3084                }
3085            }
3086            for r in 0..p {
3087                for c in 0..p {
3088                    let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3089                    worst = worst.max(rel);
3090                    assert!(
3091                        rel < 5e-3,
3092                        "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3093                        analytic[[r, c]],
3094                        fd[[r, c]],
3095                        rel
3096                    );
3097                }
3098            }
3099        }
3100        assert!(
3101            worst < 5e-3,
3102            "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3103        );
3104    }
3105
3106    #[test]
3107    pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3108        // Identifiable-subspace behavior: a rank-deficient design (column 3 =
3109        // column 1 + column 2) must yield a finite Φ = ½ log|Uᵀ W U|, a finite
3110        // gradient, and agree with the explicit reduced two-column design.
3111        let x_full = array![
3112            [1.0, -1.20, -0.20],
3113            [1.0, -0.40, 0.60],
3114            [1.0, 0.10, 1.10],
3115            [1.0, 0.70, 1.70],
3116            [1.0, 1.30, 2.30],
3117        ];
3118        let x_reduced = array![
3119            [1.0, -1.20],
3120            [1.0, -0.40],
3121            [1.0, 0.10],
3122            [1.0, 0.70],
3123            [1.0, 1.30],
3124        ];
3125        let beta_full = array![0.25, -0.50, 0.15];
3126        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3127
3128        let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3129        let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3130        assert!(
3131            phi_full.is_finite(),
3132            "probit Φ on rank-deficient design must be finite, got {phi_full}"
3133        );
3134        assert!(
3135            (phi_full - phi_reduced).abs() < 1e-12,
3136            "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3137        );
3138
3139        let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3140        let grad_full = op_full.jeffreys_beta_gradient();
3141        assert!(
3142            grad_full.iter().all(|v| v.is_finite()),
3143            "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3144        );
3145        let hat_full = op_full.pirls_hat_diag();
3146        let hat_reduced =
3147            build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3148        for i in 0..hat_full.len() {
3149            assert!(
3150                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3151                "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3152                hat_full[i],
3153                hat_reduced[i]
3154            );
3155        }
3156    }
3157
3158    #[test]
3159    pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3160        // Column 3 is exactly column 1 + column 2, so the original design is
3161        // rank-deficient but its identifiable subspace is represented exactly by
3162        // the explicit two-column reduced design below.
3163        let x_full = array![
3164            [1.0, -1.2, -0.2],
3165            [1.0, -0.4, 0.6],
3166            [1.0, 0.1, 1.1],
3167            [1.0, 0.7, 1.7],
3168            [1.0, 1.3, 2.3],
3169        ];
3170        let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3171        let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3172        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3173        let eta_full = x_full.dot(&beta_full);
3174        let eta_reduced = x_reduced.dot(&beta_reduced);
3175        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3176
3177        for i in 0..eta_full.len() {
3178            assert!(
3179                (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3180                "eta mismatch at row {i}: full={} reduced={}",
3181                eta_full[i],
3182                eta_reduced[i]
3183            );
3184        }
3185
3186        let op_full = build_weighted_logit_firth_dense_operator(
3187            &x_full,
3188            &eta_full,
3189            observation_weights.view(),
3190        )
3191        .expect("full firth operator");
3192        let op_reduced = build_weighted_logit_firth_dense_operator(
3193            &x_reduced,
3194            &eta_reduced,
3195            observation_weights.view(),
3196        )
3197        .expect("reduced firth operator");
3198
3199        assert!(
3200            (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3201            "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3202            op_full.jeffreys_logdet(),
3203            op_reduced.jeffreys_logdet()
3204        );
3205
3206        let hat_full = op_full.pirls_hat_diag();
3207        let hat_reduced = op_reduced.pirls_hat_diag();
3208        for i in 0..hat_full.len() {
3209            assert!(
3210                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3211                "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3212                hat_full[i],
3213                hat_reduced[i]
3214            );
3215        }
3216    }
3217
3218    #[test]
3219    pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3220        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3221        let basis = array![[1.4, -0.3], [0.6, 1.1]];
3222        let x_reparameterized = x.dot(&basis);
3223        let beta = array![0.25, -0.5];
3224        let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3225        assert!(
3226            basis_det.abs() > 1e-12,
3227            "basis transform must be invertible"
3228        );
3229        let basis_inv = array![
3230            [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3231            [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3232        ];
3233        let beta_reparameterized = basis_inv.dot(&beta);
3234        let eta = x.dot(&beta);
3235        let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3236        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3237
3238        for i in 0..eta.len() {
3239            assert!(
3240                (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3241                "eta mismatch at row {i}: original={} reparameterized={}",
3242                eta[i],
3243                eta_reparameterized[i]
3244            );
3245        }
3246
3247        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3248            .expect("original firth operator");
3249        let op_reparameterized = build_weighted_logit_firth_dense_operator(
3250            &x_reparameterized,
3251            &eta_reparameterized,
3252            observation_weights.view(),
3253        )
3254        .expect("reparameterized firth operator");
3255
3256        assert!(
3257            (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3258            "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3259            op.jeffreys_logdet(),
3260            op_reparameterized.jeffreys_logdet()
3261        );
3262
3263        let hat = op.pirls_hat_diag();
3264        let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3265        for i in 0..hat.len() {
3266            assert!(
3267                (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3268                "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3269                hat[i],
3270                hat_reparameterized[i]
3271            );
3272        }
3273    }
3274
3275    #[test]
3276    pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3277        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3278        let beta = array![0.25, -0.5];
3279        let eta = x.dot(&beta);
3280        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3281        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3282            .expect("firth operator");
3283
3284        let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3285        for i in 0..reduced_metric.nrows() {
3286            for j in 0..reduced_metric.ncols() {
3287                if i == j {
3288                    continue;
3289                }
3290                assert!(
3291                    reduced_metric[[i, j]].abs() < 1e-10,
3292                    "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3293                    reduced_metric[[i, j]]
3294                );
3295            }
3296        }
3297    }
3298
3299    #[test]
3300    pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3301        let x = array![
3302            [1.0, -1.0, 0.2],
3303            [1.0, -0.6, -0.3],
3304            [1.0, -0.1, 0.5],
3305            [1.0, 0.3, -0.7],
3306            [1.0, 0.8, 0.1],
3307            [1.0, 1.2, -0.4],
3308        ];
3309        let beta = array![0.1, -0.25, 0.2];
3310        let eta = x.dot(&beta);
3311        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3312
3313        let u = array![0.3, -0.2, 0.4];
3314        let v = array![-0.5, 0.1, 0.25];
3315        let du = op.direction_from_deta(x.dot(&u));
3316        let dv = op.direction_from_deta(x.dot(&v));
3317
3318        let eye = Array2::<f64>::eye(x.ncols());
3319        let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3320        let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3321
3322        for i in 0..uv.nrows() {
3323            for j in 0..uv.ncols() {
3324                let a = uv[[i, j]];
3325                let b = vu[[i, j]];
3326                assert_eq!(
3327                    a.signum(),
3328                    b.signum(),
3329                    "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3330                );
3331                assert!(
3332                    (a - b).abs() < 2e-7,
3333                    "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3334                );
3335            }
3336        }
3337    }
3338
3339    #[test]
3340    pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3341        let x = array![
3342            [1.0, -1.1, 0.2],
3343            [1.0, -0.6, -0.3],
3344            [1.0, -0.1, 0.5],
3345            [1.0, 0.3, -0.7],
3346            [1.0, 0.8, 0.1],
3347            [1.0, 1.2, -0.4],
3348        ];
3349        let beta = array![0.08, -0.22, 0.27];
3350        let eta = x.dot(&beta);
3351        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3352        let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3353        let dir = op.direction_from_deta(x.dot(&u));
3354
3355        let p = x.ncols();
3356        let eye = Array2::<f64>::eye(p);
3357        let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3358        for i in 0..p {
3359            for j in 0..i {
3360                let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3361                via_apply[[i, j]] = sym;
3362                via_apply[[j, i]] = sym;
3363            }
3364        }
3365        let direct = op.hphi_direction(&dir);
3366        let diff = &direct - &via_apply;
3367        let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3368        assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3369    }
3370
3371    #[test]
3372    pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3373        let x = array![
3374            [1.0, -1.0, 0.2],
3375            [1.0, -0.6, -0.3],
3376            [1.0, -0.1, 0.5],
3377            [1.0, 0.3, -0.7],
3378            [1.0, 0.8, 0.1],
3379            [1.0, 1.2, -0.4],
3380        ];
3381        let x_tau = array![
3382            [0.0, 0.15, -0.05],
3383            [0.0, -0.10, 0.02],
3384            [0.0, 0.08, 0.04],
3385            [0.0, -0.06, -0.03],
3386            [0.0, 0.05, 0.01],
3387            [0.0, -0.12, 0.06],
3388        ];
3389        let beta = array![0.1, -0.25, 0.2];
3390        let eta = x.dot(&beta);
3391        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3392        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3393
3394        let h = 1e-6;
3395        let x_plus = &x + &(h * &x_tau);
3396        let x_minus = &x - &(h * &x_tau);
3397        let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3398
3399        assert!(
3400            (analytic - fd).abs() < 1e-6,
3401            "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3402        );
3403    }
3404
3405    #[test]
3406    pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3407        let x = array![
3408            [1.0, -1.0, 0.2],
3409            [1.0, -0.6, -0.3],
3410            [1.0, -0.1, 0.5],
3411            [1.0, 0.3, -0.7],
3412            [1.0, 0.8, 0.1],
3413            [1.0, 1.2, -0.4],
3414        ];
3415        let x_tau = array![
3416            [0.0, 0.15, -0.05],
3417            [0.0, -0.10, 0.02],
3418            [0.0, 0.08, 0.04],
3419            [0.0, -0.06, -0.03],
3420            [0.0, 0.05, 0.01],
3421            [0.0, -0.12, 0.06],
3422        ];
3423        let beta = array![0.1, -0.25, 0.2];
3424        let eta = x.dot(&beta);
3425        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3426        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3427
3428        let h = 1e-6;
3429        let x_plus = &x + &(h * &x_tau);
3430        let x_minus = &x - &(h * &x_tau);
3431        let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3432
3433        let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3434        assert!(
3435            err < 1e-6,
3436            "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3437        );
3438    }
3439
3440    /// Verify pair.a scalar (`phi_tau_tau_partial`) by central-FD'ing the
3441    /// single-τ scalar `phi_tau_partial` along τ_j at fixed β.
3442    /// Identity: ∂/∂τ_j [Φ_{τ_i}|β] = Φ_{τ_iτ_j}|β.
3443    /// Tolerance 1e-7 relative.
3444    #[test]
3445    pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3446        let x = array![
3447            [1.0, -1.0, 0.2],
3448            [1.0, -0.6, -0.3],
3449            [1.0, -0.1, 0.5],
3450            [1.0, 0.3, -0.7],
3451            [1.0, 0.8, 0.1],
3452            [1.0, 1.2, -0.4],
3453        ];
3454        let x_tau_i = array![
3455            [0.0, 0.15, -0.05],
3456            [0.0, -0.10, 0.02],
3457            [0.0, 0.08, 0.04],
3458            [0.0, -0.06, -0.03],
3459            [0.0, 0.05, 0.01],
3460            [0.0, -0.12, 0.06],
3461        ];
3462        let x_tau_j = array![
3463            [0.0, -0.04, 0.11],
3464            [0.0, 0.09, -0.02],
3465            [0.0, -0.06, 0.07],
3466            [0.0, 0.10, -0.05],
3467            [0.0, -0.03, 0.08],
3468            [0.0, 0.07, -0.09],
3469        ];
3470        let beta = array![0.1, -0.25, 0.2];
3471        let eta = x.dot(&beta);
3472        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3473
3474        let analytic = op
3475            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3476            .phi_tau_tau_partial;
3477
3478        let h = 1e-5_f64;
3479        let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3480            let eta_e = x_eval.dot(&beta);
3481            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3482            op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3483                .phi_tau_partial
3484        };
3485        let x_plus = &x + &(h * &x_tau_j);
3486        let x_minus = &x - &(h * &x_tau_j);
3487        let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3488
3489        let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3490        assert!(
3491            rel < 1e-7,
3492            "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3493        );
3494    }
3495
3496    /// Verify pair.g p-vector (`gphi_tau_tau`) by central-FD'ing the single-τ
3497    /// `gphi_tau` along τ_j at fixed β.
3498    /// Identity: ∂/∂τ_j [(gΦ)_{τ_i}|β] = (gΦ)_{τ_iτ_j}|β.
3499    /// Tolerance 1e-7 relative max-abs.
3500    #[test]
3501    pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3502        let x = array![
3503            [1.0, -1.0, 0.2],
3504            [1.0, -0.6, -0.3],
3505            [1.0, -0.1, 0.5],
3506            [1.0, 0.3, -0.7],
3507            [1.0, 0.8, 0.1],
3508            [1.0, 1.2, -0.4],
3509        ];
3510        let x_tau_i = array![
3511            [0.0, 0.15, -0.05],
3512            [0.0, -0.10, 0.02],
3513            [0.0, 0.08, 0.04],
3514            [0.0, -0.06, -0.03],
3515            [0.0, 0.05, 0.01],
3516            [0.0, -0.12, 0.06],
3517        ];
3518        let x_tau_j = array![
3519            [0.0, -0.04, 0.11],
3520            [0.0, 0.09, -0.02],
3521            [0.0, -0.06, 0.07],
3522            [0.0, 0.10, -0.05],
3523            [0.0, -0.03, 0.08],
3524            [0.0, 0.07, -0.09],
3525        ];
3526        let beta = array![0.1, -0.25, 0.2];
3527        let eta = x.dot(&beta);
3528        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3529
3530        let analytic = op
3531            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3532            .gphi_tau_tau;
3533
3534        let h = 1e-5_f64;
3535        let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3536            let eta_e = x_eval.dot(&beta);
3537            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3538            op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3539        };
3540        let x_plus = &x + &(h * &x_tau_j);
3541        let x_minus = &x - &(h * &x_tau_j);
3542        let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3543
3544        let scale = analytic
3545            .iter()
3546            .chain(fd.iter())
3547            .map(|v| v.abs())
3548            .fold(0.0_f64, f64::max)
3549            .max(1.0);
3550        let err_max = (&analytic - &fd)
3551            .iter()
3552            .map(|v| v.abs())
3553            .fold(0.0_f64, f64::max);
3554        let rel = err_max / scale;
3555        assert!(
3556            rel < 1e-7,
3557            "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3558        );
3559    }
3560
3561    /// Verify the Primitive A body (`hphi_tau_tau_partial_apply`) against a
3562    /// finite-difference reference of the single-τ Primitive (
3563    /// `hphi_tau_partial_apply`).
3564    ///
3565    /// Identity under test:
3566    ///     ∂/∂τ_j  { (H_φ)_τ_i |_β · V }   =   ∂²H_φ/∂τ_i ∂τ_j |_β · V.
3567    ///
3568    /// Central-difference reference:
3569    ///   1. Evaluate the single-τ primitive at x, and at x ± h·X_τ_j
3570    ///      — rebuild the FirthDenseOperator (with fresh identifiable Q)
3571    ///      at each perturbed design; H_φ applied to a p-space rhs is
3572    ///      basis-invariant in unreduced β-coords, so Q rotation does not
3573    ///      contaminate the comparison.
3574    ///   2. FD_{i,j} = (T_{plus} − T_{minus}) / (2h) with T = hphi_tau_i_apply(V).
3575    ///   3. Contract both (i,j) and (j,i) directions and verify symmetry
3576    ///      of the analytic as a cross-check.
3577    ///
3578    /// Tolerance: 1e-7 relative max-abs (h chosen to balance truncation
3579    /// error at ~h² and evaluator roundoff at ~ε/h).
3580    #[test]
3581    pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3582        let x = array![
3583            [1.0, -1.0, 0.2],
3584            [1.0, -0.6, -0.3],
3585            [1.0, -0.1, 0.5],
3586            [1.0, 0.3, -0.7],
3587            [1.0, 0.8, 0.1],
3588            [1.0, 1.2, -0.4],
3589        ];
3590        let x_tau_i = array![
3591            [0.0, 0.15, -0.05],
3592            [0.0, -0.10, 0.02],
3593            [0.0, 0.08, 0.04],
3594            [0.0, -0.06, -0.03],
3595            [0.0, 0.05, 0.01],
3596            [0.0, -0.12, 0.06],
3597        ];
3598        let x_tau_j = array![
3599            [0.0, -0.04, 0.11],
3600            [0.0, 0.09, -0.02],
3601            [0.0, -0.06, 0.07],
3602            [0.0, 0.10, -0.05],
3603            [0.0, -0.03, 0.08],
3604            [0.0, 0.07, -0.09],
3605        ];
3606        let beta = array![0.1, -0.25, 0.2];
3607        let eta = x.dot(&beta);
3608        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3609        let p = x.ncols();
3610
3611        // Reproducible small rhs block (p × m).
3612        let m = 3usize;
3613        let mut rhs = Array2::<f64>::zeros((p, m));
3614        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3615        for r in 0..p {
3616            for c in 0..m {
3617                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3618            }
3619        }
3620
3621        // ── Analytic τ×τ pair apply at base design (x_tau_tau = None,
3622        //    deta_ij = None, i.e. design is linear in τ).
3623        let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3624        let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3625        let deta_i = x_tau_i.dot(&beta);
3626        let deta_j = x_tau_j.dot(&beta);
3627        let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3628        let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3629
3630        let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3631            x_tau_i_reduced.clone(),
3632            x_tau_j_reduced.clone(),
3633            &deta_i,
3634            &deta_j,
3635            dot_h_i.clone(),
3636            dot_h_j.clone(),
3637            dot_i_i.clone(),
3638            dot_i_j.clone(),
3639            None,
3640            None,
3641        );
3642        let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3643            x_tau_j_reduced,
3644            x_tau_i_reduced,
3645            &deta_j,
3646            &deta_i,
3647            dot_h_j,
3648            dot_h_i,
3649            dot_i_j,
3650            dot_i_i,
3651            None,
3652            None,
3653        );
3654        let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3655        let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3656
3657        // Symmetry cross-check (Clairaut): ∂²H/∂τ_i∂τ_j = ∂²H/∂τ_j∂τ_i.
3658        let sym_diff: f64 = (&analytic_ij - &analytic_ji)
3659            .iter()
3660            .map(|v| v.abs())
3661            .fold(0.0_f64, f64::max);
3662        let sym_scale: f64 = analytic_ij
3663            .iter()
3664            .chain(analytic_ji.iter())
3665            .map(|v| v.abs())
3666            .fold(0.0_f64, f64::max)
3667            .max(1.0);
3668        assert!(
3669            sym_diff / sym_scale < 1e-10,
3670            "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
3671        );
3672
3673        // ── FD reference: central difference of single-τ primitive in
3674        //    τ_j direction, evaluated along τ_i.
3675        let h = 1e-5_f64;
3676        let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
3677            let eta_e = x_eval.dot(&beta);
3678            let op_e =
3679                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
3680            let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
3681            let deta_i_e = x_tau_i.dot(&beta);
3682            let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
3683            let kernel_i_e = op_e
3684                .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
3685            op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
3686        };
3687        let x_plus = &x + &(h * &x_tau_j);
3688        let x_minus = &x - &(h * &x_tau_j);
3689        let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
3690
3691        // ── Compare analytic_ij (contracted against V along τ_j→analytic's
3692        //    second index) to fd_ij (FD of T_i in τ_j direction).
3693        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
3694            let scale = a
3695                .iter()
3696                .chain(b.iter())
3697                .map(|v| v.abs())
3698                .fold(0.0_f64, f64::max)
3699                .max(1.0);
3700            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
3701            max_diff / scale
3702        };
3703        let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
3704
3705        // Also FD the other direction and compare to analytic_ji, to
3706        // double-cover the primitive.
3707        let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
3708            let eta_e = x_eval.dot(&beta);
3709            let op_e =
3710                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
3711            let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
3712            let deta_j_e = x_tau_j.dot(&beta);
3713            let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
3714            let kernel_j_e = op_e
3715                .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
3716            op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
3717        };
3718        let x_plus_i = &x + &(h * &x_tau_i);
3719        let x_minus_i = &x - &(h * &x_tau_i);
3720        let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
3721        let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
3722
3723        let tol = 1e-7_f64;
3724        assert!(
3725            err_ij < tol,
3726            "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
3727             analytic=\n{analytic_ij:?}\n\
3728             fd=\n{fd_ij:?}"
3729        );
3730        assert!(
3731            err_ji < tol,
3732            "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
3733             analytic=\n{analytic_ji:?}\n\
3734             fd=\n{fd_ji:?}"
3735        );
3736    }
3737
3738    /// Verify the Primitive B body (`d_beta_hphi_tau_partial_apply`) against a
3739    /// finite-difference reference of the single-τ Primitive
3740    /// (`hphi_tau_partial_apply`).
3741    ///
3742    /// Identity under test (β held in the unreduced ambient; the design X is
3743    /// fixed so only w, η̇_τ = X_τ β, and their β-derivatives move):
3744    ///     D_β [ (H_φ)_τ|_β (β) · V ] [v]
3745    ///       = d_beta_hphi_tau_partial_apply(v, V).
3746    ///
3747    /// Central-difference reference:
3748    ///   1. Evaluate T(t) := hphi_tau_partial_apply(V) at β_t = β + t v,
3749    ///      rebuilding FirthDenseOperator at each β (so η = X β_t and the
3750    ///      w-chain are re-derived cleanly).  X is unchanged; Q is rebuilt
3751    ///      but H_φ applied to a p-space rhs is basis-invariant.
3752    ///   2. FD = (T(+h) − T(−h)) / (2h).
3753    ///   3. Tolerance 1e-7 relative max-abs (h chosen to balance truncation
3754    ///      error at ~h² and evaluator roundoff at ~ε/h).
3755    #[test]
3756    pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
3757        let x = array![
3758            [1.0, -1.0, 0.2],
3759            [1.0, -0.6, -0.3],
3760            [1.0, -0.1, 0.5],
3761            [1.0, 0.3, -0.7],
3762            [1.0, 0.8, 0.1],
3763            [1.0, 1.2, -0.4],
3764        ];
3765        let x_tau = array![
3766            [0.0, 0.15, -0.05],
3767            [0.0, -0.10, 0.02],
3768            [0.0, 0.08, 0.04],
3769            [0.0, -0.06, -0.03],
3770            [0.0, 0.05, 0.01],
3771            [0.0, -0.12, 0.06],
3772        ];
3773        let beta = array![0.1, -0.25, 0.2];
3774        // β-direction v for the D_β[·][v] test.
3775        let v = array![0.3, 0.2, -0.15];
3776
3777        let eta = x.dot(&beta);
3778        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3779        let p = x.ncols();
3780
3781        // Reproducible small rhs block (p × m).
3782        let m = 3usize;
3783        let mut rhs = Array2::<f64>::zeros((p, m));
3784        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3785        for r in 0..p {
3786            for c in 0..m {
3787                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3788            }
3789        }
3790
3791        // ── Analytic apply at (x, β).
3792        let x_tau_reduced = op.reduce_explicit_design(&x_tau);
3793        let deta_partial = x_tau.dot(&beta);
3794        let (dot_i_partial, dot_h_partial) =
3795            op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
3796        let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
3797            x_tau_reduced.clone(),
3798            &deta_partial,
3799            dot_h_partial.clone(),
3800            dot_i_partial.clone(),
3801        );
3802
3803        let deta_v = x.dot(&v);
3804        let direction = op.direction_from_deta(deta_v);
3805        let x_tau_v = x_tau.dot(&v);
3806        let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
3807            &tau_kernel,
3808            &deta_partial,
3809            &dot_i_partial,
3810            &direction,
3811            &x_tau_v,
3812        );
3813        let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
3814
3815        // ── FD reference: central difference of single-τ primitive under
3816        //    β → β ± h v.  X stays fixed; η, w, η̇_τ are re-derived.
3817        let h = 1e-5_f64;
3818        let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
3819            let eta_e = x.dot(beta_eval);
3820            let op_e =
3821                build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
3822            let x_tau_r = op_e.reduce_explicit_design(&x_tau);
3823            let deta_e = x_tau.dot(beta_eval);
3824            let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
3825            let ker_e =
3826                op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
3827            op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
3828        };
3829        let beta_plus = &beta + &(h * &v);
3830        let beta_minus = &beta - &(h * &v);
3831        let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
3832
3833        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
3834            let scale = a
3835                .iter()
3836                .chain(b.iter())
3837                .map(|v| v.abs())
3838                .fold(0.0_f64, f64::max)
3839                .max(1.0);
3840            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
3841            max_diff / scale
3842        };
3843        let err = rel_max_abs_diff(&analytic, &fd);
3844
3845        let tol = 1e-7_f64;
3846        assert!(
3847            err < tol,
3848            "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
3849             analytic=\n{analytic:?}\n\
3850             fd=\n{fd:?}"
3851        );
3852    }
3853
3854    #[test]
3855    pub(crate) fn logisticweight_loses_positive_tail_mass() {
3856        let eta = 50.0_f64;
3857        let z = (-eta).exp();
3858        let stable = z / (1.0_f64 + z).powi(2);
3859        assert!(stable > 0.0);
3860        let got = logisticweight(eta);
3861        assert!(
3862            (got - stable).abs() < 1e-30,
3863            "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
3864            got,
3865            stable
3866        );
3867    }
3868
3869    #[test]
3870    pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
3871        // The generalized Firth weight jet for the canonical logit link must
3872        // reproduce the historical `logit_inverse_link_jet5().d1..d5` path
3873        // exactly so the released logit Firth fits stay numerically unchanged.
3874        for &eta in &[
3875            -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
3876        ] {
3877            let jet = logit_inverse_link_jet5(eta);
3878            let (w, w1, w2, w3, w4) =
3879                crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
3880            assert!(
3881                w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
3882                "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
3883                 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
3884                jet.d1,
3885                jet.d2,
3886                jet.d3,
3887                jet.d4,
3888                jet.d5
3889            );
3890        }
3891    }
3892
3893    #[test]
3894    pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
3895        // Probit Bernoulli Fisher weight W(eta) = phi^2 / (Phi (1 - Phi)).
3896        // Validate the closed-form jet against central finite differences of
3897        // the reference scalar weight.
3898        fn reference_probit_weight(eta: f64) -> f64 {
3899            let p = gam_math::probability::normal_cdf(eta);
3900            let q = 1.0 - p;
3901            let phi = gam_math::probability::normal_pdf(eta);
3902            if p <= 0.0 || q <= 0.0 {
3903                return 0.0;
3904            }
3905            phi * phi / (p * q)
3906        }
3907        let h = 1e-4_f64;
3908        for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
3909            let (w, w1, w2, _w3, _w4) =
3910                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3911            let ref_w = reference_probit_weight(eta);
3912            let fd1 =
3913                (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
3914            let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
3915                + reference_probit_weight(eta - h))
3916                / (h * h);
3917            assert!(
3918                (w - ref_w).abs() < 1e-10,
3919                "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
3920            );
3921            assert!(
3922                (w1 - fd1).abs() < 1e-5,
3923                "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
3924            );
3925            assert!(
3926                (w2 - fd2).abs() < 1e-3,
3927                "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
3928            );
3929        }
3930    }
3931
3932    #[test]
3933    pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
3934        // Past the point where the denominator Phi(1-Phi) underflows to zero,
3935        // the weight and all derivatives are exactly zero (the saturated-tail
3936        // convention shared with the inverse-link jet).
3937        for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
3938            let (w, w1, w2, w3, w4) =
3939                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3940            assert!(
3941                w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
3942                "probit Fisher weight jet must saturate to zero at eta={eta}; got \
3943                 ({w}, {w1}, {w2}, {w3}, {w4})"
3944            );
3945        }
3946        // In the moderate tail the denominator is still representable (the
3947        // complement is taken as Phi(-eta), not the cancellation-prone
3948        // `1 - Phi(eta)`), so the weight is a tiny strictly-positive finite
3949        // number with finite derivatives. It must NOT prematurely round to zero.
3950        for &eta in &[12.0_f64, -12.0] {
3951            let (w, w1, w2, w3, w4) =
3952                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
3953            assert!(
3954                w > 0.0
3955                    && w.is_finite()
3956                    && w1.is_finite()
3957                    && w2.is_finite()
3958                    && w3.is_finite()
3959                    && w4.is_finite(),
3960                "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
3961                 ({w}, {w1}, {w2}, {w3}, {w4})"
3962            );
3963        }
3964    }
3965}