gam_solve/pirls/curvature.rs
1//! Curvature primitives: the variance-function jet, observed-information
2//! Hessian weights, and the weight-family / weight-link classification used to
3//! choose between Fisher and observed curvature per family.
4
5use super::*;
6
7pub struct VarianceJet {
8 pub v: f64,
9 pub v1: f64,
10 pub v2: f64,
11 pub v3: f64,
12 pub v4: f64,
13}
14
15impl VarianceJet {
16 /// Lower floor on μ before evaluating power-law variance functions, so that
17 /// `μ^(p−k)` derivatives stay finite as μ → 0 instead of producing inf/NaN.
18 const VARIANCE_MU_FLOOR: f64 = 1e-10;
19
20 /// Bernoulli / binomial variance V(μ) = μ(1−μ).
21 #[inline]
22 pub fn bernoulli(mu: f64) -> Self {
23 Self {
24 v: mu * (1.0 - mu),
25 v1: 1.0 - 2.0 * mu,
26 v2: -2.0,
27 v3: 0.0,
28 v4: 0.0,
29 }
30 }
31
32 /// Poisson variance V(μ) = μ.
33 #[inline]
34 pub fn poisson(mu: f64) -> Self {
35 Self {
36 v: mu,
37 v1: 1.0,
38 v2: 0.0,
39 v3: 0.0,
40 v4: 0.0,
41 }
42 }
43
44 /// Gamma variance V(μ) = μ².
45 #[inline]
46 pub fn gamma(mu: f64) -> Self {
47 Self {
48 v: mu * mu,
49 v1: 2.0 * mu,
50 v2: 2.0,
51 v3: 0.0,
52 v4: 0.0,
53 }
54 }
55
56 /// Tweedie variance V(μ) = μ^p.
57 #[inline]
58 pub fn tweedie(mu: f64, p: f64) -> Self {
59 let mu = mu.max(Self::VARIANCE_MU_FLOOR);
60 Self {
61 v: mu.powf(p),
62 v1: p * mu.powf(p - 1.0),
63 v2: p * (p - 1.0) * mu.powf(p - 2.0),
64 v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
65 v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
66 }
67 }
68
69 /// Negative-binomial variance V(μ) = μ + μ² / theta.
70 #[inline]
71 pub fn negative_binomial(mu: f64, theta: f64) -> Self {
72 let mu = mu.max(Self::VARIANCE_MU_FLOOR);
73 let inv_theta = if valid_negbin_theta(theta) {
74 1.0 / theta
75 } else {
76 f64::NAN
77 };
78 Self {
79 v: mu + mu * mu * inv_theta,
80 v1: 1.0 + 2.0 * mu * inv_theta,
81 v2: 2.0 * inv_theta,
82 v3: 0.0,
83 v4: 0.0,
84 }
85 }
86
87 /// Gaussian (identity) variance V(μ) = 1.
88 #[inline]
89 pub fn gaussian() -> Self {
90 Self {
91 v: 1.0,
92 v1: 0.0,
93 v2: 0.0,
94 v3: 0.0,
95 v4: 0.0,
96 }
97 }
98
99 /// Binomial(n, p) variance V(p) = p(1−p), identical to Bernoulli.
100 ///
101 /// The trial count `n` enters as a prior-weight multiplier, not through
102 /// the variance function itself.
103 #[inline]
104 pub fn binomial_n(mu: f64) -> Self {
105 // V(μ) = μ(1−μ), same jet as Bernoulli
106 Self::bernoulli(mu)
107 }
108
109 /// Beta-regression variance V(μ) = μ(1−μ)/(1+φ).
110 #[inline]
111 pub fn beta(mu: f64, phi: f64) -> Self {
112 let scale = 1.0 / (1.0 + phi.max(1e-12));
113 let base = Self::bernoulli(mu);
114 Self {
115 v: base.v * scale,
116 v1: base.v1 * scale,
117 v2: base.v2 * scale,
118 v3: 0.0,
119 v4: 0.0,
120 }
121 }
122}
123
124pub(crate) const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
125
126pub(crate) const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
127
128/// Returns the per-row floor `max(fisher · 1e-6, 1e-12)` used by PIRLS to
129/// stabilize the observed-information Hessian H = X' W X + S. Saturated
130/// rows where W_obs ≤ floor were silently raised to `floor` when PIRLS
131/// built the inner Hessian; outer REML/LAML derivatives must use the
132/// **same** floored W to keep `H` and `dH/dψ` on one surface.
133///
134/// This is the single source of truth for the floor formula. Both the
135/// inner solver (`solver_hessian_weights_into`) and the outer derivative
136/// path (`outer_hessian_curvature_arrays`) route through this helper so
137/// the inner-stabilized H and the outer dH/dψ cannot drift apart.
138#[inline]
139pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
140 (fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
141 .max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
142}
143
144/// Build the (W, c, d) triple that matches PIRLS's stabilized H = X' W X + S.
145///
146/// PIRLS internally uses `W[i] = max(W_obs[i], floor(W_F[i]))` to keep H PD,
147/// but `pirls_result.finalweights` stores the **unfloored** observed weights.
148/// Reusing those directly in `∂H/∂ψ = X_τ' W X + … + X' diag(c · X_τ β̂) X`
149/// produces an operator that disagrees with `H` at every saturated row — a
150/// 5%-Frobenius bias that `tr(G_ε(H) · op)` amplifies by O(1/σ_min(H)),
151/// driving the analytic gradient off by orders of magnitude.
152///
153/// This helper returns the floored W, plus c and d masked to zero wherever
154/// the floor is active (so `∂W/∂η` is zero on the constant-floor branch).
155pub fn outer_hessian_curvature_arrays(
156 hessian_weights: gam_linalg::matrix::SignedWeightsView<'_>,
157 fisher_weights: gam_linalg::matrix::PsdWeightsView<'_>,
158 c_array: &Array1<f64>,
159 d_array: &Array1<f64>,
160 eta: &Array1<f64>,
161 inverse_link: &InverseLink,
162) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
163 let hessian_view = hessian_weights.view();
164 let fisher_view = fisher_weights.view();
165 let n = hessian_view.len();
166 let mut w_out = Array1::<f64>::zeros(n);
167 let mut c_out = Array1::<f64>::zeros(n);
168 let mut d_out = Array1::<f64>::zeros(n);
169 for i in 0..n {
170 let floor = solver_hessian_weight_floor(fisher_view[i]);
171 let w = hessian_view[i];
172 let clamp_active = eta_clamp_active(inverse_link, eta[i]);
173 let w_below_floor = !(w.is_finite() && w > floor);
174 if w_below_floor {
175 w_out[i] = floor;
176 c_out[i] = 0.0;
177 d_out[i] = 0.0;
178 } else if clamp_active {
179 w_out[i] = w;
180 c_out[i] = 0.0;
181 d_out[i] = 0.0;
182 } else {
183 w_out[i] = w;
184 c_out[i] = c_array[i];
185 d_out[i] = d_array[i];
186 }
187 }
188 (w_out, c_out, d_out)
189}
190
191#[inline]
192pub(crate) fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
193 likelihood.fixed_phi().unwrap_or(1.0)
194}
195
196#[inline]
197pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
198 match &likelihood.spec.response {
199 ResponseFamily::Gaussian => WeightFamily::Gaussian,
200 ResponseFamily::Poisson => WeightFamily::Poisson,
201 ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
202 ResponseFamily::NegativeBinomial { theta, .. } => {
203 WeightFamily::NegativeBinomial { theta: *theta }
204 }
205 ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
206 ResponseFamily::Gamma => WeightFamily::Gamma,
207 ResponseFamily::Binomial => WeightFamily::Binomial,
208 ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
209 }
210}
211
212#[inline]
213pub(crate) fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
214 match inverse_link {
215 InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
216 InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
217 InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
218 InverseLink::Standard(StandardLink::Probit)
219 | InverseLink::Standard(StandardLink::CLogLog)
220 | InverseLink::LatentCLogLog(_)
221 | InverseLink::Sas(_)
222 | InverseLink::BetaLogistic(_)
223 | InverseLink::Mixture(_) => WeightLink::Other,
224 }
225}
226
227#[inline]
228pub(crate) fn supports_observed_hessian_curvature_for_likelihood(
229 likelihood: &GlmLikelihoodSpec,
230 inverse_link: &InverseLink,
231) -> bool {
232 let spec = &likelihood.spec;
233 if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
234 return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
235 }
236 if matches!(spec.response, ResponseFamily::Gamma) {
237 return true;
238 }
239 if !matches!(spec.response, ResponseFamily::Binomial) {
240 return false;
241 }
242 matches!(
243 spec.link,
244 InverseLink::Standard(StandardLink::Probit)
245 | InverseLink::Standard(StandardLink::CLogLog)
246 | InverseLink::Sas(_)
247 | InverseLink::BetaLogistic(_)
248 | InverseLink::Mixture(_)
249 )
250}
251
252#[inline]
253pub(crate) fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
254 match inverse_link {
255 // Why: canonical links keep V(mu) representable across the full f64 eta range; only guard against inf.
256 InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
257 eta.clamp(-ETA_CLAMP, ETA_CLAMP)
258 }
259 InverseLink::Standard(StandardLink::Identity) => eta,
260 // Why: probit mu=Phi(eta) saturates to 1.0 in f64 by |eta|~8.3; +/-6 keeps V=mu(1-mu) ~ 1e-9 representable.
261 InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
262 // Why: cloglog has mu~exp(eta) for eta<<0 (underflows below ~-23) and 1-mu~exp(-exp(eta)) collapses by eta=3.
263 InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
264 eta.clamp(-23.0, 3.0)
265 }
266 // Why: SAS / beta-logistic / mixture compose logistic-like sigmoids that saturate by |eta|~20 (logistic(20)~1-2e-9).
267 InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
268 eta.clamp(-20.0, 20.0)
269 }
270 }
271}
272
273/// Returns true at rows where PIRLS clamped η (so the observed-info weights
274/// were computed at the clamped value, making `∂W/∂η` zero w.r.t. the
275/// **unclamped** η). Outer REML/LAML derivative formulas must mask `c_obs`
276/// and `d_obs` to zero on these rows or the analytic ∂H/∂ψ disagrees with
277/// the H whose log-det we differentiate.
278#[inline]
279pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
280 let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
281 clamped != eta
282}
283
284/// Build solver-conditioned weights from the exact hessian weights.
285///
286/// The returned array applies a solver-only floor per observation so the
287/// Newton linear system X'W X + S stays numerically usable. This floor is
288/// purely a linear-algebra concern: the exact statistical weights stored in
289/// `lasthessian_weights` / `finalweights` are not affected.
290pub(crate) fn solver_hessian_weights_into(
291 hessian_weights: &Array1<f64>,
292 fisher_weights: &Array1<f64>,
293 out: &mut Array1<f64>,
294) {
295 if out.len() != hessian_weights.len() {
296 *out = Array1::<f64>::zeros(hessian_weights.len());
297 }
298 ndarray::Zip::from(out)
299 .and(hessian_weights)
300 .and(fisher_weights)
301 .par_for_each(|o, &w, &fw| {
302 let floor = solver_hessian_weight_floor(fw);
303 *o = if w.is_finite() && w > floor { w } else { floor };
304 });
305}
306
307/// Compute vectorised observed-information curvature arrays (w_obs, c_obs, d_obs)
308/// for the Hessian surface at the mode.
309///
310/// This function is the primary entry point for obtaining the observed weights
311/// that flow into the outer REML/LAML Hessian H_obs = X' W_obs X + S. The
312/// observed corrections include residual-dependent terms that vanish for
313/// canonical links but are nonzero for probit, cloglog, SAS, mixture, Gamma-log,
314/// and other flexible links.
315///
316/// The output arrays are:
317/// - `hessian_weights`: W_obs per observation (exact; solver floor applied separately).
318/// - `hessian_c`: c_obs = dW_obs/deta per observation (for outer gradient C[v]).
319/// - `hessian_d`: d_obs = d^2W_obs/deta^2 per observation (for outer Hessian Q[v_k,v_l]).
320///
321/// See `observed_weight_noncanonical` for the per-observation formulas and
322/// response.md Section 3 for the mathematical justification of why observed
323/// (not Fisher) information is required.
324pub(crate) fn compute_observed_hessian_curvature_arrays_into(
325 likelihood: &GlmLikelihoodSpec,
326 inverse_link: &InverseLink,
327 eta: &Array1<f64>,
328 y: ArrayView1<'_, f64>,
329 fisher_weights: &Array1<f64>,
330 priorweights: ArrayView1<'_, f64>,
331 hessian_weights: &mut Array1<f64>,
332 hessian_c: &mut Array1<f64>,
333 hessian_d: &mut Array1<f64>,
334) -> Result<(), EstimationError> {
335 assert!(supports_observed_hessian_curvature_for_likelihood(
336 likelihood,
337 inverse_link
338 ));
339 let n = eta.len();
340 if hessian_weights.len() != n {
341 *hessian_weights = Array1::<f64>::zeros(n);
342 }
343 if hessian_c.len() != n {
344 *hessian_c = Array1::<f64>::zeros(n);
345 }
346 if hessian_d.len() != n {
347 *hessian_d = Array1::<f64>::zeros(n);
348 }
349
350 let weight_family = weight_family_for_glm_likelihood(likelihood);
351 let weight_link = weight_link_for_inverse_link(inverse_link);
352 let phi = fixed_glm_dispersion(likelihood);
353
354 // Parallel per-row weight assembly. At large scale (n = 320k) this loop
355 // dominates non-canonical paths because each row independently evaluates
356 // inverse-link jets and residual-dependent observed curvature. Write
357 // directly into reusable output slices rather than collecting row tuples,
358 // which removes an O(n) temporary allocation on every PIRLS update.
359 hessian_weights
360 .as_slice_mut()
361 .expect("hessian weights must be contiguous")
362 .par_iter_mut()
363 .zip(
364 hessian_c
365 .as_slice_mut()
366 .expect("hessian c must be contiguous")
367 .par_iter_mut(),
368 )
369 .zip(
370 hessian_d
371 .as_slice_mut()
372 .expect("hessian d must be contiguous")
373 .par_iter_mut(),
374 )
375 .enumerate()
376 .try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
377 let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
378 // Why: closed-form observed_weight_noncanonical requires (mu, d1..d3, h4) at one consistent eta;
379 // mixing PIRLS-state jets at unclamped eta with h4 at eta_used produced 0/0 in phi_v* divisions,
380 // surfacing as: "observed Hessian curvature is not positive finite at row N: observed=NaN, fisher=0".
381 let jet =
382 crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
383 let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
384 inverse_link, eta_used,
385 )?;
386 let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
387 weight_family,
388 weight_link,
389 eta_used,
390 y[i],
391 jet.mu,
392 phi,
393 priorweights[i].max(0.0),
394 jet,
395 h4,
396 );
397 let fisher_weight = fisher_weights[i].max(0.0);
398 // A *finite* but non-positive observed weight is NOT a failure: the
399 // observed information `W_obs = W_Fisher - (y-μ)·B` legitimately goes
400 // indefinite on individual rows for a non-canonical link (probit,
401 // cloglog, SAS, and — critically for #1598 — a blended/mixture link)
402 // whenever a large residual flips the sign of the residual-dependent
403 // correction. The inner Newton system never uses this raw value: it
404 // is clamped to the SPD floor `max(W_Fisher·1e-6, 1e-12)` by
405 // `solver_hessian_weights_into`, and the outer REML/LAML derivative
406 // path applies the *same* floor through `outer_hessian_curvature_arrays`
407 // (which also zeroes c/d on the floored row). Both consumers are
408 // designed precisely to absorb an indefinite W_obs, so hard-bailing
409 // here defeats that stabilization and aborts an otherwise well-posed
410 // solve — the mixture/SAS joint link fit on data its own pure
411 // components fit trivially (clean logit under blended(logit, probit)).
412 //
413 // We therefore reject ONLY a genuinely non-finite (NaN/Inf) weight,
414 // which signals a broken jet rather than benign indefiniteness, and
415 // pass finite values (including non-positive ones) straight through to
416 // the flooring consumers. Likewise `c_obs`/`d_obs` only need to be
417 // finite; they are zeroed automatically on any floored row downstream.
418 if !w_obs.is_finite() {
419 crate::bail_invalid_estim!(
420 "observed Hessian curvature is not finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
421 );
422 }
423 if !c_obs.is_finite() || !d_obs.is_finite() {
424 crate::bail_invalid_estim!(
425 "observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
426 );
427 }
428 *w_out = w_obs;
429 *c_out = c_obs;
430 *d_out = d_obs;
431 Ok(())
432 })
433}
434
435pub(crate) fn compute_observed_hessian_curvature_arrays(
436 likelihood: &GlmLikelihoodSpec,
437 inverse_link: &InverseLink,
438 eta: &Array1<f64>,
439 y: ArrayView1<'_, f64>,
440 fisher_weights: &Array1<f64>,
441 priorweights: ArrayView1<'_, f64>,
442) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
443 let n = eta.len();
444 let mut hessian_weights = Array1::<f64>::zeros(n);
445 let mut hessian_c = Array1::<f64>::zeros(n);
446 let mut hessian_d = Array1::<f64>::zeros(n);
447 compute_observed_hessian_curvature_arrays_into(
448 likelihood,
449 inverse_link,
450 eta,
451 y,
452 fisher_weights,
453 priorweights,
454 &mut hessian_weights,
455 &mut hessian_c,
456 &mut hessian_d,
457 )?;
458 Ok((hessian_weights, hessian_c, hessian_d))
459}
460
461/// Per-observation observed-information weights and their first two
462/// eta-derivatives for a general exponential-dispersion family with a
463/// noncanonical link.
464///
465/// The observed weight differs from the Fisher (expected) weight by a
466/// residual-dependent correction (see response.md Section 3):
467///
468/// W_obs = W_Fisher - (y - mu) * B
469/// B = (h'' V - h'^2 V') / (phi V^2)
470///
471/// c_obs = c_Fisher + h' * B - (y - mu) * B_eta
472/// d_obs = d_Fisher + h'' * B + 2*h' * B_eta - (y - mu) * B_etaeta
473///
474/// For canonical links (for example logit-Binomial and log-Poisson), B = 0
475/// so observed = Fisher and no correction is needed.
476///
477/// These observed quantities are required for:
478/// 1. The outer REML/LAML Hessian H_obs = X' W_obs X + S (log|H| term).
479/// 2. The outer gradient's C[v] correction (uses c_obs).
480/// 3. The outer Hessian's Q[v_k, v_l] correction (uses d_obs).
481///
482/// Using Fisher weights in the outer REML would yield a PQL-type surrogate
483/// rather than the exact Laplace approximation.
484///
485/// # Arguments
486/// * `y` -- response value
487/// * `mu` -- fitted mean h(eta)
488/// * `h1`...`h4` -- inverse-link derivatives h'(eta) ... h''''(eta)
489/// * `vj` -- variance-function jet (V, V', V'', V''') evaluated at mu
490/// * `phi` -- dispersion parameter (1.0 for Bernoulli/Poisson)
491/// * `pw` -- prior weight for this observation
492///
493/// # Returns
494/// `(w_obs, c_obs, d_obs)` -- the observed weight and its first two
495/// eta-derivatives, all pre-multiplied by `pw`.
496#[inline]
497pub fn observed_weight_noncanonical(
498 y: f64,
499 mu: f64,
500 h1: f64,
501 h2: f64,
502 h3: f64,
503 h4: f64,
504 vj: VarianceJet,
505 phi: f64,
506 pw: f64,
507) -> (f64, f64, f64) {
508 let VarianceJet {
509 v,
510 v1,
511 v2,
512 v3,
513 v4: _,
514 } = vj;
515 let phi_v = phi * v;
516 let phi_v2 = phi * v * v;
517 let phi_v3 = phi * v * v * v;
518
519 // ---- Fisher weight and derivatives ----
520 let h1_sq = h1 * h1;
521 let w_f = h1_sq / phi_v;
522
523 // c_F = (2 h₁ h₂ V − h₁³ V₁) / (φ V²)
524 let n0 = h1_sq; // numerator of w_F
525 let n1 = 2.0 * h1 * h2; // ∂(h₁²)/∂η
526 let n2 = 2.0 * (h2 * h2 + h1 * h3); // ∂²(h₁²)/∂η²
527 let vd1 = h1 * v1; // ∂V/∂η = V'·h'
528 let vd2 = h2 * v1 + h1_sq * v2; // ∂²V/∂η²
529
530 let c_f = (n1 * v - n0 * vd1) / phi_v2;
531
532 // d_F = ∂c_F/∂η via quotient rule on c_F = (n1·v − n0·vd1) / (φ·v²)
533 // numerator of c_F and its η-derivative (cross terms cancel):
534 let numer_cf = n1 * v - n0 * vd1;
535 let dnumer_cf = n2 * v - n0 * vd2;
536 let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
537
538 // ---- Observed correction term B and its η-derivatives ----
539 // B = (h₂ V − h₁² V₁) / (φ V²)
540 let b_num = h2 * v - h1_sq * v1;
541 let b = b_num / phi_v2;
542
543 // B_η = (h₃ V² − 3 h₁ h₂ V V₁ − h₁³ V V₂ + 2 h₁³ V₁²) / (φ V³)
544 let b_eta_num =
545 h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
546 let b_eta = b_eta_num / phi_v3;
547
548 // B_ηη = ∂B_η/∂η.
549 //
550 // We differentiate b_eta_num / (φ V³) using the quotient rule.
551 //
552 // Numerator derivative of b_eta_num w.r.t. η, using chain rule ∂/∂η = h₁·∂/∂μ
553 // for the V-dependent parts:
554 //
555 // ∂/∂η [h₃ V²] = h₄ V² + 2 h₃ V h₁ V₁
556 // ∂/∂η [3 h₁ h₂ V V₁] = 3(h₂² + h₁ h₃)V V₁ + 3 h₁ h₂(h₁ V₁² + V h₁ V₂)
557 // ∂/∂η [h₁³ V V₂] = 3 h₁² h₂ V V₂ + h₁³(h₁ V₁ V₂ + V h₁ V₃)
558 // ∂/∂η [2 h₁³ V₁²] = 6 h₁² h₂ V₁² + 4 h₁³ V₁ h₁ V₂
559 // = 6 h₁² h₂ V₁² + 4 h1_sq * h1_sq * v1 * v2
560 //
561 // Denominator derivative: ∂/∂η [φ V³] = 3 φ V² h₁ V₁.
562
563 let h1_cu = h1_sq * h1;
564 let h1_qu = h1_sq * h1_sq;
565
566 let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
567 - 3.0 * (h2 * h2 + h1 * h3) * v * v1
568 - 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
569 - 3.0 * h1_sq * h2 * v * v2
570 - h1_cu * (h1 * v1 * v2 + v * h1 * v3)
571 + 6.0 * h1_sq * h2 * v1 * v1
572 + 4.0 * h1_qu * v1 * v2;
573
574 let phi_v4 = phi_v3 * v;
575 let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
576
577 // ---- Assemble observed quantities ----
578 let resid = y - mu;
579
580 let w_obs = w_f - resid * b;
581 let c_obs = c_f + h1 * b - resid * b_eta;
582 let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
583
584 (pw * w_obs, pw * c_obs, pw * d_obs)
585}
586
587/// Per-observation third η-derivative of the observed-information weight,
588/// `e_obs := ∂³W_obs/∂η³`, for a general exponential-dispersion family with
589/// any (canonical or non-canonical) link.
590///
591/// Closed-form derivation:
592/// Define `T(η) := h₁(η)/(φ V(μ(η)))`. Then
593/// * Fisher weight `W_F = h₁ · T`
594/// * Observed correction `B = T'`, so `B_η = T''`, `B_ηη = T'''`,
595/// `B_ηηη = T''''`
596/// * `W_obs = W_F − (y−μ) · T'`
597///
598/// Differentiating three times:
599/// `∂³W_obs/∂η³ = W_F''' + h₃·T' + 3 h₂·T'' + 3 h₁·T''' − (y−μ)·T''''`
600///
601/// `T` is computed via Leibniz on `T·Q = h₁` with `Q = φV`; `W_F` via
602/// Leibniz on `W_F·1 = h₁·T` (product rule).
603///
604/// All inverse-link derivatives `h₁..h₅` and variance-function derivatives
605/// `V..V₄` are required as inputs. Caller supplies them.
606///
607/// Returns `pw * e_obs` (pre-multiplied by the prior weight) so the result
608/// scales identically to `(w_obs, c_obs, d_obs)` from
609/// `observed_weight_noncanonical`.
610#[inline]
611pub fn e_obs_from_jets(
612 y: f64,
613 mu: f64,
614 h1: f64,
615 h2: f64,
616 h3: f64,
617 h4: f64,
618 h5: f64,
619 vj: VarianceJet,
620 phi: f64,
621 pw: f64,
622) -> f64 {
623 let VarianceJet { v, v1, v2, v3, v4 } = vj;
624 let q = phi * v;
625
626 // Q = φV and its η-derivatives.
627 // Q' = φ V₁ h₁
628 // Q'' = φ (V₁ h₂ + V₂ h₁²)
629 // Q''' = φ (V₁ h₃ + 3 V₂ h₁ h₂ + V₃ h₁³)
630 // Q'''' = φ (V₁ h₄ + 4 V₂ h₁ h₃ + 3 V₂ h₂² + 6 V₃ h₁² h₂ + V₄ h₁⁴)
631 let h1_sq = h1 * h1;
632 let h1_cu = h1_sq * h1;
633 let h1_qu = h1_sq * h1_sq;
634
635 let q1 = phi * v1 * h1;
636 let q2 = phi * (v1 * h2 + v2 * h1_sq);
637 let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
638 let q4 = phi
639 * (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
640
641 // T = h₁/Q and T', T'', T''', T'''' via Leibniz on T·Q = h₁.
642 // T' = (h₂ − T·Q')/Q
643 // T'' = (h₃ − 2 T'·Q' − T·Q'')/Q
644 // T''' = (h₄ − 3 T''·Q' − 3 T'·Q'' − T·Q''')/Q
645 // T'''' = (h₅ − 4 T'''·Q' − 6 T''·Q'' − 4 T'·Q''' − T·Q'''')/Q
646 let t0 = h1 / q;
647 let t1 = (h2 - t0 * q1) / q;
648 let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
649 let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
650 let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
651
652 // Fisher weight derivatives via product rule on W_F = h₁·T.
653 // W_F^(0) = h₁ T
654 // W_F^(1) = h₁ T₁ + h₂ T
655 // W_F^(2) = h₁ T₂ + 2 h₂ T₁ + h₃ T
656 // W_F^(3) = h₁ T₃ + 3 h₂ T₂ + 3 h₃ T₁ + h₄ T
657 let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
658
659 // Observed third derivative: differentiate W_obs = W_F − (y−μ)·T₁ thrice.
660 // (resid)' = −h₁, so iterating product rule yields
661 // ∂³((y−μ)·T₁)/∂η³ = −h₃·T₁ − 3 h₂·T₂ − 3 h₁·T₃ + (y−μ)·T₄
662 let resid = y - mu;
663 let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
664
665 pw * e_obs
666}
667
668// Direct (closed-form) observed-information weights for specific family-link
669// combinations. These avoid the overhead of the generic noncanonical formula
670// when the algebra simplifies.
671
672/// Gaussian family with log link: y ~ N(μ, φ), μ = exp(η).
673///
674/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
675///
676/// ```text
677/// w_obs = ω μ(2μ − y) / φ
678/// c_obs = ω μ(4μ − y) / φ
679/// d_obs = ω μ(8μ − y) / φ
680/// ```
681#[inline]
682pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
683 let inv_phi = pw / phi;
684 let w = inv_phi * mu * (2.0 * mu - y);
685 let c = inv_phi * mu * (4.0 * mu - y);
686 let d = inv_phi * mu * (8.0 * mu - y);
687 (w, c, d)
688}
689
690/// Gaussian family with inverse link: y ~ N(μ, φ), μ = 1/η.
691///
692/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
693///
694/// ```text
695/// w_obs = ω (3 − 2ηy) / (φ η⁴)
696/// c_obs = 6ω (ηy − 2) / (φ η⁵)
697/// d_obs = 12ω (5 − 2ηy) / (φ η⁶)
698/// ```
699#[inline]
700pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
701 let eta2 = eta * eta;
702 let eta4 = eta2 * eta2;
703 let eta5 = eta4 * eta;
704 let eta6 = eta4 * eta2;
705 let ey = eta * y;
706 let inv_phi = pw / phi;
707 let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
708 let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
709 let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
710 (w, c, d)
711}
712
713#[inline]
714pub(crate) fn observed_weight_binomial_logit_from_jet(
715 n_trials: f64,
716 jet: MixtureInverseLinkJet,
717 pw: f64,
718) -> (f64, f64, f64) {
719 let scale = pw * n_trials;
720 (scale * jet.d1, scale * jet.d2, scale * jet.d3)
721}
722
723/// Family tag for the observed-information weight dispatch.
724///
725/// This is a simplified family tag that identifies the variance function,
726/// independent of the link function. It is used by [`observed_weight_dispatch`]
727/// to select closed-form weight specializations.
728#[derive(Debug, Clone, Copy, PartialEq)]
729pub enum WeightFamily {
730 Gaussian,
731 Binomial,
732 Poisson,
733 Tweedie { p: f64 },
734 NegativeBinomial { theta: f64 },
735 Beta { phi: f64 },
736 Gamma,
737}
738
739/// Link tag for the observed-information weight dispatch.
740///
741/// Identifies the link function for selecting closed-form weight
742/// specializations in [`observed_weight_dispatch`].
743#[derive(Debug, Clone, Copy, PartialEq, Eq)]
744pub enum WeightLink {
745 Identity,
746 Log,
747 Logit,
748 Inverse,
749 /// Any other link — falls back to the generic noncanonical formula.
750 Other,
751}
752
753#[inline]
754pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
755 match family {
756 WeightFamily::Gaussian => VarianceJet::gaussian(),
757 WeightFamily::Binomial => VarianceJet::binomial_n(mu),
758 WeightFamily::Poisson => VarianceJet::poisson(mu),
759 WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
760 WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
761 WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
762 WeightFamily::Gamma => VarianceJet::gamma(mu),
763 }
764}
765
766/// Dispatch to closed-form observed-information weights for known family-link
767/// combinations, falling back to the generic noncanonical formula.
768///
769/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight.
770///
771/// For the `Binomial + Logit` case, `n_trials` is passed as `phi` (dispersion
772/// slot is unused for binomial) and the prior weight controls the
773/// observation-level scaling. For all other cases, `phi` is the dispersion
774/// parameter.
775///
776/// `jet` and `h4` are the inverse-link derivatives used by the generic
777/// noncanonical fallback path. They may be zero for the specialized paths.
778pub fn observed_weight_dispatch(
779 family: WeightFamily,
780 link: WeightLink,
781 eta: f64,
782 y: f64,
783 mu: f64,
784 phi: f64,
785 prior_weight: f64,
786 jet: MixtureInverseLinkJet,
787 h4: f64,
788) -> (f64, f64, f64) {
789 match (family, link) {
790 (WeightFamily::Gaussian, WeightLink::Log) => {
791 observed_weight_gaussian_log(y, mu, phi, prior_weight)
792 }
793 (WeightFamily::Gaussian, WeightLink::Inverse) => {
794 observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
795 }
796 (WeightFamily::Binomial, WeightLink::Logit) => {
797 observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
798 }
799 _ => {
800 // Generic noncanonical path via the full variance-function jet.
801 let vj = variance_jet_for_weight_family(family, mu);
802 observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
803 }
804 }
805}
806
807#[derive(Clone)]
808pub enum DirectionalWorkingCurvature {
809 /// Directional derivative of the PIRLS curvature when the working
810 /// curvature is diagonal in observation space:
811 /// W_τ = diag(w_τ).
812 Diagonal(Array1<f64>),
813}
814
815pub fn directionalworking_curvature_from_c_array(
816 c_array: &Array1<f64>,
817 hessian_weights: &Array1<f64>,
818 eta_direction: &Array1<f64>,
819) -> DirectionalWorkingCurvature {
820 let mut w_direction = c_array * eta_direction;
821 for i in 0..w_direction.len() {
822 if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
823 w_direction[i] = 0.0;
824 }
825 }
826 DirectionalWorkingCurvature::Diagonal(w_direction)
827}