Skip to main content

gam_solve/reml/
assembly.rs

1//! Canonical `InnerSolution` assembler.
2//!
3//! No production code outside this module may construct
4//! `InnerSolutionBuilder::new(...)` or call `reml_laml_evaluate(...)`.
5//! Tests are exempt.
6//!
7//! All families and runtime paths provide ingredients and call
8//! [`InnerAssembly::evaluate`] or [`InnerAssembly::build`].
9
10use super::reml_outer_engine::{
11    BarrierConfig, ContractedPsiSecondOrderFn, DispersionHandling, EvalMode, FixedDriftDerivFn,
12    HessianDerivativeProvider, HessianOperator, HyperCoord, HyperCoordPair, InnerSolution,
13    InnerSolutionBuilder, PenaltyCoordinate, PenaltyLogdetDerivs, PenaltySubspaceTrace,
14    RemlLamlResult, penalty_matrix_root, reml_laml_evaluate,
15};
16use gam_linalg::faer_ndarray::fast_xt_diag_y;
17use crate::model_types::ProjectedKktResidual;
18use ndarray::{Array1, Array2};
19use rayon::iter::{
20    IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
21};
22use rayon::slice::ParallelSliceMut;
23use std::sync::Arc;
24
25// ═══════════════════════════════════════════════════════════════════════════
26//  Streaming weighted dense-design products
27// ═══════════════════════════════════════════════════════════════════════════
28
29/// Dense weighted-product work below this approximate flop count stays on the
30/// caller thread and uses the existing faer GEMM path. Above the threshold we
31/// stream rows through rayon-local accumulation buffers to avoid materializing
32/// weighted n×p design copies at large scale.
33pub(crate) const DENSE_WEIGHTED_PRODUCT_PAR_FLOPS: usize = 8_000_000;
34pub(crate) const DENSE_ROW_SCALE_PAR_CELLS: usize = 64 * 1024;
35
36#[derive(Clone, Copy)]
37pub(crate) enum DenseRowScaleMode {
38    Direct,
39    InversePositiveOrZero,
40}
41
42#[inline]
43pub(crate) fn dense_weighted_chunk_rows(cols: usize) -> usize {
44    const TARGET_BYTES: usize = 2 * 1024 * 1024;
45    const MIN_ROWS: usize = 256;
46    const MAX_ROWS: usize = 4096;
47    let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
48    (TARGET_BYTES / bytes_per_row).clamp(MIN_ROWS, MAX_ROWS)
49}
50
51/// Write `diag(scale) · x` into `out`, preserving `out`'s allocation when its
52/// shape already matches `x`.
53///
54/// This replaces the former clone-and-row-scale pattern used by REML assembly
55/// tests and Firth kernels. It is intentionally simple and deterministic for a
56/// fixed row order.
57pub(crate) fn row_scale_dense_into(x: &Array2<f64>, scale: &Array1<f64>, out: &mut Array2<f64>) {
58    assert_eq!(x.nrows(), scale.len(), "scale length must match row count");
59    if out.raw_dim() != x.raw_dim() {
60        *out = Array2::<f64>::zeros(x.raw_dim());
61    }
62    out.assign(x);
63    row_scale_dense_in_place(out, scale, DenseRowScaleMode::Direct);
64}
65
66/// Scale each row of `out` by `1 / scale[row]`, writing zero rows where
67/// `scale[row] <= 0`.
68pub(crate) fn row_scale_dense_in_place_by_inverse_positive_or_zero(
69    out: &mut Array2<f64>,
70    scale: &Array1<f64>,
71) {
72    row_scale_dense_in_place(out, scale, DenseRowScaleMode::InversePositiveOrZero);
73}
74
75pub(crate) fn row_scale_dense_in_place(
76    out: &mut Array2<f64>,
77    scale: &Array1<f64>,
78    mode: DenseRowScaleMode,
79) {
80    assert_eq!(
81        out.nrows(),
82        scale.len(),
83        "scale length must match row count"
84    );
85    let ncols = out.ncols();
86    if ncols == 0 {
87        return;
88    }
89
90    let cells = out.nrows().saturating_mul(ncols);
91    if cells >= DENSE_ROW_SCALE_PAR_CELLS
92        && rayon::current_num_threads() > 1
93        && out.is_standard_layout()
94        && let Some(slice) = out.as_slice_memory_order_mut()
95    {
96        slice
97            .par_chunks_mut(ncols)
98            .zip(
99                scale
100                    .as_slice()
101                    .expect("Array1 must be contiguous")
102                    .par_iter(),
103            )
104            .for_each(|(row_values, &w)| scale_dense_row_values(row_values, w, mode));
105        return;
106    }
107
108    ndarray::Zip::from(out.rows_mut())
109        .and(scale.view())
110        .for_each(|mut row, &w| {
111            if let Some(row_values) = row.as_slice_mut() {
112                scale_dense_row_values(row_values, w, mode);
113            } else {
114                match mode {
115                    DenseRowScaleMode::Direct => row *= w,
116                    DenseRowScaleMode::InversePositiveOrZero => {
117                        if w > 0.0 {
118                            row *= w.recip();
119                        } else {
120                            row.fill(0.0);
121                        }
122                    }
123                }
124            }
125        });
126}
127
128#[inline]
129pub(crate) fn scale_dense_row_values(row_values: &mut [f64], scale: f64, mode: DenseRowScaleMode) {
130    match mode {
131        DenseRowScaleMode::Direct => {
132            for value in row_values {
133                *value *= scale;
134            }
135        }
136        DenseRowScaleMode::InversePositiveOrZero => {
137            if scale > 0.0 {
138                let inv = scale.recip();
139                for value in row_values {
140                    *value *= inv;
141                }
142            } else {
143                for value in row_values {
144                    *value = 0.0;
145                }
146            }
147        }
148    }
149}
150
151pub(crate) fn accumulate_weighted_cross_rows(
152    out: &mut Array2<f64>,
153    left: &Array2<f64>,
154    right: &Array2<f64>,
155    weights: &Array1<f64>,
156    row_start: usize,
157    row_end: usize,
158) {
159    let p = left.ncols();
160    let q = right.ncols();
161    for i in row_start..row_end {
162        let wi = weights[i];
163        if wi == 0.0 {
164            continue;
165        }
166        for a in 0..p {
167            let scaled = wi * left[[i, a]];
168            if scaled == 0.0 {
169                continue;
170            }
171            for b in 0..q {
172                out[[a, b]] += scaled * right[[i, b]];
173            }
174        }
175    }
176}
177
178pub(crate) fn accumulate_xt_diag_x_upper_rows(
179    out: &mut Array2<f64>,
180    x: &Array2<f64>,
181    diag: &Array1<f64>,
182    row_start: usize,
183    row_end: usize,
184) {
185    let p = x.ncols();
186    for i in row_start..row_end {
187        let wi = diag[i];
188        if wi == 0.0 {
189            continue;
190        }
191        for a in 0..p {
192            let scaled = wi * x[[i, a]];
193            if scaled == 0.0 {
194                continue;
195            }
196            for b in a..p {
197                out[[a, b]] += scaled * x[[i, b]];
198            }
199        }
200    }
201}
202
203/// Compute `leftᵀ diag(weights) right` using streamed row-block
204/// accumulation for large products. The parallel path allocates one dense
205/// p×q accumulator per rayon worker/task instead of allocating an n×q weighted
206/// design matrix.
207pub(crate) fn weighted_cross_dense(
208    left: &Array2<f64>,
209    right: &Array2<f64>,
210    weights: &Array1<f64>,
211) -> Array2<f64> {
212    assert_eq!(left.nrows(), right.nrows());
213    assert_eq!(left.nrows(), weights.len());
214    let n = weights.len();
215    let p = left.ncols();
216    let q = right.ncols();
217    if n == 0 || p == 0 || q == 0 {
218        return Array2::<f64>::zeros((p, q));
219    }
220
221    let work = n.saturating_mul(p).saturating_mul(q);
222    if rayon::current_num_threads() <= 1 || work < DENSE_WEIGHTED_PRODUCT_PAR_FLOPS {
223        return fast_xt_diag_y(left, weights, right);
224    }
225
226    let chunk_rows = crate::parallel_strategy::row_reduction_chunk_rows(
227        n,
228        p.saturating_mul(q),
229        p.saturating_mul(q),
230        DENSE_WEIGHTED_PRODUCT_PAR_FLOPS,
231    )
232    .unwrap_or_else(|| dense_weighted_chunk_rows(p + q).min(n));
233    let chunks = n.div_ceil(chunk_rows);
234    (0..chunks)
235        .into_par_iter()
236        .fold(
237            || Array2::<f64>::zeros((p, q)),
238            |mut local, chunk| {
239                let start = chunk * chunk_rows;
240                let end = (start + chunk_rows).min(n);
241                accumulate_weighted_cross_rows(&mut local, left, right, weights, start, end);
242                local
243            },
244        )
245        .reduce(
246            || Array2::<f64>::zeros((p, q)),
247            |mut a, b| {
248                a += &b;
249                a
250            },
251        )
252}
253
254/// Compute `xᵀ diag(diag) x`. For small products this reuses `weighted` as an
255/// n×p row-scaled scratch and dispatches to faer GEMM. For large products it
256/// streams rows into rayon-local p×p buffers and mirrors the accumulated upper
257/// triangle, avoiding weighted design materialization.
258pub(crate) fn xt_diag_x_dense_into(
259    x: &Array2<f64>,
260    diag: &Array1<f64>,
261    weighted: &mut Array2<f64>,
262) -> Array2<f64> {
263    let (n, p) = x.dim();
264    assert_eq!(diag.len(), n, "diag length must match row count");
265    if n == 0 || p == 0 {
266        return Array2::<f64>::zeros((p, p));
267    }
268
269    let work = n.saturating_mul(p).saturating_mul(p);
270    if rayon::current_num_threads() <= 1 || work < DENSE_WEIGHTED_PRODUCT_PAR_FLOPS {
271        row_scale_dense_into(x, diag, weighted);
272        return gam_linalg::faer_ndarray::fast_atb(x, weighted);
273    }
274
275    let chunk_rows = crate::parallel_strategy::row_reduction_chunk_rows(
276        n,
277        p.saturating_mul(p),
278        p.saturating_mul(p),
279        DENSE_WEIGHTED_PRODUCT_PAR_FLOPS,
280    )
281    .unwrap_or_else(|| dense_weighted_chunk_rows(p).min(n));
282    let chunks = n.div_ceil(chunk_rows);
283    let mut out = (0..chunks)
284        .into_par_iter()
285        .fold(
286            || Array2::<f64>::zeros((p, p)),
287            |mut local, chunk| {
288                let start = chunk * chunk_rows;
289                let end = (start + chunk_rows).min(n);
290                accumulate_xt_diag_x_upper_rows(&mut local, x, diag, start, end);
291                local
292            },
293        )
294        .reduce(
295            || Array2::<f64>::zeros((p, p)),
296            |mut a, b| {
297                a += &b;
298                a
299            },
300        );
301    for a in 0..p {
302        for b in 0..a {
303            out[[a, b]] = out[[b, a]];
304        }
305    }
306    out
307}
308
309// ═══════════════════════════════════════════════════════════════════════════
310//  InnerAssembly — the single entry point for InnerSolution construction
311// ═══════════════════════════════════════════════════════════════════════════
312
313/// All ingredients needed to assemble an `InnerSolution`.
314///
315/// Callers fill in the required fields and override optional ones as needed.
316/// The assembler builds the `InnerSolution` via `InnerSolutionBuilder` and
317/// calls `reml_laml_evaluate` — the only production code path that does so.
318pub struct InnerAssembly<'dp> {
319    // === Required core ===
320    pub log_likelihood: f64,
321    pub penalty_quadratic: f64,
322    pub beta: Array1<f64>,
323    pub n_observations: usize,
324    pub hessian_op: std::sync::Arc<dyn HessianOperator>,
325    pub penalty_coords: Vec<PenaltyCoordinate>,
326    pub penalty_logdet: PenaltyLogdetDerivs,
327    pub dispersion: DispersionHandling,
328    pub rho_curvature_scale: f64,
329    pub rho_prior: gam_problem::RhoPrior,
330    pub hessian_logdet_correction: f64,
331    pub penalty_subspace_trace: Option<Arc<PenaltySubspaceTrace>>,
332
333    // === Optional decorations (sensible defaults when None/zero) ===
334    pub deriv_provider: Option<Box<dyn HessianDerivativeProvider + 'dp>>,
335    /// Jeffreys/Firth scalar contribution to the LAML cost. Tier-A GLM callers
336    /// construct it from the dense operator (`ExactJeffreysTerm::new`); the
337    /// Tier-B coupled joint path installs the value-only carrier
338    /// (`ExactJeffreysTerm::value_only`) so the cost subtracts the same gated
339    /// `Φ(β̂)` its inner Newton optimized (gam#979).
340    pub firth: Option<crate::estimate::reml::reml_outer_engine::ExactJeffreysTerm>,
341    pub nullspace_dim: Option<f64>,
342    pub barrier_config: Option<BarrierConfig>,
343    pub kkt_residual: Option<ProjectedKktResidual>,
344    /// Active linear-inequality constraint rows at the converged inner
345    /// iterate. When `Some`, the unified evaluator builds the
346    /// constraint-aware kernel `K_T = K_S − K_S Aᵀ (A K_S Aᵀ)⁻¹ A K_S`
347    /// for per-coordinate mode responses `v_k = ∂β/∂ρ_k`.
348    pub active_constraints: Option<Arc<crate::model_types::ActiveLinearConstraintBlock>>,
349
350    // === Extended hyperparameter coordinates ===
351    pub ext_coords: Vec<HyperCoord>,
352    pub ext_coord_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
353    pub rho_ext_pair_fn: Option<Box<dyn Fn(usize, usize) -> HyperCoordPair + Send + Sync>>,
354    pub fixed_drift_deriv: Option<FixedDriftDerivFn>,
355    /// Direction-contracted ψψ second-order hook (#740). When set, the
356    /// outer-Hessian operator builder skips the `K²` per-pair ψψ assembly and
357    /// applies this once per matvec.
358    pub contracted_psi_second_order: Option<ContractedPsiSecondOrderFn>,
359}
360
361impl<'dp> InnerAssembly<'dp> {
362    /// Build the `InnerSolution` from these ingredients.
363    pub fn build(self) -> InnerSolution<'dp> {
364        let mut builder = InnerSolutionBuilder::new(
365            self.log_likelihood,
366            self.penalty_quadratic,
367            self.beta,
368            self.n_observations,
369            self.hessian_op,
370            self.penalty_coords,
371            self.penalty_logdet,
372            self.dispersion,
373        );
374        builder = builder.rho_curvature_scale(self.rho_curvature_scale);
375        builder = builder.rho_prior(self.rho_prior);
376        builder = builder.hessian_logdet_correction(self.hessian_logdet_correction);
377        builder = builder.penalty_subspace_trace(self.penalty_subspace_trace);
378
379        if let Some(dp) = self.deriv_provider {
380            builder = builder.deriv_provider(dp);
381        }
382        builder = builder.firth_term(self.firth);
383        if let Some(nd) = self.nullspace_dim {
384            builder = builder.nullspace_dim_override(nd);
385        }
386        builder = builder.barrier_config(self.barrier_config);
387        builder = builder.kkt_residual(self.kkt_residual);
388        builder = builder.active_constraints(self.active_constraints);
389
390        if !self.ext_coords.is_empty() {
391            builder = builder.ext_coords(self.ext_coords);
392        }
393        if let Some(f) = self.ext_coord_pair_fn {
394            builder = builder.ext_coord_pair_fn(f);
395        }
396        if let Some(f) = self.rho_ext_pair_fn {
397            builder = builder.rho_ext_pair_fn(f);
398        }
399        if let Some(f) = self.fixed_drift_deriv {
400            builder = builder.fixed_drift_deriv(f);
401        }
402        builder = builder.contracted_psi_second_order(self.contracted_psi_second_order);
403
404        builder.build()
405    }
406
407    /// Build and evaluate in one step.
408    pub fn evaluate(
409        self,
410        rho: &[f64],
411        mode: EvalMode,
412        prior: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
413    ) -> Result<RemlLamlResult, String> {
414        let solution = self.build();
415        reml_laml_evaluate(&solution, rho, mode, prior)
416    }
417}
418
419/// Evaluate a pre-built `InnerSolution` through the unified evaluator.
420///
421/// Use this when the caller needs the `InnerSolution` to outlive the evaluation
422/// (e.g., for EFS step computation after evaluation). Prefer
423/// [`InnerAssembly::evaluate`] when the solution is not needed afterwards.
424pub fn evaluate_solution(
425    solution: &InnerSolution<'_>,
426    rho: &[f64],
427    mode: EvalMode,
428    prior: Option<(f64, Array1<f64>, Option<Array2<f64>>)>,
429) -> Result<RemlLamlResult, String> {
430    reml_laml_evaluate(solution, rho, mode, prior)
431}
432
433// ═══════════════════════════════════════════════════════════════════════════
434//  Penalty coordinate helpers for family modules
435// ═══════════════════════════════════════════════════════════════════════════
436
437/// Descriptor for a single penalty block within the parameter vector.
438pub struct PenaltyBlockDesc<'a> {
439    pub matrix: &'a Array2<f64>,
440    pub range_start: usize,
441    pub range_end: usize,
442}
443
444/// Build `PenaltyCoordinate`s from block descriptors.
445///
446/// Replaces the manual `penalty_matrix_root` + `from_block_root` loops
447/// in `survival.rs` and `custom_family.rs`.
448pub fn penalty_coords_from_blocks(
449    blocks: &[PenaltyBlockDesc],
450    total_dim: usize,
451) -> Result<Vec<PenaltyCoordinate>, String> {
452    blocks
453        .iter()
454        .map(|b| {
455            let root = penalty_matrix_root(b.matrix)?;
456            Ok(PenaltyCoordinate::from_block_root(
457                root,
458                b.range_start,
459                b.range_end,
460                total_dim,
461            ))
462        })
463        .collect()
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use approx::assert_relative_eq;
470    use ndarray::Array2;
471
472    pub(crate) fn assert_matrix_close(
473        got: &Array2<f64>,
474        expected: &Array2<f64>,
475        epsilon: f64,
476        max_relative: f64,
477    ) {
478        assert_eq!(got.dim(), expected.dim());
479        for ((i, j), &value) in got.indexed_iter() {
480            assert_relative_eq!(
481                value,
482                expected[[i, j]],
483                epsilon = epsilon,
484                max_relative = max_relative
485            );
486        }
487    }
488
489    pub(crate) fn deterministic_matrix(n: usize, p: usize, phase: f64) -> Array2<f64> {
490        Array2::from_shape_fn((n, p), |(i, j)| {
491            let a = ((i as f64 + 1.0) * (j as f64 + 3.0) + phase).sin();
492            let b = ((i as f64 + 5.0) / (j as f64 + 2.0) + phase).cos();
493            0.25 * a + 0.75 * b
494        })
495    }
496
497    pub(crate) fn deterministic_weights(n: usize) -> Array1<f64> {
498        Array1::from_shape_fn(n, |i| {
499            if i % 17 == 0 {
500                0.0
501            } else {
502                0.2 + ((i as f64 + 1.0) * 0.013).sin().abs()
503            }
504        })
505    }
506
507    pub(crate) fn weighted_cross_reference(
508        left: &Array2<f64>,
509        right: &Array2<f64>,
510        weights: &Array1<f64>,
511    ) -> Array2<f64> {
512        let mut out = Array2::<f64>::zeros((left.ncols(), right.ncols()));
513        for i in 0..weights.len() {
514            for a in 0..left.ncols() {
515                let scaled = weights[i] * left[[i, a]];
516                for b in 0..right.ncols() {
517                    out[[a, b]] += scaled * right[[i, b]];
518                }
519            }
520        }
521        out
522    }
523
524    #[test]
525    pub(crate) fn row_scale_dense_into_reuses_buffer_and_matches_reference() {
526        let x = deterministic_matrix(37, 11, 0.3);
527        let weights = deterministic_weights(x.nrows());
528        let mut out = Array2::<f64>::zeros(x.raw_dim());
529        let ptr = out.as_ptr();
530        row_scale_dense_into(&x, &weights, &mut out);
531        assert_eq!(out.as_ptr(), ptr);
532        for i in 0..x.nrows() {
533            for j in 0..x.ncols() {
534                assert_relative_eq!(out[[i, j]], x[[i, j]] * weights[i], epsilon = 0.0);
535            }
536        }
537    }
538
539    #[test]
540    pub(crate) fn weighted_cross_dense_matches_rowwise_reference_at_large_scale_block_size() {
541        let left = deterministic_matrix(2048, 96, 0.1);
542        let right = deterministic_matrix(2048, 64, 0.7);
543        let weights = deterministic_weights(left.nrows());
544        let got = weighted_cross_dense(&left, &right, &weights);
545        let expected = weighted_cross_reference(&left, &right, &weights);
546        assert_matrix_close(&got, &expected, 5e-10, 5e-12);
547    }
548
549    #[test]
550    pub(crate) fn xt_diag_x_dense_into_matches_symmetric_reference_at_large_scale_block_size() {
551        let x = deterministic_matrix(1024, 96, 1.1);
552        let weights = deterministic_weights(x.nrows());
553        let mut scratch = Array2::<f64>::zeros((0, 0));
554        let got = xt_diag_x_dense_into(&x, &weights, &mut scratch);
555        let expected = weighted_cross_reference(&x, &x, &weights);
556        assert_matrix_close(&got, &expected, 3e-10, 5e-12);
557        for i in 0..got.nrows() {
558            for j in 0..got.ncols() {
559                assert_relative_eq!(got[[i, j]], got[[j, i]], epsilon = 0.0);
560            }
561        }
562    }
563}