Skip to main content

gam_solve/pirls/
curvature.rs

1//! Curvature primitives: the variance-function jet, observed-information
2//! Hessian weights, and the weight-family / weight-link classification used to
3//! choose between Fisher and observed curvature per family.
4
5use super::*;
6
7pub struct VarianceJet {
8    pub v: f64,
9    pub v1: f64,
10    pub v2: f64,
11    pub v3: f64,
12    pub v4: f64,
13}
14
15impl VarianceJet {
16    /// Lower floor on μ before evaluating power-law variance functions, so that
17    /// `μ^(p−k)` derivatives stay finite as μ → 0 instead of producing inf/NaN.
18    const VARIANCE_MU_FLOOR: f64 = 1e-10;
19
20    /// Bernoulli / binomial variance V(μ) = μ(1−μ).
21    #[inline]
22    pub fn bernoulli(mu: f64) -> Self {
23        Self {
24            v: mu * (1.0 - mu),
25            v1: 1.0 - 2.0 * mu,
26            v2: -2.0,
27            v3: 0.0,
28            v4: 0.0,
29        }
30    }
31
32    /// Poisson variance V(μ) = μ.
33    #[inline]
34    pub fn poisson(mu: f64) -> Self {
35        Self {
36            v: mu,
37            v1: 1.0,
38            v2: 0.0,
39            v3: 0.0,
40            v4: 0.0,
41        }
42    }
43
44    /// Gamma variance V(μ) = μ².
45    #[inline]
46    pub fn gamma(mu: f64) -> Self {
47        Self {
48            v: mu * mu,
49            v1: 2.0 * mu,
50            v2: 2.0,
51            v3: 0.0,
52            v4: 0.0,
53        }
54    }
55
56    /// Tweedie variance V(μ) = μ^p.
57    #[inline]
58    pub fn tweedie(mu: f64, p: f64) -> Self {
59        let mu = mu.max(Self::VARIANCE_MU_FLOOR);
60        Self {
61            v: mu.powf(p),
62            v1: p * mu.powf(p - 1.0),
63            v2: p * (p - 1.0) * mu.powf(p - 2.0),
64            v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
65            v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
66        }
67    }
68
69    /// Negative-binomial variance V(μ) = μ + μ² / theta.
70    #[inline]
71    pub fn negative_binomial(mu: f64, theta: f64) -> Self {
72        let mu = mu.max(Self::VARIANCE_MU_FLOOR);
73        let inv_theta = if valid_negbin_theta(theta) {
74            1.0 / theta
75        } else {
76            f64::NAN
77        };
78        Self {
79            v: mu + mu * mu * inv_theta,
80            v1: 1.0 + 2.0 * mu * inv_theta,
81            v2: 2.0 * inv_theta,
82            v3: 0.0,
83            v4: 0.0,
84        }
85    }
86
87    /// Gaussian (identity) variance V(μ) = 1.
88    #[inline]
89    pub fn gaussian() -> Self {
90        Self {
91            v: 1.0,
92            v1: 0.0,
93            v2: 0.0,
94            v3: 0.0,
95            v4: 0.0,
96        }
97    }
98
99    /// Binomial(n, p) variance V(p) = p(1−p), identical to Bernoulli.
100    ///
101    /// The trial count `n` enters as a prior-weight multiplier, not through
102    /// the variance function itself.
103    #[inline]
104    pub fn binomial_n(mu: f64) -> Self {
105        // V(μ) = μ(1−μ), same jet as Bernoulli
106        Self::bernoulli(mu)
107    }
108
109    /// Beta-regression variance V(μ) = μ(1−μ)/(1+φ).
110    #[inline]
111    pub fn beta(mu: f64, phi: f64) -> Self {
112        let scale = 1.0 / (1.0 + phi.max(1e-12));
113        let base = Self::bernoulli(mu);
114        Self {
115            v: base.v * scale,
116            v1: base.v1 * scale,
117            v2: base.v2 * scale,
118            v3: 0.0,
119            v4: 0.0,
120        }
121    }
122}
123
124pub(crate) const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
125
126pub(crate) const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
127
128/// Returns the per-row floor `max(fisher · 1e-6, 1e-12)` used by PIRLS to
129/// stabilize the observed-information Hessian H = X' W X + S. Saturated
130/// rows where W_obs ≤ floor were silently raised to `floor` when PIRLS
131/// built the inner Hessian; outer REML/LAML derivatives must use the
132/// **same** floored W to keep `H` and `dH/dψ` on one surface.
133///
134/// This is the single source of truth for the floor formula. Both the
135/// inner solver (`solver_hessian_weights_into`) and the outer derivative
136/// path (`outer_hessian_curvature_arrays`) route through this helper so
137/// the inner-stabilized H and the outer dH/dψ cannot drift apart.
138#[inline]
139pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
140    (fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
141        .max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
142}
143
144/// Build the (W, c, d) triple that matches PIRLS's stabilized H = X' W X + S.
145///
146/// PIRLS internally uses `W[i] = max(W_obs[i], floor(W_F[i]))` to keep H PD,
147/// but `pirls_result.finalweights` stores the **unfloored** observed weights.
148/// Reusing those directly in `∂H/∂ψ = X_τ' W X + … + X' diag(c · X_τ β̂) X`
149/// produces an operator that disagrees with `H` at every saturated row — a
150/// 5%-Frobenius bias that `tr(G_ε(H) · op)` amplifies by O(1/σ_min(H)),
151/// driving the analytic gradient off by orders of magnitude.
152///
153/// This helper returns the floored W, plus c and d masked to zero wherever
154/// the floor is active (so `∂W/∂η` is zero on the constant-floor branch).
155pub fn outer_hessian_curvature_arrays(
156    hessian_weights: gam_linalg::matrix::SignedWeightsView<'_>,
157    fisher_weights: gam_linalg::matrix::PsdWeightsView<'_>,
158    c_array: &Array1<f64>,
159    d_array: &Array1<f64>,
160    eta: &Array1<f64>,
161    inverse_link: &InverseLink,
162) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
163    let hessian_view = hessian_weights.view();
164    let fisher_view = fisher_weights.view();
165    let n = hessian_view.len();
166    let mut w_out = Array1::<f64>::zeros(n);
167    let mut c_out = Array1::<f64>::zeros(n);
168    let mut d_out = Array1::<f64>::zeros(n);
169    for i in 0..n {
170        let floor = solver_hessian_weight_floor(fisher_view[i]);
171        let w = hessian_view[i];
172        let clamp_active = eta_clamp_active(inverse_link, eta[i]);
173        let w_below_floor = !(w.is_finite() && w > floor);
174        if w_below_floor {
175            w_out[i] = floor;
176            c_out[i] = 0.0;
177            d_out[i] = 0.0;
178        } else if clamp_active {
179            w_out[i] = w;
180            c_out[i] = 0.0;
181            d_out[i] = 0.0;
182        } else {
183            w_out[i] = w;
184            c_out[i] = c_array[i];
185            d_out[i] = d_array[i];
186        }
187    }
188    (w_out, c_out, d_out)
189}
190
191#[inline]
192pub(crate) fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
193    likelihood.fixed_phi().unwrap_or(1.0)
194}
195
196#[inline]
197pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
198    match &likelihood.spec.response {
199        ResponseFamily::Gaussian => WeightFamily::Gaussian,
200        ResponseFamily::Poisson => WeightFamily::Poisson,
201        ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
202        ResponseFamily::NegativeBinomial { theta, .. } => {
203            WeightFamily::NegativeBinomial { theta: *theta }
204        }
205        ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
206        ResponseFamily::Gamma => WeightFamily::Gamma,
207        ResponseFamily::Binomial => WeightFamily::Binomial,
208        ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
209    }
210}
211
212#[inline]
213pub(crate) fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
214    match inverse_link {
215        InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
216        InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
217        InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
218        InverseLink::Standard(StandardLink::Probit)
219        | InverseLink::Standard(StandardLink::CLogLog)
220        | InverseLink::LatentCLogLog(_)
221        | InverseLink::Sas(_)
222        | InverseLink::BetaLogistic(_)
223        | InverseLink::Mixture(_) => WeightLink::Other,
224    }
225}
226
227#[inline]
228pub(crate) fn supports_observed_hessian_curvature_for_likelihood(
229    likelihood: &GlmLikelihoodSpec,
230    inverse_link: &InverseLink,
231) -> bool {
232    let spec = &likelihood.spec;
233    if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
234        return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
235    }
236    if matches!(spec.response, ResponseFamily::Gamma) {
237        return true;
238    }
239    if !matches!(spec.response, ResponseFamily::Binomial) {
240        return false;
241    }
242    matches!(
243        spec.link,
244        InverseLink::Standard(StandardLink::Probit)
245            | InverseLink::Standard(StandardLink::CLogLog)
246            | InverseLink::Sas(_)
247            | InverseLink::BetaLogistic(_)
248            | InverseLink::Mixture(_)
249    )
250}
251
252#[inline]
253pub(crate) fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
254    match inverse_link {
255        // Why: canonical links keep V(mu) representable across the full f64 eta range; only guard against inf.
256        InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
257            eta.clamp(-ETA_CLAMP, ETA_CLAMP)
258        }
259        InverseLink::Standard(StandardLink::Identity) => eta,
260        // Why: probit mu=Phi(eta) saturates to 1.0 in f64 by |eta|~8.3; +/-6 keeps V=mu(1-mu) ~ 1e-9 representable.
261        InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
262        // Why: cloglog has mu~exp(eta) for eta<<0 (underflows below ~-23) and 1-mu~exp(-exp(eta)) collapses by eta=3.
263        InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
264            eta.clamp(-23.0, 3.0)
265        }
266        // Why: SAS / beta-logistic / mixture compose logistic-like sigmoids that saturate by |eta|~20 (logistic(20)~1-2e-9).
267        InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
268            eta.clamp(-20.0, 20.0)
269        }
270    }
271}
272
273/// Returns true at rows where PIRLS clamped η (so the observed-info weights
274/// were computed at the clamped value, making `∂W/∂η` zero w.r.t. the
275/// **unclamped** η).  Outer REML/LAML derivative formulas must mask `c_obs`
276/// and `d_obs` to zero on these rows or the analytic ∂H/∂ψ disagrees with
277/// the H whose log-det we differentiate.
278#[inline]
279pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
280    let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
281    clamped != eta
282}
283
284/// Build solver-conditioned weights from the exact hessian weights.
285///
286/// The returned array applies a solver-only floor per observation so the
287/// Newton linear system X'W X + S stays numerically usable. This floor is
288/// purely a linear-algebra concern: the exact statistical weights stored in
289/// `lasthessian_weights` / `finalweights` are not affected.
290pub(crate) fn solver_hessian_weights_into(
291    hessian_weights: &Array1<f64>,
292    fisher_weights: &Array1<f64>,
293    out: &mut Array1<f64>,
294) {
295    if out.len() != hessian_weights.len() {
296        *out = Array1::<f64>::zeros(hessian_weights.len());
297    }
298    ndarray::Zip::from(out)
299        .and(hessian_weights)
300        .and(fisher_weights)
301        .par_for_each(|o, &w, &fw| {
302            let floor = solver_hessian_weight_floor(fw);
303            *o = if w.is_finite() && w > floor { w } else { floor };
304        });
305}
306
307/// Compute vectorised observed-information curvature arrays (w_obs, c_obs, d_obs)
308/// for the Hessian surface at the mode.
309///
310/// This function is the primary entry point for obtaining the observed weights
311/// that flow into the outer REML/LAML Hessian H_obs = X' W_obs X + S. The
312/// observed corrections include residual-dependent terms that vanish for
313/// canonical links but are nonzero for probit, cloglog, SAS, mixture, Gamma-log,
314/// and other flexible links.
315///
316/// The output arrays are:
317/// - `hessian_weights`: W_obs per observation (exact; solver floor applied separately).
318/// - `hessian_c`: c_obs = dW_obs/deta per observation (for outer gradient C[v]).
319/// - `hessian_d`: d_obs = d^2W_obs/deta^2 per observation (for outer Hessian Q[v_k,v_l]).
320///
321/// See `observed_weight_noncanonical` for the per-observation formulas and
322/// response.md Section 3 for the mathematical justification of why observed
323/// (not Fisher) information is required.
324pub(crate) fn compute_observed_hessian_curvature_arrays_into(
325    likelihood: &GlmLikelihoodSpec,
326    inverse_link: &InverseLink,
327    eta: &Array1<f64>,
328    y: ArrayView1<'_, f64>,
329    fisher_weights: &Array1<f64>,
330    priorweights: ArrayView1<'_, f64>,
331    hessian_weights: &mut Array1<f64>,
332    hessian_c: &mut Array1<f64>,
333    hessian_d: &mut Array1<f64>,
334) -> Result<(), EstimationError> {
335    assert!(supports_observed_hessian_curvature_for_likelihood(
336        likelihood,
337        inverse_link
338    ));
339    let n = eta.len();
340    if hessian_weights.len() != n {
341        *hessian_weights = Array1::<f64>::zeros(n);
342    }
343    if hessian_c.len() != n {
344        *hessian_c = Array1::<f64>::zeros(n);
345    }
346    if hessian_d.len() != n {
347        *hessian_d = Array1::<f64>::zeros(n);
348    }
349
350    let weight_family = weight_family_for_glm_likelihood(likelihood);
351    let weight_link = weight_link_for_inverse_link(inverse_link);
352    let phi = fixed_glm_dispersion(likelihood);
353
354    // Parallel per-row weight assembly. At large scale (n = 320k) this loop
355    // dominates non-canonical paths because each row independently evaluates
356    // inverse-link jets and residual-dependent observed curvature. Write
357    // directly into reusable output slices rather than collecting row tuples,
358    // which removes an O(n) temporary allocation on every PIRLS update.
359    hessian_weights
360        .as_slice_mut()
361        .expect("hessian weights must be contiguous")
362        .par_iter_mut()
363        .zip(
364            hessian_c
365                .as_slice_mut()
366                .expect("hessian c must be contiguous")
367                .par_iter_mut(),
368        )
369        .zip(
370            hessian_d
371                .as_slice_mut()
372                .expect("hessian d must be contiguous")
373                .par_iter_mut(),
374        )
375        .enumerate()
376        .try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
377            let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
378            // Why: closed-form observed_weight_noncanonical requires (mu, d1..d3, h4) at one consistent eta;
379            // mixing PIRLS-state jets at unclamped eta with h4 at eta_used produced 0/0 in phi_v* divisions,
380            // surfacing as: "observed Hessian curvature is not positive finite at row N: observed=NaN, fisher=0".
381            let jet =
382                crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
383            let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
384                inverse_link, eta_used,
385            )?;
386            let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
387                weight_family,
388                weight_link,
389                eta_used,
390                y[i],
391                jet.mu,
392                phi,
393                priorweights[i].max(0.0),
394                jet,
395                h4,
396            );
397            let fisher_weight = fisher_weights[i].max(0.0);
398            // A *finite* but non-positive observed weight is NOT a failure: the
399            // observed information `W_obs = W_Fisher - (y-μ)·B` legitimately goes
400            // indefinite on individual rows for a non-canonical link (probit,
401            // cloglog, SAS, and — critically for #1598 — a blended/mixture link)
402            // whenever a large residual flips the sign of the residual-dependent
403            // correction. The inner Newton system never uses this raw value: it
404            // is clamped to the SPD floor `max(W_Fisher·1e-6, 1e-12)` by
405            // `solver_hessian_weights_into`, and the outer REML/LAML derivative
406            // path applies the *same* floor through `outer_hessian_curvature_arrays`
407            // (which also zeroes c/d on the floored row). Both consumers are
408            // designed precisely to absorb an indefinite W_obs, so hard-bailing
409            // here defeats that stabilization and aborts an otherwise well-posed
410            // solve — the mixture/SAS joint link fit on data its own pure
411            // components fit trivially (clean logit under blended(logit, probit)).
412            //
413            // We therefore reject ONLY a genuinely non-finite (NaN/Inf) weight,
414            // which signals a broken jet rather than benign indefiniteness, and
415            // pass finite values (including non-positive ones) straight through to
416            // the flooring consumers. Likewise `c_obs`/`d_obs` only need to be
417            // finite; they are zeroed automatically on any floored row downstream.
418            if !w_obs.is_finite() {
419                crate::bail_invalid_estim!(
420                    "observed Hessian curvature is not finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
421                );
422            }
423            if !c_obs.is_finite() || !d_obs.is_finite() {
424                crate::bail_invalid_estim!(
425                    "observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
426                );
427            }
428            *w_out = w_obs;
429            *c_out = c_obs;
430            *d_out = d_obs;
431            Ok(())
432        })
433}
434
435pub(crate) fn compute_observed_hessian_curvature_arrays(
436    likelihood: &GlmLikelihoodSpec,
437    inverse_link: &InverseLink,
438    eta: &Array1<f64>,
439    y: ArrayView1<'_, f64>,
440    fisher_weights: &Array1<f64>,
441    priorweights: ArrayView1<'_, f64>,
442) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
443    let n = eta.len();
444    let mut hessian_weights = Array1::<f64>::zeros(n);
445    let mut hessian_c = Array1::<f64>::zeros(n);
446    let mut hessian_d = Array1::<f64>::zeros(n);
447    compute_observed_hessian_curvature_arrays_into(
448        likelihood,
449        inverse_link,
450        eta,
451        y,
452        fisher_weights,
453        priorweights,
454        &mut hessian_weights,
455        &mut hessian_c,
456        &mut hessian_d,
457    )?;
458    Ok((hessian_weights, hessian_c, hessian_d))
459}
460
461/// Per-observation observed-information weights and their first two
462/// eta-derivatives for a general exponential-dispersion family with a
463/// noncanonical link.
464///
465/// The observed weight differs from the Fisher (expected) weight by a
466/// residual-dependent correction (see response.md Section 3):
467///
468///   W_obs = W_Fisher - (y - mu) * B
469///   B = (h'' V - h'^2 V') / (phi V^2)
470///
471///   c_obs = c_Fisher + h' * B - (y - mu) * B_eta
472///   d_obs = d_Fisher + h'' * B + 2*h' * B_eta - (y - mu) * B_etaeta
473///
474/// For canonical links (for example logit-Binomial and log-Poisson), B = 0
475/// so observed = Fisher and no correction is needed.
476///
477/// These observed quantities are required for:
478/// 1. The outer REML/LAML Hessian H_obs = X' W_obs X + S (log|H| term).
479/// 2. The outer gradient's C[v] correction (uses c_obs).
480/// 3. The outer Hessian's Q[v_k, v_l] correction (uses d_obs).
481///
482/// Using Fisher weights in the outer REML would yield a PQL-type surrogate
483/// rather than the exact Laplace approximation.
484///
485/// # Arguments
486/// * `y`   -- response value
487/// * `mu`  -- fitted mean h(eta)
488/// * `h1`...`h4` -- inverse-link derivatives h'(eta) ... h''''(eta)
489/// * `vj`  -- variance-function jet (V, V', V'', V''') evaluated at mu
490/// * `phi` -- dispersion parameter (1.0 for Bernoulli/Poisson)
491/// * `pw`  -- prior weight for this observation
492///
493/// # Returns
494/// `(w_obs, c_obs, d_obs)` -- the observed weight and its first two
495/// eta-derivatives, all pre-multiplied by `pw`.
496#[inline]
497pub fn observed_weight_noncanonical(
498    y: f64,
499    mu: f64,
500    h1: f64,
501    h2: f64,
502    h3: f64,
503    h4: f64,
504    vj: VarianceJet,
505    phi: f64,
506    pw: f64,
507) -> (f64, f64, f64) {
508    let VarianceJet {
509        v,
510        v1,
511        v2,
512        v3,
513        v4: _,
514    } = vj;
515    let phi_v = phi * v;
516    let phi_v2 = phi * v * v;
517    let phi_v3 = phi * v * v * v;
518
519    // ---- Fisher weight and derivatives ----
520    let h1_sq = h1 * h1;
521    let w_f = h1_sq / phi_v;
522
523    // c_F = (2 h₁ h₂ V − h₁³ V₁) / (φ V²)
524    let n0 = h1_sq; // numerator of w_F
525    let n1 = 2.0 * h1 * h2; // ∂(h₁²)/∂η
526    let n2 = 2.0 * (h2 * h2 + h1 * h3); // ∂²(h₁²)/∂η²
527    let vd1 = h1 * v1; // ∂V/∂η = V'·h'
528    let vd2 = h2 * v1 + h1_sq * v2; // ∂²V/∂η²
529
530    let c_f = (n1 * v - n0 * vd1) / phi_v2;
531
532    // d_F = ∂c_F/∂η via quotient rule on c_F = (n1·v − n0·vd1) / (φ·v²)
533    // numerator of c_F and its η-derivative (cross terms cancel):
534    let numer_cf = n1 * v - n0 * vd1;
535    let dnumer_cf = n2 * v - n0 * vd2;
536    let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
537
538    // ---- Observed correction term B and its η-derivatives ----
539    // B = (h₂ V − h₁² V₁) / (φ V²)
540    let b_num = h2 * v - h1_sq * v1;
541    let b = b_num / phi_v2;
542
543    // B_η = (h₃ V² − 3 h₁ h₂ V V₁ − h₁³ V V₂ + 2 h₁³ V₁²) / (φ V³)
544    let b_eta_num =
545        h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
546    let b_eta = b_eta_num / phi_v3;
547
548    // B_ηη = ∂B_η/∂η.
549    //
550    // We differentiate b_eta_num / (φ V³) using the quotient rule.
551    //
552    // Numerator derivative of b_eta_num w.r.t. η, using chain rule ∂/∂η = h₁·∂/∂μ
553    // for the V-dependent parts:
554    //
555    //   ∂/∂η [h₃ V²]               = h₄ V² + 2 h₃ V h₁ V₁
556    //   ∂/∂η [3 h₁ h₂ V V₁]        = 3(h₂² + h₁ h₃)V V₁ + 3 h₁ h₂(h₁ V₁² + V h₁ V₂)
557    //   ∂/∂η [h₁³ V V₂]            = 3 h₁² h₂ V V₂ + h₁³(h₁ V₁ V₂ + V h₁ V₃)
558    //   ∂/∂η [2 h₁³ V₁²]           = 6 h₁² h₂ V₁² + 4 h₁³ V₁ h₁ V₂
559    //                                = 6 h₁² h₂ V₁² + 4 h1_sq * h1_sq * v1 * v2
560    //
561    // Denominator derivative: ∂/∂η [φ V³] = 3 φ V² h₁ V₁.
562
563    let h1_cu = h1_sq * h1;
564    let h1_qu = h1_sq * h1_sq;
565
566    let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
567        - 3.0 * (h2 * h2 + h1 * h3) * v * v1
568        - 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
569        - 3.0 * h1_sq * h2 * v * v2
570        - h1_cu * (h1 * v1 * v2 + v * h1 * v3)
571        + 6.0 * h1_sq * h2 * v1 * v1
572        + 4.0 * h1_qu * v1 * v2;
573
574    let phi_v4 = phi_v3 * v;
575    let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
576
577    // ---- Assemble observed quantities ----
578    let resid = y - mu;
579
580    let w_obs = w_f - resid * b;
581    let c_obs = c_f + h1 * b - resid * b_eta;
582    let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
583
584    (pw * w_obs, pw * c_obs, pw * d_obs)
585}
586
587/// Per-observation third η-derivative of the observed-information weight,
588/// `e_obs := ∂³W_obs/∂η³`, for a general exponential-dispersion family with
589/// any (canonical or non-canonical) link.
590///
591/// Closed-form derivation:
592///   Define `T(η) := h₁(η)/(φ V(μ(η)))`. Then
593///   * Fisher weight `W_F = h₁ · T`
594///   * Observed correction `B = T'`, so `B_η = T''`, `B_ηη = T'''`,
595///     `B_ηηη = T''''`
596///   * `W_obs = W_F − (y−μ) · T'`
597///
598/// Differentiating three times:
599///   `∂³W_obs/∂η³ = W_F''' + h₃·T' + 3 h₂·T'' + 3 h₁·T''' − (y−μ)·T''''`
600///
601/// `T` is computed via Leibniz on `T·Q = h₁` with `Q = φV`; `W_F` via
602/// Leibniz on `W_F·1 = h₁·T` (product rule).
603///
604/// All inverse-link derivatives `h₁..h₅` and variance-function derivatives
605/// `V..V₄` are required as inputs. Caller supplies them.
606///
607/// Returns `pw * e_obs` (pre-multiplied by the prior weight) so the result
608/// scales identically to `(w_obs, c_obs, d_obs)` from
609/// `observed_weight_noncanonical`.
610#[inline]
611pub fn e_obs_from_jets(
612    y: f64,
613    mu: f64,
614    h1: f64,
615    h2: f64,
616    h3: f64,
617    h4: f64,
618    h5: f64,
619    vj: VarianceJet,
620    phi: f64,
621    pw: f64,
622) -> f64 {
623    let VarianceJet { v, v1, v2, v3, v4 } = vj;
624    let q = phi * v;
625
626    // Q = φV and its η-derivatives.
627    //   Q'    = φ V₁ h₁
628    //   Q''   = φ (V₁ h₂ + V₂ h₁²)
629    //   Q'''  = φ (V₁ h₃ + 3 V₂ h₁ h₂ + V₃ h₁³)
630    //   Q'''' = φ (V₁ h₄ + 4 V₂ h₁ h₃ + 3 V₂ h₂² + 6 V₃ h₁² h₂ + V₄ h₁⁴)
631    let h1_sq = h1 * h1;
632    let h1_cu = h1_sq * h1;
633    let h1_qu = h1_sq * h1_sq;
634
635    let q1 = phi * v1 * h1;
636    let q2 = phi * (v1 * h2 + v2 * h1_sq);
637    let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
638    let q4 = phi
639        * (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
640
641    // T = h₁/Q and T', T'', T''', T'''' via Leibniz on T·Q = h₁.
642    //   T'    = (h₂  − T·Q')/Q
643    //   T''   = (h₃  − 2 T'·Q' − T·Q'')/Q
644    //   T'''  = (h₄  − 3 T''·Q' − 3 T'·Q'' − T·Q''')/Q
645    //   T'''' = (h₅  − 4 T'''·Q' − 6 T''·Q'' − 4 T'·Q''' − T·Q'''')/Q
646    let t0 = h1 / q;
647    let t1 = (h2 - t0 * q1) / q;
648    let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
649    let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
650    let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
651
652    // Fisher weight derivatives via product rule on W_F = h₁·T.
653    //   W_F^(0) = h₁ T
654    //   W_F^(1) = h₁ T₁ + h₂ T
655    //   W_F^(2) = h₁ T₂ + 2 h₂ T₁ + h₃ T
656    //   W_F^(3) = h₁ T₃ + 3 h₂ T₂ + 3 h₃ T₁ + h₄ T
657    let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
658
659    // Observed third derivative: differentiate W_obs = W_F − (y−μ)·T₁ thrice.
660    // (resid)' = −h₁, so iterating product rule yields
661    //   ∂³((y−μ)·T₁)/∂η³ = −h₃·T₁ − 3 h₂·T₂ − 3 h₁·T₃ + (y−μ)·T₄
662    let resid = y - mu;
663    let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
664
665    pw * e_obs
666}
667
668// Direct (closed-form) observed-information weights for specific family-link
669// combinations.  These avoid the overhead of the generic noncanonical formula
670// when the algebra simplifies.
671
672/// Gaussian family with log link: y ~ N(μ, φ), μ = exp(η).
673///
674/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
675///
676/// ```text
677/// w_obs = ω μ(2μ − y) / φ
678/// c_obs = ω μ(4μ − y) / φ
679/// d_obs = ω μ(8μ − y) / φ
680/// ```
681#[inline]
682pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
683    let inv_phi = pw / phi;
684    let w = inv_phi * mu * (2.0 * mu - y);
685    let c = inv_phi * mu * (4.0 * mu - y);
686    let d = inv_phi * mu * (8.0 * mu - y);
687    (w, c, d)
688}
689
690/// Gaussian family with inverse link: y ~ N(μ, φ), μ = 1/η.
691///
692/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
693///
694/// ```text
695/// w_obs = ω (3 − 2ηy) / (φ η⁴)
696/// c_obs = 6ω (ηy − 2) / (φ η⁵)
697/// d_obs = 12ω (5 − 2ηy) / (φ η⁶)
698/// ```
699#[inline]
700pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
701    let eta2 = eta * eta;
702    let eta4 = eta2 * eta2;
703    let eta5 = eta4 * eta;
704    let eta6 = eta4 * eta2;
705    let ey = eta * y;
706    let inv_phi = pw / phi;
707    let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
708    let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
709    let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
710    (w, c, d)
711}
712
713#[inline]
714pub(crate) fn observed_weight_binomial_logit_from_jet(
715    n_trials: f64,
716    jet: MixtureInverseLinkJet,
717    pw: f64,
718) -> (f64, f64, f64) {
719    let scale = pw * n_trials;
720    (scale * jet.d1, scale * jet.d2, scale * jet.d3)
721}
722
723/// Family tag for the observed-information weight dispatch.
724///
725/// This is a simplified family tag that identifies the variance function,
726/// independent of the link function. It is used by [`observed_weight_dispatch`]
727/// to select closed-form weight specializations.
728#[derive(Debug, Clone, Copy, PartialEq)]
729pub enum WeightFamily {
730    Gaussian,
731    Binomial,
732    Poisson,
733    Tweedie { p: f64 },
734    NegativeBinomial { theta: f64 },
735    Beta { phi: f64 },
736    Gamma,
737}
738
739/// Link tag for the observed-information weight dispatch.
740///
741/// Identifies the link function for selecting closed-form weight
742/// specializations in [`observed_weight_dispatch`].
743#[derive(Debug, Clone, Copy, PartialEq, Eq)]
744pub enum WeightLink {
745    Identity,
746    Log,
747    Logit,
748    Inverse,
749    /// Any other link — falls back to the generic noncanonical formula.
750    Other,
751}
752
753#[inline]
754pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
755    match family {
756        WeightFamily::Gaussian => VarianceJet::gaussian(),
757        WeightFamily::Binomial => VarianceJet::binomial_n(mu),
758        WeightFamily::Poisson => VarianceJet::poisson(mu),
759        WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
760        WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
761        WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
762        WeightFamily::Gamma => VarianceJet::gamma(mu),
763    }
764}
765
766/// Dispatch to closed-form observed-information weights for known family-link
767/// combinations, falling back to the generic noncanonical formula.
768///
769/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight.
770///
771/// For the `Binomial + Logit` case, `n_trials` is passed as `phi` (dispersion
772/// slot is unused for binomial) and the prior weight controls the
773/// observation-level scaling. For all other cases, `phi` is the dispersion
774/// parameter.
775///
776/// `jet` and `h4` are the inverse-link derivatives used by the generic
777/// noncanonical fallback path. They may be zero for the specialized paths.
778pub fn observed_weight_dispatch(
779    family: WeightFamily,
780    link: WeightLink,
781    eta: f64,
782    y: f64,
783    mu: f64,
784    phi: f64,
785    prior_weight: f64,
786    jet: MixtureInverseLinkJet,
787    h4: f64,
788) -> (f64, f64, f64) {
789    match (family, link) {
790        (WeightFamily::Gaussian, WeightLink::Log) => {
791            observed_weight_gaussian_log(y, mu, phi, prior_weight)
792        }
793        (WeightFamily::Gaussian, WeightLink::Inverse) => {
794            observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
795        }
796        (WeightFamily::Binomial, WeightLink::Logit) => {
797            observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
798        }
799        _ => {
800            // Generic noncanonical path via the full variance-function jet.
801            let vj = variance_jet_for_weight_family(family, mu);
802            observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
803        }
804    }
805}
806
807#[derive(Clone)]
808pub enum DirectionalWorkingCurvature {
809    /// Directional derivative of the PIRLS curvature when the working
810    /// curvature is diagonal in observation space:
811    ///   W_τ = diag(w_τ).
812    Diagonal(Array1<f64>),
813}
814
815pub fn directionalworking_curvature_from_c_array(
816    c_array: &Array1<f64>,
817    hessian_weights: &Array1<f64>,
818    eta_direction: &Array1<f64>,
819) -> DirectionalWorkingCurvature {
820    let mut w_direction = c_array * eta_direction;
821    for i in 0..w_direction.len() {
822        if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
823            w_direction[i] = 0.0;
824        }
825    }
826    DirectionalWorkingCurvature::Diagonal(w_direction)
827}