Skip to main content

gam_solve/pirls/
curvature.rs

1//! Curvature primitives: the variance-function jet, observed-information
2//! Hessian weights, and the weight-family / weight-link classification used to
3//! choose between Fisher and observed curvature per family.
4
5use super::*;
6
7pub struct VarianceJet {
8    pub v: f64,
9    pub v1: f64,
10    pub v2: f64,
11    pub v3: f64,
12    pub v4: f64,
13}
14
15impl VarianceJet {
16    /// Lower floor on μ before evaluating power-law variance functions, so that
17    /// `μ^(p−k)` derivatives stay finite as μ → 0 instead of producing inf/NaN.
18    const VARIANCE_MU_FLOOR: f64 = 1e-10;
19
20    /// Bernoulli / binomial variance V(μ) = μ(1−μ).
21    #[inline]
22    pub fn bernoulli(mu: f64) -> Self {
23        Self {
24            v: mu * (1.0 - mu),
25            v1: 1.0 - 2.0 * mu,
26            v2: -2.0,
27            v3: 0.0,
28            v4: 0.0,
29        }
30    }
31
32    /// Poisson variance V(μ) = μ.
33    #[inline]
34    pub fn poisson(mu: f64) -> Self {
35        Self {
36            v: mu,
37            v1: 1.0,
38            v2: 0.0,
39            v3: 0.0,
40            v4: 0.0,
41        }
42    }
43
44    /// Gamma variance V(μ) = μ².
45    #[inline]
46    pub fn gamma(mu: f64) -> Self {
47        Self {
48            v: mu * mu,
49            v1: 2.0 * mu,
50            v2: 2.0,
51            v3: 0.0,
52            v4: 0.0,
53        }
54    }
55
56    /// Tweedie variance V(μ) = μ^p.
57    #[inline]
58    pub fn tweedie(mu: f64, p: f64) -> Self {
59        let mu = mu.max(Self::VARIANCE_MU_FLOOR);
60        Self {
61            v: mu.powf(p),
62            v1: p * mu.powf(p - 1.0),
63            v2: p * (p - 1.0) * mu.powf(p - 2.0),
64            v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
65            v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
66        }
67    }
68
69    /// Negative-binomial variance V(μ) = μ + μ² / theta.
70    #[inline]
71    pub fn negative_binomial(mu: f64, theta: f64) -> Self {
72        let mu = mu.max(Self::VARIANCE_MU_FLOOR);
73        let inv_theta = if valid_negbin_theta(theta) {
74            1.0 / theta
75        } else {
76            f64::NAN
77        };
78        Self {
79            v: mu + mu * mu * inv_theta,
80            v1: 1.0 + 2.0 * mu * inv_theta,
81            v2: 2.0 * inv_theta,
82            v3: 0.0,
83            v4: 0.0,
84        }
85    }
86
87    /// Gaussian (identity) variance V(μ) = 1.
88    #[inline]
89    pub fn gaussian() -> Self {
90        Self {
91            v: 1.0,
92            v1: 0.0,
93            v2: 0.0,
94            v3: 0.0,
95            v4: 0.0,
96        }
97    }
98
99    /// Binomial(n, p) variance V(p) = p(1−p), identical to Bernoulli.
100    ///
101    /// The trial count `n` enters as a prior-weight multiplier, not through
102    /// the variance function itself.
103    #[inline]
104    pub fn binomial_n(mu: f64) -> Self {
105        // V(μ) = μ(1−μ), same jet as Bernoulli
106        Self::bernoulli(mu)
107    }
108
109    /// Beta-regression variance V(μ) = μ(1−μ)/(1+φ).
110    #[inline]
111    pub fn beta(mu: f64, phi: f64) -> Self {
112        let scale = 1.0 / (1.0 + phi.max(1e-12));
113        let base = Self::bernoulli(mu);
114        Self {
115            v: base.v * scale,
116            v1: base.v1 * scale,
117            v2: base.v2 * scale,
118            v3: 0.0,
119            v4: 0.0,
120        }
121    }
122}
123
124pub(crate) const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
125
126pub(crate) const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
127
128/// Returns the per-row floor `max(fisher · 1e-6, 1e-12)` used by PIRLS to
129/// stabilize the observed-information Hessian H = X' W X + S. Saturated
130/// rows where W_obs ≤ floor were silently raised to `floor` when PIRLS
131/// built the inner Hessian; outer REML/LAML derivatives must use the
132/// **same** floored W to keep `H` and `dH/dψ` on one surface.
133///
134/// This is the single source of truth for the floor formula. Both the
135/// inner solver (`solver_hessian_weights_into`) and the outer derivative
136/// path (`outer_hessian_curvature_arrays`) route through this helper so
137/// the inner-stabilized H and the outer dH/dψ cannot drift apart.
138#[inline]
139pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
140    (fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
141        .max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
142}
143
144/// Build the (W, c, d) triple that matches PIRLS's stabilized H = X' W X + S.
145///
146/// PIRLS internally uses `W[i] = max(W_obs[i], floor(W_F[i]))` to keep H PD,
147/// but `pirls_result.finalweights` stores the **unfloored** observed weights.
148/// Reusing those directly in `∂H/∂ψ = X_τ' W X + … + X' diag(c · X_τ β̂) X`
149/// produces an operator that disagrees with `H` at every saturated row — a
150/// 5%-Frobenius bias that `tr(G_ε(H) · op)` amplifies by O(1/σ_min(H)),
151/// driving the analytic gradient off by orders of magnitude.
152///
153/// This helper returns the floored W, plus c and d masked to zero wherever
154/// the floor is active (so `∂W/∂η` is zero on the constant-floor branch).
155pub fn outer_hessian_curvature_arrays(
156    hessian_weights: gam_linalg::matrix::SignedWeightsView<'_>,
157    fisher_weights: gam_linalg::matrix::PsdWeightsView<'_>,
158    c_array: &Array1<f64>,
159    d_array: &Array1<f64>,
160    eta: &Array1<f64>,
161    inverse_link: &InverseLink,
162) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
163    let hessian_view = hessian_weights.view();
164    let fisher_view = fisher_weights.view();
165    let n = hessian_view.len();
166    let mut w_out = Array1::<f64>::zeros(n);
167    let mut c_out = Array1::<f64>::zeros(n);
168    let mut d_out = Array1::<f64>::zeros(n);
169    for i in 0..n {
170        let floor = solver_hessian_weight_floor(fisher_view[i]);
171        let w = hessian_view[i];
172        let clamp_active = eta_clamp_active(inverse_link, eta[i]);
173        let w_below_floor = !(w.is_finite() && w > floor);
174        if w_below_floor {
175            w_out[i] = floor;
176            c_out[i] = 0.0;
177            d_out[i] = 0.0;
178        } else if clamp_active {
179            w_out[i] = w;
180            c_out[i] = 0.0;
181            d_out[i] = 0.0;
182        } else {
183            w_out[i] = w;
184            c_out[i] = c_array[i];
185            d_out[i] = d_array[i];
186        }
187    }
188    (w_out, c_out, d_out)
189}
190
191#[inline]
192pub(crate) fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
193    likelihood.fixed_phi().unwrap_or(1.0)
194}
195
196#[inline]
197pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
198    match &likelihood.spec.response {
199        ResponseFamily::Gaussian => WeightFamily::Gaussian,
200        ResponseFamily::Poisson => WeightFamily::Poisson,
201        ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
202        ResponseFamily::NegativeBinomial { theta, .. } => {
203            WeightFamily::NegativeBinomial { theta: *theta }
204        }
205        ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
206        ResponseFamily::Gamma => WeightFamily::Gamma,
207        ResponseFamily::Binomial => WeightFamily::Binomial,
208        ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
209    }
210}
211
212#[inline]
213pub(crate) fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
214    match inverse_link {
215        InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
216        InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
217        InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
218        InverseLink::Standard(StandardLink::Probit)
219        | InverseLink::Standard(StandardLink::CLogLog)
220        | InverseLink::Standard(StandardLink::LogLog)
221        | InverseLink::Standard(StandardLink::Cauchit)
222        | InverseLink::LatentCLogLog(_)
223        | InverseLink::Sas(_)
224        | InverseLink::BetaLogistic(_)
225        | InverseLink::Mixture(_) => WeightLink::Other,
226    }
227}
228
229#[inline]
230pub(crate) fn supports_observed_hessian_curvature_for_likelihood(
231    likelihood: &GlmLikelihoodSpec,
232    inverse_link: &InverseLink,
233) -> bool {
234    let spec = &likelihood.spec;
235    if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
236        return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
237    }
238    if matches!(spec.response, ResponseFamily::Gamma) {
239        return true;
240    }
241    if !matches!(spec.response, ResponseFamily::Binomial) {
242        return false;
243    }
244    matches!(
245        spec.link,
246        InverseLink::Standard(StandardLink::Probit)
247            | InverseLink::Standard(StandardLink::CLogLog)
248            | InverseLink::Standard(StandardLink::LogLog)
249            | InverseLink::Standard(StandardLink::Cauchit)
250            | InverseLink::Sas(_)
251            | InverseLink::BetaLogistic(_)
252            | InverseLink::Mixture(_)
253    )
254}
255
256#[inline]
257pub(crate) fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
258    match inverse_link {
259        // Why: canonical links keep V(mu) representable across the full f64 eta range; only guard against inf.
260        InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
261            eta.clamp(-ETA_CLAMP, ETA_CLAMP)
262        }
263        InverseLink::Standard(StandardLink::Identity) => eta,
264        // Why: probit mu=Phi(eta) saturates to 1.0 in f64 by |eta|~8.3; +/-6 keeps V=mu(1-mu) ~ 1e-9 representable.
265        InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
266        // Why: cloglog has mu~exp(eta) for eta<<0 (underflows below ~-23) and 1-mu~exp(-exp(eta)) collapses by eta=3.
267        InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
268            eta.clamp(-23.0, 3.0)
269        }
270        InverseLink::Standard(StandardLink::LogLog) => eta.clamp(-3.0, 23.0),
271        InverseLink::Standard(StandardLink::Cauchit) => eta.clamp(-1.0e6, 1.0e6),
272        // Why: SAS / beta-logistic / mixture compose logistic-like sigmoids that saturate by |eta|~20 (logistic(20)~1-2e-9).
273        InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
274            eta.clamp(-20.0, 20.0)
275        }
276    }
277}
278
279/// Returns true at rows where PIRLS clamped η (so the observed-info weights
280/// were computed at the clamped value, making `∂W/∂η` zero w.r.t. the
281/// **unclamped** η).  Outer REML/LAML derivative formulas must mask `c_obs`
282/// and `d_obs` to zero on these rows or the analytic ∂H/∂ψ disagrees with
283/// the H whose log-det we differentiate.
284#[inline]
285pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
286    let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
287    clamped != eta
288}
289
290/// Build solver-conditioned weights from the exact hessian weights.
291///
292/// The returned array applies a solver-only floor per observation so the
293/// Newton linear system X'W X + S stays numerically usable. This floor is
294/// purely a linear-algebra concern: the exact statistical weights stored in
295/// `lasthessian_weights` / `finalweights` are not affected.
296pub(crate) fn solver_hessian_weights_into(
297    hessian_weights: &Array1<f64>,
298    fisher_weights: &Array1<f64>,
299    out: &mut Array1<f64>,
300) {
301    if out.len() != hessian_weights.len() {
302        *out = Array1::<f64>::zeros(hessian_weights.len());
303    }
304    ndarray::Zip::from(out)
305        .and(hessian_weights)
306        .and(fisher_weights)
307        .par_for_each(|o, &w, &fw| {
308            let floor = solver_hessian_weight_floor(fw);
309            *o = if w.is_finite() && w > floor { w } else { floor };
310        });
311}
312
313/// Compute vectorised observed-information curvature arrays (w_obs, c_obs, d_obs)
314/// for the Hessian surface at the mode.
315///
316/// This function is the primary entry point for obtaining the observed weights
317/// that flow into the outer REML/LAML Hessian H_obs = X' W_obs X + S. The
318/// observed corrections include residual-dependent terms that vanish for
319/// canonical links but are nonzero for probit, cloglog, SAS, mixture, Gamma-log,
320/// and other flexible links.
321///
322/// The output arrays are:
323/// - `hessian_weights`: W_obs per observation (exact; solver floor applied separately).
324/// - `hessian_c`: c_obs = dW_obs/deta per observation (for outer gradient C[v]).
325/// - `hessian_d`: d_obs = d^2W_obs/deta^2 per observation (for outer Hessian Q[v_k,v_l]).
326///
327/// See `observed_weight_noncanonical` for the per-observation formulas and
328/// response.md Section 3 for the mathematical justification of why observed
329/// (not Fisher) information is required.
330pub(crate) fn compute_observed_hessian_curvature_arrays_into(
331    likelihood: &GlmLikelihoodSpec,
332    inverse_link: &InverseLink,
333    eta: &Array1<f64>,
334    y: ArrayView1<'_, f64>,
335    fisher_weights: &Array1<f64>,
336    priorweights: ArrayView1<'_, f64>,
337    hessian_weights: &mut Array1<f64>,
338    hessian_c: &mut Array1<f64>,
339    hessian_d: &mut Array1<f64>,
340) -> Result<(), EstimationError> {
341    assert!(supports_observed_hessian_curvature_for_likelihood(
342        likelihood,
343        inverse_link
344    ));
345    let n = eta.len();
346    if hessian_weights.len() != n {
347        *hessian_weights = Array1::<f64>::zeros(n);
348    }
349    if hessian_c.len() != n {
350        *hessian_c = Array1::<f64>::zeros(n);
351    }
352    if hessian_d.len() != n {
353        *hessian_d = Array1::<f64>::zeros(n);
354    }
355
356    let weight_family = weight_family_for_glm_likelihood(likelihood);
357    let weight_link = weight_link_for_inverse_link(inverse_link);
358    let phi = fixed_glm_dispersion(likelihood);
359
360    // Parallel per-row weight assembly. At large scale (n = 320k) this loop
361    // dominates non-canonical paths because each row independently evaluates
362    // inverse-link jets and residual-dependent observed curvature. Write
363    // directly into reusable output slices rather than collecting row tuples,
364    // which removes an O(n) temporary allocation on every PIRLS update.
365    hessian_weights
366        .as_slice_mut()
367        .expect("hessian weights must be contiguous")
368        .par_iter_mut()
369        .zip(
370            hessian_c
371                .as_slice_mut()
372                .expect("hessian c must be contiguous")
373                .par_iter_mut(),
374        )
375        .zip(
376            hessian_d
377                .as_slice_mut()
378                .expect("hessian d must be contiguous")
379                .par_iter_mut(),
380        )
381        .enumerate()
382        .try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
383            let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
384            // Why: closed-form observed_weight_noncanonical requires (mu, d1..d3, h4) at one consistent eta;
385            // mixing PIRLS-state jets at unclamped eta with h4 at eta_used produced 0/0 in phi_v* divisions,
386            // surfacing as: "observed Hessian curvature is not positive finite at row N: observed=NaN, fisher=0".
387            let jet =
388                crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
389            let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
390                inverse_link, eta_used,
391            )?;
392            let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
393                weight_family,
394                weight_link,
395                eta_used,
396                y[i],
397                jet.mu,
398                phi,
399                priorweights[i].max(0.0),
400                jet,
401                h4,
402            );
403            let fisher_weight = fisher_weights[i].max(0.0);
404            // A *finite* but non-positive observed weight is NOT a failure: the
405            // observed information `W_obs = W_Fisher - (y-μ)·B` legitimately goes
406            // indefinite on individual rows for a non-canonical link (probit,
407            // cloglog, SAS, and — critically for #1598 — a blended/mixture link)
408            // whenever a large residual flips the sign of the residual-dependent
409            // correction. The inner Newton system never uses this raw value: it
410            // is clamped to the SPD floor `max(W_Fisher·1e-6, 1e-12)` by
411            // `solver_hessian_weights_into`, and the outer REML/LAML derivative
412            // path applies the *same* floor through `outer_hessian_curvature_arrays`
413            // (which also zeroes c/d on the floored row). Both consumers are
414            // designed precisely to absorb an indefinite W_obs, so hard-bailing
415            // here defeats that stabilization and aborts an otherwise well-posed
416            // solve — the mixture/SAS joint link fit on data its own pure
417            // components fit trivially (clean logit under blended(logit, probit)).
418            //
419            // We therefore reject ONLY a genuinely non-finite (NaN/Inf) weight,
420            // which signals a broken jet rather than benign indefiniteness, and
421            // pass finite values (including non-positive ones) straight through to
422            // the flooring consumers. Likewise `c_obs`/`d_obs` only need to be
423            // finite; they are zeroed automatically on any floored row downstream.
424            if !w_obs.is_finite() {
425                crate::bail_invalid_estim!(
426                    "observed Hessian curvature is not finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
427                );
428            }
429            if !c_obs.is_finite() || !d_obs.is_finite() {
430                crate::bail_invalid_estim!(
431                    "observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
432                );
433            }
434            *w_out = w_obs;
435            *c_out = c_obs;
436            *d_out = d_obs;
437            Ok(())
438        })
439}
440
441pub(crate) fn compute_observed_hessian_curvature_arrays(
442    likelihood: &GlmLikelihoodSpec,
443    inverse_link: &InverseLink,
444    eta: &Array1<f64>,
445    y: ArrayView1<'_, f64>,
446    fisher_weights: &Array1<f64>,
447    priorweights: ArrayView1<'_, f64>,
448) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
449    let n = eta.len();
450    let mut hessian_weights = Array1::<f64>::zeros(n);
451    let mut hessian_c = Array1::<f64>::zeros(n);
452    let mut hessian_d = Array1::<f64>::zeros(n);
453    compute_observed_hessian_curvature_arrays_into(
454        likelihood,
455        inverse_link,
456        eta,
457        y,
458        fisher_weights,
459        priorweights,
460        &mut hessian_weights,
461        &mut hessian_c,
462        &mut hessian_d,
463    )?;
464    Ok((hessian_weights, hessian_c, hessian_d))
465}
466
467/// Per-observation observed-information weights and their first two
468/// eta-derivatives for a general exponential-dispersion family with a
469/// noncanonical link.
470///
471/// The observed weight differs from the Fisher (expected) weight by a
472/// residual-dependent correction (see response.md Section 3):
473///
474///   W_obs = W_Fisher - (y - mu) * B
475///   B = (h'' V - h'^2 V') / (phi V^2)
476///
477///   c_obs = c_Fisher + h' * B - (y - mu) * B_eta
478///   d_obs = d_Fisher + h'' * B + 2*h' * B_eta - (y - mu) * B_etaeta
479///
480/// For canonical links (for example logit-Binomial and log-Poisson), B = 0
481/// so observed = Fisher and no correction is needed.
482///
483/// These observed quantities are required for:
484/// 1. The outer REML/LAML Hessian H_obs = X' W_obs X + S (log|H| term).
485/// 2. The outer gradient's C[v] correction (uses c_obs).
486/// 3. The outer Hessian's Q[v_k, v_l] correction (uses d_obs).
487///
488/// Using Fisher weights in the outer REML would yield a PQL-type surrogate
489/// rather than the exact Laplace approximation.
490///
491/// # Arguments
492/// * `y`   -- response value
493/// * `mu`  -- fitted mean h(eta)
494/// * `h1`...`h4` -- inverse-link derivatives h'(eta) ... h''''(eta)
495/// * `vj`  -- variance-function jet (V, V', V'', V''') evaluated at mu
496/// * `phi` -- dispersion parameter (1.0 for Bernoulli/Poisson)
497/// * `pw`  -- prior weight for this observation
498///
499/// # Returns
500/// `(w_obs, c_obs, d_obs)` -- the observed weight and its first two
501/// eta-derivatives, all pre-multiplied by `pw`.
502#[inline]
503pub fn observed_weight_noncanonical(
504    y: f64,
505    mu: f64,
506    h1: f64,
507    h2: f64,
508    h3: f64,
509    h4: f64,
510    vj: VarianceJet,
511    phi: f64,
512    pw: f64,
513) -> (f64, f64, f64) {
514    let VarianceJet {
515        v,
516        v1,
517        v2,
518        v3,
519        v4: _,
520    } = vj;
521    let phi_v = phi * v;
522    let phi_v2 = phi * v * v;
523    let phi_v3 = phi * v * v * v;
524
525    // ---- Fisher weight and derivatives ----
526    let h1_sq = h1 * h1;
527    let w_f = h1_sq / phi_v;
528
529    // c_F = (2 h₁ h₂ V − h₁³ V₁) / (φ V²)
530    let n0 = h1_sq; // numerator of w_F
531    let n1 = 2.0 * h1 * h2; // ∂(h₁²)/∂η
532    let n2 = 2.0 * (h2 * h2 + h1 * h3); // ∂²(h₁²)/∂η²
533    let vd1 = h1 * v1; // ∂V/∂η = V'·h'
534    let vd2 = h2 * v1 + h1_sq * v2; // ∂²V/∂η²
535
536    let c_f = (n1 * v - n0 * vd1) / phi_v2;
537
538    // d_F = ∂c_F/∂η via quotient rule on c_F = (n1·v − n0·vd1) / (φ·v²)
539    // numerator of c_F and its η-derivative (cross terms cancel):
540    let numer_cf = n1 * v - n0 * vd1;
541    let dnumer_cf = n2 * v - n0 * vd2;
542    let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
543
544    // ---- Observed correction term B and its η-derivatives ----
545    // B = (h₂ V − h₁² V₁) / (φ V²)
546    let b_num = h2 * v - h1_sq * v1;
547    let b = b_num / phi_v2;
548
549    // B_η = (h₃ V² − 3 h₁ h₂ V V₁ − h₁³ V V₂ + 2 h₁³ V₁²) / (φ V³)
550    let b_eta_num =
551        h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
552    let b_eta = b_eta_num / phi_v3;
553
554    // B_ηη = ∂B_η/∂η.
555    //
556    // We differentiate b_eta_num / (φ V³) using the quotient rule.
557    //
558    // Numerator derivative of b_eta_num w.r.t. η, using chain rule ∂/∂η = h₁·∂/∂μ
559    // for the V-dependent parts:
560    //
561    //   ∂/∂η [h₃ V²]               = h₄ V² + 2 h₃ V h₁ V₁
562    //   ∂/∂η [3 h₁ h₂ V V₁]        = 3(h₂² + h₁ h₃)V V₁ + 3 h₁ h₂(h₁ V₁² + V h₁ V₂)
563    //   ∂/∂η [h₁³ V V₂]            = 3 h₁² h₂ V V₂ + h₁³(h₁ V₁ V₂ + V h₁ V₃)
564    //   ∂/∂η [2 h₁³ V₁²]           = 6 h₁² h₂ V₁² + 4 h₁³ V₁ h₁ V₂
565    //                                = 6 h₁² h₂ V₁² + 4 h1_sq * h1_sq * v1 * v2
566    //
567    // Denominator derivative: ∂/∂η [φ V³] = 3 φ V² h₁ V₁.
568
569    let h1_cu = h1_sq * h1;
570    let h1_qu = h1_sq * h1_sq;
571
572    let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
573        - 3.0 * (h2 * h2 + h1 * h3) * v * v1
574        - 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
575        - 3.0 * h1_sq * h2 * v * v2
576        - h1_cu * (h1 * v1 * v2 + v * h1 * v3)
577        + 6.0 * h1_sq * h2 * v1 * v1
578        + 4.0 * h1_qu * v1 * v2;
579
580    let phi_v4 = phi_v3 * v;
581    let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
582
583    // ---- Assemble observed quantities ----
584    let resid = y - mu;
585
586    let w_obs = w_f - resid * b;
587    let c_obs = c_f + h1 * b - resid * b_eta;
588    let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
589
590    (pw * w_obs, pw * c_obs, pw * d_obs)
591}
592
593/// Per-observation third η-derivative of the observed-information weight,
594/// `e_obs := ∂³W_obs/∂η³`, for a general exponential-dispersion family with
595/// any (canonical or non-canonical) link.
596///
597/// Closed-form derivation:
598///   Define `T(η) := h₁(η)/(φ V(μ(η)))`. Then
599///   * Fisher weight `W_F = h₁ · T`
600///   * Observed correction `B = T'`, so `B_η = T''`, `B_ηη = T'''`,
601///     `B_ηηη = T''''`
602///   * `W_obs = W_F − (y−μ) · T'`
603///
604/// Differentiating three times:
605///   `∂³W_obs/∂η³ = W_F''' + h₃·T' + 3 h₂·T'' + 3 h₁·T''' − (y−μ)·T''''`
606///
607/// `T` is computed via Leibniz on `T·Q = h₁` with `Q = φV`; `W_F` via
608/// Leibniz on `W_F·1 = h₁·T` (product rule).
609///
610/// All inverse-link derivatives `h₁..h₅` and variance-function derivatives
611/// `V..V₄` are required as inputs. Caller supplies them.
612///
613/// Returns `pw * e_obs` (pre-multiplied by the prior weight) so the result
614/// scales identically to `(w_obs, c_obs, d_obs)` from
615/// `observed_weight_noncanonical`.
616#[inline]
617pub fn e_obs_from_jets(
618    y: f64,
619    mu: f64,
620    h1: f64,
621    h2: f64,
622    h3: f64,
623    h4: f64,
624    h5: f64,
625    vj: VarianceJet,
626    phi: f64,
627    pw: f64,
628) -> f64 {
629    let VarianceJet { v, v1, v2, v3, v4 } = vj;
630    let q = phi * v;
631
632    // Q = φV and its η-derivatives.
633    //   Q'    = φ V₁ h₁
634    //   Q''   = φ (V₁ h₂ + V₂ h₁²)
635    //   Q'''  = φ (V₁ h₃ + 3 V₂ h₁ h₂ + V₃ h₁³)
636    //   Q'''' = φ (V₁ h₄ + 4 V₂ h₁ h₃ + 3 V₂ h₂² + 6 V₃ h₁² h₂ + V₄ h₁⁴)
637    let h1_sq = h1 * h1;
638    let h1_cu = h1_sq * h1;
639    let h1_qu = h1_sq * h1_sq;
640
641    let q1 = phi * v1 * h1;
642    let q2 = phi * (v1 * h2 + v2 * h1_sq);
643    let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
644    let q4 = phi
645        * (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
646
647    // T = h₁/Q and T', T'', T''', T'''' via Leibniz on T·Q = h₁.
648    //   T'    = (h₂  − T·Q')/Q
649    //   T''   = (h₃  − 2 T'·Q' − T·Q'')/Q
650    //   T'''  = (h₄  − 3 T''·Q' − 3 T'·Q'' − T·Q''')/Q
651    //   T'''' = (h₅  − 4 T'''·Q' − 6 T''·Q'' − 4 T'·Q''' − T·Q'''')/Q
652    let t0 = h1 / q;
653    let t1 = (h2 - t0 * q1) / q;
654    let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
655    let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
656    let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
657
658    // Fisher weight derivatives via product rule on W_F = h₁·T.
659    //   W_F^(0) = h₁ T
660    //   W_F^(1) = h₁ T₁ + h₂ T
661    //   W_F^(2) = h₁ T₂ + 2 h₂ T₁ + h₃ T
662    //   W_F^(3) = h₁ T₃ + 3 h₂ T₂ + 3 h₃ T₁ + h₄ T
663    let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
664
665    // Observed third derivative: differentiate W_obs = W_F − (y−μ)·T₁ thrice.
666    // (resid)' = −h₁, so iterating product rule yields
667    //   ∂³((y−μ)·T₁)/∂η³ = −h₃·T₁ − 3 h₂·T₂ − 3 h₁·T₃ + (y−μ)·T₄
668    let resid = y - mu;
669    let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
670
671    pw * e_obs
672}
673
674// Direct (closed-form) observed-information weights for specific family-link
675// combinations.  These avoid the overhead of the generic noncanonical formula
676// when the algebra simplifies.
677
678/// Gaussian family with log link: y ~ N(μ, φ), μ = exp(η).
679///
680/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
681///
682/// ```text
683/// w_obs = ω μ(2μ − y) / φ
684/// c_obs = ω μ(4μ − y) / φ
685/// d_obs = ω μ(8μ − y) / φ
686/// ```
687#[inline]
688pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
689    let inv_phi = pw / phi;
690    let w = inv_phi * mu * (2.0 * mu - y);
691    let c = inv_phi * mu * (4.0 * mu - y);
692    let d = inv_phi * mu * (8.0 * mu - y);
693    (w, c, d)
694}
695
696/// Gaussian family with inverse link: y ~ N(μ, φ), μ = 1/η.
697///
698/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
699///
700/// ```text
701/// w_obs = ω (3 − 2ηy) / (φ η⁴)
702/// c_obs = 6ω (ηy − 2) / (φ η⁵)
703/// d_obs = 12ω (5 − 2ηy) / (φ η⁶)
704/// ```
705#[inline]
706pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
707    let eta2 = eta * eta;
708    let eta4 = eta2 * eta2;
709    let eta5 = eta4 * eta;
710    let eta6 = eta4 * eta2;
711    let ey = eta * y;
712    let inv_phi = pw / phi;
713    let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
714    let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
715    let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
716    (w, c, d)
717}
718
719#[inline]
720pub(crate) fn observed_weight_binomial_logit_from_jet(
721    n_trials: f64,
722    jet: MixtureInverseLinkJet,
723    pw: f64,
724) -> (f64, f64, f64) {
725    let scale = pw * n_trials;
726    (scale * jet.d1, scale * jet.d2, scale * jet.d3)
727}
728
729/// Family tag for the observed-information weight dispatch.
730///
731/// This is a simplified family tag that identifies the variance function,
732/// independent of the link function. It is used by [`observed_weight_dispatch`]
733/// to select closed-form weight specializations.
734#[derive(Debug, Clone, Copy, PartialEq)]
735pub enum WeightFamily {
736    Gaussian,
737    Binomial,
738    Poisson,
739    Tweedie { p: f64 },
740    NegativeBinomial { theta: f64 },
741    Beta { phi: f64 },
742    Gamma,
743}
744
745/// Link tag for the observed-information weight dispatch.
746///
747/// Identifies the link function for selecting closed-form weight
748/// specializations in [`observed_weight_dispatch`].
749#[derive(Debug, Clone, Copy, PartialEq, Eq)]
750pub enum WeightLink {
751    Identity,
752    Log,
753    Logit,
754    Inverse,
755    /// Any other link — falls back to the generic noncanonical formula.
756    Other,
757}
758
759#[inline]
760pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
761    match family {
762        WeightFamily::Gaussian => VarianceJet::gaussian(),
763        WeightFamily::Binomial => VarianceJet::binomial_n(mu),
764        WeightFamily::Poisson => VarianceJet::poisson(mu),
765        WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
766        WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
767        WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
768        WeightFamily::Gamma => VarianceJet::gamma(mu),
769    }
770}
771
772/// Dispatch to closed-form observed-information weights for known family-link
773/// combinations, falling back to the generic noncanonical formula.
774///
775/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight.
776///
777/// For the `Binomial + Logit` case, `n_trials` is passed as `phi` (dispersion
778/// slot is unused for binomial) and the prior weight controls the
779/// observation-level scaling. For all other cases, `phi` is the dispersion
780/// parameter.
781///
782/// `jet` and `h4` are the inverse-link derivatives used by the generic
783/// noncanonical fallback path. They may be zero for the specialized paths.
784pub fn observed_weight_dispatch(
785    family: WeightFamily,
786    link: WeightLink,
787    eta: f64,
788    y: f64,
789    mu: f64,
790    phi: f64,
791    prior_weight: f64,
792    jet: MixtureInverseLinkJet,
793    h4: f64,
794) -> (f64, f64, f64) {
795    match (family, link) {
796        (WeightFamily::Gaussian, WeightLink::Log) => {
797            observed_weight_gaussian_log(y, mu, phi, prior_weight)
798        }
799        (WeightFamily::Gaussian, WeightLink::Inverse) => {
800            observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
801        }
802        (WeightFamily::Binomial, WeightLink::Logit) => {
803            observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
804        }
805        _ => {
806            // Generic noncanonical path via the full variance-function jet.
807            let vj = variance_jet_for_weight_family(family, mu);
808            observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
809        }
810    }
811}
812
813#[derive(Clone)]
814pub enum DirectionalWorkingCurvature {
815    /// Directional derivative of the PIRLS curvature when the working
816    /// curvature is diagonal in observation space:
817    ///   W_τ = diag(w_τ).
818    Diagonal(Array1<f64>),
819}
820
821pub fn directionalworking_curvature_from_c_array(
822    c_array: &Array1<f64>,
823    hessian_weights: &Array1<f64>,
824    eta_direction: &Array1<f64>,
825) -> DirectionalWorkingCurvature {
826    let mut w_direction = c_array * eta_direction;
827    for i in 0..w_direction.len() {
828        if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
829            w_direction[i] = 0.0;
830        }
831    }
832    DirectionalWorkingCurvature::Diagonal(w_direction)
833}