Skip to main content

gam_solve/reml/
firth.rs

1use super::*;
2use gam_linalg::matrix::symmetrize_in_place;
3use crate::mixture_link::fisher_weight_jet5_for_inverse_link;
4use gam_problem::InverseLink;
5
6pub(crate) const FIRTH_DERIVATIVE_PARALLEL_MIN_N: usize = 16_384;
7
8/// Reciprocal-condition-number floor below which the reduced Fisher information
9/// `I_r` is flagged as near-singular (a diagnostic warning only, not a hard
10/// gate). At `λ_min/λ_max < 1e-10` the SPD assumption on the identifiable
11/// subspace is numerically fragile and the exact pseudodet derivatives may be
12/// ill-conditioned near active-subspace boundaries.
13pub(crate) const FIRTH_REDUCED_FISHER_RCOND_WARN: f64 = 1e-10;
14
15/// β-dependent reduced-space pieces of the Firth/Jeffreys operator at the
16/// current `η`, produced by `FirthDenseOperator::firth_reduced_core` from a
17/// cached β-independent [`FirthDesignFactor`]. The full operator build consumes
18/// every field; the lightweight PIRLS-diagnostics path consumes only `w`, `w1`,
19/// `h_diag`, and `half_log_det`.
20struct FirthReducedCore {
21    w: Array1<f64>,
22    w1: Array1<f64>,
23    w2: Array1<f64>,
24    w3: Array1<f64>,
25    w4: Array1<f64>,
26    k_reduced: Array2<f64>,
27    half_log_det: f64,
28    h_diag: Array1<f64>,
29}
30
31/// Single-index sub-blocks of the exact mixed second directional derivative
32/// `D²H_φ[u,v]`, precomputed once against the fixed `eye` rhs used by the
33/// exact-Hessian TK outer loop. See
34/// [`FirthDenseOperator::tk_second_direction_eye_cache`] (#1575).
35pub(crate) struct FirthSecondDirEyeCache {
36    /// The fixed identity rhs (`p×p`), kept so per-pair `fast_ab(.., &eye)`
37    /// matmuls reproduce the original byte-for-byte.
38    eye: Array2<f64>,
39    /// `X·I` — index-independent.
40    eta_rhs: Array2<f64>,
41    /// `(Bᵀ P B-base)·I` — index-independent.
42    p_b_rhs: Array2<f64>,
43    /// Per-direction `apply_hadamard_gram(eta_rhs ⊙ b_uvec_i)`.
44    p_bx: Vec<Array2<f64>>,
45    /// Per-direction `apply_p_u(a_u_reduced_i, w' ⊙ eta_rhs)`.
46    pu_qv: Vec<Array2<f64>>,
47}
48
49impl<'a> RemlState<'a> {
50    pub(crate) fn xt_diag_x_dense_into(
51        x: &Array2<f64>,
52        diag: &Array1<f64>,
53        weighted: &mut Array2<f64>,
54    ) -> Array2<f64> {
55        super::assembly::xt_diag_x_dense_into(x, diag, weighted)
56    }
57
58    #[inline]
59    pub(crate) fn parallelize_firth_derivative_rows(n: usize) -> bool {
60        n >= FIRTH_DERIVATIVE_PARALLEL_MIN_N && rayon::current_num_threads() > 1
61    }
62
63    pub(crate) fn row_scale(x: &Array2<f64>, scale: &Array1<f64>) -> Array2<f64> {
64        let mut out = Array2::<f64>::zeros(x.raw_dim());
65        super::assembly::row_scale_dense_into(x, scale, &mut out);
66        out
67    }
68
69    #[inline]
70    pub(crate) fn dense_product_likely_uses_inner_parallelism(
71        m: usize,
72        n: usize,
73        k: usize,
74    ) -> bool {
75        // Keep this in sync with faer_ndarray::matmul_parallelism.  When a
76        // dense product is large enough for faer/BLAS-style internal
77        // parallelism, do not also wrap sibling products in rayon::join: that
78        // can oversubscribe CPU threads and slow down the REML hot path.
79        const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
80        const PAR_MIN_LONG_DIM: usize = 256;
81        let flop_scale = m.saturating_mul(n).saturating_mul(k);
82        let long_dim = m.max(n).max(k);
83        flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM
84    }
85
86    #[inline]
87    pub(crate) fn should_join_independent_dense_products(
88        products: &[(usize, usize, usize)],
89    ) -> bool {
90        const JOIN_MIN_TOTAL_FLOP_SCALE: usize = 128 * 1024;
91        if rayon::current_num_threads() <= 1 {
92            return false;
93        }
94        let mut total_flop_scale = 0usize;
95        for &(m, n, k) in products {
96            if Self::dense_product_likely_uses_inner_parallelism(m, n, k) {
97                return false;
98            }
99            total_flop_scale =
100                total_flop_scale.saturating_add(m.saturating_mul(n).saturating_mul(k));
101        }
102        total_flop_scale >= JOIN_MIN_TOTAL_FLOP_SCALE
103    }
104
105    /// Undo the fixed observation-weight row scale used by Firth reduced
106    /// designs.
107    ///
108    /// Keep this distinct from PIRLS's sparse-SpGEMM `sqrt_weights` cache in
109    /// `solver/pirls.rs`: PIRLS materializes roots of the current working
110    /// weights for a Gram factorization, while this uses stored fixed
111    /// case-weight roots and their reciprocals to map reduced design
112    /// derivatives back to raw design space.
113    #[inline]
114    pub(crate) fn scale_rows_by_inverse_observation_weight_sqrt(
115        out: &mut Array2<f64>,
116        observation_weight_sqrt: Option<&Array1<f64>>,
117    ) {
118        let Some(scale) = observation_weight_sqrt else {
119            return;
120        };
121        super::assembly::row_scale_dense_in_place_by_inverse_positive_or_zero(out, scale);
122    }
123
124    /// GLM Fisher working-weight 5-jet for the requested inverse link. For
125    /// standard Logit this is byte-identical to the historical
126    /// `logit_inverse_link_jet5(eta).d1..d5` path that the Firth operator used
127    /// before the weights were generalized to arbitrary inverse links.
128    #[inline]
129    pub(crate) fn fisher_weight_derivatives(
130        link: &InverseLink,
131        eta: f64,
132    ) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
133        fisher_weight_jet5_for_inverse_link(link, eta)
134    }
135
136    #[inline]
137    pub(crate) fn cholesky_pivots_are_numerically_resolved(chol_diag: &Array1<f64>) -> bool {
138        let mut min_pivot_sq = f64::INFINITY;
139        let mut max_pivot_sq = 0.0_f64;
140        for &pivot in chol_diag {
141            if !pivot.is_finite() || pivot <= 0.0 {
142                return false;
143            }
144            let pivot_sq = pivot * pivot;
145            min_pivot_sq = min_pivot_sq.min(pivot_sq);
146            max_pivot_sq = max_pivot_sq.max(pivot_sq);
147        }
148        if !min_pivot_sq.is_finite() {
149            return false;
150        }
151        let scale = max_pivot_sq.max(1.0);
152        let floor = (chol_diag.len().max(1) as f64) * f64::EPSILON * scale;
153        min_pivot_sq > floor
154    }
155
156    pub(crate) fn reduced_fisher_inverse_and_half_logdet(
157        fisher_reduced: &Array2<f64>,
158    ) -> Result<(Array2<f64>, f64), EstimationError> {
159        let r = fisher_reduced.nrows();
160        assert_eq!(r, fisher_reduced.ncols());
161        let mut k_reduced = Array2::<f64>::zeros((r, r));
162        if r == 0 {
163            return Ok((k_reduced, 0.0));
164        }
165
166        if let Ok(chol) = fisher_reduced.cholesky(Side::Lower) {
167            let chol_diag = chol.diag();
168            if Self::cholesky_pivots_are_numerically_resolved(&chol_diag) {
169                let half_log_det = chol_diag.iter().map(|d| d.ln()).sum::<f64>();
170                for col in 0..r {
171                    let mut e_col = Array1::<f64>::zeros(r);
172                    e_col[col] = 1.0;
173                    let solved = chol.solvevec(&e_col);
174                    k_reduced.column_mut(col).assign(&solved);
175                }
176                return Ok((k_reduced, half_log_det));
177            }
178        }
179
180        let (evals_ir, evecs_ir) = fisher_reduced
181            .eigh(Side::Lower)
182            .map_err(EstimationError::EigendecompositionFailed)?;
183        let max_eval = evals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
184        let tol = (r.max(1) as f64) * f64::EPSILON * max_eval;
185        let mut kept_positive_direction = false;
186        let mut half_log_det = 0.0_f64;
187        for (eig_idx, &eig) in evals_ir.iter().enumerate() {
188            if eig > tol {
189                kept_positive_direction = true;
190                half_log_det += 0.5 * eig.ln();
191                let inv = eig.recip();
192                let vec = evecs_ir.column(eig_idx).to_owned();
193                for row in 0..r {
194                    for col in 0..r {
195                        k_reduced[[row, col]] += inv * vec[row] * vec[col];
196                    }
197                }
198            }
199        }
200        if !kept_positive_direction {
201            return Err(EstimationError::ModelIsIllConditioned {
202                condition_number: f64::INFINITY,
203            });
204        }
205        Ok((k_reduced, half_log_det))
206    }
207
208    pub(crate) fn fill_fisher_weight_derivative_arrays(
209        link: &InverseLink,
210        eta: &Array1<f64>,
211        w: &mut Array1<f64>,
212        w1: &mut Array1<f64>,
213        w2: &mut Array1<f64>,
214        w3: &mut Array1<f64>,
215        w4: &mut Array1<f64>,
216    ) -> Result<(), EstimationError> {
217        assert_eq!(eta.len(), w.len());
218        assert_eq!(eta.len(), w1.len());
219        assert_eq!(eta.len(), w2.len());
220        assert_eq!(eta.len(), w3.len());
221        assert_eq!(eta.len(), w4.len());
222
223        if Self::parallelize_firth_derivative_rows(eta.len()) {
224            let values: Result<Vec<_>, EstimationError> = eta
225                .par_iter()
226                .map(|&ei| Self::fisher_weight_derivatives(link, ei))
227                .collect();
228            for (i, (value, first, second, third, fourth)) in values?.into_iter().enumerate() {
229                w[i] = value;
230                w1[i] = first;
231                w2[i] = second;
232                w3[i] = third;
233                w4[i] = fourth;
234            }
235            return Ok(());
236        }
237        for i in 0..eta.len() {
238            let (value, first, second, third, fourth) =
239                Self::fisher_weight_derivatives(link, eta[i])?;
240            w[i] = value;
241            w1[i] = first;
242            w2[i] = second;
243            w3[i] = third;
244            w4[i] = fourth;
245        }
246        Ok(())
247    }
248
249    pub(crate) fn weighted_cross(
250        left: &Array2<f64>,
251        right: &Array2<f64>,
252        weights: &Array1<f64>,
253    ) -> Array2<f64> {
254        assert_eq!(left.nrows(), right.nrows());
255        assert_eq!(left.nrows(), weights.len());
256        super::assembly::weighted_cross_dense(left, right, weights)
257    }
258
259    pub(crate) fn trace_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
260        assert_eq!(a.nrows(), b.ncols());
261        assert_eq!(a.ncols(), b.nrows());
262        let elems = a.nrows().saturating_mul(a.ncols());
263        if elems >= 32 * 32 {
264            let aview = FaerArrayView::new(a);
265            let bview = FaerArrayView::new(b);
266            return faer_frob_inner(aview.as_ref(), bview.as_ref().transpose());
267        }
268        let m = a.nrows();
269        let n = a.ncols();
270        kahan_sum((0..m).map(|i| {
271            let mut acc = 0.0_f64;
272            for j in 0..n {
273                acc += a[[i, j]] * b[[j, i]];
274            }
275            acc
276        }))
277    }
278
279    pub(crate) fn reducedweighted_gram(z: &Array2<f64>, weights: &Array1<f64>) -> Array2<f64> {
280        // Returns Zᵀ diag(weights) Z (exact).
281        //
282        // Used for:
283        //   S = Zᵀ diag(v) Z               in Hadamard-Gram apply,
284        //   G_u = X_rᵀ diag(s_u) X_r       with s_u = w' ⊙ (Xu),
285        // both of which avoid constructing dense n×n intermediates.
286        let weighted = Self::row_scale(z, weights);
287        fast_atb(z, &weighted)
288    }
289
290    pub(crate) fn reduced_crossweighted_gram(
291        z_left: &Array2<f64>,
292        z_right: &Array2<f64>,
293        weights: &Array1<f64>,
294    ) -> Array2<f64> {
295        // Returns Z_leftᵀ diag(weights) Z_right (exact).
296        //
297        // This is used for explicit design-moving terms where left/right
298        // reduced designs differ (X_r vs X_{tau,r}) in Hadamard-Gram products.
299        let weighted = Self::row_scale(z_right, weights);
300        fast_atb(z_left, &weighted)
301    }
302
303    pub(crate) fn reduced_diag_gram(z: &Array2<f64>, a: &Array2<f64>) -> Array1<f64> {
304        // Returns diag(Z A Zᵀ), exact without forming dense n×n matrix.
305        //
306        // Identity:
307        //   [diag(Z A Zᵀ)]_i = z_iᵀ A z_i,
308        // where z_i is row i of Z. We compute this as rowwise dot(Z, Z A).
309        let za = fast_ab(z, a);
310        (z * &za).sum_axis(ndarray::Axis(1))
311    }
312
313    pub(crate) fn apply_hadamard_gram(
314        z: &Array2<f64>,
315        a_left: &Array2<f64>,
316        a_right: &Array2<f64>,
317        vec: &Array1<f64>,
318    ) -> Array1<f64> {
319        // Exact apply of:
320        //   y = ((Z A_left Z^T) ⊙ (Z A_right Z^T)) vec
321        // using Gram/Hadamard identity with S = Z^T diag(vec) Z:
322        //   y_i = z_i^T A_left S A_right z_i.
323        //
324        // This is the matrix-free kernel behind the Firth terms that would
325        // otherwise require dense P = M⊙M and M⊙N_u products.
326        // Complexity is O(n r^2) with r = rank(X), no n×n storage.
327        let s = Self::reducedweighted_gram(z, vec);
328        let left_s = a_left.dot(&s);
329        let t = left_s.dot(a_right);
330        Self::reduced_diag_gram(z, &t)
331    }
332
333    pub(crate) fn apply_hadamard_gram_to_matrix(
334        z: &Array2<f64>,
335        a_left: &Array2<f64>,
336        a_right: &Array2<f64>,
337        mat: &Array2<f64>,
338    ) -> Array2<f64> {
339        // Columnwise extension of apply_hadamard_gram:
340        //   out[:,j] = ((Z A_left Zᵀ) ⊙ (Z A_right Zᵀ)) mat[:,j].
341        //
342        // In the Firth derivatives this is used with:
343        //   A_left = K_r, A_right = K_r        for Hw⊙Hw actions,
344        //   A_left = K_r, A_right = A_u        for Hw⊙Nbar_u actions,
345        // and symmetric variants in mixed second-direction terms.
346        let mut out = Array2::<f64>::zeros(mat.raw_dim());
347        for col in 0..mat.ncols() {
348            let v = mat.column(col).to_owned();
349            let y = Self::apply_hadamard_gram(z, a_left, a_right, &v);
350            out.column_mut(col).assign(&y);
351        }
352        out
353    }
354
355    /// Link-aware dense Firth/Jeffreys builder. The REML callsites resolve the
356    /// Fisher-weight link via `reml_robust_jeffreys_link` and pass it here;
357    /// standard Logit reproduces the historical logit-pinned build byte-for-byte,
358    /// and stateful links flow through the same inverse-link derivative path.
359    pub(super) fn build_firth_dense_operator_for_link(
360        link: &InverseLink,
361        x_dense: &Array2<f64>,
362        eta: &Array1<f64>,
363        observation_weights: ndarray::ArrayView1<'_, f64>,
364    ) -> Result<FirthDenseOperator, EstimationError> {
365        FirthDenseOperator::build_with_observation_weights_impl(
366            link,
367            x_dense,
368            eta,
369            Some(observation_weights),
370        )
371    }
372
373    pub(super) fn firth_exact_tau_kernel(
374        op: &FirthDenseOperator,
375        x_tau: &Array2<f64>,
376        beta: &Array1<f64>,
377        include_hphi_tau_kernel: bool,
378    ) -> FirthTauExactKernel {
379        op.exact_tau_kernel(x_tau, beta, include_hphi_tau_kernel)
380    }
381
382    pub(super) fn firth_hphi_tau_partial_apply(
383        op: &FirthDenseOperator,
384        x_tau: &Array2<f64>,
385        kernel: &FirthTauPartialKernel,
386        rhs: &Array2<f64>,
387    ) -> Array2<f64> {
388        op.hphi_tau_partial_apply(x_tau, kernel, rhs)
389    }
390}
391
392impl FirthDenseOperator {
393    pub(crate) fn canonicalize_basis_column_signs(q_basis: &mut Array2<f64>) {
394        for col in 0..q_basis.ncols() {
395            let mut pivot_row = 0usize;
396            let mut pivot_abs = 0.0_f64;
397            for row in 0..q_basis.nrows() {
398                let value = q_basis[[row, col]];
399                let abs_value = value.abs();
400                if abs_value > pivot_abs {
401                    pivot_abs = abs_value;
402                    pivot_row = row;
403                }
404            }
405            if pivot_abs > 0.0 && q_basis[[pivot_row, col]] < 0.0 {
406                q_basis.column_mut(col).mapv_inplace(|v| -v);
407            }
408        }
409    }
410
411    pub(crate) fn identifiable_subspace_basis_from_gram(
412        gram: &Array2<f64>,
413    ) -> Result<(Array2<f64>, Array1<f64>), EstimationError> {
414        let p = gram.nrows();
415        assert_eq!(p, gram.ncols());
416        if p == 0 {
417            return Ok((Array2::<f64>::eye(0), Array1::<f64>::zeros(0)));
418        }
419
420        let (evals, evecs) = gram
421            .eigh(Side::Lower)
422            .map_err(EstimationError::EigendecompositionFailed)?;
423        let max_eval = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
424        let tol = (p.max(1) as f64) * f64::EPSILON * max_eval;
425        let mut keep: Vec<usize> = evals
426            .iter()
427            .enumerate()
428            .filter_map(|(i, &value)| if value > tol { Some(i) } else { None })
429            .collect();
430        if keep.is_empty() {
431            return Err(EstimationError::ModelIsIllConditioned {
432                condition_number: f64::INFINITY,
433            });
434        }
435
436        // Use one orthonormal identifiable basis for both the full-rank and
437        // rank-deficient cases. Sorting retained modes by descending design
438        // energy makes the representation deterministic up to eigenspace
439        // degeneracy; fixing the column signs removes the remaining trivial
440        // sign ambiguity.
441        keep.sort_by(|&lhs, &rhs| evals[rhs].total_cmp(&evals[lhs]));
442        let r = keep.len();
443        let mut q_basis = Array2::<f64>::zeros((p, r));
444        let mut metric_spectrum = Array1::<f64>::zeros(r);
445        for (col_idx, eig_idx) in keep.into_iter().enumerate() {
446            q_basis.column_mut(col_idx).assign(&evecs.column(eig_idx));
447            metric_spectrum[col_idx] = evals[eig_idx];
448        }
449        Self::canonicalize_basis_column_signs(&mut q_basis);
450        Ok((q_basis, metric_spectrum))
451    }
452
453    #[inline]
454    pub(crate) fn trace_diag_product(diag: &Array1<f64>, matrix: &Array2<f64>) -> f64 {
455        assert_eq!(diag.len(), matrix.nrows());
456        assert_eq!(matrix.nrows(), matrix.ncols());
457        kahan_sum((0..diag.len()).map(|i| diag[i] * matrix[[i, i]]))
458    }
459
460    pub fn build_for_link(
461        link: &InverseLink,
462        x_dense: &Array2<f64>,
463        eta: &Array1<f64>,
464    ) -> Result<FirthDenseOperator, EstimationError> {
465        Self::build_with_observation_weights_impl(link, x_dense, eta, None)
466    }
467
468    pub fn build_with_observation_weights_for_link(
469        link: &InverseLink,
470        x_dense: &Array2<f64>,
471        eta: &Array1<f64>,
472        observation_weights: ndarray::ArrayView1<'_, f64>,
473    ) -> Result<FirthDenseOperator, EstimationError> {
474        Self::build_with_observation_weights_impl(link, x_dense, eta, Some(observation_weights))
475    }
476
477    /// Build the β-independent (design-only) factor of the Firth/Jeffreys
478    /// operator: the identifiable-subspace basis `Q`, the reduced design
479    /// `X_r = A^{1/2} X Q`, the retained design-Gram spectrum `S_r`, and the raw
480    /// design/transpose. This is the O(n·p²) Gram + O(p³) eigendecomposition +
481    /// the n×p design clones. None of it depends on `η`/β, so it is computed
482    /// ONCE per inner PIRLS solve and reused across Newton iterations (#1575).
483    ///
484    /// `build_with_observation_weights_impl` is exactly this factor followed by
485    /// the per-η remainder, so existing callers stay bit-for-bit identical.
486    pub(crate) fn build_design_factor_with_observation_weights(
487        x_dense: &Array2<f64>,
488        observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
489    ) -> Result<FirthDesignFactor, EstimationError> {
490        let n = x_dense.nrows();
491        let observation_weight_sqrt = if let Some(weights) = observation_weights {
492            if weights.len() != n {
493                crate::bail_invalid_estim!(
494                    "Firth operator observation weight length {} != number of rows {}",
495                    weights.len(),
496                    n
497                );
498            }
499            let mut sqrt = Array1::<f64>::zeros(n);
500            for i in 0..n {
501                let weight = weights[i];
502                if !weight.is_finite() || weight < 0.0 {
503                    crate::bail_invalid_estim!(
504                        "Firth operator requires finite nonnegative observation weights, got {} at row {}",
505                        weight,
506                        i
507                    );
508                }
509                sqrt[i] = weight.sqrt();
510            }
511            Some(sqrt)
512        } else {
513            None
514        };
515        let basis_design = if let Some(scale) = observation_weight_sqrt.as_ref() {
516            RemlState::row_scale(x_dense, scale)
517        } else {
518            x_dense.clone()
519        };
520        // X̃ᵀX̃ Gram → identifiable-subspace basis Q and retained spectrum S_r.
521        let gram = fast_atb(&basis_design, &basis_design);
522        let (q_basis, metric_spectrum) = Self::identifiable_subspace_basis_from_gram(&gram)?;
523        let x_reduced = fast_ab(&basis_design, &q_basis);
524        let r = q_basis.ncols();
525        let mut x_metric_reduced_inv_diag = Array1::<f64>::zeros(r);
526        for col in 0..r {
527            x_metric_reduced_inv_diag[col] = metric_spectrum[col].recip();
528        }
529        let x_dense_t = x_dense.t().to_owned();
530        Ok(FirthDesignFactor {
531            x_dense: x_dense.clone(),
532            x_dense_t,
533            q_basis,
534            x_reduced,
535            observation_weight_sqrt,
536            metric_spectrum,
537            x_metric_reduced_inv_diag,
538            r,
539            n,
540        })
541    }
542
543    /// β-dependent reduced core shared by [`Self::build_from_design_factor`] and
544    /// [`Self::pirls_diagnostics_from_factor`]: from the cached design factor and
545    /// the current `η`, compute the Fisher-weight 5-jet, the reduced Fisher
546    /// inverse `K_r`, the identifiable-subspace half-log-determinant, and the hat
547    /// diagonal `h`. The operations and their order match the un-hoisted
548    /// `build_with_observation_weights_impl` exactly, so every consumer stays
549    /// bit-for-bit identical.
550    fn firth_reduced_core(
551        factor: &FirthDesignFactor,
552        link: &InverseLink,
553        eta: &Array1<f64>,
554    ) -> Result<FirthReducedCore, EstimationError> {
555        let n = factor.n;
556        if eta.len() != n {
557            crate::bail_invalid_estim!(
558                "Firth operator shape mismatch: nrows={}, eta_len={}",
559                n,
560                eta.len()
561            );
562        }
563        let r = factor.r;
564        let mut w = Array1::<f64>::zeros(n);
565        let mut w1 = Array1::<f64>::zeros(n);
566        let mut w2 = Array1::<f64>::zeros(n);
567        let mut w3 = Array1::<f64>::zeros(n);
568        let mut w4 = Array1::<f64>::zeros(n);
569        RemlState::fill_fisher_weight_derivative_arrays(
570            link, eta, &mut w, &mut w1, &mut w2, &mut w3, &mut w4,
571        )?;
572
573        // Reduced Fisher I_r = X_rᵀ W X_r on the identifiable subspace.
574        let fisher_reduced = gam_linalg::faer_ndarray::fast_xt_diag_x(&factor.x_reduced, &w);
575        if let Ok((eigvals_ir, _)) = fisher_reduced.eigh(Side::Lower) {
576            let max_ev = eigvals_ir.iter().copied().fold(0.0_f64, f64::max).max(1.0);
577            let min_ev = eigvals_ir
578                .iter()
579                .copied()
580                .filter(|v| v.is_finite() && *v > 0.0)
581                .fold(f64::INFINITY, f64::min);
582            if min_ev.is_finite() {
583                let rel = min_ev / max_ev;
584                if rel < FIRTH_REDUCED_FISHER_RCOND_WARN {
585                    log::warn!(
586                        "[REML/Firth] reduced Fisher I_r is near-singular (min/max={:.3e}/{:.3e}, rel={:.3e}); exact derivatives may be ill-conditioned near active-subspace boundaries.",
587                        min_ev,
588                        max_ev,
589                        rel
590                    );
591                }
592            }
593        }
594
595        let (k_reduced, mut half_log_det) = if r > 0 {
596            RemlState::reduced_fisher_inverse_and_half_logdet(&fisher_reduced)?
597        } else {
598            (Array2::<f64>::zeros((r, r)), 0.0)
599        };
600        if r > 0 {
601            for col in 0..r {
602                let metric_eig = factor.metric_spectrum[col];
603                half_log_det -= 0.5 * metric_eig.ln();
604            }
605        }
606        let h_diag = if r > 0 {
607            RemlState::reduced_diag_gram(&factor.x_reduced, &k_reduced)
608        } else {
609            Array1::<f64>::zeros(n)
610        };
611        Ok(FirthReducedCore {
612            w,
613            w1,
614            w2,
615            w3,
616            w4,
617            k_reduced,
618            half_log_det,
619            h_diag,
620        })
621    }
622
623    /// Rebuild the full Firth operator at a new `η` from a cached design factor.
624    /// Pure memoization: byte-identical to `build_with_observation_weights_impl`
625    /// for the same `(link, design, weights, η)`.
626    pub(crate) fn build_from_design_factor(
627        factor: &FirthDesignFactor,
628        link: &InverseLink,
629        eta: &Array1<f64>,
630    ) -> Result<FirthDenseOperator, EstimationError> {
631        let FirthReducedCore {
632            w,
633            w1,
634            w2,
635            w3,
636            w4,
637            k_reduced,
638            half_log_det,
639            h_diag,
640        } = Self::firth_reduced_core(factor, link, eta)?;
641        let b_base = RemlState::row_scale(&factor.x_dense, &w1);
642        let p_b_base = RemlState::apply_hadamard_gram_to_matrix(
643            &factor.x_reduced,
644            &k_reduced,
645            &k_reduced,
646            &b_base,
647        );
648        Ok(FirthDenseOperator {
649            x_dense: factor.x_dense.clone(),
650            x_dense_t: factor.x_dense_t.clone(),
651            q_basis: factor.q_basis.clone(),
652            x_reduced: factor.x_reduced.clone(),
653            observation_weight_sqrt: factor.observation_weight_sqrt.clone(),
654            k_reduced,
655            x_metric_reduced_inv_diag: factor.x_metric_reduced_inv_diag.clone(),
656            half_log_det,
657            h_diag,
658            w,
659            w1,
660            w2,
661            w3,
662            w4,
663            b_base,
664            p_b_base,
665        })
666    }
667
668    /// Compute ONLY the three PIRLS Firth diagnostics — `(hat_diag,
669    /// jeffreys_logdet, firth_score_shift)` — from a cached design factor at a
670    /// new `η`, skipping the per-iteration `B = diag(w') X` and `P·B` Hadamard
671    /// blocks that the inner PIRLS solve never consumes. Each output is the same
672    /// closed form the full operator's accessors return:
673    ///   hat_diag         = w ⊙ h_diag             (`pirls_hat_diag`),
674    ///   jeffreys_logdet  = half_log_det           (`jeffreys_logdet`),
675    ///   firth_score_shift= ½ (w'/w) ⊙ h_diag      (`pirls_firth_score_shift`),
676    /// so the result is bit-for-bit identical to building the full operator and
677    /// calling those accessors, at a fraction of the cost (#1575).
678    pub(crate) fn pirls_diagnostics_from_factor(
679        factor: &FirthDesignFactor,
680        link: &InverseLink,
681        eta: &Array1<f64>,
682    ) -> Result<(Array1<f64>, f64, Array1<f64>), EstimationError> {
683        let core = Self::firth_reduced_core(factor, link, eta)?;
684        let (w, w1, h_diag, half_log_det) =
685            (core.w, core.w1, core.h_diag, core.half_log_det);
686        // hat_diag = w ⊙ h_diag (matches `pirls_hat_diag`).
687        let hat_diag = &w * &h_diag;
688        // firth_score_shift_i = ½ (w'_i / w_i) h_diag_i for w_i > 0, else 0
689        // (matches `pirls_firth_score_shift`).
690        let mut score_shift = Array1::<f64>::zeros(w.len());
691        for i in 0..w.len() {
692            let wi = w[i];
693            if wi > 0.0 {
694                score_shift[i] = 0.5 * (w1[i] / wi) * h_diag[i];
695            }
696        }
697        Ok((hat_diag, half_log_det, score_shift))
698    }
699
700    pub(crate) fn build_with_observation_weights_impl(
701        link: &InverseLink,
702        x_dense: &Array2<f64>,
703        eta: &Array1<f64>,
704        observation_weights: Option<ndarray::ArrayView1<'_, f64>>,
705    ) -> Result<FirthDenseOperator, EstimationError> {
706        // Precompute dense Firth objects at current β̂ for:
707        //   Φ(β) = 0.5 log|Uᵀ W U|,
708        // where U is a canonical orthonormal basis of the identifiable
709        // subspace of A^{1/2} X.
710        //
711        // Identifiability note:
712        // For rank-deficient design matrices X, I = Xᵀ W X is singular for all β
713        // (assuming w_i > 0). The mathematically coherent Jeffreys/Firth term is
714        // therefore the identifiable-subspace form:
715        //   Φ(β) = 0.5 log|Uᵀ W U|
716        //        = 0.5 log|I_r(β)| - 0.5 log|S_r|,
717        // with
718        //   I_r = X_rᵀ W X_r,
719        //   S_r = X_rᵀ X_r,
720        //   U   = X_r S_r^{-1/2}.
721        // The beta differential is still
722        //   dΦ = 0.5 tr(I_+^† dI),
723        // because S_r is beta-independent for a fixed design.
724        //
725        // For binomial-logit with finite eta, 0 < w_i <= 1/4, so W is SPD and
726        // Null(X'WX)=Null(X). Therefore singular directions are structural
727        // (from X) and independent of beta, which is why fixed-Q reduced-space
728        // derivatives are exact in this regime.
729        //
730        // This implementation uses the fixed identifiable-basis route directly:
731        //   I_r = X_rᵀ W X_r,  S_r = X_rᵀ X_r,
732        //   I_+^† = Q I_r^{-1} Qᵀ,
733        //   Φ     = 0.5 (log|I_r| - log|S_r|).
734        //
735        // Why `eta` (not `mu`) enters here:
736        //   all logistic weight derivatives are functions of eta through
737        //   mu(eta)=sigmoid(eta), and exact derivative consistency is preserved by
738        //   generating (w, w', w'', w''', w'''') from one coherent eta source.
739        // Using eta avoids any mismatch from externally clamped/post-processed mu.
740        //
741        // We cache reduced operators so derivatives are exact but matrix-free in n:
742        //   K_r = I_r^{-1},  h = diag(X_r K_r X_rᵀ),  B = diag(w')X,
743        // along with logistic derivatives w', w'', w''', w'''' to evaluate:
744        //   H_φ      = ∇²_β Φ,
745        //   D H_φ[u],
746        //   D² H_φ[u,v]
747        // exactly via reduced-space products (no explicit high-order tensors).
748        //
749        // Fixed observation weights:
750        // When callers provide nonnegative case weights a_i that are constant in
751        // β, the Jeffreys information is
752        //   I(β) = Xᵀ diag(a_i w_i(η)) X.
753        // We fold those fixed a_i into the identifiable basis and reduced design
754        // via X̃ = diag(sqrt(a_i)) X, so all derivative formulas continue to use
755        // the same η-derivatives of the family Fisher weights w(η), w'(η), ....
756        //
757        // This routine is now a thin wrapper: it builds the β-independent design
758        // factor (Gram, identifiable basis Q, reduced design X_r, retained
759        // spectrum S_r — the O(n·p²) + O(p³) work) and then the β-dependent
760        // remainder at `eta`. The two helpers are split out so a single inner
761        // PIRLS solve can hoist the factor out of the per-Newton-iteration hot
762        // path (#1575) while every output here stays bit-for-bit identical.
763        //
764        // The eta-length check is kept here (before the factor build) to
765        // preserve the original error ordering for existing callers.
766        let n = x_dense.nrows();
767        if eta.len() != n {
768            crate::bail_invalid_estim!(
769                "Firth operator shape mismatch: nrows={}, eta_len={}",
770                n,
771                eta.len()
772            );
773        }
774        let factor =
775            Self::build_design_factor_with_observation_weights(x_dense, observation_weights)?;
776        Self::build_from_design_factor(&factor, link, eta)
777    }
778
779    #[inline]
780    pub(crate) fn jeffreys_logdet(&self) -> f64 {
781        self.half_log_det
782    }
783
784    /// Tangent-projected Jeffreys/Firth log-determinant `½ log|Zᵀ J Z|_+`,
785    /// where `J = Xᵀ A W(η) X` is the full p-space Fisher information at
786    /// the current `η` and `Z` is the `p × m` orthonormal basis of
787    /// `null(A_act)` produced by the active-constraint tangent projector.
788    ///
789    /// Identity: `Zᵀ J Z = (X̃ Z)ᵀ W (X̃ Z)` with `X̃ = A^{1/2} X` when
790    /// fixed observation weights are present, `X̃ = X` otherwise. The
791    /// projected log-pseudo-det uses the same positive-eigenvalue
792    /// threshold convention as the rest of the tangent-projected LAML
793    /// (`positive_eigenvalue_threshold` / `exact_pseudo_logdet`) so the
794    /// kernel that defines "active subspace" is consistent across the
795    /// objective, its gradient, and the Firth contribution.
796    pub(crate) fn jeffreys_logdet_projected(&self, z: ndarray::ArrayView2<'_, f64>) -> f64 {
797        use gam_linalg::faer_ndarray::{fast_ab, fast_xt_diag_x};
798        let p = self.x_dense.ncols();
799        assert_eq!(
800            z.nrows(),
801            p,
802            "jeffreys_logdet_projected: Z must have {} rows (β-space dim), got {}",
803            p,
804            z.nrows()
805        );
806        let m = z.ncols();
807        if m == 0 {
808            return 0.0;
809        }
810        // X·Z, then optional sqrt(A) row-scale → X̃·Z.
811        let z_owned = z.to_owned();
812        let xz = fast_ab(&self.x_dense, &z_owned);
813        let xtz = if let Some(scale) = self.observation_weight_sqrt.as_ref() {
814            RemlState::row_scale(&xz, scale)
815        } else {
816            xz
817        };
818        // J_T = (X̃ Z)ᵀ W (X̃ Z), symmetric m × m PSD.
819        let mut j_t = fast_xt_diag_x(&xtz, &self.w);
820        symmetrize_in_place(&mut j_t);
821        let (evals, _) = match j_t.eigh(Side::Lower) {
822            Ok(pair) => pair,
823            Err(_) => return f64::NEG_INFINITY,
824        };
825        let Some(evals_slice) = evals.as_slice() else {
826            return f64::NEG_INFINITY;
827        };
828        let threshold = super::reml_outer_engine::positive_eigenvalue_threshold(evals_slice);
829        0.5 * super::reml_outer_engine::exact_pseudo_logdet(evals_slice, threshold)
830    }
831
832    #[inline]
833    pub(crate) fn jeffreys_beta_gradient(&self) -> Array1<f64> {
834        // For I(β) = Xᵀ A W(η) X with fixed observation weights A,
835        //   ∂/∂β_j [0.5 log|I|]
836        //   = 0.5 Σ_i h_i w_i'(η_i) x_{ij},
837        // where h_i = [A^{1/2} X I^{-1} Xᵀ A^{1/2}]_{ii}.
838        0.5 * gam_linalg::faer_ndarray::fast_av(&self.x_dense_t, &(&self.w1 * &self.h_diag))
839    }
840
841    #[inline]
842    pub fn jeffreys_logdet_and_beta_gradient(&self) -> (f64, Array1<f64>) {
843        (self.jeffreys_logdet(), self.jeffreys_beta_gradient())
844    }
845
846    #[inline]
847    pub(crate) fn reduce_explicit_design(&self, x: &Array2<f64>) -> Array2<f64> {
848        let mut reduced = fast_ab(x, &self.q_basis);
849        if let Some(scale) = self.observation_weight_sqrt.as_ref() {
850            reduced = RemlState::row_scale(&reduced, scale);
851        }
852        reduced
853    }
854
855    pub(crate) fn direction_from_deta(&self, deta: Array1<f64>) -> FirthDirection {
856        // Directional building blocks for u:
857        //   δη_u = X u
858        //   I_u  = Xᵀ diag(w' ⊙ δη_u) X
859        //   T_u  = K I_u K
860        //   N_u  = X T_u Xᵀ
861        //   Dh[u] = -diag(N_u)
862        //   P_u   = D(M⊙M)[u] = -2(M⊙N_u)
863        // and B_u = diag(w'' ⊙ δη_u) X.
864        //
865        // In this implementation, active-subspace ambiguity is removed by fixed Q:
866        // K is represented by K_r = I_r^{-1} in reduced coordinates, so
867        //   A_u = K_r G_u K_r
868        // is exact for logit with finite eta (w_i > 0) and fixed rank(X).
869        // s_u is the diagonal weight for D I[u]:
870        //   D I[u] = Xᵀ diag(s_u) X,  s_u = w' ⊙ (X u).
871        let s_u = &self.w1 * &deta;
872        // G_u = X_rᵀ diag(s_u) X_r,  A_u = K_r G_u K_r.
873        // These are reduced-space forms of
874        //   I_u and T_u = K I_u K
875        // from the full-space derivation.
876        let g_u_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_u);
877        let k_g_u = self.k_reduced.dot(&g_u_reduced);
878        let a_u_reduced = k_g_u.dot(&self.k_reduced);
879        // Dh[u] = -diag(N_u),  N_u = X T_u Xᵀ, represented here as
880        //   N_u = Z A_u Zᵀ in weighted reduced coordinates.
881        let dh = -RemlState::reduced_diag_gram(&self.x_reduced, &a_u_reduced);
882        let b_uvec = &self.w2 * &deta;
883        FirthDirection {
884            deta,
885            g_u_reduced,
886            a_u_reduced,
887            dh,
888            b_uvec,
889        }
890    }
891
892    #[inline]
893    pub(crate) fn left_scaled_xt(&self, scale: &Array1<f64>, mat: &Array2<f64>) -> Array2<f64> {
894        fast_ab(&self.x_dense_t, &(mat * &scale.view().insert_axis(Axis(1))))
895    }
896
897    #[inline]
898    pub(crate) fn apply_p_u_to_matrix(
899        &self,
900        a_u_reduced: &Array2<f64>,
901        mat: &Array2<f64>,
902    ) -> Array2<f64> {
903        let mut out = RemlState::apply_hadamard_gram_to_matrix(
904            &self.x_reduced,
905            &self.k_reduced,
906            a_u_reduced,
907            mat,
908        );
909        out.mapv_inplace(|v| -2.0 * v);
910        out
911    }
912
913    pub(crate) fn hphi_direction_apply(
914        &self,
915        dir: &FirthDirection,
916        rhs: &Array2<f64>,
917    ) -> Array2<f64> {
918        let p = self.x_dense.ncols();
919        if rhs.nrows() != p {
920            return Array2::<f64>::zeros((p, rhs.ncols()));
921        }
922        if rhs.ncols() == 0 || p == 0 {
923            return Array2::<f64>::zeros((p, rhs.ncols()));
924        }
925        // Matrix-free apply of D(Hphi)[u] to a block V:
926        //   D(Hphi)[u] V
927        // = 0.5[ Xᵀ(c_u ⊙ (X V))
928        //       - B_uᵀ P (B V) - Bᵀ P (B_u V) - Bᵀ P_u (B V) ].
929        // This avoids dense p×p materialization and is used by sparse exact
930        // trace contractions through tr(H^{-1} ·).
931        let etav = fast_ab(&self.x_dense, rhs);
932        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
933        let m_qv = RemlState::apply_hadamard_gram_to_matrix(
934            &self.x_reduced,
935            &self.k_reduced,
936            &self.k_reduced,
937            &qv,
938        );
939        let buvec = &dir.b_uvec;
940        let m_buv = RemlState::apply_hadamard_gram_to_matrix(
941            &self.x_reduced,
942            &self.k_reduced,
943            &self.k_reduced,
944            &(&etav * &buvec.view().insert_axis(Axis(1))),
945        );
946        let p_u_qv = self.apply_p_u_to_matrix(&dir.a_u_reduced, &qv);
947        let c_u = &(&self.w3 * &dir.deta) * &self.h_diag + &(&self.w2 * &dir.dh);
948        let diag_term = self
949            .x_dense_t
950            .dot(&(&etav * &c_u.view().insert_axis(Axis(1))));
951        let term1 = self.left_scaled_xt(buvec, &m_qv);
952        let term2 = self.left_scaled_xt(&self.w1, &m_buv);
953        let term3 = self.left_scaled_xt(&self.w1, &p_u_qv);
954        0.5 * (diag_term - (term1 + term2 + term3))
955    }
956
957    pub(crate) fn hphi_direction(&self, dir: &FirthDirection) -> Array2<f64> {
958        let p = self.x_dense.ncols();
959        let eye = Array2::<f64>::eye(p);
960        let mut out = self.hphi_direction_apply(dir, &eye);
961        // Exact first directional derivative of H_φ:
962        //   D H_φ[u]
963        //   = 0.5 [ Xᵀ diag(c_u) X
964        //           - (B_uᵀ P B + Bᵀ P B_u + Bᵀ P_u B) ],
965        // where
966        //   c_u = w''' ⊙ δη_u ⊙ h + w'' ⊙ Dh[u],
967        //   B   = diag(w') X.
968        //
969        // Matrix-free contraction map used below:
970        //   Bᵀ P B      via apply_hadamard_gram_to_matrix(Z, K_r, K_r, B)
971        //   Bᵀ P_u B    via apply_hadamard_gram_to_matrix(Z, K_r, A_u, B), then *(-2)
972        // where P = M⊙M and P_u = -2(M⊙N_u), but M/N_u are never formed explicitly.
973        symmetrize_in_place(&mut out);
974        out
975    }
976
977    pub(crate) fn hphisecond_direction_apply(
978        &self,
979        u: &FirthDirection,
980        v: &FirthDirection,
981        rhs: &Array2<f64>,
982    ) -> Array2<f64> {
983        let p = self.x_dense.ncols();
984        if rhs.nrows() != p {
985            return Array2::<f64>::zeros((p, rhs.ncols()));
986        }
987        if rhs.ncols() == 0 || p == 0 {
988            return Array2::<f64>::zeros((p, rhs.ncols()));
989        }
990        // Exact mixed second directional derivative:
991        //   D² H_φ[u,v] = 0.5 [ Xᵀ diag(c_uv) X - D²J₂[u,v] ], J₂ = Bᵀ P B.
992        // Implemented with matrix identities for N_{u,v}, P_{u,v}, and the
993        // nine-term expansion of D²J₂[u,v].
994        //
995        // Because we parameterize Phi through fixed-rank I_r = X_rᵀ W X_r (SPD for
996        // finite-logit eta), this mixed derivative is evaluated on a smooth
997        // manifold without dynamic active-set switching in the Firth block.
998        //
999        // The nine contraction terms below are the explicit D²J₂[u,v] expansion,
1000        // each computed through reduced Hadamard-Gram operators.
1001        let deta_uv = &u.deta * &v.deta;
1002        // Mixed reduced Gram:
1003        //   G_uv = X_rᵀ diag(w'' ⊙ (Xu) ⊙ (Xv)) X_r.
1004        let s_uv = &self.w2 * &deta_uv;
1005        let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1006        let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1007        let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1008        let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1009        // Reduced form of:
1010        //   T_{u,v} = K I_{u,v} K - K Iv K I_u K - K I_u K Iv K.
1011        let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1012            - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1013            - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1014        let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1015        // Implements mixed diagonal coefficient:
1016        //   c_uv = w'''' ⊙ (Xu) ⊙ (Xv) ⊙ h
1017        //          + w''' ⊙ ((Xu) ⊙ Dh[v] + (Xv) ⊙ Dh[u])
1018        //          + w'' ⊙ D²h[u,v].
1019        let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1020            + &(&self.w3 * &(&u.deta * &v.dh))
1021            + &(&self.w3 * &(&v.deta * &u.dh))
1022            + &(&self.w2 * &d2h);
1023
1024        let eta_rhs = fast_ab(&self.x_dense, rhs);
1025        let diag_term = fast_ab(
1026            &self.x_dense_t,
1027            &(&eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1028        );
1029
1030        let b_uvvec = &self.w3 * &deta_uv;
1031        let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1032        let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1033
1034        // Linearity in the rhs argument lets us precompute the expensive
1035        // Hadamard-Gram operator on the full base blocks B, B_u, Bv, B_uv once,
1036        // then post-multiply by rhs. This preserves the exact operator while
1037        // avoiding repeated O(n r^2 c) work for every rhs block.
1038        let p_b_rhs = fast_ab(&self.p_b_base, rhs);
1039        let p_bu_rhs = RemlState::apply_hadamard_gram_to_matrix(
1040            &self.x_reduced,
1041            &self.k_reduced,
1042            &self.k_reduced,
1043            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1044        );
1045        let p_bv_rhs = RemlState::apply_hadamard_gram_to_matrix(
1046            &self.x_reduced,
1047            &self.k_reduced,
1048            &self.k_reduced,
1049            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1050        );
1051        let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1052            &self.x_reduced,
1053            &self.k_reduced,
1054            &self.k_reduced,
1055            &b_uv_base,
1056        );
1057        let p_buv_rhs = fast_ab(&p_buv_base, rhs);
1058
1059        let pv_b_rhs = self.apply_p_u_to_matrix(&v.a_u_reduced, &qv);
1060        let pv_bu_rhs = self.apply_p_u_to_matrix(
1061            &v.a_u_reduced,
1062            &(&eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1063        );
1064        let p_u_b_rhs = self.apply_p_u_to_matrix(&u.a_u_reduced, &qv);
1065        let p_u_bv_rhs = self.apply_p_u_to_matrix(
1066            &u.a_u_reduced,
1067            &(&eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1068        );
1069
1070        let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1071            &self.x_reduced,
1072            &u.a_u_reduced,
1073            &v.a_u_reduced,
1074            &self.b_base,
1075        );
1076        let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1077            &self.x_reduced,
1078            &self.k_reduced,
1079            &a_uv_reduced,
1080            &self.b_base,
1081        );
1082        let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1083        let p_uv_rhs = fast_ab(&p_uv_base, rhs);
1084
1085        // Nine-term expansion of D²J₂[u,v] with J₂ = Bᵀ P B.
1086        let d2_terms = [
1087            self.left_scaled_xt(&b_uvvec, &p_b_rhs),
1088            self.left_scaled_xt(&self.w1, &p_buv_rhs),
1089            self.left_scaled_xt(&u.b_uvec, &p_bv_rhs),
1090            self.left_scaled_xt(&v.b_uvec, &p_bu_rhs),
1091            self.left_scaled_xt(&u.b_uvec, &pv_b_rhs),
1092            self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1093            self.left_scaled_xt(&v.b_uvec, &p_u_b_rhs),
1094            self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1095            self.left_scaled_xt(&self.w1, &p_uv_rhs),
1096        ];
1097        let mut d2_j2 = Array2::<f64>::zeros((p, rhs.ncols()));
1098        for term in d2_terms {
1099            d2_j2 += &term;
1100        }
1101
1102        0.5 * (diag_term - d2_j2)
1103    }
1104
1105    /// Precompute, for a FIXED identity rhs, every sub-block of the mixed second
1106    /// directional derivative `D²H_φ[u,v]` that depends on a SINGLE direction
1107    /// index (or on nothing but the operator). The exact-Hessian TK outer loop
1108    /// (`tk_hessian_rho_canonical_logit`) evaluates `hphisecond_direction_apply`
1109    /// for every one of the `k(k+1)/2` penalty pairs against the same `eye` rhs;
1110    /// the four heavy single-index reduced Hadamard-Gram applies inside it
1111    /// (`p_bu_rhs`/`p_bv_rhs` and `p_u_b_rhs`/`pv_b_rhs`) therefore have only `k`
1112    /// distinct values but were rebuilt `O(k²)` times. Caching them once per
1113    /// index here turns that into `O(k)` of those O(n·r²·p) applies, with the
1114    /// per-pair work limited to the genuinely mixed (`u`,`v`) blocks. This is
1115    /// exact: each cached block is a pure function of `(operator, direction[i])`
1116    /// for the fixed `eye` rhs, so the contraction it feeds is bit-identical to
1117    /// `hphisecond_direction_apply(.., &eye)` (#1575).
1118    pub(crate) fn tk_second_direction_eye_cache(
1119        &self,
1120        dirs: &[FirthDirection],
1121    ) -> FirthSecondDirEyeCache {
1122        let p = self.x_dense.ncols();
1123        let eye = Array2::<f64>::eye(p);
1124        // eta_rhs = X·I and qv = w' ⊙ eta_rhs are rhs-only (index-independent).
1125        let eta_rhs = fast_ab(&self.x_dense, &eye);
1126        let qv = &eta_rhs * &self.w1.view().insert_axis(Axis(1));
1127        // p_b_rhs = (Bᵀ P B-base)·I is rhs-only; precompute it once.
1128        let p_b_rhs = fast_ab(&self.p_b_base, &eye);
1129        // Each direction's two single-index blocks are independent O(n·r²·p)
1130        // reduced Hadamard-Gram applies. Fan them across Rayon with the nested-BLAS
1131        // guard (inner faer GEMMs pinned to `Par::Seq`, no oversubscription) when
1132        // there are several directions AND more than one thread; with a single
1133        // direction (k=1) run serially so the inner GEMMs keep the global faer
1134        // pool instead of being pinned to `Par::Seq`. The result is collected in
1135        // direction order either way, so the cached blocks are identical to the
1136        // serial build — bit-for-bit at fixture scale, where the inner GEMMs are
1137        // already `Par::Seq` (#1575).
1138        let compute_blocks = |d: &FirthDirection| -> (Array2<f64>, Array2<f64>) {
1139            // p_b{u,v}_rhs: depends only on this direction's b_uvec.
1140            let p_b = RemlState::apply_hadamard_gram_to_matrix(
1141                &self.x_reduced,
1142                &self.k_reduced,
1143                &self.k_reduced,
1144                &(&eta_rhs * &d.b_uvec.view().insert_axis(Axis(1))),
1145            );
1146            // p_u_b_rhs / pv_b_rhs: depends only on a_u_reduced.
1147            let pu = self.apply_p_u_to_matrix(&d.a_u_reduced, &qv);
1148            (p_b, pu)
1149        };
1150        let (p_bx, pu_qv): (Vec<Array2<f64>>, Vec<Array2<f64>>) =
1151            if dirs.len() > 1 && rayon::current_num_threads() > 1 {
1152                use rayon::prelude::*;
1153                dirs.par_iter()
1154                    .map(|d| gam_problem::with_nested_parallel(|| compute_blocks(d)))
1155                    .unzip()
1156            } else {
1157                dirs.iter().map(compute_blocks).unzip()
1158            };
1159        FirthSecondDirEyeCache {
1160            eye,
1161            eta_rhs,
1162            p_b_rhs,
1163            p_bx,
1164            pu_qv,
1165        }
1166    }
1167
1168    /// Exact mixed second directional derivative `D²H_φ[u,v]` against the fixed
1169    /// `eye` rhs, reusing the single-index sub-blocks precomputed once by
1170    /// [`Self::tk_second_direction_eye_cache`]. Bit-identical to
1171    /// `hphisecond_direction_apply(&dirs[i], &dirs[j], &Array2::eye(p))`; only
1172    /// the redundant per-pair recomputation of the single-index blocks is
1173    /// removed (#1575).
1174    pub(crate) fn hphisecond_direction_apply_eye_cached(
1175        &self,
1176        cache: &FirthSecondDirEyeCache,
1177        dirs: &[FirthDirection],
1178        i: usize,
1179        j: usize,
1180    ) -> Array2<f64> {
1181        let u = &dirs[i];
1182        let v = &dirs[j];
1183        let p = self.x_dense.ncols();
1184        let cols = cache.eta_rhs.ncols();
1185        if p == 0 || cols == 0 {
1186            return Array2::<f64>::zeros((p, cols));
1187        }
1188        let deta_uv = &u.deta * &v.deta;
1189        let s_uv = &self.w2 * &deta_uv;
1190        let g_uv_reduced = RemlState::reducedweighted_gram(&self.x_reduced, &s_uv);
1191        let k_g_uv = self.k_reduced.dot(&g_uv_reduced);
1192        let k_gv = self.k_reduced.dot(&v.g_u_reduced);
1193        let k_g_u = self.k_reduced.dot(&u.g_u_reduced);
1194        let a_uv_reduced = k_g_uv.dot(&self.k_reduced)
1195            - k_gv.dot(&k_g_u).dot(&self.k_reduced)
1196            - k_g_u.dot(&k_gv).dot(&self.k_reduced);
1197        let d2h = -RemlState::reduced_diag_gram(&self.x_reduced, &a_uv_reduced);
1198        let c_uv = &(&(&self.w4 * &deta_uv) * &self.h_diag)
1199            + &(&self.w3 * &(&u.deta * &v.dh))
1200            + &(&self.w3 * &(&v.deta * &u.dh))
1201            + &(&self.w2 * &d2h);
1202
1203        let eta_rhs = &cache.eta_rhs;
1204        let diag_term = fast_ab(
1205            &self.x_dense_t,
1206            &(eta_rhs * &c_uv.view().insert_axis(Axis(1))),
1207        );
1208
1209        let b_uvvec = &self.w3 * &deta_uv;
1210        let b_uv_base = &self.x_dense * &b_uvvec.view().insert_axis(Axis(1));
1211
1212        // Single-index blocks reused from the cache (the O(k²)→O(k) win).
1213        let p_b_rhs = &cache.p_b_rhs;
1214        let p_bu_rhs = &cache.p_bx[i];
1215        let p_bv_rhs = &cache.p_bx[j];
1216        let p_u_b_rhs = &cache.pu_qv[i];
1217        let pv_b_rhs = &cache.pu_qv[j];
1218
1219        // Genuinely mixed (u,v) blocks — must be rebuilt per pair.
1220        let p_buv_base = RemlState::apply_hadamard_gram_to_matrix(
1221            &self.x_reduced,
1222            &self.k_reduced,
1223            &self.k_reduced,
1224            &b_uv_base,
1225        );
1226        let p_buv_rhs = fast_ab(&p_buv_base, &cache.eye);
1227
1228        let pv_bu_rhs = self.apply_p_u_to_matrix(
1229            &v.a_u_reduced,
1230            &(eta_rhs * &u.b_uvec.view().insert_axis(Axis(1))),
1231        );
1232        let p_u_bv_rhs = self.apply_p_u_to_matrix(
1233            &u.a_u_reduced,
1234            &(eta_rhs * &v.b_uvec.view().insert_axis(Axis(1))),
1235        );
1236
1237        let p_nu_nv_base = RemlState::apply_hadamard_gram_to_matrix(
1238            &self.x_reduced,
1239            &u.a_u_reduced,
1240            &v.a_u_reduced,
1241            &self.b_base,
1242        );
1243        let p_hw_nuv_base = RemlState::apply_hadamard_gram_to_matrix(
1244            &self.x_reduced,
1245            &self.k_reduced,
1246            &a_uv_reduced,
1247            &self.b_base,
1248        );
1249        let p_uv_base = 2.0 * p_nu_nv_base - 2.0 * p_hw_nuv_base;
1250        let p_uv_rhs = fast_ab(&p_uv_base, &cache.eye);
1251
1252        let d2_terms = [
1253            self.left_scaled_xt(&b_uvvec, p_b_rhs),
1254            self.left_scaled_xt(&self.w1, &p_buv_rhs),
1255            self.left_scaled_xt(&u.b_uvec, p_bv_rhs),
1256            self.left_scaled_xt(&v.b_uvec, p_bu_rhs),
1257            self.left_scaled_xt(&u.b_uvec, pv_b_rhs),
1258            self.left_scaled_xt(&self.w1, &pv_bu_rhs),
1259            self.left_scaled_xt(&v.b_uvec, p_u_b_rhs),
1260            self.left_scaled_xt(&self.w1, &p_u_bv_rhs),
1261            self.left_scaled_xt(&self.w1, &p_uv_rhs),
1262        ];
1263        let mut d2_j2 = Array2::<f64>::zeros((p, cols));
1264        for term in d2_terms {
1265            d2_j2 += &term;
1266        }
1267
1268        0.5 * (diag_term - d2_j2)
1269    }
1270
1271    pub(super) fn rowwise_dot(a: &Array2<f64>, b: &Array2<f64>) -> Array1<f64> {
1272        assert_eq!(a.nrows(), b.nrows());
1273        assert_eq!(a.ncols(), b.ncols());
1274        let mut out = Array1::<f64>::zeros(a.nrows());
1275        for i in 0..a.nrows() {
1276            let mut acc = 0.0_f64;
1277            for j in 0..a.ncols() {
1278                acc += a[[i, j]] * b[[i, j]];
1279            }
1280            out[i] = acc;
1281        }
1282        out
1283    }
1284
1285    pub(super) fn rowwise_bilinear(
1286        a: &Array2<f64>,
1287        m: &Array2<f64>,
1288        b: &Array2<f64>,
1289    ) -> Array1<f64> {
1290        // Returns vector with entries a_iᵀ M b_i for each row i.
1291        assert_eq!(a.nrows(), b.nrows());
1292        assert_eq!(a.ncols(), m.nrows());
1293        assert_eq!(b.ncols(), m.ncols());
1294        let am = fast_ab(a, m);
1295        Self::rowwise_dot(&am, b)
1296    }
1297
1298    pub(crate) fn dot_i_and_h_from_reduced(
1299        &self,
1300        x_tau_reduced: &Array2<f64>,
1301        deta: &Array1<f64>,
1302    ) -> (Array2<f64>, Array1<f64>) {
1303        // Reduced Fisher directional derivative under fixed identifiable basis:
1304        //   I_r = X_r' W X_r
1305        //   I_r,tau = X_{r,tau}' W X_r + X_r' W X_{r,tau} + X_r' W_tau X_r
1306        // with W_tau = diag(w' ⊙ eta_tau).
1307        //
1308        // Leverage derivative used by Firth score partial:
1309        //   h_i = x_{r,i}' K_r x_{r,i}, K_r = I_r^{-1}
1310        //   h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r')
1311        //   K_{r,tau} = -K_r I_{r,tau} K_r.
1312        //
1313        // This is exactly the fixed-beta directional derivative required by
1314        //   (gphi)_tau and Phi_tau in the Jeffreys/Firth design-moving path:
1315        //   I_{r,tau}|beta = X_{r,tau}' W X_r + X_r' W X_{r,tau}
1316        //                    + X_r' diag(w' ⊙ eta_tau|beta) X_r,
1317        //   eta_tau|beta = X_tau beta.
1318        //
1319        // We return:
1320        //   dot_i  = I_{r,tau}|beta,
1321        //   dot_h  = h_tau|beta.
1322        let dw = &self.w1 * deta;
1323        let dot_i = RemlState::weighted_cross(x_tau_reduced, &self.x_reduced, &self.w)
1324            + RemlState::weighted_cross(&self.x_reduced, x_tau_reduced, &self.w)
1325            + gam_linalg::faer_ndarray::fast_xt_diag_x(&self.x_reduced, &dw);
1326
1327        let dot_k = -self.k_reduced.dot(&dot_i).dot(&self.k_reduced);
1328        let x_tauk = fast_ab(x_tau_reduced, &self.k_reduced);
1329        let dot_h_explicit = 2.0 * Self::rowwise_dot(&x_tauk, &self.x_reduced);
1330        let dot_h_implicit = Self::rowwise_dot(&fast_ab(&self.x_reduced, &dot_k), &self.x_reduced);
1331        let dot_h = dot_h_explicit + dot_h_implicit;
1332        (dot_i, dot_h)
1333    }
1334
1335    pub(crate) fn exact_tau_kernel(
1336        &self,
1337        x_tau: &Array2<f64>,
1338        beta: &Array1<f64>,
1339        include_hphi_tau_kernel: bool,
1340    ) -> FirthTauExactKernel {
1341        // Shared exact tau-partial bundle used by both dense and sparse paths:
1342        //   (gphi)_tau | beta-fixed,
1343        //   Phi_tau | beta-fixed,
1344        // and optional H_{phi,tau}|beta kernel for later matrix-free applies.
1345        //
1346        // Closed forms (reduced Fisher, fixed active subspace):
1347        //   Phi = 0.5 log|I_r| - 0.5 log|S_r|,
1348        //   I_r = X_r' W X_r, K_r = I_r^{-1},
1349        //   S_r = X_r' X_r,   diag(G_r) = diag(S_r^{-1}),
1350        //   Phi_tau|beta = 0.5 tr(K_r I_{r,tau}) - 0.5 tr(G_r S_{r,tau}).
1351        // In the canonical reduced basis used here, G_r is diagonal.
1352        //
1353        //   (gphi)_tau = Phi_beta,tau
1354        //               = 0.5 X_tau' (w1 .* h)
1355        //                 + 0.5 X'((w2 .* eta_tau) .* h + w1 .* h_tau),
1356        //   where
1357        //     h_i = x_{r,i}' K_r x_{r,i},
1358        //     h_tau = 2*diag(X_{r,tau} K_r X_r') + diag(X_r K_{r,tau} X_r'),
1359        //     K_{r,tau} = -K_r I_{r,tau} K_r.
1360        //
1361        // Phi_beta,tau is unchanged by the -0.5 log|S_r| term because S_r does
1362        // not depend on beta. Only Phi_tau gets the explicit basis-drift
1363        // subtraction.
1364        let deta_partial = gam_linalg::faer_ndarray::fast_av(x_tau, beta);
1365        let x_tau_reduced = self.reduce_explicit_design(x_tau);
1366        let (dot_i_partial, dot_h_partial) =
1367            self.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
1368        let dot_s_partial =
1369            fast_atb(&x_tau_reduced, &self.x_reduced) + fast_atb(&self.x_reduced, &x_tau_reduced);
1370
1371        let first = 0.5 * gam_linalg::faer_ndarray::fast_atv(x_tau, &(&self.w1 * &self.h_diag));
1372        let secondvec =
1373            &(&(&self.w2 * &deta_partial) * &self.h_diag) + &(&self.w1 * &dot_h_partial);
1374        let second = 0.5 * gam_linalg::faer_ndarray::fast_atv(&self.x_dense, &secondvec);
1375        let gphi_tau = first + second;
1376        let phi_tau_partial = 0.5 * RemlState::trace_product(&self.k_reduced, &dot_i_partial)
1377            - 0.5 * Self::trace_diag_product(&self.x_metric_reduced_inv_diag, &dot_s_partial);
1378
1379        let tau_kernel = if include_hphi_tau_kernel {
1380            Some(self.hphi_tau_partial_prepare_from_partials(
1381                x_tau_reduced,
1382                &deta_partial,
1383                dot_h_partial,
1384                dot_i_partial,
1385            ))
1386        } else {
1387            None
1388        };
1389        FirthTauExactKernel {
1390            gphi_tau,
1391            phi_tau_partial,
1392            tau_kernel,
1393        }
1394    }
1395
1396    pub(crate) fn hphi_tau_partial_prepare_from_partials(
1397        &self,
1398        x_tau_reduced: Array2<f64>,
1399        deta_partial: &Array1<f64>,
1400        dot_h_partial: Array1<f64>,
1401        dot_i_partial: Array2<f64>,
1402    ) -> FirthTauPartialKernel {
1403        let dotw1 = &self.w2 * deta_partial;
1404        let dotw2 = &self.w3 * deta_partial;
1405        let dot_k = -self.k_reduced.dot(&dot_i_partial).dot(&self.k_reduced);
1406        FirthTauPartialKernel {
1407            deta_partial: deta_partial.clone(),
1408            dotw1,
1409            dotw2,
1410            dot_h_partial,
1411            x_tau_reduced,
1412            dot_i_partial,
1413            dot_k_reduced: dot_k,
1414        }
1415    }
1416
1417    pub(crate) fn d_beta_hphi_tau_partial_dense(
1418        &self,
1419        x_tau: &Array2<f64>,
1420        beta: &Array1<f64>,
1421        beta_direction: &Array1<f64>,
1422    ) -> Option<Array2<f64>> {
1423        if x_tau.nrows() != self.x_dense.nrows() || x_tau.ncols() != beta.len() {
1424            return None;
1425        }
1426        if !x_tau.iter().any(|value| *value != 0.0) {
1427            return None;
1428        }
1429        let tau_bundle = self.exact_tau_kernel(x_tau, beta, true);
1430        let tau_kernel = tau_bundle.tau_kernel?;
1431        let firth_direction =
1432            self.direction_from_deta(gam_linalg::faer_ndarray::fast_av(&self.x_dense, beta_direction));
1433        let x_tau_v = gam_linalg::faer_ndarray::fast_av(x_tau, beta_direction);
1434        let kernel = self.d_beta_hphi_tau_partial_prepare_from_partials(
1435            &tau_kernel,
1436            &tau_kernel.deta_partial,
1437            &tau_kernel.dot_i_partial,
1438            &firth_direction,
1439            &x_tau_v,
1440        );
1441        let eye = Array2::<f64>::eye(beta_direction.len());
1442        Some(self.d_beta_hphi_tau_partial_apply(x_tau, &kernel, &eye))
1443    }
1444
1445    pub(crate) fn apply_pbar_to_matrix(&self, mat: &Array2<f64>) -> Array2<f64> {
1446        // Applies P̄ = (X_r K_r X_rᵀ)⊙(X_r K_r X_rᵀ) to each column of mat.
1447        RemlState::apply_hadamard_gram_to_matrix(
1448            &self.x_reduced,
1449            &self.k_reduced,
1450            &self.k_reduced,
1451            mat,
1452        )
1453    }
1454
1455    pub(crate) fn apply_mtau_to_matrix(
1456        &self,
1457        kernel: &FirthTauPartialKernel,
1458        mat: &Array2<f64>,
1459    ) -> Array2<f64> {
1460        // Exact apply of
1461        //   M_tau = d/dtau[(P⊙P)]|_{beta fixed} = 2(P⊙P_tau)
1462        // without building dense n×n objects.
1463        //
1464        // Decomposition:
1465        //   P = Z K Zᵀ, Z = X_r
1466        //   P_tau = Z_tau K Zᵀ + Z K Z_tauᵀ + Z dotK Zᵀ
1467        // and for each vector v:
1468        //   (P⊙(Z_tau K Zᵀ))v   : rowwise bilinear with K (Zᵀdiag(v)Z) K
1469        //   (P⊙(Z K Z_tauᵀ))v   : diag_Z( K (Zᵀdiag(v)Z_tau) K )
1470        //   (P⊙(Z dotK Zᵀ))v    : Hadamard-Gram apply with (K, dotK).
1471        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
1472            return Array2::<f64>::zeros(mat.raw_dim());
1473        }
1474        let mut out = Array2::<f64>::zeros(mat.raw_dim());
1475        for col in 0..mat.ncols() {
1476            let v = mat.column(col).to_owned();
1477            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
1478            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
1479            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, &kernel.x_tau_reduced);
1480
1481            let szt =
1482                RemlState::reduced_crossweighted_gram(&self.x_reduced, &kernel.x_tau_reduced, &v);
1483            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
1484            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
1485
1486            let t3 = RemlState::apply_hadamard_gram(
1487                &self.x_reduced,
1488                &self.k_reduced,
1489                &kernel.dot_k_reduced,
1490                &v,
1491            );
1492
1493            let y = 2.0 * (t1 + t2 + t3);
1494            out.column_mut(col).assign(&y);
1495        }
1496        out
1497    }
1498
1499    pub(crate) fn hphi_tau_partial_apply(
1500        &self,
1501        x_tau: &Array2<f64>,
1502        kernel: &FirthTauPartialKernel,
1503        rhs: &Array2<f64>,
1504    ) -> Array2<f64> {
1505        let p = self.x_dense.ncols();
1506        if rhs.nrows() != p {
1507            return Array2::<f64>::zeros((p, rhs.ncols()));
1508        }
1509        if rhs.ncols() == 0 || p == 0 {
1510            return Array2::<f64>::zeros((p, rhs.ncols()));
1511        }
1512        // Matrix-free block apply of Hphi,tau|beta:
1513        //   Hphi,tau|beta(V) = 0.5 [ X_tau' r(V) + X' r_tau(V) ].
1514        //
1515        // Tensor identity behind this apply:
1516        //   Hphi,tau|beta = Phi_beta,beta,tau
1517        // and for test vectors b1,b2 (matrix columns V are batched b2's):
1518        //   Phi_beta,beta,tau[b1,b2]
1519        //   = 0.5[
1520        //       tr(I^{-1} I_{b1,b2,tau})
1521        //       - tr(I^{-1} I_{b1,b2} I^{-1} I_tau)
1522        //       - tr(I^{-1} I_{b1,tau} I^{-1} I_{b2})
1523        //       - tr(I^{-1} I_{b2,tau} I^{-1} I_{b1})
1524        //       + 2 tr(I^{-1} I_{b1} I^{-1} I_{b2} I^{-1} I_tau)
1525        //     ].
1526        // This routine evaluates that form in reduced coordinates without forming
1527        // dense 3rd-order tensors explicitly.
1528        let etav = fast_ab(&self.x_dense, rhs);
1529        let etav_tau = fast_ab(x_tau, rhs);
1530        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
1531        let qv_tau = &etav * &kernel.dotw1.view().insert_axis(Axis(1))
1532            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
1533        let m_qv = self.apply_pbar_to_matrix(&qv);
1534        let m_qv_tau = self.apply_mtau_to_matrix(kernel, &qv) + self.apply_pbar_to_matrix(&qv_tau);
1535        let rv = &(&etav * &self.w2.view().insert_axis(Axis(1)))
1536            * &self.h_diag.view().insert_axis(Axis(1))
1537            - &(&m_qv * &self.w1.view().insert_axis(Axis(1)));
1538        let rv_tau = (&(&etav * &kernel.dotw2.view().insert_axis(Axis(1)))
1539            + &(&etav_tau * &self.w2.view().insert_axis(Axis(1))))
1540            * self.h_diag.view().insert_axis(Axis(1))
1541            + &(&etav * &self.w2.view().insert_axis(Axis(1)))
1542                * &kernel.dot_h_partial.view().insert_axis(Axis(1))
1543            - &(&m_qv * &kernel.dotw1.view().insert_axis(Axis(1))
1544                + &m_qv_tau * &self.w1.view().insert_axis(Axis(1)));
1545        0.5 * (fast_atb(x_tau, &rv) + fast_atb(&self.x_dense, &rv_tau))
1546    }
1547
1548    // ═════════════════════════════════════════════════════════════════════════
1549    //  Pair-term primitives for the Firth outer Hessian (Task #13a / #17)
1550    // ═════════════════════════════════════════════════════════════════════════
1551    //
1552    // The REML outer Hessian at a ψ=(ρ,τ) pair needs two Firth contributions
1553    // that are NOT covered by the existing single-τ primitives:
1554    //
1555    //   A.  the pure τ×τ second partial of H_φ at fixed β (pair drift inside
1556    //       the fixed-β second-derivative trace of B_i,j used by
1557    //       build_tau_tau_pair_callback).  This is the Firth analog of the
1558    //       penalty-logdet pair term in the outer-derivative cookbook.
1559    //
1560    //   B.  the β-derivative of (H_φ)_τ|_β in direction v.  This is the
1561    //       fixed-drift-derivative M_i[v] = D_β B_i[v] that
1562    //       compute_drift_deriv_traces uses through the fixed_drift_deriv
1563    //       callback in build_tau_hyper_coords.  It is currently always None
1564    //       in the Firth+Logit path, which is what makes the outer Hessian
1565    //       approximate for Firth-reweighted models.
1566    //
1567    // Both primitives operate in the reduced identifiable subspace (X_r, K_r,
1568    // S_r, Z=X_r) of the dense Firth operator and are matrix-free in n.  They
1569    // do NOT introduce any dense n×n or p×p×p object; every contraction is
1570    // routed through reduced-space Hadamard-Gram applies and rowwise
1571    // bilinear forms, in the same spirit as hphi_tau_partial_apply and
1572    // hphisecond_direction_apply.
1573    //
1574    // Both are exact in the smooth-regime operating point assumed by this
1575    // module: X SPD-full-rank on its identifiable subspace, Q held fixed
1576    // within one outer REML step (active-subspace drift enters only between
1577    // outer iterates), w_i(η) > 0 strictly positive, and β is at the P-IRLS
1578    // solution for the current ψ so β_τ is supplied by the unified evaluator
1579    // via the IFT solve.
1580    //
1581    // Symbol conventions shared with the existing operator code:
1582    //   X_r   := X Q            (reduced identifiable design)
1583    //   W     := diag(w(η))     (Fisher weights), w', w'', w''', w''''
1584    //   I_r   := X_rᵀ W X_r,  K_r := I_r^{-1},  S_r := X_rᵀ X_r
1585    //   M     := X_r K_r X_rᵀ,  P := M ⊙ M,  B := diag(w') X
1586    //   h     := diag(M)
1587    //   X_i   := ∂X/∂τ_i,  X_{r,i} := X_i Q,  η̇_i := X_i β,  etc.
1588    //   δη_v  := X v           (β-direction v),  δη_{τ,v} := X_τ v
1589    //
1590    // H_φ structural form (reduced form of ∇²_β Φ_F):
1591    //   H_φ  =  ½ [ Xᵀ diag(w'' ⊙ h) X  −  Bᵀ P B ]
1592    // where the first term arises from differentiating the Jeffreys gradient
1593    // ½ Xᵀ (w' ⊙ h) once more in β, and the second collects the IFT-mediated
1594    // β-derivative of h = diag(X_r K_r X_rᵀ) through K_r = I_r^{-1}.  This is
1595    // the same form the existing hphi_direction code implements in directional
1596    // form (cf. firth.rs hphi_direction / hphisecond_direction_apply, the
1597    // 9-term D²J₂ expansion with J₂ = BᵀPB).
1598    //
1599    // ─────────────────────────────────────────────────────────────────────────
1600    //  Primitive A — ∂²H_φ/∂τ_i ∂τ_j |_β
1601    // ─────────────────────────────────────────────────────────────────────────
1602    //
1603    // WHAT IT COMPUTES
1604    //   Given a pair of τ-drift designs (X_τ_i, X_τ_j) and optional second
1605    //   design derivative X_{τ_i τ_j}, evaluates
1606    //
1607    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [ ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j
1608    //                                − ∂²(Bᵀ P B)/∂τ_i ∂τ_j ],
1609    //   with Γ := diag(w'' ⊙ h).  Acts on a p×m rhs and returns a p×m block
1610    //   (exact same contract as hphi_tau_partial_apply, but for the *second*
1611    //   mixed τ-derivative at fixed β).
1612    //
1613    // WHY (REML callsite)
1614    //   In the outer Hessian entry Ḧ_{i,j} for τ×τ pair (i,j), the fixed-β
1615    //   second drift of B_i is exactly this primitive (with the Firth sign:
1616    //   B_i = −(H_φ)_τ_i|_β + other likelihood pieces, so ∂²B_i/∂τ_j|_β
1617    //   contributes −∂²H_φ/∂τ_i∂τ_j|_β to the outer Hessian trace).  The
1618    //   existing Firth pair callback at build_tau_tau_pair_callback currently
1619    //   carries zero for this Firth contribution; wiring this primitive into
1620    //   the TauTauPairHyperOperator is the remaining step to make the τ×τ
1621    //   outer Hessian exact in the Firth-reweighted Logit path.
1622    //
1623    // DERIVATION (full chain-rule expansion, at fixed β)
1624    //
1625    //   Building blocks at fixed β (single-τ):
1626    //     İ_i      := ∂I_r/∂τ_i |_β
1627    //                = X_{r,i}ᵀ W X_r + X_rᵀ W X_{r,i} + X_rᵀ Ẇ_i X_r,
1628    //       Ẇ_i    := diag(w' ⊙ η̇_i),   η̇_i := X_i β.
1629    //     K̇_i      := ∂K_r/∂τ_i = −K_r İ_i K_r.
1630    //     ḣ_i      := ∂h/∂τ_i |_β
1631    //                = 2·diag(X_{r,i} K_r X_rᵀ) + diag(X_r K̇_i X_rᵀ).
1632    //     Ṁ_i      := ∂M/∂τ_i |_β
1633    //                = X_{r,i} K_r X_rᵀ + X_r K̇_i X_rᵀ + X_r K_r X_{r,i}ᵀ.
1634    //     Ḃ_i      := ∂B/∂τ_i |_β
1635    //                = diag(w'' ⊙ η̇_i) X + diag(w') X_i.
1636    //     Ṗ_i      := ∂P/∂τ_i = 2 (M ⊙ Ṁ_i).
1637    //     Γ̇_i     := ∂Γ/∂τ_i |_β = diag(w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i).
1638    //
1639    //   Second-order building blocks:
1640    //     η̈_{ij}  := X_{ij} β                   (0 for design linear in τ)
1641    //     Ẅ_{ij} := diag(w'' ⊙ η̇_i ⊙ η̇_j
1642    //                     + w' ⊙ η̈_{ij})
1643    //
1644    //     Ï_{ij}   := ∂²I_r/∂τ_i ∂τ_j |_β
1645    //                = X_{r,ij}ᵀ W X_r  +  X_rᵀ W X_{r,ij}
1646    //                 + X_{r,i}ᵀ W X_{r,j}  +  X_{r,j}ᵀ W X_{r,i}
1647    //                 + X_{r,i}ᵀ Ẇ_j X_r  +  X_rᵀ Ẇ_j X_{r,i}
1648    //                 + X_{r,j}ᵀ Ẇ_i X_r  +  X_rᵀ Ẇ_i X_{r,j}
1649    //                 + X_rᵀ Ẅ_{ij} X_r.
1650    //
1651    //     K̈_{ij}  := ∂²K_r/∂τ_i ∂τ_j
1652    //                = −K_r Ï_{ij} K_r
1653    //                  + K_r İ_i K_r İ_j K_r
1654    //                  + K_r İ_j K_r İ_i K_r.
1655    //
1656    //     M̈_{ij}  := X_{r,ij} K_r X_rᵀ + X_r K_r X_{r,ij}ᵀ
1657    //                 + X_{r,i} K̇_j X_rᵀ + X_r K̇_j X_{r,i}ᵀ
1658    //                 + X_{r,j} K̇_i X_rᵀ + X_r K̇_i X_{r,j}ᵀ
1659    //                 + X_{r,i} K_r X_{r,j}ᵀ + X_{r,j} K_r X_{r,i}ᵀ
1660    //                 + X_r K̈_{ij} X_rᵀ.
1661    //
1662    //     P̈_{ij}  := ∂²P/∂τ_i ∂τ_j
1663    //                = 2 (Ṁ_i ⊙ Ṁ_j) + 2 (Ṁ_j ⊙ Ṁ_i) + 2 (M ⊙ M̈_{ij})
1664    //                = 4 (Ṁ_i ⊙ Ṁ_j) + 2 (M ⊙ M̈_{ij}).
1665    //
1666    //     ḧ_{ij}  := ∂²h/∂τ_i ∂τ_j |_β
1667    //                = 2·diag(X_{r,ij} K_r X_rᵀ)
1668    //                 + diag(X_r K̈_{ij} X_rᵀ)
1669    //                 + 2·diag(X_{r,i} K̇_j X_rᵀ)
1670    //                 + 2·diag(X_{r,j} K̇_i X_rᵀ)
1671    //                 + 2·diag(X_{r,i} K_r X_{r,j}ᵀ).
1672    //
1673    //     B̈_{ij}  := ∂²B/∂τ_i ∂τ_j |_β
1674    //                = diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_{ij}) X
1675    //                 + diag(w'' ⊙ η̇_i) X_j
1676    //                 + diag(w'' ⊙ η̇_j) X_i
1677    //                 + diag(w') X_{ij}.
1678    //
1679    //     Γ̈_{ij} := ∂²Γ/∂τ_i ∂τ_j |_β
1680    //                = diag( w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
1681    //                       + w''' ⊙ η̈_{ij} ⊙ h
1682    //                       + w''' ⊙ η̇_i ⊙ ḣ_j
1683    //                       + w''' ⊙ η̇_j ⊙ ḣ_i
1684    //                       + w'' ⊙ ḧ_{ij} ).
1685    //
1686    //   Diagonal-term expansion (the Xᵀ Γ X branch):
1687    //
1688    //     ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  =
1689    //         X_{ij}ᵀ Γ X  + Xᵀ Γ X_{ij}
1690    //       + X_iᵀ Γ X_j  + X_jᵀ Γ X_i
1691    //       + X_iᵀ Γ̇_j X  + Xᵀ Γ̇_j X_i
1692    //       + X_jᵀ Γ̇_i X  + Xᵀ Γ̇_i X_j
1693    //       + Xᵀ Γ̈_{ij} X.
1694    //
1695    //   9-term expansion for the BᵀPB branch (structurally identical to
1696    //   the existing β×β D²J₂[u,v] at firth.rs:~820-830 with (u,v)
1697    //   substituted by (τ_i, τ_j) and the appropriate Ḃ, B̈, Ṗ, P̈):
1698    //
1699    //     D²(BᵀPB)[τ_i,τ_j]  =
1700    //         B̈_{ij}ᵀ  P    B      +  Bᵀ       P    B̈_{ij}
1701    //       + Ḃ_iᵀ    P    Ḃ_j   +  Ḃ_jᵀ    P    Ḃ_i
1702    //       + Ḃ_iᵀ    Ṗ_j  B      +  Bᵀ       Ṗ_j  Ḃ_i
1703    //       + Ḃ_jᵀ    Ṗ_i  B      +  Bᵀ       Ṗ_i  Ḃ_j
1704    //       + Bᵀ       P̈_{ij} B.
1705    //
1706    //   Combining,
1707    //
1708    //     ∂²H_φ/∂τ_i ∂τ_j |_β  =  ½ [
1709    //         ∂²(Xᵀ Γ X)/∂τ_i ∂τ_j  −  D²(BᵀPB)[τ_i, τ_j]
1710    //     ].
1711    //
1712    // IMPLEMENTATION SKETCH (for 13b)
1713    //   • Build per-direction reduced quantities for τ_i and τ_j:
1714    //       (x_tau_reduced, η̇, İ, K̇, Ṁ operator pieces, ḣ, b_uvec = w''⊙η̇).
1715    //     The existing `dot_i_and_h_from_reduced` yields İ and ḣ already;
1716    //     the per-direction "A_u" analog is A_τ = K_r İ K_r, matching the
1717    //     FirthDirection form used by hphisecond_direction_apply.
1718    //   • Use apply_hadamard_gram_to_matrix with
1719    //       (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_τ_i), (K_r, A_τ_j),
1720    //                             (A_τ_i, A_τ_j) }
1721    //     to realize P-products, Ṗ_τ-products, and the (Ṁ_i ⊙ Ṁ_j) piece of
1722    //     P̈_{ij} without forming any n×n dense intermediate.
1723    //   • The pure-second piece `X_r K̈_{ij} X_rᵀ` decomposes into three
1724    //     reduced triple products (K_r Ï_{ij} K_r, K_r İ_i K_r İ_j K_r, and
1725    //     its transpose).  All are size-r×r in reduced coordinates.
1726    //   • For design-linear-in-τ smooths, X_{ij}=0 and η̈_{ij}=0, which
1727    //     prunes many sub-terms; callers who have X_{τ_i τ_j} available
1728    //     should pass it so the primitive remains exact on curved designs.
1729    //
1730    // ─────────────────────────────────────────────────────────────────────────
1731    //  Primitive B — D_β((H_φ)_τ|_β)[v]
1732    // ─────────────────────────────────────────────────────────────────────────
1733    //
1734    // WHAT IT COMPUTES
1735    //   Given a single τ-drift design X_τ, the β-fixed Firth partial
1736    //   (H_φ)_τ|_β encoded by FirthTauPartialKernel, and a β-direction
1737    //   vector v (of length p), returns the β-derivative of (H_φ)_τ|_β
1738    //   applied to an rhs block (so output is p×m, matching the pair's
1739    //   fixed_drift_deriv callback signature DriftDerivResult).  In
1740    //   symbols:
1741    //
1742    //     D_β((H_φ)_τ|_β)[v]  =  ½ [ D_β{(∂(XᵀΓX)/∂τ)|_β}[v]
1743    //                                 −  D_β{(∂(BᵀPB)/∂τ)|_β}[v] ].
1744    //
1745    // WHY (REML callsite)
1746    //   In the exact outer Hessian assembly (compute_drift_deriv_traces in
1747    //   unified.rs), the Ḧ_{ij} entry picks up
1748    //     tr(G_ε · D_β B_i[v_j])  +  tr(G_ε · D_β B_j[v_i]).
1749    //   For τ coordinates in the Firth+Logit path, B_τ = (penalty / design
1750    //   pieces) − (H_φ)_τ|_β, so the Firth share of D_β B_τ[v] is
1751    //     − D_β((H_φ)_τ|_β)[v].
1752    //   Hooking this primitive up through a FixedDriftDerivFn (returning
1753    //   DriftDerivResult::Dense of this p×p β-v action) is exactly what
1754    //   lets build_tau_hyper_coords pass a non-None fixed_drift_deriv
1755    //   closure into the unified evaluator, closing the approximation gap
1756    //   that firth_pair_terms_unavailable currently tracks.
1757    //
1758    // DERIVATION (β-derivative of each τ-partial term in direction v)
1759    //
1760    //   β enters only through η=Xβ, so designs X, X_τ, Q, X_r are all
1761    //   β-independent; D_β acts on w(η) and its derivatives, on I_r, K_r,
1762    //   M, h, and on η̇_τ = X_τ β.
1763    //
1764    //   Primary β-derivative building blocks (matches FirthDirection with
1765    //   deta := δη_v = X v):
1766    //     I'_v  := D_β I_r[v] = X_rᵀ diag(w' ⊙ δη_v) X_r      (g_u_reduced)
1767    //     A_v   := D_β K_r[v] = −K_r I'_v K_r                  (a_u_reduced)
1768    //     dh_v  := D_β h[v]    = −diag(X_r K_r I'_v K_r X_rᵀ)
1769    //                          = diag(X_r A_v X_rᵀ)            (dh)
1770    //     (w')_v  := D_β w'[v]  = w''  ⊙ δη_v
1771    //     (w'')_v := D_β w''[v] = w''' ⊙ δη_v
1772    //     (w''')_v:= D_β w'''[v]= w''''⊙ δη_v
1773    //     δη_{τ,v} := D_β(η̇_τ)[v] = X_τ v
1774    //
1775    //   Mixed τ-β pieces:
1776    //     D_β(İ_τ)[v]
1777    //       = X_{r,τ}ᵀ diag(w'' ⊙ δη_v) X_r
1778    //        + X_rᵀ diag(w'' ⊙ δη_v) X_{r,τ}
1779    //        + X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v
1780    //                     + w' ⊙ δη_{τ,v}) X_r.
1781    //     D_β(K̇_τ)[v]
1782    //       = −( A_v İ_τ K_r  +  K_r D_β(İ_τ)[v] K_r
1783    //             +  K_r İ_τ A_v ).
1784    //     D_β(Ṁ_τ)[v]
1785    //       = X_{r,τ} A_v X_rᵀ
1786    //        + X_r D_β(K̇_τ)[v] X_rᵀ
1787    //        + X_r A_v X_{r,τ}ᵀ.
1788    //     D_β(ḣ_τ)[v]
1789    //       = 2·diag(X_{r,τ} A_v X_rᵀ)
1790    //        + diag(X_r D_β(K̇_τ)[v] X_rᵀ).
1791    //
1792    //   Diagonal-term β-derivative ( (X_τᵀΓX + XᵀΓX_τ + XᵀΓ̇_τ X) branch ):
1793    //     D_β(X_τᵀ Γ X + Xᵀ Γ X_τ)[v]
1794    //       = X_τᵀ Γ_v X + Xᵀ Γ_v X_τ,
1795    //       Γ_v  := D_β Γ[v] = diag((w'')_v ⊙ h + w'' ⊙ dh_v)
1796    //                        = diag(w''' ⊙ δη_v ⊙ h + w'' ⊙ dh_v).
1797    //     D_β(Xᵀ Γ̇_τ X)[v]
1798    //       = Xᵀ Γ̇_{τ,v} X,
1799    //       Γ̇_{τ,v}
1800    //        := D_β Γ̇_τ[v]
1801    //         = diag( (w''')_v ⊙ η̇_τ ⊙ h
1802    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1803    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1804    //                 + (w'')_v ⊙ ḣ_τ
1805    //                 + w'' ⊙ D_β(ḣ_τ)[v] )
1806    //         = diag( w'''' ⊙ η̇_τ ⊙ δη_v ⊙ h
1807    //                 + w''' ⊙ δη_{τ,v} ⊙ h
1808    //                 + w''' ⊙ η̇_τ ⊙ dh_v
1809    //                 + w''' ⊙ δη_v ⊙ ḣ_τ
1810    //                 + w'' ⊙ D_β(ḣ_τ)[v] ).
1811    //
1812    //   Cross-coupling τ-β pieces for B:
1813    //     B_v  := D_β B[v]   = diag(w'' ⊙ δη_v) X               (b_uvec)
1814    //     B_τ  := ∂B/∂τ|_β   = diag(w'' ⊙ η̇_τ) X
1815    //                         + diag(w') X_τ.
1816    //     B_{τ,v}
1817    //         := D_β B_τ[v]  = diag( w''' ⊙ η̇_τ ⊙ δη_v
1818    //                                 + w'' ⊙ δη_{τ,v} ) X
1819    //                         + diag(w'' ⊙ δη_v) X_τ.
1820    //
1821    //   BᵀPB branch — 9 terms, obtained by applying the product rule to
1822    //   ∂(BᵀPB)/∂τ = Ḃ_τᵀ P B + Bᵀ Ṗ_τ B + Bᵀ P Ḃ_τ and then taking
1823    //   D_β(·)[v] of each factor:
1824    //
1825    //     D_β(Ḃ_τᵀ P B)[v]   = B_{τ,v}ᵀ P B + Ḃ_τᵀ P_v B + Ḃ_τᵀ P B_v,
1826    //     D_β(Bᵀ Ṗ_τ B)[v]   = B_vᵀ Ṗ_τ B  + Bᵀ P_{τ,v} B + Bᵀ Ṗ_τ B_v,
1827    //     D_β(Bᵀ P Ḃ_τ)[v]   = B_vᵀ P Ḃ_τ + Bᵀ P_v Ḃ_τ + Bᵀ P B_{τ,v}.
1828    //
1829    //   Here Ḃ_τ = B_τ above, and
1830    //     P_v := D_β P[v]         = 2 (M ⊙ M_v),   M_v = X_r A_v X_rᵀ.
1831    //     Ṗ_τ := ∂P/∂τ|_β         = 2 (M ⊙ M_τ),
1832    //       M_τ = X_{r,τ} K_r X_rᵀ + X_r K̇_τ X_rᵀ + X_r K_r X_{r,τ}ᵀ.
1833    //     P_{τ,v} := D_β(Ṗ_τ)[v]  = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
1834    //       M_{τ,v} = X_{r,τ} A_v X_rᵀ + X_r D_β(K̇_τ)[v] X_rᵀ + X_r A_v X_{r,τ}ᵀ.
1835    //
1836    //   Final primitive:
1837    //
1838    //     D_β((H_φ)_τ|_β)[v]  =  ½ [
1839    //           X_τᵀ Γ_v X  + Xᵀ Γ_v X_τ  + Xᵀ Γ̇_{τ,v} X
1840    //         −  (9-term BᵀPB β-τ expansion above)
1841    //     ].
1842    //
1843    //   Applied to an rhs block `R ∈ ℝ^{p × m}`, each Xᵀ(…) X R collapses
1844    //   to n-length row scalings of (X R) followed by Xᵀ; each Bᵀ P B
1845    //   variant uses apply_hadamard_gram_to_matrix with the correct
1846    //   (A_left, A_right) ∈ { (K_r, K_r), (K_r, A_v), (K_r, K̇_τ),
1847    //     (K_r, D_β(K̇_τ)[v]), (A_v, K̇_τ), (K_r, K̇_τ) } to realize
1848    //   P, P_v, Ṗ_τ, P_{τ,v} actions.  All operators are r×r in reduced
1849    //   coordinates, matching the existing apply cost profile.
1850    //
1851    // IMPLEMENTATION SKETCH (for 13c)
1852    //   • Build `FirthDirection` from deta = X v (reuses existing
1853    //     direction_from_deta, giving I'_v, A_v, dh_v, b_uvec).
1854    //   • Build β-derivatives of the τ-specific fields of
1855    //     FirthTauPartialKernel (dotw1, dotw2, dot_h_partial, dot_k_reduced,
1856    //     and the implicit M_τ reduced-coords operator).  These become a
1857    //     new FirthTauBetaPartialKernel attached to the prepared state.
1858    //   • The apply step is then algebraically identical to
1859    //     hphi_tau_partial_apply but with every W-tensor weight replaced by
1860    //     its β-derivative in v, and every (M, K_r)-Gram replaced by the
1861    //     appropriate β-derivative Gram above.  The structure is regular
1862    //     enough that a single helper, shared with Primitive A, can absorb
1863    //     both pair dispatches.
1864    //
1865    // NOTE ON DESIGN-LINEAR SMOOTHS
1866    //   For the common case of design-linear-in-τ smooths (scale-moving
1867    //   anisotropic bases), X_i and X_τ are constant in τ, so X_{ij}=0 and
1868    //   η̈_{ij}=0.  The primitives collapse to their W-reweighted cores but
1869    //   remain matrix-free; no special fast path is needed because the
1870    //   zeroed terms simply drop out of the Hadamard-Gram assembly.
1871    //
1872    // ═════════════════════════════════════════════════════════════════════════
1873
1874    /// Primitive A — prepare step: assemble the τ_i × τ_j reduced kernel.
1875    ///
1876    /// Consumes the per-direction partial quantities produced by
1877    /// `dot_i_and_h_from_reduced` for τ_i and τ_j (plus an optional second
1878    /// design derivative X_{τ_i τ_j}), and returns a cached kernel carrying
1879    /// the M̈_{ij}, K̈_{ij}, ḧ_{ij}, Γ̈_{ij}, and B̈_{ij}-related reduced
1880    /// coordinates needed by `hphi_tau_tau_partial_apply`.
1881    ///
1882    /// This signature mirrors `hphi_tau_partial_prepare_from_partials` for
1883    /// consistency; the pair version needs both directions simultaneously
1884    /// (to realize the 9-term D² expansion) and therefore owns both
1885    /// `x_tau_{i,j}_reduced` and their η̇_i / η̇_j.
1886    ///
1887    pub(crate) fn hphi_tau_tau_partial_prepare_from_partials(
1888        &self,
1889        x_tau_i_reduced: Array2<f64>,
1890        x_tau_j_reduced: Array2<f64>,
1891        deta_i_partial: &Array1<f64>,
1892        deta_j_partial: &Array1<f64>,
1893        dot_h_i_partial: Array1<f64>,
1894        dot_h_j_partial: Array1<f64>,
1895        dot_i_i_partial: Array2<f64>,
1896        dot_i_j_partial: Array2<f64>,
1897        x_tau_tau_reduced: Option<Array2<f64>>,
1898        deta_ij_partial: Option<Array1<f64>>,
1899    ) -> FirthTauTauPartialKernel {
1900        // K̇_i = -K_r İ_i K_r;  K̇_j = -K_r İ_j K_r.
1901        let dot_k_i_reduced = -self.k_reduced.dot(&dot_i_i_partial).dot(&self.k_reduced);
1902        let dot_k_j_reduced = -self.k_reduced.dot(&dot_i_j_partial).dot(&self.k_reduced);
1903        FirthTauTauPartialKernel {
1904            x_tau_i_reduced,
1905            x_tau_j_reduced,
1906            deta_i_partial: deta_i_partial.clone(),
1907            deta_j_partial: deta_j_partial.clone(),
1908            dot_h_i_partial,
1909            dot_h_j_partial,
1910            dot_k_i_reduced,
1911            dot_k_j_reduced,
1912            dot_i_i_partial,
1913            dot_i_j_partial,
1914            x_tau_tau_reduced,
1915            deta_ij_partial,
1916        }
1917    }
1918
1919    /// Primitive A — apply step: evaluate ∂²H_φ/∂τ_i ∂τ_j |_β against a p×m
1920    /// rhs block, returning a p×m block.
1921    ///
1922    /// Contract mirrors `hphi_tau_partial_apply`: the caller passes the two
1923    /// τ-drift designs and the prepared kernel, and receives the fixed-β
1924    /// second-τ Firth drift as a dense p×m action.  Matrix-free in n.
1925    ///
1926    pub(crate) fn hphi_tau_tau_partial_apply(
1927        &self,
1928        x_tau_i: &Array2<f64>,
1929        x_tau_j: &Array2<f64>,
1930        kernel: &FirthTauTauPartialKernel,
1931        rhs: &Array2<f64>,
1932    ) -> Array2<f64> {
1933        let p = self.x_dense.ncols();
1934        if rhs.nrows() != p {
1935            return Array2::<f64>::zeros((p, rhs.ncols()));
1936        }
1937        if rhs.ncols() == 0 || p == 0 {
1938            return Array2::<f64>::zeros((p, rhs.ncols()));
1939        }
1940        let n = self.x_dense.nrows();
1941        let m = rhs.ncols();
1942
1943        // Short aliases.
1944        let z = &self.x_reduced;
1945        let x_r = &self.x_reduced;
1946        let k = &self.k_reduced;
1947        let x_ri = &kernel.x_tau_i_reduced;
1948        let x_rj = &kernel.x_tau_j_reduced;
1949        let deta_i = &kernel.deta_i_partial;
1950        let deta_j = &kernel.deta_j_partial;
1951        let dh_i = &kernel.dot_h_i_partial;
1952        let dh_j = &kernel.dot_h_j_partial;
1953        let dot_k_i = &kernel.dot_k_i_reduced;
1954        let dot_k_j = &kernel.dot_k_j_reduced;
1955        let dot_i_i = &kernel.dot_i_i_partial;
1956        let dot_i_j = &kernel.dot_i_j_partial;
1957
1958        // Optional second-design pieces: default to zero when the design is
1959        // τ-linear (η̈_{ij} = 0, X_{ij} = 0).
1960        let x_tau_tau_is_some = kernel.x_tau_tau_reduced.is_some();
1961        let x_rij_zero = Array2::<f64>::zeros(x_r.raw_dim());
1962        let x_rij: &Array2<f64> = kernel.x_tau_tau_reduced.as_ref().unwrap_or(&x_rij_zero);
1963        let zeros_n = Array1::<f64>::zeros(n);
1964        let deta_ij = kernel.deta_ij_partial.as_ref().unwrap_or(&zeros_n);
1965
1966        // ─────────────────────────────────────────────────────────────────
1967        //  η̇ vectors in β-rhs space (η_V := X V, η_{i,V} := X_i V, etc.)
1968        // ─────────────────────────────────────────────────────────────────
1969        let (eta_v, eta_i_v, eta_j_v) = if RemlState::should_join_independent_dense_products(&[
1970            (n, m, p),
1971            (n, m, p),
1972            (n, m, p),
1973        ]) {
1974            let (eta_v, (eta_i_v, eta_j_v)) = rayon::join(
1975                || fast_ab(&self.x_dense, rhs),
1976                || rayon::join(|| fast_ab(x_tau_i, rhs), || fast_ab(x_tau_j, rhs)),
1977            );
1978            (eta_v, eta_i_v, eta_j_v)
1979        } else {
1980            (
1981                fast_ab(&self.x_dense, rhs),
1982                fast_ab(x_tau_i, rhs),
1983                fast_ab(x_tau_j, rhs),
1984            )
1985        }; // n×m blocks
1986        // X_{ij} V from the reduced second-derivative design:
1987        //   reduce_explicit_design: X_{r,τ} = diag(√a) X_τ Q,
1988        //   invert:  X_{ij} = diag(1/√a) X_{r,ij} Qᵀ.
1989        let eta_ij_v: Array2<f64> = if x_tau_tau_is_some {
1990            let qt_v = fast_atb(&self.q_basis, rhs); // r×m
1991            let mut out = fast_ab(x_rij, &qt_v); // n×m in sqrt(a)-scaled space
1992            RemlState::scale_rows_by_inverse_observation_weight_sqrt(
1993                &mut out,
1994                self.observation_weight_sqrt.as_ref(),
1995            );
1996            out
1997        } else {
1998            Array2::<f64>::zeros((n, m))
1999        };
2000
2001        // ─────────────────────────────────────────────────────────────────
2002        //  Shared per-direction reduced operators
2003        //    A_τ = K İ K   (reduced analog of T_τ = K I_τ K)
2004        //    K̇_τ = -A_τ  (already cached)
2005        // ─────────────────────────────────────────────────────────────────
2006        let a_i_reduced = -dot_k_i; // K İ_i K = -K̇_i
2007        let a_j_reduced = -dot_k_j;
2008
2009        // ─────────────────────────────────────────────────────────────────
2010        //  Ï_{ij}  — second cross derivative of reduced Fisher
2011        // ─────────────────────────────────────────────────────────────────
2012        //   Ï_{ij} = X_{r,ij}ᵀ W X_r + X_rᵀ W X_{r,ij}
2013        //          + X_{r,i}ᵀ W X_{r,j} + X_{r,j}ᵀ W X_{r,i}
2014        //          + X_{r,i}ᵀ Ẇ_j X_r + X_rᵀ Ẇ_j X_{r,i}
2015        //          + X_{r,j}ᵀ Ẇ_i X_r + X_rᵀ Ẇ_i X_{r,j}
2016        //          + X_rᵀ Ẅ_{ij} X_r.
2017        //   Ẇ_α   = diag(w' ⊙ η̇_α),  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_ij).
2018        let dw_i = &self.w1 * deta_i;
2019        let dw_j = &self.w1 * deta_j;
2020        let ddw_ij = &(&self.w2 * &(deta_i * deta_j)) + &(&self.w1 * deta_ij);
2021        let mut i_ddot = Array2::<f64>::zeros(k.raw_dim());
2022        if x_tau_tau_is_some {
2023            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2024            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2025        }
2026        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_rj, &self.w);
2027        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_ri, &self.w);
2028        i_ddot = i_ddot + RemlState::weighted_cross(x_ri, x_r, &dw_j);
2029        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_ri, &dw_j);
2030        i_ddot = i_ddot + RemlState::weighted_cross(x_rj, x_r, &dw_i);
2031        i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rj, &dw_i);
2032        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2033
2034        // K̈_{ij} = −K Ï K + K İ_i K İ_j K + K İ_j K İ_i K.
2035        //   Using K İ_α K = −K̇_α = a_α_reduced, the two product terms collapse to
2036        //   K İ_i K İ_j K = a_i_reduced · İ_j · K,
2037        //   K İ_j K İ_i K = a_j_reduced · İ_i · K.
2038        let k_ddot: Array2<f64> = -k.dot(&i_ddot).dot(k)
2039            + a_i_reduced.dot(dot_i_j).dot(k)
2040            + a_j_reduced.dot(dot_i_i).dot(k);
2041
2042        // ─────────────────────────────────────────────────────────────────
2043        //  ḧ_{ij}
2044        // ─────────────────────────────────────────────────────────────────
2045        //   ḧ_ij = 2 diag(X_{r,ij} K X_rᵀ)
2046        //        + diag(X_r K̈_ij X_rᵀ)
2047        //        + 2 diag(X_{r,i} K̇_j X_rᵀ)
2048        //        + 2 diag(X_{r,j} K̇_i X_rᵀ)
2049        //        + 2 diag(X_{r,i} K X_{r,j}ᵀ).
2050        // Using diag(A Bᵀ) = rowwise_dot(A, B):
2051        let dh_ij: Array1<f64> = {
2052            let r = k.ncols();
2053            let can_join = RemlState::should_join_independent_dense_products(&[
2054                (n, r, r),
2055                (n, r, r),
2056                (n, r, r),
2057                (n, r, r),
2058            ]);
2059            let (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k) = if can_join {
2060                let ((xr_kddot, ri_kdot_j), (rj_kdot_i, ri_k)) = rayon::join(
2061                    || rayon::join(|| fast_ab(x_r, &k_ddot), || fast_ab(x_ri, dot_k_j)),
2062                    || rayon::join(|| fast_ab(x_rj, dot_k_i), || fast_ab(x_ri, k)),
2063                );
2064                (xr_kddot, ri_kdot_j, rj_kdot_i, ri_k)
2065            } else {
2066                (
2067                    fast_ab(x_r, &k_ddot),
2068                    fast_ab(x_ri, dot_k_j),
2069                    fast_ab(x_rj, dot_k_i),
2070                    fast_ab(x_ri, k),
2071                )
2072            };
2073
2074            let mut acc = Self::rowwise_dot(&xr_kddot, x_r);
2075            acc = acc + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2076            acc = acc + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2077            acc = acc + 2.0 * Self::rowwise_dot(&ri_k, x_rj);
2078            if x_tau_tau_is_some {
2079                let rij_k = fast_ab(x_rij, k);
2080                acc = acc + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2081            }
2082            acc
2083        };
2084
2085        // ─────────────────────────────────────────────────────────────────
2086        //  Γ, Γ̇_i, Γ̇_j, Γ̈_ij  (diagonal row-weight n-vectors)
2087        // ─────────────────────────────────────────────────────────────────
2088        //   γ        = w'' ⊙ h
2089        //   γ̇_i     = w''' ⊙ η̇_i ⊙ h + w'' ⊙ ḣ_i
2090        //   γ̈_ij   = w'''' ⊙ η̇_i ⊙ η̇_j ⊙ h
2091        //            + w''' ⊙ η̈_ij ⊙ h
2092        //            + w''' ⊙ η̇_i ⊙ ḣ_j
2093        //            + w''' ⊙ η̇_j ⊙ ḣ_i
2094        //            + w'' ⊙ ḧ_ij
2095        let gamma = &self.w2 * &self.h_diag;
2096        let gamma_dot_i = &(&(&self.w3 * deta_i) * &self.h_diag) + &(&self.w2 * dh_i);
2097        let gamma_dot_j = &(&(&self.w3 * deta_j) * &self.h_diag) + &(&self.w2 * dh_j);
2098        let gamma_ddot = &(&(&(&self.w4 * deta_i) * deta_j) * &self.h_diag)
2099            + &(&(&(&self.w3 * deta_ij) * &self.h_diag)
2100                + &(&(&self.w3 * deta_i) * dh_j)
2101                + &(&(&self.w3 * deta_j) * dh_i)
2102                + &(&self.w2 * &dh_ij));
2103
2104        // ─────────────────────────────────────────────────────────────────
2105        //  Diagonal-term β-rhs contributions:
2106        //    ∂²(XᵀΓX)/∂τ_i∂τ_j · V
2107        //  = X_{ij}ᵀ (γ ⊙ η_V)       + Xᵀ (γ ⊙ η_{ij,V})         [if X_ij]
2108        //    + X_iᵀ (γ ⊙ η_{j,V})    + X_jᵀ (γ ⊙ η_{i,V})
2109        //    + X_iᵀ (γ̇_j ⊙ η_V)     + Xᵀ (γ̇_j ⊙ η_{i,V})
2110        //    + X_jᵀ (γ̇_i ⊙ η_V)     + Xᵀ (γ̇_i ⊙ η_{j,V})
2111        //    + Xᵀ (γ̈_ij ⊙ η_V).
2112        // ─────────────────────────────────────────────────────────────────
2113        let mut diag_term = Array2::<f64>::zeros((p, m));
2114        let gamma_col = gamma.view().insert_axis(Axis(1));
2115        let gamma_i_col = gamma_dot_i.view().insert_axis(Axis(1));
2116        let gamma_j_col = gamma_dot_j.view().insert_axis(Axis(1));
2117        let gamma_ij_col = gamma_ddot.view().insert_axis(Axis(1));
2118
2119        // X_iᵀ (γ ⊙ η_{j,V}) + X_jᵀ (γ ⊙ η_{i,V})
2120        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_j_v * &gamma_col));
2121        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_i_v * &gamma_col));
2122        // X_iᵀ (γ̇_j ⊙ η_V) + X_jᵀ (γ̇_i ⊙ η_V)
2123        diag_term = diag_term + fast_atb(x_tau_i, &(&eta_v * &gamma_j_col));
2124        diag_term = diag_term + fast_atb(x_tau_j, &(&eta_v * &gamma_i_col));
2125        // Xᵀ (γ̇_j ⊙ η_{i,V}) + Xᵀ (γ̇_i ⊙ η_{j,V})
2126        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_i_v * &gamma_j_col));
2127        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_j_v * &gamma_i_col));
2128        // Xᵀ (γ̈_ij ⊙ η_V)
2129        diag_term = diag_term + fast_ab(&self.x_dense_t, &(&eta_v * &gamma_ij_col));
2130        // X_{ij}ᵀ (γ ⊙ η_V) + Xᵀ (γ ⊙ η_{ij,V})
2131        if x_tau_tau_is_some {
2132            // X_{ij}ᵀ = Q X_{r,ij}ᵀ · diag(1/√a)  (inverse of the reduce shim),
2133            // but caller supplies the reduced second-derivative design.  We
2134            // form X_{ij}ᵀ Y as q_basis · (X_{r,ij}ᵀ · diag(1/√a)·Y) = Q · X_{r,ij}ᵀ (Y unscaled).
2135            // When no observation weights, X_{r,ij} = X_{ij} Q and
2136            //   X_{ij}ᵀ Y = Q X_{r,ij}ᵀ Y.
2137            let y: Array2<f64> = &eta_v * &gamma_col;
2138            let xt_ij_y: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2139                let mut y_scaled = y.clone();
2140                RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2141                    &mut y_scaled,
2142                    self.observation_weight_sqrt.as_ref(),
2143                );
2144                self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2145            } else {
2146                self.q_basis.dot(&x_rij.t().dot(&y))
2147            };
2148            diag_term = diag_term + xt_ij_y;
2149            diag_term = diag_term + self.x_dense_t.dot(&(&eta_ij_v * &gamma_col));
2150        }
2151
2152        // ─────────────────────────────────────────────────────────────────
2153        //  BᵀPB branch — 9-term expansion.
2154        //
2155        //  Represent each B-like operator as an "n-row scaling vector for the
2156        //  X part plus tails along X_τ and X_{ij}".  For rhs V, define the
2157        //  row-scaled η-space blocks R(B) = diag(scale) X V + tails.  Then
2158        //  Bᵀ (P action) R is assembled by row-scaling and left-multiplying
2159        //  the appropriate full designs.
2160        // ─────────────────────────────────────────────────────────────────
2161
2162        // B V row-block (eta-space):  B V = diag(w') X V.
2163        let w1_col = self.w1.view().insert_axis(Axis(1));
2164        let b_v = &eta_v * &w1_col;
2165
2166        // Ḃ_i V = diag(w'' ⊙ η̇_i) X V + diag(w') X_i V.
2167        let w2_deta_i = &self.w2 * deta_i;
2168        let w2_deta_j = &self.w2 * deta_j;
2169        let w2_deta_i_col = w2_deta_i.view().insert_axis(Axis(1));
2170        let w2_deta_j_col = w2_deta_j.view().insert_axis(Axis(1));
2171        let bdot_i_v = &(&eta_v * &w2_deta_i_col) + &(&eta_i_v * &w1_col);
2172        let bdot_j_v = &(&eta_v * &w2_deta_j_col) + &(&eta_j_v * &w1_col);
2173
2174        // B̈_{ij} V =
2175        //   diag(w''' ⊙ η̇_i ⊙ η̇_j + w'' ⊙ η̈_ij) X V
2176        //   + diag(w'' ⊙ η̇_i) X_j V
2177        //   + diag(w'' ⊙ η̇_j) X_i V
2178        //   + diag(w') X_{ij} V.
2179        let w3_didj = &(&self.w3 * deta_i) * deta_j;
2180        let w2_dij = &self.w2 * deta_ij;
2181        let bddot_scale = &w3_didj + &w2_dij;
2182        let bddot_scale_col = bddot_scale.view().insert_axis(Axis(1));
2183        let mut bddot_ij_v = &eta_v * &bddot_scale_col;
2184        bddot_ij_v += &(&eta_j_v * &w2_deta_i_col);
2185        bddot_ij_v += &(&eta_i_v * &w2_deta_j_col);
2186        bddot_ij_v += &(&eta_ij_v * &w1_col);
2187
2188        // P V  (columnwise, using K ⊙ K Hadamard gram on Z = X_r).
2189        let p_bv = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &b_v);
2190        let p_bddot_ij_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bddot_ij_v);
2191
2192        // Ṗ_i, Ṗ_j applied to B V, Ḃ_j V, Ḃ_i V — use the existing
2193        // apply_mtau_to_matrix helper, which computes 2(M ⊙ Ṁ_τ) · mat.
2194        //
2195        // Construct a lightweight "FirthTauPartialKernel"-shaped tuple only for
2196        // apply_mtau_to_matrix; we mirror its input contract inline to avoid
2197        // owning a FirthTauPartialKernel copy here.
2198        let pdot_i_bv = self.apply_mtau_from_reduced(x_ri, dot_k_i, &b_v);
2199        let pdot_j_bv = self.apply_mtau_from_reduced(x_rj, dot_k_j, &b_v);
2200        let pdot_i_bdot_j_v = self.apply_mtau_from_reduced(x_ri, dot_k_i, &bdot_j_v);
2201        let pdot_j_bdot_i_v = self.apply_mtau_from_reduced(x_rj, dot_k_j, &bdot_i_v);
2202
2203        // P Ḃ_j V and P Ḃ_i V.
2204        let p_bdot_j_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_j_v);
2205        let p_bdot_i_v = RemlState::apply_hadamard_gram_to_matrix(z, k, k, &bdot_i_v);
2206
2207        // P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V  + 2 (M ⊙ M̈_{ij}) V.
2208        let p_ddot_b_v = self.apply_p_ddot_ij(
2209            x_r,
2210            x_ri,
2211            x_rj,
2212            x_rij,
2213            k,
2214            dot_k_i,
2215            dot_k_j,
2216            &k_ddot,
2217            x_tau_tau_is_some,
2218            &b_v,
2219        );
2220
2221        // Assemble 9 terms of D²(BᵀPB)[τ_i, τ_j] · V.
2222        //   term1 = B̈_ijᵀ P B V + Bᵀ P B̈_ij V
2223        //   term2 = Ḃ_iᵀ P Ḃ_j V + Ḃ_jᵀ P Ḃ_i V
2224        //   term3 = Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
2225        //   term4 = Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
2226        //   term5 = Bᵀ P̈_ij B V
2227        //
2228        // "Bᵀ Q V" with B = diag(w') X equals left_scaled_xt(w1, Q V).
2229        // "Ḃ_iᵀ Q V" = diag(w'' ⊙ η̇_i) X acting on the left, plus
2230        //              diag(w') X_i on the left.  In transpose:
2231        //   Ḃ_iᵀ Q V = Xᵀ (diag(w'' ⊙ η̇_i) Q V) + X_iᵀ (diag(w') Q V).
2232        // "B̈_ijᵀ Q V" mirrors B̈_ij above in transpose.
2233
2234        let apply_bdot_tau_t =
2235            |scale_deta: &Array1<f64>, x_tau_mat: &Array2<f64>, q_v: &Array2<f64>| {
2236                let scale_col = scale_deta.view().insert_axis(Axis(1));
2237                self.x_dense_t.dot(&(q_v * &scale_col)) + x_tau_mat.t().dot(&(q_v * &w1_col))
2238            };
2239
2240        let apply_bddot_ij_t = |q_v: &Array2<f64>| -> Array2<f64> {
2241            let scale_col_full = bddot_scale.view().insert_axis(Axis(1));
2242            let mut out = self.x_dense_t.dot(&(q_v * &scale_col_full));
2243            out = out + x_tau_j.t().dot(&(q_v * &w2_deta_i_col));
2244            out = out + x_tau_i.t().dot(&(q_v * &w2_deta_j_col));
2245            if x_tau_tau_is_some {
2246                // X_{ij}ᵀ (w1 ⊙ Q V)
2247                let y = q_v * &w1_col;
2248                let contrib: Array2<f64> = if self.observation_weight_sqrt.is_some() {
2249                    let mut y_scaled = y.clone();
2250                    RemlState::scale_rows_by_inverse_observation_weight_sqrt(
2251                        &mut y_scaled,
2252                        self.observation_weight_sqrt.as_ref(),
2253                    );
2254                    self.q_basis.dot(&x_rij.t().dot(&y_scaled))
2255                } else {
2256                    self.q_basis.dot(&x_rij.t().dot(&y))
2257                };
2258                out = out + contrib;
2259            }
2260            out
2261        };
2262
2263        // term1
2264        let t1a = apply_bddot_ij_t(&p_bv);
2265        let t1b = self.left_scaled_xt(&self.w1, &p_bddot_ij_v);
2266        // term2
2267        let t2a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &p_bdot_j_v);
2268        let t2b = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &p_bdot_i_v);
2269        // term3: Ḃ_iᵀ Ṗ_j B V + Bᵀ Ṗ_j Ḃ_i V
2270        let t3a = apply_bdot_tau_t(&w2_deta_i, x_tau_i, &pdot_j_bv);
2271        let t3b = self.left_scaled_xt(&self.w1, &pdot_j_bdot_i_v);
2272        // term4: Ḃ_jᵀ Ṗ_i B V + Bᵀ Ṗ_i Ḃ_j V
2273        let t4a = apply_bdot_tau_t(&w2_deta_j, x_tau_j, &pdot_i_bv);
2274        let t4b = self.left_scaled_xt(&self.w1, &pdot_i_bdot_j_v);
2275        // term5
2276        let t5 = self.left_scaled_xt(&self.w1, &p_ddot_b_v);
2277
2278        let d2_bpb = t1a + t1b + t2a + t2b + t3a + t3b + t4a + t4b + t5;
2279
2280        0.5 * (diag_term - d2_bpb)
2281    }
2282
2283    /// Pair-level exact Firth kernel at fixed β for a (τ_i, τ_j) outer
2284    /// coordinate pair.
2285    ///
2286    /// Returns the two SCALAR- and P-VECTOR-valued second-derivative
2287    /// objects that the unified REML evaluator threads into
2288    /// `HyperCoordPair::{a,g}` as additive Firth contributions, plus an
2289    /// optional prepared Primitive-A `FirthTauTauPartialKernel` that the
2290    /// pair-callback can reuse for the `b_operator` action.
2291    ///
2292    /// ═════════════════════════════════════════════════════════════════════
2293    ///  DERIVATIONS (fixed β, reduced-basis identifiable coords).
2294    ///
2295    ///  Φ = 0.5 log|I_r| − 0.5 log|S_r|,   K_r = I_r⁻¹,   G_r = diag(S_r⁻¹)
2296    ///  Φ_{τ_i}|β = 0.5 tr(K_r İ_{r,i}) − 0.5 tr(G_r Ṡ_{r,i}).
2297    ///
2298    /// ┌── pair.a scalar Φ_{τ_i τ_j}|β ────────────────────────────────────┐
2299    ///  ∂/∂τ_j [0.5 tr(K_r İ_{r,i})]
2300    ///    = 0.5 tr(K̇_{r,j} İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2301    ///    = −0.5 tr(K_r İ_{r,j} K_r İ_{r,i}) + 0.5 tr(K_r Ï_{r,ij})
2302    ///
2303    ///  ∂/∂τ_j [−0.5 tr(G_r Ṡ_{r,i})]
2304    ///    = −0.5 tr(Ġ_{r,j} Ṡ_{r,i}) − 0.5 tr(G_r S̈_{r,ij})
2305    ///  (G_r diagonal in canonical basis →
2306    ///   Ġ_{r,j}_kk = −G_r_kk² · diag(Ṡ_{r,j})_kk.)
2307    ///
2308    ///  Ï_{r,ij} is the same 9-term Fisher cross used by Primitive A
2309    ///  (see `hphi_tau_tau_partial_apply`:i_ddot block).
2310    ///
2311    ///  S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2312    ///            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}.
2313    /// └─────────────────────────────────────────────────────────────────┘
2314    ///
2315    /// ┌── pair.g p-vector (gΦ)_{τ_i τ_j}|β ───────────────────────────────┐
2316    ///  (gΦ)_{τ_i} = 0.5 X_{τ_i}^T (w1 ⊙ h)
2317    ///              + 0.5 X^T [ (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i ]
2318    ///
2319    ///  Differentiating wrt τ_j at fixed β, using η̇_α = X_α β, η̈_{ij} =
2320    ///  X_{ij} β (when x_tau_tau is provided, else 0), and ḣ_α, ḧ_{ij}
2321    ///  from Primitive A:
2322    ///
2323    ///  term_A = 0.5 ∂/∂τ_j [X_{τ_i}^T (w1 ⊙ h)]
2324    ///        = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)          [if X_{ij} present]
2325    ///        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2326    ///
2327    ///  term_B = 0.5 ∂/∂τ_j [X^T · v_{τ_i}] with
2328    ///            v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2329    ///        = 0.5 X_{τ_j}^T v_{τ_i}
2330    ///        + 0.5 X^T · v̇_{τ_i,τ_j}
2331    ///
2332    ///  where the inner derivative
2333    ///  v̇_{τ_i,τ_j} = (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h   (from ∂w2 = w3 ⊙ η̇_j)
2334    ///              + (w2 ⊙ η̈_ij) ⊙ h          (from ∂η̇_i = η̈_{ij})
2335    ///              + (w2 ⊙ η̇_i) ⊙ ḣ_j        (from ∂h = ḣ_j)
2336    ///              + (w2 ⊙ η̇_j) ⊙ ḣ_i        (from ∂w1 = w2 ⊙ η̇_j, ⊙ ḣ_i)
2337    ///              +  w1 ⊙ ḧ_{ij}             (from ∂ḣ_i = ḧ_{ij}).
2338    /// └─────────────────────────────────────────────────────────────────┘
2339    ///
2340    /// ALL Ï_{r,ij}, η̈_{ij}, ḣ_i, ḧ_{ij} computations are identical to
2341    /// those already computed inside Primitive A's `hphi_tau_tau_partial_apply`.
2342    /// We replicate only the pieces needed to yield the scalar and p-vector
2343    /// outputs to avoid computing the full p×m action when unnecessary.
2344    ///
2345    pub(crate) fn exact_tau_tau_kernel(
2346        &self,
2347        x_tau_i: &Array2<f64>,
2348        x_tau_j: &Array2<f64>,
2349        x_tau_tau: Option<&Array2<f64>>,
2350        beta: &Array1<f64>,
2351        include_hphi_tau_tau_kernel: bool,
2352    ) -> FirthTauTauExactKernel {
2353        let deta_i = x_tau_i.dot(beta);
2354        let deta_j = x_tau_j.dot(beta);
2355        let deta_ij = x_tau_tau.as_ref().map(|xij| xij.dot(beta));
2356
2357        let x_tau_i_reduced = self.reduce_explicit_design(x_tau_i);
2358        let x_tau_j_reduced = self.reduce_explicit_design(x_tau_j);
2359        let x_tau_tau_reduced = x_tau_tau.map(|xij| self.reduce_explicit_design(xij));
2360
2361        let (dot_i_i, dot_h_i) = self.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
2362        let (dot_i_j, dot_h_j) = self.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
2363
2364        // Ï_{r,ij} = X_{r,ij}^T W X_r + X_r^T W X_{r,ij}
2365        //            + X_{r,i}^T W X_{r,j} + X_{r,j}^T W X_{r,i}
2366        //            + X_{r,i}^T Ẇ_j X_r + X_r^T Ẇ_j X_{r,i}
2367        //            + X_{r,j}^T Ẇ_i X_r + X_r^T Ẇ_i X_{r,j}
2368        //            + X_r^T Ẅ_{ij} X_r
2369        // Ẇ_α = diag(w' ⊙ η̇_α);  Ẅ_{ij} = diag(w'' ⊙ η̇_i ⊙ η̇_j + w' ⊙ η̈_{ij}).
2370        let zeros_n = Array1::<f64>::zeros(self.x_dense.nrows());
2371        let deta_ij_ref: &Array1<f64> = deta_ij.as_ref().unwrap_or(&zeros_n);
2372        let dw_i = &self.w1 * &deta_i;
2373        let dw_j = &self.w1 * &deta_j;
2374        let ddw_ij = &(&self.w2 * &(&deta_i * &deta_j)) + &(&self.w1 * deta_ij_ref);
2375
2376        let x_r = &self.x_reduced;
2377        let mut i_ddot = Array2::<f64>::zeros(self.k_reduced.raw_dim());
2378        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2379            i_ddot = i_ddot + RemlState::weighted_cross(x_rij, x_r, &self.w);
2380            i_ddot = i_ddot + RemlState::weighted_cross(x_r, x_rij, &self.w);
2381        }
2382        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, &x_tau_j_reduced, &self.w);
2383        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, &x_tau_i_reduced, &self.w);
2384        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_i_reduced, x_r, &dw_j);
2385        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_i_reduced, &dw_j);
2386        i_ddot = i_ddot + RemlState::weighted_cross(&x_tau_j_reduced, x_r, &dw_i);
2387        i_ddot = i_ddot + RemlState::weighted_cross(x_r, &x_tau_j_reduced, &dw_i);
2388        i_ddot = i_ddot + gam_linalg::faer_ndarray::fast_xt_diag_x(x_r, &ddw_ij);
2389
2390        // pair.a likelihood contribution:
2391        //   0.5 tr(K_r Ï_{r,ij}) − 0.5 tr(K_r İ_{r,j} K_r İ_{r,i}).
2392        // K_r İ_{r,α} — reuse inline dot().
2393        let k = &self.k_reduced;
2394        let k_dot_i_i = k.dot(&dot_i_i);
2395        let k_dot_i_j = k.dot(&dot_i_j);
2396        let a_lik = 0.5 * RemlState::trace_product(k, &i_ddot)
2397            - 0.5 * RemlState::trace_product(&k_dot_i_j, &k_dot_i_i);
2398
2399        // pair.a penalty-basis contribution:
2400        //   Ṡ_{r,α} = X_{r,α}^T X_r + X_r^T X_{r,α}
2401        //   S̈_{r,ij} = X_{r,ij}^T X_r + X_r^T X_{r,ij}
2402        //            + X_{r,i}^T X_{r,j} + X_{r,j}^T X_{r,i}
2403        //   tr(G_r Ṡ_{r,i}) = Σ_k G_r_kk · diag(Ṡ_{r,i})_kk
2404        //   tr(Ġ_{r,j} Ṡ_{r,i}) = −Σ_k G_r_kk² · diag(Ṡ_{r,j})_kk · diag(Ṡ_{r,i})_kk
2405        let dot_s_i = fast_atb(&x_tau_i_reduced, x_r) + fast_atb(x_r, &x_tau_i_reduced);
2406        let dot_s_j = fast_atb(&x_tau_j_reduced, x_r) + fast_atb(x_r, &x_tau_j_reduced);
2407        let mut s_ddot = Array2::<f64>::zeros(k.raw_dim());
2408        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2409            s_ddot = s_ddot + fast_atb(x_rij, x_r) + fast_atb(x_r, x_rij);
2410        }
2411        s_ddot = s_ddot
2412            + fast_atb(&x_tau_i_reduced, &x_tau_j_reduced)
2413            + fast_atb(&x_tau_j_reduced, &x_tau_i_reduced);
2414        // With G_r = diag(g) in the canonical reduced basis (where
2415        // S_r is diagonal), S_r + τ·Ṡ is generally non-diagonal under
2416        // perturbation, so Ġ_j = −G Ṡ_j G picks up OFF-DIAGONAL terms:
2417        //     (Ġ_j)_{kl} = −G_k · (Ṡ_j)_{kl} · G_l.
2418        // Hence tr(Ġ_j Ṡ_i) = −Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{lk}.
2419        // Using symmetry of Ṡ_i (and Ṡ_j):
2420        //     −0.5 tr(Ġ_j Ṡ_i) = +0.5 Σ_{k,l} G_k G_l (Ṡ_j)_{kl} (Ṡ_i)_{kl}.
2421        // The S̈_{ij} trace against diagonal G_r picks only the diagonal.
2422        let g_inv = &self.x_metric_reduced_inv_diag;
2423        let rdim = k.nrows();
2424        let mut a_pen = 0.0_f64;
2425        for kk in 0..rdim {
2426            for ll in 0..rdim {
2427                a_pen += 0.5 * g_inv[kk] * g_inv[ll] * dot_s_j[[kk, ll]] * dot_s_i[[kk, ll]];
2428            }
2429            a_pen -= 0.5 * g_inv[kk] * s_ddot[[kk, kk]];
2430        }
2431        let phi_tau_tau_partial = a_lik + a_pen;
2432
2433        // ─── pair.g p-vector: (gΦ)_{τ_i τ_j}|β ──────────────────────────
2434        //
2435        // Assemble ḧ_{ij} identically to Primitive A's body.  We need:
2436        //   K̇_{r,α} = −K_r İ_{r,α} K_r,
2437        //   K̈_{r,ij} = −K_r Ï_{r,ij} K_r + K_r İ_{r,i} K_r İ_{r,j} K_r
2438        //                                 + K_r İ_{r,j} K_r İ_{r,i} K_r.
2439        let dot_k_i = -k.dot(&dot_i_i).dot(k);
2440        let dot_k_j = -k.dot(&dot_i_j).dot(k);
2441        let a_i_red = -&dot_k_i; // K İ_i K
2442        let a_j_red = -&dot_k_j; // K İ_j K
2443        let k_ddot: Array2<f64> =
2444            -k.dot(&i_ddot).dot(k) + a_i_red.dot(&dot_i_j).dot(k) + a_j_red.dot(&dot_i_i).dot(k);
2445
2446        // ḧ_{ij} = 2 diag(X_{r,ij} K X_r^T)
2447        //        + diag(X_r K̈_{ij} X_r^T)
2448        //        + 2 diag(X_{r,i} K̇_j X_r^T)
2449        //        + 2 diag(X_{r,j} K̇_i X_r^T)
2450        //        + 2 diag(X_{r,i} K X_{r,j}^T).
2451        let n = self.x_dense.nrows();
2452        let mut dh_ij = Array1::<f64>::zeros(n);
2453        if let Some(x_rij) = x_tau_tau_reduced.as_ref() {
2454            let rij_k = x_rij.dot(k);
2455            dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rij_k, x_r);
2456        }
2457        let xr_kddot = x_r.dot(&k_ddot);
2458        dh_ij = dh_ij + Self::rowwise_dot(&xr_kddot, x_r);
2459        let ri_kdot_j = x_tau_i_reduced.dot(&dot_k_j);
2460        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_kdot_j, x_r);
2461        let rj_kdot_i = x_tau_j_reduced.dot(&dot_k_i);
2462        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&rj_kdot_i, x_r);
2463        let ri_k = x_tau_i_reduced.dot(k);
2464        dh_ij = dh_ij + 2.0 * Self::rowwise_dot(&ri_k, &x_tau_j_reduced);
2465
2466        // term_A = 0.5 X_{τ_i τ_j}^T (w1 ⊙ h)
2467        //        + 0.5 X_{τ_i}^T [ (w2 ⊙ η̇_j) ⊙ h + w1 ⊙ ḣ_j ]
2468        let w1_h = &self.w1 * &self.h_diag;
2469        let mut gphi_tau_tau = Array1::<f64>::zeros(self.x_dense.ncols());
2470        if let Some(x_ij) = x_tau_tau.as_ref() {
2471            gphi_tau_tau = gphi_tau_tau + 0.5 * x_ij.t().dot(&w1_h);
2472        }
2473        let inner_j = &(&(&self.w2 * &deta_j) * &self.h_diag) + &(&self.w1 * &dot_h_j);
2474        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_i.t().dot(&inner_j);
2475
2476        // term_B pieces:  v_{τ_i} = (w2 ⊙ η̇_i) ⊙ h + w1 ⊙ ḣ_i
2477        let v_tau_i = &(&(&self.w2 * &deta_i) * &self.h_diag) + &(&self.w1 * &dot_h_i);
2478        gphi_tau_tau = gphi_tau_tau + 0.5 * x_tau_j.t().dot(&v_tau_i);
2479
2480        // v̇_{τ_i,τ_j} =
2481        //    (w3 ⊙ η̇_j ⊙ η̇_i) ⊙ h
2482        //  + (w2 ⊙ η̈_{ij}) ⊙ h
2483        //  + (w2 ⊙ η̇_i) ⊙ ḣ_j
2484        //  + (w2 ⊙ η̇_j) ⊙ ḣ_i
2485        //  +  w1 ⊙ ḧ_{ij}.
2486        let mut v_dot_ij = &(&(&self.w3 * &deta_j) * &deta_i) * &self.h_diag;
2487        v_dot_ij += &(&(&self.w2 * deta_ij_ref) * &self.h_diag);
2488        v_dot_ij += &(&(&self.w2 * &deta_i) * &dot_h_j);
2489        v_dot_ij += &(&(&self.w2 * &deta_j) * &dot_h_i);
2490        v_dot_ij += &(&self.w1 * &dh_ij);
2491        gphi_tau_tau = gphi_tau_tau + 0.5 * self.x_dense.t().dot(&v_dot_ij);
2492
2493        let tau_tau_kernel = if include_hphi_tau_tau_kernel {
2494            Some(self.hphi_tau_tau_partial_prepare_from_partials(
2495                x_tau_i_reduced,
2496                x_tau_j_reduced,
2497                &deta_i,
2498                &deta_j,
2499                dot_h_i,
2500                dot_h_j,
2501                dot_i_i,
2502                dot_i_j,
2503                x_tau_tau_reduced,
2504                deta_ij,
2505            ))
2506        } else {
2507            None
2508        };
2509
2510        FirthTauTauExactKernel {
2511            phi_tau_tau_partial,
2512            gphi_tau_tau,
2513            tau_tau_kernel,
2514        }
2515    }
2516
2517    /// Apply `Ṗ_τ V = 2 (M ⊙ Ṁ_τ) V` given the reduced τ-drift design
2518    /// `x_tau_reduced` and the reduced Fisher-inverse drift `dot_k_reduced`.
2519    ///
2520    /// This mirrors the body of `apply_mtau_to_matrix` but accepts the
2521    /// x_tau/dot_k pieces directly, letting Primitive A reuse the same
2522    /// matrix-free Ṗ_τ applies without owning a `FirthTauPartialKernel`.
2523    pub(crate) fn apply_mtau_from_reduced(
2524        &self,
2525        x_tau_reduced: &Array2<f64>,
2526        dot_k_reduced: &Array2<f64>,
2527        mat: &Array2<f64>,
2528    ) -> Array2<f64> {
2529        if mat.nrows() != self.x_dense.nrows() || mat.ncols() == 0 {
2530            return Array2::<f64>::zeros(mat.raw_dim());
2531        }
2532        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2533        for col in 0..mat.ncols() {
2534            let v = mat.column(col).to_owned();
2535            let szz = RemlState::reducedweighted_gram(&self.x_reduced, &v);
2536            let mzz = self.k_reduced.dot(&szz).dot(&self.k_reduced);
2537            let t1 = Self::rowwise_bilinear(&self.x_reduced, &mzz, x_tau_reduced);
2538
2539            let szt = RemlState::reduced_crossweighted_gram(&self.x_reduced, x_tau_reduced, &v);
2540            let mzt = self.k_reduced.dot(&szt).dot(&self.k_reduced);
2541            let t2 = RemlState::reduced_diag_gram(&self.x_reduced, &mzt);
2542
2543            let t3 =
2544                RemlState::apply_hadamard_gram(&self.x_reduced, &self.k_reduced, dot_k_reduced, &v);
2545
2546            let y = 2.0 * (t1 + t2 + t3);
2547            out.column_mut(col).assign(&y);
2548        }
2549        out
2550    }
2551
2552    /// Apply `P̈_{ij} V = 4 (Ṁ_i ⊙ Ṁ_j) V + 2 (M ⊙ M̈_{ij}) V` columnwise.
2553    ///
2554    /// `M̈_{ij}` expands into 9 pieces `Y_α C Y_βᵀ`; `Ṁ_i ⊙ Ṁ_j` into 9 cross
2555    /// pieces `(Y_{1,α} B_{1,α} W_{1,α}ᵀ) ⊙ (Y_{2,β} B_{2,β} W_{2,β}ᵀ)`.  Both
2556    /// are evaluated via the matrix-free identities:
2557    ///
2558    ///   [(ZAZᵀ) ⊙ (YBWᵀ) v]_i   = rowwise_bilinear(Y, B · (Wᵀdiag(v)Z) · A, Z)_i,
2559    ///   [(YBWᵀ) ⊙ (Y'B'W'ᵀ) v]_i= rowwise_bilinear(Y, B · (Wᵀdiag(v)W') · B'ᵀ, Y')_i,
2560    ///
2561    /// with S := row-wise reducedweighted Gram.
2562    pub(crate) fn apply_p_ddot_ij(
2563        &self,
2564        x_r: &Array2<f64>,
2565        x_ri: &Array2<f64>,
2566        x_rj: &Array2<f64>,
2567        x_rij: &Array2<f64>,
2568        k: &Array2<f64>,
2569        dot_k_i: &Array2<f64>,
2570        dot_k_j: &Array2<f64>,
2571        k_ddot: &Array2<f64>,
2572        x_tau_tau_is_some: bool,
2573        mat: &Array2<f64>,
2574    ) -> Array2<f64> {
2575        let n = self.x_dense.nrows();
2576        let m = mat.ncols();
2577        if mat.nrows() != n || m == 0 {
2578            return Array2::<f64>::zeros(mat.raw_dim());
2579        }
2580        let mut out = Array2::<f64>::zeros((n, m));
2581        for col in 0..m {
2582            let v = mat.column(col).to_owned();
2583            // Shared reducedweighted Grams for this column.  Only the Grams
2584            // actually appearing in the 18 pieces below are computed.
2585            let s_zz = RemlState::reducedweighted_gram(x_r, &v); // Z'diag(v)Z
2586            let s_zj = RemlState::reduced_crossweighted_gram(x_r, x_rj, &v); // Z'diag(v)Y_j
2587            let s_iz = RemlState::reduced_crossweighted_gram(x_ri, x_r, &v); // Y_i'diag(v)Z
2588            let s_jz = RemlState::reduced_crossweighted_gram(x_rj, x_r, &v); // Y_j'diag(v)Z
2589            let s_ij = RemlState::reduced_crossweighted_gram(x_ri, x_rj, &v); // Y_i'diag(v)Y_j
2590
2591            // ── 4 (Ṁ_i ⊙ Ṁ_j) v ──
2592            // Ṁ_i has three pieces:
2593            //   P_i,a = Y_i K Zᵀ          — Y=Y_i, B=K, W=Z
2594            //   P_i,b = Z K̇_i Zᵀ         — Y=Z,   B=K̇_i, W=Z
2595            //   P_i,c = Z K Y_iᵀ          — Y=Z,   B=K,  W=Y_i
2596            // And symmetrically for Ṁ_j with (i→j).
2597            //
2598            // For each cross pair (α, β), compute
2599            //   core = B_α · (W_αᵀ diag(v) W_β) · B_βᵀ,
2600            //   y_piece = rowwise_bilinear(Y_α, core, Y_β),
2601            // then sum all 9 and scale by 4.
2602            let mut mdot_mdot = Array1::<f64>::zeros(n);
2603            // (a_i, a_j): Y_i, K, Z  ×  Y_j, K, Z  → W_α=Z, W_β=Z, S = s_zz
2604            {
2605                let core = k.dot(&s_zz).dot(&k.t());
2606                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_rj);
2607            }
2608            // (a_i, b_j): Y_i, K, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K · s_zz · K̇_jᵀ
2609            {
2610                let core = k.dot(&s_zz).dot(&dot_k_j.t());
2611                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2612            }
2613            // (a_i, c_j): Y_i, K, Z  ×  Z, K, Y_j  → S = s_zj; core = K · s_zj · Kᵀ
2614            {
2615                let core = k.dot(&s_zj).dot(&k.t());
2616                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_ri, &core, x_r);
2617            }
2618            // (b_i, a_j): Z, K̇_i, Z  ×  Y_j, K, Z  → S = s_zz; core = K̇_i · s_zz · Kᵀ
2619            {
2620                let core = dot_k_i.dot(&s_zz).dot(&k.t());
2621                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2622            }
2623            // (b_i, b_j): Z, K̇_i, Z  ×  Z, K̇_j, Z  → S = s_zz; core = K̇_i · s_zz · K̇_jᵀ
2624            {
2625                let core = dot_k_i.dot(&s_zz).dot(&dot_k_j.t());
2626                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2627            }
2628            // (b_i, c_j): Z, K̇_i, Z  ×  Z, K, Y_j  → S = s_zj; core = K̇_i · s_zj · Kᵀ
2629            {
2630                let core = dot_k_i.dot(&s_zj).dot(&k.t());
2631                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2632            }
2633            // (c_i, a_j): Z, K, Y_i  ×  Y_j, K, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2634            //   core = K · s_iz · Kᵀ
2635            {
2636                let core = k.dot(&s_iz).dot(&k.t());
2637                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_rj);
2638            }
2639            // (c_i, b_j): Z, K, Y_i  ×  Z, K̇_j, Z  → S = Y_iᵀ diag(v) Z = s_iz;
2640            //   core = K · s_iz · K̇_jᵀ
2641            {
2642                let core = k.dot(&s_iz).dot(&dot_k_j.t());
2643                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2644            }
2645            // (c_i, c_j): Z, K, Y_i  ×  Z, K, Y_j  → S = Y_iᵀ diag(v) Y_j = s_ij;
2646            //   core = K · s_ij · Kᵀ
2647            {
2648                let core = k.dot(&s_ij).dot(&k.t());
2649                mdot_mdot = mdot_mdot + Self::rowwise_bilinear(x_r, &core, x_r);
2650            }
2651
2652            // ── 2 (M ⊙ M̈_{ij}) v ──
2653            // Each piece has the form Y_α C W_βᵀ; M = Z K Zᵀ with A=K.
2654            // Identity:  [(ZAZᵀ) ⊙ (Y_α C W_βᵀ) v]_i
2655            //          = rowwise_bilinear(Y_α, C · (W_βᵀ diag(v) Z) · A, Z).
2656            let mut m_mddot = Array1::<f64>::zeros(n);
2657            // (a) Y_α = X_{r,ij}, C = K, W_β = X_r  → W_βᵀ diag(v) Z = s_zz
2658            if x_tau_tau_is_some {
2659                let core = k.dot(&s_zz).dot(k);
2660                m_mddot = m_mddot + Self::rowwise_bilinear(x_rij, &core, x_r);
2661            }
2662            // (b) Y_α = X_r, C = K, W_β = X_{r,ij} → W_βᵀ diag(v) Z = X_{r,ij}ᵀ diag(v) Z
2663            if x_tau_tau_is_some {
2664                let s_ijz = RemlState::reduced_crossweighted_gram(x_rij, x_r, &v);
2665                let core = k.dot(&s_ijz).dot(k);
2666                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2667            }
2668            // (c) Y_α = X_{r,i}, C = K̇_j, W_β = X_r → S = s_zz
2669            {
2670                let core = dot_k_j.dot(&s_zz).dot(k);
2671                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2672            }
2673            // (d) Y_α = X_r, C = K̇_j, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2674            {
2675                let core = dot_k_j.dot(&s_iz).dot(k);
2676                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2677            }
2678            // (e) Y_α = X_{r,j}, C = K̇_i, W_β = X_r → S = s_zz
2679            {
2680                let core = dot_k_i.dot(&s_zz).dot(k);
2681                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2682            }
2683            // (f) Y_α = X_r, C = K̇_i, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2684            {
2685                let core = dot_k_i.dot(&s_jz).dot(k);
2686                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2687            }
2688            // (g) Y_α = X_{r,i}, C = K, W_β = X_{r,j} → W_βᵀ diag(v) Z = s_jz
2689            {
2690                let core = k.dot(&s_jz).dot(k);
2691                m_mddot = m_mddot + Self::rowwise_bilinear(x_ri, &core, x_r);
2692            }
2693            // (h) Y_α = X_{r,j}, C = K, W_β = X_{r,i} → W_βᵀ diag(v) Z = s_iz
2694            {
2695                let core = k.dot(&s_iz).dot(k);
2696                m_mddot = m_mddot + Self::rowwise_bilinear(x_rj, &core, x_r);
2697            }
2698            // (i) Y_α = X_r, C = K̈_ij, W_β = X_r → S = s_zz
2699            {
2700                let core = k_ddot.dot(&s_zz).dot(k);
2701                m_mddot = m_mddot + Self::rowwise_bilinear(x_r, &core, x_r);
2702            }
2703
2704            // P̈_{ij} = ∂²(M⊙M)/∂τ_i∂τ_j = 2(Ṁ_i ⊙ Ṁ_j) + 2(M ⊙ M̈_{ij}),
2705            // with Ṁ_τ = ∂M/∂τ (NOT the pair-squared derivative).  Factor is 2,
2706            // not 4 — the earlier "4·Ṁ_i⊙Ṁ_j" was a sign-of-the-derivative
2707            // confusion between Ṁ and ∂(M⊙M)/∂τ = 2(M⊙Ṁ).
2708            let col_out = 2.0 * mdot_mdot + 2.0 * m_mddot;
2709            out.column_mut(col).assign(&col_out);
2710        }
2711        out
2712    }
2713
2714    /// Primitive B — prepare step: assemble the reduced kernel for
2715    /// D_β((H_φ)_τ|_β)[v].
2716    ///
2717    /// Consumes the existing `FirthTauPartialKernel`, the τ-drift partials
2718    /// (`deta_partial = η̇_τ = X_τ β` and `dot_i_partial = İ_τ`), the
2719    /// β-direction `FirthDirection` built from `deta = X v`, and
2720    /// `x_tau_v = X_τ v`, and returns a cached kernel carrying the mixed
2721    /// β-τ reduced quantities A_v, dh_v, D_β(İ_τ)[v], D_β(K̇_τ)[v],
2722    /// D_β(ḣ_τ)[v], and the w-chain derivatives needed by
2723    /// `d_beta_hphi_tau_partial_apply`.
2724    pub(crate) fn d_beta_hphi_tau_partial_prepare_from_partials(
2725        &self,
2726        tau_kernel: &FirthTauPartialKernel,
2727        deta_partial: &Array1<f64>,
2728        dot_i_partial: &Array2<f64>,
2729        beta_direction: &FirthDirection,
2730        x_tau_v: &Array1<f64>,
2731    ) -> FirthTauBetaPartialKernel {
2732        // D_β(İ_τ)[v] — three-piece symmetric form from the product rule on
2733        //   İ_τ = X_{r,τ}ᵀ W X_r + X_rᵀ W X_{r,τ} + X_rᵀ diag(w' ⊙ η̇_τ) X_r,
2734        // where W = diag(w(η)) is the Fisher weight (not its derivative).
2735        // The β-differential hits w (through η=Xβ) and η̇_τ (through X_τ β):
2736        //   D_β(X_{r,τ}ᵀ W X_r)[v] = X_{r,τ}ᵀ diag(w' ⊙ δη_v) X_r,
2737        //   D_β(X_rᵀ diag(w' ⊙ η̇_τ) X_r)[v]
2738        //     = X_rᵀ diag(w'' ⊙ η̇_τ ⊙ δη_v + w' ⊙ δη_{τ,v}) X_r,
2739        // where δη_v = beta_direction.deta, δη_{τ,v} = x_tau_v.
2740        // s_v := w' ⊙ δη_v (same weight the FirthDirection uses to build
2741        // g_u_reduced); b_vvec := w'' ⊙ δη_v = beta_direction.b_uvec is the
2742        // weight for the third-term product-rule piece.
2743        let s_v = &self.w1 * &beta_direction.deta;
2744        let mixed_diag_weight = &(&tau_kernel.dotw1 * &beta_direction.deta) + &(&self.w1 * x_tau_v);
2745        let cross1 =
2746            RemlState::reduced_crossweighted_gram(&tau_kernel.x_tau_reduced, &self.x_reduced, &s_v);
2747        let cross2 =
2748            RemlState::reduced_crossweighted_gram(&self.x_reduced, &tau_kernel.x_tau_reduced, &s_v);
2749        let diag_piece = RemlState::reducedweighted_gram(&self.x_reduced, &mixed_diag_weight);
2750        let d_beta_dot_i = &cross1 + &cross2 + &diag_piece;
2751
2752        // D_β(K̇_τ)[v] — direct Leibniz on K̇_τ = -K_r İ_τ K_r with
2753        //   D_β K_r[v] = -K_r I'_v K_r = -beta_direction.a_u_reduced.
2754        // Expanding yields
2755        //   D_β K̇_τ[v] = +A_v İ_τ K_r − K_r D_β(İ_τ)[v] K_r + K_r İ_τ A_v,
2756        // where A_v := beta_direction.a_u_reduced = +K_r I'_v K_r.  The
2757        // FirthDirection carries a_u with the opposite sign convention to the
2758        // derivation block's "A_v"; we keep the direction convention and
2759        // compose signs correctly here.
2760        let term_a = beta_direction
2761            .a_u_reduced
2762            .dot(dot_i_partial)
2763            .dot(&self.k_reduced);
2764        let term_b = self.k_reduced.dot(&d_beta_dot_i).dot(&self.k_reduced);
2765        let term_c = self
2766            .k_reduced
2767            .dot(dot_i_partial)
2768            .dot(&beta_direction.a_u_reduced);
2769        let d_beta_dot_k = &term_a - &term_b + &term_c;
2770
2771        // D_β(ḣ_τ)[v] — β-differential of
2772        //   ḣ_τ = 2·diag(X_{r,τ} K_r X_rᵀ) + diag(X_r K̇_τ X_rᵀ):
2773        //   D_β ḣ_τ[v]
2774        //     = 2·diag(X_{r,τ} D_β K_r[v] X_rᵀ) + diag(X_r D_β K̇_τ[v] X_rᵀ)
2775        //     = -2·diag(X_{r,τ} A_v X_rᵀ) + diag(X_r (D_β K̇_τ[v]) X_rᵀ).
2776        let cross_diag = Self::rowwise_bilinear(
2777            &tau_kernel.x_tau_reduced,
2778            &beta_direction.a_u_reduced,
2779            &self.x_reduced,
2780        );
2781        let inner_diag = RemlState::reduced_diag_gram(&self.x_reduced, &d_beta_dot_k);
2782        let d_beta_dot_h = -2.0 * &cross_diag + &inner_diag;
2783
2784        FirthTauBetaPartialKernel {
2785            x_tau_reduced: tau_kernel.x_tau_reduced.clone(),
2786            deta_partial: deta_partial.clone(),
2787            dot_h_partial: tau_kernel.dot_h_partial.clone(),
2788            dot_i_partial: dot_i_partial.clone(),
2789            dot_k_reduced: tau_kernel.dot_k_reduced.clone(),
2790            deta_v: beta_direction.deta.clone(),
2791            deta_tau_v: x_tau_v.clone(),
2792            a_v_reduced: beta_direction.a_u_reduced.clone(),
2793            dh_v: beta_direction.dh.clone(),
2794            b_vvec: beta_direction.b_uvec.clone(),
2795            d_beta_dot_k,
2796            d_beta_dot_h,
2797        }
2798    }
2799
2800    /// Apply the mixed β-τ P-action `P_{τ,v} · mat` to an n×m column block.
2801    ///
2802    /// Expansion:
2803    ///   P_{τ,v} = 2 (M_v ⊙ M_τ) + 2 (M ⊙ M_{τ,v}),
2804    ///     M_v     = X_r K̇_v X_rᵀ,  K̇_v = -A_v (A_v = a_v_reduced),
2805    ///     M_τ     = X_{r,τ} K_r X_rᵀ + X_r K_r X_{r,τ}ᵀ + X_r K̇_τ X_rᵀ,
2806    ///     M_{τ,v} = X_{r,τ} K̇_v X_rᵀ + X_r K̇_v X_{r,τ}ᵀ + X_r D_β K̇_τ[v] X_rᵀ.
2807    /// Hadamard-Gram pieces are evaluated column-wise via
2808    ///   ((Z M_A Wᵀ) ⊙ (Y M_B Xᵀ)) v row-i
2809    ///       = z_iᵀ M_A (Wᵀ diag(v) X) M_Bᵀ y_i.
2810    pub(crate) fn apply_p_tau_v_to_matrix(
2811        &self,
2812        kernel: &FirthTauBetaPartialKernel,
2813        mat: &Array2<f64>,
2814    ) -> Array2<f64> {
2815        let n = self.x_dense.nrows();
2816        if mat.nrows() != n || mat.ncols() == 0 {
2817            return Array2::<f64>::zeros(mat.raw_dim());
2818        }
2819        let z = &self.x_reduced;
2820        let z_tau = &kernel.x_tau_reduced;
2821        let k_r = &self.k_reduced;
2822        let a_v = &kernel.a_v_reduced; // = +K_r I'_v K_r  (so K̇_v = -a_v)
2823        let dot_k_tau = &kernel.dot_k_reduced; // K̇_τ = -K_r İ_τ K_r
2824        let d_beta_dot_k = &kernel.d_beta_dot_k; // D_β K̇_τ[v]
2825        let mut out = Array2::<f64>::zeros(mat.raw_dim());
2826        for col in 0..mat.ncols() {
2827            let v = mat.column(col).to_owned();
2828            let s_zz = RemlState::reducedweighted_gram(z, &v);
2829            let s_z_ztau = RemlState::reduced_crossweighted_gram(z, z_tau, &v);
2830
2831            // Piece 1: (X_r K̇_v X_rᵀ ⊙ X_{r,τ} K_r X_rᵀ) · v
2832            //   = -rowwise_bilinear(Z, a_v S_zz K_r, Z_τ).
2833            let mid_1 = a_v.dot(&s_zz).dot(k_r);
2834            let t1 = -Self::rowwise_bilinear(z, &mid_1, z_tau);
2835            // Piece 2: (X_r K̇_v X_rᵀ ⊙ X_r K_r X_{r,τ}ᵀ) · v
2836            //   = -reduced_diag_gram(Z, a_v S_z_ztau K_r).
2837            let mid_2 = a_v.dot(&s_z_ztau).dot(k_r);
2838            let t2 = -RemlState::reduced_diag_gram(z, &mid_2);
2839            // Piece 3: (X_r K̇_v X_rᵀ ⊙ X_r K̇_τ X_rᵀ) · v
2840            //   = -reduced_diag_gram(Z, a_v S_zz K̇_τ).
2841            let mid_3 = a_v.dot(&s_zz).dot(dot_k_tau);
2842            let t3 = -RemlState::reduced_diag_gram(z, &mid_3);
2843            // Piece 4: (M ⊙ X_{r,τ} K̇_v X_rᵀ) · v
2844            //   = -rowwise_bilinear(Z, K_r S_zz a_v, Z_τ).
2845            let mid_4 = k_r.dot(&s_zz).dot(a_v);
2846            let t4 = -Self::rowwise_bilinear(z, &mid_4, z_tau);
2847            // Piece 5: (M ⊙ X_r K̇_v X_{r,τ}ᵀ) · v
2848            //   = -reduced_diag_gram(Z, K_r S_z_ztau a_v).
2849            let mid_5 = k_r.dot(&s_z_ztau).dot(a_v);
2850            let t5 = -RemlState::reduced_diag_gram(z, &mid_5);
2851            // Piece 6: (M ⊙ X_r D_β K̇_τ[v] X_rᵀ) · v.
2852            let t6 = RemlState::apply_hadamard_gram(z, k_r, d_beta_dot_k, &v);
2853
2854            // P_{τ,v} = 2·(pieces 1-3) + 2·(pieces 4-6); each group contributes
2855            // with the same outer factor 2.
2856            let y = 2.0 * (t1 + t2 + t3 + t4 + t5 + t6);
2857            out.column_mut(col).assign(&y);
2858        }
2859        out
2860    }
2861
2862    pub(crate) fn d_beta_hphi_tau_partial_apply(
2863        &self,
2864        x_tau: &Array2<f64>,
2865        kernel: &FirthTauBetaPartialKernel,
2866        rhs: &Array2<f64>,
2867    ) -> Array2<f64> {
2868        let p = self.x_dense.ncols();
2869        if rhs.nrows() != p {
2870            return Array2::<f64>::zeros((p, rhs.ncols()));
2871        }
2872        if rhs.ncols() == 0 || p == 0 {
2873            return Array2::<f64>::zeros((p, rhs.ncols()));
2874        }
2875        // Matrix-free block apply of D_β((H_φ)_τ|_β)[v] evaluated on a rhs V.
2876        // Structure follows hphi_tau_partial_apply but replaces every weight
2877        // and every reduced Gram with its β-derivative in direction v:
2878        //
2879        //   (H_φ)_τ|_β (V) = 0.5 [X_τᵀ r(V) + Xᵀ r_τ(V)].
2880        //
2881        // D_β[v] leaves X, X_τ fixed and acts on r, r_τ:
2882        //   D_β((H_φ)_τ|_β)[v](V) = 0.5 [X_τᵀ D_β r(V)[v] + Xᵀ D_β r_τ(V)[v]].
2883        let etav = fast_ab(&self.x_dense, rhs);
2884        let etav_tau = fast_ab(x_tau, rhs);
2885        let deta_v = &kernel.deta_v;
2886        let deta_tau_v = &kernel.deta_tau_v;
2887        let eta_tau = &kernel.deta_partial;
2888        let dot_h = &kernel.dot_h_partial;
2889
2890        // Reuse τ-kernel weights.  dotw1 = w'' ⊙ η̇_τ, dotw2 = w''' ⊙ η̇_τ.
2891        let dotw1 = &self.w2 * eta_tau;
2892        let dotw2 = &self.w3 * eta_tau;
2893
2894        // β-derivative scaling vectors in direction v:
2895        //   c_v              = D_β(w''·h)[v]    = w'''·δη_v·h + w''·dh_v
2896        //   b_vvec           = D_β(w')[v]       = w''·δη_v   (= kernel.b_vvec)
2897        //   d_beta_dotw1_vec = D_β(w''·η̇_τ)[v]  = w'''·δη_v·η̇_τ + w''·δη_{τ,v}
2898        //   d_beta_dotw2_vec = D_β(w'''·η̇_τ)[v] = w''''·δη_v·η̇_τ + w'''·δη_{τ,v}
2899        let c_v = &(&(&self.w3 * deta_v) * &self.h_diag) + &(&self.w2 * &kernel.dh_v);
2900        let b_vvec = &kernel.b_vvec;
2901        let d_beta_dotw1_vec = &(&(&self.w3 * deta_v) * eta_tau) + &(&self.w2 * deta_tau_v);
2902        let d_beta_dotw2_vec = &(&(&self.w4 * deta_v) * eta_tau) + &(&self.w3 * deta_tau_v);
2903
2904        // Single-τ pieces (identical to hphi_tau_partial_apply).
2905        let qv = &etav * &self.w1.view().insert_axis(Axis(1));
2906        let qv_tau = &etav * &dotw1.view().insert_axis(Axis(1))
2907            + &etav_tau * &self.w1.view().insert_axis(Axis(1));
2908        let m_qv = self.apply_pbar_to_matrix(&qv);
2909        // apply_mtau_to_matrix only reads x_tau_reduced and dot_k_reduced off
2910        // the τ-kernel, but owning the full struct is cheap.
2911        let tau_kernel_view = FirthTauPartialKernel {
2912            deta_partial: eta_tau.clone(),
2913            dotw1: dotw1.clone(),
2914            dotw2: dotw2.clone(),
2915            dot_h_partial: dot_h.clone(),
2916            x_tau_reduced: kernel.x_tau_reduced.clone(),
2917            dot_i_partial: kernel.dot_i_partial.clone(),
2918            dot_k_reduced: kernel.dot_k_reduced.clone(),
2919        };
2920        let m_qv_tau =
2921            self.apply_mtau_to_matrix(&tau_kernel_view, &qv) + self.apply_pbar_to_matrix(&qv_tau);
2922
2923        // β-derivatives of the single-τ pieces:
2924        //   D_β qv     = etav · D_β w'[v]       = etav · b_vvec
2925        //   D_β qv_tau = etav · D_β dotw1[v] + etav_tau · D_β w'[v]
2926        let d_beta_qv = &etav * &b_vvec.view().insert_axis(Axis(1));
2927        let d_beta_qv_tau = &etav * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2928            + &etav_tau * &b_vvec.view().insert_axis(Axis(1));
2929
2930        //   D_β m_qv = P_v · qv + P · D_β qv
2931        let d_beta_m_qv = self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv)
2932            + self.apply_pbar_to_matrix(&d_beta_qv);
2933
2934        //   D_β m_qv_tau = P_{τ,v}·qv + P_τ·D_β qv + P_v·qv_tau + P·D_β qv_tau
2935        let d_beta_m_qv_tau = self.apply_p_tau_v_to_matrix(kernel, &qv)
2936            + self.apply_mtau_to_matrix(&tau_kernel_view, &d_beta_qv)
2937            + self.apply_p_u_to_matrix(&kernel.a_v_reduced, &qv_tau)
2938            + self.apply_pbar_to_matrix(&d_beta_qv_tau);
2939
2940        // D_β rv[v] where rv = etav·(w''·h) − w'·m_qv:
2941        //   D_β rv[v] = etav·c_v − b_vvec·m_qv − w'·D_β m_qv.
2942        let d_beta_rv = &etav * &c_v.view().insert_axis(Axis(1))
2943            - &m_qv * &b_vvec.view().insert_axis(Axis(1))
2944            - &d_beta_m_qv * &self.w1.view().insert_axis(Axis(1));
2945
2946        // D_β rv_tau[v] where
2947        //   rv_tau = etav·dotw2·h + etav_tau·w''·h + etav·w''·dot_h
2948        //            − m_qv·dotw1 − m_qv_tau·w'.
2949        //
2950        //   D_β(dotw2·h)[v]   = (w''''·δη_v·η̇_τ + w'''·δη_{τ,v})·h
2951        //                        + dotw2·dh_v,
2952        //   D_β(w''·h)[v]     = c_v,
2953        //   D_β(w''·dot_h)[v] = w'''·δη_v·dot_h + w''·D_β dot_h[v],
2954        //   D_β dotw1[v]      = d_beta_dotw1_vec,
2955        //   D_β w'[v]         = b_vvec.
2956        let d_beta_dotw2_h = &(&d_beta_dotw2_vec * &self.h_diag) + &(&dotw2 * &kernel.dh_v);
2957        let d_beta_w2_doth = &(&(&self.w3 * deta_v) * dot_h) + &(&self.w2 * &kernel.d_beta_dot_h);
2958
2959        let d_beta_rv_tau = &etav * &d_beta_dotw2_h.view().insert_axis(Axis(1))
2960            + &etav_tau * &c_v.view().insert_axis(Axis(1))
2961            + &etav * &d_beta_w2_doth.view().insert_axis(Axis(1))
2962            - &d_beta_m_qv * &dotw1.view().insert_axis(Axis(1))
2963            - &m_qv * &d_beta_dotw1_vec.view().insert_axis(Axis(1))
2964            - &d_beta_m_qv_tau * &self.w1.view().insert_axis(Axis(1))
2965            - &m_qv_tau * &b_vvec.view().insert_axis(Axis(1));
2966
2967        0.5 * (x_tau.t().dot(&d_beta_rv) + self.x_dense.t().dot(&d_beta_rv_tau))
2968    }
2969}
2970
2971#[cfg(test)]
2972mod tests {
2973    use super::*;
2974    use crate::mixture_link::logit_inverse_link_jet5;
2975    use gam_problem::StandardLink;
2976    use ndarray::{Array1, Array2, array};
2977
2978    // Operator-equivalence oracle accessors (#1575). The production inner-PIRLS
2979    // path memoizes the β-independent design factor and reads diagnostics through
2980    // `pirls_diagnostics_from_factor`; these full-operator accessors are retained
2981    // ONLY for the equivalence unit tests, so they live in this `#[cfg(test)]`
2982    // module rather than gating individual production methods with `#[cfg(test)]`.
2983    impl FirthDenseOperator {
2984        pub(crate) fn pirls_hat_diag(&self) -> Array1<f64> {
2985            &self.w * &self.h_diag
2986        }
2987
2988        /// Per-observation Firth working-response shift `Δ_i = ½·(w'_i/w_i)·h_diag_i`
2989        /// (the link-general form; `w_i ≤ 0` rows get a zero shift). Matches the
2990        /// Jeffreys score `½ Σ_i w'_i h_i x_i` the outer REML differentiates.
2991        pub(crate) fn pirls_firth_score_shift(&self) -> Array1<f64> {
2992            let mut shift = Array1::<f64>::zeros(self.w.len());
2993            for i in 0..self.w.len() {
2994                let wi = self.w[i];
2995                if wi > 0.0 {
2996                    shift[i] = 0.5 * (self.w1[i] / wi) * self.h_diag[i];
2997                }
2998            }
2999            shift
3000        }
3001    }
3002
3003    pub(crate) fn build_logit_firth_dense_operator(
3004        x_dense: &Array2<f64>,
3005        eta: &Array1<f64>,
3006    ) -> Result<FirthDenseOperator, EstimationError> {
3007        FirthDenseOperator::build_with_observation_weights_impl(
3008            &InverseLink::Standard(StandardLink::Logit),
3009            x_dense,
3010            eta,
3011            None,
3012        )
3013    }
3014
3015    pub(crate) fn build_weighted_logit_firth_dense_operator(
3016        x_dense: &Array2<f64>,
3017        eta: &Array1<f64>,
3018        observation_weights: ndarray::ArrayView1<'_, f64>,
3019    ) -> Result<FirthDenseOperator, EstimationError> {
3020        FirthDenseOperator::build_with_observation_weights_impl(
3021            &InverseLink::Standard(StandardLink::Logit),
3022            x_dense,
3023            eta,
3024            Some(observation_weights),
3025        )
3026    }
3027
3028    pub(crate) fn logisticweight(eta: f64) -> f64 {
3029        logit_inverse_link_jet5(eta).d1
3030    }
3031
3032    pub(crate) fn firthphivalue(x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3033        let eta = x.dot(beta);
3034        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3035        op.jeffreys_logdet()
3036    }
3037
3038    pub(crate) fn firthgradphi(x: &Array2<f64>, beta: &Array1<f64>) -> Array1<f64> {
3039        let eta = x.dot(beta);
3040        let op = build_logit_firth_dense_operator(x, &eta).expect("firth operator");
3041        op.jeffreys_beta_gradient()
3042    }
3043
3044    pub(crate) fn weighted_firthphivalue(
3045        x: &Array2<f64>,
3046        beta: &Array1<f64>,
3047        observation_weights: &Array1<f64>,
3048    ) -> f64 {
3049        let eta = x.dot(beta);
3050        let op = build_weighted_logit_firth_dense_operator(x, &eta, observation_weights.view())
3051            .expect("weighted firth operator");
3052        op.jeffreys_logdet()
3053    }
3054
3055    #[test]
3056    pub(crate) fn firth_reduced_fisher_logdet_is_finite_for_barely_pd_matrix() {
3057        let fisher = array![[16.0, 0.0], [0.0, 1e-15]];
3058        let (k_reduced, half_log_det) = RemlState::reduced_fisher_inverse_and_half_logdet(&fisher)
3059            .expect("barely positive-definite reduced fisher");
3060        let expected = 0.5 * 16.0_f64.ln();
3061
3062        assert!(
3063            half_log_det.is_finite(),
3064            "barely positive-definite reduced fisher produced non-finite half logdet: {half_log_det}"
3065        );
3066        assert!(
3067            (half_log_det - expected).abs() < 1e-12,
3068            "near-null Fisher direction should be excluded from pseudo-logdet: got {half_log_det}, expected {expected}"
3069        );
3070        assert!(
3071            k_reduced.iter().all(|value| value.is_finite()),
3072            "barely positive-definite reduced fisher produced non-finite inverse entries: {k_reduced:?}"
3073        );
3074        assert!(
3075            k_reduced[[1, 1]].abs() < f64::EPSILON,
3076            "near-null Fisher direction should be excluded from pseudo-inverse: {k_reduced:?}"
3077        );
3078    }
3079
3080    #[test]
3081    pub(crate) fn firth_logisticweight_derivatives_match_finite_difference() {
3082        // Validates op.w[i] (= jet.d1) and op.w1..w4[i] (= jet.d2..jet.d5)
3083        // against direct central finite differences of the logistic inverse
3084        // link pdf w(η) = μ(η)(1−μ(η)).
3085        //
3086        // Nested central differences amplify roundoff by 1/h per nesting
3087        // level, so a d1fd-of-d1fd-of-d2fd cannot deliver the tolerances
3088        // that 4th-order agreement requires. The principled replacement is
3089        // a direct higher-order stencil whose truncation and roundoff are
3090        // both controlled by a single step h:
3091        //
3092        //   d1  (2-pt):  (f(z+h) − f(z−h)) / (2h)                       O(h²) trunc
3093        //   d2  (3-pt):  (f(z+h) − 2f(z) + f(z−h)) / h²                 O(h²) trunc
3094        //   d3  (4-pt):  (−f(z−2h)+2f(z−h)−2f(z+h)+f(z+2h)) / (2h³)     O(h²) trunc
3095        //   d4  (5-pt):  (f(z−2h)−4f(z−h)+6f(z)−4f(z+h)+f(z+2h)) / h⁴   O(h²) trunc
3096        //
3097        // At h = 1e-2 the logistic pdf and its higher derivatives stay of
3098        // order ≤ 1, so truncation O(h²·M) ≲ 1e-4 and roundoff O(ε/h^n)
3099        // is well below any asserted tolerance through the 4th order.
3100        let x = array![
3101            [1.0, -1.1, 0.2],
3102            [1.0, -0.5, -0.6],
3103            [1.0, 0.0, 0.3],
3104            [1.0, 0.8, -0.4],
3105            [1.0, 1.2, 0.7],
3106        ];
3107        let beta = array![0.15, -0.6, 0.35];
3108        let eta = x.dot(&beta);
3109        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3110
3111        let h = 1e-2_f64;
3112        let w = |z: f64| logisticweight(z);
3113        let d1direct = |z: f64| (w(z + h) - w(z - h)) / (2.0 * h);
3114        let d2direct = |z: f64| (w(z + h) - 2.0 * w(z) + w(z - h)) / (h * h);
3115        let d3direct = |z: f64| {
3116            (-w(z - 2.0 * h) + 2.0 * w(z - h) - 2.0 * w(z + h) + w(z + 2.0 * h)) / (2.0 * h.powi(3))
3117        };
3118        let d4direct = |z: f64| {
3119            (w(z - 2.0 * h) - 4.0 * w(z - h) + 6.0 * w(z) - 4.0 * w(z + h) + w(z + 2.0 * h))
3120                / h.powi(4)
3121        };
3122        for i in 0..eta.len() {
3123            let z = eta[i];
3124            let wfd = w(z);
3125            let w1fd = d1direct(z);
3126            let w2fd = d2direct(z);
3127            let w3fd = d3direct(z);
3128            let w4fd = d4direct(z);
3129
3130            assert!((op.w[i] - wfd).abs() < 1e-12);
3131            assert_eq!(op.w1[i].signum(), w1fd.signum());
3132            assert_eq!(op.w2[i].signum(), w2fd.signum());
3133            assert_eq!(op.w3[i].signum(), w3fd.signum());
3134            assert_eq!(op.w4[i].signum(), w4fd.signum());
3135            assert!((op.w1[i] - w1fd).abs() < 1e-5);
3136            assert!((op.w2[i] - w2fd).abs() < 1e-4);
3137            assert!((op.w3[i] - w3fd).abs() < 1e-4);
3138            assert!((op.w4[i] - w4fd).abs() < 1e-3);
3139        }
3140    }
3141
3142    #[test]
3143    pub(crate) fn weighted_firth_jeffreys_gradient_matches_finite_difference() {
3144        let x = array![
3145            [1.0, -0.7, 0.3],
3146            [1.0, -0.2, -0.4],
3147            [1.0, 0.5, 0.1],
3148            [1.0, 1.1, -0.6],
3149            [1.0, 1.6, 0.8],
3150        ];
3151        let beta = array![0.2, -0.45, 0.25];
3152        let observation_weights = array![1.0, 0.5, 2.0, 1.5, 0.75];
3153        let eta = x.dot(&beta);
3154        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3155            .expect("weighted firth operator");
3156        let grad = op.jeffreys_beta_gradient();
3157        let h = 1e-6;
3158
3159        for j in 0..beta.len() {
3160            let mut beta_plus = beta.clone();
3161            beta_plus[j] += h;
3162            let mut beta_minus = beta.clone();
3163            beta_minus[j] -= h;
3164            let fd = (weighted_firthphivalue(&x, &beta_plus, &observation_weights)
3165                - weighted_firthphivalue(&x, &beta_minus, &observation_weights))
3166                / (2.0 * h);
3167            assert!(
3168                (grad[j] - fd).abs() < 1e-5,
3169                "weighted Firth gradient mismatch at {}: analytic={}, fd={}",
3170                j,
3171                grad[j],
3172                fd
3173            );
3174        }
3175    }
3176
3177    // ----------------------------------------------------------------------
3178    // Link-general (probit) finite-difference proof of the Jeffreys/Firth
3179    // Φ(β) = ½ log|I_r(β)|, its β-gradient ∂Φ/∂β, and the β-Hessian
3180    // derivative D H_φ[u] exposed via `hphi_direction`. Logit is used as a
3181    // regression guard against the historical logit-pinned build.
3182    // ----------------------------------------------------------------------
3183
3184    pub(crate) fn build_link_firth_op(
3185        link: StandardLink,
3186        x: &Array2<f64>,
3187        beta: &Array1<f64>,
3188    ) -> FirthDenseOperator {
3189        let eta = x.dot(beta);
3190        FirthDenseOperator::build_with_observation_weights_impl(
3191            &InverseLink::Standard(link),
3192            x,
3193            &eta,
3194            None,
3195        )
3196        .expect("link-general firth operator")
3197    }
3198
3199    pub(crate) fn link_firth_phi(link: StandardLink, x: &Array2<f64>, beta: &Array1<f64>) -> f64 {
3200        build_link_firth_op(link, x, beta).jeffreys_logdet()
3201    }
3202
3203    pub(crate) fn link_firth_grad(
3204        link: StandardLink,
3205        x: &Array2<f64>,
3206        beta: &Array1<f64>,
3207    ) -> Array1<f64> {
3208        build_link_firth_op(link, x, beta).jeffreys_beta_gradient()
3209    }
3210
3211    /// Central-difference Jacobian of the *analytic* Firth gradient, i.e. a
3212    /// numerical realization of the β-Hessian H_φ = ∂g/∂β. The Newton/REML
3213    /// path consumes H_φ (and its directional derivative) so this is the
3214    /// matrix the analytic curvature must reproduce.
3215    pub(crate) fn numeric_firth_hessian(
3216        link: StandardLink,
3217        x: &Array2<f64>,
3218        beta: &Array1<f64>,
3219        h: f64,
3220    ) -> Array2<f64> {
3221        let p = beta.len();
3222        let mut hess = Array2::<f64>::zeros((p, p));
3223        for j in 0..p {
3224            let mut bp = beta.clone();
3225            bp[j] += h;
3226            let mut bm = beta.clone();
3227            bm[j] -= h;
3228            let gp = link_firth_grad(link, x, &bp);
3229            let gm = link_firth_grad(link, x, &bm);
3230            let col = (&gp - &gm) / (2.0 * h);
3231            hess.column_mut(j).assign(&col);
3232        }
3233        hess
3234    }
3235
3236    /// #1575: the cached single-index second-direction path
3237    /// (`tk_second_direction_eye_cache` + `hphisecond_direction_apply_eye_cached`)
3238    /// must be BIT-IDENTICAL to the per-pair `hphisecond_direction_apply(.., &eye)`
3239    /// it replaces in the exact-Hessian TK outer loop. This locks the work-elision
3240    /// invariant: it removes redundant O(n·r²·p) reduced Hadamard-Gram applies, it
3241    /// must NOT change a single bit of the resulting Hessian contribution.
3242    #[test]
3243    fn hphisecond_eye_cached_matches_per_pair_bit_identical_1575() {
3244        // A 6×3 logit design with a few distinct η directions (mirrors the
3245        // multi-smooth penalty directions the TK loop contracts over).
3246        let x = array![
3247            [1.0, -1.10, 0.35],
3248            [1.0, -0.40, -0.65],
3249            [1.0, 0.15, 0.20],
3250            [1.0, 0.80, -0.45],
3251            [1.0, 1.25, 0.70],
3252            [1.0, -0.55, 0.95],
3253        ];
3254        let beta = array![0.20, -0.55, 0.30];
3255        let op = build_link_firth_op(StandardLink::Logit, &x, &beta);
3256        let p = x.ncols();
3257
3258        // Three β-direction δη vectors playing the role of eta_i[idx].
3259        let deta_list = [
3260            x.dot(&array![0.9, -0.3, 0.2]),
3261            x.dot(&array![-0.4, 0.7, 0.1]),
3262            x.dot(&array![0.1, 0.2, -0.8]),
3263        ];
3264        let dirs: Vec<FirthDirection> = deta_list
3265            .iter()
3266            .map(|d| op.direction_from_deta(d.clone()))
3267            .collect();
3268
3269        let eye = Array2::<f64>::eye(p);
3270        let cache = op.tk_second_direction_eye_cache(&dirs);
3271        for i in 0..dirs.len() {
3272            for j in 0..=i {
3273                let reference = op.hphisecond_direction_apply(&dirs[i], &dirs[j], &eye);
3274                let cached = op.hphisecond_direction_apply_eye_cached(&cache, &dirs, i, j);
3275                assert_eq!(
3276                    reference.dim(),
3277                    cached.dim(),
3278                    "shape mismatch at pair ({i},{j})"
3279                );
3280                for (a, b) in reference.iter().zip(cached.iter()) {
3281                    assert_eq!(
3282                        a.to_bits(),
3283                        b.to_bits(),
3284                        "cached D²H_φ[{i},{j}] is not bit-identical to per-pair: \
3285                         reference={a}, cached={b}"
3286                    );
3287                }
3288            }
3289        }
3290    }
3291
3292    /// A fixed, well-conditioned full-rank design (deterministic, no RNG).
3293    pub(crate) fn fixed_design_5x3() -> Array2<f64> {
3294        array![
3295            [1.0, -1.10, 0.35],
3296            [1.0, -0.40, -0.65],
3297            [1.0, 0.15, 0.20],
3298            [1.0, 0.80, -0.45],
3299            [1.0, 1.25, 0.70],
3300        ]
3301    }
3302
3303    #[test]
3304    pub(crate) fn link_general_logit_path_reproduces_historical_logit_build() {
3305        // Guard: the StandardLink::Logit path through the link-general builder
3306        // must be byte-identical to the historical logit-pinned operator for
3307        // Φ, the β-gradient, the PIRLS hat diagonal, and the cached weight
3308        // jets w, w'..w''''.
3309        let x = fixed_design_5x3();
3310        let beta = array![0.20, -0.55, 0.30];
3311        let eta = x.dot(&beta);
3312
3313        let historical = build_logit_firth_dense_operator(&x, &eta).expect("historical logit");
3314        let link_general = FirthDenseOperator::build_with_observation_weights_impl(
3315            &InverseLink::Standard(StandardLink::Logit),
3316            &x,
3317            &eta,
3318            None,
3319        )
3320        .expect("link-general logit");
3321
3322        assert_eq!(
3323            historical.jeffreys_logdet(),
3324            link_general.jeffreys_logdet(),
3325            "logit Φ must be bit-identical through the link-general path"
3326        );
3327        let g_hist = historical.jeffreys_beta_gradient();
3328        let g_link = link_general.jeffreys_beta_gradient();
3329        for j in 0..g_hist.len() {
3330            assert_eq!(
3331                g_hist[j], g_link[j],
3332                "logit gradient component {j} must be bit-identical"
3333            );
3334        }
3335        let hat_hist = historical.pirls_hat_diag();
3336        let hat_link = link_general.pirls_hat_diag();
3337        for i in 0..hat_hist.len() {
3338            assert_eq!(
3339                hat_hist[i], hat_link[i],
3340                "logit PIRLS hat diagonal {i} must be bit-identical"
3341            );
3342        }
3343        for i in 0..eta.len() {
3344            assert_eq!(historical.w[i], link_general.w[i]);
3345            assert_eq!(historical.w1[i], link_general.w1[i]);
3346            assert_eq!(historical.w2[i], link_general.w2[i]);
3347            assert_eq!(historical.w3[i], link_general.w3[i]);
3348            assert_eq!(historical.w4[i], link_general.w4[i]);
3349        }
3350    }
3351
3352    #[test]
3353    pub(crate) fn link_general_probit_jeffreys_gradient_matches_finite_difference() {
3354        // PROBIT correctness: ∂Φ/∂β from `jeffreys_beta_gradient` must match a
3355        // central finite difference of Φ(β) on a well-conditioned design.
3356        let x = fixed_design_5x3();
3357        let beta = array![0.10, -0.40, 0.25];
3358        let grad = link_firth_grad(StandardLink::Probit, &x, &beta);
3359        let h = 1e-6_f64;
3360        let mut max_rel = 0.0_f64;
3361        for j in 0..beta.len() {
3362            let mut bp = beta.clone();
3363            bp[j] += h;
3364            let mut bm = beta.clone();
3365            bm[j] -= h;
3366            let fd = (link_firth_phi(StandardLink::Probit, &x, &bp)
3367                - link_firth_phi(StandardLink::Probit, &x, &bm))
3368                / (2.0 * h);
3369            let denom = grad[j].abs().max(fd.abs()).max(1e-8);
3370            let rel = (grad[j] - fd).abs() / denom;
3371            max_rel = max_rel.max(rel);
3372            assert!(
3373                rel < 1e-6,
3374                "probit Firth gradient mismatch at {j}: analytic={}, fd={}, rel={:e}",
3375                grad[j],
3376                fd,
3377                rel
3378            );
3379        }
3380        assert!(
3381            max_rel < 1e-6,
3382            "probit gradient worst relative error {max_rel:e} exceeds 1e-6"
3383        );
3384    }
3385
3386    #[test]
3387    pub(crate) fn link_general_probit_hphi_direction_matches_finite_difference_of_hessian() {
3388        // PROBIT Hessian: `hphi_direction(direction_from_deta(X·u))` is the
3389        // analytic directional derivative D H_φ[u] of the β-Hessian. Verify it
3390        // against the central finite difference of the (numerically realized)
3391        // β-Hessian H_φ along u. The numeric H_φ at each shifted β is itself a
3392        // finite difference of the *analytic* gradient, so the base operand is
3393        // analytic at first order; only the directional step is differenced
3394        // here.
3395        let x = fixed_design_5x3();
3396        let beta = array![0.10, -0.40, 0.25];
3397        let p = beta.len();
3398
3399        // Probe several directions, including non-axis-aligned ones.
3400        let directions = [
3401            array![1.0, 0.0, 0.0],
3402            array![0.0, 1.0, 0.0],
3403            array![0.0, 0.0, 1.0],
3404            array![0.7, -0.5, 0.3],
3405        ];
3406
3407        let h_inner = 1e-4_f64; // step for the numeric Hessian (FD of analytic grad)
3408        let h_dir = 1e-4_f64; // step for the directional derivative of the Hessian
3409        let mut worst = 0.0_f64;
3410        for u in directions.iter() {
3411            let op = build_link_firth_op(StandardLink::Probit, &x, &beta);
3412            let deta = x.dot(u);
3413            let dir = op.direction_from_deta(deta);
3414            let analytic = op.hphi_direction(&dir);
3415
3416            let beta_plus = &beta + &(u * h_dir);
3417            let beta_minus = &beta - &(u * h_dir);
3418            let hess_plus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_plus, h_inner);
3419            let hess_minus = numeric_firth_hessian(StandardLink::Probit, &x, &beta_minus, h_inner);
3420            let fd = (&hess_plus - &hess_minus) / (2.0 * h_dir);
3421
3422            let mut scale = 1e-6_f64;
3423            for r in 0..p {
3424                for c in 0..p {
3425                    scale = scale.max(analytic[[r, c]].abs()).max(fd[[r, c]].abs());
3426                }
3427            }
3428            for r in 0..p {
3429                for c in 0..p {
3430                    let rel = (analytic[[r, c]] - fd[[r, c]]).abs() / scale;
3431                    worst = worst.max(rel);
3432                    assert!(
3433                        rel < 5e-3,
3434                        "probit D H_φ[u] mismatch at ({r},{c}) for u={u:?}: analytic={}, fd={}, rel={:e}",
3435                        analytic[[r, c]],
3436                        fd[[r, c]],
3437                        rel
3438                    );
3439                }
3440            }
3441        }
3442        assert!(
3443            worst < 5e-3,
3444            "probit Hessian-derivative worst relative error {worst:e} exceeds 5e-3"
3445        );
3446    }
3447
3448    #[test]
3449    pub(crate) fn link_general_probit_jeffreys_finite_on_rank_deficient_design() {
3450        // Identifiable-subspace behavior: a rank-deficient design (column 3 =
3451        // column 1 + column 2) must yield a finite Φ = ½ log|Uᵀ W U|, a finite
3452        // gradient, and agree with the explicit reduced two-column design.
3453        let x_full = array![
3454            [1.0, -1.20, -0.20],
3455            [1.0, -0.40, 0.60],
3456            [1.0, 0.10, 1.10],
3457            [1.0, 0.70, 1.70],
3458            [1.0, 1.30, 2.30],
3459        ];
3460        let x_reduced = array![
3461            [1.0, -1.20],
3462            [1.0, -0.40],
3463            [1.0, 0.10],
3464            [1.0, 0.70],
3465            [1.0, 1.30],
3466        ];
3467        let beta_full = array![0.25, -0.50, 0.15];
3468        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3469
3470        let phi_full = link_firth_phi(StandardLink::Probit, &x_full, &beta_full);
3471        let phi_reduced = link_firth_phi(StandardLink::Probit, &x_reduced, &beta_reduced);
3472        assert!(
3473            phi_full.is_finite(),
3474            "probit Φ on rank-deficient design must be finite, got {phi_full}"
3475        );
3476        assert!(
3477            (phi_full - phi_reduced).abs() < 1e-12,
3478            "probit reduced |Uᵀ W U| form mismatch: full={phi_full}, reduced={phi_reduced}"
3479        );
3480
3481        let op_full = build_link_firth_op(StandardLink::Probit, &x_full, &beta_full);
3482        let grad_full = op_full.jeffreys_beta_gradient();
3483        assert!(
3484            grad_full.iter().all(|v| v.is_finite()),
3485            "probit gradient on rank-deficient design must be finite: {grad_full:?}"
3486        );
3487        let hat_full = op_full.pirls_hat_diag();
3488        let hat_reduced =
3489            build_link_firth_op(StandardLink::Probit, &x_reduced, &beta_reduced).pirls_hat_diag();
3490        for i in 0..hat_full.len() {
3491            assert!(
3492                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3493                "probit hat diagonal {i} mismatch on rank-deficient design: full={}, reduced={}",
3494                hat_full[i],
3495                hat_reduced[i]
3496            );
3497        }
3498    }
3499
3500    #[test]
3501    pub(crate) fn rank_deficient_and_explicit_reduced_designs_share_same_jeffreys_objective() {
3502        // Column 3 is exactly column 1 + column 2, so the original design is
3503        // rank-deficient but its identifiable subspace is represented exactly by
3504        // the explicit two-column reduced design below.
3505        let x_full = array![
3506            [1.0, -1.2, -0.2],
3507            [1.0, -0.4, 0.6],
3508            [1.0, 0.1, 1.1],
3509            [1.0, 0.7, 1.7],
3510            [1.0, 1.3, 2.3],
3511        ];
3512        let x_reduced = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3513        let beta_full: ndarray::Array1<f64> = array![0.25, -0.5, 0.15];
3514        let beta_reduced = array![beta_full[0] + beta_full[2], beta_full[1] + beta_full[2]];
3515        let eta_full = x_full.dot(&beta_full);
3516        let eta_reduced = x_reduced.dot(&beta_reduced);
3517        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3518
3519        for i in 0..eta_full.len() {
3520            assert!(
3521                (eta_full[i] - eta_reduced[i]).abs() < 1e-12,
3522                "eta mismatch at row {i}: full={} reduced={}",
3523                eta_full[i],
3524                eta_reduced[i]
3525            );
3526        }
3527
3528        let op_full = build_weighted_logit_firth_dense_operator(
3529            &x_full,
3530            &eta_full,
3531            observation_weights.view(),
3532        )
3533        .expect("full firth operator");
3534        let op_reduced = build_weighted_logit_firth_dense_operator(
3535            &x_reduced,
3536            &eta_reduced,
3537            observation_weights.view(),
3538        )
3539        .expect("reduced firth operator");
3540
3541        assert!(
3542            (op_full.jeffreys_logdet() - op_reduced.jeffreys_logdet()).abs() < 1e-12,
3543            "Jeffreys logdet mismatch between rank-deficient full design and its explicit reduced identifiable basis: full={} reduced={}",
3544            op_full.jeffreys_logdet(),
3545            op_reduced.jeffreys_logdet()
3546        );
3547
3548        let hat_full = op_full.pirls_hat_diag();
3549        let hat_reduced = op_reduced.pirls_hat_diag();
3550        for i in 0..hat_full.len() {
3551            assert!(
3552                (hat_full[i] - hat_reduced[i]).abs() < 1e-12,
3553                "PIRLS hat-diagonal mismatch at row {i}: full={} reduced={}",
3554                hat_full[i],
3555                hat_reduced[i]
3556            );
3557        }
3558    }
3559
3560    #[test]
3561    pub(crate) fn full_rank_reparameterizations_share_same_jeffreys_objective() {
3562        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3563        let basis = array![[1.4, -0.3], [0.6, 1.1]];
3564        let x_reparameterized = x.dot(&basis);
3565        let beta = array![0.25, -0.5];
3566        let basis_det: f64 = basis[[0, 0]] * basis[[1, 1]] - basis[[0, 1]] * basis[[1, 0]];
3567        assert!(
3568            basis_det.abs() > 1e-12,
3569            "basis transform must be invertible"
3570        );
3571        let basis_inv = array![
3572            [basis[[1, 1]] / basis_det, -basis[[0, 1]] / basis_det],
3573            [-basis[[1, 0]] / basis_det, basis[[0, 0]] / basis_det],
3574        ];
3575        let beta_reparameterized = basis_inv.dot(&beta);
3576        let eta = x.dot(&beta);
3577        let eta_reparameterized = x_reparameterized.dot(&beta_reparameterized);
3578        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3579
3580        for i in 0..eta.len() {
3581            assert!(
3582                (eta[i] - eta_reparameterized[i]).abs() < 1e-12,
3583                "eta mismatch at row {i}: original={} reparameterized={}",
3584                eta[i],
3585                eta_reparameterized[i]
3586            );
3587        }
3588
3589        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3590            .expect("original firth operator");
3591        let op_reparameterized = build_weighted_logit_firth_dense_operator(
3592            &x_reparameterized,
3593            &eta_reparameterized,
3594            observation_weights.view(),
3595        )
3596        .expect("reparameterized firth operator");
3597
3598        assert!(
3599            (op.jeffreys_logdet() - op_reparameterized.jeffreys_logdet()).abs() < 1e-12,
3600            "Jeffreys logdet mismatch under invertible reparameterization: original={} reparameterized={}",
3601            op.jeffreys_logdet(),
3602            op_reparameterized.jeffreys_logdet()
3603        );
3604
3605        let hat = op.pirls_hat_diag();
3606        let hat_reparameterized = op_reparameterized.pirls_hat_diag();
3607        for i in 0..hat.len() {
3608            assert!(
3609                (hat[i] - hat_reparameterized[i]).abs() < 1e-12,
3610                "PIRLS hat-diagonal mismatch at row {i}: original={} reparameterized={}",
3611                hat[i],
3612                hat_reparameterized[i]
3613            );
3614        }
3615    }
3616
3617    #[test]
3618    pub(crate) fn full_rank_identifiable_basis_diagonalizes_design_metric() {
3619        let x = array![[1.0, -1.2], [1.0, -0.4], [1.0, 0.1], [1.0, 0.7], [1.0, 1.3],];
3620        let beta = array![0.25, -0.5];
3621        let eta = x.dot(&beta);
3622        let observation_weights = array![1.0, 0.5, 1.75, 0.9, 1.2];
3623        let op = build_weighted_logit_firth_dense_operator(&x, &eta, observation_weights.view())
3624            .expect("firth operator");
3625
3626        let reduced_metric = fast_atb(&op.x_reduced, &op.x_reduced);
3627        for i in 0..reduced_metric.nrows() {
3628            for j in 0..reduced_metric.ncols() {
3629                if i == j {
3630                    continue;
3631                }
3632                assert!(
3633                    reduced_metric[[i, j]].abs() < 1e-10,
3634                    "full-rank identifiable basis should diagonalize X_r'X_r: metric[{i},{j}]={}",
3635                    reduced_metric[[i, j]]
3636                );
3637            }
3638        }
3639    }
3640
3641    #[test]
3642    pub(crate) fn firth_mixedsecond_direction_apply_is_symmetric_in_direction_order() {
3643        let x = array![
3644            [1.0, -1.0, 0.2],
3645            [1.0, -0.6, -0.3],
3646            [1.0, -0.1, 0.5],
3647            [1.0, 0.3, -0.7],
3648            [1.0, 0.8, 0.1],
3649            [1.0, 1.2, -0.4],
3650        ];
3651        let beta = array![0.1, -0.25, 0.2];
3652        let eta = x.dot(&beta);
3653        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3654
3655        let u = array![0.3, -0.2, 0.4];
3656        let v = array![-0.5, 0.1, 0.25];
3657        let du = op.direction_from_deta(x.dot(&u));
3658        let dv = op.direction_from_deta(x.dot(&v));
3659
3660        let eye = Array2::<f64>::eye(x.ncols());
3661        let uv = op.hphisecond_direction_apply(&du, &dv, &eye);
3662        let vu = op.hphisecond_direction_apply(&dv, &du, &eye);
3663
3664        for i in 0..uv.nrows() {
3665            for j in 0..uv.ncols() {
3666                let a = uv[[i, j]];
3667                let b = vu[[i, j]];
3668                assert_eq!(
3669                    a.signum(),
3670                    b.signum(),
3671                    "mixed direction sign mismatch at ({i},{j}): uv={a} vu={b}"
3672                );
3673                assert!(
3674                    (a - b).abs() < 2e-7,
3675                    "mixed direction mismatch at ({i},{j}): uv={a} vu={b}"
3676                );
3677            }
3678        }
3679    }
3680
3681    #[test]
3682    pub(crate) fn firth_direction_matrix_form_matches_apply_identity_form() {
3683        let x = array![
3684            [1.0, -1.1, 0.2],
3685            [1.0, -0.6, -0.3],
3686            [1.0, -0.1, 0.5],
3687            [1.0, 0.3, -0.7],
3688            [1.0, 0.8, 0.1],
3689            [1.0, 1.2, -0.4],
3690        ];
3691        let beta = array![0.08, -0.22, 0.27];
3692        let eta = x.dot(&beta);
3693        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3694        let u = Array1::from_vec(vec![0.25, -0.4, 0.35]);
3695        let dir = op.direction_from_deta(x.dot(&u));
3696
3697        let p = x.ncols();
3698        let eye = Array2::<f64>::eye(p);
3699        let mut via_apply = op.hphi_direction_apply(&dir, &eye);
3700        for i in 0..p {
3701            for j in 0..i {
3702                let sym = 0.5 * (via_apply[[i, j]] + via_apply[[j, i]]);
3703                via_apply[[i, j]] = sym;
3704                via_apply[[j, i]] = sym;
3705            }
3706        }
3707        let direct = op.hphi_direction(&dir);
3708        let diff = &direct - &via_apply;
3709        let err = diff.iter().map(|v| v * v).sum::<f64>().sqrt();
3710        assert!(err < 1e-10, "direction/apply mismatch: {err:e}");
3711    }
3712
3713    #[test]
3714    pub(crate) fn firthphi_tau_partial_matches_finite_difference_logdet() {
3715        let x = array![
3716            [1.0, -1.0, 0.2],
3717            [1.0, -0.6, -0.3],
3718            [1.0, -0.1, 0.5],
3719            [1.0, 0.3, -0.7],
3720            [1.0, 0.8, 0.1],
3721            [1.0, 1.2, -0.4],
3722        ];
3723        let x_tau = array![
3724            [0.0, 0.15, -0.05],
3725            [0.0, -0.10, 0.02],
3726            [0.0, 0.08, 0.04],
3727            [0.0, -0.06, -0.03],
3728            [0.0, 0.05, 0.01],
3729            [0.0, -0.12, 0.06],
3730        ];
3731        let beta = array![0.1, -0.25, 0.2];
3732        let eta = x.dot(&beta);
3733        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3734        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).phi_tau_partial;
3735
3736        let h = 1e-6;
3737        let x_plus = &x + &(h * &x_tau);
3738        let x_minus = &x - &(h * &x_tau);
3739        let fd = (firthphivalue(&x_plus, &beta) - firthphivalue(&x_minus, &beta)) / (2.0 * h);
3740
3741        assert!(
3742            (analytic - fd).abs() < 1e-6,
3743            "Phi_tau mismatch: analytic={analytic:.12e}, fd={fd:.12e}"
3744        );
3745    }
3746
3747    #[test]
3748    pub(crate) fn firth_gphi_tau_matches_finite_differencegradphi() {
3749        let x = array![
3750            [1.0, -1.0, 0.2],
3751            [1.0, -0.6, -0.3],
3752            [1.0, -0.1, 0.5],
3753            [1.0, 0.3, -0.7],
3754            [1.0, 0.8, 0.1],
3755            [1.0, 1.2, -0.4],
3756        ];
3757        let x_tau = array![
3758            [0.0, 0.15, -0.05],
3759            [0.0, -0.10, 0.02],
3760            [0.0, 0.08, 0.04],
3761            [0.0, -0.06, -0.03],
3762            [0.0, 0.05, 0.01],
3763            [0.0, -0.12, 0.06],
3764        ];
3765        let beta = array![0.1, -0.25, 0.2];
3766        let eta = x.dot(&beta);
3767        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3768        let analytic = op.exact_tau_kernel(&x_tau, &beta, false).gphi_tau;
3769
3770        let h = 1e-6;
3771        let x_plus = &x + &(h * &x_tau);
3772        let x_minus = &x - &(h * &x_tau);
3773        let fd = (firthgradphi(&x_plus, &beta) - firthgradphi(&x_minus, &beta)) / (2.0 * h);
3774
3775        let err = (&analytic - &fd).iter().map(|v| v * v).sum::<f64>().sqrt();
3776        assert!(
3777            err < 1e-6,
3778            "gphi_tau mismatch: analytic={analytic:?}, fd={fd:?}, err={err:e}"
3779        );
3780    }
3781
3782    /// Verify pair.a scalar (`phi_tau_tau_partial`) by central-FD'ing the
3783    /// single-τ scalar `phi_tau_partial` along τ_j at fixed β.
3784    /// Identity: ∂/∂τ_j [Φ_{τ_i}|β] = Φ_{τ_iτ_j}|β.
3785    /// Tolerance 1e-7 relative.
3786    #[test]
3787    pub(crate) fn firthphi_tau_tau_pair_scalar_matches_finite_difference() {
3788        let x = array![
3789            [1.0, -1.0, 0.2],
3790            [1.0, -0.6, -0.3],
3791            [1.0, -0.1, 0.5],
3792            [1.0, 0.3, -0.7],
3793            [1.0, 0.8, 0.1],
3794            [1.0, 1.2, -0.4],
3795        ];
3796        let x_tau_i = array![
3797            [0.0, 0.15, -0.05],
3798            [0.0, -0.10, 0.02],
3799            [0.0, 0.08, 0.04],
3800            [0.0, -0.06, -0.03],
3801            [0.0, 0.05, 0.01],
3802            [0.0, -0.12, 0.06],
3803        ];
3804        let x_tau_j = array![
3805            [0.0, -0.04, 0.11],
3806            [0.0, 0.09, -0.02],
3807            [0.0, -0.06, 0.07],
3808            [0.0, 0.10, -0.05],
3809            [0.0, -0.03, 0.08],
3810            [0.0, 0.07, -0.09],
3811        ];
3812        let beta = array![0.1, -0.25, 0.2];
3813        let eta = x.dot(&beta);
3814        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3815
3816        let analytic = op
3817            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3818            .phi_tau_tau_partial;
3819
3820        let h = 1e-5_f64;
3821        let eval_phi_tau_i = |x_eval: &Array2<f64>| -> f64 {
3822            let eta_e = x_eval.dot(&beta);
3823            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3824            op_e.exact_tau_kernel(&x_tau_i, &beta, false)
3825                .phi_tau_partial
3826        };
3827        let x_plus = &x + &(h * &x_tau_j);
3828        let x_minus = &x - &(h * &x_tau_j);
3829        let fd = (eval_phi_tau_i(&x_plus) - eval_phi_tau_i(&x_minus)) / (2.0 * h);
3830
3831        let rel = (analytic - fd).abs() / fd.abs().max(1.0);
3832        assert!(
3833            rel < 1e-7,
3834            "pair.a scalar mismatch: analytic={analytic:.6e}, fd={fd:.6e}, rel={rel:.3e}"
3835        );
3836    }
3837
3838    /// Verify pair.g p-vector (`gphi_tau_tau`) by central-FD'ing the single-τ
3839    /// `gphi_tau` along τ_j at fixed β.
3840    /// Identity: ∂/∂τ_j [(gΦ)_{τ_i}|β] = (gΦ)_{τ_iτ_j}|β.
3841    /// Tolerance 1e-7 relative max-abs.
3842    #[test]
3843    pub(crate) fn firthphi_tau_tau_pair_g_vector_matches_finite_difference() {
3844        let x = array![
3845            [1.0, -1.0, 0.2],
3846            [1.0, -0.6, -0.3],
3847            [1.0, -0.1, 0.5],
3848            [1.0, 0.3, -0.7],
3849            [1.0, 0.8, 0.1],
3850            [1.0, 1.2, -0.4],
3851        ];
3852        let x_tau_i = array![
3853            [0.0, 0.15, -0.05],
3854            [0.0, -0.10, 0.02],
3855            [0.0, 0.08, 0.04],
3856            [0.0, -0.06, -0.03],
3857            [0.0, 0.05, 0.01],
3858            [0.0, -0.12, 0.06],
3859        ];
3860        let x_tau_j = array![
3861            [0.0, -0.04, 0.11],
3862            [0.0, 0.09, -0.02],
3863            [0.0, -0.06, 0.07],
3864            [0.0, 0.10, -0.05],
3865            [0.0, -0.03, 0.08],
3866            [0.0, 0.07, -0.09],
3867        ];
3868        let beta = array![0.1, -0.25, 0.2];
3869        let eta = x.dot(&beta);
3870        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3871
3872        let analytic = op
3873            .exact_tau_tau_kernel(&x_tau_i, &x_tau_j, None, &beta, false)
3874            .gphi_tau_tau;
3875
3876        let h = 1e-5_f64;
3877        let eval_gphi_tau_i = |x_eval: &Array2<f64>| -> Array1<f64> {
3878            let eta_e = x_eval.dot(&beta);
3879            let op_e = build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed op");
3880            op_e.exact_tau_kernel(&x_tau_i, &beta, false).gphi_tau
3881        };
3882        let x_plus = &x + &(h * &x_tau_j);
3883        let x_minus = &x - &(h * &x_tau_j);
3884        let fd = (&eval_gphi_tau_i(&x_plus) - &eval_gphi_tau_i(&x_minus)) / (2.0 * h);
3885
3886        let scale = analytic
3887            .iter()
3888            .chain(fd.iter())
3889            .map(|v| v.abs())
3890            .fold(0.0_f64, f64::max)
3891            .max(1.0);
3892        let err_max = (&analytic - &fd)
3893            .iter()
3894            .map(|v| v.abs())
3895            .fold(0.0_f64, f64::max);
3896        let rel = err_max / scale;
3897        assert!(
3898            rel < 1e-7,
3899            "pair.g p-vector mismatch: rel={rel:.3e}\nanalytic={analytic:?}\nfd={fd:?}"
3900        );
3901    }
3902
3903    /// Verify the Primitive A body (`hphi_tau_tau_partial_apply`) against a
3904    /// finite-difference reference of the single-τ Primitive (
3905    /// `hphi_tau_partial_apply`).
3906    ///
3907    /// Identity under test:
3908    ///     ∂/∂τ_j  { (H_φ)_τ_i |_β · V }   =   ∂²H_φ/∂τ_i ∂τ_j |_β · V.
3909    ///
3910    /// Central-difference reference:
3911    ///   1. Evaluate the single-τ primitive at x, and at x ± h·X_τ_j
3912    ///      — rebuild the FirthDenseOperator (with fresh identifiable Q)
3913    ///      at each perturbed design; H_φ applied to a p-space rhs is
3914    ///      basis-invariant in unreduced β-coords, so Q rotation does not
3915    ///      contaminate the comparison.
3916    ///   2. FD_{i,j} = (T_{plus} − T_{minus}) / (2h) with T = hphi_tau_i_apply(V).
3917    ///   3. Contract both (i,j) and (j,i) directions and verify symmetry
3918    ///      of the analytic as a cross-check.
3919    ///
3920    /// Tolerance: 1e-7 relative max-abs (h chosen to balance truncation
3921    /// error at ~h² and evaluator roundoff at ~ε/h).
3922    #[test]
3923    pub(crate) fn firthphi_tau_tau_partial_matches_finite_difference() {
3924        let x = array![
3925            [1.0, -1.0, 0.2],
3926            [1.0, -0.6, -0.3],
3927            [1.0, -0.1, 0.5],
3928            [1.0, 0.3, -0.7],
3929            [1.0, 0.8, 0.1],
3930            [1.0, 1.2, -0.4],
3931        ];
3932        let x_tau_i = array![
3933            [0.0, 0.15, -0.05],
3934            [0.0, -0.10, 0.02],
3935            [0.0, 0.08, 0.04],
3936            [0.0, -0.06, -0.03],
3937            [0.0, 0.05, 0.01],
3938            [0.0, -0.12, 0.06],
3939        ];
3940        let x_tau_j = array![
3941            [0.0, -0.04, 0.11],
3942            [0.0, 0.09, -0.02],
3943            [0.0, -0.06, 0.07],
3944            [0.0, 0.10, -0.05],
3945            [0.0, -0.03, 0.08],
3946            [0.0, 0.07, -0.09],
3947        ];
3948        let beta = array![0.1, -0.25, 0.2];
3949        let eta = x.dot(&beta);
3950        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
3951        let p = x.ncols();
3952
3953        // Reproducible small rhs block (p × m).
3954        let m = 3usize;
3955        let mut rhs = Array2::<f64>::zeros((p, m));
3956        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
3957        for r in 0..p {
3958            for c in 0..m {
3959                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
3960            }
3961        }
3962
3963        // ── Analytic τ×τ pair apply at base design (x_tau_tau = None,
3964        //    deta_ij = None, i.e. design is linear in τ).
3965        let x_tau_i_reduced = op.reduce_explicit_design(&x_tau_i);
3966        let x_tau_j_reduced = op.reduce_explicit_design(&x_tau_j);
3967        let deta_i = x_tau_i.dot(&beta);
3968        let deta_j = x_tau_j.dot(&beta);
3969        let (dot_i_i, dot_h_i) = op.dot_i_and_h_from_reduced(&x_tau_i_reduced, &deta_i);
3970        let (dot_i_j, dot_h_j) = op.dot_i_and_h_from_reduced(&x_tau_j_reduced, &deta_j);
3971
3972        let kernel_ij = op.hphi_tau_tau_partial_prepare_from_partials(
3973            x_tau_i_reduced.clone(),
3974            x_tau_j_reduced.clone(),
3975            &deta_i,
3976            &deta_j,
3977            dot_h_i.clone(),
3978            dot_h_j.clone(),
3979            dot_i_i.clone(),
3980            dot_i_j.clone(),
3981            None,
3982            None,
3983        );
3984        let kernel_ji = op.hphi_tau_tau_partial_prepare_from_partials(
3985            x_tau_j_reduced,
3986            x_tau_i_reduced,
3987            &deta_j,
3988            &deta_i,
3989            dot_h_j,
3990            dot_h_i,
3991            dot_i_j,
3992            dot_i_i,
3993            None,
3994            None,
3995        );
3996        let analytic_ij = op.hphi_tau_tau_partial_apply(&x_tau_i, &x_tau_j, &kernel_ij, &rhs);
3997        let analytic_ji = op.hphi_tau_tau_partial_apply(&x_tau_j, &x_tau_i, &kernel_ji, &rhs);
3998
3999        // Symmetry cross-check (Clairaut): ∂²H/∂τ_i∂τ_j = ∂²H/∂τ_j∂τ_i.
4000        let sym_diff: f64 = (&analytic_ij - &analytic_ji)
4001            .iter()
4002            .map(|v| v.abs())
4003            .fold(0.0_f64, f64::max);
4004        let sym_scale: f64 = analytic_ij
4005            .iter()
4006            .chain(analytic_ji.iter())
4007            .map(|v| v.abs())
4008            .fold(0.0_f64, f64::max)
4009            .max(1.0);
4010        assert!(
4011            sym_diff / sym_scale < 1e-10,
4012            "τ×τ primitive not symmetric in direction order: sym_diff={sym_diff:.3e}"
4013        );
4014
4015        // ── FD reference: central difference of single-τ primitive in
4016        //    τ_j direction, evaluated along τ_i.
4017        let h = 1e-5_f64;
4018        let fd_block = |x_eval: &Array2<f64>| -> Array2<f64> {
4019            let eta_e = x_eval.dot(&beta);
4020            let op_e =
4021                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4022            let x_tau_i_r = op_e.reduce_explicit_design(&x_tau_i);
4023            let deta_i_e = x_tau_i.dot(&beta);
4024            let (dot_i_i_e, dot_h_i_e) = op_e.dot_i_and_h_from_reduced(&x_tau_i_r, &deta_i_e);
4025            let kernel_i_e = op_e
4026                .hphi_tau_partial_prepare_from_partials(x_tau_i_r, &deta_i_e, dot_h_i_e, dot_i_i_e);
4027            op_e.hphi_tau_partial_apply(&x_tau_i, &kernel_i_e, &rhs)
4028        };
4029        let x_plus = &x + &(h * &x_tau_j);
4030        let x_minus = &x - &(h * &x_tau_j);
4031        let fd_ij = (&fd_block(&x_plus) - &fd_block(&x_minus)) / (2.0 * h);
4032
4033        // ── Compare analytic_ij (contracted against V along τ_j→analytic's
4034        //    second index) to fd_ij (FD of T_i in τ_j direction).
4035        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4036            let scale = a
4037                .iter()
4038                .chain(b.iter())
4039                .map(|v| v.abs())
4040                .fold(0.0_f64, f64::max)
4041                .max(1.0);
4042            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4043            max_diff / scale
4044        };
4045        let err_ij = rel_max_abs_diff(&analytic_ij, &fd_ij);
4046
4047        // Also FD the other direction and compare to analytic_ji, to
4048        // double-cover the primitive.
4049        let fd_block_j = |x_eval: &Array2<f64>| -> Array2<f64> {
4050            let eta_e = x_eval.dot(&beta);
4051            let op_e =
4052                build_logit_firth_dense_operator(x_eval, &eta_e).expect("perturbed firth operator");
4053            let x_tau_j_r = op_e.reduce_explicit_design(&x_tau_j);
4054            let deta_j_e = x_tau_j.dot(&beta);
4055            let (dot_i_j_e, dot_h_j_e) = op_e.dot_i_and_h_from_reduced(&x_tau_j_r, &deta_j_e);
4056            let kernel_j_e = op_e
4057                .hphi_tau_partial_prepare_from_partials(x_tau_j_r, &deta_j_e, dot_h_j_e, dot_i_j_e);
4058            op_e.hphi_tau_partial_apply(&x_tau_j, &kernel_j_e, &rhs)
4059        };
4060        let x_plus_i = &x + &(h * &x_tau_i);
4061        let x_minus_i = &x - &(h * &x_tau_i);
4062        let fd_ji = (&fd_block_j(&x_plus_i) - &fd_block_j(&x_minus_i)) / (2.0 * h);
4063        let err_ji = rel_max_abs_diff(&analytic_ji, &fd_ji);
4064
4065        let tol = 1e-7_f64;
4066        assert!(
4067            err_ij < tol,
4068            "∂²H_φ/∂τ_i∂τ_j apply mismatch (i,j): rel_max_abs_diff={err_ij:.3e} > {tol:.1e}\n\
4069             analytic=\n{analytic_ij:?}\n\
4070             fd=\n{fd_ij:?}"
4071        );
4072        assert!(
4073            err_ji < tol,
4074            "∂²H_φ/∂τ_j∂τ_i apply mismatch (j,i): rel_max_abs_diff={err_ji:.3e} > {tol:.1e}\n\
4075             analytic=\n{analytic_ji:?}\n\
4076             fd=\n{fd_ji:?}"
4077        );
4078    }
4079
4080    /// Verify the Primitive B body (`d_beta_hphi_tau_partial_apply`) against a
4081    /// finite-difference reference of the single-τ Primitive
4082    /// (`hphi_tau_partial_apply`).
4083    ///
4084    /// Identity under test (β held in the unreduced ambient; the design X is
4085    /// fixed so only w, η̇_τ = X_τ β, and their β-derivatives move):
4086    ///     D_β [ (H_φ)_τ|_β (β) · V ] [v]
4087    ///       = d_beta_hphi_tau_partial_apply(v, V).
4088    ///
4089    /// Central-difference reference:
4090    ///   1. Evaluate T(t) := hphi_tau_partial_apply(V) at β_t = β + t v,
4091    ///      rebuilding FirthDenseOperator at each β (so η = X β_t and the
4092    ///      w-chain are re-derived cleanly).  X is unchanged; Q is rebuilt
4093    ///      but H_φ applied to a p-space rhs is basis-invariant.
4094    ///   2. FD = (T(+h) − T(−h)) / (2h).
4095    ///   3. Tolerance 1e-7 relative max-abs (h chosen to balance truncation
4096    ///      error at ~h² and evaluator roundoff at ~ε/h).
4097    #[test]
4098    pub(crate) fn firth_d_beta_hphi_tau_partial_matches_finite_difference() {
4099        let x = array![
4100            [1.0, -1.0, 0.2],
4101            [1.0, -0.6, -0.3],
4102            [1.0, -0.1, 0.5],
4103            [1.0, 0.3, -0.7],
4104            [1.0, 0.8, 0.1],
4105            [1.0, 1.2, -0.4],
4106        ];
4107        let x_tau = array![
4108            [0.0, 0.15, -0.05],
4109            [0.0, -0.10, 0.02],
4110            [0.0, 0.08, 0.04],
4111            [0.0, -0.06, -0.03],
4112            [0.0, 0.05, 0.01],
4113            [0.0, -0.12, 0.06],
4114        ];
4115        let beta = array![0.1, -0.25, 0.2];
4116        // β-direction v for the D_β[·][v] test.
4117        let v = array![0.3, 0.2, -0.15];
4118
4119        let eta = x.dot(&beta);
4120        let op = build_logit_firth_dense_operator(&x, &eta).expect("firth operator");
4121        let p = x.ncols();
4122
4123        // Reproducible small rhs block (p × m).
4124        let m = 3usize;
4125        let mut rhs = Array2::<f64>::zeros((p, m));
4126        let vals = [0.21, -0.44, 0.17, 0.38, 0.05, -0.22, -0.11, 0.27, 0.31];
4127        for r in 0..p {
4128            for c in 0..m {
4129                rhs[[r, c]] = vals[(r * m + c) % vals.len()];
4130            }
4131        }
4132
4133        // ── Analytic apply at (x, β).
4134        let x_tau_reduced = op.reduce_explicit_design(&x_tau);
4135        let deta_partial = x_tau.dot(&beta);
4136        let (dot_i_partial, dot_h_partial) =
4137            op.dot_i_and_h_from_reduced(&x_tau_reduced, &deta_partial);
4138        let tau_kernel = op.hphi_tau_partial_prepare_from_partials(
4139            x_tau_reduced.clone(),
4140            &deta_partial,
4141            dot_h_partial.clone(),
4142            dot_i_partial.clone(),
4143        );
4144
4145        let deta_v = x.dot(&v);
4146        let direction = op.direction_from_deta(deta_v);
4147        let x_tau_v = x_tau.dot(&v);
4148        let pair_kernel = op.d_beta_hphi_tau_partial_prepare_from_partials(
4149            &tau_kernel,
4150            &deta_partial,
4151            &dot_i_partial,
4152            &direction,
4153            &x_tau_v,
4154        );
4155        let analytic = op.d_beta_hphi_tau_partial_apply(&x_tau, &pair_kernel, &rhs);
4156
4157        // ── FD reference: central difference of single-τ primitive under
4158        //    β → β ± h v.  X stays fixed; η, w, η̇_τ are re-derived.
4159        let h = 1e-5_f64;
4160        let single_tau_apply = |beta_eval: &Array1<f64>| -> Array2<f64> {
4161            let eta_e = x.dot(beta_eval);
4162            let op_e =
4163                build_logit_firth_dense_operator(&x, &eta_e).expect("perturbed firth operator");
4164            let x_tau_r = op_e.reduce_explicit_design(&x_tau);
4165            let deta_e = x_tau.dot(beta_eval);
4166            let (dot_i_e, dot_h_e) = op_e.dot_i_and_h_from_reduced(&x_tau_r, &deta_e);
4167            let ker_e =
4168                op_e.hphi_tau_partial_prepare_from_partials(x_tau_r, &deta_e, dot_h_e, dot_i_e);
4169            op_e.hphi_tau_partial_apply(&x_tau, &ker_e, &rhs)
4170        };
4171        let beta_plus = &beta + &(h * &v);
4172        let beta_minus = &beta - &(h * &v);
4173        let fd = (&single_tau_apply(&beta_plus) - &single_tau_apply(&beta_minus)) / (2.0 * h);
4174
4175        let rel_max_abs_diff = |a: &Array2<f64>, b: &Array2<f64>| -> f64 {
4176            let scale = a
4177                .iter()
4178                .chain(b.iter())
4179                .map(|v| v.abs())
4180                .fold(0.0_f64, f64::max)
4181                .max(1.0);
4182            let max_diff = (a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
4183            max_diff / scale
4184        };
4185        let err = rel_max_abs_diff(&analytic, &fd);
4186
4187        let tol = 1e-7_f64;
4188        assert!(
4189            err < tol,
4190            "D_β (H_φ)_τ|_β apply mismatch: rel_max_abs_diff={err:.3e} > {tol:.1e}\n\
4191             analytic=\n{analytic:?}\n\
4192             fd=\n{fd:?}"
4193        );
4194    }
4195
4196    #[test]
4197    pub(crate) fn logisticweight_loses_positive_tail_mass() {
4198        let eta = 50.0_f64;
4199        let z = (-eta).exp();
4200        let stable = z / (1.0_f64 + z).powi(2);
4201        assert!(stable > 0.0);
4202        let got = logisticweight(eta);
4203        assert!(
4204            (got - stable).abs() < 1e-30,
4205            "Firth logisticweight should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
4206            got,
4207            stable
4208        );
4209    }
4210
4211    #[test]
4212    pub(crate) fn fisher_weight_jet5_logit_is_byte_identical_to_inverse_link_jet() {
4213        // The generalized Firth weight jet for the canonical logit link must
4214        // reproduce the historical `logit_inverse_link_jet5().d1..d5` path
4215        // exactly so the released logit Firth fits stay numerically unchanged.
4216        for &eta in &[
4217            -40.0, -8.0, -3.0, -1.0, -0.25, 0.0, 0.25, 1.0, 3.0, 8.0, 40.0,
4218        ] {
4219            let jet = logit_inverse_link_jet5(eta);
4220            let (w, w1, w2, w3, w4) =
4221                crate::mixture_link::fisher_weight_jet5(StandardLink::Logit, eta);
4222            assert!(
4223                w == jet.d1 && w1 == jet.d2 && w2 == jet.d3 && w3 == jet.d4 && w4 == jet.d5,
4224                "logit Fisher-weight jet must equal inverse-link jet derivatives at eta={eta}: \
4225                 got ({w}, {w1}, {w2}, {w3}, {w4}) vs ({}, {}, {}, {}, {})",
4226                jet.d1,
4227                jet.d2,
4228                jet.d3,
4229                jet.d4,
4230                jet.d5
4231            );
4232        }
4233    }
4234
4235    #[test]
4236    pub(crate) fn fisher_weight_jet5_probit_matches_finite_difference() {
4237        // Probit Bernoulli Fisher weight W(eta) = phi^2 / (Phi (1 - Phi)).
4238        // Validate the closed-form jet against central finite differences of
4239        // the reference scalar weight.
4240        fn reference_probit_weight(eta: f64) -> f64 {
4241            let p = gam_math::probability::normal_cdf(eta);
4242            let q = 1.0 - p;
4243            let phi = gam_math::probability::normal_pdf(eta);
4244            if p <= 0.0 || q <= 0.0 {
4245                return 0.0;
4246            }
4247            phi * phi / (p * q)
4248        }
4249        let h = 1e-4_f64;
4250        for &eta in &[-3.0, -1.5, -0.5, 0.0, 0.3, 1.5, 3.0] {
4251            let (w, w1, w2, _w3, _w4) =
4252                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4253            let ref_w = reference_probit_weight(eta);
4254            let fd1 =
4255                (reference_probit_weight(eta + h) - reference_probit_weight(eta - h)) / (2.0 * h);
4256            let fd2 = (reference_probit_weight(eta + h) - 2.0 * reference_probit_weight(eta)
4257                + reference_probit_weight(eta - h))
4258                / (h * h);
4259            assert!(
4260                (w - ref_w).abs() < 1e-10,
4261                "probit W mismatch at eta={eta}: jet {w} vs ref {ref_w}"
4262            );
4263            assert!(
4264                (w1 - fd1).abs() < 1e-5,
4265                "probit W' mismatch at eta={eta}: jet {w1} vs fd {fd1}"
4266            );
4267            assert!(
4268                (w2 - fd2).abs() < 1e-3,
4269                "probit W'' mismatch at eta={eta}: jet {w2} vs fd {fd2}"
4270            );
4271        }
4272    }
4273
4274    #[test]
4275    pub(crate) fn fisher_weight_jet5_probit_saturates_to_zero_in_tails() {
4276        // Past the point where the denominator Phi(1-Phi) underflows to zero,
4277        // the weight and all derivatives are exactly zero (the saturated-tail
4278        // convention shared with the inverse-link jet).
4279        for &eta in &[40.0_f64, -40.0, 80.0, -80.0] {
4280            let (w, w1, w2, w3, w4) =
4281                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4282            assert!(
4283                w == 0.0 && w1 == 0.0 && w2 == 0.0 && w3 == 0.0 && w4 == 0.0,
4284                "probit Fisher weight jet must saturate to zero at eta={eta}; got \
4285                 ({w}, {w1}, {w2}, {w3}, {w4})"
4286            );
4287        }
4288        // In the moderate tail the denominator is still representable (the
4289        // complement is taken as Phi(-eta), not the cancellation-prone
4290        // `1 - Phi(eta)`), so the weight is a tiny strictly-positive finite
4291        // number with finite derivatives. It must NOT prematurely round to zero.
4292        for &eta in &[12.0_f64, -12.0] {
4293            let (w, w1, w2, w3, w4) =
4294                crate::mixture_link::fisher_weight_jet5(StandardLink::Probit, eta);
4295            assert!(
4296                w > 0.0
4297                    && w.is_finite()
4298                    && w1.is_finite()
4299                    && w2.is_finite()
4300                    && w3.is_finite()
4301                    && w4.is_finite(),
4302                "probit Fisher weight jet must be tiny-positive and finite at eta={eta}; got \
4303                 ({w}, {w1}, {w2}, {w3}, {w4})"
4304            );
4305        }
4306    }
4307}