gam_solve/pirls/curvature.rs
1//! Curvature primitives: the variance-function jet, observed-information
2//! Hessian weights, and the weight-family / weight-link classification used to
3//! choose between Fisher and observed curvature per family.
4
5use super::*;
6
7pub struct VarianceJet {
8 pub v: f64,
9 pub v1: f64,
10 pub v2: f64,
11 pub v3: f64,
12 pub v4: f64,
13}
14
15impl VarianceJet {
16 /// Lower floor on μ before evaluating power-law variance functions, so that
17 /// `μ^(p−k)` derivatives stay finite as μ → 0 instead of producing inf/NaN.
18 const VARIANCE_MU_FLOOR: f64 = 1e-10;
19
20 /// Bernoulli / binomial variance V(μ) = μ(1−μ).
21 #[inline]
22 pub fn bernoulli(mu: f64) -> Self {
23 Self {
24 v: mu * (1.0 - mu),
25 v1: 1.0 - 2.0 * mu,
26 v2: -2.0,
27 v3: 0.0,
28 v4: 0.0,
29 }
30 }
31
32 /// Poisson variance V(μ) = μ.
33 #[inline]
34 pub fn poisson(mu: f64) -> Self {
35 Self {
36 v: mu,
37 v1: 1.0,
38 v2: 0.0,
39 v3: 0.0,
40 v4: 0.0,
41 }
42 }
43
44 /// Gamma variance V(μ) = μ².
45 #[inline]
46 pub fn gamma(mu: f64) -> Self {
47 Self {
48 v: mu * mu,
49 v1: 2.0 * mu,
50 v2: 2.0,
51 v3: 0.0,
52 v4: 0.0,
53 }
54 }
55
56 /// Tweedie variance V(μ) = μ^p.
57 #[inline]
58 pub fn tweedie(mu: f64, p: f64) -> Self {
59 let mu = mu.max(Self::VARIANCE_MU_FLOOR);
60 Self {
61 v: mu.powf(p),
62 v1: p * mu.powf(p - 1.0),
63 v2: p * (p - 1.0) * mu.powf(p - 2.0),
64 v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
65 v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
66 }
67 }
68
69 /// Negative-binomial variance V(μ) = μ + μ² / theta.
70 #[inline]
71 pub fn negative_binomial(mu: f64, theta: f64) -> Self {
72 let mu = mu.max(Self::VARIANCE_MU_FLOOR);
73 let inv_theta = if valid_negbin_theta(theta) {
74 1.0 / theta
75 } else {
76 f64::NAN
77 };
78 Self {
79 v: mu + mu * mu * inv_theta,
80 v1: 1.0 + 2.0 * mu * inv_theta,
81 v2: 2.0 * inv_theta,
82 v3: 0.0,
83 v4: 0.0,
84 }
85 }
86
87 /// Gaussian (identity) variance V(μ) = 1.
88 #[inline]
89 pub fn gaussian() -> Self {
90 Self {
91 v: 1.0,
92 v1: 0.0,
93 v2: 0.0,
94 v3: 0.0,
95 v4: 0.0,
96 }
97 }
98
99 /// Binomial(n, p) variance V(p) = p(1−p), identical to Bernoulli.
100 ///
101 /// The trial count `n` enters as a prior-weight multiplier, not through
102 /// the variance function itself.
103 #[inline]
104 pub fn binomial_n(mu: f64) -> Self {
105 // V(μ) = μ(1−μ), same jet as Bernoulli
106 Self::bernoulli(mu)
107 }
108
109 /// Beta-regression variance V(μ) = μ(1−μ)/(1+φ).
110 #[inline]
111 pub fn beta(mu: f64, phi: f64) -> Self {
112 let scale = 1.0 / (1.0 + phi.max(1e-12));
113 let base = Self::bernoulli(mu);
114 Self {
115 v: base.v * scale,
116 v1: base.v1 * scale,
117 v2: base.v2 * scale,
118 v3: 0.0,
119 v4: 0.0,
120 }
121 }
122}
123
124pub(crate) const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
125
126pub(crate) const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
127
128/// Returns the per-row floor `max(fisher · 1e-6, 1e-12)` used by PIRLS to
129/// stabilize the observed-information Hessian H = X' W X + S. Saturated
130/// rows where W_obs ≤ floor were silently raised to `floor` when PIRLS
131/// built the inner Hessian; outer REML/LAML derivatives must use the
132/// **same** floored W to keep `H` and `dH/dψ` on one surface.
133///
134/// This is the single source of truth for the floor formula. Both the
135/// inner solver (`solver_hessian_weights_into`) and the outer derivative
136/// path (`outer_hessian_curvature_arrays`) route through this helper so
137/// the inner-stabilized H and the outer dH/dψ cannot drift apart.
138#[inline]
139pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
140 (fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
141 .max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
142}
143
144/// Build the (W, c, d) triple that matches PIRLS's stabilized H = X' W X + S.
145///
146/// PIRLS internally uses `W[i] = max(W_obs[i], floor(W_F[i]))` to keep H PD,
147/// but `pirls_result.finalweights` stores the **unfloored** observed weights.
148/// Reusing those directly in `∂H/∂ψ = X_τ' W X + … + X' diag(c · X_τ β̂) X`
149/// produces an operator that disagrees with `H` at every saturated row — a
150/// 5%-Frobenius bias that `tr(G_ε(H) · op)` amplifies by O(1/σ_min(H)),
151/// driving the analytic gradient off by orders of magnitude.
152///
153/// This helper returns the floored W, plus c and d masked to zero wherever
154/// the floor is active (so `∂W/∂η` is zero on the constant-floor branch).
155pub fn outer_hessian_curvature_arrays(
156 hessian_weights: gam_linalg::matrix::SignedWeightsView<'_>,
157 fisher_weights: gam_linalg::matrix::PsdWeightsView<'_>,
158 c_array: &Array1<f64>,
159 d_array: &Array1<f64>,
160 eta: &Array1<f64>,
161 inverse_link: &InverseLink,
162) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
163 let hessian_view = hessian_weights.view();
164 let fisher_view = fisher_weights.view();
165 let n = hessian_view.len();
166 let mut w_out = Array1::<f64>::zeros(n);
167 let mut c_out = Array1::<f64>::zeros(n);
168 let mut d_out = Array1::<f64>::zeros(n);
169 for i in 0..n {
170 let floor = solver_hessian_weight_floor(fisher_view[i]);
171 let w = hessian_view[i];
172 let clamp_active = eta_clamp_active(inverse_link, eta[i]);
173 let w_below_floor = !(w.is_finite() && w > floor);
174 if w_below_floor {
175 w_out[i] = floor;
176 c_out[i] = 0.0;
177 d_out[i] = 0.0;
178 } else if clamp_active {
179 w_out[i] = w;
180 c_out[i] = 0.0;
181 d_out[i] = 0.0;
182 } else {
183 w_out[i] = w;
184 c_out[i] = c_array[i];
185 d_out[i] = d_array[i];
186 }
187 }
188 (w_out, c_out, d_out)
189}
190
191#[inline]
192pub(crate) fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
193 likelihood.fixed_phi().unwrap_or(1.0)
194}
195
196#[inline]
197pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
198 match &likelihood.spec.response {
199 ResponseFamily::Gaussian => WeightFamily::Gaussian,
200 ResponseFamily::Poisson => WeightFamily::Poisson,
201 ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
202 ResponseFamily::NegativeBinomial { theta, .. } => {
203 WeightFamily::NegativeBinomial { theta: *theta }
204 }
205 ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
206 ResponseFamily::Gamma => WeightFamily::Gamma,
207 ResponseFamily::Binomial => WeightFamily::Binomial,
208 ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
209 }
210}
211
212#[inline]
213pub(crate) fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
214 match inverse_link {
215 InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
216 InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
217 InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
218 InverseLink::Standard(StandardLink::Probit)
219 | InverseLink::Standard(StandardLink::CLogLog)
220 | InverseLink::Standard(StandardLink::LogLog)
221 | InverseLink::Standard(StandardLink::Cauchit)
222 | InverseLink::LatentCLogLog(_)
223 | InverseLink::Sas(_)
224 | InverseLink::BetaLogistic(_)
225 | InverseLink::Mixture(_) => WeightLink::Other,
226 }
227}
228
229#[inline]
230pub(crate) fn supports_observed_hessian_curvature_for_likelihood(
231 likelihood: &GlmLikelihoodSpec,
232 inverse_link: &InverseLink,
233) -> bool {
234 let spec = &likelihood.spec;
235 if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
236 return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
237 }
238 if matches!(spec.response, ResponseFamily::Gamma) {
239 return true;
240 }
241 if !matches!(spec.response, ResponseFamily::Binomial) {
242 return false;
243 }
244 matches!(
245 spec.link,
246 InverseLink::Standard(StandardLink::Probit)
247 | InverseLink::Standard(StandardLink::CLogLog)
248 | InverseLink::Standard(StandardLink::LogLog)
249 | InverseLink::Standard(StandardLink::Cauchit)
250 | InverseLink::Sas(_)
251 | InverseLink::BetaLogistic(_)
252 | InverseLink::Mixture(_)
253 )
254}
255
256#[inline]
257pub(crate) fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
258 match inverse_link {
259 // Why: canonical links keep V(mu) representable across the full f64 eta range; only guard against inf.
260 InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
261 eta.clamp(-ETA_CLAMP, ETA_CLAMP)
262 }
263 InverseLink::Standard(StandardLink::Identity) => eta,
264 // Why: probit mu=Phi(eta) saturates to 1.0 in f64 by |eta|~8.3; +/-6 keeps V=mu(1-mu) ~ 1e-9 representable.
265 InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
266 // Why: cloglog has mu~exp(eta) for eta<<0 (underflows below ~-23) and 1-mu~exp(-exp(eta)) collapses by eta=3.
267 InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
268 eta.clamp(-23.0, 3.0)
269 }
270 InverseLink::Standard(StandardLink::LogLog) => eta.clamp(-3.0, 23.0),
271 InverseLink::Standard(StandardLink::Cauchit) => eta.clamp(-1.0e6, 1.0e6),
272 // Why: SAS / beta-logistic / mixture compose logistic-like sigmoids that saturate by |eta|~20 (logistic(20)~1-2e-9).
273 InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
274 eta.clamp(-20.0, 20.0)
275 }
276 }
277}
278
279/// Returns true at rows where PIRLS clamped η (so the observed-info weights
280/// were computed at the clamped value, making `∂W/∂η` zero w.r.t. the
281/// **unclamped** η). Outer REML/LAML derivative formulas must mask `c_obs`
282/// and `d_obs` to zero on these rows or the analytic ∂H/∂ψ disagrees with
283/// the H whose log-det we differentiate.
284#[inline]
285pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
286 let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
287 clamped != eta
288}
289
290/// Build solver-conditioned weights from the exact hessian weights.
291///
292/// The returned array applies a solver-only floor per observation so the
293/// Newton linear system X'W X + S stays numerically usable. This floor is
294/// purely a linear-algebra concern: the exact statistical weights stored in
295/// `lasthessian_weights` / `finalweights` are not affected.
296pub(crate) fn solver_hessian_weights_into(
297 hessian_weights: &Array1<f64>,
298 fisher_weights: &Array1<f64>,
299 out: &mut Array1<f64>,
300) {
301 if out.len() != hessian_weights.len() {
302 *out = Array1::<f64>::zeros(hessian_weights.len());
303 }
304 ndarray::Zip::from(out)
305 .and(hessian_weights)
306 .and(fisher_weights)
307 .par_for_each(|o, &w, &fw| {
308 let floor = solver_hessian_weight_floor(fw);
309 *o = if w.is_finite() && w > floor { w } else { floor };
310 });
311}
312
313/// Compute vectorised observed-information curvature arrays (w_obs, c_obs, d_obs)
314/// for the Hessian surface at the mode.
315///
316/// This function is the primary entry point for obtaining the observed weights
317/// that flow into the outer REML/LAML Hessian H_obs = X' W_obs X + S. The
318/// observed corrections include residual-dependent terms that vanish for
319/// canonical links but are nonzero for probit, cloglog, SAS, mixture, Gamma-log,
320/// and other flexible links.
321///
322/// The output arrays are:
323/// - `hessian_weights`: W_obs per observation (exact; solver floor applied separately).
324/// - `hessian_c`: c_obs = dW_obs/deta per observation (for outer gradient C[v]).
325/// - `hessian_d`: d_obs = d^2W_obs/deta^2 per observation (for outer Hessian Q[v_k,v_l]).
326///
327/// See `observed_weight_noncanonical` for the per-observation formulas and
328/// response.md Section 3 for the mathematical justification of why observed
329/// (not Fisher) information is required.
330pub(crate) fn compute_observed_hessian_curvature_arrays_into(
331 likelihood: &GlmLikelihoodSpec,
332 inverse_link: &InverseLink,
333 eta: &Array1<f64>,
334 y: ArrayView1<'_, f64>,
335 fisher_weights: &Array1<f64>,
336 priorweights: ArrayView1<'_, f64>,
337 hessian_weights: &mut Array1<f64>,
338 hessian_c: &mut Array1<f64>,
339 hessian_d: &mut Array1<f64>,
340) -> Result<(), EstimationError> {
341 assert!(supports_observed_hessian_curvature_for_likelihood(
342 likelihood,
343 inverse_link
344 ));
345 let n = eta.len();
346 if hessian_weights.len() != n {
347 *hessian_weights = Array1::<f64>::zeros(n);
348 }
349 if hessian_c.len() != n {
350 *hessian_c = Array1::<f64>::zeros(n);
351 }
352 if hessian_d.len() != n {
353 *hessian_d = Array1::<f64>::zeros(n);
354 }
355
356 let weight_family = weight_family_for_glm_likelihood(likelihood);
357 let weight_link = weight_link_for_inverse_link(inverse_link);
358 let phi = fixed_glm_dispersion(likelihood);
359
360 // Parallel per-row weight assembly. At large scale (n = 320k) this loop
361 // dominates non-canonical paths because each row independently evaluates
362 // inverse-link jets and residual-dependent observed curvature. Write
363 // directly into reusable output slices rather than collecting row tuples,
364 // which removes an O(n) temporary allocation on every PIRLS update.
365 hessian_weights
366 .as_slice_mut()
367 .expect("hessian weights must be contiguous")
368 .par_iter_mut()
369 .zip(
370 hessian_c
371 .as_slice_mut()
372 .expect("hessian c must be contiguous")
373 .par_iter_mut(),
374 )
375 .zip(
376 hessian_d
377 .as_slice_mut()
378 .expect("hessian d must be contiguous")
379 .par_iter_mut(),
380 )
381 .enumerate()
382 .try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
383 let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
384 // Why: closed-form observed_weight_noncanonical requires (mu, d1..d3, h4) at one consistent eta;
385 // mixing PIRLS-state jets at unclamped eta with h4 at eta_used produced 0/0 in phi_v* divisions,
386 // surfacing as: "observed Hessian curvature is not positive finite at row N: observed=NaN, fisher=0".
387 let jet =
388 crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
389 let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
390 inverse_link, eta_used,
391 )?;
392 let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
393 weight_family,
394 weight_link,
395 eta_used,
396 y[i],
397 jet.mu,
398 phi,
399 priorweights[i].max(0.0),
400 jet,
401 h4,
402 );
403 let fisher_weight = fisher_weights[i].max(0.0);
404 // A *finite* but non-positive observed weight is NOT a failure: the
405 // observed information `W_obs = W_Fisher - (y-μ)·B` legitimately goes
406 // indefinite on individual rows for a non-canonical link (probit,
407 // cloglog, SAS, and — critically for #1598 — a blended/mixture link)
408 // whenever a large residual flips the sign of the residual-dependent
409 // correction. The inner Newton system never uses this raw value: it
410 // is clamped to the SPD floor `max(W_Fisher·1e-6, 1e-12)` by
411 // `solver_hessian_weights_into`, and the outer REML/LAML derivative
412 // path applies the *same* floor through `outer_hessian_curvature_arrays`
413 // (which also zeroes c/d on the floored row). Both consumers are
414 // designed precisely to absorb an indefinite W_obs, so hard-bailing
415 // here defeats that stabilization and aborts an otherwise well-posed
416 // solve — the mixture/SAS joint link fit on data its own pure
417 // components fit trivially (clean logit under blended(logit, probit)).
418 //
419 // We therefore reject ONLY a genuinely non-finite (NaN/Inf) weight,
420 // which signals a broken jet rather than benign indefiniteness, and
421 // pass finite values (including non-positive ones) straight through to
422 // the flooring consumers. Likewise `c_obs`/`d_obs` only need to be
423 // finite; they are zeroed automatically on any floored row downstream.
424 if !w_obs.is_finite() {
425 crate::bail_invalid_estim!(
426 "observed Hessian curvature is not finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
427 );
428 }
429 if !c_obs.is_finite() || !d_obs.is_finite() {
430 crate::bail_invalid_estim!(
431 "observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
432 );
433 }
434 *w_out = w_obs;
435 *c_out = c_obs;
436 *d_out = d_obs;
437 Ok(())
438 })
439}
440
441pub(crate) fn compute_observed_hessian_curvature_arrays(
442 likelihood: &GlmLikelihoodSpec,
443 inverse_link: &InverseLink,
444 eta: &Array1<f64>,
445 y: ArrayView1<'_, f64>,
446 fisher_weights: &Array1<f64>,
447 priorweights: ArrayView1<'_, f64>,
448) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
449 let n = eta.len();
450 let mut hessian_weights = Array1::<f64>::zeros(n);
451 let mut hessian_c = Array1::<f64>::zeros(n);
452 let mut hessian_d = Array1::<f64>::zeros(n);
453 compute_observed_hessian_curvature_arrays_into(
454 likelihood,
455 inverse_link,
456 eta,
457 y,
458 fisher_weights,
459 priorweights,
460 &mut hessian_weights,
461 &mut hessian_c,
462 &mut hessian_d,
463 )?;
464 Ok((hessian_weights, hessian_c, hessian_d))
465}
466
467/// Per-observation observed-information weights and their first two
468/// eta-derivatives for a general exponential-dispersion family with a
469/// noncanonical link.
470///
471/// The observed weight differs from the Fisher (expected) weight by a
472/// residual-dependent correction (see response.md Section 3):
473///
474/// W_obs = W_Fisher - (y - mu) * B
475/// B = (h'' V - h'^2 V') / (phi V^2)
476///
477/// c_obs = c_Fisher + h' * B - (y - mu) * B_eta
478/// d_obs = d_Fisher + h'' * B + 2*h' * B_eta - (y - mu) * B_etaeta
479///
480/// For canonical links (for example logit-Binomial and log-Poisson), B = 0
481/// so observed = Fisher and no correction is needed.
482///
483/// These observed quantities are required for:
484/// 1. The outer REML/LAML Hessian H_obs = X' W_obs X + S (log|H| term).
485/// 2. The outer gradient's C[v] correction (uses c_obs).
486/// 3. The outer Hessian's Q[v_k, v_l] correction (uses d_obs).
487///
488/// Using Fisher weights in the outer REML would yield a PQL-type surrogate
489/// rather than the exact Laplace approximation.
490///
491/// # Arguments
492/// * `y` -- response value
493/// * `mu` -- fitted mean h(eta)
494/// * `h1`...`h4` -- inverse-link derivatives h'(eta) ... h''''(eta)
495/// * `vj` -- variance-function jet (V, V', V'', V''') evaluated at mu
496/// * `phi` -- dispersion parameter (1.0 for Bernoulli/Poisson)
497/// * `pw` -- prior weight for this observation
498///
499/// # Returns
500/// `(w_obs, c_obs, d_obs)` -- the observed weight and its first two
501/// eta-derivatives, all pre-multiplied by `pw`.
502#[inline]
503pub fn observed_weight_noncanonical(
504 y: f64,
505 mu: f64,
506 h1: f64,
507 h2: f64,
508 h3: f64,
509 h4: f64,
510 vj: VarianceJet,
511 phi: f64,
512 pw: f64,
513) -> (f64, f64, f64) {
514 let VarianceJet {
515 v,
516 v1,
517 v2,
518 v3,
519 v4: _,
520 } = vj;
521 let phi_v = phi * v;
522 let phi_v2 = phi * v * v;
523 let phi_v3 = phi * v * v * v;
524
525 // ---- Fisher weight and derivatives ----
526 let h1_sq = h1 * h1;
527 let w_f = h1_sq / phi_v;
528
529 // c_F = (2 h₁ h₂ V − h₁³ V₁) / (φ V²)
530 let n0 = h1_sq; // numerator of w_F
531 let n1 = 2.0 * h1 * h2; // ∂(h₁²)/∂η
532 let n2 = 2.0 * (h2 * h2 + h1 * h3); // ∂²(h₁²)/∂η²
533 let vd1 = h1 * v1; // ∂V/∂η = V'·h'
534 let vd2 = h2 * v1 + h1_sq * v2; // ∂²V/∂η²
535
536 let c_f = (n1 * v - n0 * vd1) / phi_v2;
537
538 // d_F = ∂c_F/∂η via quotient rule on c_F = (n1·v − n0·vd1) / (φ·v²)
539 // numerator of c_F and its η-derivative (cross terms cancel):
540 let numer_cf = n1 * v - n0 * vd1;
541 let dnumer_cf = n2 * v - n0 * vd2;
542 let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
543
544 // ---- Observed correction term B and its η-derivatives ----
545 // B = (h₂ V − h₁² V₁) / (φ V²)
546 let b_num = h2 * v - h1_sq * v1;
547 let b = b_num / phi_v2;
548
549 // B_η = (h₃ V² − 3 h₁ h₂ V V₁ − h₁³ V V₂ + 2 h₁³ V₁²) / (φ V³)
550 let b_eta_num =
551 h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
552 let b_eta = b_eta_num / phi_v3;
553
554 // B_ηη = ∂B_η/∂η.
555 //
556 // We differentiate b_eta_num / (φ V³) using the quotient rule.
557 //
558 // Numerator derivative of b_eta_num w.r.t. η, using chain rule ∂/∂η = h₁·∂/∂μ
559 // for the V-dependent parts:
560 //
561 // ∂/∂η [h₃ V²] = h₄ V² + 2 h₃ V h₁ V₁
562 // ∂/∂η [3 h₁ h₂ V V₁] = 3(h₂² + h₁ h₃)V V₁ + 3 h₁ h₂(h₁ V₁² + V h₁ V₂)
563 // ∂/∂η [h₁³ V V₂] = 3 h₁² h₂ V V₂ + h₁³(h₁ V₁ V₂ + V h₁ V₃)
564 // ∂/∂η [2 h₁³ V₁²] = 6 h₁² h₂ V₁² + 4 h₁³ V₁ h₁ V₂
565 // = 6 h₁² h₂ V₁² + 4 h1_sq * h1_sq * v1 * v2
566 //
567 // Denominator derivative: ∂/∂η [φ V³] = 3 φ V² h₁ V₁.
568
569 let h1_cu = h1_sq * h1;
570 let h1_qu = h1_sq * h1_sq;
571
572 let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
573 - 3.0 * (h2 * h2 + h1 * h3) * v * v1
574 - 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
575 - 3.0 * h1_sq * h2 * v * v2
576 - h1_cu * (h1 * v1 * v2 + v * h1 * v3)
577 + 6.0 * h1_sq * h2 * v1 * v1
578 + 4.0 * h1_qu * v1 * v2;
579
580 let phi_v4 = phi_v3 * v;
581 let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
582
583 // ---- Assemble observed quantities ----
584 let resid = y - mu;
585
586 let w_obs = w_f - resid * b;
587 let c_obs = c_f + h1 * b - resid * b_eta;
588 let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
589
590 (pw * w_obs, pw * c_obs, pw * d_obs)
591}
592
593/// Per-observation third η-derivative of the observed-information weight,
594/// `e_obs := ∂³W_obs/∂η³`, for a general exponential-dispersion family with
595/// any (canonical or non-canonical) link.
596///
597/// Closed-form derivation:
598/// Define `T(η) := h₁(η)/(φ V(μ(η)))`. Then
599/// * Fisher weight `W_F = h₁ · T`
600/// * Observed correction `B = T'`, so `B_η = T''`, `B_ηη = T'''`,
601/// `B_ηηη = T''''`
602/// * `W_obs = W_F − (y−μ) · T'`
603///
604/// Differentiating three times:
605/// `∂³W_obs/∂η³ = W_F''' + h₃·T' + 3 h₂·T'' + 3 h₁·T''' − (y−μ)·T''''`
606///
607/// `T` is computed via Leibniz on `T·Q = h₁` with `Q = φV`; `W_F` via
608/// Leibniz on `W_F·1 = h₁·T` (product rule).
609///
610/// All inverse-link derivatives `h₁..h₅` and variance-function derivatives
611/// `V..V₄` are required as inputs. Caller supplies them.
612///
613/// Returns `pw * e_obs` (pre-multiplied by the prior weight) so the result
614/// scales identically to `(w_obs, c_obs, d_obs)` from
615/// `observed_weight_noncanonical`.
616#[inline]
617pub fn e_obs_from_jets(
618 y: f64,
619 mu: f64,
620 h1: f64,
621 h2: f64,
622 h3: f64,
623 h4: f64,
624 h5: f64,
625 vj: VarianceJet,
626 phi: f64,
627 pw: f64,
628) -> f64 {
629 let VarianceJet { v, v1, v2, v3, v4 } = vj;
630 let q = phi * v;
631
632 // Q = φV and its η-derivatives.
633 // Q' = φ V₁ h₁
634 // Q'' = φ (V₁ h₂ + V₂ h₁²)
635 // Q''' = φ (V₁ h₃ + 3 V₂ h₁ h₂ + V₃ h₁³)
636 // Q'''' = φ (V₁ h₄ + 4 V₂ h₁ h₃ + 3 V₂ h₂² + 6 V₃ h₁² h₂ + V₄ h₁⁴)
637 let h1_sq = h1 * h1;
638 let h1_cu = h1_sq * h1;
639 let h1_qu = h1_sq * h1_sq;
640
641 let q1 = phi * v1 * h1;
642 let q2 = phi * (v1 * h2 + v2 * h1_sq);
643 let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
644 let q4 = phi
645 * (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
646
647 // T = h₁/Q and T', T'', T''', T'''' via Leibniz on T·Q = h₁.
648 // T' = (h₂ − T·Q')/Q
649 // T'' = (h₃ − 2 T'·Q' − T·Q'')/Q
650 // T''' = (h₄ − 3 T''·Q' − 3 T'·Q'' − T·Q''')/Q
651 // T'''' = (h₅ − 4 T'''·Q' − 6 T''·Q'' − 4 T'·Q''' − T·Q'''')/Q
652 let t0 = h1 / q;
653 let t1 = (h2 - t0 * q1) / q;
654 let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
655 let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
656 let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
657
658 // Fisher weight derivatives via product rule on W_F = h₁·T.
659 // W_F^(0) = h₁ T
660 // W_F^(1) = h₁ T₁ + h₂ T
661 // W_F^(2) = h₁ T₂ + 2 h₂ T₁ + h₃ T
662 // W_F^(3) = h₁ T₃ + 3 h₂ T₂ + 3 h₃ T₁ + h₄ T
663 let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
664
665 // Observed third derivative: differentiate W_obs = W_F − (y−μ)·T₁ thrice.
666 // (resid)' = −h₁, so iterating product rule yields
667 // ∂³((y−μ)·T₁)/∂η³ = −h₃·T₁ − 3 h₂·T₂ − 3 h₁·T₃ + (y−μ)·T₄
668 let resid = y - mu;
669 let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
670
671 pw * e_obs
672}
673
674// Direct (closed-form) observed-information weights for specific family-link
675// combinations. These avoid the overhead of the generic noncanonical formula
676// when the algebra simplifies.
677
678/// Gaussian family with log link: y ~ N(μ, φ), μ = exp(η).
679///
680/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
681///
682/// ```text
683/// w_obs = ω μ(2μ − y) / φ
684/// c_obs = ω μ(4μ − y) / φ
685/// d_obs = ω μ(8μ − y) / φ
686/// ```
687#[inline]
688pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
689 let inv_phi = pw / phi;
690 let w = inv_phi * mu * (2.0 * mu - y);
691 let c = inv_phi * mu * (4.0 * mu - y);
692 let d = inv_phi * mu * (8.0 * mu - y);
693 (w, c, d)
694}
695
696/// Gaussian family with inverse link: y ~ N(μ, φ), μ = 1/η.
697///
698/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight `pw`.
699///
700/// ```text
701/// w_obs = ω (3 − 2ηy) / (φ η⁴)
702/// c_obs = 6ω (ηy − 2) / (φ η⁵)
703/// d_obs = 12ω (5 − 2ηy) / (φ η⁶)
704/// ```
705#[inline]
706pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
707 let eta2 = eta * eta;
708 let eta4 = eta2 * eta2;
709 let eta5 = eta4 * eta;
710 let eta6 = eta4 * eta2;
711 let ey = eta * y;
712 let inv_phi = pw / phi;
713 let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
714 let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
715 let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
716 (w, c, d)
717}
718
719#[inline]
720pub(crate) fn observed_weight_binomial_logit_from_jet(
721 n_trials: f64,
722 jet: MixtureInverseLinkJet,
723 pw: f64,
724) -> (f64, f64, f64) {
725 let scale = pw * n_trials;
726 (scale * jet.d1, scale * jet.d2, scale * jet.d3)
727}
728
729/// Family tag for the observed-information weight dispatch.
730///
731/// This is a simplified family tag that identifies the variance function,
732/// independent of the link function. It is used by [`observed_weight_dispatch`]
733/// to select closed-form weight specializations.
734#[derive(Debug, Clone, Copy, PartialEq)]
735pub enum WeightFamily {
736 Gaussian,
737 Binomial,
738 Poisson,
739 Tweedie { p: f64 },
740 NegativeBinomial { theta: f64 },
741 Beta { phi: f64 },
742 Gamma,
743}
744
745/// Link tag for the observed-information weight dispatch.
746///
747/// Identifies the link function for selecting closed-form weight
748/// specializations in [`observed_weight_dispatch`].
749#[derive(Debug, Clone, Copy, PartialEq, Eq)]
750pub enum WeightLink {
751 Identity,
752 Log,
753 Logit,
754 Inverse,
755 /// Any other link — falls back to the generic noncanonical formula.
756 Other,
757}
758
759#[inline]
760pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
761 match family {
762 WeightFamily::Gaussian => VarianceJet::gaussian(),
763 WeightFamily::Binomial => VarianceJet::binomial_n(mu),
764 WeightFamily::Poisson => VarianceJet::poisson(mu),
765 WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
766 WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
767 WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
768 WeightFamily::Gamma => VarianceJet::gamma(mu),
769 }
770}
771
772/// Dispatch to closed-form observed-information weights for known family-link
773/// combinations, falling back to the generic noncanonical formula.
774///
775/// Returns `(w_obs, c_obs, d_obs)` pre-multiplied by the prior weight.
776///
777/// For the `Binomial + Logit` case, `n_trials` is passed as `phi` (dispersion
778/// slot is unused for binomial) and the prior weight controls the
779/// observation-level scaling. For all other cases, `phi` is the dispersion
780/// parameter.
781///
782/// `jet` and `h4` are the inverse-link derivatives used by the generic
783/// noncanonical fallback path. They may be zero for the specialized paths.
784pub fn observed_weight_dispatch(
785 family: WeightFamily,
786 link: WeightLink,
787 eta: f64,
788 y: f64,
789 mu: f64,
790 phi: f64,
791 prior_weight: f64,
792 jet: MixtureInverseLinkJet,
793 h4: f64,
794) -> (f64, f64, f64) {
795 match (family, link) {
796 (WeightFamily::Gaussian, WeightLink::Log) => {
797 observed_weight_gaussian_log(y, mu, phi, prior_weight)
798 }
799 (WeightFamily::Gaussian, WeightLink::Inverse) => {
800 observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
801 }
802 (WeightFamily::Binomial, WeightLink::Logit) => {
803 observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
804 }
805 _ => {
806 // Generic noncanonical path via the full variance-function jet.
807 let vj = variance_jet_for_weight_family(family, mu);
808 observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
809 }
810 }
811}
812
813#[derive(Clone)]
814pub enum DirectionalWorkingCurvature {
815 /// Directional derivative of the PIRLS curvature when the working
816 /// curvature is diagonal in observation space:
817 /// W_τ = diag(w_τ).
818 Diagonal(Array1<f64>),
819}
820
821pub fn directionalworking_curvature_from_c_array(
822 c_array: &Array1<f64>,
823 hessian_weights: &Array1<f64>,
824 eta_direction: &Array1<f64>,
825) -> DirectionalWorkingCurvature {
826 let mut w_direction = c_array * eta_direction;
827 for i in 0..w_direction.len() {
828 if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
829 w_direction[i] = 0.0;
830 }
831 }
832 DirectionalWorkingCurvature::Diagonal(w_direction)
833}