Skip to main content

gam_models/gamlss/
dispersion_family.rs

1//! #913: dispersion-channel GAMLSS location-scale families.
2//!
3//! Extracted from `gamlss.rs` (issue #780); this module now owns the
4//! dispersion-channel joint-curvature corrections.
5
6use super::weighted_design_products::{mirror_upper_to_lower, xt_diag_x_design, xt_diag_y_design};
7// `Order2<2>::value()` is the JetScalar trait method; bring the trait into scope
8// so the dispersion row-NLL value reads (`-tower.value()`) resolve (E0599 fix).
9use super::{
10    BlockwiseTermFitResult, GamlssLambdaLayout, LOCATION_SCALE_N_OUTPUTS,
11    LocationScaleFamilyBuilder, build_location_scale_block, fit_location_scale_terms,
12    identity_penalty, solve_penalizedweighted_projection,
13};
14use crate::block_layout::block_count::validate_block_count;
15use crate::custom_family::{
16    BlockWorkingSet, BlockwiseFitOptions, CustomFamily, CustomFamilyBlockPsiDerivative,
17    FamilyEvaluation, ParameterBlockSpec, ParameterBlockState, PenaltyMatrix,
18};
19use crate::gamlss::GamlssError;
20use crate::model_types::UnifiedFitResult;
21use gam_linalg::matrix::LinearOperator;
22use gam_math::jet_scalar::JetScalar;
23use gam_terms::smooth::{
24    SpatialLengthScaleOptimizationOptions, TermCollectionDesign, TermCollectionSpec,
25};
26use ndarray::{Array1, Array2, s};
27use statrs::function::gamma::ln_gamma;
28
29// ============================================================================
30// #913: dispersion-channel GAMLSS location-scale families.
31//
32// `noise_formula` (a second linear predictor on the dispersion channel) was
33// wired only for Gaussian/Binomial location-scale and the survival families.
34// The genuine-dispersion mean families — NegativeBinomial, Gamma, Beta and
35// Tweedie — were mean-only with a single scalar dispersion. This module adds a
36// SINGLE generic two-block family that routes all four through the existing
37// blockwise REML engine and the shared `LocationScaleFamilyBuilder` /
38// `fit_location_scale_terms` plumbing, so the κ-coordinate assembly, warm
39// start, shrinkage-penalised scale block and result extraction are reused
40// verbatim. A family is added by supplying only its per-row log-likelihood and
41// the (mean, log-precision) working sets — everything else is shared.
42//
43// Block layout: block 0 = mean predictor (η_μ, log link for NB/Gamma/Tweedie,
44// logit for Beta); block 1 = log-precision predictor (η_d). The dispersion
45// channel models log(precision) uniformly — `θ` for NegativeBinomial, the
46// shape `ν` for Gamma, `φ` for Beta, and `1/φ` for Tweedie — so a larger η_d
47// always means *less* dispersion, matching the Gaussian/Binomial convention
48// where η_logσ smaller ⇒ tighter. With no `noise_formula` the log-precision
49// block is a single intercept and the fit reduces to the scalar-dispersion
50// model.
51//
52// NB2 with `(μ, θ)` and the exponential-dispersion members here with
53// `(μ, φ)` are Fisher-orthogonal in their standard mean/dispersion
54// parameterizations: Gamma uses shape `ν = 1/φ`, and Tweedie models
55// `log(1/φ)`, so those precision-channel transforms preserve zero expected
56// mean/dispersion cross information. Beta is the exception in this module's
57// mean/precision parameterization. For `Beta(μφ, (1−μ)φ)`,
58//
59//   I_{μ,φ} = φ · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
60//
61// so in predictor coordinates `(η_μ = logit μ, η_φ = log φ)` the Fisher cross
62// block is
63//
64//   I_{η_μ,η_φ} = μ(1−μ) φ² · (μ ψ'(μφ) − (1−μ) ψ'((1−μ)φ)),
65//
66// which is generically nonzero. Block-cyclic Fisher-scoring IRLS is still a
67// valid block coordinate solve for the point estimate, but joint-curvature
68// consumers (`log|H|`, coefficient covariance, posterior draws) must receive
69// Beta's off-diagonal coefficient block. Smoothing-parameter selection still
70// runs through the engine's first-order (gradient-only) outer path: the family
71// declines the dense outer Hessian capability because its working weights
72// couple the two blocks (`W_μ` depends on the precision and vice-versa), which
73// the block-local diagonal-drift hook cannot represent exactly.
74// ============================================================================
75
76/// The genuine-dispersion mean family whose precision (overdispersion) channel
77/// can carry a second `noise_formula` linear predictor (issue #913).
78#[derive(Clone, Copy, Debug, PartialEq)]
79pub enum DispersionFamilyKind {
80    /// NB2: `Var = μ + μ²/θ`; the precision channel models `log θ`.
81    NegativeBinomial,
82    /// Gamma with `Var = μ²/ν`; the precision channel models `log ν` (shape).
83    Gamma,
84    /// Beta(μφ, (1−μ)φ) with a logit mean link; the precision channel models
85    /// `log φ`.
86    Beta,
87    /// Tweedie compound Poisson–Gamma with `Var = φ μ^p`, fixed power `p`; the
88    /// precision channel models `log(1/φ)`. The per-row density uses the
89    /// saddlepoint (Nelder–Pregibon) approximation for `y > 0` and the exact
90    /// point mass at `y = 0`; this is the standard tractable Tweedie ML
91    /// surface (an exact-series φ-derivative is the remaining hard sub-item of
92    /// #913).
93    Tweedie { p: f64 },
94}
95
96impl DispersionFamilyKind {
97    pub const fn family_tag(self) -> &'static str {
98        match self {
99            DispersionFamilyKind::NegativeBinomial => FAMILY_NEGBIN_LOCATION_SCALE,
100            DispersionFamilyKind::Gamma => FAMILY_GAMMA_LOCATION_SCALE,
101            DispersionFamilyKind::Beta => FAMILY_BETA_LOCATION_SCALE,
102            DispersionFamilyKind::Tweedie { .. } => FAMILY_TWEEDIE_LOCATION_SCALE,
103        }
104    }
105
106    /// The mean link is logit for Beta (a probability mean) and log otherwise.
107    pub(crate) const fn mean_is_logit(self) -> bool {
108        matches!(self, DispersionFamilyKind::Beta)
109    }
110
111    /// The mean inverse link this dispersion family fits on: log for
112    /// NegativeBinomial / Gamma / Tweedie, logit for Beta. Single source of
113    /// truth shared by the CLI and FFI save paths so the persisted
114    /// `base_link` never diverges from the fitted channel.
115    pub fn base_link(self) -> gam_problem::InverseLink {
116        use gam_problem::{InverseLink, StandardLink};
117        if self.mean_is_logit() {
118            InverseLink::Standard(StandardLink::Logit)
119        } else {
120            InverseLink::Standard(StandardLink::Log)
121        }
122    }
123
124    /// The family's canonical [`LikelihoodSpec`] (mean response × mean link).
125    /// The overdispersion parameter is estimated by the log-precision channel,
126    /// so the response-family placeholder parameters (`phi`, `theta`) mirror
127    /// the [`resolve_family`](crate::fit_orchestration::materialize::resolve_family) defaults
128    /// and are not consumed as fixed values at predict time. This is the single
129    /// source of truth for the persisted location-scale likelihood so the CLI
130    /// and FFI save paths cannot diverge.
131    pub fn likelihood_spec(self) -> gam_problem::LikelihoodSpec {
132        use gam_problem::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
133        let response = match self {
134            DispersionFamilyKind::NegativeBinomial => ResponseFamily::NegativeBinomial {
135                theta: 1.0,
136                theta_fixed: false,
137            },
138            DispersionFamilyKind::Gamma => ResponseFamily::Gamma,
139            DispersionFamilyKind::Beta => ResponseFamily::Beta { phi: 1.0 },
140            DispersionFamilyKind::Tweedie { p } => ResponseFamily::Tweedie { p },
141        };
142        let link = if self.mean_is_logit() {
143            InverseLink::Standard(StandardLink::Logit)
144        } else {
145            InverseLink::Standard(StandardLink::Log)
146        };
147        LikelihoodSpec::new(response, link)
148    }
149}
150
151pub const FAMILY_NEGBIN_LOCATION_SCALE: &str = "negbin-location-scale";
152pub const FAMILY_GAMMA_LOCATION_SCALE: &str = "gamma-location-scale";
153pub const FAMILY_BETA_LOCATION_SCALE: &str = "beta-location-scale";
154pub const FAMILY_TWEEDIE_LOCATION_SCALE: &str = "tweedie-location-scale";
155
156/// `η` magnitude clamp shared by both channels (mirrors PIRLS `ETA_CLAMP`):
157/// keeps `exp(η)` and the logit jet away from overflow while staying in the
158/// smooth interior of every link.
159pub(super) const DISPERSION_ETA_CLAMP: f64 = 30.0;
160/// Floor for a per-row IRLS working weight / curvature so the block normal
161/// equations stay positive-definite. The working *response* always carries the
162/// exact score, so the stationary point (penalised score = 0) is independent
163/// of this floor; it only conditions the inner solve.
164pub(super) const DISPERSION_MIN_CURVATURE: f64 = 1e-12;
165
166/// Row count above which the per-row dispersion-kernel map fans out across
167/// rayon workers (only when not already running on a worker, to avoid nested
168/// oversubscription). Below it the serial map beats the fork/join overhead.
169/// Mirrors the row-chunk guard in
170/// [`row_coeff_operator`](super::gaussian::row_coeff_operator).
171const DISPERSION_PARALLEL_ROW_THRESHOLD: usize = 1024;
172
173/// Per-row working quantities for both channels at the current `(η_μ, η_d)`.
174pub(super) struct DispersionRowKernel {
175    pub(super) loglik: f64,
176    pub(super) mean_weight: f64,
177    pub(super) mean_response: f64,
178    pub(super) disp_weight: f64,
179    pub(super) disp_response: f64,
180}
181
182#[cfg(test)]
183mod test_support {
184    use super::*;
185
186    /// Test-oracle NB2 row NLL over a generic [`JetScalar<2>`], seeded on the
187    /// natural parameters `(μ, θ)`.
188    #[inline]
189    pub(super) fn dispersion_nb_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
190        yi: f64,
191        mu_value: f64,
192        theta_value: f64,
193        wi: f64,
194    ) -> S {
195        let mu = S::variable(mu_value, 0);
196        let theta = S::variable(theta_value, 1);
197        let tpm = theta.add(&mu);
198        // (theta + yi).ln_gamma() - theta.ln_gamma() - ln_gamma(yi+1)
199        //   + theta*theta.ln() - theta*tpm.ln() + mu.ln()*yi - tpm.ln()*yi
200        let loglik = theta
201            .add(&S::constant(yi))
202            .ln_gamma()
203            .sub(&theta.ln_gamma())
204            .sub(&S::constant(ln_gamma(yi + 1.0)))
205            .add(&theta.mul(&theta.ln()))
206            .sub(&theta.mul(&tpm.ln()))
207            .add(&mu.ln().scale(yi))
208            .sub(&tpm.ln().scale(yi));
209        loglik.scale(-wi)
210    }
211
212    /// Test-oracle Gamma row NLL over a generic [`JetScalar<2>`], seeded on
213    /// `(μ, ν)`.
214    #[inline]
215    pub(super) fn dispersion_gamma_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
216        yi: f64,
217        y_pos: f64,
218        mu_value: f64,
219        nu_value: f64,
220        wi: f64,
221    ) -> S {
222        let mu = S::variable(mu_value, 0);
223        let nu = S::variable(nu_value, 1);
224        // nu*nu.ln() - nu*mu.ln() - nu.ln_gamma() + (nu-1)*y_pos.ln() - nu*(mu.recip()*yi)
225        let loglik = nu
226            .mul(&nu.ln())
227            .sub(&nu.mul(&mu.ln()))
228            .sub(&nu.ln_gamma())
229            .add(&nu.sub(&S::constant(1.0)).scale(y_pos.ln()))
230            .sub(&nu.mul(&mu.recip().scale(yi)));
231        loglik.scale(-wi)
232    }
233
234    /// Test-oracle Beta row NLL over a generic [`JetScalar<2>`], seeded on
235    /// `(μ, φ)`.
236    #[inline]
237    pub(super) fn dispersion_beta_nll_generic<S: gam_math::jet_scalar::JetScalar<2>>(
238        yi: f64,
239        mu_value: f64,
240        phi_value: f64,
241        wi: f64,
242    ) -> S {
243        let mu = S::variable(mu_value, 0);
244        let phi = S::variable(phi_value, 1);
245        let one_minus_mu = S::constant(1.0).sub(&mu);
246        let yc = yi.clamp(1e-12, 1.0 - 1e-12);
247        let a = mu.mul(&phi);
248        let b = one_minus_mu.mul(&phi);
249        // phi.ln_gamma() - a.ln_gamma() - b.ln_gamma()
250        //   + (a-1)*yc.ln() + (b-1)*(1-yc).ln()
251        let loglik = phi
252            .ln_gamma()
253            .sub(&a.ln_gamma())
254            .sub(&b.ln_gamma())
255            .add(&a.sub(&S::constant(1.0)).scale(yc.ln()))
256            .add(&b.sub(&S::constant(1.0)).scale((1.0 - yc).ln()));
257        loglik.scale(-wi)
258    }
259
260    /// #1591 jet-prune oracle: full `Order2<2>` (value/grad/Hessian) NB2 row NLL.
261    ///
262    /// Production no longer consumes the mean (`μ`-axis) derivative channels of
263    /// this tower — the NB mean block is Fisher-orthogonal and hand-written
264    /// exactly in [`dispersion_row_kernel`] — so the hot path uses the pruned
265    /// single-axis [`dispersion_nb_disp_order2`] instead. This `K=2` form
266    /// survives only as the dense-`Tower4<2>` oracle pin
267    /// (`order2_matches_dense_tower_all_channels`).
268    #[inline]
269    pub(super) fn dispersion_nb_nll_order2(
270        yi: f64,
271        mu_value: f64,
272        theta_value: f64,
273        wi: f64,
274    ) -> gam_math::jet_scalar::Order2<2> {
275        type O2 = gam_math::jet_scalar::Order2<2>;
276
277        let mu = O2::variable(mu_value, 0);
278        let theta = O2::variable(theta_value, 1);
279        let tpm = theta.add(&mu);
280        let theta_plus_y = theta.add(&O2::constant(yi));
281        let loglik = order2_ln_gamma(&theta_plus_y)
282            .sub(&order2_ln_gamma(&theta))
283            .sub(&O2::constant(ln_gamma(yi + 1.0)))
284            .add(&theta.mul(&theta.ln()))
285            .sub(&theta.mul(&tpm.ln()))
286            .add(&mu.ln().scale(yi))
287            .sub(&tpm.ln().scale(yi));
288        loglik.scale(-wi)
289    }
290
291    /// #1591 jet-prune oracle: full `Order2<2>` Gamma row NLL. As with NB, the
292    /// mean axis is unused in production (hand-written, Fisher-orthogonal); the
293    /// hot path uses the single-axis [`dispersion_gamma_disp_order2`]. Kept only
294    /// as the dense-tower oracle pin.
295    #[inline]
296    pub(super) fn dispersion_gamma_nll_order2(
297        yi: f64,
298        y_pos: f64,
299        mu_value: f64,
300        nu_value: f64,
301        wi: f64,
302    ) -> gam_math::jet_scalar::Order2<2> {
303        type O2 = gam_math::jet_scalar::Order2<2>;
304
305        let mu = O2::variable(mu_value, 0);
306        let nu = O2::variable(nu_value, 1);
307        let loglik = nu
308            .mul(&nu.ln())
309            .sub(&nu.mul(&mu.ln()))
310            .sub(&order2_ln_gamma(&nu))
311            .add(&nu.sub(&O2::constant(1.0)).scale(y_pos.ln()))
312            .sub(&nu.mul(&mu.recip().scale(yi)));
313        loglik.scale(-wi)
314    }
315}
316
317/// Production `Order2<2>` Beta row NLL (value/grad/Hessian hot path; the cross
318/// channel `h()[0][1]` feeds the Beta observed cross weight).
319#[inline]
320pub(crate) fn dispersion_beta_nll_order2(
321    yi: f64,
322    mu_value: f64,
323    phi_value: f64,
324    wi: f64,
325) -> gam_math::jet_scalar::Order2<2> {
326    type O2 = gam_math::jet_scalar::Order2<2>;
327
328    let mu = O2::variable(mu_value, 0);
329    let phi = O2::variable(phi_value, 1);
330    let one_minus_mu = O2::constant(1.0).sub(&mu);
331    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
332    let a = mu.mul(&phi);
333    let b = one_minus_mu.mul(&phi);
334    let loglik = order2_ln_gamma(&phi)
335        .sub(&order2_ln_gamma(&a))
336        .sub(&order2_ln_gamma(&b))
337        .add(&a.sub(&O2::constant(1.0)).scale(yc.ln()))
338        .add(&b.sub(&O2::constant(1.0)).scale((1.0 - yc).ln()));
339    loglik.scale(-wi)
340}
341
342#[inline]
343fn order2_ln_gamma<const K: usize>(
344    x: &gam_math::jet_scalar::Order2<K>,
345) -> gam_math::jet_scalar::Order2<K> {
346    gam_math::jet_scalar::Order2(
347        x.0.compose_unary(gam_math::jet_tower::ln_gamma_derivative_stack_order2(x.0.v)),
348    )
349}
350
351// ============================================================================
352// #1591 jet-prune: single-axis (`K=1`) dispersion-channel towers.
353//
354// For NegativeBinomial / Gamma / Tweedie the production row kernel consumes ONLY
355// the dispersion-axis derivatives (`g[disp]`, `h[disp][disp]`) and the value;
356// the mean block is Fisher-orthogonal and assembled in closed form. Seeding the
357// mean as a CONSTANT and the dispersion parameter as the SOLE jet variable
358// therefore yields a tower whose `(value, g[0], h[0][0])` are `to_bits`-
359// identical to the consumed `(value, g[1], h[1][1])` of the old `Order2<2>`
360// tower — the mean seed only ever populated the now-discarded `g[mean]` /
361// `h[mean][·]` channels (Leibniz/Faà-di-Bruno never read the dispersion-axis
362// channels off the mean seed). Collapsing `K=2 → K=1` quarters the Hessian
363// tensor (1 entry vs 4) and halves the gradient, with no change to any consumed
364// float bit. The `ln_gamma` derivative stacks are unchanged (the irreducible
365// transcendental cost), so this trims the rational composition, not the special
366// functions.
367// ============================================================================
368
369/// Order-≤1 `ln Γ` compose: only the value (`d[0] = lnΓ`) and first-derivative
370/// (`d[1] = ψ`) stack slots are consumed by an [`Order1`] jet, so we evaluate
371/// ONLY `lnΓ` and `ψ` — never `ψ′` (trigamma). This is the value/gradient
372/// twin of [`order2_ln_gamma`]; its `(value, g)` channels are bit-identical to
373/// that function's order-≤1 channels because [`Order1`] runs the same Leibniz /
374/// Faà-di-Bruno value+gradient float ops as [`Order2`] (doc on `Order1`).
375#[inline]
376fn order1_ln_gamma<const K: usize>(
377    x: &gam_math::jet_scalar::Order1<K>,
378) -> gam_math::jet_scalar::Order1<K> {
379    x.compose_unary([ln_gamma(x.v), gam_math::jet_tower::digamma(x.v), 0.0, 0.0, 0.0])
380}
381
382/// Value+gradient-only NB2 dispersion tower: `θ` is the sole jet variable
383/// (axis 0), `μ` a constant. `value`/`g[0]` reproduce the consumed
384/// `value`/`g[0]` of [`dispersion_nb_disp_order2`] bit-for-bit, but as an
385/// [`Order1`] jet it never evaluates the trigamma (`ψ′`) that the discarded
386/// observed-Hessian channel would need. The NB2 row kernel uses the EXPECTED
387/// (Fisher) information in θ, not the observed Hessian, so `h[0][0]` was pure
388/// discarded work — dropping it here removes two `ψ′` evaluations (at `θ+y`
389/// and `θ`) per NB2 row on top of the tensor-shrink.
390#[inline]
391pub(crate) fn dispersion_nb_disp_order1(
392    yi: f64,
393    mu_value: f64,
394    theta_value: f64,
395    wi: f64,
396) -> gam_math::jet_scalar::Order1<1> {
397    type O1 = gam_math::jet_scalar::Order1<1>;
398
399    let mu = O1::constant(mu_value);
400    let theta = O1::variable(theta_value, 0);
401    let tpm = theta.add(&mu);
402    let theta_plus_y = theta.add(&O1::constant(yi));
403    let loglik = order1_ln_gamma(&theta_plus_y)
404        .sub(&order1_ln_gamma(&theta))
405        .sub(&O1::constant(ln_gamma(yi + 1.0)))
406        .add(&theta.mul(&theta.ln()))
407        .sub(&theta.mul(&tpm.ln()))
408        .add(&mu.ln().scale(yi))
409        .sub(&tpm.ln().scale(yi));
410    loglik.scale(-wi)
411}
412
413// `dispersion_nb_disp_order2` — the `Order2<1>` NB2 dispersion oracle pin used
414// only by `prune_towers_*` — now lives in the `#[cfg(test)] mod tests` below,
415// next to its sole consumer (it is genuinely test-support, not production, so it
416// belongs in a `#[cfg(test)]` mod rather than carrying `allow(dead_code)`).
417
418/// Pruned single-axis Gamma dispersion tower: `ν` is the sole jet variable
419/// (axis 0), `μ` a constant. Consumed channels match
420/// `dispersion_gamma_nll_order2` index-1 bit-for-bit.
421#[inline]
422pub(crate) fn dispersion_gamma_disp_order2(
423    yi: f64,
424    y_pos: f64,
425    mu_value: f64,
426    nu_value: f64,
427    wi: f64,
428) -> gam_math::jet_scalar::Order2<1> {
429    type O1 = gam_math::jet_scalar::Order2<1>;
430
431    let mu = O1::constant(mu_value);
432    let nu = O1::variable(nu_value, 0);
433    let loglik = nu
434        .mul(&nu.ln())
435        .sub(&nu.mul(&mu.ln()))
436        .sub(&order2_ln_gamma(&nu))
437        .add(&nu.sub(&O1::constant(1.0)).scale(y_pos.ln()))
438        .sub(&nu.mul(&mu.recip().scale(yi)));
439    loglik.scale(-wi)
440}
441
442/// Pruned single-axis Tweedie dispersion tower seeded on the predictor `η_d`
443/// (axis 0), with `η_μ` a constant (so `μ = exp(η_μ)` carries no jet). The
444/// `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature are carried
445/// exactly as in `dispersion_tweedie_nll_generic`; `value`/`g[0]`/`h[0][0]`
446/// match that program's `value`/`g[1]`/`h[1][1]` bit-for-bit.
447#[inline]
448pub(crate) fn dispersion_tweedie_disp_order2(
449    yi: f64,
450    eta_mu: f64,
451    eta_d: f64,
452    p: f64,
453    wi: f64,
454) -> gam_math::jet_scalar::Order2<1> {
455    type O1 = gam_math::jet_scalar::Order2<1>;
456
457    let one_minus_p = 1.0 - p;
458    let two_minus_p = 2.0 - p;
459    let mu = O1::constant(eta_mu).exp();
460    let phi = O1::variable(eta_d, 0).scale(-1.0).exp();
461    if yi > 0.0 {
462        let dev = mu
463            .powf(two_minus_p)
464            .scale(1.0 / two_minus_p)
465            .sub(&mu.powf(one_minus_p).scale(yi / one_minus_p))
466            .add(&O1::constant(
467                yi.powf(two_minus_p) / (one_minus_p * two_minus_p),
468            ))
469            .scale(2.0);
470        let loglik = dev
471            .mul(&phi.recip().scale(-0.5))
472            .sub(&phi.scale(2.0 * std::f64::consts::PI).ln().scale(0.5))
473            .sub(&O1::constant(0.5 * p * yi.ln()));
474        loglik.scale(-wi)
475    } else {
476        let c = mu.powf(two_minus_p).scale(1.0 / two_minus_p);
477        let loglik = c.mul(&phi.recip()).scale(-1.0);
478        loglik.scale(-wi)
479    }
480}
481
482// ============================================================================
483// #1591 jet-prune: value-only (`K=0`) row negative-log-likelihood.
484//
485// `log_likelihood_only` reads ONLY `row.loglik = -tower.value()`; the full row
486// kernel it used to call evaluated every dispersion tower's gradient AND Hessian
487// — including the digamma/trigamma derivative stacks — purely to discard them.
488// These functions evaluate the SAME value-channel program in plain `f64`, so
489// they are `to_bits`-identical to `-tower.value()` (the jet value channel is the
490// naive scalar evaluation: `mul.v = a.v*b.v`, `compose.v = stack[0]`), while
491// touching only `ln_gamma` (stack slot 0) and never the digamma/trigamma slots.
492// On a per-row loglik that is the dominant transcendental saving.
493// ============================================================================
494
495/// NB2 row NLL value, plain `f64`, bit-identical to
496/// `-dispersion_nb_disp_order2(..).value()`.
497#[inline]
498fn dispersion_nb_neg_loglik(yi: f64, mu: f64, theta: f64, wi: f64) -> f64 {
499    let tpm = theta + mu;
500    let s = ln_gamma(theta + yi) - ln_gamma(theta) - ln_gamma(yi + 1.0) + theta * theta.ln()
501        - theta * tpm.ln()
502        + mu.ln() * yi
503        - tpm.ln() * yi;
504    -(s * -wi)
505}
506
507/// Gamma row NLL value, plain `f64`, bit-identical to
508/// `-dispersion_gamma_disp_order2(..).value()`.
509#[inline]
510fn dispersion_gamma_neg_loglik(yi: f64, y_pos: f64, mu: f64, nu: f64, wi: f64) -> f64 {
511    // NB: the jet forms `μ.recip().scale(yi)` = `(1/μ)·yᵢ` (reciprocal then
512    // multiply), NOT `yᵢ/μ` (single divide) — these differ in the last bit, so
513    // the value path must reproduce the reciprocal-then-multiply exactly.
514    let s = nu * nu.ln() - nu * mu.ln() - ln_gamma(nu) + (nu - 1.0) * y_pos.ln()
515        - nu * ((1.0 / mu) * yi);
516    -(s * -wi)
517}
518
519/// Beta row NLL value, plain `f64`, bit-identical to
520/// `-dispersion_beta_nll_order2(..).value()`.
521#[inline]
522fn dispersion_beta_neg_loglik(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
523    let one_minus_mu = 1.0 - mu;
524    let yc = yi.clamp(1e-12, 1.0 - 1e-12);
525    let a = mu * phi;
526    let b = one_minus_mu * phi;
527    let s = ln_gamma(phi) - ln_gamma(a) - ln_gamma(b)
528        + (a - 1.0) * yc.ln()
529        + (b - 1.0) * (1.0 - yc).ln();
530    -(s * -wi)
531}
532
533/// Tweedie row NLL value, plain `f64`, bit-identical to
534/// `-dispersion_tweedie_disp_order2(..).value()` (both density branches).
535#[inline]
536fn dispersion_tweedie_neg_loglik(yi: f64, eta_mu: f64, eta_d: f64, p: f64, wi: f64) -> f64 {
537    let one_minus_p = 1.0 - p;
538    let two_minus_p = 2.0 - p;
539    let mu = eta_mu.exp();
540    let phi = (-eta_d).exp();
541    let s = if yi > 0.0 {
542        let dev = (mu.powf(two_minus_p) * (1.0 / two_minus_p)
543            - mu.powf(one_minus_p) * (yi / one_minus_p)
544            + yi.powf(two_minus_p) / (one_minus_p * two_minus_p))
545            * 2.0;
546        dev * ((1.0 / phi) * -0.5)
547            - (phi * (2.0 * std::f64::consts::PI)).ln() * 0.5
548            - 0.5 * p * yi.ln()
549    } else {
550        let c = mu.powf(two_minus_p) * (1.0 / two_minus_p);
551        (c * (1.0 / phi)) * -1.0
552    };
553    -(s * -wi)
554}
555
556/// Value-only row negative log-likelihood for one observation — the pruned hot
557/// path for [`CustomFamily::log_likelihood_only`]. Mirrors the link/clamp
558/// preamble of [`dispersion_row_kernel`] exactly, then evaluates ONLY the value
559/// channel (no gradient/Hessian, no digamma/trigamma). Returns `row.loglik`
560/// `to_bits`-identically.
561#[inline]
562pub(crate) fn dispersion_row_loglik(
563    kind: DispersionFamilyKind,
564    yi: f64,
565    eta_mu: f64,
566    eta_d: f64,
567    prior_weight: f64,
568) -> f64 {
569    let wi = prior_weight.max(0.0);
570    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
571    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
572    match kind {
573        DispersionFamilyKind::NegativeBinomial => {
574            let mu = em.exp().max(1e-300);
575            let theta = ed.exp().max(1e-12);
576            dispersion_nb_neg_loglik(yi, mu, theta, wi)
577        }
578        DispersionFamilyKind::Gamma => {
579            let mu = em.exp().max(1e-300);
580            let nu = ed.exp().max(1e-12);
581            let y_pos = yi.max(1e-300);
582            dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)
583        }
584        DispersionFamilyKind::Beta => {
585            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
586            let phi = ed.exp().max(1e-12);
587            dispersion_beta_neg_loglik(yi, mu, phi, wi)
588        }
589        DispersionFamilyKind::Tweedie { p } => dispersion_tweedie_neg_loglik(yi, em, ed, p, wi),
590    }
591}
592
593#[inline]
594pub(crate) fn beta_observed_cross_weight_eta(yi: f64, mu: f64, phi: f64, wi: f64) -> f64 {
595    let q = (mu * (1.0 - mu)).max(1e-12);
596    let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
597    q * phi * tower.h()[0][1]
598}
599
600#[inline]
601pub(crate) fn dispersion_row_cross_weight(
602    kind: DispersionFamilyKind,
603    yi: f64,
604    eta_mu: f64,
605    eta_d: f64,
606    prior_weight: f64,
607) -> f64 {
608    let wi = prior_weight.max(0.0);
609    if wi == 0.0 {
610        return 0.0;
611    }
612    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
613    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
614    match kind {
615        DispersionFamilyKind::Beta => {
616            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
617            let phi = ed.exp().max(1e-12);
618            beta_observed_cross_weight_eta(yi, mu, phi, wi)
619        }
620        DispersionFamilyKind::NegativeBinomial
621        | DispersionFamilyKind::Gamma
622        | DispersionFamilyKind::Tweedie { .. } => 0.0,
623    }
624}
625
626#[inline]
627pub(crate) fn tower_score_info<const K: usize>(
628    tower: &gam_math::jet_scalar::Order2<K>,
629    idx: usize,
630    wi: f64,
631) -> (f64, f64) {
632    if wi == 0.0 {
633        (0.0, 0.0)
634    } else {
635        (-tower.g()[idx] / wi, tower.h()[idx][idx] / wi)
636    }
637}
638
639/// Evaluate the row log-likelihood and the (mean, log-precision) Fisher-scoring
640/// working sets for one observation. `eta_mu`/`eta_d` already include any
641/// per-channel offset (they are the block predictors). `prior_weight` is the
642/// observation's prior weight.
643pub(super) fn dispersion_row_kernel(
644    kind: DispersionFamilyKind,
645    yi: f64,
646    eta_mu: f64,
647    eta_d: f64,
648    prior_weight: f64,
649) -> DispersionRowKernel {
650    let wi = prior_weight.max(0.0);
651    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
652    let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
653    match kind {
654        DispersionFamilyKind::NegativeBinomial => {
655            let mu = em.exp().max(1e-300);
656            let theta = ed.exp().max(1e-12); // precision (size)
657            let tpm = theta + mu;
658            // Only the exact θ-space SCORE and the value are consumed here; the
659            // observed-Hessian channel is discarded in favor of the expected
660            // (Fisher) information assembled below (see the dispersion-curvature
661            // note). So the tower is the value+gradient-only `Order1` twin, which
662            // never evaluates the discarded observed Hessian's trigamma at `θ+y`
663            // and `θ`; `value`/`g[0]` are bit-identical to the `Order2` form.
664            let tower = dispersion_nb_disp_order1(yi, mu, theta, wi);
665            let s_theta = if wi == 0.0 { 0.0 } else { -tower.g()[0] / wi };
666            let loglik = -tower.value();
667            let info_mu = if wi == 0.0 {
668                DISPERSION_MIN_CURVATURE
669            } else {
670                (theta / (mu * tpm)).max(DISPERSION_MIN_CURVATURE)
671            };
672            let score_mu = theta * (yi - mu) / (mu * tpm);
673            let mean_weight = wi * mu * mu * info_mu;
674            let mean_response = em + score_mu / (mu * info_mu);
675            // Dispersion (log-θ) IRLS curvature: use the EXPECTED (Fisher)
676            // information in θ, not the per-row OBSERVED Hessian channel
677            // (`_info_theta_observed`). The NB2 log-likelihood is strongly
678            // non-quadratic in θ: `−∂²ℓ/∂θ²` carries the row-specific term
679            // `ψ′(θ+y)` and goes NEGATIVE for every row whose count sits below
680            // its current fitted precision (overestimated size / underestimated
681            // overdispersion). Far from the optimum a majority of rows can be
682            // negative, so the assembled block curvature `Xᵀdiag(w)X` loses
683            // positive-definiteness; flooring each negative row at
684            // `DISPERSION_MIN_CURVATURE` (≈0) then divides the exact score by
685            // ~0 in the working response, producing O(1e12) IRLS targets that
686            // make the dispersion block step explode and the inner block-cyclic
687            // solve stall (never reaching KKT within the cycle budget — the
688            // `nb` location-scale `IntegrationError`, gam#1606). The mean block
689            // already uses its closed-form expected info `θ/(μ(θ+μ))`; the
690            // dispersion block must do the same.
691            //
692            // The Fisher information in θ has the closed form
693            //   I(θ) = ψ′(θ) − E[ψ′(θ+Y)] − 1/θ + 1/(θ+μ),
694            // whose only costly piece is the per-row infinite expectation
695            // `E[ψ′(θ+Y)]`. Replacing it with the Jensen plug-in `ψ′(θ+μ)`
696            // (valid because ψ′ is convex, so this is a tight lower bound on the
697            // expectation) gives a per-row, sum-free, STRICTLY POSITIVE
698            // curvature
699            //   I_θ ≈ ψ′(θ) − ψ′(θ+μ) − 1/θ + 1/(θ+μ) > 0  for all (μ,θ),
700            // since ψ′ is strictly decreasing. The working RESPONSE still
701            // carries the EXACT score `s_theta` (= ∂ℓ/∂θ from the tower), so the
702            // penalized stationary point (score = 0) is byte-unchanged — this is
703            // Fisher scoring, which only re-conditions the inner solve and never
704            // shifts the optimum (cf. the `DISPERSION_MIN_CURVATURE` contract
705            // note above). The observed channel `_info_theta_observed` is no
706            // longer consumed for the weight.
707            // #1591-follow-up: scalar `trigamma` (== `trigamma_derivative_stack
708            // (·)[0]` bit-for-bit) evaluates ONLY ψ′; the old `[0]`-index form
709            // built the full order-1..5 polygamma stack and discarded four of
710            // five per call (8 wasted polygamma evaluations per NB2 row).
711            let trigamma_theta = gam_math::jet_tower::trigamma(theta);
712            let trigamma_tpm = gam_math::jet_tower::trigamma(tpm);
713            let info_theta_fisher = trigamma_theta - trigamma_tpm - 1.0 / theta + 1.0 / tpm;
714            let info_pos = info_theta_fisher.max(DISPERSION_MIN_CURVATURE);
715            let disp_weight = wi * theta * theta * info_pos;
716            let disp_response = ed + s_theta / (theta * info_pos);
717            DispersionRowKernel {
718                loglik,
719                mean_weight,
720                mean_response,
721                disp_weight,
722                disp_response,
723            }
724        }
725        DispersionFamilyKind::Gamma => {
726            let mu = em.exp().max(1e-300);
727            let nu = ed.exp().max(1e-12); // precision = shape ν
728            let y_pos = yi.max(1e-300);
729            let tower = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
730            let (s_nu, info_nu_raw) = tower_score_info(&tower, 0, wi);
731            let loglik = -tower.value();
732            let info_mu = if wi == 0.0 {
733                DISPERSION_MIN_CURVATURE
734            } else {
735                (nu / (mu * mu)).max(DISPERSION_MIN_CURVATURE)
736            };
737            let score_mu = nu * (yi - mu) / (mu * mu);
738            let mean_weight = wi * mu * mu * info_mu;
739            let mean_response = em + score_mu / (mu * info_mu);
740            let info_nu = info_nu_raw.max(DISPERSION_MIN_CURVATURE);
741            let disp_weight = wi * nu * nu * info_nu;
742            let disp_response = ed + s_nu / (nu * info_nu);
743            DispersionRowKernel {
744                loglik,
745                mean_weight,
746                mean_response,
747                disp_weight,
748                disp_response,
749            }
750        }
751        DispersionFamilyKind::Beta => {
752            // logit mean link.
753            let mu = (1.0 / (1.0 + (-em).exp())).clamp(1e-12, 1.0 - 1e-12);
754            let phi = ed.exp().max(1e-12); // precision
755            let q = (mu * (1.0 - mu)).max(1e-12); // dμ/dη
756            let tower = dispersion_beta_nll_order2(yi, mu, phi, wi);
757            let (score_mu, info_mu_raw) = tower_score_info(&tower, 0, wi);
758            let (s_phi, info_phi_raw) = tower_score_info(&tower, 1, wi);
759            let loglik = -tower.value();
760            let info_mu = info_mu_raw.max(DISPERSION_MIN_CURVATURE);
761            let mean_weight = wi * q * q * info_mu;
762            let mean_response = em + score_mu / (q * info_mu);
763            let info_phi = info_phi_raw.max(DISPERSION_MIN_CURVATURE);
764            let disp_weight = wi * phi * phi * info_phi;
765            let disp_response = ed + s_phi / (phi * info_phi);
766            DispersionRowKernel {
767                loglik,
768                mean_weight,
769                mean_response,
770                disp_weight,
771                disp_response,
772            }
773        }
774        DispersionFamilyKind::Tweedie { p } => {
775            let mu = em.exp().max(1e-300);
776            // Precision channel models log(1/φ) ⇒ φ = exp(−η_d).
777            let phi = (-ed).exp().max(1e-12);
778            let two_minus_p = 2.0 - p;
779            // Mean channel: the quasi-score `(y−μ)/μ` and Fisher weight
780            // `μ^{2−p}/φ` are simple closed forms (and the mean block is
781            // Fisher-orthogonal to the dispersion block in this
782            // parameterization), so they stay hand-written exactly as the
783            // NB/Gamma mean arms do.
784            let mean_weight = wi * mu.powf(two_minus_p) / phi;
785            let mean_response = em + (yi - mu) / mu;
786            // Dispersion channel: the η_d-space score and OBSERVED information
787            // come straight off the single-expression tower seeded on `η_d`
788            // (#932), so the saddlepoint/point-mass branch split, the
789            // `φ = exp(−η_d)` chain and its nonlinear `∂²φ/∂η_d²` curvature
790            // correction are all mechanically carried — no per-branch
791            // `s_phi`/`s_eta`/`curvature_eta` hand calculus. #1591: only the
792            // η_d axis is consumed, so the tower is the pruned single-axis
793            // `Order2<1>` (`η_μ` enters as a constant).
794            let tower = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
795            let loglik = -tower.value();
796            // η_d-space score and observed information off the tower, via the
797            // same helper the NB/Gamma/Beta arms use (returns `(0, 0)` when the
798            // prior weight is zero, so the row stays excluded below).
799            let (s_eta, info_eta_raw) = tower_score_info(&tower, 0, wi);
800            let curvature_eta = if wi == 0.0 {
801                DISPERSION_MIN_CURVATURE
802            } else {
803                info_eta_raw.max(DISPERSION_MIN_CURVATURE)
804            };
805            // The working response divides by this per-row curvature so the
806            // prior weight cancels (and a zero-prior-weight row stays excluded
807            // via `disp_weight = 0`).
808            let disp_weight = wi * curvature_eta;
809            let disp_response = ed + s_eta / curvature_eta;
810            DispersionRowKernel {
811                loglik,
812                mean_weight,
813                mean_response,
814                disp_weight,
815                disp_response,
816            }
817        }
818    }
819}
820
821/// Two-block GAMLSS family for the genuine-dispersion mean families (#913).
822#[derive(Clone)]
823pub(crate) struct DispersionGlmLocationScaleFamily {
824    pub(crate) kind: DispersionFamilyKind,
825    pub(crate) y: Array1<f64>,
826    pub(crate) weights: Array1<f64>,
827}
828
829impl DispersionGlmLocationScaleFamily {
830    pub(crate) const BLOCK_MEAN: usize = 0;
831    pub(crate) const BLOCK_DISP: usize = 1;
832}
833
834impl CustomFamily for DispersionGlmLocationScaleFamily {
835    // Preserve the pre-gam#1395 behavior: the trait default flipped to OFF (the
836    // flat-prior exact-Newton objective carries no Jeffreys term), so families
837    // that historically armed the term by default opt back in explicitly.
838    fn joint_jeffreys_term_required(&self) -> bool {
839        true
840    }
841
842    fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
843        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
844        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
845        let eta_d = &block_states[Self::BLOCK_DISP].eta;
846        let n = self.y.len();
847        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
848            return Err(format!(
849                "{} row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
850                self.kind.family_tag(),
851                eta_mu.len(),
852                eta_d.len(),
853                self.weights.len()
854            ));
855        }
856        let mut mean_weights = Array1::<f64>::zeros(n);
857        let mut mean_response = Array1::<f64>::zeros(n);
858        let mut disp_weights = Array1::<f64>::zeros(n);
859        let mut disp_response = Array1::<f64>::zeros(n);
860
861        // `dispersion_row_kernel` is a pure, row-independent map — each row reads
862        // only `y[i]`/`eta_mu[i]`/`eta_d[i]`/`weights[i]` and writes nothing
863        // shared — and it is transcendental-heavy (per-row digamma/trigamma
864        // derivative stacks), so the per-row evaluation is embarrassingly
865        // row-parallel. Materialize the per-row kernels (in parallel for large
866        // `n` when not already on a rayon worker; mirrors the
867        // `row_coeff_operator` guard), then reduce SERIALLY in index order so
868        // the log-likelihood sum is bit-identical to the old serial loop — no
869        // float reassociation. The reduction touches no transcendentals, so the
870        // parallel kernel map captures essentially all the savings.
871        let kernels: Vec<DispersionRowKernel> = if rayon::current_thread_index().is_none()
872            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
873        {
874            use rayon::iter::{IntoParallelIterator, ParallelIterator};
875            (0..n)
876                .into_par_iter()
877                .map(|i| {
878                    dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
879                })
880                .collect()
881        } else {
882            (0..n)
883                .map(|i| {
884                    dispersion_row_kernel(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
885                })
886                .collect()
887        };
888
889        let mut log_likelihood = 0.0;
890        for (i, row) in kernels.into_iter().enumerate() {
891            if row.loglik.is_finite() {
892                log_likelihood += row.loglik;
893            }
894            mean_weights[i] = row.mean_weight.max(0.0);
895            mean_response[i] = row.mean_response;
896            disp_weights[i] = row.disp_weight.max(0.0);
897            disp_response[i] = row.disp_response;
898        }
899        Ok(FamilyEvaluation {
900            log_likelihood,
901            blockworking_sets: vec![
902                BlockWorkingSet::diagonal_checked(mean_response, mean_weights)?,
903                BlockWorkingSet::diagonal_checked(disp_response, disp_weights)?,
904            ],
905        })
906    }
907
908    fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
909        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
910        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
911        let eta_d = &block_states[Self::BLOCK_DISP].eta;
912        let n = self.y.len();
913        // #1591 prune: the objective needs only the row log-likelihood, so each
914        // row evaluates the value channel alone (`to_bits`-identical to
915        // `dispersion_row_kernel(..).loglik`), skipping every gradient/Hessian
916        // and digamma/trigamma derivative-stack evaluation. That value-only map
917        // is still a pure, row-independent per-row `ln_gamma` evaluation, so it
918        // is row-parallel; fan it out (large `n`, off a rayon worker) into a
919        // per-row buffer, then sum SERIALLY in index order to keep the objective
920        // bit-identical to the serial loop (no float reassociation).
921        let per_row: Vec<f64> = if rayon::current_thread_index().is_none()
922            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
923        {
924            use rayon::iter::{IntoParallelIterator, ParallelIterator};
925            (0..n)
926                .into_par_iter()
927                .map(|i| {
928                    dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
929                })
930                .collect()
931        } else {
932            (0..n)
933                .map(|i| {
934                    dispersion_row_loglik(self.kind, self.y[i], eta_mu[i], eta_d[i], self.weights[i])
935                })
936                .collect()
937        };
938        let mut ll = 0.0;
939        for loglik in per_row {
940            if loglik.is_finite() {
941                ll += loglik;
942            }
943        }
944        Ok(ll)
945    }
946
947    fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
948        crate::location_scale_engine::location_scale_coefficient_hessian_cost(
949            self.y.len() as u64,
950            specs,
951        )
952    }
953
954    /// Exact joint coefficient-space Hessian `H_L = -∇²log L` in flattened
955    /// `[mean | log-precision]` block order.
956    ///
957    /// All four members assemble the same `Xᵀ diag(W) X` blocks; the cross
958    /// block is the per-row mixed weight `dispersion_row_cross_weight`. Beta
959    /// carries a genuinely nonzero (η_μ, η_φ) cross weight; the Fisher-
960    /// orthogonal members (NegativeBinomial / Gamma / Tweedie) report a zero
961    /// cross weight, so this returns their exact *block-diagonal* joint
962    /// Hessian. Returning that block-diagonal `H_L` — rather than `None` —
963    /// is what lets the multi-block outer-REML path (`build_joint_hessian_
964    /// closures` → `joint_outer_evaluate`) and the joint posterior covariance
965    /// (`compute_joint_covariance`) run for these families instead of failing
966    /// the "multi-block families must provide a joint outer path" gate and
967    /// silently escalating to a degraded ρ-seed fit with no covariance/EDF
968    /// (gam#1119). The orthogonal members additionally declare
969    /// `likelihood_blocks_uncoupled() = true` so the directional-derivative
970    /// and Jeffreys dispatch route through the block-diagonal-exact fallback
971    /// rather than rejecting the structurally-uncoupled Hessian.
972    fn exact_newton_joint_hessian_with_specs(
973        &self,
974        block_states: &[ParameterBlockState],
975        specs: &[ParameterBlockSpec],
976    ) -> Result<Option<Array2<f64>>, String> {
977        validate_block_count::<GamlssError>(self.kind.family_tag(), 2, block_states.len())?;
978        if specs.len() != 2 {
979            return Err(format!(
980                "{} exact joint Hessian expects 2 specs, got {}",
981                self.kind.family_tag(),
982                specs.len()
983            ));
984        }
985        let eta_mu = &block_states[Self::BLOCK_MEAN].eta;
986        let eta_d = &block_states[Self::BLOCK_DISP].eta;
987        let n = self.y.len();
988        if eta_mu.len() != n || eta_d.len() != n || self.weights.len() != n {
989            return Err(format!(
990                "{} exact joint Hessian row-count mismatch: y={n}, eta_mu={}, eta_d={}, weights={}",
991                self.kind.family_tag(),
992                eta_mu.len(),
993                eta_d.len(),
994                self.weights.len()
995            ));
996        }
997
998        let eval = self.evaluate(block_states)?;
999        let BlockWorkingSet::Diagonal {
1000            working_weights: mean_weights,
1001            ..
1002        } = &eval.blockworking_sets[Self::BLOCK_MEAN]
1003        else {
1004            return Err(format!(
1005                "{} dispersion mean block did not return diagonal weights",
1006                self.kind.family_tag()
1007            ));
1008        };
1009        let BlockWorkingSet::Diagonal {
1010            working_weights: disp_weights,
1011            ..
1012        } = &eval.blockworking_sets[Self::BLOCK_DISP]
1013        else {
1014            return Err(format!(
1015                "{} dispersion precision block did not return diagonal weights",
1016                self.kind.family_tag()
1017            ));
1018        };
1019
1020        // Per-row mixed `(η_μ, η_d)` weight; for Beta this is a full `Order2<2>`
1021        // tower per row (the orthogonal members return 0 cheaply). Row-
1022        // independent, so fan it out for large `n` (off a rayon worker) into a
1023        // per-row buffer — index-ordered, no reduction, so byte-identical to the
1024        // serial `from_shape_fn`.
1025        let cross_weights = if rayon::current_thread_index().is_none()
1026            && n > DISPERSION_PARALLEL_ROW_THRESHOLD
1027        {
1028            use rayon::iter::{IntoParallelIterator, ParallelIterator};
1029            Array1::from_vec(
1030                (0..n)
1031                    .into_par_iter()
1032                    .map(|i| {
1033                        dispersion_row_cross_weight(
1034                            self.kind,
1035                            self.y[i],
1036                            eta_mu[i],
1037                            eta_d[i],
1038                            self.weights[i],
1039                        )
1040                    })
1041                    .collect::<Vec<f64>>(),
1042            )
1043        } else {
1044            Array1::from_shape_fn(n, |i| {
1045                dispersion_row_cross_weight(
1046                    self.kind,
1047                    self.y[i],
1048                    eta_mu[i],
1049                    eta_d[i],
1050                    self.weights[i],
1051                )
1052            })
1053        };
1054        let mean_spec = &specs[Self::BLOCK_MEAN];
1055        let disp_spec = &specs[Self::BLOCK_DISP];
1056        if mean_spec.design.nrows() != n || disp_spec.design.nrows() != n {
1057            return Err(format!(
1058                "{} exact joint Hessian design row mismatch: y={n}, mean rows={}, precision rows={}",
1059                self.kind.family_tag(),
1060                mean_spec.design.nrows(),
1061                disp_spec.design.nrows()
1062            ));
1063        }
1064        let p_mean = mean_spec.design.ncols();
1065        let p_disp = disp_spec.design.ncols();
1066        if block_states[Self::BLOCK_MEAN].beta.len() != p_mean
1067            || block_states[Self::BLOCK_DISP].beta.len() != p_disp
1068        {
1069            return Err(format!(
1070                "{} exact joint Hessian beta/design mismatch: mean beta {} vs cols {}, precision beta {} vs cols {}",
1071                self.kind.family_tag(),
1072                block_states[Self::BLOCK_MEAN].beta.len(),
1073                p_mean,
1074                block_states[Self::BLOCK_DISP].beta.len(),
1075                p_disp
1076            ));
1077        }
1078
1079        let h_mean = xt_diag_x_design(&mean_spec.design, mean_weights)?;
1080        let h_cross = xt_diag_y_design(&mean_spec.design, &cross_weights, &disp_spec.design)?;
1081        let h_disp = xt_diag_x_design(&disp_spec.design, disp_weights)?;
1082        let total = p_mean + p_disp;
1083        let mut h = Array2::<f64>::zeros((total, total));
1084        h.slice_mut(s![0..p_mean, 0..p_mean]).assign(&h_mean);
1085        h.slice_mut(s![0..p_mean, p_mean..total]).assign(&h_cross);
1086        h.slice_mut(s![p_mean..total, p_mean..total])
1087            .assign(&h_disp);
1088        mirror_upper_to_lower(&mut h);
1089        Ok(Some(h))
1090    }
1091
1092    /// Whether the joint likelihood Hessian is block-diagonal in the
1093    /// `[mean | log-precision]` coefficient vector.
1094    ///
1095    /// `Beta(μφ, (1−μ)φ)` carries a genuinely nonzero `(η_μ, η_φ)` Fisher
1096    /// cross block (see the module header), so its blocks are coupled. The
1097    /// remaining members are Fisher-orthogonal in their mean/precision
1098    /// parameterizations — NB2 `(μ, θ)`, Gamma shape `ν = 1/φ`, Tweedie
1099    /// `log(1/φ)` — so `∂²L/∂β_μ∂β_d = 0` and the joint Hessian is exactly
1100    /// block-diagonal. Declaring that here lets the trait's directional-
1101    /// derivative / Jeffreys dispatch accept the block-diagonal joint Hessian
1102    /// via the working-set-exact fallback instead of rejecting it as an
1103    /// untrusted structurally-uncoupled override (which would strand the
1104    /// outer-REML gradient with a "dH unavailable" error, gam#1119).
1105    fn likelihood_blocks_uncoupled(&self) -> bool {
1106        !matches!(self.kind, DispersionFamilyKind::Beta)
1107    }
1108
1109    /// The mean and precision working weights couple across both blocks, which
1110    /// the block-local diagonal drift hook cannot represent, so decline the
1111    /// dense outer Hessian capability whenever the actual two-block (or
1112    /// larger) geometry is in play; a degenerate single-block probe — there
1113    /// is no cross-block coupling to reject — keeps the trait default's
1114    /// availability verdict.
1115    ///
1116    /// The override still validates the block-spec slice it is handed (the
1117    /// same consistency check the trait default's assertion bottoms out in)
1118    /// so a malformed probe is reported here rather than downstream.
1119    fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
1120        assert!(
1121            crate::custom_family::validate_blockspec_consistency(specs).is_ok(),
1122            "DispersionGlmLocationScale outer hyper-Hessian dense availability: \
1123             inconsistent parameter block specs"
1124        );
1125        specs.len() < 2
1126    }
1127}
1128
1129/// Term spec consumed by [`fit_dispersion_glm_location_scale_terms`]; mirrors
1130/// [`GaussianLocationScaleTermSpec`](super::GaussianLocationScaleTermSpec) with
1131/// the dispersion channel in place of the Gaussian log-σ channel.
1132pub struct DispersionGlmLocationScaleTermSpec {
1133    pub kind: DispersionFamilyKind,
1134    pub y: Array1<f64>,
1135    pub weights: Array1<f64>,
1136    pub meanspec: TermCollectionSpec,
1137    pub log_dispspec: TermCollectionSpec,
1138    pub mean_offset: Array1<f64>,
1139    pub log_disp_offset: Array1<f64>,
1140}
1141
1142pub(crate) struct DispersionGlmLocationScaleTermBuilder {
1143    pub(crate) kind: DispersionFamilyKind,
1144    pub(crate) y: Array1<f64>,
1145    pub(crate) weights: Array1<f64>,
1146    pub(crate) meanspec: TermCollectionSpec,
1147    pub(crate) noisespec: TermCollectionSpec,
1148    pub(crate) mean_offset: Array1<f64>,
1149    pub(crate) noise_offset: Array1<f64>,
1150}
1151
1152/// Warm start for a dispersion location-scale fit: project a link-transformed
1153/// response onto the mean block and seed the log-precision block at a constant
1154/// (precision ≈ 1) baseline. The block-cyclic IRLS then refines both jointly.
1155pub(crate) fn dispersion_location_scale_warm_start(
1156    kind: DispersionFamilyKind,
1157    y: &Array1<f64>,
1158    weights: &Array1<f64>,
1159    mean_block: &ParameterBlockSpec,
1160    disp_block: &ParameterBlockSpec,
1161    mean_beta_hint: Option<&Array1<f64>>,
1162    disp_beta_hint: Option<&Array1<f64>>,
1163) -> Result<(Array1<f64>, Array1<f64>), String> {
1164    let ridge_floor = 1e-10;
1165    let mean_beta = if let Some(beta) = mean_beta_hint {
1166        beta.clone()
1167    } else {
1168        let target = Array1::from_shape_fn(y.len(), |i| {
1169            if kind.mean_is_logit() {
1170                let yi = y[i].clamp(1e-3, 1.0 - 1e-3);
1171                (yi / (1.0 - yi)).ln()
1172            } else {
1173                // log mean link; the +0.1 keeps zero counts finite.
1174                (y[i].max(0.0) + 0.1).ln()
1175            }
1176        });
1177        solve_penalizedweighted_projection(
1178            &mean_block.design,
1179            &mean_block.offset,
1180            &target,
1181            weights,
1182            &mean_block.penalties,
1183            &mean_block.initial_log_lambdas,
1184            ridge_floor,
1185        )?
1186    };
1187    let disp_beta = if let Some(beta) = disp_beta_hint {
1188        beta.clone()
1189    } else {
1190        // Seed the precision block from a smoothed method-of-moments surface
1191        // rather than the old flat η_d=0 constant.  A single observation cannot
1192        // identify its own variance, but for the Fisher-orthogonal dispersion
1193        // members the residual-squared moment contains the correct first-order
1194        // signal:
1195        //
1196        //   Gamma:   Var(Y)=μ²/ν              ⇒ log ν     ≈ log(μ²/e²)
1197        //   NB2:     Var(Y)=μ+μ²/θ            ⇒ log θ     ≈ log(μ²/(e²-μ))
1198        //   Tweedie: Var(Y)=φ μ^p, η_d=log1/φ ⇒ η_d       ≈ log(μ^p/e²)
1199        //
1200        // The targets are deliberately conservative (finite residual floor,
1201        // precision cap, and no fixture-specific constants): they only give the
1202        // block-cyclic likelihood solve a correctly-signed non-flat starting
1203        // surface, while the final estimate is still the penalized joint MLE.
1204        let mean_eta = mean_block.design.apply(&mean_beta) + &mean_block.offset;
1205        let target = Array1::from_shape_fn(y.len(), |i| {
1206            dispersion_moment_log_precision_seed(kind, y[i], mean_eta[i])
1207        });
1208        solve_penalizedweighted_projection(
1209            &disp_block.design,
1210            &disp_block.offset,
1211            &target,
1212            weights,
1213            &disp_block.penalties,
1214            &disp_block.initial_log_lambdas,
1215            ridge_floor,
1216        )?
1217    };
1218    Ok((mean_beta, disp_beta))
1219}
1220
1221#[inline]
1222fn dispersion_moment_log_precision_seed(kind: DispersionFamilyKind, yi: f64, eta_mu: f64) -> f64 {
1223    const LOG_PRECISION_FLOOR: f64 = -10.0;
1224    const LOG_PRECISION_CEILING: f64 = 10.0;
1225    let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1226    let raw = match kind {
1227        DispersionFamilyKind::Beta => {
1228            // Beta's mean and precision scores are not Fisher-orthogonal in
1229            // the (logit μ, log φ) parameterization.  Per-row residual moments
1230            // therefore make a poor block-cyclic seed: an outlying y near 0/1
1231            // can imply a near-zero φ and pull the coupled mean block onto the
1232            // boundary before the joint likelihood has had a chance to settle.
1233            // Keep the neutral precision seed for this one coupled member; the
1234            // exact Beta cross-Hessian below still drives the joint solve and
1235            // covariance with the coherent two-block likelihood geometry.
1236            0.0
1237        }
1238        DispersionFamilyKind::Gamma => {
1239            let mu = em.exp().max(1e-12);
1240            let e2 = (yi - mu).powi(2).max(1e-8 * mu * mu);
1241            (mu * mu / e2).max(1e-6).ln()
1242        }
1243        DispersionFamilyKind::NegativeBinomial => {
1244            let mu = em.exp().max(1e-12);
1245            let e2 = (yi - mu).powi(2);
1246            let excess = (e2 - mu).max(1e-6 * (mu + mu * mu));
1247            (mu * mu / excess).max(1e-6).ln()
1248        }
1249        DispersionFamilyKind::Tweedie { p } => {
1250            let mu = em.exp().max(1e-12);
1251            let e2 = (yi - mu).powi(2).max(1e-8 * mu.powf(p));
1252            (mu.powf(p) / e2).max(1e-6).ln()
1253        }
1254    };
1255    raw.clamp(LOG_PRECISION_FLOOR, LOG_PRECISION_CEILING)
1256}
1257
1258impl LocationScaleFamilyBuilder for DispersionGlmLocationScaleTermBuilder {
1259    type Family = DispersionGlmLocationScaleFamily;
1260
1261    fn meanspec(&self) -> &TermCollectionSpec {
1262        &self.meanspec
1263    }
1264
1265    fn noisespec(&self) -> &TermCollectionSpec {
1266        &self.noisespec
1267    }
1268
1269    fn noise_penalty_count(&self, noise_design: &TermCollectionDesign) -> usize {
1270        // Mirror the Gaussian/Binomial scale block: a full-span shrinkage
1271        // penalty pins the log-precision nullspace so REML does not optimise
1272        // the dispersion smoothing on a flat surface.
1273        noise_design.penalties.len() + 1
1274    }
1275
1276    fn build_blocks(
1277        &self,
1278        theta: &Array1<f64>,
1279        mean_design: &TermCollectionDesign,
1280        noise_design: &TermCollectionDesign,
1281        mean_beta_hint: Option<Array1<f64>>,
1282        noise_beta_hint: Option<Array1<f64>>,
1283    ) -> Result<Vec<ParameterBlockSpec>, String> {
1284        let layout = GamlssLambdaLayout::two_block(
1285            mean_design.penalties.len(),
1286            self.noise_penalty_count(noise_design),
1287        );
1288        layout.validate_theta_len(theta.len(), "dispersion location-scale")?;
1289
1290        let mut meanspec = build_location_scale_block(
1291            "mu",
1292            mean_design.design.clone(),
1293            self.mean_offset.clone(),
1294            mean_design.penalties_as_penalty_matrix(),
1295            mean_design.nullspace_dims.clone(),
1296            layout.mean_from(theta),
1297            mean_beta_hint,
1298            0,
1299            LOCATION_SCALE_N_OUTPUTS,
1300            "DispersionLocationScale::build_blocks: mu",
1301        )?;
1302
1303        let p_disp = noise_design.design.ncols();
1304        let mut disp_penalties = noise_design.penalties_as_penalty_matrix();
1305        disp_penalties.push(PenaltyMatrix::Dense(identity_penalty(p_disp)));
1306        let mut disp_nullspace = noise_design.nullspace_dims.clone();
1307        disp_nullspace.push(0);
1308        let mut dispspec = build_location_scale_block(
1309            "log_precision",
1310            noise_design.design.clone(),
1311            self.noise_offset.clone(),
1312            disp_penalties,
1313            disp_nullspace,
1314            layout.noise_from(theta),
1315            noise_beta_hint,
1316            1,
1317            LOCATION_SCALE_N_OUTPUTS,
1318            "DispersionLocationScale::build_blocks: log_precision",
1319        )?;
1320
1321        if meanspec.initial_beta.is_none() || dispspec.initial_beta.is_none() {
1322            let (mean_beta0, disp_beta0) = dispersion_location_scale_warm_start(
1323                self.kind,
1324                &self.y,
1325                &self.weights,
1326                &meanspec,
1327                &dispspec,
1328                meanspec.initial_beta.as_ref(),
1329                dispspec.initial_beta.as_ref(),
1330            )?;
1331            if meanspec.initial_beta.is_none() {
1332                meanspec.initial_beta = Some(mean_beta0);
1333            }
1334            if dispspec.initial_beta.is_none() {
1335                dispspec.initial_beta = Some(disp_beta0);
1336            }
1337        }
1338
1339        Ok(vec![meanspec, dispspec])
1340    }
1341
1342    fn build_family(
1343        &self,
1344        mean_design: &TermCollectionDesign,
1345        noise_design: &TermCollectionDesign,
1346    ) -> Self::Family {
1347        // The family stores y/weights/kind directly and does not need the
1348        // designs at construction time, but the row geometry of the offered
1349        // designs is the only cross-check that ties this family back to the
1350        // builder's data — assert it before handing the family to the engine
1351        // so a misaligned design surfaces here rather than downstream in the
1352        // inner solver.
1353        assert_eq!(
1354            mean_design.design.nrows(),
1355            self.y.len(),
1356            "DispersionGlmLocationScale::build_family: mean design row count must match y"
1357        );
1358        assert_eq!(
1359            noise_design.design.nrows(),
1360            self.y.len(),
1361            "DispersionGlmLocationScale::build_family: noise design row count must match y"
1362        );
1363        DispersionGlmLocationScaleFamily {
1364            kind: self.kind,
1365            y: self.y.clone(),
1366            weights: self.weights.clone(),
1367        }
1368    }
1369
1370    fn extract_primary_betas(
1371        &self,
1372        fit: &UnifiedFitResult,
1373    ) -> Result<(Array1<f64>, Array1<f64>), String> {
1374        let mean_beta = fit
1375            .block_states
1376            .get(DispersionGlmLocationScaleFamily::BLOCK_MEAN)
1377            .ok_or_else(|| "missing dispersion mean block state".to_string())?
1378            .beta
1379            .clone();
1380        let disp_beta = fit
1381            .block_states
1382            .get(DispersionGlmLocationScaleFamily::BLOCK_DISP)
1383            .ok_or_else(|| "missing dispersion log-precision block state".to_string())?
1384            .beta
1385            .clone();
1386        Ok((mean_beta, disp_beta))
1387    }
1388
1389    fn build_psiderivative_blocks(
1390        &self,
1391        data: ndarray::ArrayView2<'_, f64>,
1392        meanspec: &TermCollectionSpec,
1393        noisespec: &TermCollectionSpec,
1394        mean_design: &TermCollectionDesign,
1395        noise_design: &TermCollectionDesign,
1396    ) -> Result<Vec<Vec<CustomFamilyBlockPsiDerivative>>, String> {
1397        // The dispersion location-scale families have no closed-form analytic
1398        // spatial psi derivatives, and `fit_dispersion_glm_location_scale_terms`
1399        // disables the κ/ψ joint optimizer before the engine ever asks. If we
1400        // do get called (for example by a future caller that forgets the
1401        // disable), return a real diagnostic rather than a sentinel — emit the
1402        // exact data and design shape that was passed in so the bug is
1403        // diagnosable from the error string alone.
1404        Err(format!(
1405            "dispersion location-scale ({:?}) does not implement analytic spatial \
1406             psi derivatives; the κ/ψ joint optimizer must be disabled before \
1407             this builder is consulted. Called with data {n_rows}×{n_cols}, mean \
1408             spec (linear={mean_lin}, random={mean_re}, smooth={mean_sm}), noise \
1409             spec (linear={noise_lin}, random={noise_re}, smooth={noise_sm}), \
1410             mean design cols={mean_p}, noise design cols={noise_p}",
1411            self.kind,
1412            n_rows = data.nrows(),
1413            n_cols = data.ncols(),
1414            mean_lin = meanspec.linear_terms.len(),
1415            mean_re = meanspec.random_effect_terms.len(),
1416            mean_sm = meanspec.smooth_terms.len(),
1417            noise_lin = noisespec.linear_terms.len(),
1418            noise_re = noisespec.random_effect_terms.len(),
1419            noise_sm = noisespec.smooth_terms.len(),
1420            mean_p = mean_design.design.ncols(),
1421            noise_p = noise_design.design.ncols(),
1422        ))
1423    }
1424}
1425
1426/// Fit a dispersion-channel GAMLSS location-scale model (#913). All four
1427/// genuine-dispersion mean families share this single entry; the per-family
1428/// likelihood lives in [`dispersion_row_kernel`].
1429pub fn fit_dispersion_glm_location_scale_terms(
1430    data: ndarray::ArrayView2<'_, f64>,
1431    spec: DispersionGlmLocationScaleTermSpec,
1432    options: &BlockwiseFitOptions,
1433    kappa_options: &SpatialLengthScaleOptimizationOptions,
1434) -> Result<BlockwiseTermFitResult, String> {
1435    if let DispersionFamilyKind::Tweedie { p } = spec.kind {
1436        if !(p.is_finite() && p > 1.0 && p < 2.0) {
1437            return Err(format!(
1438                "Tweedie location-scale requires a variance power strictly in (1, 2); got p={p}"
1439            ));
1440        }
1441    }
1442    // The κ/ψ anisotropic-kernel joint optimizer needs analytic psi
1443    // derivatives this family does not provide; disable it so the engine runs
1444    // the full ρ REML directly via `fit_custom_family` (1-D and tensor smooth
1445    // penalties λ are still REML-selected).
1446    let mut kappa = kappa_options.clone();
1447    kappa.enabled = false;
1448    // A dispersion location-scale model is an inherently *predictable* model:
1449    // posterior-mean prediction (the response-scale predict path the CLI/FFI
1450    // drive) needs the joint `(β_μ, β_d)` posterior covariance, and so does the
1451    // reported total EDF / coefficient SEs. The block-diagonal joint Hessian is
1452    // always assembled here (`exact_newton_joint_hessian_with_specs` →
1453    // `compute_joint_covariance`, which for this family's `RidgedQuadraticReml`
1454    // outer objective uses the never-erroring SPD-retry → positive-part
1455    // pseudo-inverse), so we can — and must — request the covariance
1456    // unconditionally rather than leaving `covariance_conditional = None`
1457    // whenever the outer optimizer happens to *converge* (the only family-
1458    // independent reason NB sometimes populated covariance was that it escalated
1459    // into the never-fail posterior-sampling rung, while a cleanly-converged
1460    // Gamma/Tweedie fit took the `!options.compute_covariance ⇒ None` early
1461    // return and stranded its covariance/EDF — gam#1119). Forcing the flag here
1462    // makes all four genuine-dispersion mean families assemble the joint
1463    // covariance + EDF deterministically, exactly as a predictable model
1464    // requires.
1465    let mut options = options.clone();
1466    options.compute_covariance = true;
1467    fit_location_scale_terms(
1468        data,
1469        DispersionGlmLocationScaleTermBuilder {
1470            kind: spec.kind,
1471            y: spec.y,
1472            weights: spec.weights,
1473            meanspec: spec.meanspec,
1474            noisespec: spec.log_dispspec,
1475            mean_offset: spec.mean_offset,
1476            noise_offset: spec.log_disp_offset,
1477        },
1478        &options,
1479        &kappa,
1480    )
1481}
1482
1483#[cfg(test)]
1484mod tests {
1485    use super::*;
1486    use super::test_support::{dispersion_gamma_nll_order2, dispersion_nb_nll_order2};
1487    use crate::gamlss::test_support::dispersion_tweedie_nll_generic;
1488
1489    /// Pruned single-axis NB2 dispersion tower: `θ` is the sole jet variable
1490    /// (axis 0), `μ` a constant. `value`/`g[0]`/`h[0][0]` reproduce the consumed
1491    /// `value`/`g[1]`/`h[1][1]` of `dispersion_nb_nll_order2` bit-for-bit. The
1492    /// `Order2` oracle pin for `prune_towers_match_dense_all_channels`; the
1493    /// production NB2 row kernel uses the cheaper `dispersion_nb_disp_order1`.
1494    #[inline]
1495    fn dispersion_nb_disp_order2(
1496        yi: f64,
1497        mu_value: f64,
1498        theta_value: f64,
1499        wi: f64,
1500    ) -> gam_math::jet_scalar::Order2<1> {
1501        use gam_math::jet_scalar::JetScalar;
1502        use statrs::function::gamma::ln_gamma;
1503        type O1 = gam_math::jet_scalar::Order2<1>;
1504
1505        let mu = O1::constant(mu_value);
1506        let theta = O1::variable(theta_value, 0);
1507        let tpm = theta.add(&mu);
1508        let theta_plus_y = theta.add(&O1::constant(yi));
1509        let loglik = order2_ln_gamma(&theta_plus_y)
1510            .sub(&order2_ln_gamma(&theta))
1511            .sub(&O1::constant(ln_gamma(yi + 1.0)))
1512            .add(&theta.mul(&theta.ln()))
1513            .sub(&theta.mul(&tpm.ln()))
1514            .add(&mu.ln().scale(yi))
1515            .sub(&tpm.ln().scale(yi));
1516        loglik.scale(-wi)
1517    }
1518
1519    pub(crate) fn beta_fisher_cross_info_mu_phi(mu: f64, phi: f64) -> f64 {
1520        let a = mu * phi;
1521        let b = (1.0 - mu) * phi;
1522        phi * (mu * gam_math::jet_tower::trigamma_derivative_stack(a)[0]
1523            - (1.0 - mu) * gam_math::jet_tower::trigamma_derivative_stack(b)[0])
1524    }
1525
1526    pub(crate) fn assert_close(label: &str, got: f64, want: f64, tol: f64) {
1527        assert!(
1528            (got - want).abs() <= tol,
1529            "{label}: got {got:.12e}, want {want:.12e}, |diff|={:.3e}",
1530            (got - want).abs()
1531        );
1532    }
1533
1534    #[test]
1535    pub(crate) fn beta_tower_mixed_channel_matches_cross_information_formula() {
1536        let mu = 0.1;
1537        let phi = 10.0;
1538        let a = mu * phi;
1539        let b = (1.0 - mu) * phi;
1540        let digamma_a = gam_math::jet_tower::digamma_derivative_stack(a)[0];
1541        let digamma_b = gam_math::jet_tower::digamma_derivative_stack(b)[0];
1542        let score_neutral_y = 1.0 / (1.0 + (-(digamma_a - digamma_b)).exp());
1543
1544        let tower = dispersion_beta_nll_order2(score_neutral_y, mu, phi, 1.0);
1545        let trigamma_a = std::f64::consts::PI * std::f64::consts::PI / 6.0;
1546        let trigamma_b = gam_math::jet_tower::trigamma_derivative_stack(b)[0];
1547        let analytic = phi * (mu * trigamma_a - (1.0 - mu) * trigamma_b);
1548        let helper = beta_fisher_cross_info_mu_phi(mu, phi);
1549
1550        assert!(
1551            analytic > 0.58,
1552            "audit example should have visibly nonzero cross information, got {analytic}"
1553        );
1554        assert_close("helper cross information", helper, analytic, 1e-12);
1555        assert_close("tower mixed channel", tower.h()[0][1], analytic, 1e-8);
1556
1557        let q = mu * (1.0 - mu);
1558        let eta_cross = beta_observed_cross_weight_eta(score_neutral_y, mu, phi, 1.0);
1559        assert_close(
1560            "eta-scale cross weight",
1561            eta_cross,
1562            q * phi * analytic,
1563            1e-8,
1564        );
1565    }
1566
1567    /// #932 oracle: the production `Order2<2>` evaluation of each dispersion
1568    /// row NLL must reproduce, channel-for-channel (value/grad/Hessian), the
1569    /// dense `Tower4<2>` evaluation of the same row expression.
1570    #[test]
1571    pub(crate) fn order2_matches_dense_tower_all_channels() {
1572        use gam_math::jet_scalar::{JetScalar, Order2};
1573        use gam_math::jet_tower::Tower4;
1574
1575        fn check_o2_vs_tower4(label: &str, o2: Order2<2>, t4: Tower4<2>) {
1576            let band = |a: f64, b: f64| 1e-9 + 1e-9 * a.abs().max(b.abs());
1577            assert!(
1578                (o2.value() - t4.v).abs() <= band(o2.value(), t4.v),
1579                "{label} value: {} vs {}",
1580                o2.value(),
1581                t4.v
1582            );
1583            for a in 0..2 {
1584                assert!(
1585                    (o2.g()[a] - t4.g[a]).abs() <= band(o2.g()[a], t4.g[a]),
1586                    "{label} grad[{a}]: {} vs {}",
1587                    o2.g()[a],
1588                    t4.g[a]
1589                );
1590                for b in 0..2 {
1591                    assert!(
1592                        (o2.h()[a][b] - t4.h[a][b]).abs() <= band(o2.h()[a][b], t4.h[a][b]),
1593                        "{label} hess[{a}][{b}]: {} vs {}",
1594                        o2.h()[a][b],
1595                        t4.h[a][b]
1596                    );
1597                }
1598            }
1599        }
1600
1601        let wi = 1.7_f64;
1602        // NB2: (μ, θ).
1603        for &(yi, mu, theta) in &[(0.0, 1.2, 3.0), (4.0, 2.5, 0.7), (10.0, 0.6, 5.0)] {
1604            check_o2_vs_tower4(
1605                "nb",
1606                dispersion_nb_nll_order2(yi, mu, theta, wi),
1607                test_support::dispersion_nb_nll_generic::<Tower4<2>>(yi, mu, theta, wi),
1608            );
1609        }
1610        // Gamma: (μ, ν).
1611        for &(yi, mu, nu) in &[
1612            (0.5_f64, 1.1_f64, 2.0_f64),
1613            (3.0, 4.0, 0.9),
1614            (1.0, 0.3, 6.0),
1615        ] {
1616            let y_pos = yi.max(1e-300);
1617            check_o2_vs_tower4(
1618                "gamma",
1619                dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi),
1620                test_support::dispersion_gamma_nll_generic::<Tower4<2>>(yi, y_pos, mu, nu, wi),
1621            );
1622        }
1623        // Beta: (μ, φ).
1624        for &(yi, mu, phi) in &[(0.3, 0.4, 5.0), (0.9, 0.6, 12.0), (0.01, 0.2, 3.0)] {
1625            check_o2_vs_tower4(
1626                "beta",
1627                dispersion_beta_nll_order2(yi, mu, phi, wi),
1628                test_support::dispersion_beta_nll_generic::<Tower4<2>>(yi, mu, phi, wi),
1629            );
1630        }
1631        // Tweedie: (η_μ, η_d), both density branches.
1632        for &(yi, eta_mu, eta_d, p) in &[
1633            (0.0, 0.4, -0.3, 1.5),
1634            (2.5, -0.2, 0.5, 1.3),
1635            (0.0, 1.0, 0.1, 1.7),
1636            (5.0, 0.7, -0.6, 1.6),
1637        ] {
1638            check_o2_vs_tower4(
1639                "tweedie",
1640                dispersion_tweedie_nll_generic::<Order2<2>>(yi, eta_mu, eta_d, p, wi),
1641                dispersion_tweedie_nll_generic::<Tower4<2>>(yi, eta_mu, eta_d, p, wi),
1642            );
1643        }
1644    }
1645
1646    /// #1591 prune oracle: the pruned single-axis (`K=1`) dispersion towers
1647    /// reproduce, `to_bits`-exactly, the CONSUMED channels (`value`, dispersion-
1648    /// axis `g`/`h`) of the full `Order2<2>` towers — across ≥2000 randomized
1649    /// rows per family (both Tweedie density branches), including the η-clamp
1650    /// boundary. This is the bit-identity guarantee that the K-prune changes no
1651    /// observable float.
1652    #[test]
1653    pub(crate) fn pruned_disp_towers_bit_identical_to_full_order2() {
1654        use gam_math::jet_scalar::{JetScalar, Order2};
1655
1656        // Deterministic LCG so the sweep is reproducible without an rng dep.
1657        let mut state: u64 = 0x9E3779B97F4A7C15;
1658        let mut next = || {
1659            state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
1660            ((state >> 11) as f64) / ((1u64 << 53) as f64)
1661        };
1662        let bits = |x: f64| x.to_bits();
1663
1664        let n_per = 600; // 600 rows × 4 families (Tweedie ×2 branches) > 2000.
1665        for _ in 0..n_per {
1666            let wi = 0.25 + 3.0 * next();
1667            let yi_count = (next() * 12.0).floor();
1668
1669            // NB: full O2<2> seeds (μ, θ); pruned seeds θ only.
1670            {
1671                let mu = (0.05 + 4.0 * next()).max(1e-300);
1672                let theta = (0.05 + 6.0 * next()).max(1e-12);
1673                let full = dispersion_nb_nll_order2(yi_count, mu, theta, wi);
1674                let prn = dispersion_nb_disp_order2(yi_count, mu, theta, wi);
1675                assert_eq!(bits(full.value()), bits(prn.value()), "nb value");
1676                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "nb grad");
1677                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "nb hess");
1678                // Value+gradient-only production tower (`Order1`, trigamma-free):
1679                // its consumed `value`/`g[0]` must match the `Order2` form
1680                // bit-for-bit (the observed Hessian it drops is unused by the NB2
1681                // Fisher-scoring row kernel).
1682                let prn1 = dispersion_nb_disp_order1(yi_count, mu, theta, wi);
1683                assert_eq!(bits(prn.value()), bits(prn1.value()), "nb order1 value");
1684                assert_eq!(bits(prn.g()[0]), bits(prn1.g()[0]), "nb order1 grad");
1685                // value-only path == -tower.value(), bit-for-bit.
1686                assert_eq!(
1687                    bits(dispersion_nb_neg_loglik(yi_count, mu, theta, wi)),
1688                    bits(-prn.value()),
1689                    "nb value-only"
1690                );
1691            }
1692            // Gamma: seeds (μ, ν) / ν.
1693            {
1694                let mu = (0.05 + 4.0 * next()).max(1e-300);
1695                let nu = (0.05 + 6.0 * next()).max(1e-12);
1696                let yi = 0.01 + 8.0 * next();
1697                let y_pos = yi.max(1e-300);
1698                let full = dispersion_gamma_nll_order2(yi, y_pos, mu, nu, wi);
1699                let prn = dispersion_gamma_disp_order2(yi, y_pos, mu, nu, wi);
1700                assert_eq!(bits(full.value()), bits(prn.value()), "gamma value");
1701                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "gamma grad");
1702                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "gamma hess");
1703                assert_eq!(
1704                    bits(dispersion_gamma_neg_loglik(yi, y_pos, mu, nu, wi)),
1705                    bits(-prn.value()),
1706                    "gamma value-only"
1707                );
1708            }
1709            // Beta value-only path vs full K=2 tower value.
1710            {
1711                let mu = (1e-6 + (1.0 - 2e-6) * next()).clamp(1e-12, 1.0 - 1e-12);
1712                let phi = (0.05 + 20.0 * next()).max(1e-12);
1713                let yi = next();
1714                let full = dispersion_beta_nll_order2(yi, mu, phi, wi);
1715                assert_eq!(
1716                    bits(dispersion_beta_neg_loglik(yi, mu, phi, wi)),
1717                    bits(-full.value()),
1718                    "beta value-only"
1719                );
1720            }
1721            // Tweedie: seeds (η_μ, η_d) / η_d, both density branches & clamp edge.
1722            for &(yi, eta_mu, eta_d, p) in &[
1723                (0.0_f64, -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1724                (0.01 + 9.0 * next(), -4.0 + 8.0 * next(), -4.0 + 8.0 * next(), 1.1 + 0.8 * next()),
1725                (3.0, -DISPERSION_ETA_CLAMP - 5.0, DISPERSION_ETA_CLAMP + 5.0, 1.5),
1726            ] {
1727                let em = eta_mu.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1728                let ed = eta_d.clamp(-DISPERSION_ETA_CLAMP, DISPERSION_ETA_CLAMP);
1729                let full = dispersion_tweedie_nll_generic::<Order2<2>>(yi, em, ed, p, wi);
1730                let prn = dispersion_tweedie_disp_order2(yi, em, ed, p, wi);
1731                assert_eq!(bits(full.value()), bits(prn.value()), "tweedie value");
1732                assert_eq!(bits(full.g()[1]), bits(prn.g()[0]), "tweedie grad");
1733                assert_eq!(bits(full.h()[1][1]), bits(prn.h()[0][0]), "tweedie hess");
1734                assert_eq!(
1735                    bits(dispersion_tweedie_neg_loglik(yi, em, ed, p, wi)),
1736                    bits(-prn.value()),
1737                    "tweedie value-only"
1738                );
1739            }
1740        }
1741    }
1742
1743    #[test]
1744    pub(crate) fn orthogonal_dispersion_families_report_zero_cross_weight() {
1745        let cases = [
1746            DispersionFamilyKind::NegativeBinomial,
1747            DispersionFamilyKind::Gamma,
1748            DispersionFamilyKind::Tweedie { p: 1.5 },
1749        ];
1750        for kind in cases {
1751            let got = dispersion_row_cross_weight(kind, 1.25, 0.2, -0.3, 2.0);
1752            assert_close(kind.family_tag(), got, 0.0, 1e-12);
1753        }
1754    }
1755
1756    /// Speed-path guard (#932): `evaluate` / `log_likelihood_only` materialize
1757    /// the row-kernel map in parallel for large `n`, then reduce SERIALLY in
1758    /// index order. This pins the parallel output (log-likelihood + both
1759    /// blocks' working response/weight vectors) to a hand-rolled serial
1760    /// reference so CI catches any reassociation or row-misindex regression.
1761    /// `n` sits well above `DISPERSION_PARALLEL_ROW_THRESHOLD`, and the test
1762    /// runs on the main thread (not a rayon worker), so the parallel branch is
1763    /// the one exercised. Because the reduction order is preserved the match is
1764    /// in fact bit-exact; the `1e-9` band is the contract floor.
1765    #[test]
1766    pub(crate) fn parallel_evaluate_matches_serial_reference() {
1767        let n = DISPERSION_PARALLEL_ROW_THRESHOLD * 3 + 7;
1768        // Deterministic LCG row data (no rng dependency).
1769        let mut state: u64 = 0xD1B5_4A32_D192_ED03;
1770        let mut next = || {
1771            state = state
1772                .wrapping_mul(6364136223846793005)
1773                .wrapping_add(1442695040888963407);
1774            ((state >> 11) as f64) / ((1u64 << 53) as f64)
1775        };
1776
1777        for kind in [
1778            DispersionFamilyKind::NegativeBinomial,
1779            DispersionFamilyKind::Gamma,
1780            DispersionFamilyKind::Beta,
1781            DispersionFamilyKind::Tweedie { p: 1.5 },
1782        ] {
1783            let y = Array1::from_shape_fn(n, |_| match kind {
1784                DispersionFamilyKind::Beta => 1e-3 + (1.0 - 2e-3) * next(),
1785                DispersionFamilyKind::NegativeBinomial => (next() * 12.0).floor(),
1786                _ => 0.05 + 8.0 * next(),
1787            });
1788            let weights = Array1::from_shape_fn(n, |_| 0.25 + 2.0 * next());
1789            let eta_mu = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1790            let eta_d = Array1::from_shape_fn(n, |_| -1.0 + 2.0 * next());
1791
1792            let family = DispersionGlmLocationScaleFamily {
1793                kind,
1794                y: y.clone(),
1795                weights: weights.clone(),
1796            };
1797            let states = vec![
1798                ParameterBlockState {
1799                    beta: Array1::zeros(0),
1800                    eta: eta_mu.clone(),
1801                },
1802                ParameterBlockState {
1803                    beta: Array1::zeros(0),
1804                    eta: eta_d.clone(),
1805                },
1806            ];
1807
1808            // Serial reference, computed exactly as the pre-parallel loop did.
1809            let mut ll_ref = 0.0;
1810            let mut mw_ref = Array1::<f64>::zeros(n);
1811            let mut mr_ref = Array1::<f64>::zeros(n);
1812            let mut dw_ref = Array1::<f64>::zeros(n);
1813            let mut dr_ref = Array1::<f64>::zeros(n);
1814            for i in 0..n {
1815                let row = dispersion_row_kernel(kind, y[i], eta_mu[i], eta_d[i], weights[i]);
1816                if row.loglik.is_finite() {
1817                    ll_ref += row.loglik;
1818                }
1819                mw_ref[i] = row.mean_weight.max(0.0);
1820                mr_ref[i] = row.mean_response;
1821                dw_ref[i] = row.disp_weight.max(0.0);
1822                dr_ref[i] = row.disp_response;
1823            }
1824
1825            let eval = family.evaluate(&states).expect("parallel evaluate");
1826            assert_close(
1827                &format!("{kind:?} evaluate log-likelihood"),
1828                eval.log_likelihood,
1829                ll_ref,
1830                1e-9,
1831            );
1832
1833            let BlockWorkingSet::Diagonal {
1834                working_response: mr,
1835                working_weights: mw,
1836            } = &eval.blockworking_sets[0]
1837            else {
1838                panic!("mean block not diagonal");
1839            };
1840            let BlockWorkingSet::Diagonal {
1841                working_response: dr,
1842                working_weights: dw,
1843            } = &eval.blockworking_sets[1]
1844            else {
1845                panic!("dispersion block not diagonal");
1846            };
1847            for i in 0..n {
1848                assert_close("mean weight", mw[i], mw_ref[i], 1e-9);
1849                assert_close("mean response", mr[i], mr_ref[i], 1e-9);
1850                assert_close("disp weight", dw[i], dw_ref[i], 1e-9);
1851                assert_close("disp response", dr[i], dr_ref[i], 1e-9);
1852            }
1853
1854            // `log_likelihood_only` takes the same parallel-then-serial-sum
1855            // path; its value-only kernel is bit-identical to evaluate's loglik.
1856            let ll_only = family
1857                .log_likelihood_only(&states)
1858                .expect("parallel log_likelihood_only");
1859            assert_close(
1860                &format!("{kind:?} log_likelihood_only"),
1861                ll_only,
1862                ll_ref,
1863                1e-9,
1864            );
1865        }
1866    }
1867}