Skip to main content

gam_models/gamlss/
dispersion_family.rs

1//! #913: dispersion-channel GAMLSS location-scale families.
2//!
3//! Extracted from `gamlss.rs` (issue #780); this module now owns the
4//! dispersion-channel joint-curvature corrections.
5
6use super::weighted_design_products::{mirror_upper_to_lower, xt_diag_x_design, xt_diag_y_design};
7// `Order2<2>::value()` is the JetScalar trait method; bring the trait into scope
8// so the dispersion row-NLL value reads (`-tower.value()`) resolve (E0599 fix).
9use super::{
10    BlockwiseTermFitResult, GamlssLambdaLayout, LOCATION_SCALE_N_OUTPUTS,
11    LocationScaleFamilyBuilder, build_location_scale_block, fit_location_scale_terms,
12    identity_penalty, solve_penalizedweighted_projection,
13};
14use crate::block_layout::block_count::validate_block_count;
15use crate::custom_family::{
16    BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyBlockPsiDerivative,
17    FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
18};
19use crate::gamlss::GamlssError;
20use crate::model_types::UnifiedFitResult;
21use gam_linalg::matrix::LinearOperator;
22use gam_math::jet_scalar::JetScalar;
23use gam_terms::smooth::{
24    SpatialLengthScaleOptimizationOptions, TermCollectionDesign, TermCollectionSpec,
25};
26use ndarray::{Array1, Array2, s};
27use statrs::function::gamma::ln_gamma;
28
29// ============================================================================
30// #913: dispersion-channel GAMLSS location-scale families.
31//
32// `noise_formula` (a second linear predictor on the dispersion channel) was
33// wired only for Gaussian/Binomial location-scale and the survival families.
34// The genuine-dispersion mean families — NegativeBinomial, Gamma, Beta and
35// Tweedie — were mean-only with a single scalar dispersion. This module adds a
36// SINGLE generic two-block family that routes all four through the existing
37// blockwise REML engine and the shared `LocationScaleFamilyBuilder` /
38// `fit_location_scale_terms` plumbing, so the κ-coordinate assembly, warm
39// start, shrinkage-penalised scale block and result extraction are reused
40// verbatim. A family is added by supplying only its per-row log-likelihood and
41// the (mean, log-precision) working sets — everything else is shared.
42//
43// Block layout: block 0 = mean predictor (η_μ, log link for NB/Gamma/Tweedie,
44// logit for Beta); block 1 = log-precision predictor (η_d). The dispersion
45// channel models log(precision) uniformly — `θ` for NegativeBinomial, the
46// shape `ν` for Gamma, `φ` for Beta, and `1/φ` for Tweedie — so a larger η_d
47// always means *less* dispersion, matching the Gaussian/Binomial convention
48// where η_logσ smaller ⇒ tighter. With no `noise_formula` the log-precision
49// block is a single intercept and the fit reduces to the scalar-dispersion
50// model.
51//
52// NB2 with `(μ, θ)` and the exponential-dispersion members here with
53// `(μ, φ)` are Fisher-orthogonal in their standard mean/dispersion
54// parameterizations: Gamma uses shape `ν = 1/φ`, and Tweedie models
55// `log(1/φ)`, so those precision-channel transforms preserve zero expected
56// mean/dispersion cross information. Beta is the exception in this module's
57// mean/precision parameterization. For `Beta(μφ, (1−μ)φ)`,
58//
59//   I_{μ,φ} = φ · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
60//
61// so in predictor coordinates `(η_μ = logit μ, η_φ = log φ)` the Fisher cross
62// block is
63//
64//   I_{η_μ,η_φ} = μ(1−μ) φ² · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
65//
66// which is generically nonzero. Block-cyclic Fisher-scoring IRLS is still a
67// valid block coordinate solve for the point estimate, but joint-curvature
68// consumers (`log|H|`, coefficient covariance, posterior draws) must receive
69// Beta's off-diagonal coefficient block. Smoothing-parameter selection still
70// runs through the engine's first-order (gradient-only) outer path: the family
71// declines the dense outer Hessian capability because its working weights
72// couple the two blocks (`W_μ` depends on the precision and vice-versa), which
73// the block-local diagonal-drift hook cannot represent exactly.
74// ============================================================================
75
76/// The genuine-dispersion mean family whose precision (overdispersion) channel
77/// can carry a second `noise_formula` linear predictor (issue #913).
78#[derive(Clone, Copy, Debug, PartialEq)]
79pub enum DispersionFamilyKind {
80    /// NB2: `Var = μ + μ²/θ`; the precision channel models `log θ`.
81    NegativeBinomial,
82    /// Gamma with `Var = μ²/ν`; the precision channel models `log ν` (shape).
83    Gamma,
84    /// Beta(μφ, (1−μ)φ) with a logit mean link; the precision channel models
85    /// `log φ`.
86    Beta,
87    /// Tweedie compound Poisson–Gamma with `Var = φ μ^p`, fixed power `p`; the
88    /// precision channel models `log(1/φ)`. The per-row density uses the
89    /// saddlepoint (Nelder–Pregibon) approximation for `y > 0` and the exact
90    /// point mass at `y = 0`; this is the standard tractable Tweedie ML
91    /// surface (an exact-series φ-derivative is the remaining hard sub-item of
92    /// #913).
93    Tweedie { p: f64 },
94}
95
96impl DispersionFamilyKind {
97    pub const fn family_tag(self) -> &'static str {
98        match self {
99            DispersionFamilyKind::NegativeBinomial => FAMILY_NEGBIN_LOCATION_SCALE,
100            DispersionFamilyKind::Gamma => FAMILY_GAMMA_LOCATION_SCALE,
101            DispersionFamilyKind::Beta => FAMILY_BETA_LOCATION_SCALE,
102            DispersionFamilyKind::Tweedie { .. } => FAMILY_TWEEDIE_LOCATION_SCALE,
103        }
104    }
105
106    /// The mean link is logit for Beta (a probability mean) and log otherwise.
107    pub(crate) const fn mean_is_logit(self) -> bool {
108        matches!(self, DispersionFamilyKind::Beta)
109    }
110
111    /// The mean inverse link this dispersion family fits on: log for
112    /// NegativeBinomial / Gamma / Tweedie, logit for Beta. Single source of
113    /// truth shared by the CLI and FFI save paths so the persisted
114    /// `base_link` never diverges from the fitted channel.
115    pub fn base_link(self) -> gam_problem::InverseLink {
116        use gam_problem::{InverseLink, StandardLink};
117        if self.mean_is_logit() {
118            InverseLink::Standard(StandardLink::Logit)
119        } else {
120            InverseLink::Standard(StandardLink::Log)
121        }
122    }
123
124    /// The family's canonical [`LikelihoodSpec`] (mean response × mean link).
125    /// The overdispersion parameter is estimated by the log-precision channel,
126    /// so the response-family placeholder parameters (`phi`, `theta`) mirror
127    /// the [`resolve_family`](crate::fit_orchestration::materialize::resolve_family) defaults
128    /// and are not consumed as fixed values at predict time. This is the single
129    /// source of truth for the persisted location-scale likelihood so the CLI
130    /// and FFI save paths cannot diverge.
131    pub fn likelihood_spec(self) -> gam_problem::LikelihoodSpec {
132        use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
133        let response = match self {
134            DispersionFamilyKind::NegativeBinomial => ResponseFamily::NegativeBinomial {
135                theta: 1.0,
136                theta_fixed: false,
137            },
138            DispersionFamilyKind::Gamma => ResponseFamily::Gamma,
139            DispersionFamilyKind::Beta => ResponseFamily::Beta { phi: 1.0 },
140            DispersionFamilyKind::Tweedie { p } => ResponseFamily::Tweedie { p },
141        };
142        let link = if self.mean_is_logit() {
143            InverseLink::Standard(StandardLink::Logit)
144        } else {
145            InverseLink::Standard(StandardLink::Log)
146        };
147        LikelihoodSpec::new(response, link)
148    }
149}
150
151pub const FAMILY_NEGBIN_LOCATION_SCALE: &str = "negbin-location-scale";
152pub const FAMILY_GAMMA_LOCATION_SCALE: &str = "gamma-location-scale";
153pub const FAMILY_BETA_LOCATION_SCALE: &str = "beta-location-scale";
154pub const FAMILY_TWEEDIE_LOCATION_SCALE: &str = "tweedie-location-scale";
155
156/// `η` magnitude clamp shared by both channels (mirrors PIRLS `ETA_CLAMP`):
157/// keeps `exp(η)` and the logit jet away from overflow while staying in the
158/// smooth interior of every link.
159pub(super) const DISPERSION_ETA_CLAMP: f64 = 30.0;
160/// Floor for a per-row IRLS working weight / curvature so the block normal
161/// equations stay positive-definite. The working *response* always carries the
162/// exact score, so the stationary point (penalised score = 0) is independent
163/// of this floor; it only conditions the inner solve.
164pub(super) const DISPERSION_MIN_CURVATURE: f64 = 1e-12;
165
166/// Per-row working quantities for both channels at the current `(η_μ, η_d)`.
167pub(super) struct DispersionRowKernel {
168    pub(super) loglik: f64,
169    pub(super) mean_weight: f64,
170    pub(super) mean_response: f64,
171    pub(super) disp_weight: f64,
172    pub(super) disp_response: f64,
173}
174
175#[cfg(test)]
176mod test_support {
177    use super::*;
178
179    /// Test-oracle NB2 row NLL over a generic [`JetScalar<2>`], seeded on the
180    /// natural parameters `(μ, θ)`.
181    #[inline]
182    pub(super) fn dispersion_nb_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
183        yi: f64,
184        mu_value: f64,
185        theta_value: f64,
186        wi: f64,
187    ) -> S {
188        let mu = S::variable(mu_value, 0);
189        let theta = S::variable(theta_value, 1);
190        let tpm = theta.add(&mu);
191        // (theta + yi).ln_gamma() - theta.ln_gamma() - ln_gamma(yi+1)
192        //   + theta*theta.ln() - theta*tpm.ln() + mu.ln()*yi - tpm.ln()*yi
193        let loglik = theta
194            .add(&S::constant(yi))
195            .ln_gamma()
196            .sub(&theta.ln_gamma())
197            .sub(&S::constant(ln_gamma(yi + 1.0)))
198            .add(&theta.mul(&theta.ln()))
199            .sub(&theta.mul(&tpm.ln()))
200            .add(&mu.ln().scale(yi))
201            .sub(&tpm.ln().scale(yi));
202        loglik.scale(-wi)
203    }
204
205    /// Test-oracle Gamma row NLL over a generic [`JetScalar<2>`], seeded on
206    /// `(μ, ν)`.
207    #[inline]
208    pub(super) fn dispersion_gamma_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
209        yi: f64,
210        y_pos: f64,
211        mu_value: f64,
212        nu_value: f64,
213        wi: f64,
214    ) -> S {
215        let mu = S::variable(mu_value, 0);
216        let nu = S::variable(nu_value, 1);
217        // nu*nu.ln() - nu*mu.ln() - nu.ln_gamma() + (nu-1)*y_pos.ln() - nu*(mu.recip()*yi)
218        let loglik = nu
219            .mul(&nu.ln())
220            .sub(&nu.mul(&mu.ln()))
221            .sub(&nu.ln_gamma())
222            .add(&nu.sub(&S::constant(1.0)).scale(y_pos.ln()))
223            .sub(&nu.mul(&mu.recip().scale(yi)));
224        loglik.scale(-wi)
225    }
226
227    /// Test-oracle Beta row NLL over a generic [`JetScalar<2>`], seeded on
228    /// `(μ, φ)`.
229    #[inline]
230    pub(super) fn dispersion_beta_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
231        yi: f64,
232        mu_value: f64,
233        phi_value: f64,
234        wi: f64,
235    ) -> S {
236        let mu = S::variable(mu_value, 0);
237        let phi = S::variable(phi_value, 1);
238        let one_minus_mu = S::constant(1.0).sub(&mu);
239        let yc = yi.clamp(1e-12, 1.0 - 1e-12);
240        let a = mu.mul(&phi);
241        let b = one_minus_mu.mul(&phi);
242        // phi.ln_gamma() - a.ln_gamma() - b.ln_gamma()
243        //   + (a-1)*yc.ln() + (b-1)*(1-yc).ln()
244        let loglik = phi
245            .ln_gamma()
246            .sub(&a.ln_gamma())
247            .sub(&b.ln_gamma())
248            .add(&a.sub(&S::constant(1.0)).scale(yc.ln()))
249            .add(&b.sub(&S::constant(1.0)).scale((1.0 - yc).ln()));
250        loglik.scale(-wi)
251    }
252
253    /// #1591 jet-prune oracle: full `Order2<2>` (value/grad/Hessian) NB2 row NLL.
254    ///
255    /// Production no longer consumes the mean (`μ`-axis) derivative channels of
256    /// this tower — the NB mean block is Fisher-orthogonal and hand-written
257    /// exactly in [`dispersion_row_kernel`] — so the hot path uses the pruned
258    /// single-axis [`dispersion_nb_disp_order2`] instead. This `K=2` form
259    /// survives only as the dense-`Tower4<2>` oracle pin
260    /// (`order2_matches_dense_tower_all_channels`).
261    #[inline]
262    pub(super) fn dispersion_nb_nll_order2(
263        yi: f64,
264        mu_value: f64,
265        theta_value: f64,
266        wi: f64,
267    ) -> gam_math::jet_scalar::Order2<2> {
268        type O2 = gam_math::jet_scalar::Order2<2>;
269
270        let mu = O2::variable(mu_value, 0);
271        let theta = O2::variable(theta_value, 1);
272        let tpm = theta.add(&mu);
273        let theta_plus_y = theta.add(&O2::constant(yi));
274        let loglik = order2_ln_gamma(&theta_plus_y)
275            .sub(&order2_ln_gamma(&theta))
276            .sub(&O2::constant(ln_gamma(yi + 1.0)))
277            .add(&theta.mul(&theta.ln()))
278            .sub(&theta.mul(&tpm.ln()))
279            .add(&mu.ln().scale(yi))
280            .sub(&tpm.ln().scale(yi));
281        loglik.scale(-wi)
282    }
283
284    /// #1591 jet-prune oracle: full `Order2<2>` Gamma row NLL. As with NB, the
285    /// mean axis is unused in production (hand-written, Fisher-orthogonal); the
286    /// hot path uses the single-axis [`dispersion_gamma_disp_order2`]. Kept only
287    /// as the dense-tower oracle pin.
288    #[inline]
289    pub(super) fn dispersion_gamma_nll_order2(
290        yi: f64,
291        y_pos: f64,
292        mu_value: f64,
293        nu_value: f64,
294        wi: f64,
295    ) -> gam_math::jet_scalar::Order2<2> {
296        type O2 = gam_math::jet_scalar::Order2<2>;
297
298        let mu = O2::variable(mu_value, 0);
299        let nu = O2::variable(nu_value, 1);
300        let loglik = nu
301            .mul(&nu.ln())
302            .sub(&nu.mul(&mu.ln()))
303            .sub(&order2_ln_gamma(&nu))
304            .add(&nu.sub(&O2::constant(1.0)).scale(y_pos.ln()))
305            .sub(&nu.mul(&mu.recip().scale(yi)));
306        loglik.scale(-wi)
307    }
308}
309
310/// Production `Order2<2>` Beta row NLL (value/grad/Hessian hot path; the cross
311/// channel `h()[0][1]` feeds the Beta observed cross weight).
312#[inline]
313pub(crate) fn dispersion_beta_nll_order2(
314    yi: f64,
315    mu_value: f64,
316    phi_value: f64,
317    wi: f64,
318) -> gam_math::jet_scalar::Order2<2> {
319    type O2 = gam_math::jet_scalar::Order2<2>;
320
321    let mu = O2::variable(mu_value, 0);
322    let phi = O2::variable(phi_value, 1);
323    let one_minus_mu = O2::constant(1.0).sub(&mu);
324    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
325    let a = mu.mul(&phi);
326    let b = one_minus_mu.mul(&phi);
327    let loglik = order2_ln_gamma(&phi)
328        .sub(&order2_ln_gamma(&a))
329        .sub(&order2_ln_gamma(&b))
330        .add(&a.sub(&O2::constant(1.0)).scale(yc.ln()))
331        .add(&b.sub(&O2::constant(1.0)).scale((1.0 - yc).ln()));
332    loglik.scale(-wi)
333}
334
335#[inline]
336fn order2_ln_gamma<const K: usize>(
337    x: &gam_math::jet_scalar::Order2<K>,
338) -> gam_math::jet_scalar::Order2<K> {
339    gam_math::jet_scalar::Order2(
340        x.0.compose_unary(gam_math::jet_tower::ln_gamma_derivative_stack_order2(x.0.v)),
341    )
342}
343
344// ============================================================================
345// #1591 jet-prune: single-axis (`K=1`) dispersion-channel towers.
346//
347// For NegativeBinomial / Gamma / Tweedie the production row kernel consumes ONLY
348// the dispersion-axis derivatives (`g[disp]`, `h[disp][disp]`) and the value;
349// the mean block is Fisher-orthogonal and assembled in closed form. Seeding the
350// mean as a CONSTANT and the dispersion parameter as the SOLE jet variable
351// therefore yields a tower whose `(value, g[0], h[0][0])` are `to_bits`-
352// identical to the consumed `(value, g[1], h[1][1])` of the old `Order2<2>`
353// tower — the mean seed only ever populated the now-discarded `g[mean]` /
354// `h[mean][·]` channels (Leibniz/Faà-di-Bruno never read the dispersion-axis
355// channels off the mean seed). Collapsing `K=2 → K=1` quarters the Hessian
356// tensor (1 entry vs 4) and halves the gradient, with no change to any consumed
357// float bit. The `ln_gamma` derivative stacks are unchanged (the irreducible
358// transcendental cost), so this trims the rational composition, not the special
359// functions.
360// ============================================================================
361
362/// Pruned single-axis NB2 dispersion tower: `θ` is the sole jet variable
363/// (axis 0), `μ` a constant. `value`/`g[0]`/`h[0][0]` reproduce the consumed
364/// `value`/`g[1]`/`h[1][1]` of `dispersion_nb_nll_order2` bit-for-bit.
365#[inline]
366pub(crate) fn dispersion_nb_disp_order2(
367    yi: f64,
368    mu_value: f64,
369    theta_value: f64,
370    wi: f64,
371) -> gam_math::jet_scalar::Order2<1> {
372    type O1 = gam_math::jet_scalar::Order2<1>;
373
374    let mu = O1::constant(mu_value);
375    let theta = O1::variable(theta_value, 0);
376    let tpm = theta.add(&mu);
377    let theta_plus_y = theta.add(&O1::constant(yi));
378    let loglik = order2_ln_gamma(&theta_plus_y)
379        .sub(&order2_ln_gamma(&theta))
380        .sub(&O1::constant(ln_gamma(yi + 1.0)))
381        .add(&theta.mul(&theta.ln()))
382        .sub(&theta.mul(&tpm.ln()))
383        .add(&mu.ln().scale(yi))
384        .sub(&tpm.ln().scale(yi));
385    loglik.scale(-wi)
386}
387
388/// Pruned single-axis Gamma dispersion tower: `ν` is the sole jet variable
389/// (axis 0), `μ` a constant. Consumed channels match
390/// `dispersion_gamma_nll_order2` index-1 bit-for-bit.
391#[inline]
392pub(crate) fn dispersion_gamma_disp_order2(
393    yi: f64,
394    y_pos: f64,
395    mu_value: f64,
396    nu_value: f64,
397    wi: f64,
398) -> gam_math::jet_scalar::Order2<1> {
399    type O1 = gam_math::jet_scalar::Order2<1>;
400
401    let mu = O1::constant(mu_value);
402    let nu = O1::variable(nu_value, 0);
403    let loglik = nu
404        .mul(&nu.ln())
405        .sub(&nu.mul(&mu.ln()))
406        .sub(&order2_ln_gamma(&nu))
407        .add(&nu.sub(&O1::constant(1.0)).scale(y_pos.ln()))
408        .sub(&nu.mul(&mu.recip().scale(yi)));
409    loglik.scale(-wi)
410}
411
412/// Pruned single-axis Tweedie dispersion tower seeded on the predictor `η_d`
413/// (axis 0), with `η_μ` a constant (so `μ = exp(η_μ)` carries no jet). The
414/// `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature are carried
415/// exactly as in `dispersion_tweedie_nll_generic`; `value`/`g[0]`/`h[0][0]`
416/// match that program's `value`/`g[1]`/`h[1][1]` bit-for-bit.
417#[inline]
418pub(crate) fn dispersion_tweedie_disp_order2(
419    yi: f64,
420    eta_mu: f64,
421    eta_d: f64,
422    p: f64,
423    wi: f64,
424) -> gam_math::jet_scalar::Order2<1> {
425    type O1 = gam_math::jet_scalar::Order2<1>;
426
427    let one_minus_p = 1.0 - p;
428    let two_minus_p = 2.0 - p;
429    let mu = O1::constant(eta_mu).exp();
430    let phi = O1::variable(eta_d, 0).scale(-1.0).exp();
431    if yi > 0.0 {
432        let dev = mu
433            .powf(two_minus_p)
434            .scale(1.0 / two_minus_p)
435            .sub(&mu.powf(one_minus_p).scale(yi / one_minus_p))
436            .add(&O1::constant(
437                yi.powf(two_minus_p) / (one_minus_p * two_minus_p),
438            ))
439            .scale(2.0);
440        let loglik = dev
441            .mul(&phi.recip().scale(-0.5))
442            .sub(&phi.scale(2.0 * std::f64::consts::PI).ln().scale(0.5))
443            .sub(&O1::constant(0.5 * p * yi.ln()));
444        loglik.scale(-wi)
445    } else {
446        let c = mu.powf(two_minus_p).scale(1.0 / two_minus_p);
447        let loglik = c.mul(&phi.recip()).scale(-1.0);
448        loglik.scale(-wi)
449    }
450}
451
452// ============================================================================
453// #1591 jet-prune: value-only (`K=0`) row negative-log-likelihood.
454//
455// `log_likelihood_only` reads ONLY `row.loglik = -tower.value()`; the full row
456// kernel it used to call evaluated every dispersion tower's gradient AND Hessian
457// — including the digamma/trigamma derivative stacks — purely to discard them.
458// These functions evaluate the SAME value-channel program in plain `f64`, so
459// they are `to_bits`-identical to `-tower.value()` (the jet value channel is the
460// naive scalar evaluation: `mul.v = a.v*b.v`, `compose.v = stack[0]`), while
461// touching only `ln_gamma` (stack slot 0) and never the digamma/trigamma slots.
462// On a per-row loglik that is the dominant transcendental saving.
463// ============================================================================
464
465/// NB2 row NLL value, plain `f64`, bit-identical to
466/// `-dispersion_nb_disp_order2(..).value()`.
467#[inline]
468fn dispersion_nb_neg_loglik(yi: f64, mu: f64, theta: f64, wi: f64) -> f64 {
469    let tpm = theta + mu;
470    let s = ln_gamma(theta + yi) - ln_gamma(theta) - ln_gamma(yi + 1.0) + theta * theta.ln()
471        - theta * tpm.ln()
472        + mu.ln() * yi
473        - tpm.ln() * yi;
474    -(s * -wi)
475}
476
477/// Gamma row NLL value, plain `f64`, bit-identical to
478/// `-dispersion_gamma_disp_order2(..).value()`.
479#[inline]
480fn dispersion_gamma_neg_loglik(yi: f64, y_pos: f64, mu: f64, nu: f64, wi: f64) -> f64 {
481    // NB: the jet forms `μ.recip().scale(yi)` = `(1/μ)·yᵢ` (reciprocal then
482    // multiply), NOT `yᵢ/μ` (single divide) — these differ in the last bit, so
483    // the value path must reproduce the reciprocal-then-multiply exactly.
484    let s = nu * nu.ln() - nu * mu.ln() - ln_gamma(nu) + (nu - 1.0) * y_pos.ln()
485        - nu * ((1.0 / mu) * yi);
486    -(s * -wi)
487}
488
489/// Beta row NLL value, plain `f64`, bit-identical to
490/// `-dispersion_beta_nll_order2(..).value()`.
491#[inline]
492fn dispersion_beta_neg_loglik(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
493    let one_minus_mu = 1.0 - mu;
494    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
495    let a = mu * phi;
496    let b = one_minus_mu * phi;
497    let s = ln_gamma(phi) - ln_gamma(a) - ln_gamma(b)
498        + (a - 1.0) * yc.ln()
499        + (b - 1.0) * (1.0 - yc).ln();
500    -(s * -wi)
501}
502
503/// Tweedie row NLL value, plain `f64`, bit-identical to
504/// `-dispersion_tweedie_disp_order2(..).value()` (both density branches).
505#[inline]
506fn dispersion_tweedie_neg_loglik(yi: f64, eta_mu: f64, eta_d: f64, p: f64, wi: f64) -> f64 {
507    let one_minus_p = 1.0 - p;
508    let two_minus_p = 2.0 - p;
509    let mu = eta_mu.exp();
510    let phi = (-eta_d).exp();
511    let s = if yi > 0.0 {
512        let dev = (mu.powf(two_minus_p) * (1.0 / two_minus_p)
513            - mu.powf(one_minus_p) * (yi / one_minus_p)
514            + yi.powf(two_minus_p) / (one_minus_p * two_minus_p))
515            * 2.0;
516        dev * ((1.0 / phi) * -0.5)
517            - (phi * (2.0 * std::f64::consts::PI)).ln() * 0.5
518            - 0.5 * p * yi.ln()
519    } else {
520        let c = mu.powf(two_minus_p) * (1.0 / two_minus_p);
521        (c * (1.0 / phi)) * -1.0
522    };
523    -(s * -wi)
524}
525
526/// Value-only row negative log-likelihood for one observation — the pruned hot
527/// path for [`CustomFamily::log_likelihood_only`]. Mirrors the link/clamp
528/// preamble of [`dispersion_row_kernel`] exactly, then evaluates ONLY the value
529/// channel (no gradient/Hessian, no digamma/trigamma). Returns `row.loglik`
530/// `to_bits`-identically.
531#[inline]
532pub(crate) fn dispersion_row_loglik(
533    kind: DispersionFamilyKind,
534    yi: f64,
535    eta_mu: f64,
536    eta_d: f64,
537    prior_weight: f64,
538) -> f64 {
539    let wi = prior_weight.max(0.0);
540    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
541    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
542    match kind {
543        DispersionFamilyKind::NegativeBinomial => {
544            let mu = em.exp().max(1e-300);
545            let theta = ed.exp().max(1e-12);
546            dispersion_nb_neg_loglik(yi, mu, theta, wi)
547        }
548        DispersionFamilyKind::Gamma => {
549            let mu = em.exp().max(1e-300);
550            let nu = ed.exp().max(1e-12);
551            let y_pos = yi.max(1e-300);
552            dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)
553        }
554        DispersionFamilyKind::Beta => {
555            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
556            let phi = ed.exp().max(1e-12);
557            dispersion_beta_neg_loglik(yi, mu, phi, wi)
558        }
559        DispersionFamilyKind::Tweedie { p } => dispersion_tweedie_neg_loglik(yi, em, ed, p, wi),
560    }
561}
562
563#[inline]
564pub(crate) fn beta_observed_cross_weight_eta(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
565    let q = (mu * (1.0 - mu)).max(1e-12);
566    let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
567    q * phi * tower.h()[0][1]
568}
569
570#[inline]
571pub(crate) fn dispersion_row_cross_weight(
572    kind: DispersionFamilyKind,
573    yi: f64,
574    eta_mu: f64,
575    eta_d: f64,
576    prior_weight: f64,
577) -> f64 {
578    let wi = prior_weight.max(0.0);
579    if wi == 0.0 {
580        return 0.0;
581    }
582    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
583    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
584    match kind {
585        DispersionFamilyKind::Beta => {
586            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
587            let phi = ed.exp().max(1e-12);
588            beta_observed_cross_weight_eta(yi, mu, phi, wi)
589        }
590        DispersionFamilyKind::NegativeBinomial
591        | DispersionFamilyKind::Gamma
592        | DispersionFamilyKind::Tweedie { .. } => 0.0,
593    }
594}
595
596#[inline]
597pub(crate) fn tower_score_info<const K: usize>(
598    tower: &gam_math::jet_scalar::Order2<K>,
599    idx: usize,
600    wi: f64,
601) -> (f64, f64) {
602    if wi == 0.0 {
603        (0.0, 0.0)
604    } else {
605        (-tower.g()[idx] / wi, tower.h()[idx][idx] / wi)
606    }
607}
608
609/// Evaluate the row log-likelihood and the (mean, log-precision) Fisher-scoring
610/// working sets for one observation. `eta_mu`/`eta_d` already include any
611/// per-channel offset (they are the block predictors). `prior_weight` is the
612/// observation's prior weight.
613pub(super) fn dispersion_row_kernel(
614    kind: DispersionFamilyKind,
615    yi: f64,
616    eta_mu: f64,
617    eta_d: f64,
618    prior_weight: f64,
619) -> DispersionRowKernel {
620    let wi = prior_weight.max(0.0);
621    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
622    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
623    match kind {
624        DispersionFamilyKind::NegativeBinomial => {
625            let mu = em.exp().max(1e-300);
626            let theta = ed.exp().max(1e-12); // precision (size)
627            let tpm = theta + mu;
628            let tower = dispersion_nb_disp_order2(yi, mu, theta, wi);
629            // Only the exact θ-space SCORE is consumed from the tower; the
630            // observed-Hessian channel is discarded in favor of the expected
631            // (Fisher) information assembled below (see the dispersion-curvature
632            // note). Keeping the value channel for the row log-likelihood.
633            let (s_theta, _info_theta_observed) = tower_score_info(&tower, 0, wi);
634            let loglik = -tower.value();
635            let info_mu = if wi == 0.0 {
636                DISPERSION_MIN_CURVATURE
637            } else {
638                (theta / (mu * tpm)).max(DISPERSION_MIN_CURVATURE)
639            };
640            let score_mu = theta * (yi - mu) / (mu * tpm);
641            let mean_weight = wi * mu * mu * info_mu;
642            let mean_response = em + score_mu / (mu * info_mu);
643            // Dispersion (log-θ) IRLS curvature: use the EXPECTED (Fisher)
644            // information in θ, not the per-row OBSERVED Hessian channel
645            // (`_info_theta_observed`). The NB2 log-likelihood is strongly
646            // non-quadratic in θ: `−∂²ℓ/∂θ²` carries the row-specific term
647            // `ψ′(θ+y)` and goes NEGATIVE for every row whose count sits below
648            // its current fitted precision (overestimated size / underestimated
649            // overdispersion). Far from the optimum a majority of rows can be
650            // negative, so the assembled block curvature `Xᵀdiag(w)X` loses
651            // positive-definiteness; flooring each negative row at
652            // `DISPERSION_MIN_CURVATURE` (≈0) then divides the exact score by
653            // ~0 in the working response, producing O(1e12) IRLS targets that
654            // make the dispersion block step explode and the inner block-cyclic
655            // solve stall (never reaching KKT within the cycle budget — the
656            // `nb` location-scale `IntegrationError`, gam#1606). The mean block
657            // already uses its closed-form expected info `θ/(μ(θ+μ))`; the
658            // dispersion block must do the same.
659            //
660            // The Fisher information in θ has the closed form
661            //   I(θ) = ψ′(θ) − E[ψ′(θ+Y)] − 1/θ + 1/(θ+μ),
662            // whose only costly piece is the per-row infinite expectation
663            // `E[ψ′(θ+Y)]`. Replacing it with the Jensen plug-in `ψ′(θ+μ)`
664            // (valid because ψ′ is convex, so this is a tight lower bound on the
665            // expectation) gives a per-row, sum-free, STRICTLY POSITIVE
666            // curvature
667            //   I_θ ≈ ψ′(θ) − ψ′(θ+μ) − 1/θ + 1/(θ+μ) > 0  for all (μ,θ),
668            // since ψ′ is strictly decreasing. The working RESPONSE still
669            // carries the EXACT score `s_theta` (= ∂ℓ/∂θ from the tower), so the
670            // penalized stationary point (score = 0) is byte-unchanged — this is
671            // Fisher scoring, which only re-conditions the inner solve and never
672            // shifts the optimum (cf. the `DISPERSION_MIN_CURVATURE` contract
673            // note above). The observed channel `_info_theta_observed` is no
674            // longer consumed for the weight.
675            let trigamma_theta = gam_math::jet_tower::trigamma_derivative_stack(theta)[0];
676            let trigamma_tpm = gam_math::jet_tower::trigamma_derivative_stack(tpm)[0];
677            let info_theta_fisher = trigamma_theta - trigamma_tpm - 1.0 / theta + 1.0 / tpm;
678            let info_pos = info_theta_fisher.max(DISPERSION_MIN_CURVATURE);
679            let disp_weight = wi * theta * theta * info_pos;
680            let disp_response = ed + s_theta / (theta * info_pos);
681            DispersionRowKernel {
682                loglik,
683                mean_weight,
684                mean_response,
685                disp_weight,
686                disp_response,
687            }
688        }
689        DispersionFamilyKind::Gamma => {
690            let mu = em.exp().max(1e-300);
691            let nu = ed.exp().max(1e-12); // precision = shape ν
692            let y_pos = yi.max(1e-300);
693            let tower = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
694            let (s_nu, info_nu_raw) = tower_score_info(&tower, 0, wi);
695            let loglik = -tower.value();
696            let info_mu = if wi == 0.0 {
697                DISPERSION_MIN_CURVATURE
698            } else {
699                (nu / (mu * mu)).max(DISPERSION_MIN_CURVATURE)
700            };
701            let score_mu = nu * (yi - mu) / (mu * mu);
702            let mean_weight = wi * mu * mu * info_mu;
703            let mean_response = em + score_mu / (mu * info_mu);
704            let info_nu = info_nu_raw.max(DISPERSION_MIN_CURVATURE);
705            let disp_weight = wi * nu * nu * info_nu;
706            let disp_response = ed + s_nu / (nu * info_nu);
707            DispersionRowKernel {
708                loglik,
709                mean_weight,
710                mean_response,
711                disp_weight,
712                disp_response,
713            }
714        }
715        DispersionFamilyKind::Beta => {
716            // logit mean link.
717            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
718            let phi = ed.exp().max(1e-12); // precision
719            let q = (mu * (1.0 - mu)).max(1e-12); // dμ/dη
720            let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
721            let (score_mu, info_mu_raw) = tower_score_info(&tower, 0, wi);
722            let (s_phi, info_phi_raw) = tower_score_info(&tower, 1, wi);
723            let loglik = -tower.value();
724            let info_mu = info_mu_raw.max(DISPERSION_MIN_CURVATURE);
725            let mean_weight = wi * q * q * info_mu;
726            let mean_response = em + score_mu / (q * info_mu);
727            let info_phi = info_phi_raw.max(DISPERSION_MIN_CURVATURE);
728            let disp_weight = wi * phi * phi * info_phi;
729            let disp_response = ed + s_phi / (phi * info_phi);
730            DispersionRowKernel {
731                loglik,
732                mean_weight,
733                mean_response,
734                disp_weight,
735                disp_response,
736            }
737        }
738        DispersionFamilyKind::Tweedie { p } => {
739            let mu = em.exp().max(1e-300);
740            // Precision channel models log(1/φ) ⇒ φ = exp(−η_d).
741            let phi = (-ed).exp().max(1e-12);
742            let two_minus_p = 2.0 - p;
743            // Mean channel: the quasi-score `(y−μ)/μ` and Fisher weight
744            // `μ^{2−p}/φ` are simple closed forms (and the mean block is
745            // Fisher-orthogonal to the dispersion block in this
746            // parameterization), so they stay hand-written exactly as the
747            // NB/Gamma mean arms do.
748            let mean_weight = wi * mu.powf(two_minus_p) / phi;
749            let mean_response = em + (yi - mu) / mu;
750            // Dispersion channel: the η_d-space score and OBSERVED information
751            // come straight off the single-expression tower seeded on `η_d`
752            // (#932), so the saddlepoint/point-mass branch split, the
753            // `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature
754            // correction are all mechanically carried — no per-branch
755            // `s_phi`/`s_eta`/`curvature_eta` hand calculus. #1591: only the
756            // η_d axis is consumed, so the tower is the pruned single-axis
757            // `Order2<1>` (`η_μ` enters as a constant).
758            let tower = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
759            let loglik = -tower.value();
760            // η_d-space score and observed information off the tower, via the
761            // same helper the NB/Gamma/Beta arms use (returns `(0, 0)` when the
762            // prior weight is zero, so the row stays excluded below).
763            let (s_eta, info_eta_raw) = tower_score_info(&tower, 0, wi);
764            let curvature_eta = if wi == 0.0 {
765                DISPERSION_MIN_CURVATURE
766            } else {
767                info_eta_raw.max(DISPERSION_MIN_CURVATURE)
768            };
769            // The working response divides by this per-row curvature so the
770            // prior weight cancels (and a zero-prior-weight row stays excluded
771            // via `disp_weight = 0`).
772            let disp_weight = wi * curvature_eta;
773            let disp_response = ed + s_eta / curvature_eta;
774            DispersionRowKernel {
775                loglik,
776                mean_weight,
777                mean_response,
778                disp_weight,
779                disp_response,
780            }
781        }
782    }
783}
784
785/// Two-block GAMLSS family for the genuine-dispersion mean families (#913).
786#[derive(Clone)]
787pub(crate) struct DispersionGlmLocationScaleFamily {
788    pub(crate) kind: DispersionFamilyKind,
789    pub(crate) y: Array1<f64>,
790    pub(crate) weights: Array1<f64>,
791}
792
793impl DispersionGlmLocationScaleFamily {
794    pub(crate) const BLOCK_MEAN: usize = 0;
795    pub(crate) const BLOCK_DISP: usize = 1;
796}
797
798impl CustomFamily for DispersionGlmLocationScaleFamily {
799    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
800    // flat-prior exact-Newton objective carries no Jeffreys term), so families
801    // that historically armed the term by default opt back in explicitly.
802    fn joint_jeffreys_term_required(&self) -> bool {
803        true
804    }
805
806    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
807        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
808        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
809        let eta_d = &block_states[Self::BLOCK_DISP].eta;
810        let n = self.y.len();
811        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
812            return Err(format!(
813                "{} row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
814                self.kind.family_tag(),
815                eta_mu.len(),
816                eta_d.len(),
817                self.weights.len()
818            ));
819        }
820        let mut log_likelihood = 0.0;
821        let mut mean_weights = Array1::<f64>::zeros(n);
822        let mut mean_response = Array1::<f64>::zeros(n);
823        let mut disp_weights = Array1::<f64>::zeros(n);
824        let mut disp_response = Array1::<f64>::zeros(n);
825        for i in 0..n {
826            let row =
827                dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i]);
828            if row.loglik.is_finite() {
829                log_likelihood += row.loglik;
830            }
831            mean_weights[i] = row.mean_weight.max(0.0);
832            mean_response[i] = row.mean_response;
833            disp_weights[i] = row.disp_weight.max(0.0);
834            disp_response[i] = row.disp_response;
835        }
836        Ok(FamilyEvaluation {
837            log_likelihood,
838            blockworking_sets: vec![
839                BlockWorkingSet::diagonal_checked(mean_response, mean_weights)?,
840                BlockWorkingSet::diagonal_checked(disp_response, disp_weights)?,
841            ],
842        })
843    }
844
845    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
846        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
847        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
848        let eta_d = &block_states[Self::BLOCK_DISP].eta;
849        let mut ll = 0.0;
850        for i in 0..self.y.len() {
851            // #1591 prune: the objective needs only the row log-likelihood, so
852            // evaluate the value channel alone (`to_bits`-identical to
853            // `dispersion_row_kernel(..).loglik`) and skip every gradient,
854            // Hessian and digamma/trigamma derivative-stack evaluation.
855            let loglik =
856                dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i]);
857            if loglik.is_finite() {
858                ll += loglik;
859            }
860        }
861        Ok(ll)
862    }
863
864    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
865        crate::location_scale_engine::location_scale_coefficient_hessian_cost(
866            self.y.len() as u64,
867            specs,
868        )
869    }
870
871    /// Exact joint coefficient-space Hessian `H_L = -∇²log L` in flattened
872    /// `[mean | log-precision]` block order.
873    ///
874    /// All four members assemble the same `Xᵀ diag(W) X` blocks; the cross
875    /// block is the per-row mixed weight `dispersion_row_cross_weight`. Beta
876    /// carries a genuinely nonzero (η_μ, η_φ) cross weight; the Fisher-
877    /// orthogonal members (NegativeBinomial / Gamma / Tweedie) report a zero
878    /// cross weight, so this returns their exact *block-diagonal* joint
879    /// Hessian. Returning that block-diagonal `H_L` — rather than `None` —
880    /// is what lets the multi-block outer-REML path (`build_joint_hessian_
881    /// closures` → `joint_outer_evaluate`) and the joint posterior covariance
882    /// (`compute_joint_covariance`) run for these families instead of failing
883    /// the "multi-block families must provide a joint outer path" gate and
884    /// silently escalating to a degraded ρ-seed fit with no covariance/EDF
885    /// (gam#1119). The orthogonal members additionally declare
886    /// `likelihood_blocks_uncoupled() = true` so the directional-derivative
887    /// and Jeffreys dispatch route through the block-diagonal-exact fallback
888    /// rather than rejecting the structurally-uncoupled Hessian.
889    fn exact_newton_joint_hessian_with_specs(
890        &self,
891        block_states: &[ParameterBlockState],
892        specs: &[ParameterBlockSpec],
893    ) -> Result<Option<Array2<f64>>, String> {
894        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
895        if specs.len() != 2 {
896            return Err(format!(
897                "{} exact joint Hessian expects 2 specs, got {}",
898                self.kind.family_tag(),
899                specs.len()
900            ));
901        }
902        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
903        let eta_d = &block_states[Self::BLOCK_DISP].eta;
904        let n = self.y.len();
905        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
906            return Err(format!(
907                "{} exact joint Hessian row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
908                self.kind.family_tag(),
909                eta_mu.len(),
910                eta_d.len(),
911                self.weights.len()
912            ));
913        }
914
915        let eval = self.evaluate(block_states)?;
916        let BlockWorkingSet::Diagonal {
917            working_weights: mean_weights,
918            ..
919        } = &eval.blockworking_sets[Self::BLOCK_MEAN]
920        else {
921            return Err(format!(
922                "{} dispersion mean block did not return diagonal weights",
923                self.kind.family_tag()
924            ));
925        };
926        let BlockWorkingSet::Diagonal {
927            working_weights: disp_weights,
928            ..
929        } = &eval.blockworking_sets[Self::BLOCK_DISP]
930        else {
931            return Err(format!(
932                "{} dispersion precision block did not return diagonal weights",
933                self.kind.family_tag()
934            ));
935        };
936
937        let cross_weights = Array1::from_shape_fn(n, |i| {
938            dispersion_row_cross_weight(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
939        });
940        let mean_spec = &specs[Self::BLOCK_MEAN];
941        let disp_spec = &specs[Self::BLOCK_DISP];
942        if mean_spec.design.nrows() != n || disp_spec.design.nrows() != n {
943            return Err(format!(
944                "{} exact joint Hessian design row mismatch: y={n}, mean rows={}, precision rows={}",
945                self.kind.family_tag(),
946                mean_spec.design.nrows(),
947                disp_spec.design.nrows()
948            ));
949        }
950        let p_mean = mean_spec.design.ncols();
951        let p_disp = disp_spec.design.ncols();
952        if block_states[Self::BLOCK_MEAN].beta.len() != p_mean
953            || block_states[Self::BLOCK_DISP].beta.len() != p_disp
954        {
955            return Err(format!(
956                "{} exact joint Hessian beta/design mismatch: mean beta {} vs cols {}, precision beta {} vs cols {}",
957                self.kind.family_tag(),
958                block_states[Self::BLOCK_MEAN].beta.len(),
959                p_mean,
960                block_states[Self::BLOCK_DISP].beta.len(),
961                p_disp
962            ));
963        }
964
965        let h_mean = xt_diag_x_design(&mean_spec.design, mean_weights)?;
966        let h_cross = xt_diag_y_design(&mean_spec.design, &cross_weights, &disp_spec.design)?;
967        let h_disp = xt_diag_x_design(&disp_spec.design, disp_weights)?;
968        let total = p_mean + p_disp;
969        let mut h = Array2::<f64>::zeros((total, total));
970        h.slice_mut(s![0..p_mean, 0..p_mean]).assign(&h_mean);
971        h.slice_mut(s![0..p_mean, p_mean..total]).assign(&h_cross);
972        h.slice_mut(s![p_mean..total, p_mean..total])
973            .assign(&h_disp);
974        mirror_upper_to_lower(&mut h);
975        Ok(Some(h))
976    }
977
978    /// Whether the joint likelihood Hessian is block-diagonal in the
979    /// `[mean | log-precision]` coefficient vector.
980    ///
981    /// `Beta(μφ, (1−μ)φ)` carries a genuinely nonzero `(η_μ, η_φ)` Fisher
982    /// cross block (see the module header), so its blocks are coupled. The
983    /// remaining members are Fisher-orthogonal in their mean/precision
984    /// parameterizations — NB2 `(μ, θ)`, Gamma shape `ν = 1/φ`, Tweedie
985    /// `log(1/φ)` — so `∂²L/∂β_μ∂β_d = 0` and the joint Hessian is exactly
986    /// block-diagonal. Declaring that here lets the trait's directional-
987    /// derivative / Jeffreys dispatch accept the block-diagonal joint Hessian
988    /// via the working-set-exact fallback instead of rejecting it as an
989    /// untrusted structurally-uncoupled override (which would strand the
990    /// outer-REML gradient with a "dH unavailable" error, gam#1119).
991    fn likelihood_blocks_uncoupled(&self) -> bool {
992        !matches!(self.kind, DispersionFamilyKind::Beta)
993    }
994
995    /// The mean and precision working weights couple across both blocks, which
996    /// the block-local diagonal drift hook cannot represent, so decline the
997    /// dense outer Hessian capability whenever the actual two-block (or
998    /// larger) geometry is in play; a degenerate single-block probe — there
999    /// is no cross-block coupling to reject — keeps the trait default's
1000    /// availability verdict.
1001    ///
1002    /// The override still validates the block-spec slice it is handed (the
1003    /// same consistency check the trait default's assertion bottoms out in)
1004    /// so a malformed probe is reported here rather than downstream.
1005    fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
1006        assert!(
1007            crate::custom_family::validate_blockspec_consistency(specs).is_ok(),
1008            "DispersionGlmLocationScale outer hyper-Hessian dense availability: \
1009             inconsistent parameter block specs"
1010        );
1011        specs.len() < 2
1012    }
1013}
1014
1015/// Term spec consumed by [`fit_dispersion_glm_location_scale_terms`]; mirrors
1016/// [`GaussianLocationScaleTermSpec`](super::GaussianLocationScaleTermSpec) with
1017/// the dispersion channel in place of the Gaussian log-σ channel.
1018pub struct DispersionGlmLocationScaleTermSpec {
1019    pub kind: DispersionFamilyKind,
1020    pub y: Array1<f64>,
1021    pub weights: Array1<f64>,
1022    pub meanspec: TermCollectionSpec,
1023    pub log_dispspec: TermCollectionSpec,
1024    pub mean_offset: Array1<f64>,
1025    pub log_disp_offset: Array1<f64>,
1026}
1027
1028pub(crate) struct DispersionGlmLocationScaleTermBuilder {
1029    pub(crate) kind: DispersionFamilyKind,
1030    pub(crate) y: Array1<f64>,
1031    pub(crate) weights: Array1<f64>,
1032    pub(crate) meanspec: TermCollectionSpec,
1033    pub(crate) noisespec: TermCollectionSpec,
1034    pub(crate) mean_offset: Array1<f64>,
1035    pub(crate) noise_offset: Array1<f64>,
1036}
1037
1038/// Warm start for a dispersion location-scale fit: project a link-transformed
1039/// response onto the mean block and seed the log-precision block at a constant
1040/// (precision ≈ 1) baseline. The block-cyclic IRLS then refines both jointly.
1041pub(crate) fn dispersion_location_scale_warm_start(
1042    kind: DispersionFamilyKind,
1043    y: &Array1<f64>,
1044    weights: &Array1<f64>,
1045    mean_block: &ParameterBlockSpec,
1046    disp_block: &ParameterBlockSpec,
1047    mean_beta_hint: Option<&Array1<f64>>,
1048    disp_beta_hint: Option<&Array1<f64>>,
1049) -> Result<(Array1<f64>, Array1<f64>), String> {
1050    let ridge_floor = 1e-10;
1051    let mean_beta = if let Some(beta) = mean_beta_hint {
1052        beta.clone()
1053    } else {
1054        let target = Array1::from_shape_fn(y.len(), |i| {
1055            if kind.mean_is_logit() {
1056                let yi = y[i].clamp(1e-3, 1.0 - 1e-3);
1057                (yi / (1.0 - yi)).ln()
1058            } else {
1059                // log mean link; the +0.1 keeps zero counts finite.
1060                (y[i].max(0.0) + 0.1).ln()
1061            }
1062        });
1063        solve_penalizedweighted_projection(
1064            &mean_block.design,
1065            &mean_block.offset,
1066            &target,
1067            weights,
1068            &mean_block.penalties,
1069            &mean_block.initial_log_lambdas,
1070            ridge_floor,
1071        )?
1072    };
1073    let disp_beta = if let Some(beta) = disp_beta_hint {
1074        beta.clone()
1075    } else {
1076        // Seed the precision block from a smoothed method-of-moments surface
1077        // rather than the old flat η_d=0 constant.  A single observation cannot
1078        // identify its own variance, but for the Fisher-orthogonal dispersion
1079        // members the residual-squared moment contains the correct first-order
1080        // signal:
1081        //
1082        //   Gamma:   Var(Y)=μ²/ν              ⇒ log ν     ≈ log(μ²/e²)
1083        //   NB2:     Var(Y)=μ+μ²/θ            ⇒ log θ     ≈ log(μ²/(e²-μ))
1084        //   Tweedie: Var(Y)=φ μ^p, η_d=log1/φ ⇒ η_d       ≈ log(μ^p/e²)
1085        //
1086        // The targets are deliberately conservative (finite residual floor,
1087        // precision cap, and no fixture-specific constants): they only give the
1088        // block-cyclic likelihood solve a correctly-signed non-flat starting
1089        // surface, while the final estimate is still the penalized joint MLE.
1090        let mean_eta = mean_block.design.apply(&mean_beta) + &mean_block.offset;
1091        let target = Array1::from_shape_fn(y.len(), |i| {
1092            dispersion_moment_log_precision_seed(kind, y[i], mean_eta[i])
1093        });
1094        solve_penalizedweighted_projection(
1095            &disp_block.design,
1096            &disp_block.offset,
1097            &target,
1098            weights,
1099            &disp_block.penalties,
1100            &disp_block.initial_log_lambdas,
1101            ridge_floor,
1102        )?
1103    };
1104    Ok((mean_beta, disp_beta))
1105}
1106
1107#[inline]
1108fn dispersion_moment_log_precision_seed(kind: DispersionFamilyKind, yi: f64, eta_mu: f64) -> f64 {
1109    const LOG_PRECISION_FLOOR: f64 = -10.0;
1110    const LOG_PRECISION_CEILING: f64 = 10.0;
1111    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1112    let raw = match kind {
1113        DispersionFamilyKind::Beta => {
1114            // Beta's mean and precision scores are not Fisher-orthogonal in
1115            // the (logit μ, log φ) parameterization.  Per-row residual moments
1116            // therefore make a poor block-cyclic seed: an outlying y near 0/1
1117            // can imply a near-zero φ and pull the coupled mean block onto the
1118            // boundary before the joint likelihood has had a chance to settle.
1119            // Keep the neutral precision seed for this one coupled member; the
1120            // exact Beta cross-Hessian below still drives the joint solve and
1121            // covariance with the coherent two-block likelihood geometry.
1122            0.0
1123        }
1124        DispersionFamilyKind::Gamma => {
1125            let mu = em.exp().max(1e-12);
1126            let e2 = (yi - mu).powi(2).max(1e-8 * mu * mu);
1127            (mu * mu / e2).max(1e-6).ln()
1128        }
1129        DispersionFamilyKind::NegativeBinomial => {
1130            let mu = em.exp().max(1e-12);
1131            let e2 = (yi - mu).powi(2);
1132            let excess = (e2 - mu).max(1e-6 * (mu + mu * mu));
1133            (mu * mu / excess).max(1e-6).ln()
1134        }
1135        DispersionFamilyKind::Tweedie { p } => {
1136            let mu = em.exp().max(1e-12);
1137            let e2 = (yi - mu).powi(2).max(1e-8 * mu.powf(p));
1138            (mu.powf(p) / e2).max(1e-6).ln()
1139        }
1140    };
1141    raw.clamp(LOG_PRECISION_FLOOR, LOG_PRECISION_CEILING)
1142}
1143
1144impl LocationScaleFamilyBuilder for DispersionGlmLocationScaleTermBuilder {
1145    type Family = DispersionGlmLocationScaleFamily;
1146
1147    fn meanspec(&self) -> &TermCollectionSpec {
1148        &self.meanspec
1149    }
1150
1151    fn noisespec(&self) -> &TermCollectionSpec {
1152        &self.noisespec
1153    }
1154
1155    fn noise_penalty_count(&self, noise_design: &TermCollectionDesign) -> usize {
1156        // Mirror the Gaussian/Binomial scale block: a full-span shrinkage
1157        // penalty pins the log-precision nullspace so REML does not optimise
1158        // the dispersion smoothing on a flat surface.
1159        noise_design.penalties.len() + 1
1160    }
1161
1162    fn build_blocks(
1163        &self,
1164        theta: &Array1<f64>,
1165        mean_design: &TermCollectionDesign,
1166        noise_design: &TermCollectionDesign,
1167        mean_beta_hint: Option<Array1<f64>>,
1168        noise_beta_hint: Option<Array1<f64>>,
1169    ) -> Result<Vec<ParameterBlockSpec>, String> {
1170        let layout = GamlssLambdaLayout::two_block(
1171            mean_design.penalties.len(),
1172            self.noise_penalty_count(noise_design),
1173        );
1174        layout.validate_theta_len(theta.len(), "dispersion location-scale")?;
1175
1176        let mut meanspec = build_location_scale_block(
1177            "mu",
1178            mean_design.design.clone(),
1179            self.mean_offset.clone(),
1180            mean_design.penalties_as_penalty_matrix(),
1181            mean_design.nullspace_dims.clone(),
1182            layout.mean_from(theta),
1183            mean_beta_hint,
1184            0,
1185            LOCATION_SCALE_N_OUTPUTS,
1186            "DispersionLocationScale::build_blocks: mu",
1187        )?;
1188
1189        let p_disp = noise_design.design.ncols();
1190        let mut disp_penalties = noise_design.penalties_as_penalty_matrix();
1191        disp_penalties.push(PenaltyMatrix::Dense(identity_penalty(p_disp)));
1192        let mut disp_nullspace = noise_design.nullspace_dims.clone();
1193        disp_nullspace.push(0);
1194        let mut dispspec = build_location_scale_block(
1195            "log_precision",
1196            noise_design.design.clone(),
1197            self.noise_offset.clone(),
1198            disp_penalties,
1199            disp_nullspace,
1200            layout.noise_from(theta),
1201            noise_beta_hint,
1202            1,
1203            LOCATION_SCALE_N_OUTPUTS,
1204            "DispersionLocationScale::build_blocks: log_precision",
1205        )?;
1206
1207        if meanspec.initial_beta.is_none() || dispspec.initial_beta.is_none() {
1208            let (mean_beta0, disp_beta0) = dispersion_location_scale_warm_start(
1209                self.kind,
1210                &self.y,
1211                &self.weights,
1212                &meanspec,
1213                &dispspec,
1214                meanspec.initial_beta.as_ref(),
1215                dispspec.initial_beta.as_ref(),
1216            )?;
1217            if meanspec.initial_beta.is_none() {
1218                meanspec.initial_beta = Some(mean_beta0);
1219            }
1220            if dispspec.initial_beta.is_none() {
1221                dispspec.initial_beta = Some(disp_beta0);
1222            }
1223        }
1224
1225        Ok(vec![meanspec, dispspec])
1226    }
1227
1228    fn build_family(
1229        &self,
1230        mean_design: &TermCollectionDesign,
1231        noise_design: &TermCollectionDesign,
1232    ) -> Self::Family {
1233        // The family stores y/weights/kind directly and does not need the
1234        // designs at construction time, but the row geometry of the offered
1235        // designs is the only cross-check that ties this family back to the
1236        // builder's data — assert it before handing the family to the engine
1237        // so a misaligned design surfaces here rather than downstream in the
1238        // inner solver.
1239        assert_eq!(
1240            mean_design.design.nrows(),
1241            self.y.len(),
1242            "DispersionGlmLocationScale::build_family: mean design row count must match y"
1243        );
1244        assert_eq!(
1245            noise_design.design.nrows(),
1246            self.y.len(),
1247            "DispersionGlmLocationScale::build_family: noise design row count must match y"
1248        );
1249        DispersionGlmLocationScaleFamily {
1250            kind: self.kind,
1251            y: self.y.clone(),
1252            weights: self.weights.clone(),
1253        }
1254    }
1255
1256    fn extract_primary_betas(
1257        &self,
1258        fit: &UnifiedFitResult,
1259    ) -> Result<(Array1<f64>, Array1<f64>), String> {
1260        let mean_beta = fit
1261            .block_states
1262            .get(DispersionGlmLocationScaleFamily::BLOCK_MEAN)
1263            .ok_or_else(|| "missing dispersion mean block state".to_string())?
1264            .beta
1265            .clone();
1266        let disp_beta = fit
1267            .block_states
1268            .get(DispersionGlmLocationScaleFamily::BLOCK_DISP)
1269            .ok_or_else(|| "missing dispersion log-precision block state".to_string())?
1270            .beta
1271            .clone();
1272        Ok((mean_beta, disp_beta))
1273    }
1274
1275    fn build_psiderivative_blocks(
1276        &self,
1277        data: ndarray::ArrayView2<'_, f64>,
1278        meanspec: &TermCollectionSpec,
1279        noisespec: &TermCollectionSpec,
1280        mean_design: &TermCollectionDesign,
1281        noise_design: &TermCollectionDesign,
1282    ) -> Result<Vec<Vec<CustomFamilyBlockPsiDerivative>>, String> {
1283        // The dispersion location-scale families have no closed-form analytic
1284        // spatial psi derivatives, and `fit_dispersion_glm_location_scale_terms`
1285        // disables the κ/ψ joint optimizer before the engine ever asks. If we
1286        // do get called (for example by a future caller that forgets the
1287        // disable), return a real diagnostic rather than a sentinel — emit the
1288        // exact data and design shape that was passed in so the bug is
1289        // diagnosable from the error string alone.
1290        Err(format!(
1291            "dispersion location-scale ({:?}) does not implement analytic spatial \
1292             psi derivatives; the κ/ψ joint optimizer must be disabled before \
1293             this builder is consulted. Called with data {n_rows}×{n_cols}, mean \
1294             spec (linear={mean_lin}, random={mean_re}, smooth={mean_sm}), noise \
1295             spec (linear={noise_lin}, random={noise_re}, smooth={noise_sm}), \
1296             mean design cols={mean_p}, noise design cols={noise_p}",
1297            self.kind,
1298            n_rows = data.nrows(),
1299            n_cols = data.ncols(),
1300            mean_lin = meanspec.linear_terms.len(),
1301            mean_re = meanspec.random_effect_terms.len(),
1302            mean_sm = meanspec.smooth_terms.len(),
1303            noise_lin = noisespec.linear_terms.len(),
1304            noise_re = noisespec.random_effect_terms.len(),
1305            noise_sm = noisespec.smooth_terms.len(),
1306            mean_p = mean_design.design.ncols(),
1307            noise_p = noise_design.design.ncols(),
1308        ))
1309    }
1310}
1311
1312/// Fit a dispersion-channel GAMLSS location-scale model (#913). All four
1313/// genuine-dispersion mean families share this single entry; the per-family
1314/// likelihood lives in [`dispersion_row_kernel`].
1315pub fn fit_dispersion_glm_location_scale_terms(
1316    data: ndarray::ArrayView2<'_, f64>,
1317    spec: DispersionGlmLocationScaleTermSpec,
1318    options: &BlockwiseFitOptions,
1319    kappa_options: &SpatialLengthScaleOptimizationOptions,
1320) -> Result<BlockwiseTermFitResult, String> {
1321    if let DispersionFamilyKind::Tweedie { p } = spec.kind {
1322        if !(p.is_finite() && p > 1.0 && p < 2.0) {
1323            return Err(format!(
1324                "Tweedie location-scale requires a variance power strictly in (1, 2); got p={p}"
1325            ));
1326        }
1327    }
1328    // The κ/ψ anisotropic-kernel joint optimizer needs analytic psi
1329    // derivatives this family does not provide; disable it so the engine runs
1330    // the full ρ REML directly via `fit_custom_family` (1-D and tensor smooth
1331    // penalties λ are still REML-selected).
1332    let mut kappa = kappa_options.clone();
1333    kappa.enabled = false;
1334    // A dispersion location-scale model is an inherently *predictable* model:
1335    // posterior-mean prediction (the response-scale predict path the CLI/FFI
1336    // drive) needs the joint `(β_μ, β_d)` posterior covariance, and so does the
1337    // reported total EDF / coefficient SEs. The block-diagonal joint Hessian is
1338    // always assembled here (`exact_newton_joint_hessian_with_specs` →
1339    // `compute_joint_covariance`, which for this family's `RidgedQuadraticReml`
1340    // outer objective uses the never-erroring SPD-retry → positive-part
1341    // pseudo-inverse), so we can — and must — request the covariance
1342    // unconditionally rather than leaving `covariance_conditional = None`
1343    // whenever the outer optimizer happens to *converge* (the only family-
1344    // independent reason NB sometimes populated covariance was that it escalated
1345    // into the never-fail posterior-sampling rung, while a cleanly-converged
1346    // Gamma/Tweedie fit took the `!options.compute_covariance ⇒ None` early
1347    // return and stranded its covariance/EDF — gam#1119). Forcing the flag here
1348    // makes all four genuine-dispersion mean families assemble the joint
1349    // covariance + EDF deterministically, exactly as a predictable model
1350    // requires.
1351    let mut options = options.clone();
1352    options.compute_covariance = true;
1353    fit_location_scale_terms(
1354        data,
1355        DispersionGlmLocationScaleTermBuilder {
1356            kind: spec.kind,
1357            y: spec.y,
1358            weights: spec.weights,
1359            meanspec: spec.meanspec,
1360            noisespec: spec.log_dispspec,
1361            mean_offset: spec.mean_offset,
1362            noise_offset: spec.log_disp_offset,
1363        },
1364        &options,
1365        &kappa,
1366    )
1367}
1368
1369#[cfg(test)]
1370mod tests {
1371    use super::*;
1372    use super::test_support::{dispersion_gamma_nll_order2, dispersion_nb_nll_order2};
1373    use crate::gamlss::test_support::dispersion_tweedie_nll_generic;
1374
1375    pub(crate) fn beta_fisher_cross_info_mu_phi(mu: f64, phi: f64) -> f64 {
1376        let a = mu * phi;
1377        let b = (1.0 - mu) * phi;
1378        phi * (mu * gam_math::jet_tower::trigamma_derivative_stack(a)[0]
1379            - (1.0 - mu) * gam_math::jet_tower::trigamma_derivative_stack(b)[0])
1380    }
1381
1382    pub(crate) fn assert_close(label: &str, got: f64, want: f64, tol: f64) {
1383        assert!(
1384            (got - want).abs() <= tol,
1385            "{label}: got {got:.12e}, want {want:.12e}, |diff|={:.3e}",
1386            (got - want).abs()
1387        );
1388    }
1389
1390    #[test]
1391    pub(crate) fn beta_tower_mixed_channel_matches_cross_information_formula() {
1392        let mu = 0.1;
1393        let phi = 10.0;
1394        let a = mu * phi;
1395        let b = (1.0 - mu) * phi;
1396        let digamma_a = gam_math::jet_tower::digamma_derivative_stack(a)[0];
1397        let digamma_b = gam_math::jet_tower::digamma_derivative_stack(b)[0];
1398        let score_neutral_y = 1.0 / (1.0 + (-(digamma_a - digamma_b)).exp());
1399
1400        let tower = dispersion_beta_nll_order2(score_neutral_y, mu, phi, 1.0);
1401        let trigamma_a = std::f64::consts::PI * std::f64::consts::PI / 6.0;
1402        let trigamma_b = gam_math::jet_tower::trigamma_derivative_stack(b)[0];
1403        let analytic = phi * (mu * trigamma_a - (1.0 - mu) * trigamma_b);
1404        let helper = beta_fisher_cross_info_mu_phi(mu, phi);
1405
1406        assert!(
1407            analytic > 0.58,
1408            "audit example should have visibly nonzero cross information, got {analytic}"
1409        );
1410        assert_close("helper cross information", helper, analytic, 1e-12);
1411        assert_close("tower mixed channel", tower.h()[0][1], analytic, 1e-8);
1412
1413        let q = mu * (1.0 - mu);
1414        let eta_cross = beta_observed_cross_weight_eta(score_neutral_y, mu, phi, 1.0);
1415        assert_close(
1416            "eta-scale cross weight",
1417            eta_cross,
1418            q * phi * analytic,
1419            1e-8,
1420        );
1421    }
1422
1423    /// #932 oracle: the production `Order2<2>` evaluation of each dispersion
1424    /// row NLL must reproduce, channel-for-channel (value/grad/Hessian), the
1425    /// dense `Tower4<2>` evaluation of the same row expression.
1426    #[test]
1427    pub(crate) fn order2_matches_dense_tower_all_channels() {
1428        use gam_math::jet_scalar::{JetScalar, Order2};
1429        use gam_math::jet_tower::Tower4;
1430
1431        fn check_o2_vs_tower4(label: &str, o2: Order2<2>, t4: Tower4<2>) {
1432            let band = |a: f64, b: f64| 1e-9 + 1e-9 * a.abs().max(b.abs());
1433            assert!(
1434                (o2.value() - t4.v).abs() <= band(o2.value(), t4.v),
1435                "{label} value: {} vs {}",
1436                o2.value(),
1437                t4.v
1438            );
1439            for a in 0..2 {
1440                assert!(
1441                    (o2.g()[a] - t4.g[a]).abs() <= band(o2.g()[a], t4.g[a]),
1442                    "{label} grad[{a}]: {} vs {}",
1443                    o2.g()[a],
1444                    t4.g[a]
1445                );
1446                for b in 0..2 {
1447                    assert!(
1448                        (o2.h()[a][b] - t4.h[a][b]).abs() <= band(o2.h()[a][b], t4.h[a][b]),
1449                        "{label} hess[{a}][{b}]: {} vs {}",
1450                        o2.h()[a][b],
1451                        t4.h[a][b]
1452                    );
1453                }
1454            }
1455        }
1456
1457        let wi = 1.7_f64;
1458        // NB2: (μ, θ).
1459        for &(yi, mu, theta) in &[(0.0, 1.2, 3.0), (4.0, 2.5, 0.7), (10.0, 0.6, 5.0)] {
1460            check_o2_vs_tower4(
1461                "nb",
1462                dispersion_nb_nll_order2(yi, mu, theta, wi),
1463                test_support::dispersion_nb_nll_generic::<Tower4<2>>(yi, mu, theta, wi),
1464            );
1465        }
1466        // Gamma: (μ, ν).
1467        for &(yi, mu, nu) in &[
1468            (0.5_f64, 1.1_f64, 2.0_f64),
1469            (3.0, 4.0, 0.9),
1470            (1.0, 0.3, 6.0),
1471        ] {
1472            let y_pos = yi.max(1e-300);
1473            check_o2_vs_tower4(
1474                "gamma",
1475                dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi),
1476                test_support::dispersion_gamma_nll_generic::<Tower4<2>>(yi, y_pos, mu, nu, wi),
1477            );
1478        }
1479        // Beta: (μ, φ).
1480        for &(yi, mu, phi) in &[(0.3, 0.4, 5.0), (0.9, 0.6, 12.0), (0.01, 0.2, 3.0)] {
1481            check_o2_vs_tower4(
1482                "beta",
1483                dispersion_beta_nll_order2(yi, mu, phi, wi),
1484                test_support::dispersion_beta_nll_generic::<Tower4<2>>(yi, mu, phi, wi),
1485            );
1486        }
1487        // Tweedie: (η_μ, η_d), both density branches.
1488        for &(yi, eta_mu, eta_d, p) in &[
1489            (0.0, 0.4, -0.3, 1.5),
1490            (2.5, -0.2, 0.5, 1.3),
1491            (0.0, 1.0, 0.1, 1.7),
1492            (5.0, 0.7, -0.6, 1.6),
1493        ] {
1494            check_o2_vs_tower4(
1495                "tweedie",
1496                dispersion_tweedie_nll_generic::<Order2<2>>(yi, eta_mu, eta_d, p, wi),
1497                dispersion_tweedie_nll_generic::<Tower4<2>>(yi, eta_mu, eta_d, p, wi),
1498            );
1499        }
1500    }
1501
1502    /// #1591 prune oracle: the pruned single-axis (`K=1`) dispersion towers
1503    /// reproduce, `to_bits`-exactly, the CONSUMED channels (`value`, dispersion-
1504    /// axis `g`/`h`) of the full `Order2<2>` towers — across ≥2000 randomized
1505    /// rows per family (both Tweedie density branches), including the η-clamp
1506    /// boundary. This is the bit-identity guarantee that the K-prune changes no
1507    /// observable float.
1508    #[test]
1509    pub(crate) fn pruned_disp_towers_bit_identical_to_full_order2() {
1510        use gam_math::jet_scalar::{JetScalar, Order2};
1511
1512        // Deterministic LCG so the sweep is reproducible without an rng dep.
1513        let mut state: u64 = 0x9E3779B97F4A7C15;
1514        let mut next = || {
1515            state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1516            ((state >> 11) as f64) / ((1u64 << 53) as f64)
1517        };
1518        let bits = |x: f64| x.to_bits();
1519
1520        let n_per = 600; // 600 rows × 4 families (Tweedie ×2 branches) > 2000.
1521        for _ in 0..n_per {
1522            let wi = 0.25 + 3.0 * next();
1523            let yi_count = (next() * 12.0).floor();
1524
1525            // NB: full O2<2> seeds (μ, θ); pruned seeds θ only.
1526            {
1527                let mu = (0.05 + 4.0 * next()).max(1e-300);
1528                let theta = (0.05 + 6.0 * next()).max(1e-12);
1529                let full = dispersion_nb_nll_order2(yi_count, mu, theta, wi);
1530                let prn = dispersion_nb_disp_order2(yi_count, mu, theta, wi);
1531                assert_eq!(bits(full.value()), bits(prn.value()), "nb value");
1532                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "nb grad");
1533                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "nb hess");
1534                // value-only path == -tower.value(), bit-for-bit.
1535                assert_eq!(
1536                    bits(dispersion_nb_neg_loglik(yi_count, mu, theta, wi)),
1537                    bits(-prn.value()),
1538                    "nb value-only"
1539                );
1540            }
1541            // Gamma: seeds (μ, ν) / ν.
1542            {
1543                let mu = (0.05 + 4.0 * next()).max(1e-300);
1544                let nu = (0.05 + 6.0 * next()).max(1e-12);
1545                let yi = 0.01 + 8.0 * next();
1546                let y_pos = yi.max(1e-300);
1547                let full = dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi);
1548                let prn = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
1549                assert_eq!(bits(full.value()), bits(prn.value()), "gamma value");
1550                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "gamma grad");
1551                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "gamma hess");
1552                assert_eq!(
1553                    bits(dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)),
1554                    bits(-prn.value()),
1555                    "gamma value-only"
1556                );
1557            }
1558            // Beta value-only path vs full K=2 tower value.
1559            {
1560                let mu = (1e-6 + (1.0 - 2e-6) * next()).clamp(1e-12, 1.0 - 1e-12);
1561                let phi = (0.05 + 20.0 * next()).max(1e-12);
1562                let yi = next();
1563                let full = dispersion_beta_nll_order2(yi, mu, phi, wi);
1564                assert_eq!(
1565                    bits(dispersion_beta_neg_loglik(yi, mu, phi, wi)),
1566                    bits(-full.value()),
1567                    "beta value-only"
1568                );
1569            }
1570            // Tweedie: seeds (η_μ, η_d) / η_d, both density branches & clamp edge.
1571            for &(yi, eta_mu, eta_d, p) in &[
1572                (0.0_f64, -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1573                (0.01 + 9.0 * next(), -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1574                (3.0, -DISPERSION_ETA_CLAMP - 5.0, DISPERSION_ETA_CLAMP + 5.0, 1.5),
1575            ] {
1576                let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1577                let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1578                let full = dispersion_tweedie_nll_generic::<Order2<2>>(yi, em, ed, p, wi);
1579                let prn = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
1580                assert_eq!(bits(full.value()), bits(prn.value()), "tweedie value");
1581                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "tweedie grad");
1582                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "tweedie hess");
1583                assert_eq!(
1584                    bits(dispersion_tweedie_neg_loglik(yi, em, ed, p, wi)),
1585                    bits(-prn.value()),
1586                    "tweedie value-only"
1587                );
1588            }
1589        }
1590    }
1591
1592    #[test]
1593    pub(crate) fn orthogonal_dispersion_families_report_zero_cross_weight() {
1594        let cases = [
1595            DispersionFamilyKind::NegativeBinomial,
1596            DispersionFamilyKind::Gamma,
1597            DispersionFamilyKind::Tweedie { p: 1.5 },
1598        ];
1599        for kind in cases {
1600            let got = dispersion_row_cross_weight(kind, 1.25, 0.2, -0.3, 2.0);
1601            assert_close(kind.family_tag(), got, 0.0, 1e-12);
1602        }
1603    }
1604}