Skip to main content

gam_solve/reml/
firth.rs

1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8/// Reciprocal-condition-number floor below which the reduced Fisher information
9/// `I_r` is flagged as near-singular (a diagnostic warning only, not a hard
10/// gate). At `λ_min/λ_max < 1e-10` the SPD assumption on the identifiable
11/// subspace is numerically fragile and the exact pseudodet derivatives may be
12/// ill-conditioned near active-subspace boundaries.
13pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15/// β-dependent reduced-space pieces of the Firth/Jeffreys operator at the
16/// current `η`, produced by `FirthDenseOperator::firth_reduced_core` from a
17/// cached β-independent [`FirthDesignFactor`]. The full operator build consumes
18/// every field; the lightweight PIRLS-diagnostics path consumes only `w`, `w1`,
19/// `h_diag`, and `half_log_det`.
20struct FirthReducedCore {
21    w: Array1<f64>,
22    w1: Array1<f64>,
23    w2: Array1<f64>,
24    w3: Array1<f64>,
25    w4: Array1<f64>,
26    k_reduced: Array2<f64>,
27    half_log_det: f64,
28    h_diag: Array1<f64>,
29}
30
31/// Single-index sub-blocks of the exact mixed second directional derivative
32/// `D²H_φ[u,v]`, precomputed once against the fixed `eye` rhs used by the
33/// exact-Hessian TK outer loop. See
34/// [`FirthDenseOperator::tk_second_direction_eye_cache`] (#1575).
35pub(crate) struct FirthSecondDirEyeCache {
36    /// The fixed identity rhs (`p×p`), kept so per-pair `fast_ab(.., &eye)`
37    /// matmuls reproduce the original byte-for-byte.
38    eye: Array2<f64>,
39    /// `X·I` — index-independent.
40    eta_rhs: Array2<f64>,
41    /// `(Bᵀ P B-base)·I` — index-independent.
42    p_b_rhs: Array2<f64>,
43    /// Per-direction `apply_hadamard_gram(eta_rhs ⊙ b_uvec_i)`.
44    p_bx: Vec<Array2<f64>>,
45    /// Per-direction `apply_p_u(a_u_reduced_i, w' ⊙ eta_rhs)`.
46    pu_qv: Vec<Array2<f64>>,
47}
48
49impl<'a> RemlState<'a> {
50    pub(crate) fn xt_diag_x_dense_into(
51        x: &Array2<f64>,
52        diag: &Array1<f64>,
53        weighted: &mut Array2<f64>,
54    ) -> Array2<f64> {
55        super::assembly::xt_diag_x_dense_into(x, diag, weighted)
56    }
57
58    #[inline]
59    pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
60        n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
61    }
62
63    pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
64        let mut out = Array2::<f64>::zeros(x.raw_dim());
65        super::assembly::row_scale_dense_into(x, scale, &mut out);
66        out
67    }
68
69    #[inline]
70    pub(crate) fn dense_product_likely_uses_inner_parallelism(
71        m: usize,
72        n: usize,
73        k: usize,
74    ) -> bool {
75        // Keep this in sync with faer_ndarray::matmul_parallelism.  When a
76        // dense product is large enough for faer/BLAS-style internal
77        // parallelism, do not also wrap sibling products in rayon::join: that
78        // can oversubscribe CPU threads and slow down the REML hot path.
79        const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
80        const PAR_MIN_LONG_DIM: usize = 256;
81        let flop_scale = m.saturating_mul(n).saturating_mul(k);
82        let long_dim = m.max(n).max(k);
83        flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
84    }
85
86    #[inline]
87    pub(crate) fn should_join_independent_dense_products(
88        products: &[(usize, usize, usize)],
89    ) -> bool {
90        const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
91        if rayon::current_num_threads() <= 1 {
92            return false;
93        }
94        let mut total_flop_scale = 0usize;
95        for &(m, n, k) in products {
96            if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
97                return false;
98            }
99            total_flop_scale =
100                total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
101        }
102        total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
103    }
104
105    /// Undo the fixed observation-weight row scale used by Firth reduced
106    /// designs.
107    ///
108    /// Keep this distinct from PIRLS's sparse-SpGEMM `sqrt_weights` cache in
109    /// `solver/pirls.rs`: PIRLS materializes roots of the current working
110    /// weights for a Gram factorization, while this uses stored fixed
111    /// case-weight roots and their reciprocals to map reduced design
112    /// derivatives back to raw design space.
113    #[inline]
114    pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
115        out: &mut Array2<f64>,
116        observation_weight_sqrt: Option<&Array1<f64>>,
117    ) {
118        let Some(scale) = observation_weight_sqrt else {
119            return;
120        };
121        super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
122    }
123
124    /// GLM Fisher working-weight 5-jet for the requested inverse link. For
125    /// standard Logit this is byte-identical to the historical
126    /// `logit_inverse_link_jet5(eta).d1..d5` path that the Firth operator used
127    /// before the weights were generalized to arbitrary inverse links.
128    #[inline]
129    pub(crate) fn fisher_weight_derivatives(
130        link: &InverseLink,
131        eta: f64,
132    ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
133        fisher_weight_jet5_for_inverse_link(link, eta)
134    }
135
136    #[inline]
137    pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
138        let mut min_pivot_sq = f64::INFINITY;
139        let mut max_pivot_sq = 0.0_f64;
140        for &pivot in chol_diag {
141            if !pivot.is_finite() || pivot <= 0.0 {
142                return false;
143            }
144            let pivot_sq = pivot * pivot;
145            min_pivot_sq = min_pivot_sq.min(pivot_sq);
146            max_pivot_sq = max_pivot_sq.max(pivot_sq);
147        }
148        if !min_pivot_sq.is_finite() {
149            return false;
150        }
151        let scale = max_pivot_sq.max(1.0);
152        let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
153        min_pivot_sq > floor
154    }
155
156    pub(crate) fn reduced_fisher_inverse_and_half_logdet(
157        fisher_reduced: &Array2<f64>,
158    ) -> Result<(Array2<f64>, f64), EstimationError> {
159        let r = fisher_reduced.nrows();
160        assert_eq!(r, fisher_reduced.ncols());
161        let mut k_reduced = Array2::<f64>::zeros((r, r));
162        if r == 0 {
163            return Ok((k_reduced, 0.0));
164        }
165
166        if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
167            let chol_diag = chol.diag();
168            if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
169                let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
170                for col in 0..r {
171                    let mut e_col = Array1::<f64>::zeros(r);
172                    e_col[col] = 1.0;
173                    let solved = chol.solvevec(&e_col);
174                    k_reduced.column_mut(col).assign(&solved);
175                }
176                return Ok((k_reduced, half_log_det));
177            }
178        }
179
180        let (evals_ir, evecs_ir) = fisher_reduced
181            .eigh(Side::Lower)
182            .map_err(EstimationError::EigendecompositionFailed)?;
183        let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
184        let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
185        let mut kept_positive_direction = false;
186        let mut half_log_det = 0.0_f64;
187        for (eig_idx, &eig) in evals_ir.iter().enumerate() {
188            if eig > tol {
189                kept_positive_direction = true;
190                half_log_det += 0.5 * eig.ln();
191                let inv = eig.recip();
192                let vec = evecs_ir.column(eig_idx).to_owned();
193                for row in 0..r {
194                    for col in 0..r {
195                        k_reduced[[row, col]] += inv * vec[row] * vec[col];
196                    }
197                }
198            }
199        }
200        if !kept_positive_direction {
201            return Err(EstimationError::ModelIsIllConditioned {
202                condition_number: f64::INFINITY,
203            });
204        }
205        Ok((k_reduced, half_log_det))
206    }
207
208    pub(crate) fn fill_fisher_weight_derivative_arrays(
209        link: &InverseLink,
210        eta: &Array1<f64>,
211        w: &mut Array1<f64>,
212        w1: &mut Array1<f64>,
213        w2: &mut Array1<f64>,
214        w3: &mut Array1<f64>,
215        w4: &mut Array1<f64>,
216    ) -> Result<(), EstimationError> {
217        assert_eq!(eta.len(), w.len());
218        assert_eq!(eta.len(), w1.len());
219        assert_eq!(eta.len(), w2.len());
220        assert_eq!(eta.len(), w3.len());
221        assert_eq!(eta.len(), w4.len());
222
223        if Self::parallelize_firth_derivative_rows(eta.len()) {
224            let values: Result<Vec<_>, EstimationError> = eta
225                .par_iter()
226                .map(|&ei| Self::fisher_weight_derivatives(link, ei))
227                .collect();
228            for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
229                w[i] = value;
230                w1[i] = first;
231                w2[i] = second;
232                w3[i] = third;
233                w4[i] = fourth;
234            }
235            return Ok(());
236        }
237        for i in 0..eta.len() {
238            let (value, first, second, third, fourth) =
239                Self::fisher_weight_derivatives(link, eta[i])?;
240            w[i] = value;
241            w1[i] = first;
242            w2[i] = second;
243            w3[i] = third;
244            w4[i] = fourth;
245        }
246        Ok(())
247    }
248
249    pub(crate) fn weighted_cross(
250        left: &Array2<f64>,
251        right: &Array2<f64>,
252        weights: &Array1<f64>,
253    ) -> Array2<f64> {
254        assert_eq!(left.nrows(), right.nrows());
255        assert_eq!(left.nrows(), weights.len());
256        super::assembly::weighted_cross_dense(left, right, weights)
257    }
258
259    pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
260        assert_eq!(a.nrows(), b.ncols());
261        assert_eq!(a.ncols(), b.nrows());
262        let elems = a.nrows().saturating_mul(a.ncols());
263        if elems >= 32 * 32 {
264            let aview = FaerArrayView::new(a);
265            let bview = FaerArrayView::new(b);
266            return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
267        }
268        let m = a.nrows();
269        let n = a.ncols();
270        kahan_sum((0..m).map(|i| {
271            let mut acc = 0.0_f64;
272            for j in 0..n {
273                acc += a[[i, j]] * b[[j, i]];
274            }
275            acc
276        }))
277    }
278
279    pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
280        // Returns Zᵀ diag(weights) Z (exact).
281        //
282        // Used for:
283        //   S = Zᵀ diag(v) Z               in Hadamard-Gram apply,
284        //   G_u = X_rᵀ diag(s_u) X_r       with s_u = w' ⊙ (Xu),
285        // both of which avoid constructing dense n×n intermediates.
286        let weighted = Self::row_scale(z, weights);
287        fast_atb(z, &weighted)
288    }
289
290    pub(crate) fn reduced_crossweighted_gram(
291        z_left: &Array2<f64>,
292        z_right: &Array2<f64>,
293        weights: &Array1<f64>,
294    ) -> Array2<f64> {
295        // Returns Z_leftᵀ diag(weights) Z_right (exact).
296        //
297        // This is used for explicit design-moving terms where left/right
298        // reduced designs differ (X_r vs X_{tau,r}) in Hadamard-Gram products.
299        let weighted = Self::row_scale(z_right, weights);
300        fast_atb(z_left, &weighted)
301    }
302
303    pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
304        // Returns diag(Z A Zᵀ), exact without forming dense n×n matrix.
305        //
306        // Identity:
307        //   [diag(Z A Zᵀ)]_i = z_iᵀ A z_i,
308        // where z_i is row i of Z. We compute this as rowwise dot(Z, Z A).
309        let za = fast_ab(z, a);
310        (z * &za).sum_axis(ndarray::Axis(1))
311    }
312
313    pub(crate) fn apply_hadamard_gram(
314        z: &Array2<f64>,
315        a_left: &Array2<f64>,
316        a_right: &Array2<f64>,
317        vec: &Array1<f64>,
318    ) -> Array1<f64> {
319        // Exact apply of:
320        //   y = ((Z A_left Z^T) ⊙ (Z A_right Z^T)) vec
321        // using Gram/Hadamard identity with S = Z^T diag(vec) Z:
322        //   y_i = z_i^T A_left S A_right z_i.
323        //
324        // This is the matrix-free kernel behind the Firth terms that would
325        // otherwise require dense P = M⊙M and M⊙N_u products.
326        // Complexity is O(n r^2) with r = rank(X), no n×n storage.
327        let s = Self::reducedweighted_gram(z, vec);
328        let left_s = a_left.dot(&s);
329        let t = left_s.dot(a_right);
330        Self::reduced_diag_gram(z, &t)
331    }
332
333    pub(crate) fn apply_hadamard_gram_to_matrix(
334        z: &Array2<f64>,
335        a_left: &Array2<f64>,
336        a_right: &Array2<f64>,
337        mat: &Array2<f64>,
338    ) -> Array2<f64> {
339        // Columnwise extension of apply_hadamard_gram:
340        //   out[:,j] = ((Z A_left Zᵀ) ⊙ (Z A_right Zᵀ)) mat[:,j].
341        //
342        // In the Firth derivatives this is used with:
343        //   A_left = K_r, A_right = K_r        for Hw⊙Hw actions,
344        //   A_left = K_r, A_right = A_u        for Hw⊙Nbar_u actions,
345        // and symmetric variants in mixed second-direction terms.
346        let mut out = Array2::<f64>::zeros(mat.raw_dim());
347        for col in 0..mat.ncols() {
348            let v = mat.column(col).to_owned();
349            let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
350            out.column_mut(col).assign(&y);
351        }
352        out
353    }
354
355    /// Link-aware dense Firth/Jeffreys builder. The REML callsites resolve the
356    /// Fisher-weight link via `reml_robust_jeffreys_link` and pass it here;
357    /// standard Logit reproduces the historical logit-pinned build byte-for-byte,
358    /// and stateful links flow through the same inverse-link derivative path.
359    pub(super) fn build_firth_dense_operator_for_link(
360        link: &InverseLink,
361        x_dense: &Array2<f64>,
362        eta: &Array1<f64>,
363        observation_weights: ndarray::ArrayView1<'_, f64>,
364    ) -> Result<FirthDenseOperator, EstimationError> {
365        FirthDenseOperator::build_with_observation_weights_impl(
366            link,
367            x_dense,
368            eta,
369            Some(observation_weights),
370        )
371    }
372
373    pub(super) fn firth_exact_tau_kernel(
374        op: &FirthDenseOperator,
375        x_tau: &Array2<f64>,
376        beta: &Array1<f64>,
377        include_hphi_tau_kernel: bool,
378    ) -> FirthTauExactKernel {
379        op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
380    }
381
382    pub(super) fn firth_hphi_tau_partial_apply(
383        op: &FirthDenseOperator,
384        x_tau: &Array2<f64>,
385        kernel: &FirthTauPartialKernel,
386        rhs: &Array2<f64>,
387    ) -> Array2<f64> {
388        op.hphi_tau_partial_apply(x_tau, kernel, rhs)
389    }
390}
391
392impl FirthDenseOperator {
393    pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
394        for col in 0..q_basis.ncols() {
395            let mut pivot_row = 0usize;
396            let mut pivot_abs = 0.0_f64;
397            for row in 0..q_basis.nrows() {
398                let value = q_basis[[row, col]];
399                let abs_value = value.abs();
400                if abs_value > pivot_abs {
401                    pivot_abs = abs_value;
402                    pivot_row = row;
403                }
404            }
405            if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
406                q_basis.column_mut(col).mapv_inplace(|v| -v);
407            }
408        }
409    }
410
411    pub(crate) fn identifiable_subspace_basis_from_gram(
412        gram: &Array2<f64>,
413    ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
414        let p = gram.nrows();
415        assert_eq!(p, gram.ncols());
416        if p == 0 {
417            return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
418        }
419
420        let (evals, evecs) = gram
421            .eigh(Side::Lower)
422            .map_err(EstimationError::EigendecompositionFailed)?;
423        let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
424        let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
425        let mut keep: Vec<usize> = evals
426            .iter()
427            .enumerate()
428            .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
429            .collect();
430        if keep.is_empty() {
431            return Err(EstimationError::ModelIsIllConditioned {
432                condition_number: f64::INFINITY,
433            });
434        }
435
436        // Use one orthonormal identifiable basis for both the full-rank and
437        // rank-deficient cases. Sorting retained modes by descending design
438        // energy makes the representation deterministic up to eigenspace
439        // degeneracy; fixing the column signs removes the remaining trivial
440        // sign ambiguity.
441        keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
442        let r = keep.len();
443        let mut q_basis = Array2::<f64>::zeros((p, r));
444        let mut metric_spectrum = Array1::<f64>::zeros(r);
445        for (col_idx, eig_idx) in keep.into_iter().enumerate() {
446            q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
447            metric_spectrum[col_idx] = evals[eig_idx];
448        }
449        Self::canonicalize_basis_column_signs(&mut q_basis);
450        Ok((q_basis, metric_spectrum))
451    }
452
453    #[inline]
454    pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
455        assert_eq!(diag.len(), matrix.nrows());
456        assert_eq!(matrix.nrows(), matrix.ncols());
457        kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
458    }
459
460    pub fn build_for_link(
461        link: &InverseLink,
462        x_dense: &Array2<f64>,
463        eta: &Array1<f64>,
464    ) -> Result<FirthDenseOperator, EstimationError> {
465        Self::build_with_observation_weights_impl(link, x_dense, eta, None)
466    }
467
468    pub fn build_with_observation_weights_for_link(
469        link: &InverseLink,
470        x_dense: &Array2<f64>,
471        eta: &Array1<f64>,
472        observation_weights: ndarray::ArrayView1<'_, f64>,
473    ) -> Result<FirthDenseOperator, EstimationError> {
474        Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
475    }
476
477    /// Build the β-independent (design-only) factor of the Firth/Jeffreys
478    /// operator: the identifiable-subspace basis `Q`, the reduced design
479    /// `X_r = A^{1/2} X Q`, the retained design-Gram spectrum `S_r`, and the raw
480    /// design/transpose. This is the O(n·p²) Gram + O(p³) eigendecomposition +
481    /// the n×p design clones. None of it depends on `η`/β, so it is computed
482    /// ONCE per inner PIRLS solve and reused across Newton iterations (#1575).
483    ///
484    /// `build_with_observation_weights_impl` is exactly this factor followed by
485    /// the per-η remainder, so existing callers stay bit-for-bit identical.
486    pub(crate) fn build_design_factor_with_observation_weights(
487        x_dense: &Array2<f64>,
488        observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
489    ) -> Result<FirthDesignFactor, EstimationError> {
490        let n = x_dense.nrows();
491        let observation_weight_sqrt = if let Some(weights) = observation_weights {
492            if weights.len() != n {
493                crate::bail_invalid_estim!(
494                    "Firth operator observation weight length {} != number of rows {}",
495                    weights.len(),
496                    n
497                );
498            }
499            let mut sqrt = Array1::<f64>::zeros(n);
500            for i in 0..n {
501                let weight = weights[i];
502                if !weight.is_finite() || weight < 0.0 {
503                    crate::bail_invalid_estim!(
504                        "Firth operator requires finite nonnegative observation weights, got {} at row {}",
505                        weight,
506                        i
507                    );
508                }
509                sqrt[i] = weight.sqrt();
510            }
511            Some(sqrt)
512        } else {
513            None
514        };
515        let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
516            RemlState::row_scale(x_dense, scale)
517        } else {
518            x_dense.clone()
519        };
520        // X̃ᵀX̃ Gram → identifiable-subspace basis Q and retained spectrum S_r.
521        let gram = fast_atb(&basis_design, &basis_design);
522        let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
523        let x_reduced = fast_ab(&basis_design, &q_basis);
524        let r = q_basis.ncols();
525        let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
526        for col in 0..r {
527            x_metric_reduced_inv_diag[col] = metric_spectrum[col].recip();
528        }
529        let x_dense_t = x_dense.t().to_owned();
530        Ok(FirthDesignFactor {
531            x_dense: x_dense.clone(),
532            x_dense_t,
533            q_basis,
534            x_reduced,
535            observation_weight_sqrt,
536            metric_spectrum,
537            x_metric_reduced_inv_diag,
538            r,
539            n,
540        })
541    }
542
543    /// β-dependent reduced core shared by [`Self::build_from_design_factor`] and
544    /// [`Self::pirls_diagnostics_from_factor`]: from the cached design factor and
545    /// the current `η`, compute the Fisher-weight 5-jet, the reduced Fisher
546    /// inverse `K_r`, the identifiable-subspace half-log-determinant, and the hat
547    /// diagonal `h`. The operations and their order match the un-hoisted
548    /// `build_with_observation_weights_impl` exactly, so every consumer stays
549    /// bit-for-bit identical.
550    fn firth_reduced_core(
551        factor: &FirthDesignFactor,
552        link: &InverseLink,
553        eta: &Array1<f64>,
554    ) -> Result<FirthReducedCore, EstimationError> {
555        let n = factor.n;
556        if eta.len() != n {
557            crate::bail_invalid_estim!(
558                "Firth operator shape mismatch: nrows={}, eta_len={}",
559                n,
560                eta.len()
561            );
562        }
563        let r = factor.r;
564        let mut w = Array1::<f64>::zeros(n);
565        let mut w1 = Array1::<f64>::zeros(n);
566        let mut w2 = Array1::<f64>::zeros(n);
567        let mut w3 = Array1::<f64>::zeros(n);
568        let mut w4 = Array1::<f64>::zeros(n);
569        RemlState::fill_fisher_weight_derivative_arrays(
570            link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
571        )?;
572
573        // Reduced Fisher I_r = X_rᵀ W X_r on the identifiable subspace.
574        let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&factor.x_reduced, &w);
575        if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
576            let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
577            let min_ev = eigvals_ir
578                .iter()
579                .copied()
580                .filter(|v| v.is_finite() && *v > 0.0)
581                .fold(f64::INFINITY, f64::min);
582            if min_ev.is_finite() {
583                let rel = min_ev / max_ev;
584                if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
585                    log::warn!(
586                        "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
587                        min_ev,
588                        max_ev,
589                        rel
590                    );
591                }
592            }
593        }
594
595        let (k_reduced, mut half_log_det) = if r > 0 {
596            RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
597        } else {
598            (Array2::<f64>::zeros((r, r)), 0.0)
599        };
600        if r > 0 {
601            for col in 0..r {
602                let metric_eig = factor.metric_spectrum[col];
603                half_log_det -= 0.5 * metric_eig.ln();
604            }
605        }
606        let h_diag = if r > 0 {
607            RemlState::reduced_diag_gram(&factor.x_reduced, &k_reduced)
608        } else {
609            Array1::<f64>::zeros(n)
610        };
611        Ok(FirthReducedCore {
612            w,
613            w1,
614            w2,
615            w3,
616            w4,
617            k_reduced,
618            half_log_det,
619            h_diag,
620        })
621    }
622
623    /// Rebuild the full Firth operator at a new `η` from a cached design factor.
624    /// Pure memoization: byte-identical to `build_with_observation_weights_impl`
625    /// for the same `(link, design, weights, η)`.
626    pub(crate) fn build_from_design_factor(
627        factor: &FirthDesignFactor,
628        link: &InverseLink,
629        eta: &Array1<f64>,
630    ) -> Result<FirthDenseOperator, EstimationError> {
631        let FirthReducedCore {
632            w,
633            w1,
634            w2,
635            w3,
636            w4,
637            k_reduced,
638            half_log_det,
639            h_diag,
640        } = Self::firth_reduced_core(factor, link, eta)?;
641        let b_base = RemlState::row_scale(&factor.x_dense, &w1);
642        let p_b_base = RemlState::apply_hadamard_gram_to_matrix(
643            &factor.x_reduced,
644            &k_reduced,
645            &k_reduced,
646            &b_base,
647        );
648        Ok(FirthDenseOperator {
649            x_dense: factor.x_dense.clone(),
650            x_dense_t: factor.x_dense_t.clone(),
651            q_basis: factor.q_basis.clone(),
652            x_reduced: factor.x_reduced.clone(),
653            observation_weight_sqrt: factor.observation_weight_sqrt.clone(),
654            k_reduced,
655            x_metric_reduced_inv_diag: factor.x_metric_reduced_inv_diag.clone(),
656            half_log_det,
657            h_diag,
658            w,
659            w1,
660            w2,
661            w3,
662            w4,
663            b_base,
664            p_b_base,
665        })
666    }
667
668    /// Compute ONLY the three PIRLS Firth diagnostics — `(hat_diag,
669    /// jeffreys_logdet, firth_score_shift)` — from a cached design factor at a
670    /// new `η`, skipping the per-iteration `B = diag(w') X` and `P·B` Hadamard
671    /// blocks that the inner PIRLS solve never consumes. Each output is the same
672    /// closed form the full operator's accessors return:
673    ///   hat_diag         = w ⊙ h_diag             (`pirls_hat_diag`),
674    ///   jeffreys_logdet  = half_log_det           (`jeffreys_logdet`),
675    ///   firth_score_shift= ½ (w'/w) ⊙ h_diag      (`pirls_firth_score_shift`),
676    /// so the result is bit-for-bit identical to building the full operator and
677    /// calling those accessors, at a fraction of the cost (#1575).
678    pub(crate) fn pirls_diagnostics_from_factor(
679        factor: &FirthDesignFactor,
680        link: &InverseLink,
681        eta: &Array1<f64>,
682    ) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
683        let core = Self::firth_reduced_core(factor, link, eta)?;
684        let (w, w1, h_diag, half_log_det) =
685            (core.w, core.w1, core.h_diag, core.half_log_det);
686        // hat_diag = w ⊙ h_diag (matches `pirls_hat_diag`).
687        let hat_diag = &w * &h_diag;
688        // firth_score_shift_i = ½ (w'_i / w_i) h_diag_i for w_i > 0, else 0
689        // (matches `pirls_firth_score_shift`).
690        let mut score_shift = Array1::<f64>::zeros(w.len());
691        for i in 0..w.len() {
692            let wi = w[i];
693            if wi > 0.0 {
694                score_shift[i] = 0.5 * (w1[i] / wi) * h_diag[i];
695            }
696        }
697        Ok((hat_diag, half_log_det, score_shift))
698    }
699
700    pub(crate) fn build_with_observation_weights_impl(
701        link: &InverseLink,
702        x_dense: &Array2<f64>,
703        eta: &Array1<f64>,
704        observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
705    ) -> Result<FirthDenseOperator, EstimationError> {
706        // Precompute dense Firth objects at current β̂ for:
707        //   Φ(β) = 0.5 log|Uᵀ W U|,
708        // where U is a canonical orthonormal basis of the identifiable
709        // subspace of A^{1/2} X.
710        //
711        // Identifiability note:
712        // For rank-deficient design matrices X, I = Xᵀ W X is singular for all β
713        // (assuming w_i > 0). The mathematically coherent Jeffreys/Firth term is
714        // therefore the identifiable-subspace form:
715        //   Φ(β) = 0.5 log|Uᵀ W U|
716        //        = 0.5 log|I_r(β)| - 0.5 log|S_r|,
717        // with
718        //   I_r = X_rᵀ W X_r,
719        //   S_r = X_rᵀ X_r,
720        //   U   = X_r S_r^{-1/2}.
721        // The beta differential is still
722        //   dΦ = 0.5 tr(I_+^† dI),
723        // because S_r is beta-independent for a fixed design.
724        //
725        // For binomial-logit with finite eta, 0 < w_i <= 1/4, so W is SPD and
726        // Null(X'WX)=Null(X). Therefore singular directions are structural
727        // (from X) and independent of beta, which is why fixed-Q reduced-space
728        // derivatives are exact in this regime.
729        //
730        // This implementation uses the fixed identifiable-basis route directly:
731        //   I_r = X_rᵀ W X_r,  S_r = X_rᵀ X_r,
732        //   I_+^† = Q I_r^{-1} Qᵀ,
733        //   Φ     = 0.5 (log|I_r| - log|S_r|).
734        //
735        // Why `eta` (not `mu`) enters here:
736        //   all logistic weight derivatives are functions of eta through
737        //   mu(eta)=sigmoid(eta), and exact derivative consistency is preserved by
738        //   generating (w, w', w'', w''', w'''') from one coherent eta source.
739        // Using eta avoids any mismatch from externally clamped/post-processed mu.
740        //
741        // We cache reduced operators so derivatives are exact but matrix-free in n:
742        //   K_r = I_r^{-1},  h = diag(X_r K_r X_rᵀ),  B = diag(w')X,
743        // along with logistic derivatives w', w'', w''', w'''' to evaluate:
744        //   H_φ      = ∇²_β Φ,
745        //   D H_φ[u],
746        //   D² H_φ[u,v]
747        // exactly via reduced-space products (no explicit high-order tensors).
748        //
749        // Fixed observation weights:
750        // When callers provide nonnegative case weights a_i that are constant in
751        // β, the Jeffreys information is
752        //   I(β) = Xᵀ diag(a_i w_i(η)) X.
753        // We fold those fixed a_i into the identifiable basis and reduced design
754        // via X̃ = diag(sqrt(a_i)) X, so all derivative formulas continue to use
755        // the same η-derivatives of the family Fisher weights w(η), w'(η), ....
756        //
757        // This routine is now a thin wrapper: it builds the β-independent design
758        // factor (Gram, identifiable basis Q, reduced design X_r, retained
759        // spectrum S_r — the O(n·p²) + O(p³) work) and then the β-dependent
760        // remainder at `eta`. The two helpers are split out so a single inner
761        // PIRLS solve can hoist the factor out of the per-Newton-iteration hot
762        // path (#1575) while every output here stays bit-for-bit identical.
763        //
764        // The eta-length check is kept here (before the factor build) to
765        // preserve the original error ordering for existing callers.
766        let n = x_dense.nrows();
767        if eta.len() != n {
768            crate::bail_invalid_estim!(
769                "Firth operator shape mismatch: nrows={}, eta_len={}",
770                n,
771                eta.len()
772            );
773        }
774        let factor =
775            Self::build_design_factor_with_observation_weights(x_dense, observation_weights)?;
776        Self::build_from_design_factor(&factor, link, eta)
777    }
778
779    #[inline]
780    pub(crate) fn jeffreys_logdet(&self) -> f64 {
781        self.half_log_det
782    }
783
784    /// Tangent-projected Jeffreys/Firth log-determinant `½ log|Zᵀ J Z|_+`,
785    /// where `J = Xᵀ A W(η) X` is the full p-space Fisher information at
786    /// the current `η` and `Z` is the `p × m` orthonormal basis of
787    /// `null(A_act)` produced by the active-constraint tangent projector.
788    ///
789    /// Identity: `Zᵀ J Z = (X̃ Z)ᵀ W (X̃ Z)` with `X̃ = A^{1/2} X` when
790    /// fixed observation weights are present, `X̃ = X` otherwise. The
791    /// projected log-pseudo-det uses the same positive-eigenvalue
792    /// threshold convention as the rest of the tangent-projected LAML
793    /// (`positive_eigenvalue_threshold` / `exact_pseudo_logdet`) so the
794    /// kernel that defines "active subspace" is consistent across the
795    /// objective, its gradient, and the Firth contribution.
796    pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
797        use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
798        let p = self.x_dense.ncols();
799        assert_eq!(
800            z.nrows(),
801            p,
802            "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
803            p,
804            z.nrows()
805        );
806        let m = z.ncols();
807        if m == 0 {
808            return 0.0;
809        }
810        // X·Z, then optional sqrt(A) row-scale → X̃·Z.
811        let z_owned = z.to_owned();
812        let xz = fast_ab(&self.x_dense, &z_owned);
813        let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
814            RemlState::row_scale(&xz, scale)
815        } else {
816            xz
817        };
818        // J_T = (X̃ Z)ᵀ W (X̃ Z), symmetric m × m PSD.
819        let mut j_t = fast_xt_diag_x(&xtz, &self.w);
820        symmetrize_in_place(&mut j_t);
821        let (evals, _) = match j_t.eigh(Side::Lower) {
822            Ok(pair) => pair,
823            Err(_) => return f64::NEG_INFINITY,
824        };
825        let Some(evals_slice) = evals.as_slice() else {
826            return f64::NEG_INFINITY;
827        };
828        let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
829        0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
830    }
831
832    #[inline]
833    pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
834        // For I(β) = Xᵀ A W(η) X with fixed observation weights A,
835        //   ∂/∂β_j [0.5 log|I|]
836        //   = 0.5 Σ_i h_i w_i'(η_i) x_{ij},
837        // where h_i = [A^{1/2} X I^{-1} Xᵀ A^{1/2}]_{ii}.
838        0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
839    }
840
841    #[inline]
842    pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
843        (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
844    }
845
846    #[inline]
847    pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
848        let mut reduced = fast_ab(x, &self.q_basis);
849        if let Some(scale) = self.observation_weight_sqrt.as_ref() {
850            reduced = RemlState::row_scale(&reduced, scale);
851        }
852        reduced
853    }
854
855    pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
856        // Directional building blocks for u:
857        //   δη_u = X u
858        //   I_u  = Xᵀ diag(w' ⊙ δη_u) X
859        //   T_u  = K I_u K
860        //   N_u  = X T_u Xᵀ
861        //   Dh[u] = -diag(N_u)
862        //   P_u   = D(M⊙M)[u] = -2(M⊙N_u)
863        // and B_u = diag(w'' ⊙ δη_u) X.
864        //
865        // In this implementation, active-subspace ambiguity is removed by fixed Q:
866        // K is represented by K_r = I_r^{-1} in reduced coordinates, so
867        //   A_u = K_r G_u K_r
868        // is exact for logit with finite eta (w_i > 0) and fixed rank(X).
869        // s_u is the diagonal weight for D I[u]:
870        //   D I[u] = Xᵀ diag(s_u) X,  s_u = w' ⊙ (X u).
871        let s_u = &self.w1 * &deta;
872        // G_u = X_rᵀ diag(s_u) X_r,  A_u = K_r G_u K_r.
873        // These are reduced-space forms of
874        //   I_u and T_u = K I_u K
875        // from the full-space derivation.
876        let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
877        let k_g_u = self.k_reduced.dot(&g_u_reduced);
878        let a_u_reduced = k_g_u.dot(&self.k_reduced);
879        // Dh[u] = -diag(N_u),  N_u = X T_u Xᵀ, represented here as
880        //   N_u = Z A_u Zᵀ in weighted reduced coordinates.
881        let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
882        let b_uvec = &self.w2 * &deta;
883        FirthDirection {
884            deta,
885            g_u_reduced,
886            a_u_reduced,
887            dh,
888            b_uvec,
889        }
890    }
891
892    #[inline]
893    pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
894        fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
895    }
896
897    #[inline]
898    pub(crate) fn apply_p_u_to_matrix(
899        &self,
900        a_u_reduced: &Array2<f64>,
901        mat: &Array2<f64>,
902    ) -> Array2<f64> {
903        let mut out = RemlState::apply_hadamard_gram_to_matrix(
904            &self.x_reduced,
905            &self.k_reduced,
906            a_u_reduced,
907            mat,
908        );
909        out.mapv_inplace(|v| -2.0 * v);
910        out
911    }
912
913    pub(crate) fn hphi_direction_apply(
914        &self,
915        dir: &FirthDirection,
916        rhs: &Array2<f64>,
917    ) -> Array2<f64> {
918        let p = self.x_dense.ncols();
919        if rhs.nrows() != p {
920            return Array2::<f64>::zeros((p, rhs.ncols()));
921        }
922        if rhs.ncols() == 0 || p == 0 {
923            return Array2::<f64>::zeros((p, rhs.ncols()));
924        }
925        // Matrix-free apply of D(Hphi)[u] to a block V:
926        //   D(Hphi)[u] V
927        // = 0.5[ Xᵀ(c_u ⊙ (X V))
928        //       - B_uᵀ P (B V) - Bᵀ P (B_u V) - Bᵀ P_u (B V) ].
929        // This avoids dense p×p materialization and is used by sparse exact
930        // trace contractions through tr(H^{-1} ·).
931        let etav = fast_ab(&self.x_dense, rhs);
932        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
933        let m_qv = RemlState::apply_hadamard_gram_to_matrix(
934            &self.x_reduced,
935            &self.k_reduced,
936            &self.k_reduced,
937            &qv,
938        );
939        let buvec = &dir.b_uvec;
940        let m_buv = RemlState::apply_hadamard_gram_to_matrix(
941            &self.x_reduced,
942            &self.k_reduced,
943            &self.k_reduced,
944            &(&etav * &buvec.view().insert_axis(Axis(1))),
945        );
946        let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
947        let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
948        let diag_term = self
949            .x_dense_t
950            .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
951        let term1 = self.left_scaled_xt(buvec, &m_qv);
952        let term2 = self.left_scaled_xt(&self.w1, &m_buv);
953        let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
954        0.5 * (diag_term - (term1 + term2 + term3))
955    }
956
957    pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
958        let p = self.x_dense.ncols();
959        let eye = Array2::<f64>::eye(p);
960        let mut out = self.hphi_direction_apply(dir, &eye);
961        // Exact first directional derivative of H_φ:
962        //   D H_φ[u]
963        //   = 0.5 [ Xᵀ diag(c_u) X
964        //           - (B_uᵀ P B + Bᵀ P B_u + Bᵀ P_u B) ],
965        // where
966        //   c_u = w''' ⊙ δη_u ⊙ h + w'' ⊙ Dh[u],
967        //   B   = diag(w') X.
968        //
969        // Matrix-free contraction map used below:
970        //   Bᵀ P B      via apply_hadamard_gram_to_matrix(Z, K_r, K_r, B)
971        //   Bᵀ P_u B    via apply_hadamard_gram_to_matrix(Z, K_r, A_u, B), then *(-2)
972        // where P = M⊙M and P_u = -2(M⊙N_u), but M/N_u are never formed explicitly.
973        symmetrize_in_place(&mut out);
974        out
975    }
976
977    pub(crate) fn hphisecond_direction_apply(
978        &self,
979        u: &FirthDirection,
980        v: &FirthDirection,
981        rhs: &Array2<f64>,
982    ) -> Array2<f64> {
983        let p = self.x_dense.ncols();
984        if rhs.nrows() != p {
985            return Array2::<f64>::zeros((p, rhs.ncols()));
986        }
987        if rhs.ncols() == 0 || p == 0 {
988            return Array2::<f64>::zeros((p, rhs.ncols()));
989        }
990        // Exact mixed second directional derivative:
991        //   D² H_φ[u,v] = 0.5 [ Xᵀ diag(c_uv) X - D²J₂[u,v] ], J₂ = Bᵀ P B.
992        // Implemented with matrix identities for N_{u,v}, P_{u,v}, and the
993        // nine-term expansion of D²J₂[u,v].
994        //
995        // Because we parameterize Phi through fixed-rank I_r = X_rᵀ W X_r (SPD for
996        // finite-logit eta), this mixed derivative is evaluated on a smooth
997        // manifold without dynamic active-set switching in the Firth block.
998        //
999        // The nine contraction terms below are the explicit D²J₂[u,v] expansion,
1000        // each computed through reduced Hadamard-Gram operators.
1001        let deta_uv = &u.deta * &v.deta;
1002        // Mixed reduced Gram:
1003        //   G_uv = X_rᵀ diag(w'' ⊙ (Xu) ⊙ (Xv)) X_r.
1004        let s_uv = &self.w2 * &deta_uv;
1005        let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1006        let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1007        let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1008        let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1009        // Reduced form of:
1010        //   T_{u,v} = K I_{u,v} K - K Iv K I_u K - K I_u K Iv K.
1011        let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1012            - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1013            - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1014        let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1015        // Implements mixed diagonal coefficient:
1016        //   c_uv = w'''' ⊙ (Xu) ⊙ (Xv) ⊙ h
1017        //          + w''' ⊙ ((Xu) ⊙ Dh[v] + (Xv) ⊙ Dh[u])
1018        //          + w'' ⊙ D²h[u,v].
1019        let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1020            + &(&self.w3 * &(&u.deta * &v.dh))
1021            + &(&self.w3 * &(&v.deta * &u.dh))
1022            + &(&self.w2 * &d2h);
1023
1024        let eta_rhs = fast_ab(&self.x_dense, rhs);
1025        let diag_term = fast_ab(
1026            &self.x_dense_t,
1027            &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1028        );
1029
1030        let b_uvvec = &self.w3 * &deta_uv;
1031        let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1032        let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1033
1034        // Linearity in the rhs argument lets us precompute the expensive
1035        // Hadamard-Gram operator on the full base blocks B, B_u, Bv, B_uv once,
1036        // then post-multiply by rhs. This preserves the exact operator while
1037        // avoiding repeated O(n r^2 c) work for every rhs block.
1038        let p_b_rhs = fast_ab(&self.p_b_base, rhs);
1039        let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
1040            &self.x_reduced,
1041            &self.k_reduced,
1042            &self.k_reduced,
1043            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1044        );
1045        let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
1046            &self.x_reduced,
1047            &self.k_reduced,
1048            &self.k_reduced,
1049            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1050        );
1051        let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1052            &self.x_reduced,
1053            &self.k_reduced,
1054            &self.k_reduced,
1055            &b_uv_base,
1056        );
1057        let p_buv_rhs = fast_ab(&p_buv_base, rhs);
1058
1059        let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
1060        let pv_bu_rhs = self.apply_p_u_to_matrix(
1061            &v.a_u_reduced,
1062            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1063        );
1064        let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
1065        let p_u_bv_rhs = self.apply_p_u_to_matrix(
1066            &u.a_u_reduced,
1067            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1068        );
1069
1070        let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1071            &self.x_reduced,
1072            &u.a_u_reduced,
1073            &v.a_u_reduced,
1074            &self.b_base,
1075        );
1076        let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1077            &self.x_reduced,
1078            &self.k_reduced,
1079            &a_uv_reduced,
1080            &self.b_base,
1081        );
1082        let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1083        let p_uv_rhs = fast_ab(&p_uv_base, rhs);
1084
1085        // Nine-term expansion of D²J₂[u,v] with J₂ = Bᵀ P B.
1086        let d2_terms = [
1087            self.left_scaled_xt(&b_uvvec, &p_b_rhs),
1088            self.left_scaled_xt(&self.w1, &p_buv_rhs),
1089            self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
1090            self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
1091            self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
1092            self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1093            self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
1094            self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1095            self.left_scaled_xt(&self.w1, &p_uv_rhs),
1096        ];
1097        let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1098        for term in d2_terms {
1099            d2_j2 += &term;
1100        }
1101
1102        0.5 * (diag_term - d2_j2)
1103    }
1104
1105    /// Precompute, for a FIXED identity rhs, every sub-block of the mixed second
1106    /// directional derivative `D²H_φ[u,v]` that depends on a SINGLE direction
1107    /// index (or on nothing but the operator). The exact-Hessian TK outer loop
1108    /// (`tk_hessian_rho_canonical_logit`) evaluates `hphisecond_direction_apply`
1109    /// for every one of the `k(k+1)/2` penalty pairs against the same `eye` rhs;
1110    /// the four heavy single-index reduced Hadamard-Gram applies inside it
1111    /// (`p_bu_rhs`/`p_bv_rhs` and `p_u_b_rhs`/`pv_b_rhs`) therefore have only `k`
1112    /// distinct values but were rebuilt `O(k²)` times. Caching them once per
1113    /// index here turns that into `O(k)` of those O(n·r²·p) applies, with the
1114    /// per-pair work limited to the genuinely mixed (`u`,`v`) blocks. This is
1115    /// exact: each cached block is a pure function of `(operator, direction[i])`
1116    /// for the fixed `eye` rhs, so the contraction it feeds is bit-identical to
1117    /// `hphisecond_direction_apply(.., &eye)` (#1575).
1118    pub(crate) fn tk_second_direction_eye_cache(
1119        &self,
1120        dirs: &[FirthDirection],
1121    ) -> FirthSecondDirEyeCache {
1122        let p = self.x_dense.ncols();
1123        let eye = Array2::<f64>::eye(p);
1124        // eta_rhs = X·I and qv = w' ⊙ eta_rhs are rhs-only (index-independent).
1125        let eta_rhs = fast_ab(&self.x_dense, &eye);
1126        let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1127        // p_b_rhs = (Bᵀ P B-base)·I is rhs-only; precompute it once.
1128        let p_b_rhs = fast_ab(&self.p_b_base, &eye);
1129        let mut p_bx = Vec::with_capacity(dirs.len());
1130        let mut pu_qv = Vec::with_capacity(dirs.len());
1131        for d in dirs {
1132            // p_b{u,v}_rhs: depends only on this direction's b_uvec (and eta_rhs).
1133            p_bx.push(RemlState::apply_hadamard_gram_to_matrix(
1134                &self.x_reduced,
1135                &self.k_reduced,
1136                &self.k_reduced,
1137                &(&eta_rhs * &d.b_uvec.view().insert_axis(Axis(1))),
1138            ));
1139            // p_u_b_rhs / pv_b_rhs: depends only on this direction's a_u_reduced.
1140            pu_qv.push(self.apply_p_u_to_matrix(&d.a_u_reduced, &qv));
1141        }
1142        FirthSecondDirEyeCache {
1143            eye,
1144            eta_rhs,
1145            p_b_rhs,
1146            p_bx,
1147            pu_qv,
1148        }
1149    }
1150
1151    /// Exact mixed second directional derivative `D²H_φ[u,v]` against the fixed
1152    /// `eye` rhs, reusing the single-index sub-blocks precomputed once by
1153    /// [`Self::tk_second_direction_eye_cache`]. Bit-identical to
1154    /// `hphisecond_direction_apply(&dirs[i], &dirs[j], &Array2::eye(p))`; only
1155    /// the redundant per-pair recomputation of the single-index blocks is
1156    /// removed (#1575).
1157    pub(crate) fn hphisecond_direction_apply_eye_cached(
1158        &self,
1159        cache: &FirthSecondDirEyeCache,
1160        dirs: &[FirthDirection],
1161        i: usize,
1162        j: usize,
1163    ) -> Array2<f64> {
1164        let u = &dirs[i];
1165        let v = &dirs[j];
1166        let p = self.x_dense.ncols();
1167        let cols = cache.eta_rhs.ncols();
1168        if p == 0 || cols == 0 {
1169            return Array2::<f64>::zeros((p, cols));
1170        }
1171        let deta_uv = &u.deta * &v.deta;
1172        let s_uv = &self.w2 * &deta_uv;
1173        let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1174        let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1175        let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1176        let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1177        let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1178            - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1179            - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1180        let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1181        let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1182            + &(&self.w3 * &(&u.deta * &v.dh))
1183            + &(&self.w3 * &(&v.deta * &u.dh))
1184            + &(&self.w2 * &d2h);
1185
1186        let eta_rhs = &cache.eta_rhs;
1187        let diag_term = fast_ab(
1188            &self.x_dense_t,
1189            &(eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1190        );
1191
1192        let b_uvvec = &self.w3 * &deta_uv;
1193        let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1194
1195        // Single-index blocks reused from the cache (the O(k²)→O(k) win).
1196        let p_b_rhs = &cache.p_b_rhs;
1197        let p_bu_rhs = &cache.p_bx[i];
1198        let p_bv_rhs = &cache.p_bx[j];
1199        let p_u_b_rhs = &cache.pu_qv[i];
1200        let pv_b_rhs = &cache.pu_qv[j];
1201
1202        // Genuinely mixed (u,v) blocks — must be rebuilt per pair.
1203        let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1204            &self.x_reduced,
1205            &self.k_reduced,
1206            &self.k_reduced,
1207            &b_uv_base,
1208        );
1209        let p_buv_rhs = fast_ab(&p_buv_base, &cache.eye);
1210
1211        let pv_bu_rhs = self.apply_p_u_to_matrix(
1212            &v.a_u_reduced,
1213            &(eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1214        );
1215        let p_u_bv_rhs = self.apply_p_u_to_matrix(
1216            &u.a_u_reduced,
1217            &(eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1218        );
1219
1220        let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1221            &self.x_reduced,
1222            &u.a_u_reduced,
1223            &v.a_u_reduced,
1224            &self.b_base,
1225        );
1226        let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1227            &self.x_reduced,
1228            &self.k_reduced,
1229            &a_uv_reduced,
1230            &self.b_base,
1231        );
1232        let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1233        let p_uv_rhs = fast_ab(&p_uv_base, &cache.eye);
1234
1235        let d2_terms = [
1236            self.left_scaled_xt(&b_uvvec, p_b_rhs),
1237            self.left_scaled_xt(&self.w1, &p_buv_rhs),
1238            self.left_scaled_xt(&u.b_uvec, p_bv_rhs),
1239            self.left_scaled_xt(&v.b_uvec, p_bu_rhs),
1240            self.left_scaled_xt(&u.b_uvec, pv_b_rhs),
1241            self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1242            self.left_scaled_xt(&v.b_uvec, p_u_b_rhs),
1243            self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1244            self.left_scaled_xt(&self.w1, &p_uv_rhs),
1245        ];
1246        let mut d2_j2 = Array2::<f64>::zeros((p, cols));
1247        for term in d2_terms {
1248            d2_j2 += &term;
1249        }
1250
1251        0.5 * (diag_term - d2_j2)
1252    }
1253
1254    pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1255        assert_eq!(a.nrows(), b.nrows());
1256        assert_eq!(a.ncols(), b.ncols());
1257        let mut out = Array1::<f64>::zeros(a.nrows());
1258        for i in 0..a.nrows() {
1259            let mut acc = 0.0_f64;
1260            for j in 0..a.ncols() {
1261                acc += a[[i, j]] * b[[i, j]];
1262            }
1263            out[i] = acc;
1264        }
1265        out
1266    }
1267
1268    pub(super) fn rowwise_bilinear(
1269        a: &Array2<f64>,
1270        m: &Array2<f64>,
1271        b: &Array2<f64>,
1272    ) -> Array1<f64> {
1273        // Returns vector with entries a_iᵀ M b_i for each row i.
1274        assert_eq!(a.nrows(), b.nrows());
1275        assert_eq!(a.ncols(), m.nrows());
1276        assert_eq!(b.ncols(), m.ncols());
1277        let am = fast_ab(a, m);
1278        Self::rowwise_dot(&am, b)
1279    }
1280
1281    pub(crate) fn dot_i_and_h_from_reduced(
1282        &self,
1283        x_tau_reduced: &Array2<f64>,
1284        deta: &Array1<f64>,
1285    ) -> (Array2<f64>, Array1<f64>) {
1286        // Reduced Fisher directional derivative under fixed identifiable basis:
1287        //   I_r = X_r' W X_r
1288        //   I_r,tau = X_{r,tau}' W X_r + X_r' W X_{r,tau} + X_r' W_tau X_r
1289        // with W_tau = diag(w' ⊙ eta_tau).
1290        //
1291        // Leverage derivative used by Firth score partial:
1292        //   h_i = x_{r,i}' K_r x_{r,i}, K_r = I_r^{-1}
1293        //   h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r')
1294        //   K_{r,tau} = -K_r I_{r,tau} K_r.
1295        //
1296        // This is exactly the fixed-beta directional derivative required by
1297        //   (gphi)_tau and Phi_tau in the Jeffreys/Firth design-moving path:
1298        //   I_{r,tau}|beta = X_{r,tau}' W X_r + X_r' W X_{r,tau}
1299        //                    + X_r' diag(w' ⊙ eta_tau|beta) X_r,
1300        //   eta_tau|beta = X_tau beta.
1301        //
1302        // We return:
1303        //   dot_i  = I_{r,tau}|beta,
1304        //   dot_h  = h_tau|beta.
1305        let dw = &self.w1 * deta;
1306        let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1307            + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1308            + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1309
1310        let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1311        let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1312        let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1313        let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1314        let dot_h = dot_h_explicit + dot_h_implicit;
1315        (dot_i, dot_h)
1316    }
1317
1318    pub(crate) fn exact_tau_kernel(
1319        &self,
1320        x_tau: &Array2<f64>,
1321        beta: &Array1<f64>,
1322        include_hphi_tau_kernel: bool,
1323    ) -> FirthTauExactKernel {
1324        // Shared exact tau-partial bundle used by both dense and sparse paths:
1325        //   (gphi)_tau | beta-fixed,
1326        //   Phi_tau | beta-fixed,
1327        // and optional H_{phi,tau}|beta kernel for later matrix-free applies.
1328        //
1329        // Closed forms (reduced Fisher, fixed active subspace):
1330        //   Phi = 0.5 log|I_r| - 0.5 log|S_r|,
1331        //   I_r = X_r' W X_r, K_r = I_r^{-1},
1332        //   S_r = X_r' X_r,   diag(G_r) = diag(S_r^{-1}),
1333        //   Phi_tau|beta = 0.5 tr(K_r I_{r,tau}) - 0.5 tr(G_r S_{r,tau}).
1334        // In the canonical reduced basis used here, G_r is diagonal.
1335        //
1336        //   (gphi)_tau = Phi_beta,tau
1337        //               = 0.5 X_tau' (w1 .* h)
1338        //                 + 0.5 X'((w2 .* eta_tau) .* h + w1 .* h_tau),
1339        //   where
1340        //     h_i = x_{r,i}' K_r x_{r,i},
1341        //     h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r'),
1342        //     K_{r,tau} = -K_r I_{r,tau} K_r.
1343        //
1344        // Phi_beta,tau is unchanged by the -0.5 log|S_r| term because S_r does
1345        // not depend on beta. Only Phi_tau gets the explicit basis-drift
1346        // subtraction.
1347        let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1348        let x_tau_reduced = self.reduce_explicit_design(x_tau);
1349        let (dot_i_partial, dot_h_partial) =
1350            self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1351        let dot_s_partial =
1352            fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1353
1354        let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1355        let secondvec =
1356            &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1357        let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1358        let gphi_tau = first + second;
1359        let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1360            - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1361
1362        let tau_kernel = if include_hphi_tau_kernel {
1363            Some(self.hphi_tau_partial_prepare_from_partials(
1364                x_tau_reduced,
1365                &deta_partial,
1366                dot_h_partial,
1367                dot_i_partial,
1368            ))
1369        } else {
1370            None
1371        };
1372        FirthTauExactKernel {
1373            gphi_tau,
1374            phi_tau_partial,
1375            tau_kernel,
1376        }
1377    }
1378
1379    pub(crate) fn hphi_tau_partial_prepare_from_partials(
1380        &self,
1381        x_tau_reduced: Array2<f64>,
1382        deta_partial: &Array1<f64>,
1383        dot_h_partial: Array1<f64>,
1384        dot_i_partial: Array2<f64>,
1385    ) -> FirthTauPartialKernel {
1386        let dotw1 = &self.w2 * deta_partial;
1387        let dotw2 = &self.w3 * deta_partial;
1388        let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1389        FirthTauPartialKernel {
1390            deta_partial: deta_partial.clone(),
1391            dotw1,
1392            dotw2,
1393            dot_h_partial,
1394            x_tau_reduced,
1395            dot_i_partial,
1396            dot_k_reduced: dot_k,
1397        }
1398    }
1399
1400    pub(crate) fn d_beta_hphi_tau_partial_dense(
1401        &self,
1402        x_tau: &Array2<f64>,
1403        beta: &Array1<f64>,
1404        beta_direction: &Array1<f64>,
1405    ) -> Option<Array2<f64>> {
1406        if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1407            return None;
1408        }
1409        if !x_tau.iter().any(|value| *value != 0.0) {
1410            return None;
1411        }
1412        let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1413        let tau_kernel = tau_bundle.tau_kernel?;
1414        let firth_direction =
1415            self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1416        let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1417        let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1418            &tau_kernel,
1419            &tau_kernel.deta_partial,
1420            &tau_kernel.dot_i_partial,
1421            &firth_direction,
1422            &x_tau_v,
1423        );
1424        let eye = Array2::<f64>::eye(beta_direction.len());
1425        Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1426    }
1427
1428    pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1429        // Applies P̄ = (X_r K_r X_rᵀ)⊙(X_r K_r X_rᵀ) to each column of mat.
1430        RemlState::apply_hadamard_gram_to_matrix(
1431            &self.x_reduced,
1432            &self.k_reduced,
1433            &self.k_reduced,
1434            mat,
1435        )
1436    }
1437
1438    pub(crate) fn apply_mtau_to_matrix(
1439        &self,
1440        kernel: &FirthTauPartialKernel,
1441        mat: &Array2<f64>,
1442    ) -> Array2<f64> {
1443        // Exact apply of
1444        //   M_tau = d/dtau[(P⊙P)]|_{beta fixed} = 2(P⊙P_tau)
1445        // without building dense n×n objects.
1446        //
1447        // Decomposition:
1448        //   P = Z K Zᵀ, Z = X_r
1449        //   P_tau = Z_tau K Zᵀ + Z K Z_tauᵀ + Z dotK Zᵀ
1450        // and for each vector v:
1451        //   (P⊙(Z_tau K Zᵀ))v   : rowwise bilinear with K (Zᵀdiag(v)Z) K
1452        //   (P⊙(Z K Z_tauᵀ))v   : diag_Z( K (Zᵀdiag(v)Z_tau) K )
1453        //   (P⊙(Z dotK Zᵀ))v    : Hadamard-Gram apply with (K, dotK).
1454        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1455            return Array2::<f64>::zeros(mat.raw_dim());
1456        }
1457        let mut out = Array2::<f64>::zeros(mat.raw_dim());
1458        for col in 0..mat.ncols() {
1459            let v = mat.column(col).to_owned();
1460            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1461            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1462            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1463
1464            let szt =
1465                RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1466            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1467            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1468
1469            let t3 = RemlState::apply_hadamard_gram(
1470                &self.x_reduced,
1471                &self.k_reduced,
1472                &kernel.dot_k_reduced,
1473                &v,
1474            );
1475
1476            let y = 2.0 * (t1 + t2 + t3);
1477            out.column_mut(col).assign(&y);
1478        }
1479        out
1480    }
1481
1482    pub(crate) fn hphi_tau_partial_apply(
1483        &self,
1484        x_tau: &Array2<f64>,
1485        kernel: &FirthTauPartialKernel,
1486        rhs: &Array2<f64>,
1487    ) -> Array2<f64> {
1488        let p = self.x_dense.ncols();
1489        if rhs.nrows() != p {
1490            return Array2::<f64>::zeros((p, rhs.ncols()));
1491        }
1492        if rhs.ncols() == 0 || p == 0 {
1493            return Array2::<f64>::zeros((p, rhs.ncols()));
1494        }
1495        // Matrix-free block apply of Hphi,tau|beta:
1496        //   Hphi,tau|beta(V) = 0.5 [ X_tau' r(V) + X' r_tau(V) ].
1497        //
1498        // Tensor identity behind this apply:
1499        //   Hphi,tau|beta = Phi_beta,beta,tau
1500        // and for test vectors b1,b2 (matrix columns V are batched b2's):
1501        //   Phi_beta,beta,tau[b1,b2]
1502        //   = 0.5[
1503        //       tr(I^{-1} I_{b1,b2,tau})
1504        //       - tr(I^{-1} I_{b1,b2} I^{-1} I_tau)
1505        //       - tr(I^{-1} I_{b1,tau} I^{-1} I_{b2})
1506        //       - tr(I^{-1} I_{b2,tau} I^{-1} I_{b1})
1507        //       + 2 tr(I^{-1} I_{b1} I^{-1} I_{b2} I^{-1} I_tau)
1508        //     ].
1509        // This routine evaluates that form in reduced coordinates without forming
1510        // dense 3rd-order tensors explicitly.
1511        let etav = fast_ab(&self.x_dense, rhs);
1512        let etav_tau = fast_ab(x_tau, rhs);
1513        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1514        let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1515            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1516        let m_qv = self.apply_pbar_to_matrix(&qv);
1517        let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1518        let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1519            * &self.h_diag.view().insert_axis(Axis(1))
1520            - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1521        let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1522            + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1523            * self.h_diag.view().insert_axis(Axis(1))
1524            + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1525                * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1526            - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1527                + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1528        0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1529    }
1530
1531    // ═════════════════════════════════════════════════════════════════════════
1532    //  Pair-term primitives for the Firth outer Hessian (Task #13a / #17)
1533    // ═════════════════════════════════════════════════════════════════════════
1534    //
1535    // The REML outer Hessian at a ψ=(ρ,τ) pair needs two Firth contributions
1536    // that are NOT covered by the existing single-τ primitives:
1537    //
1538    //   A.  the pure τ×τ second partial of H_φ at fixed β (pair drift inside
1539    //       the fixed-β second-derivative trace of B_i,j used by
1540    //       build_tau_tau_pair_callback).  This is the Firth analog of the
1541    //       penalty-logdet pair term in the outer-derivative cookbook.
1542    //
1543    //   B.  the β-derivative of (H_φ)_τ|_β in direction v.  This is the
1544    //       fixed-drift-derivative M_i[v] = D_β B_i[v] that
1545    //       compute_drift_deriv_traces uses through the fixed_drift_deriv
1546    //       callback in build_tau_hyper_coords.  It is currently always None
1547    //       in the Firth+Logit path, which is what makes the outer Hessian
1548    //       approximate for Firth-reweighted models.
1549    //
1550    // Both primitives operate in the reduced identifiable subspace (X_r, K_r,
1551    // S_r, Z=X_r) of the dense Firth operator and are matrix-free in n.  They
1552    // do NOT introduce any dense n×n or p×p×p object; every contraction is
1553    // routed through reduced-space Hadamard-Gram applies and rowwise
1554    // bilinear forms, in the same spirit as hphi_tau_partial_apply and
1555    // hphisecond_direction_apply.
1556    //
1557    // Both are exact in the smooth-regime operating point assumed by this
1558    // module: X SPD-full-rank on its identifiable subspace, Q held fixed
1559    // within one outer REML step (active-subspace drift enters only between
1560    // outer iterates), w_i(η) > 0 strictly positive, and β is at the P-IRLS
1561    // solution for the current ψ so β_τ is supplied by the unified evaluator
1562    // via the IFT solve.
1563    //
1564    // Symbol conventions shared with the existing operator code:
1565    //   X_r   := X Q            (reduced identifiable design)
1566    //   W     := diag(w(η))     (Fisher weights), w', w'', w''', w''''
1567    //   I_r   := X_rᵀ W X_r,  K_r := I_r^{-1},  S_r := X_rᵀ X_r
1568    //   M     := X_r K_r X_rᵀ,  P := M ⊙ M,  B := diag(w') X
1569    //   h     := diag(M)
1570    //   X_i   := ∂X/∂τ_i,  X_{r,i} := X_i Q,  η̇_i := X_i β,  etc.
1571    //   δη_v  := X v           (β-direction v),  δη_{τ,v} := X_τ v
1572    //
1573    // H_φ structural form (reduced form of ∇²_β Φ_F):
1574    //   H_φ  =  ½ [ Xᵀ diag(w'' ⊙ h) X  −  Bᵀ P B ]
1575    // where the first term arises from differentiating the Jeffreys gradient
1576    // ½ Xᵀ (w' ⊙ h) once more in β, and the second collects the IFT-mediated
1577    // β-derivative of h = diag(X_r K_r X_rᵀ) through K_r = I_r^{-1}.  This is
1578    // the same form the existing hphi_direction code implements in directional
1579    // form (cf. firth.rs hphi_direction / hphisecond_direction_apply, the
1580    // 9-term D²J₂ expansion with J₂ = BᵀPB).
1581    //
1582    // ─────────────────────────────────────────────────────────────────────────
1583    //  Primitive A — ∂²H_φ/∂τ_i ∂τ_j |_β
1584    // ─────────────────────────────────────────────────────────────────────────
1585    //
1586    // WHAT IT COMPUTES
1587    //   Given a pair of τ-drift designs (X_τ_i, X_τ_j) and optional second
1588    //   design derivative X_{τ_i τ_j}, evaluates
1589    //
1590    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [ ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j
1591    //                                − ∂²(Bᵀ P B)/∂τ_i ∂τ_j ],
1592    //   with Γ := diag(w'' ⊙ h).  Acts on a p×m rhs and returns a p×m block
1593    //   (exact same contract as hphi_tau_partial_apply, but for the *second*
1594    //   mixed τ-derivative at fixed β).
1595    //
1596    // WHY (REML callsite)
1597    //   In the outer Hessian entry Ḧ_{i,j} for τ×τ pair (i,j), the fixed-β
1598    //   second drift of B_i is exactly this primitive (with the Firth sign:
1599    //   B_i = −(H_φ)_τ_i|_β + other likelihood pieces, so ∂²B_i/∂τ_j|_β
1600    //   contributes −∂²H_φ/∂τ_i∂τ_j|_β to the outer Hessian trace).  The
1601    //   existing Firth pair callback at build_tau_tau_pair_callback currently
1602    //   carries zero for this Firth contribution; wiring this primitive into
1603    //   the TauTauPairHyperOperator is the remaining step to make the τ×τ
1604    //   outer Hessian exact in the Firth-reweighted Logit path.
1605    //
1606    // DERIVATION (full chain-rule expansion, at fixed β)
1607    //
1608    //   Building blocks at fixed β (single-τ):
1609    //     İ_i      := ∂I_r/∂τ_i |_β
1610    //                = X_{r,i}ᵀ W X_r + X_rᵀ W X_{r,i} + X_rᵀ Ẇ_i X_r,
1611    //       Ẇ_i    := diag(w' ⊙ η̇_i),   η̇_i := X_i β.
1612    //     K̇_i      := ∂K_r/∂τ_i = −K_r İ_i K_r.
1613    //     ḣ_i      := ∂h/∂τ_i |_β
1614    //                = 2·diag(X_{r,i} K_r X_rᵀ) + diag(X_r K̇_i X_rᵀ).
1615    //     Ṁ_i      := ∂M/∂τ_i |_β
1616    //                = X_{r,i} K_r X_rᵀ + X_r K̇_i X_rᵀ + X_r K_r X_{r,i}ᵀ.
1617    //     Ḃ_i      := ∂B/∂τ_i |_β
1618    //                = diag(w'' ⊙ η̇_i) X + diag(w') X_i.
1619    //     Ṗ_i      := ∂P/∂τ_i = 2 (M ⊙ Ṁ_i).
1620    //     Γ̇_i     := ∂Γ/∂τ_i |_β = diag(w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i).
1621    //
1622    //   Second-order building blocks:
1623    //     η̈_{ij}  := X_{ij} β                   (0 for design linear in τ)
1624    //     Ẅ_{ij} := diag(w'' ⊙ η̇_i ⊙ η̇_j
1625    //                     + w' ⊙ η̈_{ij})
1626    //
1627    //     Ï_{ij}   := ∂²I_r/∂τ_i ∂τ_j |_β
1628    //                = X_{r,ij}ᵀ W X_r  +  X_rᵀ W X_{r,ij}
1629    //                 + X_{r,i}ᵀ W X_{r,j}  +  X_{r,j}ᵀ W X_{r,i}
1630    //                 + X_{r,i}ᵀ Ẇ_j X_r  +  X_rᵀ Ẇ_j X_{r,i}
1631    //                 + X_{r,j}ᵀ Ẇ_i X_r  +  X_rᵀ Ẇ_i X_{r,j}
1632    //                 + X_rᵀ Ẅ_{ij} X_r.
1633    //
1634    //     K̈_{ij}  := ∂²K_r/∂τ_i ∂τ_j
1635    //                = −K_r Ï_{ij} K_r
1636    //                  + K_r İ_i K_r İ_j K_r
1637    //                  + K_r İ_j K_r İ_i K_r.
1638    //
1639    //     M̈_{ij}  := X_{r,ij} K_r X_rᵀ + X_r K_r X_{r,ij}ᵀ
1640    //                 + X_{r,i} K̇_j X_rᵀ + X_r K̇_j X_{r,i}ᵀ
1641    //                 + X_{r,j} K̇_i X_rᵀ + X_r K̇_i X_{r,j}ᵀ
1642    //                 + X_{r,i} K_r X_{r,j}ᵀ + X_{r,j} K_r X_{r,i}ᵀ
1643    //                 + X_r K̈_{ij} X_rᵀ.
1644    //
1645    //     P̈_{ij}  := ∂²P/∂τ_i ∂τ_j
1646    //                = 2 (Ṁ_i ⊙ Ṁ_j) + 2 (Ṁ_j ⊙ Ṁ_i) + 2 (M ⊙ M̈_{ij})
1647    //                = 4 (Ṁ_i ⊙ Ṁ_j) + 2 (M ⊙ M̈_{ij}).
1648    //
1649    //     ḧ_{ij}  := ∂²h/∂τ_i ∂τ_j |_β
1650    //                = 2·diag(X_{r,ij} K_r X_rᵀ)
1651    //                 + diag(X_r K̈_{ij} X_rᵀ)
1652    //                 + 2·diag(X_{r,i} K̇_j X_rᵀ)
1653    //                 + 2·diag(X_{r,j} K̇_i X_rᵀ)
1654    //                 + 2·diag(X_{r,i} K_r X_{r,j}ᵀ).
1655    //
1656    //     B̈_{ij}  := ∂²B/∂τ_i ∂τ_j |_β
1657    //                = diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_{ij}) X
1658    //                 + diag(w'' ⊙ η̇_i) X_j
1659    //                 + diag(w'' ⊙ η̇_j) X_i
1660    //                 + diag(w') X_{ij}.
1661    //
1662    //     Γ̈_{ij} := ∂²Γ/∂τ_i ∂τ_j |_β
1663    //                = diag( w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
1664    //                       + w''' ⊙ η̈_{ij} ⊙ h
1665    //                       + w''' ⊙ η̇_i ⊙ ḣ_j
1666    //                       + w''' ⊙ η̇_j ⊙ ḣ_i
1667    //                       + w'' ⊙ ḧ_{ij} ).
1668    //
1669    //   Diagonal-term expansion (the Xᵀ Γ X branch):
1670    //
1671    //     ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  =
1672    //         X_{ij}ᵀ Γ X  + Xᵀ Γ X_{ij}
1673    //       + X_iᵀ Γ X_j  + X_jᵀ Γ X_i
1674    //       + X_iᵀ Γ̇_j X  + Xᵀ Γ̇_j X_i
1675    //       + X_jᵀ Γ̇_i X  + Xᵀ Γ̇_i X_j
1676    //       + Xᵀ Γ̈_{ij} X.
1677    //
1678    //   9-term expansion for the BᵀPB branch (structurally identical to
1679    //   the existing β×β D²J₂[u,v] at firth.rs:~820-830 with (u,v)
1680    //   substituted by (τ_i, τ_j) and the appropriate Ḃ, B̈, Ṗ, P̈):
1681    //
1682    //     D²(BᵀPB)[τ_i,τ_j]  =
1683    //         B̈_{ij}ᵀ  P    B      +  Bᵀ       P    B̈_{ij}
1684    //       + Ḃ_iᵀ    P    Ḃ_j   +  Ḃ_jᵀ    P    Ḃ_i
1685    //       + Ḃ_iᵀ    Ṗ_j  B      +  Bᵀ       Ṗ_j  Ḃ_i
1686    //       + Ḃ_jᵀ    Ṗ_i  B      +  Bᵀ       Ṗ_i  Ḃ_j
1687    //       + Bᵀ       P̈_{ij} B.
1688    //
1689    //   Combining,
1690    //
1691    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [
1692    //         ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  −  D²(BᵀPB)[τ_i, τ_j]
1693    //     ].
1694    //
1695    // IMPLEMENTATION SKETCH (for 13b)
1696    //   • Build per-direction reduced quantities for τ_i and τ_j:
1697    //       (x_tau_reduced, η̇, İ, K̇, Ṁ operator pieces, ḣ, b_uvec = w''⊙η̇).
1698    //     The existing `dot_i_and_h_from_reduced` yields İ and ḣ already;
1699    //     the per-direction "A_u" analog is A_τ = K_r İ K_r, matching the
1700    //     FirthDirection form used by hphisecond_direction_apply.
1701    //   • Use apply_hadamard_gram_to_matrix with
1702    //       (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_τ_i), (K_r, A_τ_j),
1703    //                             (A_τ_i, A_τ_j) }
1704    //     to realize P-products, Ṗ_τ-products, and the (Ṁ_i ⊙ Ṁ_j) piece of
1705    //     P̈_{ij} without forming any n×n dense intermediate.
1706    //   • The pure-second piece `X_r K̈_{ij} X_rᵀ` decomposes into three
1707    //     reduced triple products (K_r Ï_{ij} K_r, K_r İ_i K_r İ_j K_r, and
1708    //     its transpose).  All are size-r×r in reduced coordinates.
1709    //   • For design-linear-in-τ smooths, X_{ij}=0 and η̈_{ij}=0, which
1710    //     prunes many sub-terms; callers who have X_{τ_i τ_j} available
1711    //     should pass it so the primitive remains exact on curved designs.
1712    //
1713    // ─────────────────────────────────────────────────────────────────────────
1714    //  Primitive B — D_β((H_φ)_τ|_β)[v]
1715    // ─────────────────────────────────────────────────────────────────────────
1716    //
1717    // WHAT IT COMPUTES
1718    //   Given a single τ-drift design X_τ, the β-fixed Firth partial
1719    //   (H_φ)_τ|_β encoded by FirthTauPartialKernel, and a β-direction
1720    //   vector v (of length p), returns the β-derivative of (H_φ)_τ|_β
1721    //   applied to an rhs block (so output is p×m, matching the pair's
1722    //   fixed_drift_deriv callback signature DriftDerivResult).  In
1723    //   symbols:
1724    //
1725    //     D_β((H_φ)_τ|_β)[v]  =  ½ [ D_β{(∂(XᵀΓX)/∂τ)|_β}[v]
1726    //                                 −  D_β{(∂(BᵀPB)/∂τ)|_β}[v] ].
1727    //
1728    // WHY (REML callsite)
1729    //   In the exact outer Hessian assembly (compute_drift_deriv_traces in
1730    //   unified.rs), the Ḧ_{ij} entry picks up
1731    //     tr(G_ε · D_β B_i[v_j])  +  tr(G_ε · D_β B_j[v_i]).
1732    //   For τ coordinates in the Firth+Logit path, B_τ = (penalty / design
1733    //   pieces) − (H_φ)_τ|_β, so the Firth share of D_β B_τ[v] is
1734    //     − D_β((H_φ)_τ|_β)[v].
1735    //   Hooking this primitive up through a FixedDriftDerivFn (returning
1736    //   DriftDerivResult::Dense of this p×p β-v action) is exactly what
1737    //   lets build_tau_hyper_coords pass a non-None fixed_drift_deriv
1738    //   closure into the unified evaluator, closing the approximation gap
1739    //   that firth_pair_terms_unavailable currently tracks.
1740    //
1741    // DERIVATION (β-derivative of each τ-partial term in direction v)
1742    //
1743    //   β enters only through η=Xβ, so designs X, X_τ, Q, X_r are all
1744    //   β-independent; D_β acts on w(η) and its derivatives, on I_r, K_r,
1745    //   M, h, and on η̇_τ = X_τ β.
1746    //
1747    //   Primary β-derivative building blocks (matches FirthDirection with
1748    //   deta := δη_v = X v):
1749    //     I'_v  := D_β I_r[v] = X_rᵀ diag(w' ⊙ δη_v) X_r      (g_u_reduced)
1750    //     A_v   := D_β K_r[v] = −K_r I'_v K_r                  (a_u_reduced)
1751    //     dh_v  := D_β h[v]    = −diag(X_r K_r I'_v K_r X_rᵀ)
1752    //                          = diag(X_r A_v X_rᵀ)            (dh)
1753    //     (w')_v  := D_β w'[v]  = w''  ⊙ δη_v
1754    //     (w'')_v := D_β w''[v] = w''' ⊙ δη_v
1755    //     (w''')_v:= D_β w'''[v]= w''''⊙ δη_v
1756    //     δη_{τ,v} := D_β(η̇_τ)[v] = X_τ v
1757    //
1758    //   Mixed τ-β pieces:
1759    //     D_β(İ_τ)[v]
1760    //       = X_{r,τ}ᵀ diag(w'' ⊙ δη_v) X_r
1761    //        + X_rᵀ diag(w'' ⊙ δη_v) X_{r,τ}
1762    //        + X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v
1763    //                     + w' ⊙ δη_{τ,v}) X_r.
1764    //     D_β(K̇_τ)[v]
1765    //       = −( A_v İ_τ K_r  +  K_r D_β(İ_τ)[v] K_r
1766    //             +  K_r İ_τ A_v ).
1767    //     D_β(Ṁ_τ)[v]
1768    //       = X_{r,τ} A_v X_rᵀ
1769    //        + X_r D_β(K̇_τ)[v] X_rᵀ
1770    //        + X_r A_v X_{r,τ}ᵀ.
1771    //     D_β(ḣ_τ)[v]
1772    //       = 2·diag(X_{r,τ} A_v X_rᵀ)
1773    //        + diag(X_r D_β(K̇_τ)[v] X_rᵀ).
1774    //
1775    //   Diagonal-term β-derivative ( (X_τᵀΓX + XᵀΓX_τ + XᵀΓ̇_τ X) branch ):
1776    //     D_β(X_τᵀ Γ X + Xᵀ Γ X_τ)[v]
1777    //       = X_τᵀ Γ_v X + Xᵀ Γ_v X_τ,
1778    //       Γ_v  := D_β Γ[v] = diag((w'')_v ⊙ h + w'' ⊙ dh_v)
1779    //                        = diag(w''' ⊙ δη_v ⊙ h + w'' ⊙ dh_v).
1780    //     D_β(Xᵀ Γ̇_τ X)[v]
1781    //       = Xᵀ Γ̇_{τ,v} X,
1782    //       Γ̇_{τ,v}
1783    //        := D_β Γ̇_τ[v]
1784    //         = diag( (w''')_v ⊙ η̇_τ ⊙ h
1785    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1786    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1787    //                 + (w'')_v ⊙ ḣ_τ
1788    //                 + w'' ⊙ D_β(ḣ_τ)[v] )
1789    //         = diag( w'''' ⊙ η̇_τ ⊙ δη_v ⊙ h
1790    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1791    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1792    //                 + w''' ⊙ δη_v ⊙ ḣ_τ
1793    //                 + w'' ⊙ D_β(ḣ_τ)[v] ).
1794    //
1795    //   Cross-coupling τ-β pieces for B:
1796    //     B_v  := D_β B[v]   = diag(w'' ⊙ δη_v) X               (b_uvec)
1797    //     B_τ  := ∂B/∂τ|_β   = diag(w'' ⊙ η̇_τ) X
1798    //                         + diag(w') X_τ.
1799    //     B_{τ,v}
1800    //         := D_β B_τ[v]  = diag( w''' ⊙ η̇_τ ⊙ δη_v
1801    //                                 + w'' ⊙ δη_{τ,v} ) X
1802    //                         + diag(w'' ⊙ δη_v) X_τ.
1803    //
1804    //   BᵀPB branch — 9 terms, obtained by applying the product rule to
1805    //   ∂(BᵀPB)/∂τ = Ḃ_τᵀ P B + Bᵀ Ṗ_τ B + Bᵀ P Ḃ_τ and then taking
1806    //   D_β(·)[v] of each factor:
1807    //
1808    //     D_β(Ḃ_τᵀ P B)[v]   = B_{τ,v}ᵀ P B + Ḃ_τᵀ P_v B + Ḃ_τᵀ P B_v,
1809    //     D_β(Bᵀ Ṗ_τ B)[v]   = B_vᵀ Ṗ_τ B  + Bᵀ P_{τ,v} B + Bᵀ Ṗ_τ B_v,
1810    //     D_β(Bᵀ P Ḃ_τ)[v]   = B_vᵀ P Ḃ_τ + Bᵀ P_v Ḃ_τ + Bᵀ P B_{τ,v}.
1811    //
1812    //   Here Ḃ_τ = B_τ above, and
1813    //     P_v := D_β P[v]         = 2 (M ⊙ M_v),   M_v = X_r A_v X_rᵀ.
1814    //     Ṗ_τ := ∂P/∂τ|_β         = 2 (M ⊙ M_τ),
1815    //       M_τ = X_{r,τ} K_r X_rᵀ + X_r K̇_τ X_rᵀ + X_r K_r X_{r,τ}ᵀ.
1816    //     P_{τ,v} := D_β(Ṗ_τ)[v]  = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
1817    //       M_{τ,v} = X_{r,τ} A_v X_rᵀ + X_r D_β(K̇_τ)[v] X_rᵀ + X_r A_v X_{r,τ}ᵀ.
1818    //
1819    //   Final primitive:
1820    //
1821    //     D_β((H_φ)_τ|_β)[v]  =  ½ [
1822    //           X_τᵀ Γ_v X  + Xᵀ Γ_v X_τ  + Xᵀ Γ̇_{τ,v} X
1823    //         −  (9-term BᵀPB β-τ expansion above)
1824    //     ].
1825    //
1826    //   Applied to an rhs block `R ∈ ℝ^{p × m}`, each Xᵀ(…) X R collapses
1827    //   to n-length row scalings of (X R) followed by Xᵀ; each Bᵀ P B
1828    //   variant uses apply_hadamard_gram_to_matrix with the correct
1829    //   (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_v), (K_r, K̇_τ),
1830    //     (K_r, D_β(K̇_τ)[v]), (A_v, K̇_τ), (K_r, K̇_τ) } to realize
1831    //   P, P_v, Ṗ_τ, P_{τ,v} actions.  All operators are r×r in reduced
1832    //   coordinates, matching the existing apply cost profile.
1833    //
1834    // IMPLEMENTATION SKETCH (for 13c)
1835    //   • Build `FirthDirection` from deta = X v (reuses existing
1836    //     direction_from_deta, giving I'_v, A_v, dh_v, b_uvec).
1837    //   • Build β-derivatives of the τ-specific fields of
1838    //     FirthTauPartialKernel (dotw1, dotw2, dot_h_partial, dot_k_reduced,
1839    //     and the implicit M_τ reduced-coords operator).  These become a
1840    //     new FirthTauBetaPartialKernel attached to the prepared state.
1841    //   • The apply step is then algebraically identical to
1842    //     hphi_tau_partial_apply but with every W-tensor weight replaced by
1843    //     its β-derivative in v, and every (M, K_r)-Gram replaced by the
1844    //     appropriate β-derivative Gram above.  The structure is regular
1845    //     enough that a single helper, shared with Primitive A, can absorb
1846    //     both pair dispatches.
1847    //
1848    // NOTE ON DESIGN-LINEAR SMOOTHS
1849    //   For the common case of design-linear-in-τ smooths (scale-moving
1850    //   anisotropic bases), X_i and X_τ are constant in τ, so X_{ij}=0 and
1851    //   η̈_{ij}=0.  The primitives collapse to their W-reweighted cores but
1852    //   remain matrix-free; no special fast path is needed because the
1853    //   zeroed terms simply drop out of the Hadamard-Gram assembly.
1854    //
1855    // ═════════════════════════════════════════════════════════════════════════
1856
1857    /// Primitive A — prepare step: assemble the τ_i × τ_j reduced kernel.
1858    ///
1859    /// Consumes the per-direction partial quantities produced by
1860    /// `dot_i_and_h_from_reduced` for τ_i and τ_j (plus an optional second
1861    /// design derivative X_{τ_i τ_j}), and returns a cached kernel carrying
1862    /// the M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij}-related reduced
1863    /// coordinates needed by `hphi_tau_tau_partial_apply`.
1864    ///
1865    /// This signature mirrors `hphi_tau_partial_prepare_from_partials` for
1866    /// consistency; the pair version needs both directions simultaneously
1867    /// (to realize the 9-term D² expansion) and therefore owns both
1868    /// `x_tau_{i,j}_reduced` and their η̇_i / η̇_j.
1869    ///
1870    pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1871        &self,
1872        x_tau_i_reduced: Array2<f64>,
1873        x_tau_j_reduced: Array2<f64>,
1874        deta_i_partial: &Array1<f64>,
1875        deta_j_partial: &Array1<f64>,
1876        dot_h_i_partial: Array1<f64>,
1877        dot_h_j_partial: Array1<f64>,
1878        dot_i_i_partial: Array2<f64>,
1879        dot_i_j_partial: Array2<f64>,
1880        x_tau_tau_reduced: Option<Array2<f64>>,
1881        deta_ij_partial: Option<Array1<f64>>,
1882    ) -> FirthTauTauPartialKernel {
1883        // K̇_i = -K_r İ_i K_r;  K̇_j = -K_r İ_j K_r.
1884        let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1885        let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1886        FirthTauTauPartialKernel {
1887            x_tau_i_reduced,
1888            x_tau_j_reduced,
1889            deta_i_partial: deta_i_partial.clone(),
1890            deta_j_partial: deta_j_partial.clone(),
1891            dot_h_i_partial,
1892            dot_h_j_partial,
1893            dot_k_i_reduced,
1894            dot_k_j_reduced,
1895            dot_i_i_partial,
1896            dot_i_j_partial,
1897            x_tau_tau_reduced,
1898            deta_ij_partial,
1899        }
1900    }
1901
1902    /// Primitive A — apply step: evaluate ∂²H_φ/∂τ_i ∂τ_j |_β against a p×m
1903    /// rhs block, returning a p×m block.
1904    ///
1905    /// Contract mirrors `hphi_tau_partial_apply`: the caller passes the two
1906    /// τ-drift designs and the prepared kernel, and receives the fixed-β
1907    /// second-τ Firth drift as a dense p×m action.  Matrix-free in n.
1908    ///
1909    pub(crate) fn hphi_tau_tau_partial_apply(
1910        &self,
1911        x_tau_i: &Array2<f64>,
1912        x_tau_j: &Array2<f64>,
1913        kernel: &FirthTauTauPartialKernel,
1914        rhs: &Array2<f64>,
1915    ) -> Array2<f64> {
1916        let p = self.x_dense.ncols();
1917        if rhs.nrows() != p {
1918            return Array2::<f64>::zeros((p, rhs.ncols()));
1919        }
1920        if rhs.ncols() == 0 || p == 0 {
1921            return Array2::<f64>::zeros((p, rhs.ncols()));
1922        }
1923        let n = self.x_dense.nrows();
1924        let m = rhs.ncols();
1925
1926        // Short aliases.
1927        let z = &self.x_reduced;
1928        let x_r = &self.x_reduced;
1929        let k = &self.k_reduced;
1930        let x_ri = &kernel.x_tau_i_reduced;
1931        let x_rj = &kernel.x_tau_j_reduced;
1932        let deta_i = &kernel.deta_i_partial;
1933        let deta_j = &kernel.deta_j_partial;
1934        let dh_i = &kernel.dot_h_i_partial;
1935        let dh_j = &kernel.dot_h_j_partial;
1936        let dot_k_i = &kernel.dot_k_i_reduced;
1937        let dot_k_j = &kernel.dot_k_j_reduced;
1938        let dot_i_i = &kernel.dot_i_i_partial;
1939        let dot_i_j = &kernel.dot_i_j_partial;
1940
1941        // Optional second-design pieces: default to zero when the design is
1942        // τ-linear (η̈_{ij} = 0, X_{ij} = 0).
1943        let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1944        let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1945        let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1946        let zeros_n = Array1::<f64>::zeros(n);
1947        let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1948
1949        // ─────────────────────────────────────────────────────────────────
1950        //  η̇ vectors in β-rhs space (η_V := X V, η_{i,V} := X_i V, etc.)
1951        // ─────────────────────────────────────────────────────────────────
1952        let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1953            (n, m, p),
1954            (n, m, p),
1955            (n, m, p),
1956        ]) {
1957            let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1958                || fast_ab(&self.x_dense, rhs),
1959                || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1960            );
1961            (eta_v, eta_i_v, eta_j_v)
1962        } else {
1963            (
1964                fast_ab(&self.x_dense, rhs),
1965                fast_ab(x_tau_i, rhs),
1966                fast_ab(x_tau_j, rhs),
1967            )
1968        }; // n×m blocks
1969        // X_{ij} V from the reduced second-derivative design:
1970        //   reduce_explicit_design: X_{r,τ} = diag(√a) X_τ Q,
1971        //   invert:  X_{ij} = diag(1/√a) X_{r,ij} Qᵀ.
1972        let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1973            let qt_v = fast_atb(&self.q_basis, rhs); // r×m
1974            let mut out = fast_ab(x_rij, &qt_v); // n×m in sqrt(a)-scaled space
1975            RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1976                &mut out,
1977                self.observation_weight_sqrt.as_ref(),
1978            );
1979            out
1980        } else {
1981            Array2::<f64>::zeros((n, m))
1982        };
1983
1984        // ─────────────────────────────────────────────────────────────────
1985        //  Shared per-direction reduced operators
1986        //    A_τ = K İ K   (reduced analog of T_τ = K I_τ K)
1987        //    K̇_τ = -A_τ  (already cached)
1988        // ─────────────────────────────────────────────────────────────────
1989        let a_i_reduced = -dot_k_i; // K İ_i K = -K̇_i
1990        let a_j_reduced = -dot_k_j;
1991
1992        // ─────────────────────────────────────────────────────────────────
1993        //  Ï_{ij}  — second cross derivative of reduced Fisher
1994        // ─────────────────────────────────────────────────────────────────
1995        //   Ï_{ij} = X_{r,ij}ᵀ W X_r + X_rᵀ W X_{r,ij}
1996        //          + X_{r,i}ᵀ W X_{r,j} + X_{r,j}ᵀ W X_{r,i}
1997        //          + X_{r,i}ᵀ Ẇ_j X_r + X_rᵀ Ẇ_j X_{r,i}
1998        //          + X_{r,j}ᵀ Ẇ_i X_r + X_rᵀ Ẇ_i X_{r,j}
1999        //          + X_rᵀ Ẅ_{ij} X_r.
2000        //   Ẇ_α   = diag(w' ⊙ η̇_α),  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_ij).
2001        let dw_i = &self.w1 * deta_i;
2002        let dw_j = &self.w1 * deta_j;
2003        let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
2004        let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
2005        if x_tau_tau_is_some {
2006            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2007            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2008        }
2009        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
2010        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
2011        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
2012        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
2013        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
2014        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
2015        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2016
2017        // K̈_{ij} = −K Ï K + K İ_i K İ_j K + K İ_j K İ_i K.
2018        //   Using K İ_α K = −K̇_α = a_α_reduced, the two product terms collapse to
2019        //   K İ_i K İ_j K = a_i_reduced · İ_j · K,
2020        //   K İ_j K İ_i K = a_j_reduced · İ_i · K.
2021        let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
2022            + a_i_reduced.dot(dot_i_j).dot(k)
2023            + a_j_reduced.dot(dot_i_i).dot(k);
2024
2025        // ─────────────────────────────────────────────────────────────────
2026        //  ḧ_{ij}
2027        // ─────────────────────────────────────────────────────────────────
2028        //   ḧ_ij = 2 diag(X_{r,ij} K X_rᵀ)
2029        //        + diag(X_r K̈_ij X_rᵀ)
2030        //        + 2 diag(X_{r,i} K̇_j X_rᵀ)
2031        //        + 2 diag(X_{r,j} K̇_i X_rᵀ)
2032        //        + 2 diag(X_{r,i} K X_{r,j}ᵀ).
2033        // Using diag(A Bᵀ) = rowwise_dot(A, B):
2034        let dh_ij: Array1<f64> = {
2035            let r = k.ncols();
2036            let can_join = RemlState::should_join_independent_dense_products(&[
2037                (n, r, r),
2038                (n, r, r),
2039                (n, r, r),
2040                (n, r, r),
2041            ]);
2042            let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
2043                let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
2044                    || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
2045                    || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
2046                );
2047                (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
2048            } else {
2049                (
2050                    fast_ab(x_r, &k_ddot),
2051                    fast_ab(x_ri, dot_k_j),
2052                    fast_ab(x_rj, dot_k_i),
2053                    fast_ab(x_ri, k),
2054                )
2055            };
2056
2057            let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
2058            acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2059            acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2060            acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
2061            if x_tau_tau_is_some {
2062                let rij_k = fast_ab(x_rij, k);
2063                acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2064            }
2065            acc
2066        };
2067
2068        // ─────────────────────────────────────────────────────────────────
2069        //  Γ, Γ̇_i, Γ̇_j, Γ̈_ij  (diagonal row-weight n-vectors)
2070        // ─────────────────────────────────────────────────────────────────
2071        //   γ        = w'' ⊙ h
2072        //   γ̇_i     = w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i
2073        //   γ̈_ij   = w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
2074        //            + w''' ⊙ η̈_ij ⊙ h
2075        //            + w''' ⊙ η̇_i ⊙ ḣ_j
2076        //            + w''' ⊙ η̇_j ⊙ ḣ_i
2077        //            + w'' ⊙ ḧ_ij
2078        let gamma = &self.w2 * &self.h_diag;
2079        let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
2080        let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
2081        let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
2082            + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
2083                + &(&(&self.w3 * deta_i) * dh_j)
2084                + &(&(&self.w3 * deta_j) * dh_i)
2085                + &(&self.w2 * &dh_ij));
2086
2087        // ─────────────────────────────────────────────────────────────────
2088        //  Diagonal-term β-rhs contributions:
2089        //    ∂²(XᵀΓX)/∂τ_i∂τ_j · V
2090        //  = X_{ij}ᵀ (γ ⊙ η_V)       + Xᵀ (γ ⊙ η_{ij,V})         [if X_ij]
2091        //    + X_iᵀ (γ ⊙ η_{j,V})    + X_jᵀ (γ ⊙ η_{i,V})
2092        //    + X_iᵀ (γ̇_j ⊙ η_V)     + Xᵀ (γ̇_j ⊙ η_{i,V})
2093        //    + X_jᵀ (γ̇_i ⊙ η_V)     + Xᵀ (γ̇_i ⊙ η_{j,V})
2094        //    + Xᵀ (γ̈_ij ⊙ η_V).
2095        // ─────────────────────────────────────────────────────────────────
2096        let mut diag_term = Array2::<f64>::zeros((p, m));
2097        let gamma_col = gamma.view().insert_axis(Axis(1));
2098        let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
2099        let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
2100        let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
2101
2102        // X_iᵀ (γ ⊙ η_{j,V}) + X_jᵀ (γ ⊙ η_{i,V})
2103        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
2104        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
2105        // X_iᵀ (γ̇_j ⊙ η_V) + X_jᵀ (γ̇_i ⊙ η_V)
2106        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
2107        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
2108        // Xᵀ (γ̇_j ⊙ η_{i,V}) + Xᵀ (γ̇_i ⊙ η_{j,V})
2109        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
2110        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
2111        // Xᵀ (γ̈_ij ⊙ η_V)
2112        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
2113        // X_{ij}ᵀ (γ ⊙ η_V) + Xᵀ (γ ⊙ η_{ij,V})
2114        if x_tau_tau_is_some {
2115            // X_{ij}ᵀ = Q X_{r,ij}ᵀ · diag(1/√a)  (inverse of the reduce shim),
2116            // but caller supplies the reduced second-derivative design.  We
2117            // form X_{ij}ᵀ Y as q_basis · (X_{r,ij}ᵀ · diag(1/√a)·Y) = Q · X_{r,ij}ᵀ (Y unscaled).
2118            // When no observation weights, X_{r,ij} = X_{ij} Q and
2119            //   X_{ij}ᵀ Y = Q X_{r,ij}ᵀ Y.
2120            let y: Array2<f64> = &eta_v * &gamma_col;
2121            let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2122                let mut y_scaled = y.clone();
2123                RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2124                    &mut y_scaled,
2125                    self.observation_weight_sqrt.as_ref(),
2126                );
2127                self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2128            } else {
2129                self.q_basis.dot(&x_rij.t().dot(&y))
2130            };
2131            diag_term = diag_term + xt_ij_y;
2132            diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
2133        }
2134
2135        // ─────────────────────────────────────────────────────────────────
2136        //  BᵀPB branch — 9-term expansion.
2137        //
2138        //  Represent each B-like operator as an "n-row scaling vector for the
2139        //  X part plus tails along X_τ and X_{ij}".  For rhs V, define the
2140        //  row-scaled η-space blocks R(B) = diag(scale) X V + tails.  Then
2141        //  Bᵀ (P action) R is assembled by row-scaling and left-multiplying
2142        //  the appropriate full designs.
2143        // ─────────────────────────────────────────────────────────────────
2144
2145        // B V row-block (eta-space):  B V = diag(w') X V.
2146        let w1_col = self.w1.view().insert_axis(Axis(1));
2147        let b_v = &eta_v * &w1_col;
2148
2149        // Ḃ_i V = diag(w'' ⊙ η̇_i) X V + diag(w') X_i V.
2150        let w2_deta_i = &self.w2 * deta_i;
2151        let w2_deta_j = &self.w2 * deta_j;
2152        let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
2153        let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
2154        let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
2155        let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
2156
2157        // B̈_{ij} V =
2158        //   diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_ij) X V
2159        //   + diag(w'' ⊙ η̇_i) X_j V
2160        //   + diag(w'' ⊙ η̇_j) X_i V
2161        //   + diag(w') X_{ij} V.
2162        let w3_didj = &(&self.w3 * deta_i) * deta_j;
2163        let w2_dij = &self.w2 * deta_ij;
2164        let bddot_scale = &w3_didj + &w2_dij;
2165        let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
2166        let mut bddot_ij_v = &eta_v * &bddot_scale_col;
2167        bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
2168        bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
2169        bddot_ij_v += &(&eta_ij_v * &w1_col);
2170
2171        // P V  (columnwise, using K ⊙ K Hadamard gram on Z = X_r).
2172        let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
2173        let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
2174
2175        // Ṗ_i, Ṗ_j applied to B V, Ḃ_j V, Ḃ_i V — use the existing
2176        // apply_mtau_to_matrix helper, which computes 2(M ⊙ Ṁ_τ) · mat.
2177        //
2178        // Construct a lightweight "FirthTauPartialKernel"-shaped tuple only for
2179        // apply_mtau_to_matrix; we mirror its input contract inline to avoid
2180        // owning a FirthTauPartialKernel copy here.
2181        let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
2182        let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
2183        let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
2184        let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
2185
2186        // P Ḃ_j V and P Ḃ_i V.
2187        let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
2188        let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
2189
2190        // P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V  + 2 (M ⊙ M̈_{ij}) V.
2191        let p_ddot_b_v = self.apply_p_ddot_ij(
2192            x_r,
2193            x_ri,
2194            x_rj,
2195            x_rij,
2196            k,
2197            dot_k_i,
2198            dot_k_j,
2199            &k_ddot,
2200            x_tau_tau_is_some,
2201            &b_v,
2202        );
2203
2204        // Assemble 9 terms of D²(BᵀPB)[τ_i, τ_j] · V.
2205        //   term1 = B̈_ijᵀ P B V + Bᵀ P B̈_ij V
2206        //   term2 = Ḃ_iᵀ P Ḃ_j V + Ḃ_jᵀ P Ḃ_i V
2207        //   term3 = Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
2208        //   term4 = Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
2209        //   term5 = Bᵀ P̈_ij B V
2210        //
2211        // "Bᵀ Q V" with B = diag(w') X equals left_scaled_xt(w1, Q V).
2212        // "Ḃ_iᵀ Q V" = diag(w'' ⊙ η̇_i) X acting on the left, plus
2213        //              diag(w') X_i on the left.  In transpose:
2214        //   Ḃ_iᵀ Q V = Xᵀ (diag(w'' ⊙ η̇_i) Q V) + X_iᵀ (diag(w') Q V).
2215        // "B̈_ijᵀ Q V" mirrors B̈_ij above in transpose.
2216
2217        let apply_bdot_tau_t =
2218            |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
2219                let scale_col = scale_deta.view().insert_axis(Axis(1));
2220                self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
2221            };
2222
2223        let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
2224            let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
2225            let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
2226            out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
2227            out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
2228            if x_tau_tau_is_some {
2229                // X_{ij}ᵀ (w1 ⊙ Q V)
2230                let y = q_v * &w1_col;
2231                let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2232                    let mut y_scaled = y.clone();
2233                    RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2234                        &mut y_scaled,
2235                        self.observation_weight_sqrt.as_ref(),
2236                    );
2237                    self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2238                } else {
2239                    self.q_basis.dot(&x_rij.t().dot(&y))
2240                };
2241                out = out + contrib;
2242            }
2243            out
2244        };
2245
2246        // term1
2247        let t1a = apply_bddot_ij_t(&p_bv);
2248        let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2249        // term2
2250        let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2251        let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2252        // term3: Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
2253        let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2254        let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2255        // term4: Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
2256        let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2257        let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2258        // term5
2259        let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2260
2261        let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2262
2263        0.5 * (diag_term - d2_bpb)
2264    }
2265
2266    /// Pair-level exact Firth kernel at fixed β for a (τ_i, τ_j) outer
2267    /// coordinate pair.
2268    ///
2269    /// Returns the two SCALAR- and P-VECTOR-valued second-derivative
2270    /// objects that the unified REML evaluator threads into
2271    /// `HyperCoordPair::{a,g}` as additive Firth contributions, plus an
2272    /// optional prepared Primitive-A `FirthTauTauPartialKernel` that the
2273    /// pair-callback can reuse for the `b_operator` action.
2274    ///
2275    /// ═════════════════════════════════════════════════════════════════════
2276    ///  DERIVATIONS (fixed β, reduced-basis identifiable coords).
2277    ///
2278    ///  Φ = 0.5 log|I_r| − 0.5 log|S_r|,   K_r = I_r⁻¹,   G_r = diag(S_r⁻¹)
2279    ///  Φ_{τ_i}|β = 0.5 tr(K_r İ_{r,i}) − 0.5 tr(G_r Ṡ_{r,i}).
2280    ///
2281    /// ┌── pair.a scalar Φ_{τ_i τ_j}|β ────────────────────────────────────┐
2282    ///  ∂/∂τ_j [0.5 tr(K_r İ_{r,i})]
2283    ///    = 0.5 tr(K̇_{r,j} İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2284    ///    = −0.5 tr(K_r İ_{r,j} K_r İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2285    ///
2286    ///  ∂/∂τ_j [−0.5 tr(G_r Ṡ_{r,i})]
2287    ///    = −0.5 tr(Ġ_{r,j} Ṡ_{r,i}) − 0.5 tr(G_r S̈_{r,ij})
2288    ///  (G_r diagonal in canonical basis →
2289    ///   Ġ_{r,j}_kk = −G_r_kk² · diag(Ṡ_{r,j})_kk.)
2290    ///
2291    ///  Ï_{r,ij} is the same 9-term Fisher cross used by Primitive A
2292    ///  (see `hphi_tau_tau_partial_apply`:i_ddot block).
2293    ///
2294    ///  S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2295    ///            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}.
2296    /// └─────────────────────────────────────────────────────────────────┘
2297    ///
2298    /// ┌── pair.g p-vector (gΦ)_{τ_i τ_j}|β ───────────────────────────────┐
2299    ///  (gΦ)_{τ_i} = 0.5 X_{τ_i}^T (w1 ⊙ h)
2300    ///              + 0.5 X^T [ (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i ]
2301    ///
2302    ///  Differentiating wrt τ_j at fixed β, using η̇_α = X_α β, η̈_{ij} =
2303    ///  X_{ij} β (when x_tau_tau is provided, else 0), and ḣ_α, ḧ_{ij}
2304    ///  from Primitive A:
2305    ///
2306    ///  term_A = 0.5 ∂/∂τ_j [X_{τ_i}^T (w1 ⊙ h)]
2307    ///        = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)          [if X_{ij} present]
2308    ///        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2309    ///
2310    ///  term_B = 0.5 ∂/∂τ_j [X^T · v_{τ_i}] with
2311    ///            v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2312    ///        = 0.5 X_{τ_j}^T v_{τ_i}
2313    ///        + 0.5 X^T · v̇_{τ_i,τ_j}
2314    ///
2315    ///  where the inner derivative
2316    ///  v̇_{τ_i,τ_j} = (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h   (from ∂w2 = w3 ⊙ η̇_j)
2317    ///              + (w2 ⊙ η̈_ij) ⊙ h          (from ∂η̇_i = η̈_{ij})
2318    ///              + (w2 ⊙ η̇_i) ⊙ ḣ_j        (from ∂h = ḣ_j)
2319    ///              + (w2 ⊙ η̇_j) ⊙ ḣ_i        (from ∂w1 = w2 ⊙ η̇_j, ⊙ ḣ_i)
2320    ///              +  w1 ⊙ ḧ_{ij}             (from ∂ḣ_i = ḧ_{ij}).
2321    /// └─────────────────────────────────────────────────────────────────┘
2322    ///
2323    /// ALL Ï_{r,ij}, η̈_{ij}, ḣ_i, ḧ_{ij} computations are identical to
2324    /// those already computed inside Primitive A's `hphi_tau_tau_partial_apply`.
2325    /// We replicate only the pieces needed to yield the scalar and p-vector
2326    /// outputs to avoid computing the full p×m action when unnecessary.
2327    ///
2328    pub(crate) fn exact_tau_tau_kernel(
2329        &self,
2330        x_tau_i: &Array2<f64>,
2331        x_tau_j: &Array2<f64>,
2332        x_tau_tau: Option<&Array2<f64>>,
2333        beta: &Array1<f64>,
2334        include_hphi_tau_tau_kernel: bool,
2335    ) -> FirthTauTauExactKernel {
2336        let deta_i = x_tau_i.dot(beta);
2337        let deta_j = x_tau_j.dot(beta);
2338        let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2339
2340        let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2341        let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2342        let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2343
2344        let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2345        let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2346
2347        // Ï_{r,ij} = X_{r,ij}^T W X_r + X_r^T W X_{r,ij}
2348        //            + X_{r,i}^T W X_{r,j} + X_{r,j}^T W X_{r,i}
2349        //            + X_{r,i}^T Ẇ_j X_r + X_r^T Ẇ_j X_{r,i}
2350        //            + X_{r,j}^T Ẇ_i X_r + X_r^T Ẇ_i X_{r,j}
2351        //            + X_r^T Ẅ_{ij} X_r
2352        // Ẇ_α = diag(w' ⊙ η̇_α);  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_{ij}).
2353        let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2354        let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2355        let dw_i = &self.w1 * &deta_i;
2356        let dw_j = &self.w1 * &deta_j;
2357        let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2358
2359        let x_r = &self.x_reduced;
2360        let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2361        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2362            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2363            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2364        }
2365        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2366        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2367        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2368        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2369        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2370        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2371        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2372
2373        // pair.a likelihood contribution:
2374        //   0.5 tr(K_r Ï_{r,ij}) − 0.5 tr(K_r İ_{r,j} K_r İ_{r,i}).
2375        // K_r İ_{r,α} — reuse inline dot().
2376        let k = &self.k_reduced;
2377        let k_dot_i_i = k.dot(&dot_i_i);
2378        let k_dot_i_j = k.dot(&dot_i_j);
2379        let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2380            - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2381
2382        // pair.a penalty-basis contribution:
2383        //   Ṡ_{r,α} = X_{r,α}^T X_r + X_r^T X_{r,α}
2384        //   S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2385        //            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}
2386        //   tr(G_r Ṡ_{r,i}) = Σ_k G_r_kk · diag(Ṡ_{r,i})_kk
2387        //   tr(Ġ_{r,j} Ṡ_{r,i}) = −Σ_k G_r_kk² · diag(Ṡ_{r,j})_kk · diag(Ṡ_{r,i})_kk
2388        let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2389        let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2390        let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2391        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2392            s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2393        }
2394        s_ddot = s_ddot
2395            + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2396            + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2397        // With G_r = diag(g) in the canonical reduced basis (where
2398        // S_r is diagonal), S_r + τ·Ṡ is generally non-diagonal under
2399        // perturbation, so Ġ_j = −G Ṡ_j G picks up OFF-DIAGONAL terms:
2400        //     (Ġ_j)_{kl} = −G_k · (Ṡ_j)_{kl} · G_l.
2401        // Hence tr(Ġ_j Ṡ_i) = −Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{lk}.
2402        // Using symmetry of Ṡ_i (and Ṡ_j):
2403        //     −0.5 tr(Ġ_j Ṡ_i) = +0.5 Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{kl}.
2404        // The S̈_{ij} trace against diagonal G_r picks only the diagonal.
2405        let g_inv = &self.x_metric_reduced_inv_diag;
2406        let rdim = k.nrows();
2407        let mut a_pen = 0.0_f64;
2408        for kk in 0..rdim {
2409            for ll in 0..rdim {
2410                a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2411            }
2412            a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2413        }
2414        let phi_tau_tau_partial = a_lik + a_pen;
2415
2416        // ─── pair.g p-vector: (gΦ)_{τ_i τ_j}|β ──────────────────────────
2417        //
2418        // Assemble ḧ_{ij} identically to Primitive A's body.  We need:
2419        //   K̇_{r,α} = −K_r İ_{r,α} K_r,
2420        //   K̈_{r,ij} = −K_r Ï_{r,ij} K_r + K_r İ_{r,i} K_r İ_{r,j} K_r
2421        //                                 + K_r İ_{r,j} K_r İ_{r,i} K_r.
2422        let dot_k_i = -k.dot(&dot_i_i).dot(k);
2423        let dot_k_j = -k.dot(&dot_i_j).dot(k);
2424        let a_i_red = -&dot_k_i; // K İ_i K
2425        let a_j_red = -&dot_k_j; // K İ_j K
2426        let k_ddot: Array2<f64> =
2427            -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2428
2429        // ḧ_{ij} = 2 diag(X_{r,ij} K X_r^T)
2430        //        + diag(X_r K̈_{ij} X_r^T)
2431        //        + 2 diag(X_{r,i} K̇_j X_r^T)
2432        //        + 2 diag(X_{r,j} K̇_i X_r^T)
2433        //        + 2 diag(X_{r,i} K X_{r,j}^T).
2434        let n = self.x_dense.nrows();
2435        let mut dh_ij = Array1::<f64>::zeros(n);
2436        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2437            let rij_k = x_rij.dot(k);
2438            dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2439        }
2440        let xr_kddot = x_r.dot(&k_ddot);
2441        dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2442        let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2443        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2444        let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2445        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2446        let ri_k = x_tau_i_reduced.dot(k);
2447        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2448
2449        // term_A = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)
2450        //        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2451        let w1_h = &self.w1 * &self.h_diag;
2452        let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2453        if let Some(x_ij) = x_tau_tau.as_ref() {
2454            gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2455        }
2456        let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2457        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2458
2459        // term_B pieces:  v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2460        let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2461        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2462
2463        // v̇_{τ_i,τ_j} =
2464        //    (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h
2465        //  + (w2 ⊙ η̈_{ij}) ⊙ h
2466        //  + (w2 ⊙ η̇_i) ⊙ ḣ_j
2467        //  + (w2 ⊙ η̇_j) ⊙ ḣ_i
2468        //  +  w1 ⊙ ḧ_{ij}.
2469        let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2470        v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2471        v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2472        v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2473        v_dot_ij += &(&self.w1 * &dh_ij);
2474        gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2475
2476        let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2477            Some(self.hphi_tau_tau_partial_prepare_from_partials(
2478                x_tau_i_reduced,
2479                x_tau_j_reduced,
2480                &deta_i,
2481                &deta_j,
2482                dot_h_i,
2483                dot_h_j,
2484                dot_i_i,
2485                dot_i_j,
2486                x_tau_tau_reduced,
2487                deta_ij,
2488            ))
2489        } else {
2490            None
2491        };
2492
2493        FirthTauTauExactKernel {
2494            phi_tau_tau_partial,
2495            gphi_tau_tau,
2496            tau_tau_kernel,
2497        }
2498    }
2499
2500    /// Apply `Ṗ_τ V = 2 (M ⊙ Ṁ_τ) V` given the reduced τ-drift design
2501    /// `x_tau_reduced` and the reduced Fisher-inverse drift `dot_k_reduced`.
2502    ///
2503    /// This mirrors the body of `apply_mtau_to_matrix` but accepts the
2504    /// x_tau/dot_k pieces directly, letting Primitive A reuse the same
2505    /// matrix-free Ṗ_τ applies without owning a `FirthTauPartialKernel`.
2506    pub(crate) fn apply_mtau_from_reduced(
2507        &self,
2508        x_tau_reduced: &Array2<f64>,
2509        dot_k_reduced: &Array2<f64>,
2510        mat: &Array2<f64>,
2511    ) -> Array2<f64> {
2512        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2513            return Array2::<f64>::zeros(mat.raw_dim());
2514        }
2515        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2516        for col in 0..mat.ncols() {
2517            let v = mat.column(col).to_owned();
2518            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2519            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2520            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2521
2522            let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2523            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2524            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2525
2526            let t3 =
2527                RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2528
2529            let y = 2.0 * (t1 + t2 + t3);
2530            out.column_mut(col).assign(&y);
2531        }
2532        out
2533    }
2534
2535    /// Apply `P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V + 2 (M ⊙ M̈_{ij}) V` columnwise.
2536    ///
2537    /// `M̈_{ij}` expands into 9 pieces `Y_α C Y_βᵀ`; `Ṁ_i ⊙ Ṁ_j` into 9 cross
2538    /// pieces `(Y_{1,α} B_{1,α} W_{1,α}ᵀ) ⊙ (Y_{2,β} B_{2,β} W_{2,β}ᵀ)`.  Both
2539    /// are evaluated via the matrix-free identities:
2540    ///
2541    ///   [(ZAZᵀ) ⊙ (YBWᵀ) v]_i   = rowwise_bilinear(Y, B · (Wᵀdiag(v)Z) · A, Z)_i,
2542    ///   [(YBWᵀ) ⊙ (Y'B'W'ᵀ) v]_i= rowwise_bilinear(Y, B · (Wᵀdiag(v)W') · B'ᵀ, Y')_i,
2543    ///
2544    /// with S := row-wise reducedweighted Gram.
2545    pub(crate) fn apply_p_ddot_ij(
2546        &self,
2547        x_r: &Array2<f64>,
2548        x_ri: &Array2<f64>,
2549        x_rj: &Array2<f64>,
2550        x_rij: &Array2<f64>,
2551        k: &Array2<f64>,
2552        dot_k_i: &Array2<f64>,
2553        dot_k_j: &Array2<f64>,
2554        k_ddot: &Array2<f64>,
2555        x_tau_tau_is_some: bool,
2556        mat: &Array2<f64>,
2557    ) -> Array2<f64> {
2558        let n = self.x_dense.nrows();
2559        let m = mat.ncols();
2560        if mat.nrows() != n || m == 0 {
2561            return Array2::<f64>::zeros(mat.raw_dim());
2562        }
2563        let mut out = Array2::<f64>::zeros((n, m));
2564        for col in 0..m {
2565            let v = mat.column(col).to_owned();
2566            // Shared reducedweighted Grams for this column.  Only the Grams
2567            // actually appearing in the 18 pieces below are computed.
2568            let s_zz = RemlState::reducedweighted_gram(x_r, &v); // Z'diag(v)Z
2569            let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); // Z'diag(v)Y_j
2570            let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); // Y_i'diag(v)Z
2571            let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); // Y_j'diag(v)Z
2572            let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); // Y_i'diag(v)Y_j
2573
2574            // ── 4 (Ṁ_i ⊙ Ṁ_j) v ──
2575            // Ṁ_i has three pieces:
2576            //   P_i,a = Y_i K Zᵀ          — Y=Y_i, B=K, W=Z
2577            //   P_i,b = Z K̇_i Zᵀ         — Y=Z,   B=K̇_i, W=Z
2578            //   P_i,c = Z K Y_iᵀ          — Y=Z,   B=K,  W=Y_i
2579            // And symmetrically for Ṁ_j with (i→j).
2580            //
2581            // For each cross pair (α, β), compute
2582            //   core = B_α · (W_αᵀ diag(v) W_β) · B_βᵀ,
2583            //   y_piece = rowwise_bilinear(Y_α, core, Y_β),
2584            // then sum all 9 and scale by 4.
2585            let mut mdot_mdot = Array1::<f64>::zeros(n);
2586            // (a_i, a_j): Y_i, K, Z  ×  Y_j, K, Z  → W_α=Z, W_β=Z, S = s_zz
2587            {
2588                let core = k.dot(&s_zz).dot(&k.t());
2589                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2590            }
2591            // (a_i, b_j): Y_i, K, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K · s_zz · K̇_jᵀ
2592            {
2593                let core = k.dot(&s_zz).dot(&dot_k_j.t());
2594                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2595            }
2596            // (a_i, c_j): Y_i, K, Z  ×  Z, K, Y_j  → S = s_zj; core = K · s_zj · Kᵀ
2597            {
2598                let core = k.dot(&s_zj).dot(&k.t());
2599                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2600            }
2601            // (b_i, a_j): Z, K̇_i, Z  ×  Y_j, K, Z  → S = s_zz; core = K̇_i · s_zz · Kᵀ
2602            {
2603                let core = dot_k_i.dot(&s_zz).dot(&k.t());
2604                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2605            }
2606            // (b_i, b_j): Z, K̇_i, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K̇_i · s_zz · K̇_jᵀ
2607            {
2608                let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2609                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2610            }
2611            // (b_i, c_j): Z, K̇_i, Z  ×  Z, K, Y_j  → S = s_zj; core = K̇_i · s_zj · Kᵀ
2612            {
2613                let core = dot_k_i.dot(&s_zj).dot(&k.t());
2614                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2615            }
2616            // (c_i, a_j): Z, K, Y_i  ×  Y_j, K, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2617            //   core = K · s_iz · Kᵀ
2618            {
2619                let core = k.dot(&s_iz).dot(&k.t());
2620                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2621            }
2622            // (c_i, b_j): Z, K, Y_i  ×  Z, K̇_j, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2623            //   core = K · s_iz · K̇_jᵀ
2624            {
2625                let core = k.dot(&s_iz).dot(&dot_k_j.t());
2626                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2627            }
2628            // (c_i, c_j): Z, K, Y_i  ×  Z, K, Y_j  → S = Y_iᵀ diag(v) Y_j = s_ij;
2629            //   core = K · s_ij · Kᵀ
2630            {
2631                let core = k.dot(&s_ij).dot(&k.t());
2632                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2633            }
2634
2635            // ── 2 (M ⊙ M̈_{ij}) v ──
2636            // Each piece has the form Y_α C W_βᵀ; M = Z K Zᵀ with A=K.
2637            // Identity:  [(ZAZᵀ) ⊙ (Y_α C W_βᵀ) v]_i
2638            //          = rowwise_bilinear(Y_α, C · (W_βᵀ diag(v) Z) · A, Z).
2639            let mut m_mddot = Array1::<f64>::zeros(n);
2640            // (a) Y_α = X_{r,ij}, C = K, W_β = X_r  → W_βᵀ diag(v) Z = s_zz
2641            if x_tau_tau_is_some {
2642                let core = k.dot(&s_zz).dot(k);
2643                m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2644            }
2645            // (b) Y_α = X_r, C = K, W_β = X_{r,ij} → W_βᵀ diag(v) Z = X_{r,ij}ᵀ diag(v) Z
2646            if x_tau_tau_is_some {
2647                let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2648                let core = k.dot(&s_ijz).dot(k);
2649                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2650            }
2651            // (c) Y_α = X_{r,i}, C = K̇_j, W_β = X_r → S = s_zz
2652            {
2653                let core = dot_k_j.dot(&s_zz).dot(k);
2654                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2655            }
2656            // (d) Y_α = X_r, C = K̇_j, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2657            {
2658                let core = dot_k_j.dot(&s_iz).dot(k);
2659                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2660            }
2661            // (e) Y_α = X_{r,j}, C = K̇_i, W_β = X_r → S = s_zz
2662            {
2663                let core = dot_k_i.dot(&s_zz).dot(k);
2664                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2665            }
2666            // (f) Y_α = X_r, C = K̇_i, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2667            {
2668                let core = dot_k_i.dot(&s_jz).dot(k);
2669                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2670            }
2671            // (g) Y_α = X_{r,i}, C = K, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2672            {
2673                let core = k.dot(&s_jz).dot(k);
2674                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2675            }
2676            // (h) Y_α = X_{r,j}, C = K, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2677            {
2678                let core = k.dot(&s_iz).dot(k);
2679                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2680            }
2681            // (i) Y_α = X_r, C = K̈_ij, W_β = X_r → S = s_zz
2682            {
2683                let core = k_ddot.dot(&s_zz).dot(k);
2684                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2685            }
2686
2687            // P̈_{ij} = ∂²(M⊙M)/∂τ_i∂τ_j = 2(Ṁ_i ⊙ Ṁ_j) + 2(M ⊙ M̈_{ij}),
2688            // with Ṁ_τ = ∂M/∂τ (NOT the pair-squared derivative).  Factor is 2,
2689            // not 4 — the earlier "4·Ṁ_i⊙Ṁ_j" was a sign-of-the-derivative
2690            // confusion between Ṁ and ∂(M⊙M)/∂τ = 2(M⊙Ṁ).
2691            let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2692            out.column_mut(col).assign(&col_out);
2693        }
2694        out
2695    }
2696
2697    /// Primitive B — prepare step: assemble the reduced kernel for
2698    /// D_β((H_φ)_τ|_β)[v].
2699    ///
2700    /// Consumes the existing `FirthTauPartialKernel`, the τ-drift partials
2701    /// (`deta_partial = η̇_τ = X_τ β` and `dot_i_partial = İ_τ`), the
2702    /// β-direction `FirthDirection` built from `deta = X v`, and
2703    /// `x_tau_v = X_τ v`, and returns a cached kernel carrying the mixed
2704    /// β-τ reduced quantities A_v, dh_v, D_β(İ_τ)[v], D_β(K̇_τ)[v],
2705    /// D_β(ḣ_τ)[v], and the w-chain derivatives needed by
2706    /// `d_beta_hphi_tau_partial_apply`.
2707    pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2708        &self,
2709        tau_kernel: &FirthTauPartialKernel,
2710        deta_partial: &Array1<f64>,
2711        dot_i_partial: &Array2<f64>,
2712        beta_direction: &FirthDirection,
2713        x_tau_v: &Array1<f64>,
2714    ) -> FirthTauBetaPartialKernel {
2715        // D_β(İ_τ)[v] — three-piece symmetric form from the product rule on
2716        //   İ_τ = X_{r,τ}ᵀ W X_r + X_rᵀ W X_{r,τ} + X_rᵀ diag(w' ⊙ η̇_τ) X_r,
2717        // where W = diag(w(η)) is the Fisher weight (not its derivative).
2718        // The β-differential hits w (through η=Xβ) and η̇_τ (through X_τ β):
2719        //   D_β(X_{r,τ}ᵀ W X_r)[v] = X_{r,τ}ᵀ diag(w' ⊙ δη_v) X_r,
2720        //   D_β(X_rᵀ diag(w' ⊙ η̇_τ) X_r)[v]
2721        //     = X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v + w' ⊙ δη_{τ,v}) X_r,
2722        // where δη_v = beta_direction.deta, δη_{τ,v} = x_tau_v.
2723        // s_v := w' ⊙ δη_v (same weight the FirthDirection uses to build
2724        // g_u_reduced); b_vvec := w'' ⊙ δη_v = beta_direction.b_uvec is the
2725        // weight for the third-term product-rule piece.
2726        let s_v = &self.w1 * &beta_direction.deta;
2727        let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2728        let cross1 =
2729            RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2730        let cross2 =
2731            RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2732        let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2733        let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2734
2735        // D_β(K̇_τ)[v] — direct Leibniz on K̇_τ = -K_r İ_τ K_r with
2736        //   D_β K_r[v] = -K_r I'_v K_r = -beta_direction.a_u_reduced.
2737        // Expanding yields
2738        //   D_β K̇_τ[v] = +A_v İ_τ K_r − K_r D_β(İ_τ)[v] K_r + K_r İ_τ A_v,
2739        // where A_v := beta_direction.a_u_reduced = +K_r I'_v K_r.  The
2740        // FirthDirection carries a_u with the opposite sign convention to the
2741        // derivation block's "A_v"; we keep the direction convention and
2742        // compose signs correctly here.
2743        let term_a = beta_direction
2744            .a_u_reduced
2745            .dot(dot_i_partial)
2746            .dot(&self.k_reduced);
2747        let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2748        let term_c = self
2749            .k_reduced
2750            .dot(dot_i_partial)
2751            .dot(&beta_direction.a_u_reduced);
2752        let d_beta_dot_k = &term_a - &term_b + &term_c;
2753
2754        // D_β(ḣ_τ)[v] — β-differential of
2755        //   ḣ_τ = 2·diag(X_{r,τ} K_r X_rᵀ) + diag(X_r K̇_τ X_rᵀ):
2756        //   D_β ḣ_τ[v]
2757        //     = 2·diag(X_{r,τ} D_β K_r[v] X_rᵀ) + diag(X_r D_β K̇_τ[v] X_rᵀ)
2758        //     = -2·diag(X_{r,τ} A_v X_rᵀ) + diag(X_r (D_β K̇_τ[v]) X_rᵀ).
2759        let cross_diag = Self::rowwise_bilinear(
2760            &tau_kernel.x_tau_reduced,
2761            &beta_direction.a_u_reduced,
2762            &self.x_reduced,
2763        );
2764        let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2765        let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2766
2767        FirthTauBetaPartialKernel {
2768            x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2769            deta_partial: deta_partial.clone(),
2770            dot_h_partial: tau_kernel.dot_h_partial.clone(),
2771            dot_i_partial: dot_i_partial.clone(),
2772            dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2773            deta_v: beta_direction.deta.clone(),
2774            deta_tau_v: x_tau_v.clone(),
2775            a_v_reduced: beta_direction.a_u_reduced.clone(),
2776            dh_v: beta_direction.dh.clone(),
2777            b_vvec: beta_direction.b_uvec.clone(),
2778            d_beta_dot_k,
2779            d_beta_dot_h,
2780        }
2781    }
2782
2783    /// Apply the mixed β-τ P-action `P_{τ,v} · mat` to an n×m column block.
2784    ///
2785    /// Expansion:
2786    ///   P_{τ,v} = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
2787    ///     M_v     = X_r K̇_v X_rᵀ,  K̇_v = -A_v (A_v = a_v_reduced),
2788    ///     M_τ     = X_{r,τ} K_r X_rᵀ + X_r K_r X_{r,τ}ᵀ + X_r K̇_τ X_rᵀ,
2789    ///     M_{τ,v} = X_{r,τ} K̇_v X_rᵀ + X_r K̇_v X_{r,τ}ᵀ + X_r D_β K̇_τ[v] X_rᵀ.
2790    /// Hadamard-Gram pieces are evaluated column-wise via
2791    ///   ((Z M_A Wᵀ) ⊙ (Y M_B Xᵀ)) v row-i
2792    ///       = z_iᵀ M_A (Wᵀ diag(v) X) M_Bᵀ y_i.
2793    pub(crate) fn apply_p_tau_v_to_matrix(
2794        &self,
2795        kernel: &FirthTauBetaPartialKernel,
2796        mat: &Array2<f64>,
2797    ) -> Array2<f64> {
2798        let n = self.x_dense.nrows();
2799        if mat.nrows() != n || mat.ncols() == 0 {
2800            return Array2::<f64>::zeros(mat.raw_dim());
2801        }
2802        let z = &self.x_reduced;
2803        let z_tau = &kernel.x_tau_reduced;
2804        let k_r = &self.k_reduced;
2805        let a_v = &kernel.a_v_reduced; // = +K_r I'_v K_r  (so K̇_v = -a_v)
2806        let dot_k_tau = &kernel.dot_k_reduced; // K̇_τ = -K_r İ_τ K_r
2807        let d_beta_dot_k = &kernel.d_beta_dot_k; // D_β K̇_τ[v]
2808        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2809        for col in 0..mat.ncols() {
2810            let v = mat.column(col).to_owned();
2811            let s_zz = RemlState::reducedweighted_gram(z, &v);
2812            let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2813
2814            // Piece 1: (X_r K̇_v X_rᵀ ⊙ X_{r,τ} K_r X_rᵀ) · v
2815            //   = -rowwise_bilinear(Z, a_v S_zz K_r, Z_τ).
2816            let mid_1 = a_v.dot(&s_zz).dot(k_r);
2817            let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2818            // Piece 2: (X_r K̇_v X_rᵀ ⊙ X_r K_r X_{r,τ}ᵀ) · v
2819            //   = -reduced_diag_gram(Z, a_v S_z_ztau K_r).
2820            let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2821            let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2822            // Piece 3: (X_r K̇_v X_rᵀ ⊙ X_r K̇_τ X_rᵀ) · v
2823            //   = -reduced_diag_gram(Z, a_v S_zz K̇_τ).
2824            let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2825            let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2826            // Piece 4: (M ⊙ X_{r,τ} K̇_v X_rᵀ) · v
2827            //   = -rowwise_bilinear(Z, K_r S_zz a_v, Z_τ).
2828            let mid_4 = k_r.dot(&s_zz).dot(a_v);
2829            let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2830            // Piece 5: (M ⊙ X_r K̇_v X_{r,τ}ᵀ) · v
2831            //   = -reduced_diag_gram(Z, K_r S_z_ztau a_v).
2832            let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2833            let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2834            // Piece 6: (M ⊙ X_r D_β K̇_τ[v] X_rᵀ) · v.
2835            let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2836
2837            // P_{τ,v} = 2·(pieces 1-3) + 2·(pieces 4-6); each group contributes
2838            // with the same outer factor 2.
2839            let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2840            out.column_mut(col).assign(&y);
2841        }
2842        out
2843    }
2844
2845    pub(crate) fn d_beta_hphi_tau_partial_apply(
2846        &self,
2847        x_tau: &Array2<f64>,
2848        kernel: &FirthTauBetaPartialKernel,
2849        rhs: &Array2<f64>,
2850    ) -> Array2<f64> {
2851        let p = self.x_dense.ncols();
2852        if rhs.nrows() != p {
2853            return Array2::<f64>::zeros((p, rhs.ncols()));
2854        }
2855        if rhs.ncols() == 0 || p == 0 {
2856            return Array2::<f64>::zeros((p, rhs.ncols()));
2857        }
2858        // Matrix-free block apply of D_β((H_φ)_τ|_β)[v] evaluated on a rhs V.
2859        // Structure follows hphi_tau_partial_apply but replaces every weight
2860        // and every reduced Gram with its β-derivative in direction v:
2861        //
2862        //   (H_φ)_τ|_β (V) = 0.5 [X_τᵀ r(V) + Xᵀ r_τ(V)].
2863        //
2864        // D_β[v] leaves X, X_τ fixed and acts on r, r_τ:
2865        //   D_β((H_φ)_τ|_β)[v](V) = 0.5 [X_τᵀ D_β r(V)[v] + Xᵀ D_β r_τ(V)[v]].
2866        let etav = fast_ab(&self.x_dense, rhs);
2867        let etav_tau = fast_ab(x_tau, rhs);
2868        let deta_v = &kernel.deta_v;
2869        let deta_tau_v = &kernel.deta_tau_v;
2870        let eta_tau = &kernel.deta_partial;
2871        let dot_h = &kernel.dot_h_partial;
2872
2873        // Reuse τ-kernel weights.  dotw1 = w'' ⊙ η̇_τ, dotw2 = w''' ⊙ η̇_τ.
2874        let dotw1 = &self.w2 * eta_tau;
2875        let dotw2 = &self.w3 * eta_tau;
2876
2877        // β-derivative scaling vectors in direction v:
2878        //   c_v              = D_β(w''·h)[v]    = w'''·δη_v·h + w''·dh_v
2879        //   b_vvec           = D_β(w')[v]       = w''·δη_v   (= kernel.b_vvec)
2880        //   d_beta_dotw1_vec = D_β(w''·η̇_τ)[v]  = w'''·δη_v·η̇_τ + w''·δη_{τ,v}
2881        //   d_beta_dotw2_vec = D_β(w'''·η̇_τ)[v] = w''''·δη_v·η̇_τ + w'''·δη_{τ,v}
2882        let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2883        let b_vvec = &kernel.b_vvec;
2884        let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2885        let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2886
2887        // Single-τ pieces (identical to hphi_tau_partial_apply).
2888        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2889        let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2890            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2891        let m_qv = self.apply_pbar_to_matrix(&qv);
2892        // apply_mtau_to_matrix only reads x_tau_reduced and dot_k_reduced off
2893        // the τ-kernel, but owning the full struct is cheap.
2894        let tau_kernel_view = FirthTauPartialKernel {
2895            deta_partial: eta_tau.clone(),
2896            dotw1: dotw1.clone(),
2897            dotw2: dotw2.clone(),
2898            dot_h_partial: dot_h.clone(),
2899            x_tau_reduced: kernel.x_tau_reduced.clone(),
2900            dot_i_partial: kernel.dot_i_partial.clone(),
2901            dot_k_reduced: kernel.dot_k_reduced.clone(),
2902        };
2903        let m_qv_tau =
2904            self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2905
2906        // β-derivatives of the single-τ pieces:
2907        //   D_β qv     = etav · D_β w'[v]       = etav · b_vvec
2908        //   D_β qv_tau = etav · D_β dotw1[v] + etav_tau · D_β w'[v]
2909        let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2910        let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2911            + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2912
2913        //   D_β m_qv = P_v · qv + P · D_β qv
2914        let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2915            + self.apply_pbar_to_matrix(&d_beta_qv);
2916
2917        //   D_β m_qv_tau = P_{τ,v}·qv + P_τ·D_β qv + P_v·qv_tau + P·D_β qv_tau
2918        let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2919            + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2920            + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2921            + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2922
2923        // D_β rv[v] where rv = etav·(w''·h) − w'·m_qv:
2924        //   D_β rv[v] = etav·c_v − b_vvec·m_qv − w'·D_β m_qv.
2925        let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2926            - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2927            - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2928
2929        // D_β rv_tau[v] where
2930        //   rv_tau = etav·dotw2·h + etav_tau·w''·h + etav·w''·dot_h
2931        //            − m_qv·dotw1 − m_qv_tau·w'.
2932        //
2933        //   D_β(dotw2·h)[v]   = (w''''·δη_v·η̇_τ + w'''·δη_{τ,v})·h
2934        //                        + dotw2·dh_v,
2935        //   D_β(w''·h)[v]     = c_v,
2936        //   D_β(w''·dot_h)[v] = w'''·δη_v·dot_h + w''·D_β dot_h[v],
2937        //   D_β dotw1[v]      = d_beta_dotw1_vec,
2938        //   D_β w'[v]         = b_vvec.
2939        let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2940        let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2941
2942        let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2943            + &etav_tau * &c_v.view().insert_axis(Axis(1))
2944            + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2945            - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2946            - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2947            - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2948            - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2949
2950        0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2951    }
2952}
2953
2954#[cfg(test)]
2955mod tests {
2956    use super::*;
2957    use crate::mixture_link::logit_inverse_link_jet5;
2958    use gam_problem::StandardLink;
2959    use ndarray::{Array1, Array2, array};
2960
2961    // Operator-equivalence oracle accessors (#1575). The production inner-PIRLS
2962    // path memoizes the β-independent design factor and reads diagnostics through
2963    // `pirls_diagnostics_from_factor`; these full-operator accessors are retained
2964    // ONLY for the equivalence unit tests, so they live in this `#[cfg(test)]`
2965    // module rather than gating individual production methods with `#[cfg(test)]`.
2966    impl FirthDenseOperator {
2967        pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
2968            &self.w * &self.h_diag
2969        }
2970
2971        /// Per-observation Firth working-response shift `Δ_i = ½·(w'_i/w_i)·h_diag_i`
2972        /// (the link-general form; `w_i ≤ 0` rows get a zero shift). Matches the
2973        /// Jeffreys score `½ Σ_i w'_i h_i x_i` the outer REML differentiates.
2974        pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
2975            let mut shift = Array1::<f64>::zeros(self.w.len());
2976            for i in 0..self.w.len() {
2977                let wi = self.w[i];
2978                if wi > 0.0 {
2979                    shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
2980                }
2981            }
2982            shift
2983        }
2984    }
2985
2986    pub(crate) fn build_logit_firth_dense_operator(
2987        x_dense: &Array2<f64>,
2988        eta: &Array1<f64>,
2989    ) -> Result<FirthDenseOperator, EstimationError> {
2990        FirthDenseOperator::build_with_observation_weights_impl(
2991            &InverseLink::Standard(StandardLink::Logit),
2992            x_dense,
2993            eta,
2994            None,
2995        )
2996    }
2997
2998    pub(crate) fn build_weighted_logit_firth_dense_operator(
2999        x_dense: &Array2<f64>,
3000        eta: &Array1<f64>,
3001        observation_weights: ndarray::ArrayView1<'_, f64>,
3002    ) -> Result<FirthDenseOperator, EstimationError> {
3003        FirthDenseOperator::build_with_observation_weights_impl(
3004            &InverseLink::Standard(StandardLink::Logit),
3005            x_dense,
3006            eta,
3007            Some(observation_weights),
3008        )
3009    }
3010
3011    pub(crate) fn logisticweight(eta: f64) -> f64 {
3012        logit_inverse_link_jet5(eta).d1
3013    }
3014
3015    pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3016        let eta = x.dot(beta);
3017        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3018        op.jeffreys_logdet()
3019    }
3020
3021    pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
3022        let eta = x.dot(beta);
3023        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3024        op.jeffreys_beta_gradient()
3025    }
3026
3027    pub(crate) fn weighted_firthphivalue(
3028        x: &Array2<f64>,
3029        beta: &Array1<f64>,
3030        observation_weights: &Array1<f64>,
3031    ) -> f64 {
3032        let eta = x.dot(beta);
3033        let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
3034            .expect("weighted firth operator");
3035        op.jeffreys_logdet()
3036    }
3037
3038    #[test]
3039    pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
3040        let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
3041        let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
3042            .expect("barely positive-definite reduced fisher");
3043        let expected = 0.5 * 16.0_f64.ln();
3044
3045        assert!(
3046            half_log_det.is_finite(),
3047            "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
3048        );
3049        assert!(
3050            (half_log_det - expected).abs() < 1e-12,
3051            "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
3052        );
3053        assert!(
3054            k_reduced.iter().all(|value| value.is_finite()),
3055            "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
3056        );
3057        assert!(
3058            k_reduced[[1, 1]].abs() < f64::EPSILON,
3059            "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
3060        );
3061    }
3062
3063    #[test]
3064    pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
3065        // Validates op.w[i] (= jet.d1) and op.w1..w4[i] (= jet.d2..jet.d5)
3066        // against direct central finite differences of the logistic inverse
3067        // link pdf w(η) = μ(η)(1−μ(η)).
3068        //
3069        // Nested central differences amplify roundoff by 1/h per nesting
3070        // level, so a d1fd-of-d1fd-of-d2fd cannot deliver the tolerances
3071        // that 4th-order agreement requires. The principled replacement is
3072        // a direct higher-order stencil whose truncation and roundoff are
3073        // both controlled by a single step h:
3074        //
3075        //   d1  (2-pt):  (f(z+h) − f(z−h)) / (2h)                       O(h²) trunc
3076        //   d2  (3-pt):  (f(z+h) − 2f(z) + f(z−h)) / h²                 O(h²) trunc
3077        //   d3  (4-pt):  (−f(z−2h)+2f(z−h)−2f(z+h)+f(z+2h)) / (2h³)     O(h²) trunc
3078        //   d4  (5-pt):  (f(z−2h)−4f(z−h)+6f(z)−4f(z+h)+f(z+2h)) / h⁴   O(h²) trunc
3079        //
3080        // At h = 1e-2 the logistic pdf and its higher derivatives stay of
3081        // order ≤ 1, so truncation O(h²·M) ≲ 1e-4 and roundoff O(ε/h^n)
3082        // is well below any asserted tolerance through the 4th order.
3083        let x = array![
3084            [1.0, -1.1, 0.2],
3085            [1.0, -0.5, -0.6],
3086            [1.0, 0.0, 0.3],
3087            [1.0, 0.8, -0.4],
3088            [1.0, 1.2, 0.7],
3089        ];
3090        let beta = array![0.15, -0.6, 0.35];
3091        let eta = x.dot(&beta);
3092        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3093
3094        let h = 1e-2_f64;
3095        let w = |z: f64| logisticweight(z);
3096        let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
3097        let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
3098        let d3direct = |z: f64| {
3099            (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
3100        };
3101        let d4direct = |z: f64| {
3102            (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
3103                / h.powi(4)
3104        };
3105        for i in 0..eta.len() {
3106            let z = eta[i];
3107            let wfd = w(z);
3108            let w1fd = d1direct(z);
3109            let w2fd = d2direct(z);
3110            let w3fd = d3direct(z);
3111            let w4fd = d4direct(z);
3112
3113            assert!((op.w[i] - wfd).abs() < 1e-12);
3114            assert_eq!(op.w1[i].signum(), w1fd.signum());
3115            assert_eq!(op.w2[i].signum(), w2fd.signum());
3116            assert_eq!(op.w3[i].signum(), w3fd.signum());
3117            assert_eq!(op.w4[i].signum(), w4fd.signum());
3118            assert!((op.w1[i] - w1fd).abs() < 1e-5);
3119            assert!((op.w2[i] - w2fd).abs() < 1e-4);
3120            assert!((op.w3[i] - w3fd).abs() < 1e-4);
3121            assert!((op.w4[i] - w4fd).abs() < 1e-3);
3122        }
3123    }
3124
3125    #[test]
3126    pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
3127        let x = array![
3128            [1.0, -0.7, 0.3],
3129            [1.0, -0.2, -0.4],
3130            [1.0, 0.5, 0.1],
3131            [1.0, 1.1, -0.6],
3132            [1.0, 1.6, 0.8],
3133        ];
3134        let beta = array![0.2, -0.45, 0.25];
3135        let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
3136        let eta = x.dot(&beta);
3137        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3138            .expect("weighted firth operator");
3139        let grad = op.jeffreys_beta_gradient();
3140        let h = 1e-6;
3141
3142        for j in 0..beta.len() {
3143            let mut beta_plus = beta.clone();
3144            beta_plus[j] += h;
3145            let mut beta_minus = beta.clone();
3146            beta_minus[j] -= h;
3147            let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
3148                - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
3149                / (2.0 * h);
3150            assert!(
3151                (grad[j] - fd).abs() < 1e-5,
3152                "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
3153                j,
3154                grad[j],
3155                fd
3156            );
3157        }
3158    }
3159
3160    // ----------------------------------------------------------------------
3161    // Link-general (probit) finite-difference proof of the Jeffreys/Firth
3162    // Φ(β) = ½ log|I_r(β)|, its β-gradient ∂Φ/∂β, and the β-Hessian
3163    // derivative D H_φ[u] exposed via `hphi_direction`. Logit is used as a
3164    // regression guard against the historical logit-pinned build.
3165    // ----------------------------------------------------------------------
3166
3167    pub(crate) fn build_link_firth_op(
3168        link: StandardLink,
3169        x: &Array2<f64>,
3170        beta: &Array1<f64>,
3171    ) -> FirthDenseOperator {
3172        let eta = x.dot(beta);
3173        FirthDenseOperator::build_with_observation_weights_impl(
3174            &InverseLink::Standard(link),
3175            x,
3176            &eta,
3177            None,
3178        )
3179        .expect("link-general firth operator")
3180    }
3181
3182    pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3183        build_link_firth_op(link, x, beta).jeffreys_logdet()
3184    }
3185
3186    pub(crate) fn link_firth_grad(
3187        link: StandardLink,
3188        x: &Array2<f64>,
3189        beta: &Array1<f64>,
3190    ) -> Array1<f64> {
3191        build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
3192    }
3193
3194    /// Central-difference Jacobian of the *analytic* Firth gradient, i.e. a
3195    /// numerical realization of the β-Hessian H_φ = ∂g/∂β. The Newton/REML
3196    /// path consumes H_φ (and its directional derivative) so this is the
3197    /// matrix the analytic curvature must reproduce.
3198    pub(crate) fn numeric_firth_hessian(
3199        link: StandardLink,
3200        x: &Array2<f64>,
3201        beta: &Array1<f64>,
3202        h: f64,
3203    ) -> Array2<f64> {
3204        let p = beta.len();
3205        let mut hess = Array2::<f64>::zeros((p, p));
3206        for j in 0..p {
3207            let mut bp = beta.clone();
3208            bp[j] += h;
3209            let mut bm = beta.clone();
3210            bm[j] -= h;
3211            let gp = link_firth_grad(link, x, &bp);
3212            let gm = link_firth_grad(link, x, &bm);
3213            let col = (&gp - &gm) / (2.0 * h);
3214            hess.column_mut(j).assign(&col);
3215        }
3216        hess
3217    }
3218
3219    /// #1575: the cached single-index second-direction path
3220    /// (`tk_second_direction_eye_cache` + `hphisecond_direction_apply_eye_cached`)
3221    /// must be BIT-IDENTICAL to the per-pair `hphisecond_direction_apply(.., &eye)`
3222    /// it replaces in the exact-Hessian TK outer loop. This locks the work-elision
3223    /// invariant: it removes redundant O(n·r²·p) reduced Hadamard-Gram applies, it
3224    /// must NOT change a single bit of the resulting Hessian contribution.
3225    #[test]
3226    fn hphisecond_eye_cached_matches_per_pair_bit_identical_1575() {
3227        // A 6×3 logit design with a few distinct η directions (mirrors the
3228        // multi-smooth penalty directions the TK loop contracts over).
3229        let x = array![
3230            [1.0, -1.10, 0.35],
3231            [1.0, -0.40, -0.65],
3232            [1.0, 0.15, 0.20],
3233            [1.0, 0.80, -0.45],
3234            [1.0, 1.25, 0.70],
3235            [1.0, -0.55, 0.95],
3236        ];
3237        let beta = array![0.20, -0.55, 0.30];
3238        let op = build_link_firth_op(StandardLink::Logit, &x, &beta);
3239        let p = x.ncols();
3240
3241        // Three β-direction δη vectors playing the role of eta_i[idx].
3242        let deta_list = [
3243            x.dot(&array![0.9, -0.3, 0.2]),
3244            x.dot(&array![-0.4, 0.7, 0.1]),
3245            x.dot(&array![0.1, 0.2, -0.8]),
3246        ];
3247        let dirs: Vec<FirthDirection> = deta_list
3248            .iter()
3249            .map(|d| op.direction_from_deta(d.clone()))
3250            .collect();
3251
3252        let eye = Array2::<f64>::eye(p);
3253        let cache = op.tk_second_direction_eye_cache(&dirs);
3254        for i in 0..dirs.len() {
3255            for j in 0..=i {
3256                let reference = op.hphisecond_direction_apply(&dirs[i], &dirs[j], &eye);
3257                let cached = op.hphisecond_direction_apply_eye_cached(&cache, &dirs, i, j);
3258                assert_eq!(
3259                    reference.dim(),
3260                    cached.dim(),
3261                    "shape mismatch at pair ({i},{j})"
3262                );
3263                for (a, b) in reference.iter().zip(cached.iter()) {
3264                    assert_eq!(
3265                        a.to_bits(),
3266                        b.to_bits(),
3267                        "cached D²H_φ[{i},{j}] is not bit-identical to per-pair: \
3268                         reference={a}, cached={b}"
3269                    );
3270                }
3271            }
3272        }
3273    }
3274
3275    /// A fixed, well-conditioned full-rank design (deterministic, no RNG).
3276    pub(crate) fn fixed_design_5x3() -> Array2<f64> {
3277        array![
3278            [1.0, -1.10, 0.35],
3279            [1.0, -0.40, -0.65],
3280            [1.0, 0.15, 0.20],
3281            [1.0, 0.80, -0.45],
3282            [1.0, 1.25, 0.70],
3283        ]
3284    }
3285
3286    #[test]
3287    pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
3288        // Guard: the StandardLink::Logit path through the link-general builder
3289        // must be byte-identical to the historical logit-pinned operator for
3290        // Φ, the β-gradient, the PIRLS hat diagonal, and the cached weight
3291        // jets w, w'..w''''.
3292        let x = fixed_design_5x3();
3293        let beta = array![0.20, -0.55, 0.30];
3294        let eta = x.dot(&beta);
3295
3296        let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
3297        let link_general = FirthDenseOperator::build_with_observation_weights_impl(
3298            &InverseLink::Standard(StandardLink::Logit),
3299            &x,
3300            &eta,
3301            None,
3302        )
3303        .expect("link-general logit");
3304
3305        assert_eq!(
3306            historical.jeffreys_logdet(),
3307            link_general.jeffreys_logdet(),
3308            "logit Φ must be bit-identical through the link-general path"
3309        );
3310        let g_hist = historical.jeffreys_beta_gradient();
3311        let g_link = link_general.jeffreys_beta_gradient();
3312        for j in 0..g_hist.len() {
3313            assert_eq!(
3314                g_hist[j], g_link[j],
3315                "logit gradient component {j} must be bit-identical"
3316            );
3317        }
3318        let hat_hist = historical.pirls_hat_diag();
3319        let hat_link = link_general.pirls_hat_diag();
3320        for i in 0..hat_hist.len() {
3321            assert_eq!(
3322                hat_hist[i], hat_link[i],
3323                "logit PIRLS hat diagonal {i} must be bit-identical"
3324            );
3325        }
3326        for i in 0..eta.len() {
3327            assert_eq!(historical.w[i], link_general.w[i]);
3328            assert_eq!(historical.w1[i], link_general.w1[i]);
3329            assert_eq!(historical.w2[i], link_general.w2[i]);
3330            assert_eq!(historical.w3[i], link_general.w3[i]);
3331            assert_eq!(historical.w4[i], link_general.w4[i]);
3332        }
3333    }
3334
3335    #[test]
3336    pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3337        // PROBIT correctness: ∂Φ/∂β from `jeffreys_beta_gradient` must match a
3338        // central finite difference of Φ(β) on a well-conditioned design.
3339        let x = fixed_design_5x3();
3340        let beta = array![0.10, -0.40, 0.25];
3341        let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3342        let h = 1e-6_f64;
3343        let mut max_rel = 0.0_f64;
3344        for j in 0..beta.len() {
3345            let mut bp = beta.clone();
3346            bp[j] += h;
3347            let mut bm = beta.clone();
3348            bm[j] -= h;
3349            let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3350                - link_firth_phi(StandardLink::Probit, &x, &bm))
3351                / (2.0 * h);
3352            let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3353            let rel = (grad[j] - fd).abs() / denom;
3354            max_rel = max_rel.max(rel);
3355            assert!(
3356                rel < 1e-6,
3357                "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3358                grad[j],
3359                fd,
3360                rel
3361            );
3362        }
3363        assert!(
3364            max_rel < 1e-6,
3365            "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3366        );
3367    }
3368
3369    #[test]
3370    pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3371        // PROBIT Hessian: `hphi_direction(direction_from_deta(X·u))` is the
3372        // analytic directional derivative D H_φ[u] of the β-Hessian. Verify it
3373        // against the central finite difference of the (numerically realized)
3374        // β-Hessian H_φ along u. The numeric H_φ at each shifted β is itself a
3375        // finite difference of the *analytic* gradient, so the base operand is
3376        // analytic at first order; only the directional step is differenced
3377        // here.
3378        let x = fixed_design_5x3();
3379        let beta = array![0.10, -0.40, 0.25];
3380        let p = beta.len();
3381
3382        // Probe several directions, including non-axis-aligned ones.
3383        let directions = [
3384            array![1.0, 0.0, 0.0],
3385            array![0.0, 1.0, 0.0],
3386            array![0.0, 0.0, 1.0],
3387            array![0.7, -0.5, 0.3],
3388        ];
3389
3390        let h_inner = 1e-4_f64; // step for the numeric Hessian (FD of analytic grad)
3391        let h_dir = 1e-4_f64; // step for the directional derivative of the Hessian
3392        let mut worst = 0.0_f64;
3393        for u in directions.iter() {
3394            let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3395            let deta = x.dot(u);
3396            let dir = op.direction_from_deta(deta);
3397            let analytic = op.hphi_direction(&dir);
3398
3399            let beta_plus = &beta + &(u * h_dir);
3400            let beta_minus = &beta - &(u * h_dir);
3401            let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3402            let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3403            let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3404
3405            let mut scale = 1e-6_f64;
3406            for r in 0..p {
3407                for c in 0..p {
3408                    scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3409                }
3410            }
3411            for r in 0..p {
3412                for c in 0..p {
3413                    let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3414                    worst = worst.max(rel);
3415                    assert!(
3416                        rel < 5e-3,
3417                        "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3418                        analytic[[r, c]],
3419                        fd[[r, c]],
3420                        rel
3421                    );
3422                }
3423            }
3424        }
3425        assert!(
3426            worst < 5e-3,
3427            "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3428        );
3429    }
3430
3431    #[test]
3432    pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3433        // Identifiable-subspace behavior: a rank-deficient design (column 3 =
3434        // column 1 + column 2) must yield a finite Φ = ½ log|Uᵀ W U|, a finite
3435        // gradient, and agree with the explicit reduced two-column design.
3436        let x_full = array![
3437            [1.0, -1.20, -0.20],
3438            [1.0, -0.40, 0.60],
3439            [1.0, 0.10, 1.10],
3440            [1.0, 0.70, 1.70],
3441            [1.0, 1.30, 2.30],
3442        ];
3443        let x_reduced = array![
3444            [1.0, -1.20],
3445            [1.0, -0.40],
3446            [1.0, 0.10],
3447            [1.0, 0.70],
3448            [1.0, 1.30],
3449        ];
3450        let beta_full = array![0.25, -0.50, 0.15];
3451        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3452
3453        let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3454        let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3455        assert!(
3456            phi_full.is_finite(),
3457            "probit Φ on rank-deficient design must be finite, got {phi_full}"
3458        );
3459        assert!(
3460            (phi_full - phi_reduced).abs() < 1e-12,
3461            "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3462        );
3463
3464        let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3465        let grad_full = op_full.jeffreys_beta_gradient();
3466        assert!(
3467            grad_full.iter().all(|v| v.is_finite()),
3468            "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3469        );
3470        let hat_full = op_full.pirls_hat_diag();
3471        let hat_reduced =
3472            build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3473        for i in 0..hat_full.len() {
3474            assert!(
3475                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3476                "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3477                hat_full[i],
3478                hat_reduced[i]
3479            );
3480        }
3481    }
3482
3483    #[test]
3484    pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3485        // Column 3 is exactly column 1 + column 2, so the original design is
3486        // rank-deficient but its identifiable subspace is represented exactly by
3487        // the explicit two-column reduced design below.
3488        let x_full = array![
3489            [1.0, -1.2, -0.2],
3490            [1.0, -0.4, 0.6],
3491            [1.0, 0.1, 1.1],
3492            [1.0, 0.7, 1.7],
3493            [1.0, 1.3, 2.3],
3494        ];
3495        let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3496        let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3497        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3498        let eta_full = x_full.dot(&beta_full);
3499        let eta_reduced = x_reduced.dot(&beta_reduced);
3500        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3501
3502        for i in 0..eta_full.len() {
3503            assert!(
3504                (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3505                "eta mismatch at row {i}: full={} reduced={}",
3506                eta_full[i],
3507                eta_reduced[i]
3508            );
3509        }
3510
3511        let op_full = build_weighted_logit_firth_dense_operator(
3512            &x_full,
3513            &eta_full,
3514            observation_weights.view(),
3515        )
3516        .expect("full firth operator");
3517        let op_reduced = build_weighted_logit_firth_dense_operator(
3518            &x_reduced,
3519            &eta_reduced,
3520            observation_weights.view(),
3521        )
3522        .expect("reduced firth operator");
3523
3524        assert!(
3525            (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3526            "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3527            op_full.jeffreys_logdet(),
3528            op_reduced.jeffreys_logdet()
3529        );
3530
3531        let hat_full = op_full.pirls_hat_diag();
3532        let hat_reduced = op_reduced.pirls_hat_diag();
3533        for i in 0..hat_full.len() {
3534            assert!(
3535                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3536                "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3537                hat_full[i],
3538                hat_reduced[i]
3539            );
3540        }
3541    }
3542
3543    #[test]
3544    pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3545        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3546        let basis = array![[1.4, -0.3], [0.6, 1.1]];
3547        let x_reparameterized = x.dot(&basis);
3548        let beta = array![0.25, -0.5];
3549        let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3550        assert!(
3551            basis_det.abs() > 1e-12,
3552            "basis transform must be invertible"
3553        );
3554        let basis_inv = array![
3555            [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3556            [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3557        ];
3558        let beta_reparameterized = basis_inv.dot(&beta);
3559        let eta = x.dot(&beta);
3560        let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3561        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3562
3563        for i in 0..eta.len() {
3564            assert!(
3565                (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3566                "eta mismatch at row {i}: original={} reparameterized={}",
3567                eta[i],
3568                eta_reparameterized[i]
3569            );
3570        }
3571
3572        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3573            .expect("original firth operator");
3574        let op_reparameterized = build_weighted_logit_firth_dense_operator(
3575            &x_reparameterized,
3576            &eta_reparameterized,
3577            observation_weights.view(),
3578        )
3579        .expect("reparameterized firth operator");
3580
3581        assert!(
3582            (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3583            "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3584            op.jeffreys_logdet(),
3585            op_reparameterized.jeffreys_logdet()
3586        );
3587
3588        let hat = op.pirls_hat_diag();
3589        let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3590        for i in 0..hat.len() {
3591            assert!(
3592                (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3593                "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3594                hat[i],
3595                hat_reparameterized[i]
3596            );
3597        }
3598    }
3599
3600    #[test]
3601    pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3602        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3603        let beta = array![0.25, -0.5];
3604        let eta = x.dot(&beta);
3605        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3606        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3607            .expect("firth operator");
3608
3609        let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3610        for i in 0..reduced_metric.nrows() {
3611            for j in 0..reduced_metric.ncols() {
3612                if i == j {
3613                    continue;
3614                }
3615                assert!(
3616                    reduced_metric[[i, j]].abs() < 1e-10,
3617                    "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3618                    reduced_metric[[i, j]]
3619                );
3620            }
3621        }
3622    }
3623
3624    #[test]
3625    pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3626        let x = array![
3627            [1.0, -1.0, 0.2],
3628            [1.0, -0.6, -0.3],
3629            [1.0, -0.1, 0.5],
3630            [1.0, 0.3, -0.7],
3631            [1.0, 0.8, 0.1],
3632            [1.0, 1.2, -0.4],
3633        ];
3634        let beta = array![0.1, -0.25, 0.2];
3635        let eta = x.dot(&beta);
3636        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3637
3638        let u = array![0.3, -0.2, 0.4];
3639        let v = array![-0.5, 0.1, 0.25];
3640        let du = op.direction_from_deta(x.dot(&u));
3641        let dv = op.direction_from_deta(x.dot(&v));
3642
3643        let eye = Array2::<f64>::eye(x.ncols());
3644        let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3645        let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3646
3647        for i in 0..uv.nrows() {
3648            for j in 0..uv.ncols() {
3649                let a = uv[[i, j]];
3650                let b = vu[[i, j]];
3651                assert_eq!(
3652                    a.signum(),
3653                    b.signum(),
3654                    "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3655                );
3656                assert!(
3657                    (a - b).abs() < 2e-7,
3658                    "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3659                );
3660            }
3661        }
3662    }
3663
3664    #[test]
3665    pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3666        let x = array![
3667            [1.0, -1.1, 0.2],
3668            [1.0, -0.6, -0.3],
3669            [1.0, -0.1, 0.5],
3670            [1.0, 0.3, -0.7],
3671            [1.0, 0.8, 0.1],
3672            [1.0, 1.2, -0.4],
3673        ];
3674        let beta = array![0.08, -0.22, 0.27];
3675        let eta = x.dot(&beta);
3676        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3677        let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3678        let dir = op.direction_from_deta(x.dot(&u));
3679
3680        let p = x.ncols();
3681        let eye = Array2::<f64>::eye(p);
3682        let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3683        for i in 0..p {
3684            for j in 0..i {
3685                let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3686                via_apply[[i, j]] = sym;
3687                via_apply[[j, i]] = sym;
3688            }
3689        }
3690        let direct = op.hphi_direction(&dir);
3691        let diff = &direct - &via_apply;
3692        let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3693        assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3694    }
3695
3696    #[test]
3697    pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3698        let x = array![
3699            [1.0, -1.0, 0.2],
3700            [1.0, -0.6, -0.3],
3701            [1.0, -0.1, 0.5],
3702            [1.0, 0.3, -0.7],
3703            [1.0, 0.8, 0.1],
3704            [1.0, 1.2, -0.4],
3705        ];
3706        let x_tau = array![
3707            [0.0, 0.15, -0.05],
3708            [0.0, -0.10, 0.02],
3709            [0.0, 0.08, 0.04],
3710            [0.0, -0.06, -0.03],
3711            [0.0, 0.05, 0.01],
3712            [0.0, -0.12, 0.06],
3713        ];
3714        let beta = array![0.1, -0.25, 0.2];
3715        let eta = x.dot(&beta);
3716        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3717        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3718
3719        let h = 1e-6;
3720        let x_plus = &x + &(h * &x_tau);
3721        let x_minus = &x - &(h * &x_tau);
3722        let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3723
3724        assert!(
3725            (analytic - fd).abs() < 1e-6,
3726            "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3727        );
3728    }
3729
3730    #[test]
3731    pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3732        let x = array![
3733            [1.0, -1.0, 0.2],
3734            [1.0, -0.6, -0.3],
3735            [1.0, -0.1, 0.5],
3736            [1.0, 0.3, -0.7],
3737            [1.0, 0.8, 0.1],
3738            [1.0, 1.2, -0.4],
3739        ];
3740        let x_tau = array![
3741            [0.0, 0.15, -0.05],
3742            [0.0, -0.10, 0.02],
3743            [0.0, 0.08, 0.04],
3744            [0.0, -0.06, -0.03],
3745            [0.0, 0.05, 0.01],
3746            [0.0, -0.12, 0.06],
3747        ];
3748        let beta = array![0.1, -0.25, 0.2];
3749        let eta = x.dot(&beta);
3750        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3751        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3752
3753        let h = 1e-6;
3754        let x_plus = &x + &(h * &x_tau);
3755        let x_minus = &x - &(h * &x_tau);
3756        let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3757
3758        let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3759        assert!(
3760            err < 1e-6,
3761            "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3762        );
3763    }
3764
3765    /// Verify pair.a scalar (`phi_tau_tau_partial`) by central-FD'ing the
3766    /// single-τ scalar `phi_tau_partial` along τ_j at fixed β.
3767    /// Identity: ∂/∂τ_j [Φ_{τ_i}|β] = Φ_{τ_iτ_j}|β.
3768    /// Tolerance 1e-7 relative.
3769    #[test]
3770    pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3771        let x = array![
3772            [1.0, -1.0, 0.2],
3773            [1.0, -0.6, -0.3],
3774            [1.0, -0.1, 0.5],
3775            [1.0, 0.3, -0.7],
3776            [1.0, 0.8, 0.1],
3777            [1.0, 1.2, -0.4],
3778        ];
3779        let x_tau_i = array![
3780            [0.0, 0.15, -0.05],
3781            [0.0, -0.10, 0.02],
3782            [0.0, 0.08, 0.04],
3783            [0.0, -0.06, -0.03],
3784            [0.0, 0.05, 0.01],
3785            [0.0, -0.12, 0.06],
3786        ];
3787        let x_tau_j = array![
3788            [0.0, -0.04, 0.11],
3789            [0.0, 0.09, -0.02],
3790            [0.0, -0.06, 0.07],
3791            [0.0, 0.10, -0.05],
3792            [0.0, -0.03, 0.08],
3793            [0.0, 0.07, -0.09],
3794        ];
3795        let beta = array![0.1, -0.25, 0.2];
3796        let eta = x.dot(&beta);
3797        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3798
3799        let analytic = op
3800            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3801            .phi_tau_tau_partial;
3802
3803        let h = 1e-5_f64;
3804        let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3805            let eta_e = x_eval.dot(&beta);
3806            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3807            op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3808                .phi_tau_partial
3809        };
3810        let x_plus = &x + &(h * &x_tau_j);
3811        let x_minus = &x - &(h * &x_tau_j);
3812        let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3813
3814        let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3815        assert!(
3816            rel < 1e-7,
3817            "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3818        );
3819    }
3820
3821    /// Verify pair.g p-vector (`gphi_tau_tau`) by central-FD'ing the single-τ
3822    /// `gphi_tau` along τ_j at fixed β.
3823    /// Identity: ∂/∂τ_j [(gΦ)_{τ_i}|β] = (gΦ)_{τ_iτ_j}|β.
3824    /// Tolerance 1e-7 relative max-abs.
3825    #[test]
3826    pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3827        let x = array![
3828            [1.0, -1.0, 0.2],
3829            [1.0, -0.6, -0.3],
3830            [1.0, -0.1, 0.5],
3831            [1.0, 0.3, -0.7],
3832            [1.0, 0.8, 0.1],
3833            [1.0, 1.2, -0.4],
3834        ];
3835        let x_tau_i = array![
3836            [0.0, 0.15, -0.05],
3837            [0.0, -0.10, 0.02],
3838            [0.0, 0.08, 0.04],
3839            [0.0, -0.06, -0.03],
3840            [0.0, 0.05, 0.01],
3841            [0.0, -0.12, 0.06],
3842        ];
3843        let x_tau_j = array![
3844            [0.0, -0.04, 0.11],
3845            [0.0, 0.09, -0.02],
3846            [0.0, -0.06, 0.07],
3847            [0.0, 0.10, -0.05],
3848            [0.0, -0.03, 0.08],
3849            [0.0, 0.07, -0.09],
3850        ];
3851        let beta = array![0.1, -0.25, 0.2];
3852        let eta = x.dot(&beta);
3853        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3854
3855        let analytic = op
3856            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3857            .gphi_tau_tau;
3858
3859        let h = 1e-5_f64;
3860        let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3861            let eta_e = x_eval.dot(&beta);
3862            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3863            op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3864        };
3865        let x_plus = &x + &(h * &x_tau_j);
3866        let x_minus = &x - &(h * &x_tau_j);
3867        let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3868
3869        let scale = analytic
3870            .iter()
3871            .chain(fd.iter())
3872            .map(|v| v.abs())
3873            .fold(0.0_f64, f64::max)
3874            .max(1.0);
3875        let err_max = (&analytic - &fd)
3876            .iter()
3877            .map(|v| v.abs())
3878            .fold(0.0_f64, f64::max);
3879        let rel = err_max / scale;
3880        assert!(
3881            rel < 1e-7,
3882            "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3883        );
3884    }
3885
3886    /// Verify the Primitive A body (`hphi_tau_tau_partial_apply`) against a
3887    /// finite-difference reference of the single-τ Primitive (
3888    /// `hphi_tau_partial_apply`).
3889    ///
3890    /// Identity under test:
3891    ///     ∂/∂τ_j  { (H_φ)_τ_i |_β · V }   =   ∂²H_φ/∂τ_i ∂τ_j |_β · V.
3892    ///
3893    /// Central-difference reference:
3894    ///   1. Evaluate the single-τ primitive at x, and at x ± h·X_τ_j
3895    ///      — rebuild the FirthDenseOperator (with fresh identifiable Q)
3896    ///      at each perturbed design; H_φ applied to a p-space rhs is
3897    ///      basis-invariant in unreduced β-coords, so Q rotation does not
3898    ///      contaminate the comparison.
3899    ///   2. FD_{i,j} = (T_{plus} − T_{minus}) / (2h) with T = hphi_tau_i_apply(V).
3900    ///   3. Contract both (i,j) and (j,i) directions and verify symmetry
3901    ///      of the analytic as a cross-check.
3902    ///
3903    /// Tolerance: 1e-7 relative max-abs (h chosen to balance truncation
3904    /// error at ~h² and evaluator roundoff at ~ε/h).
3905    #[test]
3906    pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3907        let x = array![
3908            [1.0, -1.0, 0.2],
3909            [1.0, -0.6, -0.3],
3910            [1.0, -0.1, 0.5],
3911            [1.0, 0.3, -0.7],
3912            [1.0, 0.8, 0.1],
3913            [1.0, 1.2, -0.4],
3914        ];
3915        let x_tau_i = array![
3916            [0.0, 0.15, -0.05],
3917            [0.0, -0.10, 0.02],
3918            [0.0, 0.08, 0.04],
3919            [0.0, -0.06, -0.03],
3920            [0.0, 0.05, 0.01],
3921            [0.0, -0.12, 0.06],
3922        ];
3923        let x_tau_j = array![
3924            [0.0, -0.04, 0.11],
3925            [0.0, 0.09, -0.02],
3926            [0.0, -0.06, 0.07],
3927            [0.0, 0.10, -0.05],
3928            [0.0, -0.03, 0.08],
3929            [0.0, 0.07, -0.09],
3930        ];
3931        let beta = array![0.1, -0.25, 0.2];
3932        let eta = x.dot(&beta);
3933        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3934        let p = x.ncols();
3935
3936        // Reproducible small rhs block (p × m).
3937        let m = 3usize;
3938        let mut rhs = Array2::<f64>::zeros((p, m));
3939        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3940        for r in 0..p {
3941            for c in 0..m {
3942                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3943            }
3944        }
3945
3946        // ── Analytic τ×τ pair apply at base design (x_tau_tau = None,
3947        //    deta_ij = None, i.e. design is linear in τ).
3948        let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3949        let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3950        let deta_i = x_tau_i.dot(&beta);
3951        let deta_j = x_tau_j.dot(&beta);
3952        let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3953        let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3954
3955        let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3956            x_tau_i_reduced.clone(),
3957            x_tau_j_reduced.clone(),
3958            &deta_i,
3959            &deta_j,
3960            dot_h_i.clone(),
3961            dot_h_j.clone(),
3962            dot_i_i.clone(),
3963            dot_i_j.clone(),
3964            None,
3965            None,
3966        );
3967        let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3968            x_tau_j_reduced,
3969            x_tau_i_reduced,
3970            &deta_j,
3971            &deta_i,
3972            dot_h_j,
3973            dot_h_i,
3974            dot_i_j,
3975            dot_i_i,
3976            None,
3977            None,
3978        );
3979        let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3980        let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3981
3982        // Symmetry cross-check (Clairaut): ∂²H/∂τ_i∂τ_j = ∂²H/∂τ_j∂τ_i.
3983        let sym_diff: f64 = (&analytic_ij - &analytic_ji)
3984            .iter()
3985            .map(|v| v.abs())
3986            .fold(0.0_f64, f64::max);
3987        let sym_scale: f64 = analytic_ij
3988            .iter()
3989            .chain(analytic_ji.iter())
3990            .map(|v| v.abs())
3991            .fold(0.0_f64, f64::max)
3992            .max(1.0);
3993        assert!(
3994            sym_diff / sym_scale < 1e-10,
3995            "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
3996        );
3997
3998        // ── FD reference: central difference of single-τ primitive in
3999        //    τ_j direction, evaluated along τ_i.
4000        let h = 1e-5_f64;
4001        let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
4002            let eta_e = x_eval.dot(&beta);
4003            let op_e =
4004                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4005            let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
4006            let deta_i_e = x_tau_i.dot(&beta);
4007            let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
4008            let kernel_i_e = op_e
4009                .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
4010            op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
4011        };
4012        let x_plus = &x + &(h * &x_tau_j);
4013        let x_minus = &x - &(h * &x_tau_j);
4014        let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
4015
4016        // ── Compare analytic_ij (contracted against V along τ_j→analytic's
4017        //    second index) to fd_ij (FD of T_i in τ_j direction).
4018        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4019            let scale = a
4020                .iter()
4021                .chain(b.iter())
4022                .map(|v| v.abs())
4023                .fold(0.0_f64, f64::max)
4024                .max(1.0);
4025            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4026            max_diff / scale
4027        };
4028        let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
4029
4030        // Also FD the other direction and compare to analytic_ji, to
4031        // double-cover the primitive.
4032        let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
4033            let eta_e = x_eval.dot(&beta);
4034            let op_e =
4035                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4036            let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
4037            let deta_j_e = x_tau_j.dot(&beta);
4038            let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
4039            let kernel_j_e = op_e
4040                .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
4041            op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
4042        };
4043        let x_plus_i = &x + &(h * &x_tau_i);
4044        let x_minus_i = &x - &(h * &x_tau_i);
4045        let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
4046        let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
4047
4048        let tol = 1e-7_f64;
4049        assert!(
4050            err_ij < tol,
4051            "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
4052             analytic=\n{analytic_ij:?}\n\
4053             fd=\n{fd_ij:?}"
4054        );
4055        assert!(
4056            err_ji < tol,
4057            "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
4058             analytic=\n{analytic_ji:?}\n\
4059             fd=\n{fd_ji:?}"
4060        );
4061    }
4062
4063    /// Verify the Primitive B body (`d_beta_hphi_tau_partial_apply`) against a
4064    /// finite-difference reference of the single-τ Primitive
4065    /// (`hphi_tau_partial_apply`).
4066    ///
4067    /// Identity under test (β held in the unreduced ambient; the design X is
4068    /// fixed so only w, η̇_τ = X_τ β, and their β-derivatives move):
4069    ///     D_β [ (H_φ)_τ|_β (β) · V ] [v]
4070    ///       = d_beta_hphi_tau_partial_apply(v, V).
4071    ///
4072    /// Central-difference reference:
4073    ///   1. Evaluate T(t) := hphi_tau_partial_apply(V) at β_t = β + t v,
4074    ///      rebuilding FirthDenseOperator at each β (so η = X β_t and the
4075    ///      w-chain are re-derived cleanly).  X is unchanged; Q is rebuilt
4076    ///      but H_φ applied to a p-space rhs is basis-invariant.
4077    ///   2. FD = (T(+h) − T(−h)) / (2h).
4078    ///   3. Tolerance 1e-7 relative max-abs (h chosen to balance truncation
4079    ///      error at ~h² and evaluator roundoff at ~ε/h).
4080    #[test]
4081    pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
4082        let x = array![
4083            [1.0, -1.0, 0.2],
4084            [1.0, -0.6, -0.3],
4085            [1.0, -0.1, 0.5],
4086            [1.0, 0.3, -0.7],
4087            [1.0, 0.8, 0.1],
4088            [1.0, 1.2, -0.4],
4089        ];
4090        let x_tau = array![
4091            [0.0, 0.15, -0.05],
4092            [0.0, -0.10, 0.02],
4093            [0.0, 0.08, 0.04],
4094            [0.0, -0.06, -0.03],
4095            [0.0, 0.05, 0.01],
4096            [0.0, -0.12, 0.06],
4097        ];
4098        let beta = array![0.1, -0.25, 0.2];
4099        // β-direction v for the D_β[·][v] test.
4100        let v = array![0.3, 0.2, -0.15];
4101
4102        let eta = x.dot(&beta);
4103        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
4104        let p = x.ncols();
4105
4106        // Reproducible small rhs block (p × m).
4107        let m = 3usize;
4108        let mut rhs = Array2::<f64>::zeros((p, m));
4109        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
4110        for r in 0..p {
4111            for c in 0..m {
4112                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
4113            }
4114        }
4115
4116        // ── Analytic apply at (x, β).
4117        let x_tau_reduced = op.reduce_explicit_design(&x_tau);
4118        let deta_partial = x_tau.dot(&beta);
4119        let (dot_i_partial, dot_h_partial) =
4120            op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
4121        let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
4122            x_tau_reduced.clone(),
4123            &deta_partial,
4124            dot_h_partial.clone(),
4125            dot_i_partial.clone(),
4126        );
4127
4128        let deta_v = x.dot(&v);
4129        let direction = op.direction_from_deta(deta_v);
4130        let x_tau_v = x_tau.dot(&v);
4131        let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
4132            &tau_kernel,
4133            &deta_partial,
4134            &dot_i_partial,
4135            &direction,
4136            &x_tau_v,
4137        );
4138        let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
4139
4140        // ── FD reference: central difference of single-τ primitive under
4141        //    β → β ± h v.  X stays fixed; η, w, η̇_τ are re-derived.
4142        let h = 1e-5_f64;
4143        let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
4144            let eta_e = x.dot(beta_eval);
4145            let op_e =
4146                build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
4147            let x_tau_r = op_e.reduce_explicit_design(&x_tau);
4148            let deta_e = x_tau.dot(beta_eval);
4149            let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
4150            let ker_e =
4151                op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
4152            op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
4153        };
4154        let beta_plus = &beta + &(h * &v);
4155        let beta_minus = &beta - &(h * &v);
4156        let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
4157
4158        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4159            let scale = a
4160                .iter()
4161                .chain(b.iter())
4162                .map(|v| v.abs())
4163                .fold(0.0_f64, f64::max)
4164                .max(1.0);
4165            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4166            max_diff / scale
4167        };
4168        let err = rel_max_abs_diff(&analytic, &fd);
4169
4170        let tol = 1e-7_f64;
4171        assert!(
4172            err < tol,
4173            "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
4174             analytic=\n{analytic:?}\n\
4175             fd=\n{fd:?}"
4176        );
4177    }
4178
4179    #[test]
4180    pub(crate) fn logisticweight_loses_positive_tail_mass() {
4181        let eta = 50.0_f64;
4182        let z = (-eta).exp();
4183        let stable = z / (1.0_f64 + z).powi(2);
4184        assert!(stable > 0.0);
4185        let got = logisticweight(eta);
4186        assert!(
4187            (got - stable).abs() < 1e-30,
4188            "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
4189            got,
4190            stable
4191        );
4192    }
4193
4194    #[test]
4195    pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
4196        // The generalized Firth weight jet for the canonical logit link must
4197        // reproduce the historical `logit_inverse_link_jet5().d1..d5` path
4198        // exactly so the released logit Firth fits stay numerically unchanged.
4199        for &eta in &[
4200            -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
4201        ] {
4202            let jet = logit_inverse_link_jet5(eta);
4203            let (w, w1, w2, w3, w4) =
4204                crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
4205            assert!(
4206                w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
4207                "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
4208                 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
4209                jet.d1,
4210                jet.d2,
4211                jet.d3,
4212                jet.d4,
4213                jet.d5
4214            );
4215        }
4216    }
4217
4218    #[test]
4219    pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
4220        // Probit Bernoulli Fisher weight W(eta) = phi^2 / (Phi (1 - Phi)).
4221        // Validate the closed-form jet against central finite differences of
4222        // the reference scalar weight.
4223        fn reference_probit_weight(eta: f64) -> f64 {
4224            let p = gam_math::probability::normal_cdf(eta);
4225            let q = 1.0 - p;
4226            let phi = gam_math::probability::normal_pdf(eta);
4227            if p <= 0.0 || q <= 0.0 {
4228                return 0.0;
4229            }
4230            phi * phi / (p * q)
4231        }
4232        let h = 1e-4_f64;
4233        for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
4234            let (w, w1, w2, _w3, _w4) =
4235                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4236            let ref_w = reference_probit_weight(eta);
4237            let fd1 =
4238                (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
4239            let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
4240                + reference_probit_weight(eta - h))
4241                / (h * h);
4242            assert!(
4243                (w - ref_w).abs() < 1e-10,
4244                "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
4245            );
4246            assert!(
4247                (w1 - fd1).abs() < 1e-5,
4248                "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
4249            );
4250            assert!(
4251                (w2 - fd2).abs() < 1e-3,
4252                "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
4253            );
4254        }
4255    }
4256
4257    #[test]
4258    pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
4259        // Past the point where the denominator Phi(1-Phi) underflows to zero,
4260        // the weight and all derivatives are exactly zero (the saturated-tail
4261        // convention shared with the inverse-link jet).
4262        for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
4263            let (w, w1, w2, w3, w4) =
4264                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4265            assert!(
4266                w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
4267                "probit Fisher weight jet must saturate to zero at eta={eta}; got \
4268                 ({w}, {w1}, {w2}, {w3}, {w4})"
4269            );
4270        }
4271        // In the moderate tail the denominator is still representable (the
4272        // complement is taken as Phi(-eta), not the cancellation-prone
4273        // `1 - Phi(eta)`), so the weight is a tiny strictly-positive finite
4274        // number with finite derivatives. It must NOT prematurely round to zero.
4275        for &eta in &[12.0_f64, -12.0] {
4276            let (w, w1, w2, w3, w4) =
4277                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4278            assert!(
4279                w > 0.0
4280                    && w.is_finite()
4281                    && w1.is_finite()
4282                    && w2.is_finite()
4283                    && w3.is_finite()
4284                    && w4.is_finite(),
4285                "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
4286                 ({w}, {w1}, {w2}, {w3}, {w4})"
4287            );
4288        }
4289    }
4290}