Skip to main content

gam_models/gamlss/
dispersion_family.rs

1//! #913: dispersion-channel GAMLSS location-scale families.
2//!
3//! Extracted from `gamlss.rs` (issue #780); this module now owns the
4//! dispersion-channel joint-curvature corrections.
5
6use super::weighted_design_products::{mirror_upper_to_lower, xt_diag_x_design, xt_diag_y_design};
7// `Order2<2>::value()` is the JetScalar trait method; bring the trait into scope
8// so the dispersion row-NLL value reads (`-tower.value()`) resolve (E0599 fix).
9use super::{
10    BlockwiseTermFitResult, GamlssLambdaLayout, LOCATION_SCALE_N_OUTPUTS,
11    LocationScaleFamilyBuilder, build_location_scale_block, fit_location_scale_terms,
12    identity_penalty, solve_penalizedweighted_projection,
13};
14use crate::block_layout::block_count::validate_block_count;
15use crate::custom_family::{
16    BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyBlockPsiDerivative,
17    FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
18};
19use crate::gamlss::GamlssError;
20use crate::model_types::UnifiedFitResult;
21use gam_linalg::matrix::LinearOperator;
22use gam_math::jet_scalar::JetScalar;
23use gam_terms::smooth::{
24    SpatialLengthScaleOptimizationOptions, TermCollectionDesign, TermCollectionSpec,
25};
26use ndarray::{Array1, Array2, s};
27use statrs::function::gamma::ln_gamma;
28
29// ============================================================================
30// #913: dispersion-channel GAMLSS location-scale families.
31//
32// `noise_formula` (a second linear predictor on the dispersion channel) was
33// wired only for Gaussian/Binomial location-scale and the survival families.
34// The genuine-dispersion mean families — NegativeBinomial, Gamma, Beta and
35// Tweedie — were mean-only with a single scalar dispersion. This module adds a
36// SINGLE generic two-block family that routes all four through the existing
37// blockwise REML engine and the shared `LocationScaleFamilyBuilder` /
38// `fit_location_scale_terms` plumbing, so the κ-coordinate assembly, warm
39// start, shrinkage-penalised scale block and result extraction are reused
40// verbatim. A family is added by supplying only its per-row log-likelihood and
41// the (mean, log-precision) working sets — everything else is shared.
42//
43// Block layout: block 0 = mean predictor (η_μ, log link for NB/Gamma/Tweedie,
44// logit for Beta); block 1 = log-precision predictor (η_d). The dispersion
45// channel models log(precision) uniformly — `θ` for NegativeBinomial, the
46// shape `ν` for Gamma, `φ` for Beta, and `1/φ` for Tweedie — so a larger η_d
47// always means *less* dispersion, matching the Gaussian/Binomial convention
48// where η_logσ smaller ⇒ tighter. With no `noise_formula` the log-precision
49// block is a single intercept and the fit reduces to the scalar-dispersion
50// model.
51//
52// NB2 with `(μ, θ)` and the exponential-dispersion members here with
53// `(μ, φ)` are Fisher-orthogonal in their standard mean/dispersion
54// parameterizations: Gamma uses shape `ν = 1/φ`, and Tweedie models
55// `log(1/φ)`, so those precision-channel transforms preserve zero expected
56// mean/dispersion cross information. Beta is the exception in this module's
57// mean/precision parameterization. For `Beta(μφ, (1−μ)φ)`,
58//
59//   I_{μ,φ} = φ · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
60//
61// so in predictor coordinates `(η_μ = logit μ, η_φ = log φ)` the Fisher cross
62// block is
63//
64//   I_{η_μ,η_φ} = μ(1−μ) φ² · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
65//
66// which is generically nonzero. Block-cyclic Fisher-scoring IRLS is still a
67// valid block coordinate solve for the point estimate, but joint-curvature
68// consumers (`log|H|`, coefficient covariance, posterior draws) must receive
69// Beta's off-diagonal coefficient block. Smoothing-parameter selection still
70// runs through the engine's first-order (gradient-only) outer path: the family
71// declines the dense outer Hessian capability because its working weights
72// couple the two blocks (`W_μ` depends on the precision and vice-versa), which
73// the block-local diagonal-drift hook cannot represent exactly.
74// ============================================================================
75
76/// The genuine-dispersion mean family whose precision (overdispersion) channel
77/// can carry a second `noise_formula` linear predictor (issue #913).
78#[derive(Clone, Copy, Debug, PartialEq)]
79pub enum DispersionFamilyKind {
80    /// NB2: `Var = μ + μ²/θ`; the precision channel models `log θ`.
81    NegativeBinomial,
82    /// Gamma with `Var = μ²/ν`; the precision channel models `log ν` (shape).
83    Gamma,
84    /// Beta(μφ, (1−μ)φ) with a logit mean link; the precision channel models
85    /// `log φ`.
86    Beta,
87    /// Tweedie compound Poisson–Gamma with `Var = φ μ^p`, fixed power `p`; the
88    /// precision channel models `log(1/φ)`. The per-row density uses the
89    /// saddlepoint (Nelder–Pregibon) approximation for `y > 0` and the exact
90    /// point mass at `y = 0`; this is the standard tractable Tweedie ML
91    /// surface (an exact-series φ-derivative is the remaining hard sub-item of
92    /// #913).
93    Tweedie { p: f64 },
94}
95
96impl DispersionFamilyKind {
97    pub const fn family_tag(self) -> &'static str {
98        match self {
99            DispersionFamilyKind::NegativeBinomial => FAMILY_NEGBIN_LOCATION_SCALE,
100            DispersionFamilyKind::Gamma => FAMILY_GAMMA_LOCATION_SCALE,
101            DispersionFamilyKind::Beta => FAMILY_BETA_LOCATION_SCALE,
102            DispersionFamilyKind::Tweedie { .. } => FAMILY_TWEEDIE_LOCATION_SCALE,
103        }
104    }
105
106    /// The mean link is logit for Beta (a probability mean) and log otherwise.
107    pub(crate) const fn mean_is_logit(self) -> bool {
108        matches!(self, DispersionFamilyKind::Beta)
109    }
110
111    /// The mean inverse link this dispersion family fits on: log for
112    /// NegativeBinomial / Gamma / Tweedie, logit for Beta. Single source of
113    /// truth shared by the CLI and FFI save paths so the persisted
114    /// `base_link` never diverges from the fitted channel.
115    pub fn base_link(self) -> gam_problem::InverseLink {
116        use gam_problem::{InverseLink, StandardLink};
117        if self.mean_is_logit() {
118            InverseLink::Standard(StandardLink::Logit)
119        } else {
120            InverseLink::Standard(StandardLink::Log)
121        }
122    }
123
124    /// The family's canonical [`LikelihoodSpec`] (mean response × mean link).
125    /// The overdispersion parameter is estimated by the log-precision channel,
126    /// so the response-family placeholder parameters (`phi`, `theta`) mirror
127    /// the [`resolve_family`](crate::fit_orchestration::materialize::resolve_family) defaults
128    /// and are not consumed as fixed values at predict time. This is the single
129    /// source of truth for the persisted location-scale likelihood so the CLI
130    /// and FFI save paths cannot diverge.
131    pub fn likelihood_spec(self) -> gam_problem::LikelihoodSpec {
132        use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
133        let response = match self {
134            DispersionFamilyKind::NegativeBinomial => ResponseFamily::NegativeBinomial {
135                theta: 1.0,
136                theta_fixed: false,
137            },
138            DispersionFamilyKind::Gamma => ResponseFamily::Gamma,
139            DispersionFamilyKind::Beta => ResponseFamily::Beta { phi: 1.0 },
140            DispersionFamilyKind::Tweedie { p } => ResponseFamily::Tweedie { p },
141        };
142        let link = if self.mean_is_logit() {
143            InverseLink::Standard(StandardLink::Logit)
144        } else {
145            InverseLink::Standard(StandardLink::Log)
146        };
147        LikelihoodSpec::new(response, link)
148    }
149}
150
151pub const FAMILY_NEGBIN_LOCATION_SCALE: &str = "negbin-location-scale";
152pub const FAMILY_GAMMA_LOCATION_SCALE: &str = "gamma-location-scale";
153pub const FAMILY_BETA_LOCATION_SCALE: &str = "beta-location-scale";
154pub const FAMILY_TWEEDIE_LOCATION_SCALE: &str = "tweedie-location-scale";
155
156/// `η` magnitude clamp shared by both channels (mirrors PIRLS `ETA_CLAMP`):
157/// keeps `exp(η)` and the logit jet away from overflow while staying in the
158/// smooth interior of every link.
159pub(super) const DISPERSION_ETA_CLAMP: f64 = 30.0;
160/// Floor for a per-row IRLS working weight / curvature so the block normal
161/// equations stay positive-definite. The working *response* always carries the
162/// exact score, so the stationary point (penalised score = 0) is independent
163/// of this floor; it only conditions the inner solve.
164pub(super) const DISPERSION_MIN_CURVATURE: f64 = 1e-12;
165
166/// Row count above which the per-row dispersion-kernel map fans out across
167/// rayon workers (only when not already running on a worker, to avoid nested
168/// oversubscription). Below it the serial map beats the fork/join overhead.
169/// Mirrors the row-chunk guard in
170/// [`row_coeff_operator`](super::gaussian::row_coeff_operator).
171const DISPERSION_PARALLEL_ROW_THRESHOLD: usize = 1024;
172
173/// Per-row working quantities for both channels at the current `(η_μ, η_d)`.
174pub(super) struct DispersionRowKernel {
175    pub(super) loglik: f64,
176    pub(super) mean_weight: f64,
177    pub(super) mean_response: f64,
178    pub(super) disp_weight: f64,
179    pub(super) disp_response: f64,
180}
181
182#[cfg(test)]
183mod test_support {
184    use super::*;
185
186    /// Test-oracle NB2 row NLL over a generic [`JetScalar<2>`], seeded on the
187    /// natural parameters `(μ, θ)`.
188    #[inline]
189    pub(super) fn dispersion_nb_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
190        yi: f64,
191        mu_value: f64,
192        theta_value: f64,
193        wi: f64,
194    ) -> S {
195        let mu = S::variable(mu_value, 0);
196        let theta = S::variable(theta_value, 1);
197        let tpm = theta.add(&mu);
198        // (theta + yi).ln_gamma() - theta.ln_gamma() - ln_gamma(yi+1)
199        //   + theta*theta.ln() - theta*tpm.ln() + mu.ln()*yi - tpm.ln()*yi
200        let loglik = theta
201            .add(&S::constant(yi))
202            .ln_gamma()
203            .sub(&theta.ln_gamma())
204            .sub(&S::constant(ln_gamma(yi + 1.0)))
205            .add(&theta.mul(&theta.ln()))
206            .sub(&theta.mul(&tpm.ln()))
207            .add(&mu.ln().scale(yi))
208            .sub(&tpm.ln().scale(yi));
209        loglik.scale(-wi)
210    }
211
212    /// Test-oracle Gamma row NLL over a generic [`JetScalar<2>`], seeded on
213    /// `(μ, ν)`.
214    #[inline]
215    pub(super) fn dispersion_gamma_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
216        yi: f64,
217        y_pos: f64,
218        mu_value: f64,
219        nu_value: f64,
220        wi: f64,
221    ) -> S {
222        let mu = S::variable(mu_value, 0);
223        let nu = S::variable(nu_value, 1);
224        // nu*nu.ln() - nu*mu.ln() - nu.ln_gamma() + (nu-1)*y_pos.ln() - nu*(mu.recip()*yi)
225        let loglik = nu
226            .mul(&nu.ln())
227            .sub(&nu.mul(&mu.ln()))
228            .sub(&nu.ln_gamma())
229            .add(&nu.sub(&S::constant(1.0)).scale(y_pos.ln()))
230            .sub(&nu.mul(&mu.recip().scale(yi)));
231        loglik.scale(-wi)
232    }
233
234    /// Test-oracle Beta row NLL over a generic [`JetScalar<2>`], seeded on
235    /// `(μ, φ)`.
236    #[inline]
237    pub(super) fn dispersion_beta_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
238        yi: f64,
239        mu_value: f64,
240        phi_value: f64,
241        wi: f64,
242    ) -> S {
243        let mu = S::variable(mu_value, 0);
244        let phi = S::variable(phi_value, 1);
245        let one_minus_mu = S::constant(1.0).sub(&mu);
246        let yc = yi.clamp(1e-12, 1.0 - 1e-12);
247        let a = mu.mul(&phi);
248        let b = one_minus_mu.mul(&phi);
249        // phi.ln_gamma() - a.ln_gamma() - b.ln_gamma()
250        //   + (a-1)*yc.ln() + (b-1)*(1-yc).ln()
251        let loglik = phi
252            .ln_gamma()
253            .sub(&a.ln_gamma())
254            .sub(&b.ln_gamma())
255            .add(&a.sub(&S::constant(1.0)).scale(yc.ln()))
256            .add(&b.sub(&S::constant(1.0)).scale((1.0 - yc).ln()));
257        loglik.scale(-wi)
258    }
259
260    /// #1591 jet-prune oracle: full `Order2<2>` (value/grad/Hessian) NB2 row NLL.
261    ///
262    /// Production no longer consumes the mean (`μ`-axis) derivative channels of
263    /// this tower — the NB mean block is Fisher-orthogonal and hand-written
264    /// exactly in [`dispersion_row_kernel`] — so the hot path uses the pruned
265    /// single-axis [`dispersion_nb_disp_order2`] instead. This `K=2` form
266    /// survives only as the dense-`Tower4<2>` oracle pin
267    /// (`order2_matches_dense_tower_all_channels`).
268    #[inline]
269    pub(super) fn dispersion_nb_nll_order2(
270        yi: f64,
271        mu_value: f64,
272        theta_value: f64,
273        wi: f64,
274    ) -> gam_math::jet_scalar::Order2<2> {
275        type O2 = gam_math::jet_scalar::Order2<2>;
276
277        let mu = O2::variable(mu_value, 0);
278        let theta = O2::variable(theta_value, 1);
279        let tpm = theta.add(&mu);
280        let theta_plus_y = theta.add(&O2::constant(yi));
281        let loglik = order2_ln_gamma(&theta_plus_y)
282            .sub(&order2_ln_gamma(&theta))
283            .sub(&O2::constant(ln_gamma(yi + 1.0)))
284            .add(&theta.mul(&theta.ln()))
285            .sub(&theta.mul(&tpm.ln()))
286            .add(&mu.ln().scale(yi))
287            .sub(&tpm.ln().scale(yi));
288        loglik.scale(-wi)
289    }
290
291    /// #1591 jet-prune oracle: full `Order2<2>` Gamma row NLL. As with NB, the
292    /// mean axis is unused in production (hand-written, Fisher-orthogonal); the
293    /// hot path uses the single-axis [`dispersion_gamma_disp_order2`]. Kept only
294    /// as the dense-tower oracle pin.
295    #[inline]
296    pub(super) fn dispersion_gamma_nll_order2(
297        yi: f64,
298        y_pos: f64,
299        mu_value: f64,
300        nu_value: f64,
301        wi: f64,
302    ) -> gam_math::jet_scalar::Order2<2> {
303        type O2 = gam_math::jet_scalar::Order2<2>;
304
305        let mu = O2::variable(mu_value, 0);
306        let nu = O2::variable(nu_value, 1);
307        let loglik = nu
308            .mul(&nu.ln())
309            .sub(&nu.mul(&mu.ln()))
310            .sub(&order2_ln_gamma(&nu))
311            .add(&nu.sub(&O2::constant(1.0)).scale(y_pos.ln()))
312            .sub(&nu.mul(&mu.recip().scale(yi)));
313        loglik.scale(-wi)
314    }
315}
316
317/// Production `Order2<2>` Beta row NLL (value/grad/Hessian hot path; the cross
318/// channel `h()[0][1]` feeds the Beta observed cross weight).
319#[inline]
320pub(crate) fn dispersion_beta_nll_order2(
321    yi: f64,
322    mu_value: f64,
323    phi_value: f64,
324    wi: f64,
325) -> gam_math::jet_scalar::Order2<2> {
326    type O2 = gam_math::jet_scalar::Order2<2>;
327
328    let mu = O2::variable(mu_value, 0);
329    let phi = O2::variable(phi_value, 1);
330    let one_minus_mu = O2::constant(1.0).sub(&mu);
331    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
332    let a = mu.mul(&phi);
333    let b = one_minus_mu.mul(&phi);
334    let loglik = order2_ln_gamma(&phi)
335        .sub(&order2_ln_gamma(&a))
336        .sub(&order2_ln_gamma(&b))
337        .add(&a.sub(&O2::constant(1.0)).scale(yc.ln()))
338        .add(&b.sub(&O2::constant(1.0)).scale((1.0 - yc).ln()));
339    loglik.scale(-wi)
340}
341
342#[inline]
343fn order2_ln_gamma<const K: usize>(
344    x: &gam_math::jet_scalar::Order2<K>,
345) -> gam_math::jet_scalar::Order2<K> {
346    gam_math::jet_scalar::Order2(
347        x.0.compose_unary(gam_math::jet_tower::ln_gamma_derivative_stack_order2(x.0.v)),
348    )
349}
350
351// ============================================================================
352// #1591 jet-prune: single-axis (`K=1`) dispersion-channel towers.
353//
354// For NegativeBinomial / Gamma / Tweedie the production row kernel consumes ONLY
355// the dispersion-axis derivatives (`g[disp]`, `h[disp][disp]`) and the value;
356// the mean block is Fisher-orthogonal and assembled in closed form. Seeding the
357// mean as a CONSTANT and the dispersion parameter as the SOLE jet variable
358// therefore yields a tower whose `(value, g[0], h[0][0])` are `to_bits`-
359// identical to the consumed `(value, g[1], h[1][1])` of the old `Order2<2>`
360// tower — the mean seed only ever populated the now-discarded `g[mean]` /
361// `h[mean][·]` channels (Leibniz/Faà-di-Bruno never read the dispersion-axis
362// channels off the mean seed). Collapsing `K=2 → K=1` quarters the Hessian
363// tensor (1 entry vs 4) and halves the gradient, with no change to any consumed
364// float bit. The `ln_gamma` derivative stacks are unchanged (the irreducible
365// transcendental cost), so this trims the rational composition, not the special
366// functions.
367// ============================================================================
368
369/// Pruned single-axis NB2 dispersion tower: `θ` is the sole jet variable
370/// (axis 0), `μ` a constant. `value`/`g[0]`/`h[0][0]` reproduce the consumed
371/// `value`/`g[1]`/`h[1][1]` of `dispersion_nb_nll_order2` bit-for-bit.
372#[inline]
373pub(crate) fn dispersion_nb_disp_order2(
374    yi: f64,
375    mu_value: f64,
376    theta_value: f64,
377    wi: f64,
378) -> gam_math::jet_scalar::Order2<1> {
379    type O1 = gam_math::jet_scalar::Order2<1>;
380
381    let mu = O1::constant(mu_value);
382    let theta = O1::variable(theta_value, 0);
383    let tpm = theta.add(&mu);
384    let theta_plus_y = theta.add(&O1::constant(yi));
385    let loglik = order2_ln_gamma(&theta_plus_y)
386        .sub(&order2_ln_gamma(&theta))
387        .sub(&O1::constant(ln_gamma(yi + 1.0)))
388        .add(&theta.mul(&theta.ln()))
389        .sub(&theta.mul(&tpm.ln()))
390        .add(&mu.ln().scale(yi))
391        .sub(&tpm.ln().scale(yi));
392    loglik.scale(-wi)
393}
394
395/// Pruned single-axis Gamma dispersion tower: `ν` is the sole jet variable
396/// (axis 0), `μ` a constant. Consumed channels match
397/// `dispersion_gamma_nll_order2` index-1 bit-for-bit.
398#[inline]
399pub(crate) fn dispersion_gamma_disp_order2(
400    yi: f64,
401    y_pos: f64,
402    mu_value: f64,
403    nu_value: f64,
404    wi: f64,
405) -> gam_math::jet_scalar::Order2<1> {
406    type O1 = gam_math::jet_scalar::Order2<1>;
407
408    let mu = O1::constant(mu_value);
409    let nu = O1::variable(nu_value, 0);
410    let loglik = nu
411        .mul(&nu.ln())
412        .sub(&nu.mul(&mu.ln()))
413        .sub(&order2_ln_gamma(&nu))
414        .add(&nu.sub(&O1::constant(1.0)).scale(y_pos.ln()))
415        .sub(&nu.mul(&mu.recip().scale(yi)));
416    loglik.scale(-wi)
417}
418
419/// Pruned single-axis Tweedie dispersion tower seeded on the predictor `η_d`
420/// (axis 0), with `η_μ` a constant (so `μ = exp(η_μ)` carries no jet). The
421/// `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature are carried
422/// exactly as in `dispersion_tweedie_nll_generic`; `value`/`g[0]`/`h[0][0]`
423/// match that program's `value`/`g[1]`/`h[1][1]` bit-for-bit.
424#[inline]
425pub(crate) fn dispersion_tweedie_disp_order2(
426    yi: f64,
427    eta_mu: f64,
428    eta_d: f64,
429    p: f64,
430    wi: f64,
431) -> gam_math::jet_scalar::Order2<1> {
432    type O1 = gam_math::jet_scalar::Order2<1>;
433
434    let one_minus_p = 1.0 - p;
435    let two_minus_p = 2.0 - p;
436    let mu = O1::constant(eta_mu).exp();
437    let phi = O1::variable(eta_d, 0).scale(-1.0).exp();
438    if yi > 0.0 {
439        let dev = mu
440            .powf(two_minus_p)
441            .scale(1.0 / two_minus_p)
442            .sub(&mu.powf(one_minus_p).scale(yi / one_minus_p))
443            .add(&O1::constant(
444                yi.powf(two_minus_p) / (one_minus_p * two_minus_p),
445            ))
446            .scale(2.0);
447        let loglik = dev
448            .mul(&phi.recip().scale(-0.5))
449            .sub(&phi.scale(2.0 * std::f64::consts::PI).ln().scale(0.5))
450            .sub(&O1::constant(0.5 * p * yi.ln()));
451        loglik.scale(-wi)
452    } else {
453        let c = mu.powf(two_minus_p).scale(1.0 / two_minus_p);
454        let loglik = c.mul(&phi.recip()).scale(-1.0);
455        loglik.scale(-wi)
456    }
457}
458
459// ============================================================================
460// #1591 jet-prune: value-only (`K=0`) row negative-log-likelihood.
461//
462// `log_likelihood_only` reads ONLY `row.loglik = -tower.value()`; the full row
463// kernel it used to call evaluated every dispersion tower's gradient AND Hessian
464// — including the digamma/trigamma derivative stacks — purely to discard them.
465// These functions evaluate the SAME value-channel program in plain `f64`, so
466// they are `to_bits`-identical to `-tower.value()` (the jet value channel is the
467// naive scalar evaluation: `mul.v = a.v*b.v`, `compose.v = stack[0]`), while
468// touching only `ln_gamma` (stack slot 0) and never the digamma/trigamma slots.
469// On a per-row loglik that is the dominant transcendental saving.
470// ============================================================================
471
472/// NB2 row NLL value, plain `f64`, bit-identical to
473/// `-dispersion_nb_disp_order2(..).value()`.
474#[inline]
475fn dispersion_nb_neg_loglik(yi: f64, mu: f64, theta: f64, wi: f64) -> f64 {
476    let tpm = theta + mu;
477    let s = ln_gamma(theta + yi) - ln_gamma(theta) - ln_gamma(yi + 1.0) + theta * theta.ln()
478        - theta * tpm.ln()
479        + mu.ln() * yi
480        - tpm.ln() * yi;
481    -(s * -wi)
482}
483
484/// Gamma row NLL value, plain `f64`, bit-identical to
485/// `-dispersion_gamma_disp_order2(..).value()`.
486#[inline]
487fn dispersion_gamma_neg_loglik(yi: f64, y_pos: f64, mu: f64, nu: f64, wi: f64) -> f64 {
488    // NB: the jet forms `μ.recip().scale(yi)` = `(1/μ)·yᵢ` (reciprocal then
489    // multiply), NOT `yᵢ/μ` (single divide) — these differ in the last bit, so
490    // the value path must reproduce the reciprocal-then-multiply exactly.
491    let s = nu * nu.ln() - nu * mu.ln() - ln_gamma(nu) + (nu - 1.0) * y_pos.ln()
492        - nu * ((1.0 / mu) * yi);
493    -(s * -wi)
494}
495
496/// Beta row NLL value, plain `f64`, bit-identical to
497/// `-dispersion_beta_nll_order2(..).value()`.
498#[inline]
499fn dispersion_beta_neg_loglik(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
500    let one_minus_mu = 1.0 - mu;
501    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
502    let a = mu * phi;
503    let b = one_minus_mu * phi;
504    let s = ln_gamma(phi) - ln_gamma(a) - ln_gamma(b)
505        + (a - 1.0) * yc.ln()
506        + (b - 1.0) * (1.0 - yc).ln();
507    -(s * -wi)
508}
509
510/// Tweedie row NLL value, plain `f64`, bit-identical to
511/// `-dispersion_tweedie_disp_order2(..).value()` (both density branches).
512#[inline]
513fn dispersion_tweedie_neg_loglik(yi: f64, eta_mu: f64, eta_d: f64, p: f64, wi: f64) -> f64 {
514    let one_minus_p = 1.0 - p;
515    let two_minus_p = 2.0 - p;
516    let mu = eta_mu.exp();
517    let phi = (-eta_d).exp();
518    let s = if yi > 0.0 {
519        let dev = (mu.powf(two_minus_p) * (1.0 / two_minus_p)
520            - mu.powf(one_minus_p) * (yi / one_minus_p)
521            + yi.powf(two_minus_p) / (one_minus_p * two_minus_p))
522            * 2.0;
523        dev * ((1.0 / phi) * -0.5)
524            - (phi * (2.0 * std::f64::consts::PI)).ln() * 0.5
525            - 0.5 * p * yi.ln()
526    } else {
527        let c = mu.powf(two_minus_p) * (1.0 / two_minus_p);
528        (c * (1.0 / phi)) * -1.0
529    };
530    -(s * -wi)
531}
532
533/// Value-only row negative log-likelihood for one observation — the pruned hot
534/// path for [`CustomFamily::log_likelihood_only`]. Mirrors the link/clamp
535/// preamble of [`dispersion_row_kernel`] exactly, then evaluates ONLY the value
536/// channel (no gradient/Hessian, no digamma/trigamma). Returns `row.loglik`
537/// `to_bits`-identically.
538#[inline]
539pub(crate) fn dispersion_row_loglik(
540    kind: DispersionFamilyKind,
541    yi: f64,
542    eta_mu: f64,
543    eta_d: f64,
544    prior_weight: f64,
545) -> f64 {
546    let wi = prior_weight.max(0.0);
547    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
548    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
549    match kind {
550        DispersionFamilyKind::NegativeBinomial => {
551            let mu = em.exp().max(1e-300);
552            let theta = ed.exp().max(1e-12);
553            dispersion_nb_neg_loglik(yi, mu, theta, wi)
554        }
555        DispersionFamilyKind::Gamma => {
556            let mu = em.exp().max(1e-300);
557            let nu = ed.exp().max(1e-12);
558            let y_pos = yi.max(1e-300);
559            dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)
560        }
561        DispersionFamilyKind::Beta => {
562            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
563            let phi = ed.exp().max(1e-12);
564            dispersion_beta_neg_loglik(yi, mu, phi, wi)
565        }
566        DispersionFamilyKind::Tweedie { p } => dispersion_tweedie_neg_loglik(yi, em, ed, p, wi),
567    }
568}
569
570#[inline]
571pub(crate) fn beta_observed_cross_weight_eta(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
572    let q = (mu * (1.0 - mu)).max(1e-12);
573    let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
574    q * phi * tower.h()[0][1]
575}
576
577#[inline]
578pub(crate) fn dispersion_row_cross_weight(
579    kind: DispersionFamilyKind,
580    yi: f64,
581    eta_mu: f64,
582    eta_d: f64,
583    prior_weight: f64,
584) -> f64 {
585    let wi = prior_weight.max(0.0);
586    if wi == 0.0 {
587        return 0.0;
588    }
589    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
590    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
591    match kind {
592        DispersionFamilyKind::Beta => {
593            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
594            let phi = ed.exp().max(1e-12);
595            beta_observed_cross_weight_eta(yi, mu, phi, wi)
596        }
597        DispersionFamilyKind::NegativeBinomial
598        | DispersionFamilyKind::Gamma
599        | DispersionFamilyKind::Tweedie { .. } => 0.0,
600    }
601}
602
603#[inline]
604pub(crate) fn tower_score_info<const K: usize>(
605    tower: &gam_math::jet_scalar::Order2<K>,
606    idx: usize,
607    wi: f64,
608) -> (f64, f64) {
609    if wi == 0.0 {
610        (0.0, 0.0)
611    } else {
612        (-tower.g()[idx] / wi, tower.h()[idx][idx] / wi)
613    }
614}
615
616/// Evaluate the row log-likelihood and the (mean, log-precision) Fisher-scoring
617/// working sets for one observation. `eta_mu`/`eta_d` already include any
618/// per-channel offset (they are the block predictors). `prior_weight` is the
619/// observation's prior weight.
620pub(super) fn dispersion_row_kernel(
621    kind: DispersionFamilyKind,
622    yi: f64,
623    eta_mu: f64,
624    eta_d: f64,
625    prior_weight: f64,
626) -> DispersionRowKernel {
627    let wi = prior_weight.max(0.0);
628    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
629    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
630    match kind {
631        DispersionFamilyKind::NegativeBinomial => {
632            let mu = em.exp().max(1e-300);
633            let theta = ed.exp().max(1e-12); // precision (size)
634            let tpm = theta + mu;
635            let tower = dispersion_nb_disp_order2(yi, mu, theta, wi);
636            // Only the exact θ-space SCORE is consumed from the tower; the
637            // observed-Hessian channel is discarded in favor of the expected
638            // (Fisher) information assembled below (see the dispersion-curvature
639            // note). Keeping the value channel for the row log-likelihood.
640            let (s_theta, _info_theta_observed) = tower_score_info(&tower, 0, wi);
641            let loglik = -tower.value();
642            let info_mu = if wi == 0.0 {
643                DISPERSION_MIN_CURVATURE
644            } else {
645                (theta / (mu * tpm)).max(DISPERSION_MIN_CURVATURE)
646            };
647            let score_mu = theta * (yi - mu) / (mu * tpm);
648            let mean_weight = wi * mu * mu * info_mu;
649            let mean_response = em + score_mu / (mu * info_mu);
650            // Dispersion (log-θ) IRLS curvature: use the EXPECTED (Fisher)
651            // information in θ, not the per-row OBSERVED Hessian channel
652            // (`_info_theta_observed`). The NB2 log-likelihood is strongly
653            // non-quadratic in θ: `−∂²ℓ/∂θ²` carries the row-specific term
654            // `ψ′(θ+y)` and goes NEGATIVE for every row whose count sits below
655            // its current fitted precision (overestimated size / underestimated
656            // overdispersion). Far from the optimum a majority of rows can be
657            // negative, so the assembled block curvature `Xᵀdiag(w)X` loses
658            // positive-definiteness; flooring each negative row at
659            // `DISPERSION_MIN_CURVATURE` (≈0) then divides the exact score by
660            // ~0 in the working response, producing O(1e12) IRLS targets that
661            // make the dispersion block step explode and the inner block-cyclic
662            // solve stall (never reaching KKT within the cycle budget — the
663            // `nb` location-scale `IntegrationError`, gam#1606). The mean block
664            // already uses its closed-form expected info `θ/(μ(θ+μ))`; the
665            // dispersion block must do the same.
666            //
667            // The Fisher information in θ has the closed form
668            //   I(θ) = ψ′(θ) − E[ψ′(θ+Y)] − 1/θ + 1/(θ+μ),
669            // whose only costly piece is the per-row infinite expectation
670            // `E[ψ′(θ+Y)]`. Replacing it with the Jensen plug-in `ψ′(θ+μ)`
671            // (valid because ψ′ is convex, so this is a tight lower bound on the
672            // expectation) gives a per-row, sum-free, STRICTLY POSITIVE
673            // curvature
674            //   I_θ ≈ ψ′(θ) − ψ′(θ+μ) − 1/θ + 1/(θ+μ) > 0  for all (μ,θ),
675            // since ψ′ is strictly decreasing. The working RESPONSE still
676            // carries the EXACT score `s_theta` (= ∂ℓ/∂θ from the tower), so the
677            // penalized stationary point (score = 0) is byte-unchanged — this is
678            // Fisher scoring, which only re-conditions the inner solve and never
679            // shifts the optimum (cf. the `DISPERSION_MIN_CURVATURE` contract
680            // note above). The observed channel `_info_theta_observed` is no
681            // longer consumed for the weight.
682            let trigamma_theta = gam_math::jet_tower::trigamma_derivative_stack(theta)[0];
683            let trigamma_tpm = gam_math::jet_tower::trigamma_derivative_stack(tpm)[0];
684            let info_theta_fisher = trigamma_theta - trigamma_tpm - 1.0 / theta + 1.0 / tpm;
685            let info_pos = info_theta_fisher.max(DISPERSION_MIN_CURVATURE);
686            let disp_weight = wi * theta * theta * info_pos;
687            let disp_response = ed + s_theta / (theta * info_pos);
688            DispersionRowKernel {
689                loglik,
690                mean_weight,
691                mean_response,
692                disp_weight,
693                disp_response,
694            }
695        }
696        DispersionFamilyKind::Gamma => {
697            let mu = em.exp().max(1e-300);
698            let nu = ed.exp().max(1e-12); // precision = shape ν
699            let y_pos = yi.max(1e-300);
700            let tower = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
701            let (s_nu, info_nu_raw) = tower_score_info(&tower, 0, wi);
702            let loglik = -tower.value();
703            let info_mu = if wi == 0.0 {
704                DISPERSION_MIN_CURVATURE
705            } else {
706                (nu / (mu * mu)).max(DISPERSION_MIN_CURVATURE)
707            };
708            let score_mu = nu * (yi - mu) / (mu * mu);
709            let mean_weight = wi * mu * mu * info_mu;
710            let mean_response = em + score_mu / (mu * info_mu);
711            let info_nu = info_nu_raw.max(DISPERSION_MIN_CURVATURE);
712            let disp_weight = wi * nu * nu * info_nu;
713            let disp_response = ed + s_nu / (nu * info_nu);
714            DispersionRowKernel {
715                loglik,
716                mean_weight,
717                mean_response,
718                disp_weight,
719                disp_response,
720            }
721        }
722        DispersionFamilyKind::Beta => {
723            // logit mean link.
724            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
725            let phi = ed.exp().max(1e-12); // precision
726            let q = (mu * (1.0 - mu)).max(1e-12); // dμ/dη
727            let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
728            let (score_mu, info_mu_raw) = tower_score_info(&tower, 0, wi);
729            let (s_phi, info_phi_raw) = tower_score_info(&tower, 1, wi);
730            let loglik = -tower.value();
731            let info_mu = info_mu_raw.max(DISPERSION_MIN_CURVATURE);
732            let mean_weight = wi * q * q * info_mu;
733            let mean_response = em + score_mu / (q * info_mu);
734            let info_phi = info_phi_raw.max(DISPERSION_MIN_CURVATURE);
735            let disp_weight = wi * phi * phi * info_phi;
736            let disp_response = ed + s_phi / (phi * info_phi);
737            DispersionRowKernel {
738                loglik,
739                mean_weight,
740                mean_response,
741                disp_weight,
742                disp_response,
743            }
744        }
745        DispersionFamilyKind::Tweedie { p } => {
746            let mu = em.exp().max(1e-300);
747            // Precision channel models log(1/φ) ⇒ φ = exp(−η_d).
748            let phi = (-ed).exp().max(1e-12);
749            let two_minus_p = 2.0 - p;
750            // Mean channel: the quasi-score `(y−μ)/μ` and Fisher weight
751            // `μ^{2−p}/φ` are simple closed forms (and the mean block is
752            // Fisher-orthogonal to the dispersion block in this
753            // parameterization), so they stay hand-written exactly as the
754            // NB/Gamma mean arms do.
755            let mean_weight = wi * mu.powf(two_minus_p) / phi;
756            let mean_response = em + (yi - mu) / mu;
757            // Dispersion channel: the η_d-space score and OBSERVED information
758            // come straight off the single-expression tower seeded on `η_d`
759            // (#932), so the saddlepoint/point-mass branch split, the
760            // `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature
761            // correction are all mechanically carried — no per-branch
762            // `s_phi`/`s_eta`/`curvature_eta` hand calculus. #1591: only the
763            // η_d axis is consumed, so the tower is the pruned single-axis
764            // `Order2<1>` (`η_μ` enters as a constant).
765            let tower = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
766            let loglik = -tower.value();
767            // η_d-space score and observed information off the tower, via the
768            // same helper the NB/Gamma/Beta arms use (returns `(0, 0)` when the
769            // prior weight is zero, so the row stays excluded below).
770            let (s_eta, info_eta_raw) = tower_score_info(&tower, 0, wi);
771            let curvature_eta = if wi == 0.0 {
772                DISPERSION_MIN_CURVATURE
773            } else {
774                info_eta_raw.max(DISPERSION_MIN_CURVATURE)
775            };
776            // The working response divides by this per-row curvature so the
777            // prior weight cancels (and a zero-prior-weight row stays excluded
778            // via `disp_weight = 0`).
779            let disp_weight = wi * curvature_eta;
780            let disp_response = ed + s_eta / curvature_eta;
781            DispersionRowKernel {
782                loglik,
783                mean_weight,
784                mean_response,
785                disp_weight,
786                disp_response,
787            }
788        }
789    }
790}
791
792/// Two-block GAMLSS family for the genuine-dispersion mean families (#913).
793#[derive(Clone)]
794pub(crate) struct DispersionGlmLocationScaleFamily {
795    pub(crate) kind: DispersionFamilyKind,
796    pub(crate) y: Array1<f64>,
797    pub(crate) weights: Array1<f64>,
798}
799
800impl DispersionGlmLocationScaleFamily {
801    pub(crate) const BLOCK_MEAN: usize = 0;
802    pub(crate) const BLOCK_DISP: usize = 1;
803}
804
805impl CustomFamily for DispersionGlmLocationScaleFamily {
806    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
807    // flat-prior exact-Newton objective carries no Jeffreys term), so families
808    // that historically armed the term by default opt back in explicitly.
809    fn joint_jeffreys_term_required(&self) -> bool {
810        true
811    }
812
813    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
814        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
815        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
816        let eta_d = &block_states[Self::BLOCK_DISP].eta;
817        let n = self.y.len();
818        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
819            return Err(format!(
820                "{} row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
821                self.kind.family_tag(),
822                eta_mu.len(),
823                eta_d.len(),
824                self.weights.len()
825            ));
826        }
827        let mut mean_weights = Array1::<f64>::zeros(n);
828        let mut mean_response = Array1::<f64>::zeros(n);
829        let mut disp_weights = Array1::<f64>::zeros(n);
830        let mut disp_response = Array1::<f64>::zeros(n);
831
832        // `dispersion_row_kernel` is a pure, row-independent map — each row reads
833        // only `y[i]`/`eta_mu[i]`/`eta_d[i]`/`weights[i]` and writes nothing
834        // shared — and it is transcendental-heavy (per-row digamma/trigamma
835        // derivative stacks), so the per-row evaluation is embarrassingly
836        // row-parallel. Materialize the per-row kernels (in parallel for large
837        // `n` when not already on a rayon worker; mirrors the
838        // `row_coeff_operator` guard), then reduce SERIALLY in index order so
839        // the log-likelihood sum is bit-identical to the old serial loop — no
840        // float reassociation. The reduction touches no transcendentals, so the
841        // parallel kernel map captures essentially all the savings.
842        let kernels: Vec<DispersionRowKernel> = if rayon::current_thread_index().is_none()
843            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
844        {
845            use rayon::iter::{IntoParallelIterator, ParallelIterator};
846            (0..n)
847                .into_par_iter()
848                .map(|i| {
849                    dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
850                })
851                .collect()
852        } else {
853            (0..n)
854                .map(|i| {
855                    dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
856                })
857                .collect()
858        };
859
860        let mut log_likelihood = 0.0;
861        for (i, row) in kernels.into_iter().enumerate() {
862            if row.loglik.is_finite() {
863                log_likelihood += row.loglik;
864            }
865            mean_weights[i] = row.mean_weight.max(0.0);
866            mean_response[i] = row.mean_response;
867            disp_weights[i] = row.disp_weight.max(0.0);
868            disp_response[i] = row.disp_response;
869        }
870        Ok(FamilyEvaluation {
871            log_likelihood,
872            blockworking_sets: vec![
873                BlockWorkingSet::diagonal_checked(mean_response, mean_weights)?,
874                BlockWorkingSet::diagonal_checked(disp_response, disp_weights)?,
875            ],
876        })
877    }
878
879    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
880        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
881        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
882        let eta_d = &block_states[Self::BLOCK_DISP].eta;
883        let n = self.y.len();
884        // #1591 prune: the objective needs only the row log-likelihood, so each
885        // row evaluates the value channel alone (`to_bits`-identical to
886        // `dispersion_row_kernel(..).loglik`), skipping every gradient/Hessian
887        // and digamma/trigamma derivative-stack evaluation. That value-only map
888        // is still a pure, row-independent per-row `ln_gamma` evaluation, so it
889        // is row-parallel; fan it out (large `n`, off a rayon worker) into a
890        // per-row buffer, then sum SERIALLY in index order to keep the objective
891        // bit-identical to the serial loop (no float reassociation).
892        let per_row: Vec<f64> = if rayon::current_thread_index().is_none()
893            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
894        {
895            use rayon::iter::{IntoParallelIterator, ParallelIterator};
896            (0..n)
897                .into_par_iter()
898                .map(|i| {
899                    dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
900                })
901                .collect()
902        } else {
903            (0..n)
904                .map(|i| {
905                    dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
906                })
907                .collect()
908        };
909        let mut ll = 0.0;
910        for loglik in per_row {
911            if loglik.is_finite() {
912                ll += loglik;
913            }
914        }
915        Ok(ll)
916    }
917
918    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
919        crate::location_scale_engine::location_scale_coefficient_hessian_cost(
920            self.y.len() as u64,
921            specs,
922        )
923    }
924
925    /// Exact joint coefficient-space Hessian `H_L = -∇²log L` in flattened
926    /// `[mean | log-precision]` block order.
927    ///
928    /// All four members assemble the same `Xᵀ diag(W) X` blocks; the cross
929    /// block is the per-row mixed weight `dispersion_row_cross_weight`. Beta
930    /// carries a genuinely nonzero (η_μ, η_φ) cross weight; the Fisher-
931    /// orthogonal members (NegativeBinomial / Gamma / Tweedie) report a zero
932    /// cross weight, so this returns their exact *block-diagonal* joint
933    /// Hessian. Returning that block-diagonal `H_L` — rather than `None` —
934    /// is what lets the multi-block outer-REML path (`build_joint_hessian_
935    /// closures` → `joint_outer_evaluate`) and the joint posterior covariance
936    /// (`compute_joint_covariance`) run for these families instead of failing
937    /// the "multi-block families must provide a joint outer path" gate and
938    /// silently escalating to a degraded ρ-seed fit with no covariance/EDF
939    /// (gam#1119). The orthogonal members additionally declare
940    /// `likelihood_blocks_uncoupled() = true` so the directional-derivative
941    /// and Jeffreys dispatch route through the block-diagonal-exact fallback
942    /// rather than rejecting the structurally-uncoupled Hessian.
943    fn exact_newton_joint_hessian_with_specs(
944        &self,
945        block_states: &[ParameterBlockState],
946        specs: &[ParameterBlockSpec],
947    ) -> Result<Option<Array2<f64>>, String> {
948        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
949        if specs.len() != 2 {
950            return Err(format!(
951                "{} exact joint Hessian expects 2 specs, got {}",
952                self.kind.family_tag(),
953                specs.len()
954            ));
955        }
956        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
957        let eta_d = &block_states[Self::BLOCK_DISP].eta;
958        let n = self.y.len();
959        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
960            return Err(format!(
961                "{} exact joint Hessian row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
962                self.kind.family_tag(),
963                eta_mu.len(),
964                eta_d.len(),
965                self.weights.len()
966            ));
967        }
968
969        let eval = self.evaluate(block_states)?;
970        let BlockWorkingSet::Diagonal {
971            working_weights: mean_weights,
972            ..
973        } = &eval.blockworking_sets[Self::BLOCK_MEAN]
974        else {
975            return Err(format!(
976                "{} dispersion mean block did not return diagonal weights",
977                self.kind.family_tag()
978            ));
979        };
980        let BlockWorkingSet::Diagonal {
981            working_weights: disp_weights,
982            ..
983        } = &eval.blockworking_sets[Self::BLOCK_DISP]
984        else {
985            return Err(format!(
986                "{} dispersion precision block did not return diagonal weights",
987                self.kind.family_tag()
988            ));
989        };
990
991        // Per-row mixed `(η_μ, η_d)` weight; for Beta this is a full `Order2<2>`
992        // tower per row (the orthogonal members return 0 cheaply). Row-
993        // independent, so fan it out for large `n` (off a rayon worker) into a
994        // per-row buffer — index-ordered, no reduction, so byte-identical to the
995        // serial `from_shape_fn`.
996        let cross_weights = if rayon::current_thread_index().is_none()
997            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
998        {
999            use rayon::iter::{IntoParallelIterator, ParallelIterator};
1000            Array1::from_vec(
1001                (0..n)
1002                    .into_par_iter()
1003                    .map(|i| {
1004                        dispersion_row_cross_weight(
1005                            self.kind,
1006                            self.y[i],
1007                            eta_mu[i],
1008                            eta_d[i],
1009                            self.weights[i],
1010                        )
1011                    })
1012                    .collect::<Vec<f64>>(),
1013            )
1014        } else {
1015            Array1::from_shape_fn(n, |i| {
1016                dispersion_row_cross_weight(
1017                    self.kind,
1018                    self.y[i],
1019                    eta_mu[i],
1020                    eta_d[i],
1021                    self.weights[i],
1022                )
1023            })
1024        };
1025        let mean_spec = &specs[Self::BLOCK_MEAN];
1026        let disp_spec = &specs[Self::BLOCK_DISP];
1027        if mean_spec.design.nrows() != n || disp_spec.design.nrows() != n {
1028            return Err(format!(
1029                "{} exact joint Hessian design row mismatch: y={n}, mean rows={}, precision rows={}",
1030                self.kind.family_tag(),
1031                mean_spec.design.nrows(),
1032                disp_spec.design.nrows()
1033            ));
1034        }
1035        let p_mean = mean_spec.design.ncols();
1036        let p_disp = disp_spec.design.ncols();
1037        if block_states[Self::BLOCK_MEAN].beta.len() != p_mean
1038            || block_states[Self::BLOCK_DISP].beta.len() != p_disp
1039        {
1040            return Err(format!(
1041                "{} exact joint Hessian beta/design mismatch: mean beta {} vs cols {}, precision beta {} vs cols {}",
1042                self.kind.family_tag(),
1043                block_states[Self::BLOCK_MEAN].beta.len(),
1044                p_mean,
1045                block_states[Self::BLOCK_DISP].beta.len(),
1046                p_disp
1047            ));
1048        }
1049
1050        let h_mean = xt_diag_x_design(&mean_spec.design, mean_weights)?;
1051        let h_cross = xt_diag_y_design(&mean_spec.design, &cross_weights, &disp_spec.design)?;
1052        let h_disp = xt_diag_x_design(&disp_spec.design, disp_weights)?;
1053        let total = p_mean + p_disp;
1054        let mut h = Array2::<f64>::zeros((total, total));
1055        h.slice_mut(s![0..p_mean, 0..p_mean]).assign(&h_mean);
1056        h.slice_mut(s![0..p_mean, p_mean..total]).assign(&h_cross);
1057        h.slice_mut(s![p_mean..total, p_mean..total])
1058            .assign(&h_disp);
1059        mirror_upper_to_lower(&mut h);
1060        Ok(Some(h))
1061    }
1062
1063    /// Whether the joint likelihood Hessian is block-diagonal in the
1064    /// `[mean | log-precision]` coefficient vector.
1065    ///
1066    /// `Beta(μφ, (1−μ)φ)` carries a genuinely nonzero `(η_μ, η_φ)` Fisher
1067    /// cross block (see the module header), so its blocks are coupled. The
1068    /// remaining members are Fisher-orthogonal in their mean/precision
1069    /// parameterizations — NB2 `(μ, θ)`, Gamma shape `ν = 1/φ`, Tweedie
1070    /// `log(1/φ)` — so `∂²L/∂β_μ∂β_d = 0` and the joint Hessian is exactly
1071    /// block-diagonal. Declaring that here lets the trait's directional-
1072    /// derivative / Jeffreys dispatch accept the block-diagonal joint Hessian
1073    /// via the working-set-exact fallback instead of rejecting it as an
1074    /// untrusted structurally-uncoupled override (which would strand the
1075    /// outer-REML gradient with a "dH unavailable" error, gam#1119).
1076    fn likelihood_blocks_uncoupled(&self) -> bool {
1077        !matches!(self.kind, DispersionFamilyKind::Beta)
1078    }
1079
1080    /// The mean and precision working weights couple across both blocks, which
1081    /// the block-local diagonal drift hook cannot represent, so decline the
1082    /// dense outer Hessian capability whenever the actual two-block (or
1083    /// larger) geometry is in play; a degenerate single-block probe — there
1084    /// is no cross-block coupling to reject — keeps the trait default's
1085    /// availability verdict.
1086    ///
1087    /// The override still validates the block-spec slice it is handed (the
1088    /// same consistency check the trait default's assertion bottoms out in)
1089    /// so a malformed probe is reported here rather than downstream.
1090    fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
1091        assert!(
1092            crate::custom_family::validate_blockspec_consistency(specs).is_ok(),
1093            "DispersionGlmLocationScale outer hyper-Hessian dense availability: \
1094             inconsistent parameter block specs"
1095        );
1096        specs.len() < 2
1097    }
1098}
1099
1100/// Term spec consumed by [`fit_dispersion_glm_location_scale_terms`]; mirrors
1101/// [`GaussianLocationScaleTermSpec`](super::GaussianLocationScaleTermSpec) with
1102/// the dispersion channel in place of the Gaussian log-σ channel.
1103pub struct DispersionGlmLocationScaleTermSpec {
1104    pub kind: DispersionFamilyKind,
1105    pub y: Array1<f64>,
1106    pub weights: Array1<f64>,
1107    pub meanspec: TermCollectionSpec,
1108    pub log_dispspec: TermCollectionSpec,
1109    pub mean_offset: Array1<f64>,
1110    pub log_disp_offset: Array1<f64>,
1111}
1112
1113pub(crate) struct DispersionGlmLocationScaleTermBuilder {
1114    pub(crate) kind: DispersionFamilyKind,
1115    pub(crate) y: Array1<f64>,
1116    pub(crate) weights: Array1<f64>,
1117    pub(crate) meanspec: TermCollectionSpec,
1118    pub(crate) noisespec: TermCollectionSpec,
1119    pub(crate) mean_offset: Array1<f64>,
1120    pub(crate) noise_offset: Array1<f64>,
1121}
1122
1123/// Warm start for a dispersion location-scale fit: project a link-transformed
1124/// response onto the mean block and seed the log-precision block at a constant
1125/// (precision ≈ 1) baseline. The block-cyclic IRLS then refines both jointly.
1126pub(crate) fn dispersion_location_scale_warm_start(
1127    kind: DispersionFamilyKind,
1128    y: &Array1<f64>,
1129    weights: &Array1<f64>,
1130    mean_block: &ParameterBlockSpec,
1131    disp_block: &ParameterBlockSpec,
1132    mean_beta_hint: Option<&Array1<f64>>,
1133    disp_beta_hint: Option<&Array1<f64>>,
1134) -> Result<(Array1<f64>, Array1<f64>), String> {
1135    let ridge_floor = 1e-10;
1136    let mean_beta = if let Some(beta) = mean_beta_hint {
1137        beta.clone()
1138    } else {
1139        let target = Array1::from_shape_fn(y.len(), |i| {
1140            if kind.mean_is_logit() {
1141                let yi = y[i].clamp(1e-3, 1.0 - 1e-3);
1142                (yi / (1.0 - yi)).ln()
1143            } else {
1144                // log mean link; the +0.1 keeps zero counts finite.
1145                (y[i].max(0.0) + 0.1).ln()
1146            }
1147        });
1148        solve_penalizedweighted_projection(
1149            &mean_block.design,
1150            &mean_block.offset,
1151            &target,
1152            weights,
1153            &mean_block.penalties,
1154            &mean_block.initial_log_lambdas,
1155            ridge_floor,
1156        )?
1157    };
1158    let disp_beta = if let Some(beta) = disp_beta_hint {
1159        beta.clone()
1160    } else {
1161        // Seed the precision block from a smoothed method-of-moments surface
1162        // rather than the old flat η_d=0 constant.  A single observation cannot
1163        // identify its own variance, but for the Fisher-orthogonal dispersion
1164        // members the residual-squared moment contains the correct first-order
1165        // signal:
1166        //
1167        //   Gamma:   Var(Y)=μ²/ν              ⇒ log ν     ≈ log(μ²/e²)
1168        //   NB2:     Var(Y)=μ+μ²/θ            ⇒ log θ     ≈ log(μ²/(e²-μ))
1169        //   Tweedie: Var(Y)=φ μ^p, η_d=log1/φ ⇒ η_d       ≈ log(μ^p/e²)
1170        //
1171        // The targets are deliberately conservative (finite residual floor,
1172        // precision cap, and no fixture-specific constants): they only give the
1173        // block-cyclic likelihood solve a correctly-signed non-flat starting
1174        // surface, while the final estimate is still the penalized joint MLE.
1175        let mean_eta = mean_block.design.apply(&mean_beta) + &mean_block.offset;
1176        let target = Array1::from_shape_fn(y.len(), |i| {
1177            dispersion_moment_log_precision_seed(kind, y[i], mean_eta[i])
1178        });
1179        solve_penalizedweighted_projection(
1180            &disp_block.design,
1181            &disp_block.offset,
1182            &target,
1183            weights,
1184            &disp_block.penalties,
1185            &disp_block.initial_log_lambdas,
1186            ridge_floor,
1187        )?
1188    };
1189    Ok((mean_beta, disp_beta))
1190}
1191
1192#[inline]
1193fn dispersion_moment_log_precision_seed(kind: DispersionFamilyKind, yi: f64, eta_mu: f64) -> f64 {
1194    const LOG_PRECISION_FLOOR: f64 = -10.0;
1195    const LOG_PRECISION_CEILING: f64 = 10.0;
1196    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1197    let raw = match kind {
1198        DispersionFamilyKind::Beta => {
1199            // Beta's mean and precision scores are not Fisher-orthogonal in
1200            // the (logit μ, log φ) parameterization.  Per-row residual moments
1201            // therefore make a poor block-cyclic seed: an outlying y near 0/1
1202            // can imply a near-zero φ and pull the coupled mean block onto the
1203            // boundary before the joint likelihood has had a chance to settle.
1204            // Keep the neutral precision seed for this one coupled member; the
1205            // exact Beta cross-Hessian below still drives the joint solve and
1206            // covariance with the coherent two-block likelihood geometry.
1207            0.0
1208        }
1209        DispersionFamilyKind::Gamma => {
1210            let mu = em.exp().max(1e-12);
1211            let e2 = (yi - mu).powi(2).max(1e-8 * mu * mu);
1212            (mu * mu / e2).max(1e-6).ln()
1213        }
1214        DispersionFamilyKind::NegativeBinomial => {
1215            let mu = em.exp().max(1e-12);
1216            let e2 = (yi - mu).powi(2);
1217            let excess = (e2 - mu).max(1e-6 * (mu + mu * mu));
1218            (mu * mu / excess).max(1e-6).ln()
1219        }
1220        DispersionFamilyKind::Tweedie { p } => {
1221            let mu = em.exp().max(1e-12);
1222            let e2 = (yi - mu).powi(2).max(1e-8 * mu.powf(p));
1223            (mu.powf(p) / e2).max(1e-6).ln()
1224        }
1225    };
1226    raw.clamp(LOG_PRECISION_FLOOR, LOG_PRECISION_CEILING)
1227}
1228
1229impl LocationScaleFamilyBuilder for DispersionGlmLocationScaleTermBuilder {
1230    type Family = DispersionGlmLocationScaleFamily;
1231
1232    fn meanspec(&self) -> &TermCollectionSpec {
1233        &self.meanspec
1234    }
1235
1236    fn noisespec(&self) -> &TermCollectionSpec {
1237        &self.noisespec
1238    }
1239
1240    fn noise_penalty_count(&self, noise_design: &TermCollectionDesign) -> usize {
1241        // Mirror the Gaussian/Binomial scale block: a full-span shrinkage
1242        // penalty pins the log-precision nullspace so REML does not optimise
1243        // the dispersion smoothing on a flat surface.
1244        noise_design.penalties.len() + 1
1245    }
1246
1247    fn build_blocks(
1248        &self,
1249        theta: &Array1<f64>,
1250        mean_design: &TermCollectionDesign,
1251        noise_design: &TermCollectionDesign,
1252        mean_beta_hint: Option<Array1<f64>>,
1253        noise_beta_hint: Option<Array1<f64>>,
1254    ) -> Result<Vec<ParameterBlockSpec>, String> {
1255        let layout = GamlssLambdaLayout::two_block(
1256            mean_design.penalties.len(),
1257            self.noise_penalty_count(noise_design),
1258        );
1259        layout.validate_theta_len(theta.len(), "dispersion location-scale")?;
1260
1261        let mut meanspec = build_location_scale_block(
1262            "mu",
1263            mean_design.design.clone(),
1264            self.mean_offset.clone(),
1265            mean_design.penalties_as_penalty_matrix(),
1266            mean_design.nullspace_dims.clone(),
1267            layout.mean_from(theta),
1268            mean_beta_hint,
1269            0,
1270            LOCATION_SCALE_N_OUTPUTS,
1271            "DispersionLocationScale::build_blocks: mu",
1272        )?;
1273
1274        let p_disp = noise_design.design.ncols();
1275        let mut disp_penalties = noise_design.penalties_as_penalty_matrix();
1276        disp_penalties.push(PenaltyMatrix::Dense(identity_penalty(p_disp)));
1277        let mut disp_nullspace = noise_design.nullspace_dims.clone();
1278        disp_nullspace.push(0);
1279        let mut dispspec = build_location_scale_block(
1280            "log_precision",
1281            noise_design.design.clone(),
1282            self.noise_offset.clone(),
1283            disp_penalties,
1284            disp_nullspace,
1285            layout.noise_from(theta),
1286            noise_beta_hint,
1287            1,
1288            LOCATION_SCALE_N_OUTPUTS,
1289            "DispersionLocationScale::build_blocks: log_precision",
1290        )?;
1291
1292        if meanspec.initial_beta.is_none() || dispspec.initial_beta.is_none() {
1293            let (mean_beta0, disp_beta0) = dispersion_location_scale_warm_start(
1294                self.kind,
1295                &self.y,
1296                &self.weights,
1297                &meanspec,
1298                &dispspec,
1299                meanspec.initial_beta.as_ref(),
1300                dispspec.initial_beta.as_ref(),
1301            )?;
1302            if meanspec.initial_beta.is_none() {
1303                meanspec.initial_beta = Some(mean_beta0);
1304            }
1305            if dispspec.initial_beta.is_none() {
1306                dispspec.initial_beta = Some(disp_beta0);
1307            }
1308        }
1309
1310        Ok(vec![meanspec, dispspec])
1311    }
1312
1313    fn build_family(
1314        &self,
1315        mean_design: &TermCollectionDesign,
1316        noise_design: &TermCollectionDesign,
1317    ) -> Self::Family {
1318        // The family stores y/weights/kind directly and does not need the
1319        // designs at construction time, but the row geometry of the offered
1320        // designs is the only cross-check that ties this family back to the
1321        // builder's data — assert it before handing the family to the engine
1322        // so a misaligned design surfaces here rather than downstream in the
1323        // inner solver.
1324        assert_eq!(
1325            mean_design.design.nrows(),
1326            self.y.len(),
1327            "DispersionGlmLocationScale::build_family: mean design row count must match y"
1328        );
1329        assert_eq!(
1330            noise_design.design.nrows(),
1331            self.y.len(),
1332            "DispersionGlmLocationScale::build_family: noise design row count must match y"
1333        );
1334        DispersionGlmLocationScaleFamily {
1335            kind: self.kind,
1336            y: self.y.clone(),
1337            weights: self.weights.clone(),
1338        }
1339    }
1340
1341    fn extract_primary_betas(
1342        &self,
1343        fit: &UnifiedFitResult,
1344    ) -> Result<(Array1<f64>, Array1<f64>), String> {
1345        let mean_beta = fit
1346            .block_states
1347            .get(DispersionGlmLocationScaleFamily::BLOCK_MEAN)
1348            .ok_or_else(|| "missing dispersion mean block state".to_string())?
1349            .beta
1350            .clone();
1351        let disp_beta = fit
1352            .block_states
1353            .get(DispersionGlmLocationScaleFamily::BLOCK_DISP)
1354            .ok_or_else(|| "missing dispersion log-precision block state".to_string())?
1355            .beta
1356            .clone();
1357        Ok((mean_beta, disp_beta))
1358    }
1359
1360    fn build_psiderivative_blocks(
1361        &self,
1362        data: ndarray::ArrayView2<'_, f64>,
1363        meanspec: &TermCollectionSpec,
1364        noisespec: &TermCollectionSpec,
1365        mean_design: &TermCollectionDesign,
1366        noise_design: &TermCollectionDesign,
1367    ) -> Result<Vec<Vec<CustomFamilyBlockPsiDerivative>>, String> {
1368        // The dispersion location-scale families have no closed-form analytic
1369        // spatial psi derivatives, and `fit_dispersion_glm_location_scale_terms`
1370        // disables the κ/ψ joint optimizer before the engine ever asks. If we
1371        // do get called (for example by a future caller that forgets the
1372        // disable), return a real diagnostic rather than a sentinel — emit the
1373        // exact data and design shape that was passed in so the bug is
1374        // diagnosable from the error string alone.
1375        Err(format!(
1376            "dispersion location-scale ({:?}) does not implement analytic spatial \
1377             psi derivatives; the κ/ψ joint optimizer must be disabled before \
1378             this builder is consulted. Called with data {n_rows}×{n_cols}, mean \
1379             spec (linear={mean_lin}, random={mean_re}, smooth={mean_sm}), noise \
1380             spec (linear={noise_lin}, random={noise_re}, smooth={noise_sm}), \
1381             mean design cols={mean_p}, noise design cols={noise_p}",
1382            self.kind,
1383            n_rows = data.nrows(),
1384            n_cols = data.ncols(),
1385            mean_lin = meanspec.linear_terms.len(),
1386            mean_re = meanspec.random_effect_terms.len(),
1387            mean_sm = meanspec.smooth_terms.len(),
1388            noise_lin = noisespec.linear_terms.len(),
1389            noise_re = noisespec.random_effect_terms.len(),
1390            noise_sm = noisespec.smooth_terms.len(),
1391            mean_p = mean_design.design.ncols(),
1392            noise_p = noise_design.design.ncols(),
1393        ))
1394    }
1395}
1396
1397/// Fit a dispersion-channel GAMLSS location-scale model (#913). All four
1398/// genuine-dispersion mean families share this single entry; the per-family
1399/// likelihood lives in [`dispersion_row_kernel`].
1400pub fn fit_dispersion_glm_location_scale_terms(
1401    data: ndarray::ArrayView2<'_, f64>,
1402    spec: DispersionGlmLocationScaleTermSpec,
1403    options: &BlockwiseFitOptions,
1404    kappa_options: &SpatialLengthScaleOptimizationOptions,
1405) -> Result<BlockwiseTermFitResult, String> {
1406    if let DispersionFamilyKind::Tweedie { p } = spec.kind {
1407        if !(p.is_finite() && p > 1.0 && p < 2.0) {
1408            return Err(format!(
1409                "Tweedie location-scale requires a variance power strictly in (1, 2); got p={p}"
1410            ));
1411        }
1412    }
1413    // The κ/ψ anisotropic-kernel joint optimizer needs analytic psi
1414    // derivatives this family does not provide; disable it so the engine runs
1415    // the full ρ REML directly via `fit_custom_family` (1-D and tensor smooth
1416    // penalties λ are still REML-selected).
1417    let mut kappa = kappa_options.clone();
1418    kappa.enabled = false;
1419    // A dispersion location-scale model is an inherently *predictable* model:
1420    // posterior-mean prediction (the response-scale predict path the CLI/FFI
1421    // drive) needs the joint `(β_μ, β_d)` posterior covariance, and so does the
1422    // reported total EDF / coefficient SEs. The block-diagonal joint Hessian is
1423    // always assembled here (`exact_newton_joint_hessian_with_specs` →
1424    // `compute_joint_covariance`, which for this family's `RidgedQuadraticReml`
1425    // outer objective uses the never-erroring SPD-retry → positive-part
1426    // pseudo-inverse), so we can — and must — request the covariance
1427    // unconditionally rather than leaving `covariance_conditional = None`
1428    // whenever the outer optimizer happens to *converge* (the only family-
1429    // independent reason NB sometimes populated covariance was that it escalated
1430    // into the never-fail posterior-sampling rung, while a cleanly-converged
1431    // Gamma/Tweedie fit took the `!options.compute_covariance ⇒ None` early
1432    // return and stranded its covariance/EDF — gam#1119). Forcing the flag here
1433    // makes all four genuine-dispersion mean families assemble the joint
1434    // covariance + EDF deterministically, exactly as a predictable model
1435    // requires.
1436    let mut options = options.clone();
1437    options.compute_covariance = true;
1438    fit_location_scale_terms(
1439        data,
1440        DispersionGlmLocationScaleTermBuilder {
1441            kind: spec.kind,
1442            y: spec.y,
1443            weights: spec.weights,
1444            meanspec: spec.meanspec,
1445            noisespec: spec.log_dispspec,
1446            mean_offset: spec.mean_offset,
1447            noise_offset: spec.log_disp_offset,
1448        },
1449        &options,
1450        &kappa,
1451    )
1452}
1453
1454#[cfg(test)]
1455mod tests {
1456    use super::*;
1457    use super::test_support::{dispersion_gamma_nll_order2, dispersion_nb_nll_order2};
1458    use crate::gamlss::test_support::dispersion_tweedie_nll_generic;
1459
1460    pub(crate) fn beta_fisher_cross_info_mu_phi(mu: f64, phi: f64) -> f64 {
1461        let a = mu * phi;
1462        let b = (1.0 - mu) * phi;
1463        phi * (mu * gam_math::jet_tower::trigamma_derivative_stack(a)[0]
1464            - (1.0 - mu) * gam_math::jet_tower::trigamma_derivative_stack(b)[0])
1465    }
1466
1467    pub(crate) fn assert_close(label: &str, got: f64, want: f64, tol: f64) {
1468        assert!(
1469            (got - want).abs() <= tol,
1470            "{label}: got {got:.12e}, want {want:.12e}, |diff|={:.3e}",
1471            (got - want).abs()
1472        );
1473    }
1474
1475    #[test]
1476    pub(crate) fn beta_tower_mixed_channel_matches_cross_information_formula() {
1477        let mu = 0.1;
1478        let phi = 10.0;
1479        let a = mu * phi;
1480        let b = (1.0 - mu) * phi;
1481        let digamma_a = gam_math::jet_tower::digamma_derivative_stack(a)[0];
1482        let digamma_b = gam_math::jet_tower::digamma_derivative_stack(b)[0];
1483        let score_neutral_y = 1.0 / (1.0 + (-(digamma_a - digamma_b)).exp());
1484
1485        let tower = dispersion_beta_nll_order2(score_neutral_y, mu, phi, 1.0);
1486        let trigamma_a = std::f64::consts::PI * std::f64::consts::PI / 6.0;
1487        let trigamma_b = gam_math::jet_tower::trigamma_derivative_stack(b)[0];
1488        let analytic = phi * (mu * trigamma_a - (1.0 - mu) * trigamma_b);
1489        let helper = beta_fisher_cross_info_mu_phi(mu, phi);
1490
1491        assert!(
1492            analytic > 0.58,
1493            "audit example should have visibly nonzero cross information, got {analytic}"
1494        );
1495        assert_close("helper cross information", helper, analytic, 1e-12);
1496        assert_close("tower mixed channel", tower.h()[0][1], analytic, 1e-8);
1497
1498        let q = mu * (1.0 - mu);
1499        let eta_cross = beta_observed_cross_weight_eta(score_neutral_y, mu, phi, 1.0);
1500        assert_close(
1501            "eta-scale cross weight",
1502            eta_cross,
1503            q * phi * analytic,
1504            1e-8,
1505        );
1506    }
1507
1508    /// #932 oracle: the production `Order2<2>` evaluation of each dispersion
1509    /// row NLL must reproduce, channel-for-channel (value/grad/Hessian), the
1510    /// dense `Tower4<2>` evaluation of the same row expression.
1511    #[test]
1512    pub(crate) fn order2_matches_dense_tower_all_channels() {
1513        use gam_math::jet_scalar::{JetScalar, Order2};
1514        use gam_math::jet_tower::Tower4;
1515
1516        fn check_o2_vs_tower4(label: &str, o2: Order2<2>, t4: Tower4<2>) {
1517            let band = |a: f64, b: f64| 1e-9 + 1e-9 * a.abs().max(b.abs());
1518            assert!(
1519                (o2.value() - t4.v).abs() <= band(o2.value(), t4.v),
1520                "{label} value: {} vs {}",
1521                o2.value(),
1522                t4.v
1523            );
1524            for a in 0..2 {
1525                assert!(
1526                    (o2.g()[a] - t4.g[a]).abs() <= band(o2.g()[a], t4.g[a]),
1527                    "{label} grad[{a}]: {} vs {}",
1528                    o2.g()[a],
1529                    t4.g[a]
1530                );
1531                for b in 0..2 {
1532                    assert!(
1533                        (o2.h()[a][b] - t4.h[a][b]).abs() <= band(o2.h()[a][b], t4.h[a][b]),
1534                        "{label} hess[{a}][{b}]: {} vs {}",
1535                        o2.h()[a][b],
1536                        t4.h[a][b]
1537                    );
1538                }
1539            }
1540        }
1541
1542        let wi = 1.7_f64;
1543        // NB2: (μ, θ).
1544        for &(yi, mu, theta) in &[(0.0, 1.2, 3.0), (4.0, 2.5, 0.7), (10.0, 0.6, 5.0)] {
1545            check_o2_vs_tower4(
1546                "nb",
1547                dispersion_nb_nll_order2(yi, mu, theta, wi),
1548                test_support::dispersion_nb_nll_generic::<Tower4<2>>(yi, mu, theta, wi),
1549            );
1550        }
1551        // Gamma: (μ, ν).
1552        for &(yi, mu, nu) in &[
1553            (0.5_f64, 1.1_f64, 2.0_f64),
1554            (3.0, 4.0, 0.9),
1555            (1.0, 0.3, 6.0),
1556        ] {
1557            let y_pos = yi.max(1e-300);
1558            check_o2_vs_tower4(
1559                "gamma",
1560                dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi),
1561                test_support::dispersion_gamma_nll_generic::<Tower4<2>>(yi, y_pos, mu, nu, wi),
1562            );
1563        }
1564        // Beta: (μ, φ).
1565        for &(yi, mu, phi) in &[(0.3, 0.4, 5.0), (0.9, 0.6, 12.0), (0.01, 0.2, 3.0)] {
1566            check_o2_vs_tower4(
1567                "beta",
1568                dispersion_beta_nll_order2(yi, mu, phi, wi),
1569                test_support::dispersion_beta_nll_generic::<Tower4<2>>(yi, mu, phi, wi),
1570            );
1571        }
1572        // Tweedie: (η_μ, η_d), both density branches.
1573        for &(yi, eta_mu, eta_d, p) in &[
1574            (0.0, 0.4, -0.3, 1.5),
1575            (2.5, -0.2, 0.5, 1.3),
1576            (0.0, 1.0, 0.1, 1.7),
1577            (5.0, 0.7, -0.6, 1.6),
1578        ] {
1579            check_o2_vs_tower4(
1580                "tweedie",
1581                dispersion_tweedie_nll_generic::<Order2<2>>(yi, eta_mu, eta_d, p, wi),
1582                dispersion_tweedie_nll_generic::<Tower4<2>>(yi, eta_mu, eta_d, p, wi),
1583            );
1584        }
1585    }
1586
1587    /// #1591 prune oracle: the pruned single-axis (`K=1`) dispersion towers
1588    /// reproduce, `to_bits`-exactly, the CONSUMED channels (`value`, dispersion-
1589    /// axis `g`/`h`) of the full `Order2<2>` towers — across ≥2000 randomized
1590    /// rows per family (both Tweedie density branches), including the η-clamp
1591    /// boundary. This is the bit-identity guarantee that the K-prune changes no
1592    /// observable float.
1593    #[test]
1594    pub(crate) fn pruned_disp_towers_bit_identical_to_full_order2() {
1595        use gam_math::jet_scalar::{JetScalar, Order2};
1596
1597        // Deterministic LCG so the sweep is reproducible without an rng dep.
1598        let mut state: u64 = 0x9E3779B97F4A7C15;
1599        let mut next = || {
1600            state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1601            ((state >> 11) as f64) / ((1u64 << 53) as f64)
1602        };
1603        let bits = |x: f64| x.to_bits();
1604
1605        let n_per = 600; // 600 rows × 4 families (Tweedie ×2 branches) > 2000.
1606        for _ in 0..n_per {
1607            let wi = 0.25 + 3.0 * next();
1608            let yi_count = (next() * 12.0).floor();
1609
1610            // NB: full O2<2> seeds (μ, θ); pruned seeds θ only.
1611            {
1612                let mu = (0.05 + 4.0 * next()).max(1e-300);
1613                let theta = (0.05 + 6.0 * next()).max(1e-12);
1614                let full = dispersion_nb_nll_order2(yi_count, mu, theta, wi);
1615                let prn = dispersion_nb_disp_order2(yi_count, mu, theta, wi);
1616                assert_eq!(bits(full.value()), bits(prn.value()), "nb value");
1617                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "nb grad");
1618                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "nb hess");
1619                // value-only path == -tower.value(), bit-for-bit.
1620                assert_eq!(
1621                    bits(dispersion_nb_neg_loglik(yi_count, mu, theta, wi)),
1622                    bits(-prn.value()),
1623                    "nb value-only"
1624                );
1625            }
1626            // Gamma: seeds (μ, ν) / ν.
1627            {
1628                let mu = (0.05 + 4.0 * next()).max(1e-300);
1629                let nu = (0.05 + 6.0 * next()).max(1e-12);
1630                let yi = 0.01 + 8.0 * next();
1631                let y_pos = yi.max(1e-300);
1632                let full = dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi);
1633                let prn = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
1634                assert_eq!(bits(full.value()), bits(prn.value()), "gamma value");
1635                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "gamma grad");
1636                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "gamma hess");
1637                assert_eq!(
1638                    bits(dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)),
1639                    bits(-prn.value()),
1640                    "gamma value-only"
1641                );
1642            }
1643            // Beta value-only path vs full K=2 tower value.
1644            {
1645                let mu = (1e-6 + (1.0 - 2e-6) * next()).clamp(1e-12, 1.0 - 1e-12);
1646                let phi = (0.05 + 20.0 * next()).max(1e-12);
1647                let yi = next();
1648                let full = dispersion_beta_nll_order2(yi, mu, phi, wi);
1649                assert_eq!(
1650                    bits(dispersion_beta_neg_loglik(yi, mu, phi, wi)),
1651                    bits(-full.value()),
1652                    "beta value-only"
1653                );
1654            }
1655            // Tweedie: seeds (η_μ, η_d) / η_d, both density branches & clamp edge.
1656            for &(yi, eta_mu, eta_d, p) in &[
1657                (0.0_f64, -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1658                (0.01 + 9.0 * next(), -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1659                (3.0, -DISPERSION_ETA_CLAMP - 5.0, DISPERSION_ETA_CLAMP + 5.0, 1.5),
1660            ] {
1661                let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1662                let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1663                let full = dispersion_tweedie_nll_generic::<Order2<2>>(yi, em, ed, p, wi);
1664                let prn = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
1665                assert_eq!(bits(full.value()), bits(prn.value()), "tweedie value");
1666                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "tweedie grad");
1667                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "tweedie hess");
1668                assert_eq!(
1669                    bits(dispersion_tweedie_neg_loglik(yi, em, ed, p, wi)),
1670                    bits(-prn.value()),
1671                    "tweedie value-only"
1672                );
1673            }
1674        }
1675    }
1676
1677    #[test]
1678    pub(crate) fn orthogonal_dispersion_families_report_zero_cross_weight() {
1679        let cases = [
1680            DispersionFamilyKind::NegativeBinomial,
1681            DispersionFamilyKind::Gamma,
1682            DispersionFamilyKind::Tweedie { p: 1.5 },
1683        ];
1684        for kind in cases {
1685            let got = dispersion_row_cross_weight(kind, 1.25, 0.2, -0.3, 2.0);
1686            assert_close(kind.family_tag(), got, 0.0, 1e-12);
1687        }
1688    }
1689
1690    /// Speed-path guard (#932): `evaluate` / `log_likelihood_only` materialize
1691    /// the row-kernel map in parallel for large `n`, then reduce SERIALLY in
1692    /// index order. This pins the parallel output (log-likelihood + both
1693    /// blocks' working response/weight vectors) to a hand-rolled serial
1694    /// reference so CI catches any reassociation or row-misindex regression.
1695    /// `n` sits well above `DISPERSION_PARALLEL_ROW_THRESHOLD`, and the test
1696    /// runs on the main thread (not a rayon worker), so the parallel branch is
1697    /// the one exercised. Because the reduction order is preserved the match is
1698    /// in fact bit-exact; the `1e-9` band is the contract floor.
1699    #[test]
1700    pub(crate) fn parallel_evaluate_matches_serial_reference() {
1701        let n = DISPERSION_PARALLEL_ROW_THRESHOLD * 3 + 7;
1702        // Deterministic LCG row data (no rng dependency).
1703        let mut state: u64 = 0xD1B5_4A32_D192_ED03;
1704        let mut next = || {
1705            state = state
1706                .wrapping_mul(6364136223846793005)
1707                .wrapping_add(1442695040888963407);
1708            ((state >> 11) as f64) / ((1u64 << 53) as f64)
1709        };
1710
1711        for kind in [
1712            DispersionFamilyKind::NegativeBinomial,
1713            DispersionFamilyKind::Gamma,
1714            DispersionFamilyKind::Beta,
1715            DispersionFamilyKind::Tweedie { p: 1.5 },
1716        ] {
1717            let y = Array1::from_shape_fn(n, |_| match kind {
1718                DispersionFamilyKind::Beta => 1e-3 + (1.0 - 2e-3) * next(),
1719                DispersionFamilyKind::NegativeBinomial => (next() * 12.0).floor(),
1720                _ => 0.05 + 8.0 * next(),
1721            });
1722            let weights = Array1::from_shape_fn(n, |_| 0.25 + 2.0 * next());
1723            let eta_mu = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1724            let eta_d = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1725
1726            let family = DispersionGlmLocationScaleFamily {
1727                kind,
1728                y: y.clone(),
1729                weights: weights.clone(),
1730            };
1731            let states = vec![
1732                ParameterBlockState {
1733                    beta: Array1::zeros(0),
1734                    eta: eta_mu.clone(),
1735                },
1736                ParameterBlockState {
1737                    beta: Array1::zeros(0),
1738                    eta: eta_d.clone(),
1739                },
1740            ];
1741
1742            // Serial reference, computed exactly as the pre-parallel loop did.
1743            let mut ll_ref = 0.0;
1744            let mut mw_ref = Array1::<f64>::zeros(n);
1745            let mut mr_ref = Array1::<f64>::zeros(n);
1746            let mut dw_ref = Array1::<f64>::zeros(n);
1747            let mut dr_ref = Array1::<f64>::zeros(n);
1748            for i in 0..n {
1749                let row = dispersion_row_kernel(kind, y[i], eta_mu[i], eta_d[i], weights[i]);
1750                if row.loglik.is_finite() {
1751                    ll_ref += row.loglik;
1752                }
1753                mw_ref[i] = row.mean_weight.max(0.0);
1754                mr_ref[i] = row.mean_response;
1755                dw_ref[i] = row.disp_weight.max(0.0);
1756                dr_ref[i] = row.disp_response;
1757            }
1758
1759            let eval = family.evaluate(&states).expect("parallel evaluate");
1760            assert_close(
1761                &format!("{kind:?} evaluate log-likelihood"),
1762                eval.log_likelihood,
1763                ll_ref,
1764                1e-9,
1765            );
1766
1767            let BlockWorkingSet::Diagonal {
1768                working_response: mr,
1769                working_weights: mw,
1770            } = &eval.blockworking_sets[0]
1771            else {
1772                panic!("mean block not diagonal");
1773            };
1774            let BlockWorkingSet::Diagonal {
1775                working_response: dr,
1776                working_weights: dw,
1777            } = &eval.blockworking_sets[1]
1778            else {
1779                panic!("dispersion block not diagonal");
1780            };
1781            for i in 0..n {
1782                assert_close("mean weight", mw[i], mw_ref[i], 1e-9);
1783                assert_close("mean response", mr[i], mr_ref[i], 1e-9);
1784                assert_close("disp weight", dw[i], dw_ref[i], 1e-9);
1785                assert_close("disp response", dr[i], dr_ref[i], 1e-9);
1786            }
1787
1788            // `log_likelihood_only` takes the same parallel-then-serial-sum
1789            // path; its value-only kernel is bit-identical to evaluate's loglik.
1790            let ll_only = family
1791                .log_likelihood_only(&states)
1792                .expect("parallel log_likelihood_only");
1793            assert_close(
1794                &format!("{kind:?} log_likelihood_only"),
1795                ll_only,
1796                ll_ref,
1797                1e-9,
1798            );
1799        }
1800    }
1801}