Skip to main content

gam_solve/reml/reml_outer_engine/
hyper_operator.rs

1use super::*;
2
3pub(crate) fn as_implicit(op: &dyn HyperOperator) -> Option<&ImplicitHyperOperator> {
4    op.as_any().downcast_ref::<ImplicitHyperOperator>()
5}
6
7pub(crate) fn as_composite(op: &dyn HyperOperator) -> Option<&CompositeHyperOperator> {
8    op.as_any().downcast_ref::<CompositeHyperOperator>()
9}
10
11pub(crate) fn as_weighted(op: &dyn HyperOperator) -> Option<&WeightedHyperOperator> {
12    op.as_any().downcast_ref::<WeightedHyperOperator>()
13}
14
15pub(crate) trait DriftDerivTraceExt {
16    fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64;
17
18    fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64;
19}
20
21impl DriftDerivTraceExt for DriftDerivResult {
22    fn trace_logdet(&self, hop: &dyn HessianOperator) -> f64 {
23        match self {
24            Self::Dense(matrix) => hop.trace_logdet_gradient(matrix),
25            Self::Operator(operator) => hop.trace_logdet_operator(operator.as_ref()),
26        }
27    }
28
29    fn trace_logdet_hessian_cross(&self, rhs: &Self, hop: &dyn HessianOperator) -> f64 {
30        match (self, rhs) {
31            (Self::Dense(left), Self::Dense(right)) => hop.trace_logdet_hessian_cross(left, right),
32            (Self::Dense(left), Self::Operator(right)) => {
33                hop.trace_logdet_hessian_cross_matrix_operator(left, right.as_ref())
34            }
35            (Self::Operator(left), Self::Dense(right)) => {
36                hop.trace_logdet_hessian_cross_matrix_operator(right, left.as_ref())
37            }
38            (Self::Operator(left), Self::Operator(right)) => {
39                hop.trace_logdet_hessian_cross_operator(left.as_ref(), right.as_ref())
40            }
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct CompositeHyperOperator {
47    pub dense: Option<Array2<f64>>,
48    pub operators: Vec<Arc<dyn HyperOperator>>,
49    pub dim_hint: usize,
50}
51
52/// Group composite operators by shared `(implicit_deriv, x_design, w_diag)`
53/// so every Duchon ψ-axis built atop the same implicit derivative runs
54/// through a single row-kernel sweep via
55/// `trace_projected_factor_all_axes_with_xf`. Per-axis `s_psi` and
56/// `c_x_psi_beta` are threaded in individually so the batched path matches
57/// the per-axis path exactly. Non-implicit operators and singleton groups
58/// fall through to the original per-op trace path.
59pub(crate) fn composite_trace_implicit_batched(
60    operators: &[Arc<dyn HyperOperator>],
61    factor: &Array2<f64>,
62    cache: Option<&ProjectedFactorCache>,
63) -> f64 {
64    let mut trace = 0.0;
65    let mut group_starts: Vec<Vec<usize>> = Vec::new();
66    let mut handled = vec![false; operators.len()];
67
68    for (i, op) in operators.iter().enumerate() {
69        if handled[i] {
70            continue;
71        }
72        let Some(impl_i) = as_implicit(op.as_ref()) else {
73            continue;
74        };
75        let mut group = vec![i];
76        handled[i] = true;
77        for j in (i + 1)..operators.len() {
78            if handled[j] {
79                continue;
80            }
81            if let Some(impl_j) = as_implicit(operators[j].as_ref())
82                && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
83                && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
84                && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
85                && impl_i.p == impl_j.p
86            {
87                group.push(j);
88                handled[j] = true;
89            }
90        }
91        group_starts.push(group);
92    }
93
94    for group in &group_starts {
95        if group.len() >= 2 {
96            let lead = as_implicit(operators[group[0]].as_ref()).unwrap();
97            let xf = match cache {
98                Some(c) => lead.cached_xf(factor, c),
99                None => Arc::new(lead.compute_xf(factor)),
100            };
101            let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
102                .iter()
103                .map(|&k| {
104                    let op = as_implicit(operators[k].as_ref()).unwrap();
105                    (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
106                })
107                .collect();
108            let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
109            trace += values.iter().sum::<f64>();
110        } else {
111            let op = &operators[group[0]];
112            trace += match cache {
113                Some(c) => op.trace_projected_factor_cached(factor, c),
114                None => op.trace_projected_factor(factor),
115            };
116        }
117    }
118
119    for (i, op) in operators.iter().enumerate() {
120        if handled[i] {
121            continue;
122        }
123        trace += match cache {
124            Some(c) => op.trace_projected_factor_cached(factor, c),
125            None => op.trace_projected_factor(factor),
126        };
127    }
128
129    trace
130}
131
132/// Vector form of the implicit-axis trace batching used by
133/// [`CompositeHyperOperator`].  It returns one exact `tr(Fᵀ B_i F)` value per
134/// input operator while sharing the expensive `X·F` projection and Duchon
135/// row-kernel sweeps across sibling implicit ψ/ρ axes.
136pub(crate) fn trace_projected_factors_batched(
137    operators: &[Arc<dyn HyperOperator>],
138    factor: &Array2<f64>,
139    cache: &ProjectedFactorCache,
140) -> Vec<f64> {
141    let mut out = vec![0.0; operators.len()];
142    let mut handled = vec![false; operators.len()];
143
144    for i in 0..operators.len() {
145        if handled[i] {
146            continue;
147        }
148        let Some(impl_i) = as_implicit(operators[i].as_ref()) else {
149            out[i] = operators[i].trace_projected_factor_cached(factor, cache);
150            handled[i] = true;
151            continue;
152        };
153
154        let mut group = vec![i];
155        handled[i] = true;
156        for j in (i + 1)..operators.len() {
157            if handled[j] {
158                continue;
159            }
160            if let Some(impl_j) = as_implicit(operators[j].as_ref())
161                && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
162                && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
163                && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
164                && impl_i.p == impl_j.p
165            {
166                group.push(j);
167                handled[j] = true;
168            }
169        }
170
171        if group.len() >= 2 {
172            let xf = impl_i.cached_xf(factor, cache);
173            let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
174                .iter()
175                .map(|&idx| {
176                    let op = as_implicit(operators[idx].as_ref()).unwrap();
177                    (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
178                })
179                .collect();
180            let values = impl_i.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
181            for (&idx, value) in group.iter().zip(values) {
182                out[idx] = value;
183            }
184        } else {
185            out[i] = operators[i].trace_projected_factor_cached(factor, cache);
186        }
187    }
188
189    out
190}
191
192pub(crate) fn collect_projected_trace_terms<'a>(
193    out_idx: usize,
194    weight: f64,
195    op: &'a dyn HyperOperator,
196    factor: &Array2<f64>,
197    dense_acc: &mut [f64],
198    terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
199) {
200    if weight == 0.0 {
201        return;
202    }
203    if let Some(composite) = as_composite(op) {
204        if let Some(dense) = composite.dense.as_ref() {
205            dense_acc[out_idx] += weight * dense_trace_projected_factor(dense, factor);
206        }
207        for inner in &composite.operators {
208            collect_projected_trace_terms(
209                out_idx,
210                weight,
211                inner.as_ref(),
212                factor,
213                dense_acc,
214                terms,
215            );
216        }
217    } else if let Some(weighted) = as_weighted(op) {
218        for (term_weight, inner) in &weighted.terms {
219            collect_projected_trace_terms(
220                out_idx,
221                weight * *term_weight,
222                inner.as_ref(),
223                factor,
224                dense_acc,
225                terms,
226            );
227        }
228    } else {
229        terms.push((out_idx, weight, op));
230    }
231}
232
233pub(crate) fn collect_projected_matrix_terms<'a>(
234    out_idx: usize,
235    weight: f64,
236    op: &'a dyn HyperOperator,
237    factor: &Array2<f64>,
238    dense_acc: &mut [Array2<f64>],
239    terms: &mut Vec<(usize, f64, &'a dyn HyperOperator)>,
240) {
241    if weight == 0.0 {
242        return;
243    }
244    if let Some(composite) = as_composite(op) {
245        if let Some(dense) = composite.dense.as_ref() {
246            dense_acc[out_idx].scaled_add(weight, &dense_projected_matrix(dense, factor));
247        }
248        for inner in &composite.operators {
249            collect_projected_matrix_terms(
250                out_idx,
251                weight,
252                inner.as_ref(),
253                factor,
254                dense_acc,
255                terms,
256            );
257        }
258    } else if let Some(weighted) = as_weighted(op) {
259        for (term_weight, inner) in &weighted.terms {
260            collect_projected_matrix_terms(
261                out_idx,
262                weight * *term_weight,
263                inner.as_ref(),
264                factor,
265                dense_acc,
266                terms,
267            );
268        }
269    } else {
270        terms.push((out_idx, weight, op));
271    }
272}
273
274pub(crate) fn trace_projected_operator_terms_batched(
275    n_out: usize,
276    terms: &[(usize, f64, &dyn HyperOperator)],
277    factor: &Array2<f64>,
278    cache: &ProjectedFactorCache,
279) -> Vec<f64> {
280    let mut out = vec![0.0_f64; n_out];
281    let mut handled = vec![false; terms.len()];
282
283    for i in 0..terms.len() {
284        if handled[i] {
285            continue;
286        }
287        let Some(impl_i) = as_implicit(terms[i].2) else {
288            continue;
289        };
290        let mut group = vec![i];
291        handled[i] = true;
292        for j in (i + 1)..terms.len() {
293            if handled[j] {
294                continue;
295            }
296            if let Some(impl_j) = as_implicit(terms[j].2)
297                && Arc::ptr_eq(&impl_i.implicit_deriv, &impl_j.implicit_deriv)
298                && Arc::ptr_eq(&impl_i.x_design, &impl_j.x_design)
299                && Arc::ptr_eq(impl_i.w_diag.as_arc(), impl_j.w_diag.as_arc())
300                && impl_i.p == impl_j.p
301            {
302                group.push(j);
303                handled[j] = true;
304            }
305        }
306
307        let lead = as_implicit(terms[group[0]].2).unwrap();
308        let xf = lead.cached_xf(factor, cache);
309        let axes: Vec<(usize, &Array2<f64>, Option<&Array1<f64>>)> = group
310            .iter()
311            .map(|&term_idx| {
312                let op = as_implicit(terms[term_idx].2).unwrap();
313                (op.axis, &op.s_psi, op.c_x_psi_beta.as_deref())
314            })
315            .collect();
316        let values = lead.trace_projected_factor_all_axes_with_xf(factor, xf.view(), &axes);
317        for (&term_idx, value) in group.iter().zip(values.iter()) {
318            let (out_idx, weight, _) = terms[term_idx];
319            out[out_idx] += weight * *value;
320        }
321    }
322
323    for (i, (out_idx, weight, op)) in terms.iter().enumerate() {
324        if handled[i] {
325            continue;
326        }
327        out[*out_idx] += *weight * op.trace_projected_factor_cached(factor, cache);
328    }
329
330    out
331}
332
333pub(crate) fn projected_operator_terms_batched(
334    n_out: usize,
335    terms: &[(usize, f64, &dyn HyperOperator)],
336    factor: &Array2<f64>,
337    cache: &ProjectedFactorCache,
338) -> Vec<Array2<f64>> {
339    let rank = factor.ncols();
340    let mut out: Vec<Array2<f64>> = (0..n_out)
341        .map(|_| Array2::<f64>::zeros((rank, rank)))
342        .collect();
343    for (out_idx, weight, op) in terms.iter() {
344        let projected = op.projected_matrix_cached(factor, cache);
345        out[*out_idx].scaled_add(*weight, &projected);
346    }
347    out
348}
349
350pub(crate) fn project_hyper_operators_batched(
351    n_out: usize,
352    terms: &[(usize, f64, &dyn HyperOperator)],
353    factor: &Array2<f64>,
354    cache: &ProjectedFactorCache,
355) -> Vec<Array2<f64>> {
356    projected_operator_terms_batched(n_out, terms, factor, cache)
357}
358
359pub(crate) fn trace_logdet_drifts_projected_factor_batched(
360    drifts: &[DriftDerivResult],
361    factor: &Array2<f64>,
362    cache: &ProjectedFactorCache,
363) -> Vec<f64> {
364    let mut out = vec![0.0_f64; drifts.len()];
365    let mut terms: Vec<(usize, f64, &dyn HyperOperator)> = Vec::new();
366    for (idx, drift) in drifts.iter().enumerate() {
367        match drift {
368            DriftDerivResult::Dense(matrix) => {
369                out[idx] += dense_trace_projected_factor(matrix, factor);
370            }
371            DriftDerivResult::Operator(op) => {
372                collect_projected_trace_terms(idx, 1.0, op.as_ref(), factor, &mut out, &mut terms);
373            }
374        }
375    }
376    let batched = trace_projected_operator_terms_batched(drifts.len(), &terms, factor, cache);
377    for (dst, value) in out.iter_mut().zip(batched) {
378        *dst += value;
379    }
380    out
381}
382
383pub(crate) fn dense_spectral_trace_logdet_drifts_batched(
384    ds: &DenseSpectralOperator,
385    drifts: &[DriftDerivResult],
386) -> Vec<f64> {
387    trace_logdet_drifts_projected_factor_batched(drifts, &ds.g_factor, &ds.projected_factor_cache)
388}
389
390pub(crate) fn penalty_subspace_trace_factor(kernel: &PenaltySubspaceTrace) -> Array2<f64> {
391    let (evals, evecs) = kernel
392        .h_proj_inverse
393        .eigh(faer::Side::Lower)
394        .expect("PenaltySubspaceTrace kernel factor eigendecomposition failed");
395    let r = evals.len();
396    // F must satisfy F·Fᵀ = K exactly: the batched `tr(FᵀAF)` is consumed as
397    // the gradient of the SAME pseudo-logdet criterion whose exact kernel the
398    // per-coordinate path contracts via `h_proj_inverse` directly. The kernel
399    // eigenvalues are `1/σ_a` over the kept Hessian spectrum, so their
400    // dynamic range is the Hessian condition number — clamp ONLY the
401    // roundoff-negative tail to zero (K is PSD by construction; a negative
402    // eigenvalue is O(ε)·‖K‖ eigensolver noise, and √(max(λ,0)) is the
403    // honest PSD square root). A relative floor here is NOT a stabilization:
404    // raising `1/σ_max` to `√ε·r·(1/σ_min)` rewrites the criterion's
405    // sensitivity along exactly the stiffest directions — where the ρ-drifts
406    // `λ_k·S_k` live — inflating the analytic trace by up to `√ε·r·κ(H_pen)`
407    // (O(1) once κ ≳ 1e7) while FD differentiates the true criterion. That
408    // desync red-lined every iso-κ Duchon probit/logit FD test and starved
409    // the spatial κ-optimizer of descent directions; Gaussian was immune
410    // because the intrinsic kernel is only installed for c-nontrivial
411    // families (#901).
412    let mut root = evecs.clone();
413    for col in 0..r {
414        let scale = evals[col].max(0.0).sqrt();
415        for row in 0..r {
416            root[[row, col]] *= scale;
417        }
418    }
419    gam_linalg::faer_ndarray::fast_ab(&kernel.u_s, &root)
420}
421
422pub(crate) fn penalty_subspace_trace_drifts_batched(
423    kernel: &PenaltySubspaceTrace,
424    drifts: &[DriftDerivResult],
425) -> Vec<f64> {
426    let factor = penalty_subspace_trace_factor(kernel);
427    let cache = ProjectedFactorCache::default();
428    trace_logdet_drifts_projected_factor_batched(drifts, &factor, &cache)
429}
430
431pub(crate) fn penalty_subspace_reduce_drifts_batched(
432    kernel: &PenaltySubspaceTrace,
433    drifts: &[DriftDerivResult],
434) -> Vec<Array2<f64>> {
435    drifts
436        .iter()
437        .map(|drift| match drift {
438            DriftDerivResult::Dense(matrix) => kernel.reduce(matrix),
439            // #901 layer-2 (outer-Hessian path): reduce the operator via
440            // `U_Sᵀ·A·U_S = U_Sᵀ·A.mul_mat(U_S)` — NOT `op.to_dense()` then
441            // reduce. For the GLM cubic correction `C[v] = Xᵀdiag(c⊙Xv)X` the
442            // dense materialization computes near-null quadratic forms by
443            // cancelling O(‖C‖) entries, and the spectral kernel's `1/σ_min`
444            // then amplifies the roundoff (the +39-vs-−0.30 / ~−7.7e5 blow-up).
445            // `reduce_operator` probes through the `X·U_S` matvecs instead, so
446            // tiny² stays tiny — the same stability cure as the first-order
447            // `trace_operator` path.
448            DriftDerivResult::Operator(op) => kernel.reduce_operator(op.as_ref()),
449        })
450        .collect()
451}
452
453pub(crate) fn dense_spectral_trace_logdet_operators_batched(
454    ds: &DenseSpectralOperator,
455    operators: &[Arc<dyn HyperOperator>],
456) -> Vec<f64> {
457    if operators.is_empty() {
458        return Vec::new();
459    }
460    if log::log_enabled!(log::Level::Info) {
461        let start = std::time::Instant::now();
462        let out =
463            trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache);
464        let implicit_count = operators.iter().filter(|op| op.is_implicit()).count();
465        dense_spectral_stage_log(
466            &format!(
467                "DenseSpectralOperator::trace_logdet_operators_batched dim={} rank={} ops={} implicit_ops={}",
468                ds.n_dim,
469                ds.g_factor.ncols(),
470                operators.len(),
471                implicit_count,
472            ),
473            start.elapsed().as_secs_f64(),
474        );
475        out
476    } else {
477        trace_projected_factors_batched(operators, &ds.g_factor, &ds.projected_factor_cache)
478    }
479}
480
481impl HyperOperator for CompositeHyperOperator {
482    fn as_any(&self) -> &(dyn std::any::Any + 'static) {
483        self
484    }
485
486    fn dim(&self) -> usize {
487        self.dim_hint
488    }
489
490    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
491        let mut out = Array1::<f64>::zeros(v.len());
492        self.mul_vec_into(v.view(), out.view_mut());
493        out
494    }
495
496    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
497        let mut out = Array1::<f64>::zeros(v.len());
498        self.mul_vec_into(v, out.view_mut());
499        out
500    }
501
502    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, mut out: ArrayViewMut1<'_, f64>) {
503        if self.dense.is_none() && self.operators.len() == 1 {
504            self.operators[0].mul_vec_into(v, out);
505            return;
506        }
507
508        out.fill(0.0);
509        if let Some(dense) = self.dense.as_ref() {
510            dense_matvec_into(dense, v, out.view_mut());
511        }
512        for op in &self.operators {
513            op.scaled_add_mul_vec(v, 1.0, out.view_mut());
514        }
515    }
516
517    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
518        if self.dense.is_none() && self.operators.len() == 1 {
519            self.operators[0].mul_basis_columns_into(start, out);
520            return;
521        }
522
523        out.fill(0.0);
524        let cols = out.ncols();
525        let end = start + cols;
526        if let Some(dense) = self.dense.as_ref() {
527            out += &dense.slice(ndarray::s![.., start..end]);
528        }
529        let mut work = Array2::<f64>::zeros((out.nrows(), cols));
530        for op in &self.operators {
531            op.mul_basis_columns_into(start, work.view_mut());
532            out += &work;
533        }
534    }
535
536    fn scaled_add_mul_vec(
537        &self,
538        v: ArrayView1<'_, f64>,
539        scale: f64,
540        mut out: ArrayViewMut1<'_, f64>,
541    ) {
542        if scale == 0.0 {
543            return;
544        }
545        if self.dense.is_none() && self.operators.len() == 1 {
546            self.operators[0].scaled_add_mul_vec(v, scale, out);
547            return;
548        }
549
550        if let Some(dense) = self.dense.as_ref() {
551            dense_matvec_scaled_add_into(dense, v, scale, out.view_mut());
552        }
553        for op in &self.operators {
554            op.scaled_add_mul_vec(v, scale, out.view_mut());
555        }
556    }
557
558    /// Forward batched apply to inner operators so their `mul_mat` overrides
559    /// (matrix-free Khatri–Rao BLAS3 fuses) fire instead of the default
560    /// per-column parallel matvec — which would triple-nest rayon when an
561    /// inner op already parallelizes internally.
562    fn mul_mat(&self, factor: &Array2<f64>) -> Array2<f64> {
563        if self.dense.is_none() && self.operators.len() == 1 {
564            return self.operators[0].mul_mat(factor);
565        }
566        let p = factor.nrows();
567        let k = factor.ncols();
568        let mut out = Array2::<f64>::zeros((p, k));
569        if let Some(dense) = self.dense.as_ref() {
570            out += &dense.dot(factor);
571        }
572        for op in &self.operators {
573            out += &op.mul_mat(factor);
574        }
575        out
576    }
577
578    fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
579        if self.dense.is_none() && self.operators.len() == 1 {
580            return self.operators[0].trace_projected_factor(factor);
581        }
582
583        let mut trace = 0.0;
584        if let Some(dense) = self.dense.as_ref() {
585            let dense_factor = dense.dot(factor);
586            trace += factor
587                .iter()
588                .zip(dense_factor.iter())
589                .map(|(&f, &bf)| f * bf)
590                .sum::<f64>();
591        }
592        trace += composite_trace_implicit_batched(&self.operators, factor, None);
593        trace
594    }
595
596    fn trace_projected_factor_cached(
597        &self,
598        factor: &Array2<f64>,
599        cache: &ProjectedFactorCache,
600    ) -> f64 {
601        if self.dense.is_none() && self.operators.len() == 1 {
602            return self.operators[0].trace_projected_factor_cached(factor, cache);
603        }
604
605        let mut trace = 0.0;
606        if let Some(dense) = self.dense.as_ref() {
607            let dense_factor = dense.dot(factor);
608            trace += factor
609                .iter()
610                .zip(dense_factor.iter())
611                .map(|(&f, &bf)| f * bf)
612                .sum::<f64>();
613        }
614        trace += composite_trace_implicit_batched(&self.operators, factor, Some(cache));
615        trace
616    }
617
618    fn projected_matrix(&self, factor: &Array2<f64>) -> Array2<f64> {
619        if self.dense.is_none() && self.operators.len() == 1 {
620            return self.operators[0].projected_matrix(factor);
621        }
622
623        let rank = factor.ncols();
624        let mut projected = Array2::<f64>::zeros((rank, rank));
625        if let Some(dense) = self.dense.as_ref() {
626            let mf = gam_linalg::faer_ndarray::fast_ab(dense, factor);
627            projected += &gam_linalg::faer_ndarray::fast_atb(factor, &mf);
628        }
629        for op in &self.operators {
630            projected += &op.projected_matrix(factor);
631        }
632        projected
633    }
634
635    fn projected_matrix_cached(
636        &self,
637        factor: &Array2<f64>,
638        cache: &ProjectedFactorCache,
639    ) -> Array2<f64> {
640        if self.dense.is_none() && self.operators.len() == 1 {
641            return self.operators[0].projected_matrix_cached(factor, cache);
642        }
643
644        let rank = factor.ncols();
645        let mut projected = Array2::<f64>::zeros((rank, rank));
646        if let Some(dense) = self.dense.as_ref() {
647            let mf = gam_linalg::faer_ndarray::fast_ab(dense, factor);
648            projected += &gam_linalg::faer_ndarray::fast_atb(factor, &mf);
649        }
650        for op in &self.operators {
651            projected += &op.projected_matrix_cached(factor, cache);
652        }
653        projected
654    }
655
656    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
657        let mut total = 0.0;
658        if let Some(dense) = self.dense.as_ref() {
659            total += dense_bilinear(dense, v.view(), u.view());
660        }
661        for op in &self.operators {
662            total += op.bilinear(v, u);
663        }
664        total
665    }
666
667    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
668        let mut total = 0.0;
669        if let Some(dense) = self.dense.as_ref() {
670            total += dense_bilinear(dense, v, u);
671        }
672        for op in &self.operators {
673            total += op.bilinear_view(v, u);
674        }
675        total
676    }
677
678    fn to_dense(&self) -> Array2<f64> {
679        let mut out = self
680            .dense
681            .clone()
682            .unwrap_or_else(|| Array2::<f64>::zeros((self.dim_hint, self.dim_hint)));
683        for op in &self.operators {
684            out += &op.to_dense();
685        }
686        out
687    }
688
689    fn is_implicit(&self) -> bool {
690        self.operators.iter().any(|op| op.is_implicit())
691    }
692}
693
694/// Implicit Hessian-drift operator for a single anisotropic ψ_d coordinate.
695///
696/// Computes B_d · v on the fly:
697///   B_d · v = (∂X/∂ψ_d)^T (W · (X · v)) + X^T (W · ((∂X/∂ψ_d) · v)) + S_{ψ_d} · v
698///
699/// The first two terms use the implicit design-derivative operator (no dense
700/// (n × p) matrices), and S_{ψ_d} is a dense (p × p) penalty matrix (manageable).
701///
702/// Storage: the implicit operator holds O(n·k·D) radial jets, plus references
703/// to an active-basis X design operator and W (the working weights). The
704/// penalty matrix S_{ψ_d} is stored as a dense (p × p) matrix.
705/// Thread-local scratch buffers for `ImplicitHyperOperator::mul_vec_into`.
706/// Reused across PCG iterations and basis-column sweeps so each matvec
707/// avoids three fresh O(n)/O(p) allocations.
708mod implicit_matvec_scratch {
709    use std::cell::RefCell;
710
711    pub(super) struct Scratch {
712        pub x_v: Vec<f64>,
713        pub n_work: Vec<f64>,
714        pub p_work: Vec<f64>,
715    }
716
717    impl Scratch {
718        pub(crate) const fn new() -> Self {
719            Self {
720                x_v: Vec::new(),
721                n_work: Vec::new(),
722                p_work: Vec::new(),
723            }
724        }
725    }
726
727    thread_local! {
728        static SCRATCH: RefCell<Scratch> = const { RefCell::new(Scratch::new()) };
729    }
730
731    pub(super) fn with<R>(f: impl FnOnce(&mut Scratch) -> R) -> R {
732        SCRATCH.with(|cell| f(&mut cell.borrow_mut()))
733    }
734}
735
736pub struct ImplicitHyperOperator {
737    /// The implicit design-derivative operator (shared across all axes).
738    pub implicit_deriv: std::sync::Arc<gam_terms::basis::ImplicitDesignPsiDerivative>,
739    /// Which axis this operator is for.
740    pub axis: usize,
741    /// The active-basis design matrix X. This may be lazy / operator-backed.
742    pub(crate) x_design: std::sync::Arc<DesignMatrix>,
743    /// Working weights W (diagonal, length n) — observed-information curvature,
744    /// signed for non-canonical links. Carried as the owned [`gam_linalg::matrix::SignedWeightsArc`]
745    /// newtype so the sign character is construction-enforced at the operator
746    /// struct boundary; the function-boundary contract from `linalg/matrix.rs`
747    /// is no longer reconstructable accidentally inside `mul_vec`.
748    pub(crate) w_diag: gam_linalg::matrix::SignedWeightsArc,
749    /// Penalty derivative matrix S_{ψ_d} (p × p), dense.
750    pub s_psi: Array2<f64>,
751    /// Total basis dimension p.
752    pub(crate) p: usize,
753    /// Non-Gaussian fixed-β third-derivative correction: c ⊙ (X_{ψ_d} β̂),
754    /// length n. When present, the operator additionally applies
755    /// `Xᵀ diag(c_x_psi_beta) X v` so that the full B_d formula
756    /// `B_d v = (∂X/∂ψ_d)ᵀ W X v + Xᵀ W (∂X/∂ψ_d) v + Xᵀ diag(c ⊙ X_{ψ_d} β̂) X v + S_{ψ_d} v`
757    /// is matrix-free for non-Gaussian likelihoods. `None` for Gaussian
758    /// identity (c ≡ 0 there).
759    pub c_x_psi_beta: Option<std::sync::Arc<Array1<f64>>>,
760}
761
762impl HyperOperator for ImplicitHyperOperator {
763    fn dim(&self) -> usize {
764        self.p
765    }
766
767    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
768        // Single canonical path: route every matvec through `mul_vec_into`,
769        // which routes through `matvec_with_shared_xz_into`. The four terms of
770        // B_d are assembled there, with the third-derivative correction added
771        // by `accumulate_c_correction_xt_into` so the four matvec entry points
772        // share one inner kernel.
773        let mut out = Array1::<f64>::zeros(self.p);
774        self.mul_vec_into(v.view(), out.view_mut());
775        out
776    }
777
778    fn mul_vec_view(&self, v: ArrayView1<'_, f64>) -> Array1<f64> {
779        let mut out = Array1::<f64>::zeros(self.p);
780        self.mul_vec_into(v, out.view_mut());
781        out
782    }
783
784    fn mul_vec_into(&self, v: ArrayView1<'_, f64>, out: ArrayViewMut1<'_, f64>) {
785        assert_eq!(v.len(), self.p);
786        let n_obs = self.w_diag.len();
787        // Reuse thread-local scratch across repeated matvec calls (e.g.
788        // PCG iterations, basis-column sweeps) instead of allocating
789        // (2 n_obs + p) f64s every time.
790        implicit_matvec_scratch::with(|s| {
791            s.x_v.clear();
792            s.x_v.resize(n_obs, 0.0);
793            s.n_work.clear();
794            s.n_work.resize(n_obs, 0.0);
795            s.p_work.clear();
796            s.p_work.resize(self.p, 0.0);
797            let mut x_v_view = ndarray::ArrayViewMut1::from(s.x_v.as_mut_slice());
798            let n_work_view = ndarray::ArrayViewMut1::from(s.n_work.as_mut_slice());
799            let p_work_view = ndarray::ArrayViewMut1::from(s.p_work.as_mut_slice());
800            design_matrix_apply_view_into(&self.x_design, v, x_v_view.view_mut());
801            self.matvec_with_shared_xz_into(x_v_view.view(), v, out, n_work_view, p_work_view);
802        });
803    }
804
805    fn mul_basis_columns_into(&self, start: usize, mut out: ArrayViewMut2<'_, f64>) {
806        let cols = out.ncols();
807        assert!(start + cols <= self.p);
808
809        let n_obs = self.w_diag.len();
810        let mut basis = Array1::<f64>::zeros(self.p);
811        let mut x_col = Array1::<f64>::zeros(n_obs);
812        let mut dx_col = Array1::<f64>::zeros(n_obs);
813        let mut weighted = Array1::<f64>::zeros(n_obs);
814        let mut term = Array1::<f64>::zeros(self.p);
815
816        for local_col in 0..cols {
817            let global_col = start + local_col;
818            let mut out_col = out.column_mut(local_col);
819            out_col.assign(&self.s_psi.column(global_col));
820
821            design_matrix_column_into(&self.x_design, global_col, x_col.view_mut());
822            Zip::from(weighted.view_mut())
823                .and(self.w_diag.view())
824                .and(x_col.view())
825                .par_for_each(|dst, &w, &x| *dst = w * x);
826            term.assign(
827                &self
828                    .implicit_deriv
829                    .transpose_mul(self.axis, &weighted.view())
830                    .expect("radial scalar evaluation failed during implicit hyper transpose_mul"),
831            );
832            out_col += &term;
833
834            basis[global_col] = 1.0;
835            dx_col.assign(
836                &self
837                    .implicit_deriv
838                    .forward_mul(self.axis, &basis.view())
839                    .expect("radial scalar evaluation failed during implicit hyper forward_mul"),
840            );
841            basis[global_col] = 0.0;
842
843            Zip::from(weighted.view_mut())
844                .and(self.w_diag.view())
845                .and(dx_col.view())
846                .par_for_each(|dst, &w, &dx| *dst = w * dx);
847            design_matrix_transpose_apply_view_into(
848                &self.x_design,
849                weighted.view(),
850                term.view_mut(),
851            );
852            out_col += &term;
853
854            // Non-Gaussian third-derivative correction column j: shared kernel.
855            self.accumulate_c_correction_xt_into(
856                x_col.view(),
857                weighted.view_mut(),
858                term.view_mut(),
859                out_col,
860            );
861        }
862    }
863
864    fn bilinear(&self, v: &Array1<f64>, u: &Array1<f64>) -> f64 {
865        self.bilinear_view(v.view(), u.view())
866    }
867
868    fn bilinear_view(&self, v: ArrayView1<'_, f64>, u: ArrayView1<'_, f64>) -> f64 {
869        assert_eq!(v.len(), self.p);
870        assert_eq!(u.len(), self.p);
871
872        let x_v = design_matrix_apply_view(&self.x_design, v);
873        let x_u = design_matrix_apply_view(&self.x_design, u);
874        let dx_v = self
875            .implicit_deriv
876            .forward_mul(self.axis, &v)
877            .expect("radial scalar evaluation failed during implicit hyper forward_mul");
878        let dx_u = self
879            .implicit_deriv
880            .forward_mul(self.axis, &u)
881            .expect("radial scalar evaluation failed during implicit hyper forward_mul");
882
883        let w = &*self.w_diag;
884        let mut design = 0.0;
885        for i in 0..w.len() {
886            design += dx_v[i] * w[i] * x_u[i];
887            design += dx_u[i] * w[i] * x_v[i];
888        }
889
890        design += self.c_correction_bilinear(&x_v, &x_u);
891
892        let penalty = dense_bilinear(&self.s_psi, v, u);
893
894        design + penalty
895    }
896
897    fn is_implicit(&self) -> bool {
898        true
899    }
900
901    fn as_any(&self) -> &(dyn std::any::Any + 'static) {
902        self
903    }
904
905    /// Compute `tr(F^T B F)` directly via fused chunked BLAS3 GEMMs on the
906    /// shared X and the shared raw kernel matrix, bypassing the rank-many
907    /// separate matvecs the default impl would run through the lazy /
908    /// operator-backed design.
909    ///
910    /// **Why this matters:** the default trait impl is
911    ///   `let bf = self.mul_mat(F); (F ⊙ bf).sum()`
912    /// which calls `mul_vec_into` per column of `F` (rank columns). On a
913    /// lazy Duchon / Matérn / CTN design each `mul_vec_into` triggers a
914    /// full `O(n · p · kernel_eval)` row-streamed matvec — and with rank ≈ p
915    /// at large-scale shape (16D-Duchon-aniso 32 ψ-axes, p ≈ 95, n = 320 K)
916    /// the per-axis trace landed at ~30 s. With 32 axes per outer Hessian
917    /// eval and ~5 outer iters that's the ~1 hr large-scale timeout.
918    ///
919    /// Algebra:
920    /// ```text
921    ///   B_d = D_d^T W X + X^T W D_d  + X^T diag(c) X  + S_psi
922    ///   D_d = (∂X/∂ψ_d) = K_d · Z_unproject       (raw kernel · unproject)
923    ///   tr(F^T B_d F) = 2 · ⟨W ⊙ DXF, XF⟩ + ⟨c ⊙ XF, XF⟩ + tr(F^T S_psi F)
924    /// ```
925    /// where `K_d` is the raw (n × n_knots) per-pair kernel scalar matrix
926    /// for axis `d` (`q · s_combo + c · coeff_sum · φ` per (i, j) pair) and
927    /// `Z_unproject` is the identifiability/padding back-projection.
928    ///
929    /// We compute `U_knot = unproject_matrix(F)` once at (n_knots × rank),
930    /// then for each row chunk do a fused pass:
931    ///   * `XF_chunk  = X_chunk · F`        (chunk × rank)  — shared-X GEMM
932    ///   * `Kd_chunk  = row_chunk_first_raw`(chunk × n_knots) — raw kernel
933    ///   * `DXF_chunk = Kd_chunk · U_knot`  (chunk × rank)  — single GEMM
934    /// and immediately accumulate `⟨W ⊙ DXF, XF⟩` and `⟨c ⊙ XF, XF⟩` over
935    /// the chunk, never materialising full XF or DXF.
936    ///
937    /// This replaces the previous `rank`-many `forward_mul` apply loop. On
938    /// the large-scale margslope-aniso-duchon16d shard each per-axis trace
939    /// drops from ~30 s to a single chunked-GEMM cost.
940    fn trace_projected_factor(&self, factor: &Array2<f64>) -> f64 {
941        assert_eq!(factor.nrows(), self.p);
942        let n_obs = self.w_diag.len();
943        let rank = factor.ncols();
944        if rank == 0 || n_obs == 0 {
945            return 0.0;
946        }
947        let xf = self.compute_xf(factor);
948        self.trace_projected_factor_with_xf(factor, xf.view())
949    }
950
951    /// Cached variant — *the* hot-path optimisation for large-scale outer
952    /// gradient/Hessian sweeps. Every ψ-axis built atop the same `x_design`
953    /// (e.g. all 32 ψ-axes of a marginal-slope model, or the same axis hit
954    /// from `g_factor` and `w_factor` traces) shares one chunked
955    /// `X · F` design GEMM per `(x_design, factor)` pair via
956    /// [`ProjectedFactorCache`]. With 32 axes per outer-gradient sweep and
957    /// O(rank) more cross-axis traces inside the outer-Hessian build, the
958    /// cache turns 32× redundant `O(n · p · rank)` GEMMs into a single one
959    /// per outer iter. At large-scale shape (`n = 320 K`, `p = rank = 95`) that
960    /// is the difference between minutes and seconds of design-GEMM work.
961    fn trace_projected_factor_cached(
962        &self,
963        factor: &Array2<f64>,
964        cache: &ProjectedFactorCache,
965    ) -> f64 {
966        assert_eq!(factor.nrows(), self.p);
967        let n_obs = self.w_diag.len();
968        let rank = factor.ncols();
969        if rank == 0 || n_obs == 0 {
970            return 0.0;
971        }
972        let xf = self.cached_xf(factor, cache);
973        self.trace_projected_factor_with_xf(factor, xf.view())
974    }
975}
976
977/// Row-block size that keeps each streamed `n × cols` chunk near an 8 MiB
978/// working set, with a 512-row floor so a wide design still makes useful BLAS-3
979/// progress per block, capped at the total row count. Shared by the implicit
980/// operator's row-streaming kernels so they cannot drift apart.
981pub(crate) fn byte_balanced_row_chunk(cols: usize, n_rows: usize) -> usize {
982    const TARGET_BYTES: usize = 8 * 1024 * 1024;
983    const MIN_CHUNK_ROWS: usize = 512;
984    let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
985    (TARGET_BYTES / bytes_per_row)
986        .max(MIN_CHUNK_ROWS)
987        .min(n_rows)
988}
989
990impl ImplicitHyperOperator {
991    /// Chunked `X · F` via faer SIMD-parallel GEMM. The chunk-row sizing
992    /// targets ~8 MiB live blocks so the (chunk_n × p) row slice and
993    /// (chunk_n × rank) result both stay in L2/L3 across realistic large-scale
994    /// shapes; the kernel mirrors `xt_logdet_kernel_x_diagonal`'s sizing
995    /// rule. Caller wraps this in [`Self::cached_xf`] when invariance
996    /// across ψ-axes lets one matrix serve every axis at this `(x_design,
997    /// factor)` pair.
998    pub(crate) fn compute_xf(&self, factor: &Array2<f64>) -> Array2<f64> {
999        let n_obs = self.w_diag.len();
1000        let rank = factor.ncols();
1001        let mut xf = Array2::<f64>::zeros((n_obs, rank));
1002        let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
1003        let mut start = 0usize;
1004        while start < n_obs {
1005            let end = (start + chunk_rows).min(n_obs);
1006            let rows = self
1007                .x_design
1008                .try_row_chunk(start..end)
1009                // SAFETY: `try_row_chunk` only fails on operator
1010                // implementation bugs — `start..end` is built from
1011                // `0..n_obs = 0..x_design.nrows()` with
1012                // `end = (start+chunk_rows).min(n_obs)`, so the range is
1013                // always a valid sub-range of `x_design`. Failure means the
1014                // operator broke its row-chunk contract.
1015                .unwrap_or_else(|err| {
1016                    // SAFETY: row range is a valid sub-range of x_design; failure means operator broke contract.
1017                    reml_contract_panic(format!(
1018                        "ImplicitHyperOperator::compute_xf row chunk failed: {err}"
1019                    ))
1020                });
1021            let block = gam_linalg::faer_ndarray::fast_ab(&rows, factor);
1022            xf.slice_mut(ndarray::s![start..end, ..]).assign(&block);
1023            start = end;
1024        }
1025        xf
1026    }
1027
1028    /// Look up `X · F` from the [`ProjectedFactorCache`] (compute-on-miss).
1029    /// Cache key combines the shared `x_design` Arc pointer and the
1030    /// factor's value fingerprint, so two `ImplicitHyperOperator` instances
1031    /// built atop the same `x_design` (e.g. axis-0 and axis-1 of a 32-axis
1032    /// ψ-block) consult the same cache slot and hit after the first
1033    /// computes.
1034    pub(crate) fn cached_xf(
1035        &self,
1036        factor: &Array2<f64>,
1037        cache: &ProjectedFactorCache,
1038    ) -> Arc<Array2<f64>> {
1039        let design_id = Arc::as_ptr(&self.x_design) as usize;
1040        let key = ProjectedFactorKey::from_factor_view(design_id, factor.view());
1041        cache.get_or_insert_with(key, || self.compute_xf(factor))
1042    }
1043
1044    /// Evaluate `tr(Fᵀ B_d F)` given a precomputed `X · F`. Pulls every
1045    /// per-axis-redundant `X · F` out of the inner loop so the cache (or
1046    /// caller-supplied matrix) covers every ψ-axis at once. The remaining
1047    /// per-axis work is the row-kernel build (`row_chunk_first_raw`),
1048    /// the `K_d · U_knot` GEMM, the fused `⟨W ⊙ DXF, XF⟩` inner products,
1049    /// and the small dense penalty contraction.
1050    pub(crate) fn trace_projected_factor_with_xf(
1051        &self,
1052        factor: &Array2<f64>,
1053        xf: ArrayView2<'_, f64>,
1054    ) -> f64 {
1055        let rank = factor.ncols();
1056        let n_obs = self.w_diag.len();
1057        assert_eq!(xf.dim(), (n_obs, rank));
1058
1059        // Once: unproject F to raw knot space → (n_knots × rank).
1060        let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
1061
1062        // Match the chunk sizing `xt_logdet_kernel_x_diagonal` uses so the
1063        // live block stays in L2/L3 across realistic large-scale shapes.
1064        let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs);
1065
1066        let w = self.w_diag.as_ref();
1067        let c_opt = self.c_x_psi_beta.as_ref().map(|arc| arc.as_ref());
1068        let mut design_total = 0.0_f64;
1069        let mut correction_total = 0.0_f64;
1070        let mut start = 0usize;
1071        while start < n_obs {
1072            let end = (start + chunk_rows).min(n_obs);
1073            let chunk_n = end - start;
1074
1075            // Cached-or-precomputed X·F slice for this chunk.
1076            let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
1077
1078            // Raw kernel scalars for axis d on this chunk, then a single
1079            // (chunk × n_knots) · (n_knots × rank) GEMM gives DXF_chunk.
1080            let kd_chunk = self
1081                .implicit_deriv
1082                .row_chunk_first_raw(self.axis, start..end)
1083                .expect("radial scalar evaluation failed during implicit hyper forward_mul_matrix");
1084            let dxf_chunk = gam_linalg::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
1085
1086            // Fused inner-product accumulation.
1087            for i_local in 0..chunk_n {
1088                let i = start + i_local;
1089                let w_i = w[i];
1090                let dxf_row = dxf_chunk.row(i_local);
1091                let xf_row = xf_chunk.row(i_local);
1092                for k in 0..rank {
1093                    design_total += dxf_row[k] * w_i * xf_row[k];
1094                }
1095                if let Some(c) = c_opt {
1096                    let c_i = c[i];
1097                    for k in 0..rank {
1098                        let v = xf_row[k];
1099                        correction_total += c_i * v * v;
1100                    }
1101                }
1102            }
1103            start = end;
1104        }
1105
1106        // Penalty trace: tr(F^T S_psi F) via dense BLAS3.
1107        let s_f = self.s_psi.dot(factor);
1108        let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
1109
1110        2.0 * design_total + correction_total + penalty
1111    }
1112
1113    /// Batched-axis sibling of [`Self::trace_projected_factor_with_xf`].
1114    /// Returns `tr(Fᵀ B_d F)` for every `(axis, s_psi, c_x_psi_beta)` triple
1115    /// in `axes`, sharing the unproject-and-row-sweep work across axes that
1116    /// only differ in their axis index / penalty matrix / correction vector.
1117    pub(crate) fn trace_projected_factor_all_axes_with_xf(
1118        &self,
1119        factor: &Array2<f64>,
1120        xf: ArrayView2<'_, f64>,
1121        axes: &[(usize, &Array2<f64>, Option<&Array1<f64>>)],
1122    ) -> Vec<f64> {
1123        let rank = factor.ncols();
1124        let n_obs = self.w_diag.len();
1125        assert_eq!(xf.dim(), (n_obs, rank));
1126
1127        let u_knot = self.implicit_deriv.unproject_matrix(&factor.view());
1128
1129        let chunk_rows = byte_balanced_row_chunk(self.p + rank, n_obs.max(1));
1130
1131        let w = self.w_diag.as_ref();
1132        let mut design_totals = vec![0.0_f64; axes.len()];
1133        let mut correction_totals = vec![0.0_f64; axes.len()];
1134
1135        let mut start = 0usize;
1136        while start < n_obs {
1137            let end = (start + chunk_rows).min(n_obs);
1138            let chunk_n = end - start;
1139            let xf_chunk = xf.slice(ndarray::s![start..end, ..]);
1140
1141            for (axis_idx, (axis, _s_psi, c_opt_axis)) in axes.iter().enumerate() {
1142                let kd_chunk = self
1143                    .implicit_deriv
1144                    .row_chunk_first_raw(*axis, start..end)
1145                    .expect(
1146                        "radial scalar evaluation failed during \
1147                         trace_projected_factor_all_axes_with_xf",
1148                    );
1149                let dxf_chunk = gam_linalg::faer_ndarray::fast_ab(&kd_chunk, &u_knot);
1150
1151                for i_local in 0..chunk_n {
1152                    let i = start + i_local;
1153                    let w_i = w[i];
1154                    let dxf_row = dxf_chunk.row(i_local);
1155                    let xf_row = xf_chunk.row(i_local);
1156                    for k in 0..rank {
1157                        design_totals[axis_idx] += dxf_row[k] * w_i * xf_row[k];
1158                    }
1159                    if let Some(c) = c_opt_axis {
1160                        let c_i = c[i];
1161                        for k in 0..rank {
1162                            let v = xf_row[k];
1163                            correction_totals[axis_idx] += c_i * v * v;
1164                        }
1165                    }
1166                }
1167            }
1168            start = end;
1169        }
1170
1171        axes.iter()
1172            .enumerate()
1173            .map(|(idx, (_axis, s_psi, _c_opt_axis))| {
1174                let s_f = s_psi.dot(factor);
1175                let penalty: f64 = factor.iter().zip(s_f.iter()).map(|(&f, &s)| f * s).sum();
1176                2.0 * design_totals[idx] + correction_totals[idx] + penalty
1177            })
1178            .collect()
1179    }
1180
1181    pub(crate) fn accumulate_c_correction_xt_into(
1182        &self,
1183        x_col: ArrayView1<'_, f64>,
1184        mut n_work: ArrayViewMut1<'_, f64>,
1185        mut p_work: ArrayViewMut1<'_, f64>,
1186        mut out_col: ArrayViewMut1<'_, f64>,
1187    ) {
1188        let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
1189            return;
1190        };
1191        let c = c_x_psi_beta.as_ref();
1192        assert_eq!(x_col.len(), c.len());
1193        assert_eq!(n_work.len(), c.len());
1194        assert_eq!(p_work.len(), self.p);
1195
1196        for i in 0..c.len() {
1197            n_work[i] = c[i] * x_col[i];
1198        }
1199        design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
1200        out_col += &p_work;
1201    }
1202
1203    pub(crate) fn c_correction_bilinear(&self, x_v: &Array1<f64>, x_u: &Array1<f64>) -> f64 {
1204        let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() else {
1205            return 0.0;
1206        };
1207        x_v.iter()
1208            .zip(x_u.iter())
1209            .zip(c_x_psi_beta.iter())
1210            .map(|((&xv, &xu), &c)| xv * c * xu)
1211            .sum()
1212    }
1213
1214    /// Compute the design-part bilinear form u^T (X^T C_d X) z using precomputed
1215    /// shared X-multiplies, avoiding the full B_d matvec.
1216    ///
1217    /// The design part of B_d is:
1218    ///   (∂X/∂ψ_d)^T W X + X^T W (∂X/∂ψ_d)
1219    ///
1220    /// For vectors z and u, the bilinear form u^T [design_part] z equals:
1221    ///   ((∂X/∂ψ_d) u)^T (W (Xz)) + (Xu)^T (W ((∂X/∂ψ_d) z))
1222    ///   = 2 * (w ⊙ y_vec)^T dx_z       [when u = u, z = z]
1223    ///
1224    /// where y_vec = X u, dx_z = (∂X/∂ψ_d) z.
1225    ///
1226    /// But the full bilinear form is NOT symmetric in its dependence on z vs u
1227    /// through the design derivative, so we compute both cross-terms:
1228    ///   dx_z^T (w ⊙ y_vec) + dx_u^T (w ⊙ x_vec)
1229    ///
1230    /// # Arguments
1231    /// - `x_vec`: X z (precomputed, shared across axes)
1232    /// - `y_vec`: X u (precomputed, shared across axes)
1233    /// - `z`: the probe vector (needed for forward_mul and penalty)
1234    /// - `u`: H⁻¹ z (needed for forward_mul and penalty)
1235    ///
1236    /// # Returns
1237    /// The full bilinear form u^T B_d z = design_part + penalty_part.
1238    pub fn bilinear_with_shared_x(
1239        &self,
1240        x_vec: &Array1<f64>,
1241        y_vec: &Array1<f64>,
1242        z: &Array1<f64>,
1243        u: &Array1<f64>,
1244    ) -> f64 {
1245        // Design part: dx_z^T (w ⊙ y_vec) + dx_u^T (w ⊙ x_vec)
1246        let dx_z = self
1247            .implicit_deriv
1248            .forward_mul(self.axis, &z.view())
1249            .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1250        let dx_u = self
1251            .implicit_deriv
1252            .forward_mul(self.axis, &u.view())
1253            .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1254
1255        let mut design = 0.0f64;
1256        let w = &*self.w_diag;
1257        for i in 0..x_vec.len() {
1258            let wi = w[i];
1259            design += dx_z[i] * wi * y_vec[i];
1260            design += dx_u[i] * wi * x_vec[i];
1261        }
1262
1263        // Non-Gaussian fixed-β third-derivative correction:
1264        //   uᵀ Xᵀ diag(c ⊙ X_{ψ_d} β̂) X z = Σ_i (X u)_i · c_x_psi_beta_i · (X z)_i
1265        //   = Σ_i y_vec[i] · c_x_psi_beta[i] · x_vec[i]
1266        if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
1267            let c = c_x_psi_beta.as_ref();
1268            for i in 0..x_vec.len() {
1269                design += y_vec[i] * c[i] * x_vec[i];
1270            }
1271        }
1272
1273        // Penalty part: u^T S_psi z
1274        let penalty = dense_bilinear(&self.s_psi, z.view(), u.view());
1275
1276        design + penalty
1277    }
1278
1279    /// Compute the design-part contribution to A_d z without the X^T step.
1280    ///
1281    /// Returns the n-vector C_d (X z) where C_d encodes the diagonal weighting.
1282    /// Specifically: (∂X/∂ψ_d)^T maps FROM n-space, but for stochastic trace
1283    /// estimation we need q_d = A_d z = X^T (C_d x_vec) + P_d z.
1284    ///
1285    /// This method computes q_d = A_d z using the shared x_vec = X z:
1286    ///   q_d = (∂X/∂ψ_d)^T (W (X z)) + X^T (W ((∂X/∂ψ_d) z)) + S_psi z
1287    /// which is the standard mul_vec but we can share x_vec across axes.
1288    pub fn matvec_with_shared_xz_into(
1289        &self,
1290        x_vec: ArrayView1<'_, f64>,
1291        z: ArrayView1<'_, f64>,
1292        mut out: ArrayViewMut1<'_, f64>,
1293        mut n_work: ArrayViewMut1<'_, f64>,
1294        mut p_work: ArrayViewMut1<'_, f64>,
1295    ) {
1296        assert_eq!(z.len(), self.p);
1297        assert_eq!(out.len(), self.p);
1298        assert_eq!(n_work.len(), self.w_diag.len());
1299        assert_eq!(p_work.len(), self.p);
1300
1301        let w = &*self.w_diag;
1302        for i in 0..w.len() {
1303            n_work[i] = w[i] * x_vec[i];
1304        }
1305        let term1 = self
1306            .implicit_deriv
1307            .transpose_mul(self.axis, &n_work.view())
1308            .expect("radial scalar evaluation failed during implicit hyper transpose_mul");
1309        out.assign(&term1);
1310
1311        let dx_z = self
1312            .implicit_deriv
1313            .forward_mul(self.axis, &z)
1314            .expect("radial scalar evaluation failed during implicit hyper forward_mul");
1315        for i in 0..w.len() {
1316            n_work[i] = w[i] * dx_z[i];
1317        }
1318        design_matrix_transpose_apply_view_into(&self.x_design, n_work.view(), p_work.view_mut());
1319        out += &p_work;
1320
1321        dense_matvec_into(&self.s_psi, z, p_work.view_mut());
1322        out += &p_work;
1323
1324        // Non-Gaussian fixed-β third-derivative correction.
1325        if let Some(c_x_psi_beta) = self.c_x_psi_beta.as_ref() {
1326            let c = c_x_psi_beta.as_ref();
1327            for i in 0..w.len() {
1328                n_work[i] = c[i] * x_vec[i];
1329            }
1330            design_matrix_transpose_apply_view_into(
1331                &self.x_design,
1332                n_work.view(),
1333                p_work.view_mut(),
1334            );
1335            out += &p_work;
1336        }
1337    }
1338}
1339
1340/// Operator-backed fixed-β Hessian drift for sparse-exact τ coordinates.
1341///
1342/// This stays in the original sparse/native coefficient basis and computes the
1343/// exact first-order τ Hessian drift
1344///   B_τ = X_τᵀ W X + Xᵀ W X_τ + Xᵀ diag(c ⊙ X_τ β̂) X + S_τ − (H_φ)_{τ}|_β
1345/// without materializing the full dense matrix up front.
1346pub struct SparseDirectionalHyperOperator {
1347    /// Original-basis design derivative X_τ.
1348    pub(crate) x_tau: super::super::HyperDesignDerivative,
1349    /// Design matrix X in the sparse-native basis.
1350    pub(crate) x_design: DesignMatrix,
1351    /// Working weights W (diagonal) — observed-information curvature, signed
1352    /// for non-canonical links.  Carried as the owned [`gam_linalg::matrix::SignedWeightsArc`]
1353    /// newtype so the sign character is construction-enforced at the operator
1354    /// struct boundary.
1355    pub(crate) w_diag: gam_linalg::matrix::SignedWeightsArc,
1356    /// Penalty derivative S_τ.
1357    pub(crate) s_tau: Array2<f64>,
1358    /// Fixed-β non-Gaussian curvature term c ⊙ (X_τ β̂), if applicable.
1359    pub(crate) c_x_tau_beta: Option<Array1<f64>>,
1360    /// Fixed-β Firth partial Hessian drift (H_φ)_{τ}|_β, if applicable.
1361    pub(crate) firth_hphi_tau_partial: Option<Array2<f64>>,
1362    /// Total coefficient dimension.
1363    pub(crate) p: usize,
1364}
1365
1366impl HyperOperator for SparseDirectionalHyperOperator {
1367    fn dim(&self) -> usize {
1368        self.p
1369    }
1370
1371    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
1372        assert_eq!(v.len(), self.p);
1373
1374        // X v
1375        let x_v = self.x_design.matrixvectormultiply(v);
1376
1377        // X_tauᵀ (W (X v))
1378        let w_x_v = &*self.w_diag * &x_v;
1379        let term1 = self
1380            .x_tau
1381            .transpose_mul_original(&w_x_v)
1382            .expect("SparseDirectionalHyperOperator transpose product should be shape-consistent");
1383
1384        // Xᵀ (W (X_tau v))
1385        let x_tau_v = self
1386            .x_tau
1387            .forward_mul_original(v)
1388            .expect("SparseDirectionalHyperOperator forward product should be shape-consistent");
1389        let w_x_tau_v = &*self.w_diag * &x_tau_v;
1390        let term2 = self.x_design.transpose_vector_multiply(&w_x_tau_v);
1391
1392        // S_tau v
1393        let term3 = self.s_tau.dot(v);
1394
1395        let mut out = term1 + term2 + term3;
1396
1397        // Non-Gaussian fixed-beta curvature: Xᵀ diag(c ⊙ X_tau β̂) X v
1398        if let Some(c_x_tau_beta) = self.c_x_tau_beta.as_ref() {
1399            let weighted = c_x_tau_beta * &x_v;
1400            out += &self.x_design.transpose_vector_multiply(&weighted);
1401        }
1402
1403        // Firth fixed-beta partial: subtract (H_φ)_{τ}|_β v
1404        if let Some(hphi_tau_partial) = self.firth_hphi_tau_partial.as_ref() {
1405            out -= &hphi_tau_partial.dot(v);
1406        }
1407
1408        out
1409    }
1410
1411    fn is_implicit(&self) -> bool {
1412        false
1413    }
1414    fn as_any(&self) -> &(dyn std::any::Any + 'static) {
1415        self
1416    }
1417}
1418
1419/// Matrix-free GLM cubic-correction drift `C[v] = −Xᵀ diag(c ⊙ X v) X`
1420/// (rows masked to the active Hessian-curvature surface, sign folded into
1421/// the stored diagonal).
1422///
1423/// # Why this must stay an operator (#901 layer 2)
1424///
1425/// The spectral logdet kernel evaluates `tr(H⁺ · C)` as
1426/// `Σ_a (1/σ_a) · u_aᵀ C u_a` over the eigenpairs of `H_pen`. For a
1427/// near-null eigenvector (`σ_min ~ 1e−4` on the Duchon fixtures) the true
1428/// quadratic form is tiny — `‖X u_a‖² ≲ σ_a / w_min` — but a DENSE
1429/// materialization of `C` computes it as a cancellation across entries of
1430/// magnitude `‖C‖`, leaving roundoff `~ ε‖C‖p` that the kernel then
1431/// amplifies by `1/σ_min`. On the iso-κ Duchon binomial FD drivers this
1432/// turned a true cubic trace of `−0.30` into `+39.0`, and `~−7.7e5` on the
1433/// κ-scaled ψ arms where `‖C‖ ~ λ · ∂S/∂ψ` — the dominant #901 blow-up.
1434///
1435/// In operator form the kernel probes `C · u_a = −Xᵀ(d ⊙ (X u_a))`: the
1436/// cancellation happens inside the `X u_a` matvec (error `~ ε‖X‖‖u_a‖`),
1437/// and the quadratic form is the *square* of that already-small vector —
1438/// tiny² stays tiny, so the `1/σ_a` amplification acts on a relatively
1439/// accurate value. This is the same stability argument as evaluating
1440/// leverages via `(X u)ᵀ d (X u)` instead of `uᵀ (XᵀdX) u`.
1441pub struct GlmCurvatureCorrectionOperator {
1442    /// Design matrix X in the transformed basis (matrix-free capable).
1443    pub(crate) x_design: DesignMatrix,
1444    /// Pre-masked, sign-folded diagonal `−(c ⊙ X v)` over active rows.
1445    pub(crate) neg_c_xv: Array1<f64>,
1446    /// Total coefficient dimension.
1447    pub(crate) p: usize,
1448}
1449
1450impl HyperOperator for GlmCurvatureCorrectionOperator {
1451    fn dim(&self) -> usize {
1452        self.p
1453    }
1454
1455    fn mul_vec(&self, v: &Array1<f64>) -> Array1<f64> {
1456        assert_eq!(v.len(), self.p);
1457        let x_v = self.x_design.matrixvectormultiply(v);
1458        let weighted = &self.neg_c_xv * &x_v;
1459        self.x_design.transpose_vector_multiply(&weighted)
1460    }
1461
1462    fn as_any(&self) -> &(dyn std::any::Any + 'static) {
1463        self
1464    }
1465
1466    fn is_implicit(&self) -> bool {
1467        false
1468    }
1469}
1470
1471// ═══════════════════════════════════════════════════════════════════════════
1472//  Data structures
1473// ═══════════════════════════════════════════════════════════════════════════