Skip to main content

gam_models/
marginal_slope_shared.rs

1//! Shared kernels and outer-evaluation infrastructure for the
2//! marginal-slope family of GAMs (BMS, survival, latent survival).
3//!
4//! # Outer-row subsampling
5//!
6//! At large scale (n ≥ tens of thousands) the outer rho-gradient is
7//! a sum-over-rows trace whose per-row cost is dominated by the cubic
8//! cell-moment kernel. The pieces here — [`AutoOuterSubsampleOptions`],
9//! [`auto_outer_score_subsample`], [`maybe_install_auto_outer_subsample`],
10//! and [`build_outer_score_subsample`] — implement a stratified
11//! Horvitz–Thompson estimator that replaces the full row sum with an
12//! unbiased sample, gated by
13//! [`crate::custom_family::BlockwiseFitOptions::auto_outer_subsample`]
14//! and enabled by default for large marginal-slope fits.
15//!
16//! `maybe_install_auto_outer_subsample` is the entry point family
17//! impls call: it consults the per-family phase counter and the
18//! per-family last-ρ mutex (used to detect distinct outer steps),
19//! installs a stratified mask for the first
20//! `BMS_AUTO_SUBSAMPLE_PHASE1_BUDGET` (or family analog) outer
21//! evaluations, and reverts to full data afterward so the BFGS/ARC
22//! convergence target `outer_tol` is reached on exact gradients
23//! rather than chasing the stochastic noise floor.
24//!
25//! This subsampling is **complementary** to the trace-estimator tier
26//! system documented at the top of `solver::reml::reml_outer_engine` (exact /
27//! Hutchinson multi-target / Hutch++ single-target). They operate on
28//! orthogonal axes — the trace estimators reduce work *within* the
29//! Hessian structure for a fixed row set; subsampling reduces the row
30//! set itself for the family-specific row-trace path.
31
32use crate::custom_family::{CustomFamilyBlockPsiDerivative, ParameterBlockSpec};
33use crate::cubic_cell_kernel::{self, DenestedPartitionCell, LocalSpanCubic};
34use crate::outer_subsample::{OuterScoreSubsample, WeightedOuterRow};
35use gam_math::jet_partitions::MultiDirJet;
36use ndarray::{Array1, Array2, Axis};
37use std::ops::Range;
38use std::sync::Arc;
39
40/// Canonical inner-cache `beta_seed` validator passed to the generic
41/// outer-engine (`optimize_spatial_length_scale_exact_joint`).
42///
43/// The outer solver hands back the converged inner `beta` at each accepted
44/// ρ-step so the next inner solve can warm-start from it. This guards that
45/// cached vector for non-finite entries (which would poison the warm start)
46/// and, when clean, stashes it into the caller's `pending` cell.
47///
48/// This is the single source of truth for the seed callback: every family
49/// that wires up the exact-joint outer engine (survival location-scale,
50/// bernoulli marginal-slope, survival marginal-slope) routes through here
51/// instead of re-deriving the identical closure, which previously drifted in
52/// error construction (`EstimationError::InvalidInput(...)` vs
53/// `bail_invalid_estim!`).
54pub fn make_beta_seed_validator(
55    pending: &std::cell::RefCell<Option<Array1<f64>>>,
56) -> impl FnMut(
57    &Array1<f64>,
58)
59    -> Result<gam_solve::rho_optimizer::SeedOutcome, crate::model_types::EstimationError>
60+ '_ {
61    move |beta: &Array1<f64>| {
62        bail_if_cached_beta_non_finite(beta)?;
63        // Stage the seed for promotion at the next eval, where the freshly
64        // built per-block widths are known. A width mismatch is reconciled
65        // there (the eval's `from_cached_beta` logs and falls back to a cold
66        // β for that step) — never an error that aborts the fit. Staging a
67        // finite β always succeeds, so the contract reply is `Installed`.
68        pending.replace(Some(beta.clone()));
69        Ok(gam_solve::rho_optimizer::SeedOutcome::Installed)
70    }
71}
72
73/// Canonical non-finite guard on a cached inner `beta`.
74///
75/// Single source of truth for the `"cached inner beta contains non-finite
76/// entries"` check + error: the full seed closure
77/// ([`make_beta_seed_validator`]) and the bare warm-start length-then-finite
78/// guards in `custom_family` all route through this so the predicate and the
79/// error construction (`EstimationError::InvalidInput`) never drift apart.
80pub use gam_problem::bail_if_cached_beta_non_finite;
81
82#[inline]
83pub const fn eval_coeff4_at(coefficients: &[f64; 4], z: f64) -> f64 {
84    ((coefficients[3] * z + coefficients[2]) * z + coefficients[1]) * z + coefficients[0]
85}
86
87#[inline]
88pub fn add_scaled_coeff4(target: &mut [f64; 4], source: &[f64; 4], scale: f64) {
89    for j in 0..4 {
90        target[j] += scale * source[j];
91    }
92}
93
94#[inline]
95fn coeff4_dot(left: &[f64; 4], right: &[f64; 4]) -> f64 {
96    left[0] * right[0] + left[1] * right[1] + left[2] * right[2] + left[3] * right[3]
97}
98
99#[inline]
100pub const fn scale_coeff4(source: [f64; 4], scale: f64) -> [f64; 4] {
101    [
102        source[0] * scale,
103        source[1] * scale,
104        source[2] * scale,
105        source[3] * scale,
106    ]
107}
108
109pub fn probit_frailty_scale(gaussian_frailty_sd: Option<f64>) -> f64 {
110    let sigma = gaussian_frailty_sd.unwrap_or(0.0);
111    if sigma <= 0.0 {
112        1.0
113    } else {
114        crate::survival::lognormal_kernel::ProbitFrailtyScaleJet::from_log_sigma(
115            sigma.ln(),
116        )
117        .s
118    }
119}
120
121pub(crate) fn probit_frailty_scale_multi_dir_jet(
122    gaussian_frailty_sd: Option<f64>,
123    missing_sigma_message: &str,
124    n_dirs: usize,
125    first_masks: &[usize],
126    second_masks: &[usize],
127) -> Result<MultiDirJet, String> {
128    let sigma = gaussian_frailty_sd.ok_or_else(|| missing_sigma_message.to_string())?;
129    let jet = crate::survival::lognormal_kernel::ProbitFrailtyScaleJet::from_log_sigma(
130        sigma.ln(),
131    );
132    let mut coeffs = Vec::with_capacity(1 + first_masks.len() + second_masks.len());
133    coeffs.push((0usize, jet.s));
134    coeffs.extend(first_masks.iter().copied().map(|mask| (mask, jet.ds)));
135    coeffs.extend(second_masks.iter().copied().map(|mask| (mask, jet.d2s)));
136    Ok(MultiDirJet::with_coeffs(n_dirs, &coeffs))
137}
138
139/// Per-sweep scale jets for the shared directional obj/grad/hess kernel.
140///
141/// Every marginal-slope family forms its exact-Newton primary terms by
142/// differentiating the same row negative-log directional jet
143/// (`row_neglog_directional_with_scale_jet`) along unit primary directions.
144/// The sweep appends one unit direction for the gradient pass and two for the
145/// Hessian pass on top of a fixed *leading* prefix of directions, scaling the
146/// frailty kernel with an order-matched [`MultiDirJet`] each time. The `obj`
147/// slot is `Some` only when the caller also wants the zeroth-order objective
148/// (the prefix-only evaluation); psi-Hessian directional sweeps leave it `None`.
149#[derive(Clone)]
150pub(crate) struct DirectionalScaleJets {
151    pub(crate) obj: Option<MultiDirJet>,
152    pub(crate) grad: MultiDirJet,
153    pub(crate) hess: MultiDirJet,
154}
155
156/// Output of [`directional_obj_grad_hess`]: the (optional) zeroth-order
157/// objective, the full primary gradient, and the symmetric primary Hessian.
158pub(crate) struct DirectionalPrimaryTerms {
159    pub(crate) objective: f64,
160    pub(crate) grad: Array1<f64>,
161    pub(crate) hess: Array2<f64>,
162}
163
164/// Shared exact-Newton primary directional sweep for the marginal-slope
165/// families (Bernoulli, survival, latent survival).
166///
167/// Given a fixed `leading` prefix of directions and a family-specific row jet
168/// evaluator `eval`, this builds the objective (when requested), the gradient
169/// `g_a = D[leading, e_a] φ`, and the symmetric Hessian
170/// `H_ab = D[leading, e_a, e_b] φ`, where `e_a` is the `a`-th unit primary
171/// direction (length `primary_dim`) and `D[..]` is the mixed directional
172/// derivative the row jet returns. `eval(dirs, scale)` must return the highest
173/// mixed-partial coefficient of the row negative-log jet for the supplied
174/// directions and scale jet — exactly what each family's
175/// `row_neglog_directional_with_scale_jet` produces.
176///
177/// Centralizing the sweep removes the per-family duplication of the
178/// obj/grad/hess loop nest, which is the single most drift-prone piece of the
179/// exact-Newton stack: a stray index or a missing symmetric assignment in one
180/// copy silently destabilizes only that family's optimizer.
181pub(crate) fn directional_obj_grad_hess<Eval>(
182    primary_dim: usize,
183    leading: &[&Array1<f64>],
184    scales: &DirectionalScaleJets,
185    eval: Eval,
186) -> Result<DirectionalPrimaryTerms, String>
187where
188    Eval: Fn(&[&Array1<f64>], &MultiDirJet) -> Result<f64, String>,
189{
190    let objective = if let Some(scale_obj) = scales.obj.as_ref() {
191        eval(leading, scale_obj)?
192    } else {
193        0.0
194    };
195
196    let unit = |a: usize| -> Array1<f64> {
197        let mut da = Array1::<f64>::zeros(primary_dim);
198        da[a] = 1.0;
199        da
200    };
201
202    let units: Vec<Array1<f64>> = (0..primary_dim).map(unit).collect();
203
204    let mut grad = Array1::<f64>::zeros(primary_dim);
205    let mut dirs: Vec<&Array1<f64>> = Vec::with_capacity(leading.len() + 2);
206    for a in 0..primary_dim {
207        dirs.clear();
208        dirs.extend_from_slice(leading);
209        dirs.push(&units[a]);
210        grad[a] = eval(&dirs, &scales.grad)?;
211    }
212
213    let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
214    for a in 0..primary_dim {
215        for b in a..primary_dim {
216            dirs.clear();
217            dirs.extend_from_slice(leading);
218            dirs.push(&units[a]);
219            dirs.push(&units[b]);
220            let value = eval(&dirs, &scales.hess)?;
221            hess[[a, b]] = value;
222            hess[[b, a]] = value;
223        }
224    }
225
226    Ok(DirectionalPrimaryTerms {
227        objective,
228        grad,
229        hess,
230    })
231}
232
233fn zero_local_span_cubic() -> LocalSpanCubic {
234    LocalSpanCubic {
235        left: 0.0,
236        right: 1.0,
237        c0: 0.0,
238        c1: 0.0,
239        c2: 0.0,
240        c3: 0.0,
241    }
242}
243
244pub(crate) fn build_denested_partition_cells(
245    a: f64,
246    b: f64,
247    score_warp: Option<&crate::bms::DeviationRuntime>,
248    beta_h: Option<&Array1<f64>>,
249    link_dev: Option<&crate::bms::DeviationRuntime>,
250    beta_w: Option<&Array1<f64>>,
251    scale: f64,
252) -> Result<Vec<DenestedPartitionCell>, String> {
253    let score_breaks = score_warp
254        .map(|runtime| runtime.breakpoints().to_vec())
255        .unwrap_or_default();
256    let link_breaks = link_dev
257        .map(|runtime| runtime.breakpoints().to_vec())
258        .unwrap_or_default();
259
260    let mut cells = cubic_cell_kernel::build_denested_partition_cells_with_tails(
261        a,
262        b,
263        &score_breaks,
264        &link_breaks,
265        |z| {
266            if let (Some(runtime), Some(beta)) = (score_warp, beta_h) {
267                runtime.local_cubic_at(beta, z)
268            } else {
269                Ok(zero_local_span_cubic())
270            }
271        },
272        |u| {
273            if let (Some(runtime), Some(beta)) = (link_dev, beta_w) {
274                runtime.local_cubic_at(beta, u)
275            } else {
276                Ok(zero_local_span_cubic())
277            }
278        },
279    )?;
280    if scale != 1.0 {
281        for partition_cell in &mut cells {
282            partition_cell.cell.c0 *= scale;
283            partition_cell.cell.c1 *= scale;
284            partition_cell.cell.c2 *= scale;
285            partition_cell.cell.c3 *= scale;
286        }
287    }
288    Ok(cells)
289}
290
291pub(crate) struct ObservedDenestedCellPartials {
292    pub(crate) coeff: [f64; 4],
293    pub(crate) dc_da: [f64; 4],
294    pub(crate) dc_db: [f64; 4],
295    pub(crate) dc_daa: [f64; 4],
296    pub(crate) dc_dab: [f64; 4],
297    pub(crate) dc_dbb: [f64; 4],
298    pub(crate) dc_daaa: [f64; 4],
299    pub(crate) dc_daab: [f64; 4],
300    pub(crate) dc_dabb: [f64; 4],
301    pub(crate) dc_dbbb: [f64; 4],
302}
303
304pub(crate) fn observed_denested_cell_partials(
305    z_obs: f64,
306    a: f64,
307    b: f64,
308    score_warp: Option<&crate::bms::DeviationRuntime>,
309    beta_h: Option<&Array1<f64>>,
310    link_dev: Option<&crate::bms::DeviationRuntime>,
311    beta_w: Option<&Array1<f64>>,
312    scale: f64,
313) -> Result<ObservedDenestedCellPartials, String> {
314    let zero_score_span = zero_local_span_cubic();
315    let zero_link_span = zero_local_span_cubic();
316    let u_obs = a + b * z_obs;
317    let score_span_obs = if let (Some(runtime), Some(beta_h)) = (score_warp, beta_h) {
318        runtime.local_cubic_at(beta_h, z_obs)?
319    } else {
320        zero_score_span
321    };
322    let link_span_obs = if let (Some(runtime), Some(beta_w)) = (link_dev, beta_w) {
323        runtime.local_cubic_at(beta_w, u_obs)?
324    } else {
325        zero_link_span
326    };
327    let coeff = scale_coeff4(
328        cubic_cell_kernel::denested_cell_coefficients(score_span_obs, link_span_obs, a, b),
329        scale,
330    );
331    let (dc_da_raw, dc_db_raw) =
332        cubic_cell_kernel::denested_cell_coefficient_partials(score_span_obs, link_span_obs, a, b);
333    let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
334        cubic_cell_kernel::denested_cell_second_partials(score_span_obs, link_span_obs, a, b);
335    let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) =
336        cubic_cell_kernel::denested_cell_third_partials(link_span_obs);
337    Ok(ObservedDenestedCellPartials {
338        coeff,
339        dc_da: scale_coeff4(dc_da_raw, scale),
340        dc_db: scale_coeff4(dc_db_raw, scale),
341        dc_daa: scale_coeff4(dc_daa_raw, scale),
342        dc_dab: scale_coeff4(dc_dab_raw, scale),
343        dc_dbb: scale_coeff4(dc_dbb_raw, scale),
344        dc_daaa: scale_coeff4(dc_daaa, scale),
345        dc_daab: scale_coeff4(dc_daab, scale),
346        dc_dabb: scale_coeff4(dc_dabb, scale),
347        dc_dbbb: scale_coeff4(dc_dbbb, scale),
348    })
349}
350
351pub(crate) fn add_two_surface_psi_outer(
352    block_i: usize,
353    psi_row_i: &Array1<f64>,
354    block_j: usize,
355    psi_row_j: &Array1<f64>,
356    alpha: f64,
357    marginal_block: usize,
358    logslope_block: usize,
359    h_mm: &mut Array2<f64>,
360    h_gg: &mut Array2<f64>,
361    h_mg: &mut Array2<f64>,
362) {
363    if alpha == 0.0 {
364        return;
365    }
366    let col_i = psi_row_i.view().insert_axis(Axis(1));
367    let row_j = psi_row_j.view().insert_axis(Axis(0));
368
369    if block_i == block_j {
370        let col_j = psi_row_j.view().insert_axis(Axis(1));
371        let row_i = psi_row_i.view().insert_axis(Axis(0));
372        let target = match block_i {
373            b if b == marginal_block => h_mm,
374            b if b == logslope_block => h_gg,
375            _ => return,
376        };
377        ndarray::linalg::general_mat_mul(alpha, &col_i, &row_j, 1.0, target);
378        ndarray::linalg::general_mat_mul(alpha, &col_j, &row_i, 1.0, target);
379    } else {
380        let (marginal_row, logslope_row) = if block_i == marginal_block {
381            (psi_row_i, psi_row_j)
382        } else {
383            (psi_row_j, psi_row_i)
384        };
385        let m_col = marginal_row.view().insert_axis(Axis(1));
386        let g_row = logslope_row.view().insert_axis(Axis(0));
387        ndarray::linalg::general_mat_mul(alpha, &m_col, &g_row, 1.0, h_mg);
388    }
389}
390
391pub(crate) fn add_optional_vector(left: &mut Option<Array1<f64>>, right: &Option<Array1<f64>>) {
392    if let (Some(left), Some(right)) = (left.as_mut(), right.as_ref()) {
393        *left += right;
394    }
395}
396
397pub(crate) fn add_optional_matrix(left: &mut Option<Array2<f64>>, right: &Option<Array2<f64>>) {
398    if let (Some(left), Some(right)) = (left.as_mut(), right.as_ref()) {
399        *left += right;
400    }
401}
402
403pub(crate) fn psi_derivative_location(
404    derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
405    psi_index: usize,
406) -> Option<(usize, usize)> {
407    let mut cursor = 0usize;
408    for (block_idx, block) in derivative_blocks.iter().enumerate() {
409        if psi_index < cursor + block.len() {
410            return Some((block_idx, psi_index - cursor));
411        }
412        cursor += block.len();
413    }
414    None
415}
416
417pub(crate) fn is_sigma_aux_index(
418    gaussian_frailty_sd: Option<f64>,
419    derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
420    psi_index: usize,
421) -> bool {
422    let total = derivative_blocks.iter().map(Vec::len).sum::<usize>();
423    if gaussian_frailty_sd.is_none() || total == 0 || psi_index != total - 1 {
424        return false;
425    }
426    let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index) else {
427        return false;
428    };
429    let deriv = &derivative_blocks[block_idx][local_idx];
430    deriv.penalty_index.is_none()
431        && deriv.x_psi.is_empty()
432        && deriv.s_psi.is_empty()
433        && deriv.s_psi_components.is_none()
434        && deriv.x_psi_psi.is_none()
435        && deriv.s_psi_psi.is_none()
436}
437
438/// Predicate used by every marginal-slope family's persistent-warm-start
439/// fingerprint guard: the caller's parameter blocks must each have row count
440/// matching the family's `n`, and the list must be non-empty.
441#[inline]
442pub(crate) fn parameter_block_specs_match_rows(
443    specs: &[ParameterBlockSpec],
444    expected_n: usize,
445) -> bool {
446    !specs.is_empty()
447        && specs
448            .iter()
449            .all(|spec| spec.design.nrows() == expected_n && spec.offset.len() == expected_n)
450}
451
452#[derive(Clone, Copy)]
453pub(crate) struct CoeffSupport {
454    pub(crate) include_primary: bool,
455    pub(crate) include_h: bool,
456    pub(crate) include_w: bool,
457}
458
459impl CoeffSupport {
460    #[inline]
461    pub(crate) fn without_primary(self) -> Self {
462        Self {
463            include_primary: false,
464            ..self
465        }
466    }
467}
468
469pub(crate) struct SparsePrimaryCoeffJetView<'a> {
470    primary_index: usize,
471    h_range: Option<Range<usize>>,
472    w_range: Option<Range<usize>>,
473    pub(crate) first: &'a [[f64; 4]],
474    pub(crate) a_first: &'a [[f64; 4]],
475    pub(crate) b_first: &'a [[f64; 4]],
476    pub(crate) aa_first: &'a [[f64; 4]],
477    pub(crate) ab_first: &'a [[f64; 4]],
478    pub(crate) bb_first: &'a [[f64; 4]],
479    pub(crate) aaa_first: &'a [[f64; 4]],
480    pub(crate) aab_first: &'a [[f64; 4]],
481    pub(crate) abb_first: &'a [[f64; 4]],
482    pub(crate) bbb_first: &'a [[f64; 4]],
483}
484
485impl<'a> SparsePrimaryCoeffJetView<'a> {
486    pub(crate) fn new(
487        primary_index: usize,
488        h_range: Option<&Range<usize>>,
489        w_range: Option<&Range<usize>>,
490        first: &'a [[f64; 4]],
491        a_first: &'a [[f64; 4]],
492        b_first: &'a [[f64; 4]],
493        aa_first: &'a [[f64; 4]],
494        ab_first: &'a [[f64; 4]],
495        bb_first: &'a [[f64; 4]],
496        aaa_first: &'a [[f64; 4]],
497        aab_first: &'a [[f64; 4]],
498        abb_first: &'a [[f64; 4]],
499        bbb_first: &'a [[f64; 4]],
500    ) -> Self {
501        Self {
502            primary_index,
503            h_range: h_range.cloned(),
504            w_range: w_range.cloned(),
505            first,
506            a_first,
507            b_first,
508            aa_first,
509            ab_first,
510            bb_first,
511            aaa_first,
512            aab_first,
513            abb_first,
514            bbb_first,
515        }
516    }
517
518    #[inline]
519    fn in_h_range(&self, idx: usize) -> bool {
520        self.h_range
521            .as_ref()
522            .map(|range| range.contains(&idx))
523            .unwrap_or(false)
524    }
525
526    #[inline]
527    fn in_w_range(&self, idx: usize) -> bool {
528        self.w_range
529            .as_ref()
530            .map(|range| range.contains(&idx))
531            .unwrap_or(false)
532    }
533
534    #[inline]
535    fn param_supported(&self, idx: usize, support: CoeffSupport) -> bool {
536        (support.include_primary && idx == self.primary_index)
537            || (support.include_h && self.in_h_range(idx))
538            || (support.include_w && self.in_w_range(idx))
539    }
540
541    pub(crate) fn directional_family(
542        &self,
543        family: &[[f64; 4]],
544        dir: &Array1<f64>,
545        support: CoeffSupport,
546    ) -> [f64; 4] {
547        let mut out = [0.0; 4];
548        if support.include_primary {
549            add_scaled_coeff4(
550                &mut out,
551                &family[self.primary_index],
552                dir[self.primary_index],
553            );
554        }
555        if support.include_h
556            && let Some(h_range) = self.h_range.as_ref()
557        {
558            for idx in h_range.clone() {
559                add_scaled_coeff4(&mut out, &family[idx], dir[idx]);
560            }
561        }
562        if support.include_w
563            && let Some(w_range) = self.w_range.as_ref()
564        {
565            for idx in w_range.clone() {
566                add_scaled_coeff4(&mut out, &family[idx], dir[idx]);
567            }
568        }
569        out
570    }
571
572    pub(crate) fn add_directional_family_adjoint(
573        &self,
574        family: &[[f64; 4]],
575        coeff_adjoint: &[f64; 4],
576        support: CoeffSupport,
577        direction_adjoint: &mut [f64],
578    ) {
579        assert!(direction_adjoint.len() > self.primary_index);
580        if support.include_primary {
581            direction_adjoint[self.primary_index] +=
582                coeff4_dot(coeff_adjoint, &family[self.primary_index]);
583        }
584        if support.include_h
585            && let Some(h_range) = self.h_range.as_ref()
586        {
587            for idx in h_range.clone() {
588                direction_adjoint[idx] += coeff4_dot(coeff_adjoint, &family[idx]);
589            }
590        }
591        if support.include_w
592            && let Some(w_range) = self.w_range.as_ref()
593        {
594            for idx in w_range.clone() {
595                direction_adjoint[idx] += coeff4_dot(coeff_adjoint, &family[idx]);
596            }
597        }
598    }
599
600    pub(crate) fn mixed_directional_from_b_family(
601        &self,
602        family: &[[f64; 4]],
603        dir_u: &Array1<f64>,
604        dir_v: &Array1<f64>,
605        support: CoeffSupport,
606    ) -> [f64; 4] {
607        let mut out = [0.0; 4];
608        let dir_u_primary = dir_u[self.primary_index];
609        let dir_v_primary = dir_v[self.primary_index];
610        if support.include_primary {
611            add_scaled_coeff4(
612                &mut out,
613                &family[self.primary_index],
614                dir_u_primary * dir_v_primary,
615            );
616        }
617        if support.include_h
618            && let Some(h_range) = self.h_range.as_ref()
619        {
620            for idx in h_range.clone() {
621                add_scaled_coeff4(
622                    &mut out,
623                    &family[idx],
624                    dir_u_primary * dir_v[idx] + dir_v_primary * dir_u[idx],
625                );
626            }
627        }
628        if support.include_w
629            && let Some(w_range) = self.w_range.as_ref()
630        {
631            for idx in w_range.clone() {
632                add_scaled_coeff4(
633                    &mut out,
634                    &family[idx],
635                    dir_u_primary * dir_v[idx] + dir_v_primary * dir_u[idx],
636                );
637            }
638        }
639        out
640    }
641
642    pub(crate) fn param_directional_from_b_family(
643        &self,
644        family: &[[f64; 4]],
645        param: usize,
646        dir: &Array1<f64>,
647        support: CoeffSupport,
648    ) -> [f64; 4] {
649        if param == self.primary_index {
650            return self.directional_family(family, dir, support);
651        }
652        if self.param_supported(param, support.without_primary()) {
653            let mut out = [0.0; 4];
654            add_scaled_coeff4(&mut out, &family[param], dir[self.primary_index]);
655            return out;
656        }
657        [0.0; 4]
658    }
659
660    pub(crate) fn add_param_directional_from_b_family_adjoint(
661        &self,
662        family: &[[f64; 4]],
663        param: usize,
664        coeff_adjoint: &[f64; 4],
665        support: CoeffSupport,
666        direction_adjoint: &mut [f64],
667    ) {
668        assert!(direction_adjoint.len() > self.primary_index);
669        if param == self.primary_index {
670            self.add_directional_family_adjoint(family, coeff_adjoint, support, direction_adjoint);
671        } else if self.param_supported(param, support.without_primary()) {
672            direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[param]);
673        }
674    }
675
676    pub(crate) fn param_mixed_from_bb_family(
677        &self,
678        family: &[[f64; 4]],
679        param: usize,
680        dir_u: &Array1<f64>,
681        dir_v: &Array1<f64>,
682        support: CoeffSupport,
683    ) -> [f64; 4] {
684        if param == self.primary_index {
685            return self.mixed_directional_from_b_family(family, dir_u, dir_v, support);
686        }
687        if self.param_supported(param, support.without_primary()) {
688            let mut out = [0.0; 4];
689            add_scaled_coeff4(
690                &mut out,
691                &family[param],
692                dir_u[self.primary_index] * dir_v[self.primary_index],
693            );
694            return out;
695        }
696        [0.0; 4]
697    }
698
699    pub(crate) fn pair_from_b_family(
700        &self,
701        family: &[[f64; 4]],
702        u: usize,
703        v: usize,
704        support: CoeffSupport,
705    ) -> [f64; 4] {
706        if u == self.primary_index && v == self.primary_index {
707            if support.include_primary {
708                return family[self.primary_index];
709            }
710            return [0.0; 4];
711        }
712        if u == self.primary_index && self.param_supported(v, support.without_primary()) {
713            return family[v];
714        }
715        if v == self.primary_index && self.param_supported(u, support.without_primary()) {
716            return family[u];
717        }
718        [0.0; 4]
719    }
720
721    pub(crate) fn pair_directional_from_bb_family(
722        &self,
723        family: &[[f64; 4]],
724        u: usize,
725        v: usize,
726        dir: &Array1<f64>,
727        support: CoeffSupport,
728    ) -> [f64; 4] {
729        if u == self.primary_index && v == self.primary_index {
730            return self.directional_family(family, dir, support);
731        }
732        if u == self.primary_index && self.param_supported(v, support.without_primary()) {
733            let mut out = [0.0; 4];
734            add_scaled_coeff4(&mut out, &family[v], dir[self.primary_index]);
735            return out;
736        }
737        if v == self.primary_index && self.param_supported(u, support.without_primary()) {
738            let mut out = [0.0; 4];
739            add_scaled_coeff4(&mut out, &family[u], dir[self.primary_index]);
740            return out;
741        }
742        [0.0; 4]
743    }
744
745    pub(crate) fn add_pair_directional_from_bb_family_adjoint(
746        &self,
747        family: &[[f64; 4]],
748        u: usize,
749        v: usize,
750        coeff_adjoint: &[f64; 4],
751        support: CoeffSupport,
752        direction_adjoint: &mut [f64],
753    ) {
754        assert!(direction_adjoint.len() > self.primary_index);
755        if u == self.primary_index && v == self.primary_index {
756            self.add_directional_family_adjoint(family, coeff_adjoint, support, direction_adjoint);
757        } else if u == self.primary_index && self.param_supported(v, support.without_primary()) {
758            direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[v]);
759        } else if v == self.primary_index && self.param_supported(u, support.without_primary()) {
760            direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[u]);
761        }
762    }
763
764    pub(crate) fn pair_mixed_from_bbb_family(
765        &self,
766        family: &[[f64; 4]],
767        u: usize,
768        v: usize,
769        dir_u: &Array1<f64>,
770        dir_v: &Array1<f64>,
771        support: CoeffSupport,
772    ) -> [f64; 4] {
773        if u == self.primary_index && v == self.primary_index {
774            return self.mixed_directional_from_b_family(family, dir_u, dir_v, support);
775        }
776        if u == self.primary_index && self.param_supported(v, support.without_primary()) {
777            let mut out = [0.0; 4];
778            add_scaled_coeff4(
779                &mut out,
780                &family[v],
781                dir_u[self.primary_index] * dir_v[self.primary_index],
782            );
783            return out;
784        }
785        if v == self.primary_index && self.param_supported(u, support.without_primary()) {
786            let mut out = [0.0; 4];
787            add_scaled_coeff4(
788                &mut out,
789                &family[u],
790                dir_u[self.primary_index] * dir_v[self.primary_index],
791            );
792            return out;
793        }
794        [0.0; 4]
795    }
796}
797
798// ---------------------------------------------------------------------------
799// Outer-only stratified row subsample (Phase 1 scaffolding).
800//
801// The large-scale outer-loop score/gradient passes do O(n) work per outer
802// evaluation, which dominates wall-clock once n grows past ~10^5. To keep
803// outer-loop iterations tractable while leaving the inner PIRLS solve and the
804// final covariance assembly untouched, outer-only hot loops can be redirected
805// to iterate over a small stratified subsample with a constant rescaling
806// factor, sampled once per fit and shared via `Arc`. The subsample is
807// stratified by event/outcome × z-deciles (≤ 200 strata) so that the rescaled
808// estimator inherits the same support coverage as the full-data estimator.
809//
810// This module defines only the types and helpers; Phase 2 wires them into
811// per-row hot loops. Default state (`outer_score_subsample = None`) keeps the
812// legacy full-data behavior bit-for-bit.
813
814/// Splitmix64: deterministic single-u64 expansion. Thin wrapper over the
815/// canonical implementation in [`gam_linalg::utils::splitmix64`].
816#[inline]
817const fn splitmix64(state: &mut u64) -> u64 {
818    gam_linalg::utils::splitmix64(state)
819}
820
821/// Configuration for the automatic outer-score subsampler.
822///
823/// At large scale (n ≥ tens of thousands) the marginal-slope outer
824/// rho-gradient computes a sum-over-rows trace
825/// `tr(F Fᵀ M_k) = Σ_i row_i(k)` whose per-row work is dominated by
826/// the cell-moment kernel. Stratified Horvitz–Thompson subsampling
827/// replaces the full sum with an unbiased estimator using `K` of `N`
828/// rows; the trace cost drops from `O(N · cell_work)` to
829/// `O(K · cell_work)`.
830///
831/// # Math
832///
833/// Estimator `T̂ = Σ_{i∈S} w_i · row_i` with HT weights
834/// `w_i = N_h / K_h` (per-stratum) is unbiased: `E[T̂] = T`.
835///
836/// Variance under stratified SRS without replacement:
837/// `Var(T̂) = Σ_h N_h² (1 − K_h/N_h) S_h² / K_h`
838/// where `S_h²` is the within-stratum variance of per-row contributions.
839/// With proportional allocation `K_h = K · N_h/N`, the standard deviation
840/// of `T̂` relative to `T` is roughly
841/// `σ(T̂)/T ≈ (1/√K) · √(1 − K/N) · cv_within`
842/// where `cv_within` is the within-stratum coefficient of variation.
843///
844/// The defaults are tuned so that the relative gradient-noise σ stays
845/// below ≈ 1 % across realistic `n` ∈ [30 000, 300 000+], assuming
846/// `cv_within ≲ 1` (which holds for marginal-slope contributions
847/// because the z-decile stratification absorbs the dominant
848/// inhomogeneity).
849#[derive(Clone, Debug)]
850pub struct AutoOuterSubsampleOptions {
851    /// Below this `n`, the auto-subsampler always returns `None` (use
852    /// full data). Default 30 000.
853    pub min_n_for_auto: usize,
854    /// Floor on `K`, so the relative gradient noise stays bounded
855    /// even when the target fraction would round to a smaller `K`.
856    /// `K = max(min_k, round(n · target_fraction))`. Default 10 000
857    /// gives `σ/T ≤ 1 %` for cv_within ≤ 1 and any `n ≥ min_n_for_auto`.
858    pub min_k: usize,
859    /// Target ratio `K / n` once `n ≫ min_k`. Default 0.10.
860    pub target_fraction: f64,
861    /// RNG seed for stratified mask construction. Default
862    /// `0xA075_8AMP_LE_5UB5` (deterministic across runs at the same
863    /// `n`, so CRN holds across BFGS iterations).
864    pub seed: u64,
865    /// Family-supplied **per-unit-of-K** outer-derivative work cost.
866    ///
867    /// Despite the historical name, this is *not* a per-row quantity.
868    /// It is `predicted_outer_gradient_work / K` evaluated at the
869    /// family's reference operating point — i.e. how many work units
870    /// each additional row in the K-subsample contributes summed over
871    /// all n. The auto schedule caps `K` by
872    /// `K_work = AUTO_OUTER_WORK_BUDGET / outer_work_per_k_unit`,
873    /// guaranteeing a single outer evaluation never exceeds
874    /// [`AUTO_OUTER_WORK_BUDGET`] work units regardless of the
875    /// noise-only target. Default `1` (no effective work cap beyond
876    /// `K ≤ n`); families with measurable per-K cost (survival
877    /// marginal-slope, BMS) overwrite at the call site.
878    ///
879    /// Calibration recipe: from a profiled run,
880    ///     outer_work_per_k_unit = predicted_gradient_work / K.
881    /// For the large-scale survival marginal-slope reference
882    /// (predicted_gradient_work ≈ 4.33×10⁹ at K=19_661), this gives
883    /// ~220_000; we use 250_000 as a conservative upper bound. With
884    /// `AUTO_OUTER_WORK_BUDGET = 5×10⁸` that caps K at ~2_000.
885    pub outer_work_per_k_unit: u64,
886    /// Absolute floor on the chosen K after the noise/work caps are combined.
887    /// Default [`AUTO_OUTER_MIN_K_FLOOR`].
888    pub min_k_floor: usize,
889}
890
891/// Half-billion outer-derivative work units per evaluation. Picked so the
892/// rigid survival marginal-slope pilot Newton cycle (which previously ran
893/// ~57 min at n≈2e5 with `K=19_661`) finishes in a minute or two on
894/// commodity hardware once `K` is capped by this budget.
895pub const AUTO_OUTER_WORK_BUDGET: u64 = 500_000_000;
896
897/// Absolute floor on `K` chosen by the auto schedule. Even when the work
898/// budget would drive `K` to a handful of rows the stratified mask cannot
899/// usefully shrink below `MIN_K_FLOOR` without collapsing entire deciles
900/// of `z`-strata. Set so the resulting gradient noise (~3 %) is still
901/// usable for BFGS Phase 1 progress when the family is very expensive.
902pub const AUTO_OUTER_MIN_K_FLOOR: usize = 1_000;
903
904/// L2 distance below which two outer ρ keys are treated as the *same* outer
905/// step (a line-search retry, not a fresh outer iteration). Well below any
906/// meaningful BFGS step on log-scale ρ, well above float-noise from cloning
907/// the ρ vector. Used to keep the phase-1 budget counting outer iterations
908/// rather than per-step function evaluations.
909const AUTO_OUTER_DISTINCT_STEP_L2_TOL: f64 = 1e-10;
910
911/// Reason the auto schedule chose the reported `K`. Used by the
912/// `[family auto-subsample]` log line so operators can tell whether the
913/// noise model, the work budget, the `MIN_K_FLOOR`, or `n` itself
914/// determined the subsample size.
915#[derive(Clone, Copy, Debug, PartialEq, Eq)]
916pub enum AutoOuterCapReason {
917    Noise,
918    Work,
919    Floor,
920    NFull,
921}
922
923impl AutoOuterCapReason {
924    pub fn as_str(self) -> &'static str {
925        match self {
926            AutoOuterCapReason::Noise => "noise",
927            AutoOuterCapReason::Work => "work",
928            AutoOuterCapReason::Floor => "floor",
929            AutoOuterCapReason::NFull => "n",
930        }
931    }
932}
933
934impl Default for AutoOuterSubsampleOptions {
935    fn default() -> Self {
936        Self {
937            min_n_for_auto: 30_000,
938            min_k: 10_000,
939            target_fraction: 0.10,
940            seed: 0xA075_8A8B_1ED5_5B5C,
941            outer_work_per_k_unit: 1,
942            min_k_floor: AUTO_OUTER_MIN_K_FLOOR,
943        }
944    }
945}
946
947/// Outcome of [`AutoOuterSubsampleOptions::target_k_detailed`]: the
948/// chosen `K`, the underlying noise-only choice, the work-budget cap,
949/// and which constraint won.
950#[derive(Clone, Copy, Debug)]
951pub struct AutoOuterKChoice {
952    pub k: usize,
953    pub k_noise: usize,
954    pub k_work: usize,
955    pub cap_reason: AutoOuterCapReason,
956}
957
958impl AutoOuterSubsampleOptions {
959    /// Compute the K that this configuration would pick for a given n.
960    /// Returns `None` if `n < min_n_for_auto` (caller should not subsample).
961    pub fn target_k(&self, n: usize) -> Option<usize> {
962        self.target_k_detailed(n).map(|choice| choice.k)
963    }
964
965    /// Same as [`target_k`] but also reports the noise-only `K`, the
966    /// work-budget cap, and which constraint set the final value. Used by
967    /// [`maybe_install_auto_outer_subsample`] to surface a `cap_reason`
968    /// in the auto-subsample log line.
969    pub fn target_k_detailed(&self, n: usize) -> Option<AutoOuterKChoice> {
970        if n < self.min_n_for_auto {
971            return None;
972        }
973        let k_noise_raw = ((n as f64) * self.target_fraction).round() as usize;
974        let k_noise = k_noise_raw.max(self.min_k);
975        // Work-budget cap. `outer_work_per_k_unit == 1` is the
976        // default-1-work-unit signal that the family has not measured
977        // its per-K cost, in which case the work cap is `WORK_BUDGET`
978        // and typically dominated by `n`.
979        let work_per_k = self.outer_work_per_k_unit.max(1);
980        let k_work_u64 = AUTO_OUTER_WORK_BUDGET / work_per_k;
981        let k_work = usize::try_from(k_work_u64).unwrap_or(usize::MAX);
982        // Combine noise + work + n + floor in a single comparison so we
983        // can attribute the binding constraint exactly once.
984        let mut k = k_noise.min(k_work);
985        let mut cap_reason = if k_work < k_noise {
986            AutoOuterCapReason::Work
987        } else {
988            AutoOuterCapReason::Noise
989        };
990        if k < self.min_k_floor {
991            k = self.min_k_floor;
992            cap_reason = AutoOuterCapReason::Floor;
993        }
994        if k > n {
995            k = n;
996            cap_reason = AutoOuterCapReason::NFull;
997        }
998        if k >= n {
999            // Borderline: the auto schedule would cover the whole
1000            // dataset. Subsampling buys nothing.
1001            return None;
1002        }
1003        Some(AutoOuterKChoice {
1004            k,
1005            k_noise,
1006            k_work,
1007            cap_reason,
1008        })
1009    }
1010}
1011
1012/// Build a stratified outer-score subsample automatically from problem
1013/// characteristics. Returns `None` for problems too small to benefit
1014/// (the caller should fall back to the full-data path).
1015///
1016/// Stratification matches `build_outer_score_subsample`: 100 z-deciles
1017/// × the supplied secondary stratum (typically the {0, 1} response
1018/// indicator). When `stratum_secondary` is `None` the secondary
1019/// dimension collapses to a single bin.
1020///
1021/// The returned mask carries proper Horvitz–Thompson weights so that
1022/// `Σ_{i ∈ mask} weight_i · row_i` is an unbiased estimate of the
1023/// full row sum.
1024pub fn auto_outer_score_subsample(
1025    z: &[f64],
1026    stratum_secondary: Option<&[u8]>,
1027    options: &AutoOuterSubsampleOptions,
1028) -> Option<OuterScoreSubsample> {
1029    let n = z.len();
1030    let k = options.target_k(n)?;
1031    let secondary_storage;
1032    let secondary: &[u8] = if let Some(s) = stratum_secondary {
1033        if s.len() != n {
1034            // Caller error; fall through to no-subsample rather than panic.
1035            return None;
1036        }
1037        s
1038    } else {
1039        secondary_storage = vec![0u8; n];
1040        &secondary_storage
1041    };
1042    Some(build_outer_score_subsample(z, secondary, k, options.seed))
1043}
1044
1045/// Two-phase auto-subsample guard shared across marginal-slope families.
1046///
1047/// Returns `Some(cloned_options)` carrying a freshly stratified
1048/// Horvitz-Thompson mask when `options.auto_outer_subsample` is enabled, the
1049/// caller has not already supplied a mask, and
1050/// the per-family phase counter is below `phase1_budget`. Returns `None`
1051/// when the caller's options should be used unchanged (either subsample
1052/// is disabled / pre-installed, the budget is exhausted, or the problem
1053/// is too small for `auto_outer_score_subsample` to find a benefit).
1054///
1055/// The `phase_counter` and `last_rho` pair together implement
1056/// distinct-step detection: line searches re-call the family at the
1057/// same ρ during step-size retries, but the budget is meant to count
1058/// outer iterations, not function evaluations. The counter only ticks
1059/// when the incoming ρ differs from the last observed ρ in L2 by
1060/// > 1e-10 — well below any meaningful BFGS step on log-scale ρ, well
1061/// > above float-noise from cloning. The mutex around `last_rho` is the
1062/// > minimal coordination needed: `(counter, last_rho)` must update
1063/// > together so two threads cannot both decide "new ρ" and double-bump.
1064///
1065/// The transition at `phase_idx == phase1_budget` is logged exactly
1066/// once via `log::info!` with the supplied `family_label`. Each phase-1
1067/// install also logs the planned mask size and predicted gradient
1068/// noise. Callers running with auto-subsample disabled see no logging.
1069pub fn maybe_install_auto_outer_subsample(
1070    options: &crate::custom_family::BlockwiseFitOptions,
1071    z: &[f64],
1072    stratum_secondary: Option<&[u8]>,
1073    outer_rho_key: &[f64],
1074    phase_counter: &Arc<std::sync::atomic::AtomicUsize>,
1075    last_rho: &Arc<std::sync::Mutex<Option<Array1<f64>>>>,
1076    phase1_budget: usize,
1077    family_label: &'static str,
1078    outer_work_per_k_unit: u64,
1079    min_n_for_auto: usize,
1080    min_k: usize,
1081    min_k_floor: usize,
1082) -> Option<crate::custom_family::BlockwiseFitOptions> {
1083    if options.outer_score_subsample.is_some() || !options.auto_outer_subsample {
1084        return None;
1085    }
1086    let phase_idx = {
1087        let mut guard = last_rho
1088            .lock()
1089            .expect("auto_subsample_last_rho mutex poisoned");
1090        let new_step = match guard.as_ref() {
1091            None => true,
1092            Some(prev) if prev.len() != outer_rho_key.len() => true,
1093            Some(prev) => {
1094                let mut sq = 0.0_f64;
1095                for (a, b) in outer_rho_key.iter().zip(prev.iter()) {
1096                    let d = a - b;
1097                    sq += d * d;
1098                }
1099                sq.sqrt() > AUTO_OUTER_DISTINCT_STEP_L2_TOL
1100            }
1101        };
1102        if new_step {
1103            *guard = Some(Array1::from(outer_rho_key.to_vec()));
1104            phase_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
1105        } else {
1106            phase_counter
1107                .load(std::sync::atomic::Ordering::SeqCst)
1108                .saturating_sub(1)
1109        }
1110    };
1111    if phase_idx >= phase1_budget {
1112        if phase_idx == phase1_budget {
1113            log::info!(
1114                "[{family_label} auto-subsample] Phase 1 budget exhausted after {} evals; \
1115                 Phase 2 (full data) for remaining iterations",
1116                phase1_budget
1117            );
1118        }
1119        return None;
1120    }
1121    // Honour the family's per-K-unit work cost. Constructing
1122    // `AutoOuterSubsampleOptions::default()` here would silently reset
1123    // `outer_work_per_k_unit` to 1, making the work-budget cap
1124    // (`K_work = AUTO_OUTER_WORK_BUDGET / outer_work_per_k_unit`)
1125    // never bind, and letting the noise-only rule pick K ≈ 0.10·n —
1126    // which at large-scale n=195_780 is K≈19_578 instead of the survival
1127    // family's intended K≈2_000. That ~9× inflation drove the
1128    // documented 8h large-scale hang (exit 137 from resource exhaustion).
1129    let auto_options = AutoOuterSubsampleOptions {
1130        min_n_for_auto,
1131        min_k,
1132        min_k_floor,
1133        outer_work_per_k_unit: outer_work_per_k_unit.max(1),
1134        ..AutoOuterSubsampleOptions::default()
1135    };
1136    // Compute the K choice up-front so the log surfaces both the
1137    // noise-only target and the work-cap target even when stratification
1138    // ceil overshoots the picked K.
1139    let choice = auto_options.target_k_detailed(z.len())?;
1140    let mask = auto_outer_score_subsample(z, stratum_secondary, &auto_options)?;
1141    let n_full = mask.n_full;
1142    let k = mask.len();
1143    log::info!(
1144        "[{family_label} auto-subsample] phase=1 eval={}/{} n={} K={} fraction={:.3} expected_grad_noise={:.2}% work_per_k_unit={} k_noise={} k_work={} cap_reason={}",
1145        phase_idx + 1,
1146        phase1_budget,
1147        n_full,
1148        k,
1149        k as f64 / n_full.max(1) as f64,
1150        100.0 * (1.0 / (k as f64).sqrt()) * (1.0 - k as f64 / n_full.max(1) as f64).sqrt(),
1151        outer_work_per_k_unit,
1152        choice.k_noise,
1153        choice.k_work,
1154        choice.cap_reason.as_str(),
1155    );
1156    let mut cloned = options.clone();
1157    cloned.outer_score_subsample = Some(Arc::new(mask));
1158    Some(cloned)
1159}
1160
1161/// Build a deterministic stratified row subsample of size ≥ `k` from
1162/// `(z, stratum_secondary)`.
1163///
1164/// Stratification: 100 z-deciles × distinct values of `stratum_secondary`
1165/// (typically the {0,1} event/outcome indicator, giving ≤ 200 strata).
1166/// Each non-empty stratum contributes `ceil(k * stratum_size / n)` rows
1167/// drawn via a splitmix64-keyed Fisher-Yates partial shuffle so the result
1168/// is reproducible from `(seed, stratum_id)`.
1169///
1170/// The returned mask is sorted, deduplicated, and never empty when `n > 0`.
1171/// Per-row weights `w_i = N_h / k_h` (Horvitz-Thompson inverse-inclusion
1172/// weights for the stratum the row came from) are assigned to
1173/// `OuterScoreSubsample::rows`, and `weight_scale` is reported as the mean
1174/// of those weights for diagnostics only.
1175///
1176/// Panics if `z.len() != stratum_secondary.len()`.
1177pub fn build_outer_score_subsample(
1178    z: &[f64],
1179    stratum_secondary: &[u8],
1180    k: usize,
1181    seed: u64,
1182) -> OuterScoreSubsample {
1183    let n = z.len();
1184    assert_eq!(
1185        n,
1186        stratum_secondary.len(),
1187        "build_outer_score_subsample: z and stratum_secondary must have equal length",
1188    );
1189
1190    if n == 0 {
1191        return OuterScoreSubsample::with_uniform_weight(Vec::new(), 0, seed, 1.0);
1192    }
1193
1194    // If the requested subsample covers the full dataset (or more), short-
1195    // circuit to the full row set with weight 1.0 — this is a no-op
1196    // relative to the legacy full-data path.
1197    if k >= n {
1198        let mask: Vec<usize> = (0..n).collect();
1199        return OuterScoreSubsample::with_uniform_weight(mask, n, seed, 1.0);
1200    }
1201
1202    // Q = 100 z-deciles. Sort indices by z and split into Q ~equal chunks.
1203    const Q: usize = 100;
1204    let mut z_order: Vec<usize> = (0..n).collect();
1205    z_order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
1206    // decile[i] = bin index in 0..Q for row i
1207    let mut decile = vec![0u16; n];
1208    for (rank, &row) in z_order.iter().enumerate() {
1209        // Map rank in 0..n to bin in 0..Q. Using floor((rank * Q) / n)
1210        // keeps bin sizes within ±1 row of n/Q.
1211        let bin = (rank * Q) / n;
1212        let bin = bin.min(Q - 1);
1213        decile[row] = bin as u16;
1214    }
1215
1216    // Distinct secondary values (the canonical use case is {0,1}, but the
1217    // general u8 alphabet is supported transparently).
1218    let mut distinct_secondary: Vec<u8> = stratum_secondary.to_vec();
1219    distinct_secondary.sort_unstable();
1220    distinct_secondary.dedup();
1221    // stratum index = secondary_rank * Q + decile, where secondary_rank is
1222    // the position of the row's secondary value in `distinct_secondary`.
1223    let mut secondary_rank = vec![0u16; 256];
1224    for (rank, &val) in distinct_secondary.iter().enumerate() {
1225        secondary_rank[val as usize] = rank as u16;
1226    }
1227    let n_strata = distinct_secondary.len() * Q;
1228
1229    // Bucket rows by stratum.
1230    let mut strata: Vec<Vec<usize>> = vec![Vec::new(); n_strata];
1231    for i in 0..n {
1232        let s = secondary_rank[stratum_secondary[i] as usize] as usize * Q + decile[i] as usize;
1233        strata[s].push(i);
1234    }
1235
1236    // For each non-empty stratum, draw ceil(k * stratum_size / n) rows and
1237    // tag each retained row with its HT weight w_h = N_h / k_h.
1238    let mut picked: Vec<WeightedOuterRow> = Vec::with_capacity(k + n_strata);
1239    for (stratum_id, rows) in strata.iter().enumerate() {
1240        if rows.is_empty() {
1241            continue;
1242        }
1243        let take = (k as u128 * rows.len() as u128).div_ceil(n as u128) as usize;
1244        let take = take.max(1).min(rows.len());
1245        // HT inverse-inclusion weight for this stratum: w_h = N_h / k_h.
1246        // Identical for every row drawn from `stratum_id`.
1247        let w_h = rows.len() as f64 / take as f64;
1248        let stratum_tag = stratum_id as u32;
1249
1250        // Deterministic key from (seed, stratum_id).
1251        let mut state = seed ^ (stratum_id as u64).wrapping_mul(0x9E3779B97F4A7C15);
1252        // Mix once so even seed=0, stratum_id=0 produces a non-trivial state.
1253        splitmix64(&mut state);
1254
1255        if take == rows.len() {
1256            for &index in rows.iter() {
1257                picked.push(WeightedOuterRow {
1258                    index,
1259                    weight: w_h,
1260                    stratum: stratum_tag,
1261                });
1262            }
1263        } else {
1264            // Fisher-Yates partial shuffle: produce `take` distinct rows.
1265            let mut buf: Vec<usize> = rows.clone();
1266            let m = buf.len();
1267            for i in 0..take {
1268                let r = splitmix64(&mut state);
1269                let j = i + (r as usize) % (m - i);
1270                buf.swap(i, j);
1271            }
1272            for &index in &buf[..take] {
1273                picked.push(WeightedOuterRow {
1274                    index,
1275                    weight: w_h,
1276                    stratum: stratum_tag,
1277                });
1278            }
1279        }
1280    }
1281
1282    // `from_weighted_rows` sorts + dedups by index. Strata are disjoint by
1283    // construction so dedup is a no-op, but we route through the constructor
1284    // so the OuterScoreSubsample contract stays in one place.
1285    OuterScoreSubsample::from_weighted_rows(picked, n, seed)
1286}
1287
1288// ---------------------------------------------------------------------------
1289// Outer-row iteration helpers.
1290//
1291// These wrap the choice between "iterate 0..n" (default) and "iterate
1292// `subsample.mask`" so per-row hot loops in Phase 2 can call a single helper
1293// rather than branch by hand. We expose both an enum that callers can match
1294// on directly (cheap path: a `Range` plus a `Arc<Vec<usize>>`) and a
1295// `Vec<usize>`-returning convenience that satisfies
1296// `IntoParallelIterator<Item = usize>` via `Vec`'s rayon impl.
1297
1298/// Row-index iteration choice for outer-only score/gradient passes.
1299#[derive(Debug, Clone)]
1300pub enum OuterRowIter {
1301    /// Full data: iterate `0..n`.
1302    All { n: usize },
1303    /// Subsample: iterate `subsample.mask`.
1304    Subset { mask: Arc<Vec<usize>> },
1305}
1306
1307impl OuterRowIter {
1308    /// Number of rows this iterator covers.
1309    #[inline]
1310    pub fn len(&self) -> usize {
1311        match self {
1312            OuterRowIter::All { n } => *n,
1313            OuterRowIter::Subset { mask } => mask.len(),
1314        }
1315    }
1316
1317    #[inline]
1318    pub fn is_empty(&self) -> bool {
1319        self.len() == 0
1320    }
1321
1322    /// Materialize the row indices as a `Vec<usize>`. Useful for callers
1323    /// that want a `IntoParallelIterator<Item = usize>` source — `Vec<usize>`
1324    /// satisfies that trait via rayon's blanket impl.
1325    pub fn to_vec(&self) -> Vec<usize> {
1326        match self {
1327            OuterRowIter::All { n } => (0..*n).collect(),
1328            OuterRowIter::Subset { mask } => mask.as_ref().clone(),
1329        }
1330    }
1331}
1332
1333/// Choose the row-iteration strategy for an outer-only pass. When
1334/// `opts.outer_score_subsample` is `Some`, returns the subsample mask;
1335/// otherwise returns the full range `0..n`.
1336///
1337/// Callers using this helper iterate over row indices and must additionally
1338/// consult [`outer_row_weights_by_index`] (or [`outer_weighted_rows`]) for
1339/// per-row HT weights — a single global rescale is biased under stratified
1340/// sampling and is no longer exposed.
1341pub fn outer_row_indices(
1342    opts: &crate::custom_family::BlockwiseFitOptions,
1343    n: usize,
1344) -> OuterRowIter {
1345    match opts.outer_score_subsample.as_ref() {
1346        Some(s) => OuterRowIter::Subset {
1347            mask: Arc::clone(&s.mask),
1348        },
1349        None => OuterRowIter::All { n },
1350    }
1351}
1352
1353/// Per-row HT-weighted iteration: returns one `WeightedOuterRow` per
1354/// retained row when a subsample is active; otherwise returns
1355/// `(index, weight = 1.0, stratum = 0)` for every row in `0..n`.
1356pub fn outer_weighted_rows(
1357    opts: &crate::custom_family::BlockwiseFitOptions,
1358    n: usize,
1359) -> Vec<WeightedOuterRow> {
1360    match opts.outer_score_subsample.as_ref() {
1361        Some(s) => s.rows.as_ref().clone(),
1362        None => (0..n)
1363            .map(|index| WeightedOuterRow {
1364                index,
1365                weight: 1.0,
1366                stratum: 0,
1367            })
1368            .collect(),
1369    }
1370}
1371
1372/// Dense-by-row HT weights of length `n`. Masked rows carry their HT
1373/// weight; unmasked rows default to 1.0 so that callers who index by row
1374/// regardless of subsampling still get a valid scalar (the consumer is
1375/// expected to iterate only over `outer_row_indices`).
1376pub fn outer_row_weights_by_index(
1377    opts: &crate::custom_family::BlockwiseFitOptions,
1378    n: usize,
1379) -> Vec<f64> {
1380    match opts.outer_score_subsample.as_ref() {
1381        Some(s) => {
1382            let mut weights = vec![1.0; n];
1383            for r in s.rows.iter() {
1384                if r.index < n {
1385                    weights[r.index] = r.weight;
1386                }
1387            }
1388            weights
1389        }
1390        None => vec![1.0; n],
1391    }
1392}
1393
1394/// Shared monotonicity line-search safeguard for time-block linear inequality
1395/// constraints `A·beta >= b`.
1396///
1397/// Both survival families (location-scale and marginal-slope) clamp a Newton
1398/// step `beta + alpha·delta` to the largest feasible fraction `alpha ∈ [0, 1]`
1399/// such that no constraint row is driven below its bound, then back off by the
1400/// fixed `0.995` safeguard whenever the boundary is reached. The slack/drift
1401/// arithmetic and the `0.995` factor live here once; each family supplies only
1402/// its own error type by mapping the dimension-mismatch and constraint-violation
1403/// conditions into `E` via the two closures (preserving family-specific message
1404/// text and error variants).
1405///
1406/// `map_dim_err` is called with `(beta_len, delta_len, expected_ncols)` when the
1407/// step dimensions disagree with the constraint matrix. `map_violation_err` is
1408/// called with `(row, slack)` when the current `beta` already violates a
1409/// constraint row (slack below `-1e-10`).
1410pub fn feasible_step_fraction<E>(
1411    constraints: &gam_problem::LinearInequalityConstraints,
1412    beta: &Array1<f64>,
1413    direction: &Array1<f64>,
1414    map_dim_err: impl Fn(usize, usize, usize) -> E,
1415    map_violation_err: impl Fn(usize, f64) -> E,
1416) -> Result<f64, E> {
1417    if beta.len() != constraints.a.ncols() || direction.len() != constraints.a.ncols() {
1418        return Err(map_dim_err(
1419            beta.len(),
1420            direction.len(),
1421            constraints.a.ncols(),
1422        ));
1423    }
1424    // Feasibility-violation tolerance for the *current* iterate, kept consistent
1425    // with the QP entry gate `check_linear_feasibility` (called at 1e-8) and the
1426    // residual left by `project_onto_linear_constraints` (per-row violation <= 1e-10
1427    // on the working vector, accumulating up to O(1e-9) on the final beta through
1428    // its sequential Dykstra corrections). Rejecting at -1e-10 here re-classified a
1429    // beta the QP had already accepted as feasible as a hard error (gam#797: the
1430    // projected survival time-block seed lands at slack ~ -1.1e-9 on a binding
1431    // derivative-guard row, so every trust-region attempt errored out before any
1432    // step). A slack within this band is numerically AT the boundary; treat it as
1433    // active (slack = 0) rather than a violation.
1434    const FEASIBLE_STEP_VIOLATION_TOL: f64 = 1e-8;
1435    // Multiplicative backoff applied when the step is clipped by a binding
1436    // constraint, keeping the new iterate strictly interior (slack > 0) so the
1437    // next iteration's feasibility gate cannot reject a point that landed
1438    // exactly on the boundary through round-off.
1439    const FEASIBLE_STEP_BOUNDARY_BACKOFF: f64 = 0.995;
1440    let mut alpha = 1.0f64;
1441    for row in 0..constraints.a.nrows() {
1442        let a_row = constraints.a.row(row);
1443        let raw_slack = a_row.dot(beta) - constraints.b[row];
1444        if raw_slack < -FEASIBLE_STEP_VIOLATION_TOL {
1445            return Err(map_violation_err(row, raw_slack));
1446        }
1447        // Clamp boundary round-off to the boundary so a tiny negative slack cannot
1448        // produce a spurious negative/zero step fraction below.
1449        let slack = raw_slack.max(0.0);
1450        let drift = a_row.dot(direction);
1451        if drift < 0.0 {
1452            alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
1453        }
1454    }
1455    if alpha >= 1.0 {
1456        Ok(1.0)
1457    } else {
1458        Ok((FEASIBLE_STEP_BOUNDARY_BACKOFF * alpha).clamp(0.0, 1.0))
1459    }
1460}
1461
1462/// Family-specific ψ-calculus hooks for the shared exact-Newton joint-ψ
1463/// workspace.
1464///
1465/// The two marginal-slope families (Bernoulli marginal-slope and survival
1466/// marginal-slope) build an [`ExactNewtonJointPsiWorkspace`] whose four methods
1467/// share a single skeleton: a σ-auxiliary (log-σ frailty) dispatch branch on
1468/// top of a family-specific non-σ row pass. The skeleton lives once in
1469/// [`MarginalSlopeExactNewtonPsiWorkspace`]; each family supplies only the
1470/// resolved per-call operations here, holding its own block states, specs,
1471/// derivative blocks, cache and outer-subsample options internally.
1472///
1473/// Implementors own all workspace state, so every hook takes only the ψ index /
1474/// pair / direction. The two genuine per-family policy differences in the
1475/// second-order σ-aux branch are encoded as
1476/// [`both_sigma_aux_second_order`](Self::both_sigma_aux_second_order) (which
1477/// pure-σ pairs are admissible) and
1478/// [`mixed_sigma_aux_second_order`](Self::mixed_sigma_aux_second_order) (how a
1479/// mixed σ / non-σ pair is handled) rather than being harmonized away.
1480pub trait MarginalSlopePsiFamily: Send + Sync {
1481    /// True when ψ index `psi_index` addresses the log-σ frailty auxiliary
1482    /// parameter rather than a spatial / spline derivative axis.
1483    fn is_sigma_aux(&self, psi_index: usize) -> bool;
1484
1485    /// First-order joint-ψ terms for the σ-auxiliary parameter.
1486    fn sigma_first_order_terms(
1487        &self,
1488    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String>;
1489
1490    /// First-order joint-ψ terms for a non-σ derivative axis `psi_index`.
1491    fn psi_first_order_terms(
1492        &self,
1493        psi_index: usize,
1494    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String>;
1495
1496    /// Batched first-order joint-ψ terms over all derivative axes (used by the
1497    /// outer score sweep). Returns `Ok(None)` when the batched fast path is
1498    /// unavailable for the current configuration so the caller falls back to
1499    /// per-axis evaluation.
1500    fn psi_first_order_terms_all(
1501        &self,
1502    ) -> Result<Option<Vec<crate::custom_family::ExactNewtonJointPsiTerms>>, String>;
1503
1504    /// Whether the σ-aux second-order branch should treat `(psi_i, psi_j)` as a
1505    /// pure-σ pair (dispatching to [`sigma_second_order_terms`](Self::sigma_second_order_terms)).
1506    /// Any σ-touching pair that is not pure-σ routes through
1507    /// [`mixed_sigma_aux_second_order`](Self::mixed_sigma_aux_second_order).
1508    fn both_sigma_aux_second_order(&self, psi_i: usize, psi_j: usize) -> bool;
1509
1510    /// Second-order joint-ψ terms for a pure σ / σ pair.
1511    fn sigma_second_order_terms(
1512        &self,
1513    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1514
1515    /// Per-family policy for a mixed σ / non-σ second-order pair: one family
1516    /// rejects it (no cross auxiliary terms available), the other returns
1517    /// `Ok(None)`.
1518    fn mixed_sigma_aux_second_order(
1519        &self,
1520    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1521
1522    /// Second-order joint-ψ terms for a non-σ derivative-axis pair.
1523    fn psi_second_order_terms(
1524        &self,
1525        psi_i: usize,
1526        psi_j: usize,
1527    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1528
1529    /// Direction-contracted second-order ψ terms over the non-σ derivative
1530    /// axes (#740). `alpha_psi` is the full ψ-block weight vector; the
1531    /// contraction is against the combined non-σ direction
1532    /// `ψ(α) = Σ_j alpha_psi[j] · ψ_j`, streaming the family's rows ONCE so the
1533    /// profiled θ-HVP operator applies one combined-direction n-pass per matvec
1534    /// instead of `K²` per-pair [`Self::psi_second_order_terms`] passes.
1535    ///
1536    /// Default `None` keeps the family on the exact per-pair path. The generic
1537    /// workspace only calls this when no σ-auxiliary axis carries weight (a σ
1538    /// term routes the whole direction back to the per-pair fallback), so an
1539    /// override only handles the pure non-σ derivative axes — the same domain
1540    /// as [`Self::psi_second_order_terms`].
1541    fn psi_second_order_terms_contracted(
1542        &self,
1543        _: &[f64],
1544    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderContracted>, String>
1545    {
1546        // Default implementation ignores this parameter.
1547        Ok(None)
1548    }
1549
1550    /// Hessian directional derivative for the σ-auxiliary parameter, returned
1551    /// as a dense matrix (the generic wraps it into
1552    /// [`DriftDerivResult::Dense`](gam_problem::DriftDerivResult::Dense)).
1553    fn sigma_hessian_directional_derivative(
1554        &self,
1555        d_beta_flat: &Array1<f64>,
1556    ) -> Result<Option<Array2<f64>>, String>;
1557
1558    /// Hessian directional derivative for a non-σ derivative axis, returned as
1559    /// a hyper-operator (the generic wraps it into
1560    /// [`DriftDerivResult::Operator`](gam_problem::DriftDerivResult::Operator)).
1561    fn psi_hessian_directional_derivative(
1562        &self,
1563        psi_index: usize,
1564        d_beta_flat: &Array1<f64>,
1565    ) -> Result<Option<Arc<dyn gam_problem::HyperOperator>>, String>;
1566}
1567
1568/// Generic exact-Newton joint-ψ workspace shared by the marginal-slope
1569/// families. Owns the σ-auxiliary dispatch skeleton and delegates every
1570/// family-specific operation to its [`MarginalSlopePsiFamily`] impl.
1571pub struct MarginalSlopeExactNewtonPsiWorkspace<F: MarginalSlopePsiFamily> {
1572    family: F,
1573}
1574
1575impl<F: MarginalSlopePsiFamily> MarginalSlopeExactNewtonPsiWorkspace<F> {
1576    pub fn new(family: F) -> Self {
1577        Self { family }
1578    }
1579}
1580
1581impl<F: MarginalSlopePsiFamily> crate::custom_family::ExactNewtonJointPsiWorkspace
1582    for MarginalSlopeExactNewtonPsiWorkspace<F>
1583{
1584    fn first_order_terms(
1585        &self,
1586        psi_index: usize,
1587    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String> {
1588        if self.family.is_sigma_aux(psi_index) {
1589            return self.family.sigma_first_order_terms();
1590        }
1591        self.family.psi_first_order_terms(psi_index)
1592    }
1593
1594    fn first_order_terms_all(
1595        &self,
1596    ) -> Result<Option<Vec<crate::custom_family::ExactNewtonJointPsiTerms>>, String> {
1597        self.family.psi_first_order_terms_all()
1598    }
1599
1600    fn second_order_terms(
1601        &self,
1602        psi_i: usize,
1603        psi_j: usize,
1604    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String> {
1605        if self.family.is_sigma_aux(psi_i) || self.family.is_sigma_aux(psi_j) {
1606            if self.family.both_sigma_aux_second_order(psi_i, psi_j) {
1607                return self.family.sigma_second_order_terms();
1608            }
1609            return self.family.mixed_sigma_aux_second_order();
1610        }
1611        self.family.psi_second_order_terms(psi_i, psi_j)
1612    }
1613
1614    fn second_order_terms_contracted(
1615        &self,
1616        alpha_psi: &[f64],
1617    ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderContracted>, String>
1618    {
1619        // The σ-auxiliary axes do not participate in the family's combined
1620        // non-σ row stream (their second-order terms come from a separate
1621        // σ/σ and mixed-σ path with no directional row kernel). If any
1622        // σ-aux axis carries weight in this applied direction, decline the
1623        // contracted fast path entirely so the caller keeps the exact
1624        // per-pair assembly — the contracted hook is a representation/cost
1625        // choice, never an approximation, so falling back is the correct
1626        // behaviour rather than dropping the σ contribution.
1627        for (j, &weight) in alpha_psi.iter().enumerate() {
1628            if weight != 0.0 && self.family.is_sigma_aux(j) {
1629                return Ok(None);
1630            }
1631        }
1632        self.family.psi_second_order_terms_contracted(alpha_psi)
1633    }
1634
1635    fn hessian_directional_derivative(
1636        &self,
1637        psi_index: usize,
1638        d_beta_flat: &Array1<f64>,
1639    ) -> Result<Option<gam_problem::DriftDerivResult>, String> {
1640        if self.family.is_sigma_aux(psi_index) {
1641            return self
1642                .family
1643                .sigma_hessian_directional_derivative(d_beta_flat)
1644                .map(|result| result.map(gam_problem::DriftDerivResult::Dense));
1645        }
1646        self.family
1647            .psi_hessian_directional_derivative(psi_index, d_beta_flat)
1648            .map(|result| result.map(gam_problem::DriftDerivResult::Operator))
1649    }
1650}
1651
1652/// Deterministic-order parallel reduction over a row-index slice.
1653///
1654/// Splits `rows` into contiguous chunks sized to saturate the rayon pool
1655/// (several chunks per worker, floored so small `n` stays coarse), processes
1656/// each chunk sequentially in parallel via `process_row`, and combines the
1657/// per-chunk accumulators in chunk-index order via `combine` on the calling
1658/// thread. The chunk count is a pure function of `(rows.len(), worker_count)`;
1659/// the worker count is fixed for the process (one global pool), so the
1660/// reduction tree is fixed across calls regardless of rayon's work-stealing
1661/// decisions.
1662///
1663/// `try_fold/try_reduce` over `rows.into_par_iter()` does **not** have
1664/// this property: rayon's adaptive splitter sets chunk boundaries based
1665/// on `current_num_threads()` and runtime work-stealing, so two calls
1666/// with identical inputs can return ULP-different floating-point sums
1667/// when the rayon pool has different concurrent activity. Tests that
1668/// compare two reductions and rely on bit-for-bit equality flake under
1669/// load with that pattern. This primitive is the per-family deterministic
1670/// row-reduction that the bernoulli / survival sigma-ψ paths funnel
1671/// through; their per-row contributions are the dominant non-deterministic
1672/// source in the marginal-slope outer-loop score / Hessian sums.
1673pub(crate) fn chunked_row_reduction<Item, Acc, Init, Process, Combine>(
1674    rows: &[Item],
1675    init: Init,
1676    process_row: Process,
1677    mut combine: Combine,
1678) -> Result<Acc, String>
1679where
1680    Item: Sync + Copy,
1681    Acc: Send,
1682    Init: Fn() -> Acc + Sync,
1683    Process: Fn(Item, &mut Acc) -> Result<(), String> + Sync,
1684    Combine: FnMut(&mut Acc, Acc),
1685{
1686    use rayon::iter::{IntoParallelIterator, ParallelIterator};
1687    let n = rows.len();
1688    if n == 0 {
1689        return Ok(init());
1690    }
1691    // The chunk count is sized so the heavy reduction phases actually saturate
1692    // the rayon pool: a fixed `32` left half of a 64-core box idle whenever the
1693    // pool had more than 32 workers, capping utilization at ~50% on the biobank
1694    // coord-corrections / row-stream phases. Targeting several chunks per worker
1695    // keeps load balanced across an uneven row-cost tail (work-stealing still
1696    // moves whole chunks, never partial sums) without flooding the sequential
1697    // `combine` with tiny partials. The count is a pure function of
1698    // `(rows.len(), worker_count)`; the worker count is fixed for the process
1699    // (one global pool, gam owns its threads), so chunk boundaries are stable
1700    // across calls and the ordered `Vec` collect + sequential `combine` keep the
1701    // reduction bit-for-bit deterministic regardless of work-stealing.
1702    const CHUNKS_PER_WORKER: usize = 4;
1703    const MIN_CHUNK_COUNT: usize = 32;
1704    const MIN_ROWS_PER_CHUNK: usize = 64;
1705    let workers = rayon::current_num_threads().max(1);
1706    let target_chunk_count = workers
1707        .saturating_mul(CHUNKS_PER_WORKER)
1708        .max(MIN_CHUNK_COUNT);
1709    // Never carve chunks below `MIN_ROWS_PER_CHUNK` rows: for small `n` the
1710    // scheduler/partial-accumulator overhead would dominate the row arithmetic.
1711    let chunk_count = target_chunk_count
1712        .min(n.div_ceil(MIN_ROWS_PER_CHUNK))
1713        .max(1);
1714    let chunk_size = n.div_ceil(chunk_count).max(1);
1715    let n_chunks = n.div_ceil(chunk_size);
1716    // `(0..n_chunks).into_par_iter()` is `IndexedParallelIterator`, so the
1717    // `.collect::<Vec<_>>()` below preserves chunk-index order regardless
1718    // of work-stealing. That ordered `Vec` is what makes the sequential
1719    // `combine` deterministic.
1720    let chunk_states: Vec<Acc> = (0..n_chunks)
1721        .into_par_iter()
1722        .map(|chunk_idx| -> Result<Acc, String> {
1723            let start = chunk_idx * chunk_size;
1724            let end = (start + chunk_size).min(n);
1725            let mut acc = init();
1726            for &item in &rows[start..end] {
1727                process_row(item, &mut acc)?;
1728            }
1729            Ok(acc)
1730        })
1731        .collect::<Result<Vec<Acc>, String>>()?;
1732    let mut total = init();
1733    for chunk in chunk_states {
1734        combine(&mut total, chunk);
1735    }
1736    Ok(total)
1737}
1738
1739#[cfg(test)]
1740mod tests {
1741    use super::*;
1742
1743    // ---------------------------------------------------------------------
1744    // Parity guard for the shared exact-Newton directional sweep.
1745    //
1746    // `directional_obj_grad_hess` is the single engine that both the
1747    // Bernoulli and survival marginal-slope families now route their
1748    // exact-Newton obj/grad/hess and psi-Hessian directional sweeps
1749    // through (replacing three hand-rolled, drift-prone loop nests). The
1750    // test below reconstructs the *reference* loop nest those families used
1751    // to carry inline and asserts the shared engine reproduces it
1752    // bit-for-bit on randomized fixtures across both sweep shapes
1753    // (objective present / suppressed, leading-prefix lengths 1 and 2).
1754    //
1755    // The synthetic `eval` mirrors the contract of every family's
1756    // `row_neglog_directional_with_scale_jet`: it builds one linear
1757    // `MultiDirJet` per supplied direction at a distinct base, multiplies
1758    // them together, scales by the per-sweep scale jet, composes through a
1759    // smooth nonlinearity, and returns the highest mixed-partial
1760    // coefficient. That makes the appended unit directions genuinely
1761    // interact (so a transposed Hessian index or a missing symmetric
1762    // assignment is caught) and makes the scale jet load-bearing (so a
1763    // mis-wired obj/grad/hess scale is caught).
1764
1765    use gam_math::jet_partitions::MultiDirJet;
1766
1767    /// Deterministic LCG so the fixture is reproducible without pulling in
1768    /// an RNG dependency.
1769    struct Lcg(u64);
1770    impl Lcg {
1771        fn next_f64(&mut self) -> f64 {
1772            self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1);
1773            ((self.0 >> 11) as f64) / ((1u64 << 53) as f64) * 2.0 - 1.0
1774        }
1775    }
1776
1777    /// Synthetic row-jet evaluator with the exact contract
1778    /// `directional_obj_grad_hess` expects: given `dirs` (the leading prefix
1779    /// plus appended unit directions) and a `scale` jet of matching order,
1780    /// return the top mixed-partial coefficient of a smooth multilinear
1781    /// functional of the directions.
1782    fn synthetic_row_eval(
1783        bases: &[f64],
1784        weight: f64,
1785        dirs: &[&Array1<f64>],
1786        scale: &MultiDirJet,
1787    ) -> Result<f64, String> {
1788        let k = dirs.len();
1789        if k > 4 {
1790            return Err(format!("synthetic eval expects 0..=4 directions, got {k}"));
1791        }
1792        if scale.coeffs.len() != (1usize << k) {
1793            return Err(format!(
1794                "synthetic eval scale jet dimension mismatch: coeffs={}, dirs={k}",
1795                scale.coeffs.len()
1796            ));
1797        }
1798        let primary_dim = bases.len();
1799        // One linear jet per direction, each at a distinct base coordinate,
1800        // mixing the direction's components across the primary dimensions so
1801        // every Hessian entry is exercised.
1802        let first = |dir: &Array1<f64>| -> Vec<f64> {
1803            (0..k).map(|j| dir[j % primary_dim]).collect::<Vec<f64>>()
1804        };
1805        let mut product = MultiDirJet::constant(k, 1.0);
1806        for (slot, dir) in dirs.iter().enumerate() {
1807            let base = bases[slot % primary_dim] + 0.25 * slot as f64;
1808            let comps: Vec<f64> = (0..primary_dim)
1809                .map(|p| dir[p] * (1.0 + 0.5 * p as f64))
1810                .collect();
1811            let lin = MultiDirJet::linear(k, base, &first(&Array1::from(comps)));
1812            product = product.mul(&lin);
1813        }
1814        let scaled = product.mul(scale);
1815        // Smooth nonlinearity φ(x) = weight·ln(1 + x²) composed through the
1816        // jet — derivs[0..=4] evaluated at the zeroth-order coefficient.
1817        let x = scaled.coeff(0);
1818        let denom = 1.0 + x * x;
1819        let d1 = weight * (2.0 * x) / denom;
1820        let d2 = weight * (2.0 * (1.0 - x * x)) / (denom * denom);
1821        let d3 = weight * (-4.0 * x * (3.0 - x * x)) / (denom * denom * denom);
1822        let d4 = weight * (-12.0 * (1.0 - 6.0 * x * x + x * x * x * x))
1823            / (denom * denom * denom * denom);
1824        let phi = weight * denom.ln();
1825        Ok(scaled
1826            .compose_unary([phi, d1, d2, d3, d4])
1827            .coeff((1usize << k) - 1))
1828    }
1829
1830    /// Hand-rolled reference sweep — the exact obj/grad/hess loop nest the
1831    /// families carried inline before the unification onto
1832    /// `directional_obj_grad_hess`. Kept here purely as the parity oracle.
1833    fn reference_obj_grad_hess<Eval>(
1834        primary_dim: usize,
1835        leading: &[&Array1<f64>],
1836        scales: &DirectionalScaleJets,
1837        eval: Eval,
1838    ) -> Result<(f64, Array1<f64>, Array2<f64>), String>
1839    where
1840        Eval: Fn(&[&Array1<f64>], &MultiDirJet) -> Result<f64, String>,
1841    {
1842        let unit = |a: usize| -> Array1<f64> {
1843            let mut da = Array1::<f64>::zeros(primary_dim);
1844            da[a] = 1.0;
1845            da
1846        };
1847        let objective = if let Some(scale_obj) = scales.obj.as_ref() {
1848            eval(leading, scale_obj)?
1849        } else {
1850            0.0
1851        };
1852        let mut grad = Array1::<f64>::zeros(primary_dim);
1853        for a in 0..primary_dim {
1854            let da = unit(a);
1855            let mut dirs: Vec<&Array1<f64>> = leading.to_vec();
1856            dirs.push(&da);
1857            grad[a] = eval(&dirs, &scales.grad)?;
1858        }
1859        let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
1860        for a in 0..primary_dim {
1861            let da = unit(a);
1862            for b in a..primary_dim {
1863                let db = unit(b);
1864                let mut dirs: Vec<&Array1<f64>> = leading.to_vec();
1865                dirs.push(&da);
1866                dirs.push(&db);
1867                let value = eval(&dirs, &scales.hess)?;
1868                hess[[a, b]] = value;
1869                hess[[b, a]] = value;
1870            }
1871        }
1872        Ok((objective, grad, hess))
1873    }
1874
1875    /// Build a scale jet of the requested order with random first/second
1876    /// mixed coefficients on the supplied masks — mirrors the structure of
1877    /// each family's `sigma_scale_jet` (a base plus first-order entries on
1878    /// the leading log-sigma slots and a second-order entry on the pair).
1879    fn random_scale_jet(
1880        rng: &mut Lcg,
1881        n_dirs: usize,
1882        first_masks: &[usize],
1883        second_masks: &[usize],
1884    ) -> MultiDirJet {
1885        let mut coeffs: Vec<(usize, f64)> = vec![(0usize, 1.0 + 0.1 * rng.next_f64())];
1886        for &m in first_masks {
1887            coeffs.push((1usize << m, rng.next_f64()));
1888        }
1889        for &m in second_masks {
1890            coeffs.push(((1usize << m) | 1usize, rng.next_f64()));
1891        }
1892        MultiDirJet::with_coeffs(n_dirs, &coeffs)
1893    }
1894
1895    #[test]
1896    fn directional_obj_grad_hess_matches_reference_loop_nest() {
1897        let primary_dim = 4usize;
1898        let mut rng = Lcg(0x5EED_1234_ABCD_0001);
1899        // Sweep both family shapes: first-order log-sigma (leading=[zero],
1900        // obj present), second-order (leading=[zero,zero], obj present), and
1901        // the psi-Hessian directional (leading=[zero,row_dir], obj absent).
1902        for trial in 0..32 {
1903            let bases: Vec<f64> = (0..primary_dim).map(|_| rng.next_f64()).collect();
1904            let weight = 0.5 + 0.5 * (rng.next_f64() + 1.0);
1905            let eval = |dirs: &[&Array1<f64>], scale: &MultiDirJet| {
1906                synthetic_row_eval(&bases, weight, dirs, scale)
1907            };
1908
1909            let zero = Array1::<f64>::zeros(primary_dim);
1910            let row_dir: Array1<f64> =
1911                Array1::from((0..primary_dim).map(|_| rng.next_f64()).collect::<Vec<_>>());
1912
1913            let cases: Vec<(Vec<&Array1<f64>>, DirectionalScaleJets)> = vec![
1914                (
1915                    vec![&zero],
1916                    DirectionalScaleJets {
1917                        obj: Some(random_scale_jet(&mut rng, 1, &[], &[])),
1918                        grad: random_scale_jet(&mut rng, 2, &[0], &[]),
1919                        hess: random_scale_jet(&mut rng, 3, &[0], &[]),
1920                    },
1921                ),
1922                (
1923                    vec![&zero, &zero],
1924                    DirectionalScaleJets {
1925                        obj: Some(random_scale_jet(&mut rng, 2, &[0, 1], &[])),
1926                        grad: random_scale_jet(&mut rng, 3, &[0, 1], &[]),
1927                        hess: random_scale_jet(&mut rng, 4, &[0, 1], &[]),
1928                    },
1929                ),
1930                (
1931                    vec![&zero, &row_dir],
1932                    DirectionalScaleJets {
1933                        obj: None,
1934                        grad: random_scale_jet(&mut rng, 3, &[0], &[]),
1935                        hess: random_scale_jet(&mut rng, 4, &[0], &[]),
1936                    },
1937                ),
1938            ];
1939
1940            for (leading, scales) in &cases {
1941                let shared =
1942                    directional_obj_grad_hess(primary_dim, leading, scales, eval).expect("shared");
1943                let (ref_obj, ref_grad, ref_hess) =
1944                    reference_obj_grad_hess(primary_dim, leading, scales, eval).expect("reference");
1945
1946                assert_eq!(
1947                    shared.objective, ref_obj,
1948                    "trial {trial}: objective drift {} vs {}",
1949                    shared.objective, ref_obj
1950                );
1951                for a in 0..primary_dim {
1952                    assert_eq!(
1953                        shared.grad[a], ref_grad[a],
1954                        "trial {trial}: grad[{a}] drift {} vs {}",
1955                        shared.grad[a], ref_grad[a]
1956                    );
1957                    for b in 0..primary_dim {
1958                        assert_eq!(
1959                            shared.hess[[a, b]],
1960                            ref_hess[[a, b]],
1961                            "trial {trial}: hess[{a},{b}] drift {} vs {}",
1962                            shared.hess[[a, b]],
1963                            ref_hess[[a, b]]
1964                        );
1965                    }
1966                }
1967                // The Hessian the engine returns must be exactly symmetric —
1968                // a transposed write in the upper-triangle loop is the classic
1969                // exact-Newton drift bug.
1970                for a in 0..primary_dim {
1971                    for b in 0..primary_dim {
1972                        assert_eq!(
1973                            shared.hess[[a, b]],
1974                            shared.hess[[b, a]],
1975                            "trial {trial}: hess asymmetric at ({a},{b})"
1976                        );
1977                    }
1978                }
1979            }
1980        }
1981    }
1982
1983    #[test]
1984    fn auto_outer_score_subsample_skips_small_problems() {
1985        let n = 1000;
1986        let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
1987        let opts = AutoOuterSubsampleOptions::default();
1988        assert!(
1989            auto_outer_score_subsample(&z, None, &opts).is_none(),
1990            "n={n} below default min_n_for_auto=30000 should not subsample"
1991        );
1992    }
1993
1994    #[test]
1995    fn auto_outer_score_subsample_returns_target_k_above_threshold() {
1996        let n = 60_000;
1997        let z: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
1998        let opts = AutoOuterSubsampleOptions::default();
1999        let mask = auto_outer_score_subsample(&z, None, &opts)
2000            .expect("n=60000 should auto-subsample with default options");
2001        // Default target_fraction=0.10 and min_k=10000 → K = max(10000, 6000) = 10000.
2002        assert_eq!(mask.n_full, n);
2003        assert!(
2004            mask.len() >= 9_900 && mask.len() <= 10_200,
2005            "expected K≈10_000, got {}",
2006            mask.len()
2007        );
2008        // HT weights should reconstruct n_full in expectation: sum of
2009        // per-row weights ≈ n_full (allowing for small allocation rounding).
2010        let weight_sum: f64 = mask.rows.iter().map(|r| r.weight).sum();
2011        let rel_err = (weight_sum - n as f64).abs() / n as f64;
2012        assert!(
2013            rel_err < 0.02,
2014            "HT weight sum {weight_sum:.3} should ≈ n_full={n}, rel_err={rel_err:.4}"
2015        );
2016    }
2017
2018    #[test]
2019    fn auto_outer_score_subsample_horvitz_thompson_unbiased() {
2020        // On a synthetic per-row contribution `t_i = z_i² + 1`, verify
2021        // the HT-weighted sum over the auto-mask matches the full sum
2022        // within 3 standard deviations of the predicted estimator
2023        // variance. This guards against silent regressions in either
2024        // the stratified mask construction or the weight assignment.
2025        let n = 50_000;
2026        let z: Vec<f64> = (0..n)
2027            .map(|i| ((i as f64) / n as f64) * 2.0 - 1.0)
2028            .collect();
2029        let stratum: Vec<u8> = (0..n).map(|i| if i % 3 == 0 { 1 } else { 0 }).collect();
2030        let opts = AutoOuterSubsampleOptions {
2031            seed: 0xC0FFEE,
2032            ..AutoOuterSubsampleOptions::default()
2033        };
2034        let t: Vec<f64> = z.iter().map(|zi| zi * zi + 1.0).collect();
2035        let exact: f64 = t.iter().sum();
2036        let mask = auto_outer_score_subsample(&z, Some(&stratum), &opts)
2037            .expect("n=50000 should auto-subsample");
2038        let estimate: f64 = mask.rows.iter().map(|r| r.weight * t[r.index]).sum();
2039        // Predicted standard error: σ ≈ (1/√K) · √(1 − K/N) · cv · |T|.
2040        // For t_i ∈ [1, 2], cv ≲ 0.4. Be generous (factor 5) to keep
2041        // the test robust against PRNG-dependent allocation jitter.
2042        let k = mask.len();
2043        let predicted_se =
2044            exact * 0.4 * (1.0 / (k as f64).sqrt()) * (1.0 - k as f64 / n as f64).sqrt();
2045        let observed_err = (estimate - exact).abs();
2046        assert!(
2047            observed_err < 5.0 * predicted_se.max(1.0),
2048            "HT estimate {estimate:.3} vs exact {exact:.3}: err={observed_err:.3} exceeds 5×predicted_se={:.3}",
2049            predicted_se
2050        );
2051    }
2052
2053    #[test]
2054    fn subsample_full_n_equals_no_subsample() {
2055        // mask = (0..n) — the all-rows subsample should have weight_scale 1.0
2056        // and outer_row_indices should yield the same sorted set in both
2057        // Some(mask=full) and None modes.
2058        let n: usize = 1024;
2059        let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
2060        let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2061        let s = build_outer_score_subsample(&z, &secondary, n, 0xDEADBEEF);
2062        assert_eq!(s.len(), n);
2063        assert!((s.weight_scale - 1.0).abs() < 1e-12);
2064
2065        let mut full = crate::custom_family::BlockwiseFitOptions::default();
2066        let from_none = outer_row_indices(&full, n).to_vec();
2067        full.outer_score_subsample = Some(Arc::new(s));
2068        let from_some = outer_row_indices(&full, n).to_vec();
2069
2070        let mut a = from_none.clone();
2071        let mut b = from_some.clone();
2072        a.sort_unstable();
2073        b.sort_unstable();
2074        assert_eq!(a, b);
2075        assert_eq!(a, (0..n).collect::<Vec<_>>());
2076    }
2077
2078    #[test]
2079    fn stratification_covers_all_strata() {
2080        // Synthetic with 2 secondary classes × 100 z-deciles. Every
2081        // non-empty (secondary, decile) stratum must contribute ≥ 1 row.
2082        let n: usize = 20_000;
2083        let z: Vec<f64> = (0..n).map(|i| (i as f64) * 0.001).collect();
2084        let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2085        let k = 2_000;
2086        let s = build_outer_score_subsample(&z, &secondary, k, 12345);
2087        assert!(s.len() >= k, "subsample size {} < k {}", s.len(), k);
2088
2089        // Recompute deciles to label rows.
2090        let mut order: Vec<usize> = (0..n).collect();
2091        order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap());
2092        let mut decile = vec![0usize; n];
2093        for (rank, &row) in order.iter().enumerate() {
2094            decile[row] = ((rank * 100) / n).min(99);
2095        }
2096        // For each (sec, dec), is there at least one row in mask?
2097        let mut covered = [false; 200];
2098        for &row in s.mask.iter() {
2099            let stratum = secondary[row] as usize * 100 + decile[row];
2100            covered[stratum] = true;
2101        }
2102        // All 200 strata are non-empty in this synthetic, so all must be
2103        // covered.
2104        for (stratum, &c) in covered.iter().enumerate() {
2105            assert!(c, "stratum {} uncovered", stratum);
2106        }
2107    }
2108
2109    #[test]
2110    fn deterministic_seed() {
2111        // Same inputs + seed must produce identical masks; different seeds
2112        // produce different masks (with overwhelming probability for these
2113        // sizes).
2114        let n: usize = 5_000;
2115        let z: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
2116        let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2117        let k = 800;
2118        let a = build_outer_score_subsample(&z, &secondary, k, 0xABCDEF);
2119        let b = build_outer_score_subsample(&z, &secondary, k, 0xABCDEF);
2120        let c = build_outer_score_subsample(&z, &secondary, k, 0xFEDCBA);
2121        assert_eq!(a.mask.as_ref(), b.mask.as_ref());
2122        assert_ne!(a.mask.as_ref(), c.mask.as_ref());
2123    }
2124
2125    #[test]
2126    fn weight_scale_correct() {
2127        // n=10000, k=2000 → weight_scale ≈ 5.0 (allow small overshoot from
2128        // ceil(k * stratum_size / n) summed across strata).
2129        let n: usize = 10_000;
2130        let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
2131        let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2132        let k = 2_000;
2133        let s = build_outer_score_subsample(&z, &secondary, k, 7);
2134        assert!(s.len() >= k);
2135        // overshoot bounded by number of strata (one extra row per stratum
2136        // from the ceil); for 2 × 100 = 200 strata, overshoot ≤ 200.
2137        assert!(
2138            s.len() <= k + 200,
2139            "subsample {} much larger than expected",
2140            s.len()
2141        );
2142        let scale = s.weight_scale;
2143        // expected ≈ 5.0; allow ±10% for the ceiling overshoot.
2144        assert!(
2145            (scale - 5.0).abs() < 0.5,
2146            "weight_scale {} not near 5.0",
2147            scale
2148        );
2149    }
2150}