Skip to main content

gam_solve/
mixture_link.rs

1use crate::estimate::EstimationError;
2use gam_math::probability::{normal_cdf, normal_pdf};
3use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
4use crate::quadrature::latent_cloglog_jet5;
5use gam_problem::{
6    InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7    MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16/// Bound B used by the bounded sinh-arcsinh log-delta parameterisation:
17/// `delta = exp(B * tanh(raw_log_delta / B))`. Exposed for the outer-strategy
18/// edge-barrier helpers in `solver/estimate.rs` that previously had to
19/// hard-code the same `12.0` with a "must match" comment.
20pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24    static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25    QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30    state: &LatentCLogLogState,
31    eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33    let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34    Ok(InverseLinkJet {
35        mu: jet.mean,
36        d1: jet.d1,
37        d2: jet.d2,
38        d3: jet.d3,
39    })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44    pub mu: f64,
45    pub d1: f64,
46    pub d2: f64,
47    pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52    pub mu: f64,
53    pub d1: f64,
54    pub d2: f64,
55    pub d3: f64,
56    pub d4: f64,
57    pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62    if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67    jet.d1 = canonicalzero(jet.d1);
68    jet.d2 = canonicalzero(jet.d2);
69    jet.d3 = canonicalzero(jet.d3);
70    jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75    if eta.is_nan() {
76        return LogitJet5 {
77            mu: f64::NAN,
78            d1: f64::NAN,
79            d2: f64::NAN,
80            d3: f64::NAN,
81            d4: f64::NAN,
82            d5: f64::NAN,
83        };
84    }
85    if eta == f64::INFINITY {
86        return LogitJet5 {
87            mu: 1.0,
88            d1: 0.0,
89            d2: 0.0,
90            d3: 0.0,
91            d4: 0.0,
92            d5: 0.0,
93        };
94    }
95    if eta == f64::NEG_INFINITY {
96        return LogitJet5 {
97            mu: 0.0,
98            d1: 0.0,
99            d2: 0.0,
100            d3: 0.0,
101            d4: 0.0,
102            d5: 0.0,
103        };
104    }
105
106    let jet = if eta >= 0.0 {
107        let z = (-eta).exp();
108        let opz = 1.0 + z;
109        let opz2 = opz * opz;
110        let opz3 = opz2 * opz;
111        let opz4 = opz3 * opz;
112        let opz5 = opz4 * opz;
113        let opz6 = opz5 * opz;
114        let z2 = z * z;
115        let z3 = z2 * z;
116        let z4 = z3 * z;
117        LogitJet5 {
118            mu: 1.0 / opz,
119            d1: z / opz2,
120            d2: z * (z - 1.0) / opz3,
121            d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122            d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123            d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124        }
125    } else {
126        let z = eta.exp();
127        let opz = 1.0 + z;
128        let opz2 = opz * opz;
129        let opz3 = opz2 * opz;
130        let opz4 = opz3 * opz;
131        let opz5 = opz4 * opz;
132        let opz6 = opz5 * opz;
133        let z2 = z * z;
134        let z3 = z2 * z;
135        let z4 = z3 * z;
136        LogitJet5 {
137            mu: z / opz,
138            d1: z / opz2,
139            d2: z * (1.0 - z) / opz3,
140            d3: z * (1.0 - 4.0 * z + z2) / opz4,
141            d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142            d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143        }
144    };
145    LogitJet5 {
146        mu: jet.mu,
147        d1: canonicalzero(jet.d1),
148        d2: canonicalzero(jet.d2),
149        d3: canonicalzero(jet.d3),
150        d4: canonicalzero(jet.d4),
151        d5: canonicalzero(jet.d5),
152    }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157    // Exact probit semantics:
158    //
159    //   mu(eta) = Phi(eta),
160    //   mu'     = phi(eta),
161    //   mu''    = -eta * phi(eta),
162    //   mu'''   = (eta^2 - 1) * phi(eta).
163    //
164    // `normal_cdf` now evaluates the exact special-function form
165    // Phi(x) = 0.5 * erfc(-x / sqrt(2)), so the jet can and should use the
166    // matching closed-form Gaussian identities directly.
167    if eta.is_nan() {
168        return InverseLinkJet {
169            mu: f64::NAN,
170            d1: f64::NAN,
171            d2: f64::NAN,
172            d3: f64::NAN,
173        };
174    }
175    if eta == f64::INFINITY {
176        return InverseLinkJet {
177            mu: 1.0,
178            d1: 0.0,
179            d2: 0.0,
180            d3: 0.0,
181        };
182    }
183    if eta == f64::NEG_INFINITY {
184        return InverseLinkJet {
185            mu: 0.0,
186            d1: 0.0,
187            d2: 0.0,
188            d3: 0.0,
189        };
190    }
191    let x = eta;
192    let phi = normal_pdf(x);
193    InverseLinkJet {
194        mu: normal_cdf(x),
195        d1: phi,
196        d2: -x * phi,
197        d3: (x * x - 1.0) * phi,
198    }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203    // Since d1 = mu' = phi(eta), this returns
204    //
205    //   d³/deta³ d1 = mu'''' = -(eta³ - 3 eta) phi(eta).
206    if eta.is_nan() {
207        return f64::NAN;
208    }
209    if !eta.is_finite() {
210        return 0.0;
211    }
212    let x = eta;
213    let phi = normal_pdf(x);
214    canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219    // mu''''' = Phi^{(5)}(eta) = (eta^4 - 6*eta^2 + 3) * phi(eta).
220    if eta.is_nan() {
221        return f64::NAN;
222    }
223    if !eta.is_finite() {
224        return 0.0;
225    }
226    let x = eta;
227    let phi = normal_pdf(x);
228    canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231/// Multiply two 5-term truncated Taylor series (coefficients `a_k = g^(k)/k!`,
232/// `k = 0..=4`) and return the truncated product coefficients.
233#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235    let mut c = [0.0_f64; 5];
236    for i in 0..5 {
237        let ai = a[i];
238        if ai == 0.0 {
239            continue;
240        }
241        for j in 0..(5 - i) {
242            c[i + j] += ai * b[j];
243        }
244    }
245    c
246}
247
248/// Reciprocal of a 5-term truncated Taylor series with nonzero constant term.
249#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251    let mut b = [0.0_f64; 5];
252    b[0] = 1.0 / a[0];
253    for k in 1..5 {
254        let mut s = 0.0_f64;
255        for j in 1..=k {
256            s += a[j] * b[k - j];
257        }
258        b[k] = -s * b[0];
259    }
260    b
261}
262
263/// 5-jet (value + four eta-derivatives) of the GLM Fisher working weight
264/// `W(eta) = mu'(eta)^2 / V(mu(eta))` for the requested standard link, returned
265/// as `(W, W', W'', W''', W'''')`.
266///
267/// For the canonical logit link this is exactly the binomial weight
268/// `W = mu(1 - mu) = mu'`, whose eta-derivatives are the higher derivatives of
269/// the inverse-link jet (`W^(k) = mu^(k+1)`); the dispatch returns
270/// `logit_inverse_link_jet5`'s `d1..d5` byte-for-byte so the existing Firth
271/// logit path is numerically unchanged.
272///
273/// Noncanonical Bernoulli links use the same truncated Taylor-series quotient:
274/// assemble the inverse-link jet through `mu^(5)`, square the `mu'` series, and
275/// divide by the Bernoulli variance series `mu(1-mu)`. As the variance
276/// denominator saturates to zero in either tail, the weight and all derivatives
277/// saturate to zero, matching the inverse-link jet convention.
278pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279    match link {
280        StandardLink::Logit => {
281            let jet = logit_inverse_link_jet5(eta);
282            (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283        }
284        StandardLink::Probit => probit_fisher_weight_jet5(eta),
285        StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286        StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
287    }
288}
289
290pub(crate) fn fisher_weight_jet5_for_inverse_link(
291    link: &InverseLink,
292    eta: f64,
293) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
294    match link {
295        InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
296        InverseLink::LatentCLogLog(_)
297        | InverseLink::Sas(_)
298        | InverseLink::BetaLogistic(_)
299        | InverseLink::Mixture(_) => {
300            let jet = link.jet(eta)?;
301            let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
302            let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
303            Ok(fisher_weight_jet5_from_inverse_link_derivatives(
304                jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
305            ))
306        }
307    }
308}
309
310#[inline]
311pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
312    matches!(
313        link,
314        InverseLink::Standard(StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,)
315            | InverseLink::LatentCLogLog(_)
316            | InverseLink::Sas(_)
317            | InverseLink::BetaLogistic(_)
318            | InverseLink::Mixture(_)
319    )
320}
321
322#[inline]
323fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
324    let jet = component_inverse_link_jet(component, eta);
325    let d4 = component_inverse_link_pdfthird_derivative(component, eta);
326    let d5 = component_inverse_link_pdffourth_derivative(component, eta);
327    fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
328}
329
330#[inline]
331fn fisher_weight_jet5_from_inverse_link_derivatives(
332    mu: f64,
333    d1: f64,
334    d2: f64,
335    d3: f64,
336    d4: f64,
337    d5: f64,
338) -> (f64, f64, f64, f64, f64) {
339    if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
340        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
341    }
342    let variance = mu * (1.0 - mu);
343    if !(variance > 0.0) || !variance.is_finite() {
344        return (0.0, 0.0, 0.0, 0.0, 0.0);
345    }
346
347    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
348    let mu_d = [mu, d1, d2, d3, d4];
349    let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
350    let dmu_d = [d1, d2, d3, d4, d5];
351    let mut mu_t = [0.0_f64; 5];
352    let mut one_minus_mu_t = [0.0_f64; 5];
353    let mut dmu_t = [0.0_f64; 5];
354    for k in 0..5 {
355        let inv_fact = 1.0 / factorial[k];
356        mu_t[k] = mu_d[k] * inv_fact;
357        one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
358        dmu_t[k] = dmu_d[k] * inv_fact;
359    }
360    let num_t = taylor5_mul(&dmu_t, &dmu_t);
361    let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
362    if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
363        return (0.0, 0.0, 0.0, 0.0, 0.0);
364    }
365    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
366    (
367        canonicalzero(w_t[0] * factorial[0]),
368        canonicalzero(w_t[1] * factorial[1]),
369        canonicalzero(w_t[2] * factorial[2]),
370        canonicalzero(w_t[3] * factorial[3]),
371        canonicalzero(w_t[4] * factorial[4]),
372    )
373}
374
375/// Probit Bernoulli Fisher-weight 5-jet `W = phi^2 / (Phi (1 - Phi))` and its
376/// first four eta-derivatives. See [`fisher_weight_jet5`].
377#[inline]
378fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
379    if eta.is_nan() {
380        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
381    }
382    if !eta.is_finite() {
383        return (0.0, 0.0, 0.0, 0.0, 0.0);
384    }
385    let x = eta;
386    let p = normal_cdf(x);
387    // Compute the complement directly via Phi(-x) rather than `1 - Phi(x)`:
388    // in the positive tail `Phi(x)` rounds to 1.0 and `1 - Phi(x)` cancels to
389    // zero, whereas `Phi(-x)` retains the accurate (tiny) tail mass.
390    let q = normal_cdf(-x);
391    let phi = normal_pdf(x);
392    // Saturated tail: the denominator Phi(1-Phi) has underflowed to zero (or
393    // would divide by zero); the working weight and all derivatives go to zero.
394    if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
395        return (0.0, 0.0, 0.0, 0.0, 0.0);
396    }
397    // Gaussian derivative ladder: phi^(k) for k = 0..=4 using phi' = -x phi.
398    let phi1 = -x * phi;
399    let phi2 = (x * x - 1.0) * phi;
400    let phi3 = -(x * x * x - 3.0 * x) * phi;
401    let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
402    // Derivative arrays (d^k/deta^k) for f = phi, p = Phi, q = 1 - Phi.
403    // p^(0) = Phi, p^(k>=1) = phi^(k-1); q is the negated complement.
404    let f_d = [phi, phi1, phi2, phi3, phi4];
405    let p_d = [p, phi, phi1, phi2, phi3];
406    let q_d = [q, -phi, -phi1, -phi2, -phi3];
407    // Convert derivative arrays to Taylor coefficients a_k = g^(k)/k!.
408    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
409    let mut f_t = [0.0_f64; 5];
410    let mut p_t = [0.0_f64; 5];
411    let mut q_t = [0.0_f64; 5];
412    for k in 0..5 {
413        let inv_fact = 1.0 / factorial[k];
414        f_t[k] = f_d[k] * inv_fact;
415        p_t[k] = p_d[k] * inv_fact;
416        q_t[k] = q_d[k] * inv_fact;
417    }
418    let num_t = taylor5_mul(&f_t, &f_t);
419    let den_t = taylor5_mul(&p_t, &q_t);
420    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
421    // Back to derivatives W^(k) = w_t[k] * k!.
422    (
423        canonicalzero(w_t[0] * factorial[0]),
424        canonicalzero(w_t[1] * factorial[1]),
425        canonicalzero(w_t[2] * factorial[2]),
426        canonicalzero(w_t[3] * factorial[3]),
427        canonicalzero(w_t[4] * factorial[4]),
428    )
429}
430
431#[inline]
432fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
433    InverseLinkJet {
434        mu: base.mu,
435        d1: base.d1 * z1,
436        d2: base.d2 * z1 * z1 + base.d1 * z2,
437        d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
438    }
439}
440
441#[inline]
442fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
443    match component {
444        LinkComponent::Probit => probit_pdfthird_derivative(eta),
445        LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
446        LinkComponent::CLogLog => {
447            // CLogLog link:
448            //   mu = 1 - exp(-t),  t = exp(eta),  d1 = t exp(-t).
449            //
450            // Repeated differentiation closes in the basis `d1 * poly(t)`:
451            //   d2 = d1(-t + 1)
452            //   d3 = d1(t² - 3t + 1)
453            //   d4 = d1(-t³ + 6t² - 7t + 1).
454            if eta.is_nan() {
455                return f64::NAN;
456            }
457            if !eta.is_finite() {
458                return 0.0;
459            }
460            let t = eta.exp();
461            canonicalzero(stable_nonnegative_poly_times_exp_neg(
462                t,
463                &[0.0, 1.0, -7.0, 6.0, -1.0],
464            ))
465        }
466        LinkComponent::LogLog => {
467            // LogLog link is the reflected cloglog family with `r = exp(-eta)`:
468            //   mu = exp(-r), d1 = mu r,
469            // and again higher derivatives are `d1 * poly(r)`:
470            //   d2 = d1(r - 1)
471            //   d3 = d1(r² - 3r + 1)
472            //   d4 = d1(r³ - 6r² + 7r - 1).
473            if eta.is_nan() {
474                return f64::NAN;
475            }
476            if !eta.is_finite() {
477                return 0.0;
478            }
479            let r = (-eta).exp();
480            canonicalzero(stable_nonnegative_poly_times_exp_neg(
481                r,
482                &[0.0, -1.0, 7.0, -6.0, 1.0],
483            ))
484        }
485        LinkComponent::Cauchit => {
486            // Cauchit link:
487            //   mu = 1/2 + atan(eta)/pi,
488            //   d1 = 1 / [pi (1+eta²)].
489            //
490            // Differentiating three more times gives
491            //
492            //   d4 = 24 eta (1-eta²) / [pi (1+eta²)^4].
493            if eta.is_nan() {
494                return f64::NAN;
495            }
496            if !eta.is_finite() {
497                return 0.0;
498            }
499            let denom = 1.0 + eta * eta;
500            24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
501        }
502    }
503}
504
505/// Fifth derivative of a component inverse-link CDF (= fourth derivative of PDF).
506/// Extends `component_inverse_link_pdfthird_derivative` by one derivative order.
507#[inline]
508fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
509    match component {
510        LinkComponent::Probit => probit_pdffourth_derivative(eta),
511        LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
512        LinkComponent::CLogLog => {
513            // Exact closed form:
514            //   d5 = exp(-t) * (t - 15t^2 + 25t^3 - 10t^4 + t^5)
515            //      = d1 * (1 - 15t + 25t^2 - 10t^3 + t^4),
516            // where t = exp(eta).
517            if eta.is_nan() {
518                return f64::NAN;
519            }
520            if !eta.is_finite() {
521                return 0.0;
522            }
523            let t = eta.exp();
524            canonicalzero(stable_nonnegative_poly_times_exp_neg(
525                t,
526                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
527            ))
528        }
529        LinkComponent::LogLog => {
530            // Exact closed form:
531            //   d5 = exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5)
532            //      = d1 * (1 - 15r + 25r^2 - 10r^3 + r^4),
533            // where r = exp(-eta).
534            if eta.is_nan() {
535                return f64::NAN;
536            }
537            if !eta.is_finite() {
538                return 0.0;
539            }
540            let r = (-eta).exp();
541            canonicalzero(stable_nonnegative_poly_times_exp_neg(
542                r,
543                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
544            ))
545        }
546        LinkComponent::Cauchit => {
547            // d5 = 24(1 - 10eta^2 + 5eta^4) / [pi * (1+eta^2)^5]
548            if eta.is_nan() {
549                return f64::NAN;
550            }
551            if !eta.is_finite() {
552                return 0.0;
553            }
554            let e2 = eta * eta;
555            let denom = 1.0 + e2;
556            24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
557        }
558    }
559}
560
561#[derive(Clone, Debug, PartialEq)]
562pub struct MixtureJetWithRhoPartials {
563    pub jet: InverseLinkJet,
564    /// Partial derivatives wrt free logits rho_j, j in [0, K-2].
565    /// Each entry stores derivatives of (mu, d1, d2, d3) wrt one rho_j.
566    pub djet_drho: Vec<InverseLinkJet>,
567}
568
569#[derive(Clone, Debug, PartialEq)]
570pub struct SasJetWithParamPartials {
571    pub jet: InverseLinkJet,
572    pub djet_depsilon: InverseLinkJet,
573    pub djet_dlog_delta: InverseLinkJet,
574}
575
576#[derive(Clone, Debug, PartialEq)]
577pub enum LinkParamPartials {
578    Mixture(MixtureJetWithRhoPartials),
579    Sas(SasJetWithParamPartials),
580}
581
582/// Trait-based inverse-link kernel interface.
583///
584/// Implementors provide pointwise inverse-link derivatives wrt `eta`:
585/// `F(eta), F'(eta), F''(eta), F'''(eta)`.
586/// Optionally they may expose parameter partials used by outer-loop optimization.
587pub trait InverseLinkKernel {
588    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
589
590    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
591        assert!(eta.is_finite(), "eta must be finite");
592        Ok(None)
593    }
594}
595
596#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
597pub struct ProbitLinkKernel;
598
599#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
600pub struct LogitLinkKernel;
601
602#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
603pub struct CLogLogLinkKernel;
604
605#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
606pub struct LogLogLinkKernel;
607
608#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
609pub struct CauchitLinkKernel;
610
611/// Construct SAS state from raw optimizer parameters using the same bounded
612/// transform used everywhere in fitting/evaluation.
613///
614/// A free function rather than an inherent `SasLinkState::new` because the
615/// bounded `delta` transform is solver-side math, so the constructor is hosted
616/// here next to the transform rather than on the type. `SasLinkState`'s fields
617/// are `pub`, so it builds directly.
618pub fn sas_link_state_from_raw(
619    raw_epsilon: f64,
620    raw_log_delta: f64,
621) -> Result<SasLinkState, String> {
622    if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
623        return Err("SAS link parameters must be finite".to_string());
624    }
625    Ok(SasLinkState {
626        epsilon: raw_epsilon,
627        log_delta: raw_log_delta,
628        delta: sas_delta_from_raw_log_delta(raw_log_delta),
629    })
630}
631
632pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
633    sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
634}
635
636pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
637    if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
638        return Err("Beta-Logistic link parameters must be finite".to_string());
639    }
640    // For Beta-Logistic, `log_delta` is the unconstrained log geometric-mean beta
641    // shape (the kernels' `log_shape_center`). Evaluation consumes `log_delta`,
642    // never `delta`, but keep the shared `SasLinkState::delta` field on the same
643    // bounded SAS parameterization used by `state_from_sasspec` so constructing a
644    // state from a large finite raw log-delta cannot overflow this derived field.
645    let log_shape_center = spec.initial_log_delta;
646    Ok(SasLinkState {
647        epsilon: spec.initial_epsilon,
648        log_delta: log_shape_center,
649        delta: sas_delta_from_raw_log_delta(log_shape_center),
650    })
651}
652
653#[inline]
654fn tanh_bound(value: f64, bound: f64) -> f64 {
655    let b = bound.max(f64::EPSILON);
656    b * (value / b).tanh()
657}
658
659#[inline]
660fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
661    let b = bound.max(f64::EPSILON);
662    let t = (value / b).tanh();
663    1.0 - t * t
664}
665
666#[inline]
667fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
668    let b = bound.max(f64::EPSILON);
669    let t = (value / b).tanh();
670    let s = 1.0 - t * t;
671    -2.0 * t * s / b
672}
673
674#[inline]
675fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
676    let b = bound.max(f64::EPSILON);
677    let t = (value / b).tanh();
678    let s = 1.0 - t * t;
679    -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
680}
681
682#[inline]
683fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
684    let b = bound.max(f64::EPSILON);
685    let t = (value / b).tanh();
686    let s = 1.0 - t * t;
687    8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
688}
689
690#[inline]
691fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
692    // 5th derivative of B * tanh(x/B):
693    //   g5 = 8 * s * (2 - 15*t^2 + 15*t^4) / B^4
694    // where t = tanh(x/B) and s = 1 - t^2.
695    let b = bound.max(f64::EPSILON);
696    let t = (value / b).tanh();
697    let s = 1.0 - t * t;
698    let t2 = t * t;
699    let b4 = b * b * b * b;
700    8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
701}
702
703#[inline]
704fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
705    let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
706    let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
707    (ld_eff, dld_eff_draw)
708}
709
710#[inline]
711fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
712    let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
713    ld_eff.exp()
714}
715
716pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
717    if spec.components.is_empty() {
718        return Err("mixture link requires at least 1 component".to_string());
719    }
720    if spec.initial_rho.len() + 1 != spec.components.len() {
721        return Err(format!(
722            "mixture link rho length mismatch: expected {}, got {}",
723            spec.components.len() - 1,
724            spec.initial_rho.len()
725        ));
726    }
727    for i in 0..spec.components.len() {
728        for j in (i + 1)..spec.components.len() {
729            if spec.components[i] == spec.components[j] {
730                return Err("mixture link components must be unique".to_string());
731            }
732        }
733    }
734    // `LinkComponent` admits two variants (Cauchit, LogLog) that have no matching
735    // `LinkFunction` entry. When two or more components are *blended*, the mixture-link
736    // pipeline projects the blend back onto a single `LinkFunction` value for downstream
737    // solver/IO bookkeeping (see `InverseLink::link_function`), so a multi-component blend
738    // composed solely of components without a LinkFunction representative would silently
739    // lie about its projected link. We therefore require any genuine *blend* (two or more
740    // components) to contain at least one Logit/Probit/CLogLog "anchor" so the projection
741    // is meaningful, and reject e.g. a blend of only {Cauchit, LogLog}.
742    //
743    // A *single-component* spec is not a blend at all: it is that one link, with weight
744    // 1.0 and no free mixing logits. `LinkComponent::LogLog` / `LinkComponent::Cauchit`
745    // implement their inverse link and derivative jets exactly, so a single-component
746    // `{LogLog}` / `{Cauchit}` spec is a fully-defined standalone link and is accepted
747    // here (this is how survival `--link loglog` / `--link cauchit` are represented).
748    let has_anchor = spec.components.iter().any(|component| {
749        matches!(
750            component,
751            LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
752        )
753    });
754    if !has_anchor && spec.components.len() > 1 {
755        let unsupported: Vec<&str> = spec
756            .components
757            .iter()
758            .map(|component| component.name())
759            .collect();
760        return Err(format!(
761            "mixture link components {{{}}} are unsupported: at least one component \
762             must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
763             projected LinkFunction is well defined; cauchit and loglog have no \
764             LinkFunction representative",
765            unsupported.join(", ")
766        ));
767    }
768    Ok(())
769}
770
771pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
772    let k = rho.len() + 1;
773    let mut logits = Vec::with_capacity(k);
774    let mut maxv = 0.0_f64;
775    for &v in rho {
776        maxv = maxv.max(v);
777        logits.push(v);
778    }
779    maxv = maxv.max(0.0);
780    logits.push(0.0);
781
782    let mut sum = 0.0_f64;
783    let mut exps = vec![0.0_f64; k];
784    for i in 0..k {
785        let e = (logits[i] - maxv).exp();
786        exps[i] = e;
787        sum += e;
788    }
789    if !sum.is_finite() || sum <= 0.0 {
790        return Array1::from_elem(k, 1.0 / k as f64);
791    }
792    let inv = 1.0 / sum;
793    Array1::from_iter(exps.into_iter().map(|v| v * inv))
794}
795
796/// Returns softmax weights and Jacobian wrt free logits (last logit fixed at zero).
797/// Jacobian shape is (K, K-1): d pi_k / d rho_j.
798pub fn softmaxwith_jacobian_last_fixedzero(
799    rho: &Array1<f64>,
800) -> (Array1<f64>, ndarray::Array2<f64>) {
801    let pi = softmax_last_fixedzero(rho);
802    let k = pi.len();
803    let m = k.saturating_sub(1);
804    let mut jac = ndarray::Array2::<f64>::zeros((k, m));
805    for j in 0..m {
806        let pi_j = pi[j];
807        for kk in 0..k {
808            let delta = if kk == j { 1.0 } else { 0.0 };
809            jac[[kk, j]] = pi[kk] * (delta - pi_j);
810        }
811    }
812    (pi, jac)
813}
814
815pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
816    validate_mixturespec(spec)?;
817    let pi = softmax_last_fixedzero(&spec.initial_rho);
818    Ok(MixtureLinkState {
819        components: spec.components.clone(),
820        rho: spec.initial_rho.clone(),
821        pi,
822    })
823}
824
825#[inline]
826pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
827    canonicalize_jet(match component {
828        LinkComponent::Logit => {
829            let jet = logit_inverse_link_jet5(eta);
830            InverseLinkJet {
831                mu: jet.mu,
832                d1: jet.d1,
833                d2: jet.d2,
834                d3: jet.d3,
835            }
836        }
837        LinkComponent::Probit => probit_jet(eta),
838        LinkComponent::CLogLog => {
839            if eta.is_nan() {
840                return InverseLinkJet {
841                    mu: f64::NAN,
842                    d1: f64::NAN,
843                    d2: f64::NAN,
844                    d3: f64::NAN,
845                };
846            }
847            let t = eta.exp();
848            if !t.is_finite() {
849                return InverseLinkJet {
850                    mu: 1.0,
851                    d1: 0.0,
852                    d2: 0.0,
853                    d3: 0.0,
854                };
855            }
856            InverseLinkJet {
857                mu: -(-t).exp_m1(),
858                d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
859                d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
860                d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
861            }
862        }
863        LinkComponent::LogLog => {
864            if eta.is_nan() {
865                return InverseLinkJet {
866                    mu: f64::NAN,
867                    d1: f64::NAN,
868                    d2: f64::NAN,
869                    d3: f64::NAN,
870                };
871            }
872            let r = (-eta).exp();
873            if !r.is_finite() {
874                return InverseLinkJet {
875                    mu: 0.0,
876                    d1: 0.0,
877                    d2: 0.0,
878                    d3: 0.0,
879                };
880            }
881            InverseLinkJet {
882                mu: (-r).exp(),
883                d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
884                d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
885                d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
886            }
887        }
888        LinkComponent::Cauchit => {
889            if eta.is_nan() {
890                return InverseLinkJet {
891                    mu: f64::NAN,
892                    d1: f64::NAN,
893                    d2: f64::NAN,
894                    d3: f64::NAN,
895                };
896            }
897            let den = 1.0 + eta * eta;
898            let d1 = if eta.is_finite() {
899                1.0 / (std::f64::consts::PI * den)
900            } else {
901                0.0
902            };
903            let d2 = if eta.is_finite() {
904                -2.0 * eta / (std::f64::consts::PI * den * den)
905            } else {
906                0.0
907            };
908            let d3 = if eta.is_finite() {
909                (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
910            } else {
911                0.0
912            };
913            InverseLinkJet {
914                mu: 0.5 + eta.atan() / std::f64::consts::PI,
915                d1,
916                d2,
917                d3,
918            }
919        }
920    })
921}
922
923impl InverseLinkKernel for ProbitLinkKernel {
924    #[inline]
925    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
926        Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
927    }
928}
929
930impl InverseLinkKernel for LogitLinkKernel {
931    #[inline]
932    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
933        Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
934    }
935}
936
937impl InverseLinkKernel for CLogLogLinkKernel {
938    #[inline]
939    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
940        Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
941    }
942}
943
944impl InverseLinkKernel for LogLogLinkKernel {
945    #[inline]
946    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
947        Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
948    }
949}
950
951impl InverseLinkKernel for CauchitLinkKernel {
952    #[inline]
953    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
954        Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
955    }
956}
957
958impl InverseLinkKernel for LinkComponent {
959    #[inline]
960    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
961        Ok(component_inverse_link_jet(*self, eta))
962    }
963}
964
965impl InverseLinkKernel for LinkFunction {
966    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
967        match self {
968            LinkFunction::Logit => LogitLinkKernel.jet(eta),
969            LinkFunction::Probit => ProbitLinkKernel.jet(eta),
970            LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
971            LinkFunction::Identity => Ok(InverseLinkJet {
972                mu: eta,
973                d1: 1.0,
974                d2: 0.0,
975                d3: 0.0,
976            }),
977            LinkFunction::Log => {
978                // SOLVER-INTERNAL inverse-link jet: `η.clamp(−700, 700).exp()`.
979                // The clamp is an intentional conditioning hack so the IRLS/REML
980                // normal equations stay well posed when η wanders into the tails
981                // during a trust-region step — it is NOT the public response
982                // transform. Public response-scale outputs (predictions, FFI
983                // `apply_inverse_link_array`, posterior bands) must use the EXACT
984                // `exp(η)` in `families::inverse_link::apply_inverse_link_vec`,
985                // which is finite wherever representable. Do not reroute a public
986                // output through this clamped jet (issue #963). Keep the clamp:
987                // solver consumers (e.g. `reml/runtime.rs` trust-region `excess`)
988                // pass raw η and rely on it to keep μ finite.
989                let e = eta.clamp(-700.0, 700.0).exp();
990                Ok(InverseLinkJet {
991                    mu: e,
992                    d1: e,
993                    d2: e,
994                    d3: e,
995                })
996            }
997            LinkFunction::Sas => Err(EstimationError::InvalidInput(
998                "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
999            )),
1000            LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1001                "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1002                    .to_string(),
1003            )),
1004        }
1005    }
1006}
1007
1008impl InverseLinkKernel for SasLinkState {
1009    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1010        Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1011    }
1012
1013    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1014        Ok(Some(LinkParamPartials::Sas(
1015            sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1016        )))
1017    }
1018}
1019
1020#[derive(Clone, Copy, Debug)]
1021pub struct BetaLogisticKernel {
1022    /// Unconstrained log of the geometric-mean beta shape — the raw optimization
1023    /// parameter `SasLinkState::log_delta`, NOT the derived `SasLinkState::delta`.
1024    pub log_shape_center: f64,
1025    pub epsilon: f64,
1026}
1027
1028impl InverseLinkKernel for BetaLogisticKernel {
1029    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1030        Ok(beta_logistic_inverse_link_jet(
1031            eta,
1032            self.log_shape_center,
1033            self.epsilon,
1034        ))
1035    }
1036
1037    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1038        Ok(Some(LinkParamPartials::Sas(
1039            beta_logistic_inverse_link_jetwith_param_partials(
1040                eta,
1041                self.log_shape_center,
1042                self.epsilon,
1043            ),
1044        )))
1045    }
1046}
1047
1048impl InverseLinkKernel for MixtureLinkState {
1049    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1050        Ok(mixture_inverse_link_jet(self, eta))
1051    }
1052
1053    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1054        Ok(Some(LinkParamPartials::Mixture(
1055            mixture_inverse_link_jetwith_rho_partials(self, eta),
1056        )))
1057    }
1058}
1059
1060impl InverseLinkKernel for InverseLink {
1061    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1062        match self {
1063            InverseLink::Standard(link_fn) => link_fn.as_link_function().jet(eta),
1064            InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1065            InverseLink::Sas(state) => state.jet(eta),
1066            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1067                log_shape_center: state.log_delta,
1068                epsilon: state.epsilon,
1069            }
1070            .jet(eta),
1071            InverseLink::Mixture(state) => state.jet(eta),
1072        }
1073    }
1074
1075    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1076        match self {
1077            InverseLink::Standard(_) => Ok(None),
1078            InverseLink::LatentCLogLog(_) => Ok(None),
1079            InverseLink::Sas(state) => state.param_partials(eta),
1080            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1081                log_shape_center: state.log_delta,
1082                epsilon: state.epsilon,
1083            }
1084            .param_partials(eta),
1085            InverseLink::Mixture(state) => state.param_partials(eta),
1086        }
1087    }
1088}
1089
1090/// Central family-aware inverse-link jet dispatch.
1091///
1092/// For `BinomialSas` and `BinomialMixture`, required state must be provided.
1093pub fn inverse_link_jet_for_inverse_link(
1094    link: &InverseLink,
1095    eta: f64,
1096) -> Result<InverseLinkJet, EstimationError> {
1097    link.jet(eta)
1098}
1099
1100/// Specialized `(mu, d1)` inverse-link evaluation that skips the d2/d3
1101/// polynomial chain used by the full jet. Numerical semantics are preserved:
1102/// the returned `mu` and `d1` are bit-identical to the corresponding fields of
1103/// `inverse_link_jet_for_inverse_link(link, eta)?` for every supported link.
1104///
1105/// For latent cloglog the underlying lognormal-Laplace kernel produces all
1106/// orders together, so this falls back to the full jet for that branch — the
1107/// savings come from the parameterised polynomial links (SAS, beta-logistic,
1108/// mixture) and the simple analytic links where d2/d3 are pure waste.
1109pub fn inverse_link_mu_d1_for_inverse_link(
1110    link: &InverseLink,
1111    eta: f64,
1112) -> Result<(f64, f64), EstimationError> {
1113    match link {
1114        InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1115        InverseLink::LatentCLogLog(state) => {
1116            let jet = latent_cloglog_point_jet(state, eta)?;
1117            Ok((jet.mu, jet.d1))
1118        }
1119        InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1120        InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1121            eta,
1122            state.log_delta,
1123            state.epsilon,
1124        )),
1125        InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1126    }
1127}
1128
1129fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1130    match link {
1131        LinkFunction::Identity => Ok((eta, 1.0)),
1132        LinkFunction::Log => {
1133            // SOLVER-INTERNAL clamped `(μ, dμ/dη)`; see the matching note on the
1134            // full `LinkFunction::Log` jet above. Public response transforms use
1135            // exact `exp(η)` via `families::inverse_link::apply_inverse_link_vec`
1136            // (issue #963).
1137            let e = eta.clamp(-700.0, 700.0).exp();
1138            Ok((e, e))
1139        }
1140        LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1141        LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1142        LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1143        LinkFunction::Sas => Err(EstimationError::InvalidInput(
1144            "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1145        )),
1146        LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1147            "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1148                .to_string(),
1149        )),
1150    }
1151}
1152
1153#[inline]
1154fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1155    // The full per-component jet already factors `mu` and `d1` exactly the same
1156    // way the higher orders are derived, so we either reuse the cheap closed
1157    // forms directly (Logit/Probit/CLogLog/LogLog/Cauchit) or fall back to the
1158    // existing canonicalised jet for the few cases without a separate fast
1159    // path — bit-identical to `component_inverse_link_jet(...).{mu,d1}`.
1160    match component {
1161        LinkComponent::Logit => {
1162            let jet = logit_inverse_link_jet5(eta);
1163            (jet.mu, canonicalzero(jet.d1))
1164        }
1165        LinkComponent::Probit => {
1166            if eta.is_nan() {
1167                return (f64::NAN, f64::NAN);
1168            }
1169            if eta == f64::INFINITY {
1170                return (1.0, 0.0);
1171            }
1172            if eta == f64::NEG_INFINITY {
1173                return (0.0, 0.0);
1174            }
1175            let phi = normal_pdf(eta);
1176            (normal_cdf(eta), canonicalzero(phi))
1177        }
1178        LinkComponent::CLogLog => {
1179            if eta.is_nan() {
1180                return (f64::NAN, f64::NAN);
1181            }
1182            let t = eta.exp();
1183            if !t.is_finite() {
1184                return (1.0, 0.0);
1185            }
1186            (
1187                -(-t).exp_m1(),
1188                canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1189            )
1190        }
1191        LinkComponent::LogLog => {
1192            if eta.is_nan() {
1193                return (f64::NAN, f64::NAN);
1194            }
1195            let r = (-eta).exp();
1196            if !r.is_finite() {
1197                return (0.0, 0.0);
1198            }
1199            (
1200                (-r).exp(),
1201                canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1202            )
1203        }
1204        LinkComponent::Cauchit => {
1205            if eta.is_nan() {
1206                return (f64::NAN, f64::NAN);
1207            }
1208            let den = 1.0 + eta * eta;
1209            let d1 = if eta.is_finite() {
1210                1.0 / (std::f64::consts::PI * den)
1211            } else {
1212                0.0
1213            };
1214            (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1215        }
1216    }
1217}
1218
1219fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1220    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1221    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1222        return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1223    }
1224    let e = if eta.is_finite() { eta } else { 0.0 };
1225    let a = e.asinh();
1226    let delta = delta_id;
1227    let u_raw = delta * a - epsilon;
1228    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1229    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1230    let s = u.sinh();
1231    let c = u.cosh();
1232    let z = s;
1233    let q = e.hypot(1.0);
1234    let inv_q = 1.0 / q;
1235    let r1 = delta * inv_q;
1236    let u1 = g1 * r1;
1237    let z1 = c * u1;
1238    // `mu = Phi(z)` and `d1 = phi(z) * z1`, the same closed forms used by the
1239    // full jet via `chain_inverse_link_jet(probit_jet(z), z1, _, _)`.
1240    let base = probit_jet(z);
1241    (base.mu, canonicalzero(base.d1 * z1))
1242}
1243
1244fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1245    let logistic = logistic_uwith_derivatives(eta);
1246    let a = (delta - epsilon).exp();
1247    let b = (delta + epsilon).exp();
1248    let mu = beta_reg_logistic(a, b, logistic);
1249    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1250    (mu, log_d1.exp())
1251}
1252
1253fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1254    let mut mu = 0.0_f64;
1255    let mut d1 = 0.0_f64;
1256    let k = state.components.len().min(state.pi.len());
1257    for i in 0..k {
1258        let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1259        let w = state.pi[i];
1260        mu += w * mu_i;
1261        d1 += w * d1_i;
1262    }
1263    (mu, d1)
1264}
1265
1266#[derive(Clone, Copy)]
1267enum PdfDerivativeOrder {
1268    Third,
1269    Fourth,
1270}
1271
1272impl PdfDerivativeOrder {
1273    fn probit(self, eta: f64) -> f64 {
1274        match self {
1275            Self::Third => probit_pdfthird_derivative(eta),
1276            Self::Fourth => probit_pdffourth_derivative(eta),
1277        }
1278    }
1279
1280    fn component(self, component: LinkComponent, eta: f64) -> f64 {
1281        match self {
1282            Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1283            Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1284        }
1285    }
1286
1287    fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1288        let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1289        Ok(match self {
1290            Self::Third => jet.d4,
1291            Self::Fourth => jet.d5,
1292        })
1293    }
1294
1295    fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1296        match self {
1297            Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1298            Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1299        }
1300    }
1301
1302    fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1303        match self {
1304            Self::Third => {
1305                beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1306            }
1307            Self::Fourth => {
1308                beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1309            }
1310        }
1311    }
1312}
1313
1314fn inverse_link_pdf_derivative_for_inverse_link(
1315    link: &InverseLink,
1316    eta: f64,
1317    order: PdfDerivativeOrder,
1318) -> Result<f64, EstimationError> {
1319    match link {
1320        InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1321        InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1322        InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1323        InverseLink::Standard(StandardLink::Logit) => {
1324            Ok(order.component(LinkComponent::Logit, eta))
1325        }
1326        InverseLink::Standard(StandardLink::CLogLog) => {
1327            Ok(order.component(LinkComponent::CLogLog, eta))
1328        }
1329        InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1330        InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1331        InverseLink::BetaLogistic(state) => {
1332            Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1333        }
1334        InverseLink::Mixture(state) => Ok(state
1335            .components
1336            .iter()
1337            .zip(state.pi.iter())
1338            .map(|(&component, &weight)| weight * order.component(component, eta))
1339            .sum()),
1340    }
1341}
1342
1343pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1344    link: &InverseLink,
1345    eta: f64,
1346) -> Result<f64, EstimationError> {
1347    // This dispatch returns the fourth eta-derivative of the inverse-link CDF,
1348    // equivalently the third derivative of the inverse-link density
1349    //
1350    //   f(eta) = d/deta mu(eta).
1351    //
1352    // It is used downstream as the `f'''` input in
1353    //
1354    //   d³/deta³ log f = f'''/f - 3 f'f''/f² + 2(f')³/f³.
1355    //
1356    // Mixture links preserve linearity:
1357    //
1358    //   mu = sum_j pi_j mu_j
1359    //   => f''' = sum_j pi_j f_j'''
1360    //
1361    // because the mixture weights `pi_j` are constant with respect to `eta`.
1362    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1363}
1364
1365/// Fifth derivative of the inverse-link CDF (= fourth derivative of the PDF).
1366///
1367/// Extends `inverse_link_pdfthird_derivative_for_inverse_link` by one order.
1368/// Used for the outer REML Hessian Q[v_k, v_l] term in survival models,
1369/// specifically the `m1 * u_{abcd}` Arbogast contribution.
1370pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1371    link: &InverseLink,
1372    eta: f64,
1373) -> Result<f64, EstimationError> {
1374    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1375}
1376
1377
1378#[inline]
1379fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1380    // ApproxKind: NumericalApproximation — exp(exp(eta)) overflows f64 for
1381    // eta > ~3.7; the |eta| > 30 saturation tail returns survival ≈ 0 with
1382    // d_k = 0, matching the analytic limit to within IEEE-754 underflow.
1383    const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1384
1385    let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1386    let hazard = z.exp();
1387    let survival = (-hazard).exp();
1388    if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1389        return InverseLinkJet {
1390            mu: survival,
1391            d1: 0.0,
1392            d2: 0.0,
1393            d3: 0.0,
1394        };
1395    }
1396
1397    let d1 = -hazard * survival;
1398    let d2 = hazard * (hazard - 1.0) * survival;
1399    let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1400    InverseLinkJet {
1401        mu: survival,
1402        d1,
1403        d2,
1404        d3,
1405    }
1406}
1407
1408pub fn inverse_link_jet_for_family(
1409    spec: &LikelihoodSpec,
1410    eta: f64,
1411) -> Result<InverseLinkJet, EstimationError> {
1412    // RoystonParmar uses its own analytic survival inverse link irrespective of
1413    // the (nominal `Identity`) link slot carried in the spec.
1414    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1415        return Ok(royston_parmar_inverse_link_jet(eta));
1416    }
1417    spec.link.jet(eta)
1418}
1419
1420/// Exact-public log inverse-link jet: `mu = d1 = d2 = d3 = exp(η)` with NO
1421/// `η`-clamp. Sibling of the solver-internal `LinkFunction::Log` jet (which
1422/// clamps `η` to `[−700, 700]` as an IRLS/REML conditioning hack); see issue
1423/// #963. Every derivative of `exp` is `exp`, so all four jet slots carry the
1424/// same exact value — finite wherever representable, `0.0` on underflow,
1425/// `+∞` on overflow.
1426#[inline]
1427fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1428    let e = eta.exp();
1429    InverseLinkJet {
1430        mu: e,
1431        d1: e,
1432        d2: e,
1433        d3: e,
1434    }
1435}
1436
1437/// EXACT public inverse-link jet for response-scale prediction outputs.
1438///
1439/// Identical to [`inverse_link_jet_for_family`] for every link EXCEPT the
1440/// standard `Log` link, where it returns the exact `exp(η)` jet instead of the
1441/// solver's `η.clamp(−700, 700).exp()` conditioning transform. On finite `η`
1442/// the two diverge (η = 705: exact `exp(705)` ≈ 1.5e306 vs clamped
1443/// `exp(700)` ≈ 1.0e304; η = −720: exact ≈ 2e−313 vs clamped ≈ 9.9e−305), so
1444/// public predictions (`FamilyStrategy::inverse_link_jet`/`inverse_link_array`,
1445/// the predict response + delta-method SE path) route here. The solver/REML/
1446/// PIRLS engines keep the clamped jet (issue #963). For `|η| ≤ 700` this is
1447/// byte-identical to the clamped jet (the clamp is inert there), so no
1448/// in-range prediction changes.
1449pub fn inverse_link_jet_for_family_public(
1450    spec: &LikelihoodSpec,
1451    eta: f64,
1452) -> Result<InverseLinkJet, EstimationError> {
1453    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1454        return Ok(royston_parmar_inverse_link_jet(eta));
1455    }
1456    if let InverseLink::Standard(StandardLink::Log) = spec.link {
1457        return Ok(log_inverse_link_jet_exact(eta));
1458    }
1459    spec.link.jet(eta)
1460}
1461
1462#[inline]
1463pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1464    let mut mu = 0.0_f64;
1465    let mut d1 = 0.0_f64;
1466    let mut d2 = 0.0_f64;
1467    let mut d3 = 0.0_f64;
1468    let k = state.components.len().min(state.pi.len());
1469    for i in 0..k {
1470        let jet = component_inverse_link_jet(state.components[i], eta);
1471        let w = state.pi[i];
1472        mu += w * jet.mu;
1473        d1 += w * jet.d1;
1474        d2 += w * jet.d2;
1475        d3 += w * jet.d3;
1476    }
1477    InverseLinkJet { mu, d1, d2, d3 }
1478}
1479
1480/// Computes mixture jet and exact partial derivatives wrt free softmax logits.
1481///
1482/// Uses identities:
1483///   d mu     / d rho_j = pi_j (mu_j     - mu)
1484///   d mu'    / d rho_j = pi_j (mu_j'    - mu')
1485///   d mu''   / d rho_j = pi_j (mu_j''   - mu'')
1486///   d mu'''  / d rho_j = pi_j (mu_j'''  - mu''')
1487pub fn mixture_inverse_link_jetwith_rho_partials(
1488    state: &MixtureLinkState,
1489    eta: f64,
1490) -> MixtureJetWithRhoPartials {
1491    let k = state.components.len().min(state.pi.len());
1492    let m = k.saturating_sub(1);
1493    let mut djet_drho = vec![
1494        InverseLinkJet {
1495            mu: 0.0,
1496            d1: 0.0,
1497            d2: 0.0,
1498            d3: 0.0,
1499        };
1500        m
1501    ];
1502    let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1503    MixtureJetWithRhoPartials { jet, djet_drho }
1504}
1505
1506/// Computes mixture jet and writes exact rho partial jets into `out` (length >= K-1).
1507/// This avoids heap allocation in hot loops.
1508pub fn mixture_inverse_link_jetwith_rho_partials_into(
1509    state: &MixtureLinkState,
1510    eta: f64,
1511    out: &mut [InverseLinkJet],
1512) -> InverseLinkJet {
1513    let k = state.components.len().min(state.pi.len());
1514    let m = k.saturating_sub(1);
1515    assert!(
1516        out.len() >= m,
1517        "rho-partial output buffer too small: got {}, need {}",
1518        out.len(),
1519        m
1520    );
1521    let mut mixed = InverseLinkJet {
1522        mu: 0.0,
1523        d1: 0.0,
1524        d2: 0.0,
1525        d3: 0.0,
1526    };
1527    for i in 0..k {
1528        let jet_i = component_inverse_link_jet(state.components[i], eta);
1529        let w = state.pi[i];
1530        mixed.mu += w * jet_i.mu;
1531        mixed.d1 += w * jet_i.d1;
1532        mixed.d2 += w * jet_i.d2;
1533        mixed.d3 += w * jet_i.d3;
1534        // Cache the first K-1 component jets directly in the output buffer so
1535        // we don't recompute them in the partial loop.
1536        if i < m {
1537            out[i] = jet_i;
1538        }
1539    }
1540    for j in 0..m {
1541        let pi_j = state.pi[j];
1542        let cj = out[j];
1543        out[j] = InverseLinkJet {
1544            mu: pi_j * (cj.mu - mixed.mu),
1545            d1: pi_j * (cj.d1 - mixed.d1),
1546            d2: pi_j * (cj.d2 - mixed.d2),
1547            d3: pi_j * (cj.d3 - mixed.d3),
1548        };
1549    }
1550    mixed
1551}
1552
1553#[derive(Clone, Copy)]
1554struct LogisticU {
1555    u: f64,
1556    one_minus_u: f64,
1557    ln_u: f64,
1558    ln_one_minus_u: f64,
1559    du: f64,
1560    use_upper_tail: bool,
1561}
1562
1563#[inline]
1564fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1565    let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1566    let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1567    let u = ln_u.exp();
1568    let one_minus_u = ln_one_minus_u.exp();
1569    let du = (ln_u + ln_one_minus_u).exp();
1570    LogisticU {
1571        u,
1572        one_minus_u,
1573        ln_u,
1574        ln_one_minus_u,
1575        du,
1576        use_upper_tail: eta >= 0.0,
1577    }
1578}
1579
1580#[inline]
1581fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1582    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1583        return f64::NAN;
1584    }
1585    if logistic.ln_u == f64::NEG_INFINITY {
1586        return 0.0;
1587    }
1588    if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1589        return 1.0;
1590    }
1591    if logistic.use_upper_tail {
1592        1.0 - beta_reg(b, a, logistic.one_minus_u)
1593    } else {
1594        beta_reg(a, b, logistic.u)
1595    }
1596}
1597
1598#[inline]
1599fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1600    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1601        return (f64::NAN, f64::NAN, f64::NAN);
1602    }
1603    if logistic.use_upper_tail {
1604        let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1605        (1.0 - tail, -dtail_da, -dtail_db)
1606    } else {
1607        beta_reg_with_shape_partials(a, b, logistic.u)
1608    }
1609}
1610
1611#[inline]
1612fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1613    a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1614}
1615
1616#[derive(Clone, Copy)]
1617struct ShapeDual {
1618    v: f64,
1619    da: f64,
1620    db: f64,
1621}
1622
1623impl ShapeDual {
1624    #[inline]
1625    fn constant(v: f64) -> Self {
1626        Self {
1627            v,
1628            da: 0.0,
1629            db: 0.0,
1630        }
1631    }
1632
1633    #[inline]
1634    fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1635        Self { v, da, db }
1636    }
1637
1638    #[inline]
1639    fn clamp_small(self, floor: f64) -> Self {
1640        if self.v.abs() < floor {
1641            Self::constant(floor)
1642        } else {
1643            self
1644        }
1645    }
1646}
1647
1648impl std::ops::Add for ShapeDual {
1649    type Output = Self;
1650
1651    #[inline]
1652    fn add(self, rhs: Self) -> Self {
1653        Self {
1654            v: self.v + rhs.v,
1655            da: self.da + rhs.da,
1656            db: self.db + rhs.db,
1657        }
1658    }
1659}
1660
1661impl std::ops::Sub for ShapeDual {
1662    type Output = Self;
1663
1664    #[inline]
1665    fn sub(self, rhs: Self) -> Self {
1666        Self {
1667            v: self.v - rhs.v,
1668            da: self.da - rhs.da,
1669            db: self.db - rhs.db,
1670        }
1671    }
1672}
1673
1674impl std::ops::Mul for ShapeDual {
1675    type Output = Self;
1676
1677    #[inline]
1678    fn mul(self, rhs: Self) -> Self {
1679        Self {
1680            v: self.v * rhs.v,
1681            da: self.da * rhs.v + self.v * rhs.da,
1682            db: self.db * rhs.v + self.v * rhs.db,
1683        }
1684    }
1685}
1686
1687impl std::ops::Div for ShapeDual {
1688    type Output = Self;
1689
1690    #[inline]
1691    fn div(self, rhs: Self) -> Self {
1692        let inv = 1.0 / rhs.v;
1693        let inv2 = inv * inv;
1694        Self {
1695            v: self.v * inv,
1696            da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1697            db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1698        }
1699    }
1700}
1701
1702impl std::ops::Neg for ShapeDual {
1703    type Output = Self;
1704
1705    #[inline]
1706    fn neg(self) -> Self {
1707        ShapeDual {
1708            v: -self.v,
1709            da: -self.da,
1710            db: -self.db,
1711        }
1712    }
1713}
1714
1715#[inline]
1716fn shape_dual(v: f64) -> ShapeDual {
1717    ShapeDual::constant(v)
1718}
1719
1720// Analytic shape partials for I_x(a,b), obtained by differentiating the same
1721// regularized-beta continued fraction used by statrs. The normalizing term uses
1722// d log B(a,b) / da = psi(a) - psi(a+b) and likewise for b.
1723fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1724    if x0 <= 0.0 {
1725        return (0.0, 0.0, 0.0);
1726    }
1727    if x0 >= 1.0 {
1728        return (1.0, 0.0, 0.0);
1729    }
1730
1731    let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1732    let (a, b, x) = if symm_transform {
1733        (
1734            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1735            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1736            1.0 - x0,
1737        )
1738    } else {
1739        (
1740            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1741            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1742            x0,
1743        )
1744    };
1745
1746    let ln_x = x.ln();
1747    let ln_1mx = (1.0 - x).ln();
1748    let psi_ab = digamma(a.v + b.v);
1749    let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1750        - statrs::function::gamma::ln_gamma(a.v)
1751        - statrs::function::gamma::ln_gamma(b.v)
1752        + a.v * ln_x
1753        + b.v * ln_1mx;
1754    let bt_v = log_bt.exp();
1755    let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1756    let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1757    let bt = ShapeDual {
1758        v: bt_v,
1759        da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1760        db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1761    };
1762
1763    let eps = 0.00000000000000011102230246251565;
1764    let fpmin = f64::MIN_POSITIVE / eps;
1765    let one = shape_dual(1.0);
1766    let qab = a + b;
1767    let qap = a + one;
1768    let qam = a - one;
1769    let mut c = one;
1770    let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1771    d = one / d;
1772    let mut h = d;
1773
1774    for m in 1..141 {
1775        let mf = f64::from(m);
1776        let m2 = mf * 2.0;
1777        let md = shape_dual(mf);
1778        let m2d = shape_dual(m2);
1779        let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1780        d = (one + aa * d).clamp_small(fpmin);
1781        c = (one + aa / c).clamp_small(fpmin);
1782        d = one / d;
1783        h = h * d * c;
1784
1785        aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1786        d = (one + aa * d).clamp_small(fpmin);
1787        c = (one + aa / c).clamp_small(fpmin);
1788        d = one / d;
1789        let del = d * c;
1790        h = h * del;
1791
1792        if (del.v - 1.0).abs() <= eps {
1793            let reg = bt * h / a;
1794            return if symm_transform {
1795                (1.0 - reg.v, -reg.da, -reg.db)
1796            } else {
1797                (reg.v, reg.da, reg.db)
1798            };
1799        }
1800    }
1801    let reg = bt * h / a;
1802    if symm_transform {
1803        (1.0 - reg.v, -reg.da, -reg.db)
1804    } else {
1805        (reg.v, reg.da, reg.db)
1806    }
1807}
1808
1809/// Beta-Logistic inverse-link jet for:
1810///   u = logistic(eta)
1811///   a = exp(log_shape_center - epsilon), b = exp(log_shape_center + epsilon)
1812///   mu = I_u(a, b)
1813///
1814/// NOTE: `log_shape_center` is the *unconstrained* log of the geometric-mean
1815/// beta shape (so a·b = exp(2·log_shape_center)). Callers must pass the raw
1816/// optimization parameter `SasLinkState::log_delta`, NOT the derived positive
1817/// `SasLinkState::delta = exp(log_shape_center)`.
1818pub fn beta_logistic_inverse_link_jet(
1819    eta: f64,
1820    log_shape_center: f64,
1821    epsilon: f64,
1822) -> InverseLinkJet {
1823    let logistic = logistic_uwith_derivatives(eta);
1824    let a = (log_shape_center - epsilon).exp();
1825    let b = (log_shape_center + epsilon).exp();
1826    let mu = beta_reg_logistic(a, b, logistic);
1827    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1828    let d1 = log_d1.exp();
1829    let t = a * logistic.one_minus_u - b * logistic.u;
1830    let d2 = d1 * t;
1831    let d3 = d1 * (t * t - (a + b) * logistic.du);
1832    InverseLinkJet { mu, d1, d2, d3 }
1833}
1834
1835pub fn beta_logistic_inverse_link_pdfthird_derivative(
1836    eta: f64,
1837    log_shape_center: f64,
1838    epsilon: f64,
1839) -> f64 {
1840    // Beta-logistic link:
1841    //
1842    //   u = logistic(eta),
1843    //   d1 = C * u^a (1-u)^b,
1844    //   t  = a(1-u) - b u,
1845    //   c  = a + b,
1846    //
1847    // so
1848    //
1849    //   d2 = d1 * t
1850    //   d3 = d1 * (t² - c u')
1851    //
1852    // with `u' = u(1-u)`.
1853    //
1854    // Differentiate once more:
1855    //
1856    //   d4 = d/deta[d1 (t² - c u')]
1857    //      = d1' (t² - c u') + d1 (2 t t' - c u'')
1858    //      = d1 [ t(t² - c u') - 2 c t u' - c u'' ]
1859    //      = d1 [ t³ - 3 c t u' - c u'' ],
1860    //
1861    // since `t' = -c u'`.
1862    let logistic = logistic_uwith_derivatives(eta);
1863    let a = (log_shape_center - epsilon).exp();
1864    let b = (log_shape_center + epsilon).exp();
1865    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1866    let d1 = log_d1.exp();
1867    let c = a + b;
1868    let t = a * logistic.one_minus_u - b * logistic.u;
1869    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1870    d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1871}
1872
1873/// Fifth derivative of the beta-logistic inverse-link CDF (= 4th deriv of PDF).
1874///
1875/// With `P_4 = t^3 - 3ct*u' - c*u''` giving `d4 = d1 * P_4`, the next order is:
1876///
1877///   d5 = d1 * [t^4 - 6c*t^2*u' - 4c*t*u'' + 3c^2*u'^2 - c*u''']
1878///
1879/// where u' = u(1-u), u'' = u'(1-2u), u''' = u''(1-2u) - 2*u'^2.
1880pub fn beta_logistic_inverse_link_pdffourth_derivative(
1881    eta: f64,
1882    log_shape_center: f64,
1883    epsilon: f64,
1884) -> f64 {
1885    let logistic = logistic_uwith_derivatives(eta);
1886    let a = (log_shape_center - epsilon).exp();
1887    let b = (log_shape_center + epsilon).exp();
1888    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1889    let d1 = log_d1.exp();
1890    let c = a + b;
1891    let t = a * logistic.one_minus_u - b * logistic.u;
1892    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1893    let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1894    let t2 = t * t;
1895    d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1896        + 3.0 * c * c * logistic.du * logistic.du
1897        - c * u3)
1898}
1899
1900pub fn beta_logistic_inverse_link_jetwith_param_partials(
1901    eta: f64,
1902    log_shape_center: f64,
1903    epsilon: f64,
1904) -> SasJetWithParamPartials {
1905    let logistic = logistic_uwith_derivatives(eta);
1906    let a = (log_shape_center - epsilon).exp();
1907    let b = (log_shape_center + epsilon).exp();
1908    let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1909    let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1910    let dmu_depsilon = -a * dmu_da + b * dmu_db;
1911    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1912    let d1 = log_d1.exp();
1913    let t = a * logistic.one_minus_u - b * logistic.u;
1914    let d2 = d1 * t;
1915    let k = t * t - (a + b) * logistic.du;
1916    let d3 = d1 * k;
1917    let jet = InverseLinkJet { mu, d1, d2, d3 };
1918
1919    let psi_a = digamma(a);
1920    let psi_b = digamma(b);
1921    let psi_ab = digamma(a + b);
1922    let la = logistic.ln_u - psi_a + psi_ab;
1923    let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1924
1925    let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1926        let logd1_p = a_p * la + b_p * lb;
1927        let d1_p = d1 * logd1_p;
1928        let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1929        let d2_p = d1_p * t + d1 * t_p;
1930        let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1931        let d3_p = d1_p * k + d1 * k_p;
1932        InverseLinkJet {
1933            mu: dmu,
1934            d1: d1_p,
1935            d2: d2_p,
1936            d3: d3_p,
1937        }
1938    };
1939    let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1940    let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1941    SasJetWithParamPartials {
1942        jet,
1943        djet_depsilon,
1944        djet_dlog_delta: djet_dlog_shape_center,
1945    }
1946}
1947
1948/// SAS inverse-link jet for:
1949///   mu(eta) = Phi(sinh(delta * asinh(eta) - epsilon)),
1950///   delta = exp(B * tanh(log_delta / B)), B = SAS_LOG_DELTA_BOUND.
1951pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1952    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1953    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1954        return component_inverse_link_jet(LinkComponent::Probit, eta);
1955    }
1956    let e = if eta.is_finite() { eta } else { 0.0 };
1957    let a = e.asinh();
1958    let delta = delta_id;
1959    let u_raw = delta * a - epsilon;
1960    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1961    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1962    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1963    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1964    let s = u.sinh();
1965    let c = u.cosh();
1966    let z = s;
1967    let q = e.hypot(1.0);
1968    let inv_q = 1.0 / q;
1969    let inv_q2 = inv_q * inv_q;
1970    let inv_q3 = inv_q2 * inv_q;
1971    let inv_q5 = inv_q3 * inv_q2;
1972    let r1 = delta * inv_q;
1973    let r2 = -delta * e * inv_q3;
1974    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
1975    let u1 = g1 * r1;
1976    let u2 = g2 * r1 * r1 + g1 * r2;
1977    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
1978    let z1 = c * u1;
1979    let z2 = s * u1 * u1 + c * u2;
1980    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
1981    let base = probit_jet(z);
1982    chain_inverse_link_jet(base, z1, z2, z3)
1983}
1984
1985pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1986    // SAS link with bounded latent transform:
1987    //
1988    //   a  = asinh(eta),
1989    //   u  = tanh_bound(delta * a - epsilon),
1990    //   z  = sinh(u),
1991    //   mu = Phi(z).
1992    //
1993    // Write:
1994    //
1995    //   z1 = z'
1996    //   z2 = z''
1997    //   z3 = z'''
1998    //   z4 = z''''.
1999    //
2000    // Since `mu' = phi(z) z1`, repeated differentiation factors through the
2001    // standard normal Hermite-polynomial identities:
2002    //
2003    //   mu''   = phi(z) [ z2 - z z1² ]
2004    //
2005    //   mu'''  = phi(z) [ z3 - 3 z z1 z2 + (z² - 1) z1³ ]
2006    //          = phi(z) k3
2007    //
2008    //   mu'''' = phi(z) [ k4 - z z1 k3 ],
2009    //
2010    // where `k4` is the derivative of `k3` after collecting like terms. The
2011    // code below computes `u1..u4`, then `z1..z4`, then `k3` and `k4`, exactly
2012    // matching that chain.
2013    //
2014    // The needed fourth derivative of `u(eta)` is obtained from the nested
2015    // composition `u(eta) = g(r(eta))` with
2016    //   g = tanh_bound, r = delta * asinh(eta) - epsilon:
2017    //
2018    //   u4 = g'''' r1^4 + 6 g''' r1² r2 + 3 g'' r2² + 4 g'' r1 r3 + g' r4,
2019    //
2020    // which is the standard scalar Arbogast expansion for order four.
2021    let e = if eta.is_finite() { eta } else { 0.0 };
2022    let a = e.asinh();
2023    let delta = sas_delta_from_raw_log_delta(log_delta);
2024    let u_raw = delta * a - epsilon;
2025    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2026    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2027    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2028    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2029    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2030    let s = u.sinh();
2031    let c = u.cosh();
2032    let z = s;
2033    let base = probit_jet(z);
2034    let q = e.hypot(1.0);
2035    let inv_q = 1.0 / q;
2036    let inv_q2 = inv_q * inv_q;
2037    let inv_q3 = inv_q2 * inv_q;
2038    let inv_q5 = inv_q3 * inv_q2;
2039    let inv_q7 = inv_q5 * inv_q2;
2040    let r1 = delta * inv_q;
2041    let r2 = -delta * e * inv_q3;
2042    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2043    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2044    let u1 = g1 * r1;
2045    let u2 = g2 * r1 * r1 + g1 * r2;
2046    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2047    let u4 = g4 * r1.powi(4)
2048        + 6.0 * g3 * r1 * r1 * r2
2049        + 3.0 * g2 * r2 * r2
2050        + 4.0 * g2 * r1 * r3
2051        + g1 * r4;
2052    let z1 = c * u1;
2053    let z2 = s * u1 * u1 + c * u2;
2054    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2055    let z4 =
2056        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2057    let base4 = probit_pdfthird_derivative(z);
2058    let out = base4 * z1.powi(4)
2059        + 6.0 * base.d3 * z1 * z1 * z2
2060        + 3.0 * base.d2 * z2 * z2
2061        + 4.0 * base.d2 * z1 * z3
2062        + base.d1 * z4;
2063    canonicalzero(out)
2064}
2065
2066/// Fifth derivative of the SAS inverse-link CDF (= fourth derivative of the PDF).
2067///
2068/// Extends `sas_inverse_link_pdfthird_derivative` by one more derivative order,
2069/// using the same composition chain u(eta) = g(r(eta)), z = sinh(u), mu = Phi(z).
2070///
2071/// The Arbogast expansion at order 5 for u(eta) = g(r(eta)) is:
2072///   u5 = g5 r1^5 + 10 g4 r1^3 r2 + 15 g3 r1 r2^2 + 10 g3 r1^2 r3
2073///        + 10 g2 r2 r3 + 5 g2 r1 r4 + g1 r5
2074///
2075/// The z = sinh(u) expansion at order 5 is the standard Arbogast for sinh:
2076///   z5 = c*u1^5 + 10*s*u1^3*u2 + 15*c*u1*u2^2 + 10*c*u1^2*u3
2077///        + 10*s*u2*u3 + 5*s*u1*u4 + c*u5
2078///
2079/// The mu = Phi(z) expansion at order 5 uses probit derivatives:
2080///   mu^(5) = Phi5*z1^5 + 10*Phi4*z1^3*z2 + 15*Phi3*z1*z2^2 + 10*Phi3*z1^2*z3
2081///            + 10*Phi2*z2*z3 + 5*Phi2*z1*z4 + Phi1*z5
2082pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2083    let e = if eta.is_finite() { eta } else { 0.0 };
2084    let a = e.asinh();
2085    let delta = sas_delta_from_raw_log_delta(log_delta);
2086    let u_raw = delta * a - epsilon;
2087    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2088    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2089    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2090    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2091    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2092    let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2093    let s = u.sinh();
2094    let c = u.cosh();
2095    let z = s;
2096
2097    // Probit derivatives at z.
2098    let base = probit_jet(z);
2099    let phi3 = probit_pdfthird_derivative(z); // Phi^{(4)}
2100    let phi4 = probit_pdffourth_derivative(z); // Phi^{(5)}
2101
2102    // Powers of q = sqrt(1 + eta^2) for r1..r5.
2103    let q = e.hypot(1.0);
2104    let inv_q = 1.0 / q;
2105    let inv_q2 = inv_q * inv_q;
2106    let inv_q3 = inv_q2 * inv_q;
2107    let inv_q5 = inv_q3 * inv_q2;
2108    let inv_q7 = inv_q5 * inv_q2;
2109    let inv_q9 = inv_q7 * inv_q2;
2110
2111    let r1 = delta * inv_q;
2112    let r2 = -delta * e * inv_q3;
2113    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2114    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2115    let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2116
2117    // u1..u5 via Arbogast for g(r(eta)).
2118    let u1 = g1 * r1;
2119    let u2 = g2 * r1 * r1 + g1 * r2;
2120    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2121    let u4 = g4 * r1.powi(4)
2122        + 6.0 * g3 * r1 * r1 * r2
2123        + 3.0 * g2 * r2 * r2
2124        + 4.0 * g2 * r1 * r3
2125        + g1 * r4;
2126    let u5 = g5 * r1.powi(5)
2127        + 10.0 * g4 * r1 * r1 * r1 * r2
2128        + 15.0 * g3 * r1 * r2 * r2
2129        + 10.0 * g3 * r1 * r1 * r3
2130        + 10.0 * g2 * r2 * r3
2131        + 5.0 * g2 * r1 * r4
2132        + g1 * r5;
2133
2134    // z1..z5 via Arbogast for sinh(u(eta)).
2135    let z1 = c * u1;
2136    let z2 = s * u1 * u1 + c * u2;
2137    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2138    let z4 =
2139        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2140    let z5 = c * u1.powi(5)
2141        + 10.0 * s * u1 * u1 * u1 * u2
2142        + 15.0 * c * u1 * u2 * u2
2143        + 10.0 * c * u1 * u1 * u3
2144        + 10.0 * s * u2 * u3
2145        + 5.0 * s * u1 * u4
2146        + c * u5;
2147
2148    // mu^(5) = Phi^(5)*z1^5 + 10*Phi^(4)*z1^3*z2 + 15*Phi^(3)*z1*z2^2
2149    //        + 10*Phi^(3)*z1^2*z3 + 10*Phi^(2)*z2*z3 + 5*Phi^(2)*z1*z4 + Phi^(1)*z5
2150    let out = phi4 * z1.powi(5)
2151        + 10.0 * phi3 * z1 * z1 * z1 * z2
2152        + 15.0 * base.d3 * z1 * z2 * z2
2153        + 10.0 * base.d3 * z1 * z1 * z3
2154        + 10.0 * base.d2 * z2 * z3
2155        + 5.0 * base.d2 * z1 * z4
2156        + base.d1 * z5;
2157    canonicalzero(out)
2158}
2159
2160pub fn sas_inverse_link_jetwith_param_partials(
2161    eta: f64,
2162    epsilon: f64,
2163    log_delta: f64,
2164) -> SasJetWithParamPartials {
2165    let e = if eta.is_finite() { eta } else { 0.0 };
2166    let a = e.asinh();
2167    let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2168    let delta = ld_eff.exp();
2169    let ddelta_draw = delta * dld_eff_draw;
2170    let u_raw = delta * a - epsilon;
2171    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2172    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2173    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2174    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2175    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2176    let s = u.sinh();
2177    let c = u.cosh();
2178    let z = s;
2179    let q = e.hypot(1.0);
2180    let inv_q = 1.0 / q;
2181    let inv_q2 = inv_q * inv_q;
2182    let inv_q3 = inv_q2 * inv_q;
2183    let inv_q5 = inv_q3 * inv_q2;
2184    let a1 = inv_q;
2185    let a2 = -e * inv_q3;
2186    let a3 = (2.0 * e * e - 1.0) * inv_q5;
2187    let r1 = delta * a1;
2188    let r2 = delta * a2;
2189    let r3 = delta * a3;
2190    let u1 = g1 * r1;
2191    let u2 = g2 * r1 * r1 + g1 * r2;
2192    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2193    let z1 = c * u1;
2194    let z2 = s * u1 * u1 + c * u2;
2195    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2196
2197    let base = probit_jet(z);
2198    let jet = chain_inverse_link_jet(base, z1, z2, z3);
2199
2200    // Generic chain for parameter t:
2201    // u_t, u1_t, u2_t, u3_t -> z_t,z1_t,z2_t,z3_t -> mu_t,d1_t,d2_t,d3_t
2202    let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2203        let z_t = c * u_t;
2204        let z1_t = s * u_t * u1 + c * u1_t;
2205        let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2206        let z3_t = s * u_t * u1 * u1 * u1
2207            + 3.0 * c * u1 * u1 * u1_t
2208            + 3.0 * c * u_t * u1 * u2
2209            + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2210            + s * u_t * u3
2211            + c * u3_t;
2212
2213        InverseLinkJet {
2214            mu: base.d1 * z_t,
2215            d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2216            d2: base.d3 * z_t * z1 * z1
2217                + 2.0 * base.d2 * z1 * z1_t
2218                + base.d2 * z_t * z2
2219                + base.d1 * z2_t,
2220            d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2221                + 3.0 * base.d3 * z1 * z1 * z1_t
2222                + 3.0 * base.d3 * z_t * z1 * z2
2223                + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2224                + base.d2 * z_t * z3
2225                + base.d1 * z3_t,
2226        }
2227    };
2228
2229    // epsilon partials (raw_u_t = -1).
2230    let rt_eps = -1.0;
2231    let r1t_eps = 0.0;
2232    let r2t_eps = 0.0;
2233    let r3t_eps = 0.0;
2234    let u_eps = g1 * rt_eps;
2235    let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2236    let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2237    let u3_eps = g4 * rt_eps * r1 * r1 * r1
2238        + 3.0 * g3 * r1 * r1 * r1t_eps
2239        + 3.0 * g3 * rt_eps * r1 * r2
2240        + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2241        + g2 * rt_eps * r3
2242        + g1 * r3t_eps;
2243    let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2244
2245    // raw log-delta partials (through smooth bounded effective log-delta).
2246    let rt_ld = ddelta_draw * a;
2247    let r1t_ld = ddelta_draw * a1;
2248    let r2t_ld = ddelta_draw * a2;
2249    let r3t_ld = ddelta_draw * a3;
2250    let u_ld = g1 * rt_ld;
2251    let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2252    let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2253    let u3_ld = g4 * rt_ld * r1 * r1 * r1
2254        + 3.0 * g3 * r1 * r1 * r1t_ld
2255        + 3.0 * g3 * rt_ld * r1 * r2
2256        + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2257        + g2 * rt_ld * r3
2258        + g1 * r3t_ld;
2259    let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2260
2261    SasJetWithParamPartials {
2262        jet,
2263        djet_depsilon,
2264        djet_dlog_delta,
2265    }
2266}
2267
2268#[cfg(test)]
2269mod tests {
2270    use super::*;
2271    use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2272
2273    #[test]
2274    fn softmax_jacobian_matchesfd() {
2275        let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2276        let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2277        let h = 1e-6;
2278        for j in 0..rho.len() {
2279            let mut rp = rho.clone();
2280            rp[j] += h;
2281            let mut rm = rho.clone();
2282            rm[j] -= h;
2283            let pp = softmax_last_fixedzero(&rp);
2284            let pm = softmax_last_fixedzero(&rm);
2285            let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2286            for k in 0..pi.len() {
2287                let err = (jac[[k, j]] - fd[k]).abs();
2288                assert_eq!(
2289                    jac[[k, j]].signum(),
2290                    fd[k].signum(),
2291                    "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2292                    jac[[k, j]],
2293                    fd[k]
2294                );
2295                assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2296            }
2297        }
2298    }
2299
2300    #[test]
2301    fn mixture_jet_rho_partials_matchfd() {
2302        let spec = MixtureLinkSpec {
2303            components: vec![
2304                LinkComponent::Probit,
2305                LinkComponent::Logit,
2306                LinkComponent::CLogLog,
2307                LinkComponent::Cauchit,
2308            ],
2309            initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2310        };
2311        let state = state_fromspec(&spec).expect("state");
2312        let eta = 0.35;
2313        let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2314        let h = 1e-6;
2315        for j in 0..state.rho.len() {
2316            let mut rp = state.rho.clone();
2317            rp[j] += h;
2318            let sp = MixtureLinkSpec {
2319                components: state.components.clone(),
2320                initial_rho: rp,
2321            };
2322            let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2323            let mut rm = state.rho.clone();
2324            rm[j] -= h;
2325            let sm = MixtureLinkSpec {
2326                components: state.components.clone(),
2327                initial_rho: rm,
2328            };
2329            let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2330            let fd = InverseLinkJet {
2331                mu: (jp.mu - jm.mu) / (2.0 * h),
2332                d1: (jp.d1 - jm.d1) / (2.0 * h),
2333                d2: (jp.d2 - jm.d2) / (2.0 * h),
2334                d3: (jp.d3 - jm.d3) / (2.0 * h),
2335            };
2336            let an = out.djet_drho[j];
2337            assert_eq!(an.mu.signum(), fd.mu.signum());
2338            assert_eq!(an.d1.signum(), fd.d1.signum());
2339            assert_eq!(an.d2.signum(), fd.d2.signum());
2340            assert_eq!(an.d3.signum(), fd.d3.signum());
2341            assert!((an.mu - fd.mu).abs() < 1e-6);
2342            assert!((an.d1 - fd.d1).abs() < 1e-6);
2343            assert!((an.d2 - fd.d2).abs() < 1e-6);
2344            assert!((an.d3 - fd.d3).abs() < 1e-6);
2345        }
2346    }
2347
2348    #[test]
2349    fn sas_param_partials_matchfd() {
2350        let eta = 0.37;
2351        let epsilon = -0.12;
2352        let log_delta = 0.21;
2353        let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2354        let h = 1e-6;
2355
2356        let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2357        let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2358        let fd_ep = InverseLinkJet {
2359            mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2360            d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2361            d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2362            d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2363        };
2364        assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2365        assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2366        assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2367        assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2368        assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2369        assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2370        assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2371        assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2372
2373        let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2374        let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2375        let fd_ld = InverseLinkJet {
2376            mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2377            d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2378            d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2379            d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2380        };
2381        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2382        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2383        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2384        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2385        assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2386        assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2387        assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2388        assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2389    }
2390
2391    #[test]
2392    fn sas_jet_extreme_inputs_stay_finite() {
2393        let cases = [
2394            (-1e6, 0.0, 0.0),
2395            (1e6, 0.0, 0.0),
2396            (3.0, 12.0, 12.0),
2397            (-3.0, -12.0, -12.0),
2398            (0.5, 40.0, 10.0),
2399            (0.5, -40.0, -10.0),
2400        ];
2401        for (eta, eps, log_delta) in cases {
2402            let j = sas_inverse_link_jet(eta, eps, log_delta);
2403            assert!(j.mu.is_finite());
2404            assert!(j.d1.is_finite());
2405            assert!(j.d2.is_finite());
2406            assert!(j.d3.is_finite());
2407            let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2408            assert!(p.djet_depsilon.mu.is_finite());
2409            assert!(p.djet_depsilon.d1.is_finite());
2410            assert!(p.djet_depsilon.d2.is_finite());
2411            assert!(p.djet_depsilon.d3.is_finite());
2412            assert!(p.djet_dlog_delta.mu.is_finite());
2413            assert!(p.djet_dlog_delta.d1.is_finite());
2414            assert!(p.djet_dlog_delta.d2.is_finite());
2415            assert!(p.djet_dlog_delta.d3.is_finite());
2416        }
2417    }
2418
2419    #[test]
2420    fn sas_param_partials_remain_finite_in_extreme_region() {
2421        let eta = 10.0;
2422        let epsilon = -60.0;
2423        let log_delta = 40.0;
2424        let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2425        assert!(j.djet_depsilon.mu.is_finite());
2426        assert!(j.djet_depsilon.d1.is_finite());
2427        assert!(j.djet_depsilon.d2.is_finite());
2428        assert!(j.djet_depsilon.d3.is_finite());
2429        assert!(j.djet_dlog_delta.mu.is_finite());
2430        assert!(j.djet_dlog_delta.d1.is_finite());
2431        assert!(j.djet_dlog_delta.d2.is_finite());
2432        assert!(j.djet_dlog_delta.d3.is_finite());
2433    }
2434
2435    #[test]
2436    fn sas_eta_jets_matchfd() {
2437        let eta = -0.43;
2438        let epsilon = 0.27;
2439        let log_delta = -0.31;
2440        let h = 1e-5;
2441        let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2442        let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2443        let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2444        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2445        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2446        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2447        assert_eq!(j0.d1.signum(), d1fd.signum());
2448        assert_eq!(j0.d2.signum(), d2fd.signum());
2449        assert_eq!(j0.d3.signum(), d3fd.signum());
2450        assert!((j0.d1 - d1fd).abs() < 5e-5);
2451        assert!((j0.d2 - d2fd).abs() < 2e-4);
2452        assert!((j0.d3 - d3fd).abs() < 1e-3);
2453    }
2454
2455    #[test]
2456    fn family_dispatch_resolves_parameterized_links_from_spec() {
2457        // After the LikelihoodSpec migration, the dispatch no longer needs
2458        // out-of-band state arguments — the parameterized link state lives on
2459        // `spec.link`. Verify that supplying explicit SAS/Mixture link states
2460        // through the spec produces finite jets at a representative eta.
2461        let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2462        let sas_spec = gam_problem::LikelihoodSpec {
2463            response: gam_problem::ResponseFamily::Binomial,
2464            link: InverseLink::Sas(sas_state),
2465        };
2466        let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2467        assert!(sas_jet.mu.is_finite());
2468        assert!(sas_jet.d1.is_finite());
2469
2470        let mix_state = MixtureLinkState {
2471            components: vec![LinkComponent::Logit, LinkComponent::Probit],
2472            rho: ndarray::array![0.0],
2473            pi: ndarray::array![0.5, 0.5],
2474        };
2475        let mix_spec = gam_problem::LikelihoodSpec {
2476            response: gam_problem::ResponseFamily::Binomial,
2477            link: InverseLink::Mixture(mix_state),
2478        };
2479        let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2480        assert!(mix_jet.mu.is_finite());
2481        assert!(mix_jet.d1.is_finite());
2482    }
2483
2484    #[test]
2485    fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2486        let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2487        for eta in etas {
2488            let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2489            let expected_mu = gam_linalg::utils::stable_logistic(eta);
2490            let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2491                - gam_linalg::utils::stable_softplus(eta))
2492            .exp();
2493            assert!(
2494                (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2495                "mu mismatch at eta={eta}: got {}, expected {}",
2496                j_bl.mu,
2497                expected_mu
2498            );
2499            assert!(
2500                (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2501                "d1 mismatch at eta={eta}: got {}, expected {}",
2502                j_bl.d1,
2503                expected_d1
2504            );
2505            assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2506        }
2507
2508        let eta = 0.42;
2509        let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2510        let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2511        assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2512        assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2513    }
2514
2515    #[test]
2516    fn beta_logistic_eta_jets_matchfd() {
2517        let eta = -0.31;
2518        let delta = 0.27;
2519        let epsilon = -0.19;
2520        let h = 1e-5;
2521        let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2522        let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2523        let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2524        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2525        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2526        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2527        assert_eq!(j0.d1.signum(), d1fd.signum());
2528        assert_eq!(j0.d2.signum(), d2fd.signum());
2529        assert_eq!(j0.d3.signum(), d3fd.signum());
2530        assert!((j0.d1 - d1fd).abs() < 5e-5);
2531        assert!((j0.d2 - d2fd).abs() < 5e-5);
2532        assert!((j0.d3 - d3fd).abs() < 2e-4);
2533    }
2534
2535    #[test]
2536    fn standard_kernel_structs_match_component_jets() {
2537        let eta = 0.73;
2538        assert_eq!(
2539            ProbitLinkKernel.jet(eta).expect("probit"),
2540            component_inverse_link_jet(LinkComponent::Probit, eta)
2541        );
2542        assert_eq!(
2543            LogitLinkKernel.jet(eta).expect("logit"),
2544            component_inverse_link_jet(LinkComponent::Logit, eta)
2545        );
2546        assert_eq!(
2547            CLogLogLinkKernel.jet(eta).expect("cloglog"),
2548            component_inverse_link_jet(LinkComponent::CLogLog, eta)
2549        );
2550        assert_eq!(
2551            LogLogLinkKernel.jet(eta).expect("loglog"),
2552            component_inverse_link_jet(LinkComponent::LogLog, eta)
2553        );
2554        assert_eq!(
2555            CauchitLinkKernel.jet(eta).expect("cauchit"),
2556            component_inverse_link_jet(LinkComponent::Cauchit, eta)
2557        );
2558    }
2559
2560    #[test]
2561    fn all_component_eta_jets_matchfd() {
2562        let components = [
2563            LinkComponent::Logit,
2564            LinkComponent::Probit,
2565            LinkComponent::CLogLog,
2566            LinkComponent::LogLog,
2567            LinkComponent::Cauchit,
2568        ];
2569        let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2570        let h = 1e-5;
2571        for c in components {
2572            for &eta in &points {
2573                let j0 = component_inverse_link_jet(c, eta);
2574                let jp = component_inverse_link_jet(c, eta + h);
2575                let jm = component_inverse_link_jet(c, eta - h);
2576                let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2577                let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2578                let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2579                let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2580                    1.2e-4
2581                } else {
2582                    5e-5
2583                };
2584                let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2585                    4e-4
2586                } else {
2587                    1.2e-4
2588                };
2589                let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2590                    1.2e-3
2591                } else {
2592                    4e-4
2593                };
2594                if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2595                    assert_eq!(
2596                        j0.d1.signum(),
2597                        d1fd.signum(),
2598                        "d1 sign mismatch for {c:?} eta={eta}"
2599                    );
2600                }
2601                if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2602                    assert_eq!(
2603                        j0.d2.signum(),
2604                        d2fd.signum(),
2605                        "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2606                        j0.d2,
2607                        d2fd
2608                    );
2609                }
2610                if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2611                    assert_eq!(
2612                        j0.d3.signum(),
2613                        d3fd.signum(),
2614                        "d3 sign mismatch for {c:?} eta={eta}"
2615                    );
2616                }
2617                assert!(
2618                    (j0.d1 - d1fd).abs() < d1_tol,
2619                    "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2620                    j0.d1,
2621                    d1fd
2622                );
2623                assert!(
2624                    (j0.d2 - d2fd).abs() < d2_tol,
2625                    "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2626                    j0.d2,
2627                    d2fd
2628                );
2629                assert!(
2630                    (j0.d3 - d3fd).abs() < d3_tol,
2631                    "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2632                    j0.d3,
2633                    d3fd
2634                );
2635            }
2636        }
2637    }
2638
2639    #[test]
2640    fn sas_center_matches_probit_at_delta1_epsilon0() {
2641        let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2642        for eta in etas {
2643            let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2644            let probit = ProbitLinkKernel.jet(eta).expect("probit");
2645            // SAS implementation uses a smooth bounded latent (`tanh_bound`) for
2646            // numerical robustness, so the probit center is approximate in practice.
2647            assert!(
2648                (sas.mu - probit.mu).abs() < 6e-4,
2649                "mu mismatch at eta={eta}"
2650            );
2651            assert!(
2652                (sas.d1 - probit.d1).abs() < 6e-4,
2653                "d1 mismatch at eta={eta}"
2654            );
2655            assert!(
2656                (sas.d2 - probit.d2).abs() < 2e-3,
2657                "d2 mismatch at eta={eta}"
2658            );
2659            assert!(
2660                (sas.d3 - probit.d3).abs() < 4e-3,
2661                "d3 mismatch at eta={eta}"
2662            );
2663        }
2664    }
2665
2666    #[test]
2667    fn beta_logistic_param_partials_matchfd() {
2668        let eta = -0.41;
2669        let delta = 0.23;
2670        let epsilon = -0.17;
2671        let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2672        let h = 1e-6;
2673
2674        let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2675        let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2676        let fd_delta = InverseLinkJet {
2677            mu: (dp.mu - dm.mu) / (2.0 * h),
2678            d1: (dp.d1 - dm.d1) / (2.0 * h),
2679            d2: (dp.d2 - dm.d2) / (2.0 * h),
2680            d3: (dp.d3 - dm.d3) / (2.0 * h),
2681        };
2682        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2683        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2684        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2685        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2686        assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2687        assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2688        assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2689        assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2690
2691        let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2692        let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2693        let fd_epsilon = InverseLinkJet {
2694            mu: (ep.mu - em.mu) / (2.0 * h),
2695            d1: (ep.d1 - em.d1) / (2.0 * h),
2696            d2: (ep.d2 - em.d2) / (2.0 * h),
2697            d3: (ep.d3 - em.d3) / (2.0 * h),
2698        };
2699        assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2700        assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2701        assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2702        assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2703        assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2704        assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2705        assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2706        assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2707    }
2708
2709    #[test]
2710    fn beta_logistic_left_tail_uses_unclamped_log_space() {
2711        let eta = -40.0_f64;
2712        let delta = 0.2_f64;
2713        let epsilon = -0.1_f64;
2714        let a = (delta - epsilon).exp();
2715        let b = (delta + epsilon).exp();
2716        let expected_mu = beta_reg(a, b, eta.exp());
2717        let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2718
2719        assert!(
2720            (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2721            "left-tail mu mismatch: got {}, expected {}",
2722            out.mu,
2723            expected_mu
2724        );
2725        assert!(out.d1 > 0.0);
2726        assert!(out.d2 > 0.0);
2727        assert!(out.d3 > 0.0);
2728        assert!(out.d1 < 1e-20);
2729
2730        let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2731        assert!(partials.jet.d1 > 0.0);
2732        assert!(partials.jet.d2 > 0.0);
2733        assert!(partials.jet.d3 > 0.0);
2734        assert!(partials.djet_dlog_delta.d1.is_finite());
2735        assert!(partials.djet_depsilon.d1.is_finite());
2736    }
2737
2738    #[test]
2739    fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2740        let delta = 0.2;
2741        let epsilon = -0.35;
2742        let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2743        for eta in etas {
2744            let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2745            let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2746            assert!(
2747                (left - right).abs() <= 1e-14,
2748                "symmetry mismatch at eta={eta}: left={left}, right={right}"
2749            );
2750        }
2751    }
2752
2753    #[test]
2754    fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2755        let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2756        let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2757            epsilon: 0.18,
2758            log_delta: -0.22,
2759            delta: (-0.22_f64).exp(),
2760        });
2761        let mixture = InverseLink::Mixture(
2762            state_fromspec(&MixtureLinkSpec {
2763                components: vec![
2764                    LinkComponent::Probit,
2765                    LinkComponent::Logit,
2766                    LinkComponent::CLogLog,
2767                    LinkComponent::Cauchit,
2768                ],
2769                initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2770            })
2771            .expect("mixture state"),
2772        );
2773        let links = [
2774            InverseLink::Standard(StandardLink::Probit),
2775            InverseLink::Standard(StandardLink::Logit),
2776            InverseLink::Standard(StandardLink::CLogLog),
2777            sas,
2778            beta_logistic,
2779            mixture,
2780        ];
2781        let etas = [-1.1, -0.2, 0.6];
2782        let h = 1e-5;
2783
2784        for link in &links {
2785            for &eta in &etas {
2786                let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2787                let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2788                let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2789                let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2790                    .expect("analytic d4");
2791                assert_eq!(
2792                    d4.signum(),
2793                    d4fd.signum(),
2794                    "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2795                    link,
2796                    d4,
2797                    d4fd
2798                );
2799                assert!(
2800                    (d4 - d4fd).abs() < 5e-3,
2801                    "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2802                    link,
2803                    d4,
2804                    d4fd
2805                );
2806            }
2807        }
2808    }
2809
2810    #[test]
2811    fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2812        let eta = 800.0;
2813        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2814        assert_eq!(jet.mu, 1.0);
2815        assert!(
2816            jet.d1 == 0.0,
2817            "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2818            jet.d1
2819        );
2820        assert!(
2821            jet.d2 == 0.0,
2822            "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2823            jet.d2
2824        );
2825        assert!(
2826            jet.d3 == 0.0,
2827            "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2828            jet.d3
2829        );
2830
2831        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2832            &InverseLink::Standard(StandardLink::CLogLog),
2833            eta,
2834        )
2835        .expect("cloglog d4");
2836        assert!(
2837            d4 == 0.0,
2838            "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2839        );
2840    }
2841
2842    #[test]
2843    fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2844        let eta = -800.0;
2845        let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2846        assert_eq!(jet.mu, 0.0);
2847        assert!(
2848            jet.d1 == 0.0,
2849            "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2850            jet.d1
2851        );
2852        assert!(
2853            jet.d2 == 0.0,
2854            "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2855            jet.d2
2856        );
2857        assert!(
2858            jet.d3 == 0.0,
2859            "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2860            jet.d3
2861        );
2862
2863        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2864            &InverseLink::Mixture(
2865                state_fromspec(&MixtureLinkSpec {
2866                    components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2867                    initial_rho: Array1::from_vec(vec![12.0]),
2868                })
2869                .expect("mixture state"),
2870            ),
2871            eta,
2872        )
2873        .expect("loglog mixture d4");
2874        assert!(
2875            d4.is_finite(),
2876            "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2877        );
2878    }
2879
2880    #[test]
2881    fn logit_tail_derivatives_should_match_stable_closed_forms() {
2882        let eta = 50.0_f64;
2883        let z = (-eta).exp();
2884        let denom = 1.0_f64 + z;
2885        let stable_d1 = z / denom.powi(2);
2886        let stable_d2 = z * (z - 1.0) / denom.powi(3);
2887        let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2888        let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2889        let stable_d5 =
2890            z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2891
2892        assert!(stable_d1 > 0.0);
2893        assert!(stable_d2 < 0.0);
2894        assert!(stable_d3 > 0.0);
2895        assert!(stable_d4 < 0.0);
2896        assert!(stable_d5 > 0.0);
2897
2898        let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2899        assert!(
2900            (jet.d1 - stable_d1).abs() < 1e-30,
2901            "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2902            jet.d1,
2903            stable_d1
2904        );
2905        assert!(
2906            (jet.d2 - stable_d2).abs() < 1e-30,
2907            "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2908            jet.d2,
2909            stable_d2
2910        );
2911        assert!(
2912            (jet.d3 - stable_d3).abs() < 1e-30,
2913            "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2914            jet.d3,
2915            stable_d3
2916        );
2917
2918        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2919            &InverseLink::Standard(StandardLink::Logit),
2920            eta,
2921        )
2922        .expect("logit d4");
2923        assert!(
2924            (d4 - stable_d4).abs() < 1e-30,
2925            "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2926            d4,
2927            stable_d4
2928        );
2929
2930        let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2931            &InverseLink::Standard(StandardLink::Logit),
2932            eta,
2933        )
2934        .expect("logit d5");
2935        assert!(
2936            (d5 - stable_d5).abs() < 1e-30,
2937            "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2938            d5,
2939            stable_d5
2940        );
2941    }
2942
2943    #[test]
2944    fn cloglog_negative_tail_value_should_match_expm1_form() {
2945        let eta = -50.0_f64;
2946        let t = eta.exp();
2947        let stable_mu = -(-t).exp_m1();
2948        assert!(stable_mu > 0.0);
2949
2950        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2951        assert!(
2952            (jet.mu - stable_mu).abs() < 1e-30,
2953            "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2954            jet.mu,
2955            stable_mu
2956        );
2957    }
2958
2959    #[test]
2960    fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2961        fn rel_err(a: f64, b: f64) -> f64 {
2962            (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2963        }
2964
2965        let cases = [
2966            (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2967            (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2968            (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2969        ];
2970        for (component, etas) in cases {
2971            for eta in etas {
2972                let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
2973                let jet = component_inverse_link_jet(component, eta);
2974                let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
2975                assert!(
2976                    rel_err(w, expected) < 1.0e-12,
2977                    "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
2978                );
2979
2980                let h = 1.0e-4;
2981                let fd1 = (component_fisher_weight_jet5(component, eta + h).0
2982                    - component_fisher_weight_jet5(component, eta - h).0)
2983                    / (2.0 * h);
2984                let fd2 = (component_fisher_weight_jet5(component, eta + h).1
2985                    - component_fisher_weight_jet5(component, eta - h).1)
2986                    / (2.0 * h);
2987                let fd3 = (component_fisher_weight_jet5(component, eta + h).2
2988                    - component_fisher_weight_jet5(component, eta - h).2)
2989                    / (2.0 * h);
2990                let fd4 = (component_fisher_weight_jet5(component, eta + h).3
2991                    - component_fisher_weight_jet5(component, eta - h).3)
2992                    / (2.0 * h);
2993
2994                assert!(
2995                    rel_err(w1, fd1) < 1.0e-5,
2996                    "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
2997                );
2998                assert!(
2999                    rel_err(w2, fd2) < 1.0e-5,
3000                    "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
3001                );
3002                assert!(
3003                    rel_err(w3, fd3) < 5.0e-5,
3004                    "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
3005                );
3006                assert!(
3007                    rel_err(w4, fd4) < 5.0e-4,
3008                    "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3009                );
3010            }
3011        }
3012    }
3013
3014    #[test]
3015    fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3016        let state = state_fromspec(&MixtureLinkSpec {
3017            components: vec![
3018                LinkComponent::CLogLog,
3019                LinkComponent::LogLog,
3020                LinkComponent::Cauchit,
3021            ],
3022            initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3023        })
3024        .expect("mixture state");
3025        let link = InverseLink::Mixture(state);
3026        assert!(
3027            inverse_link_has_fisher_weight_jet(&link),
3028            "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3029        );
3030        assert!(
3031            LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3032            "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3033        );
3034
3035        for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3036            let (w, w1, w2, w3, w4) =
3037                fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3038            for value in [w, w1, w2, w3, w4] {
3039                assert!(
3040                    value.is_finite(),
3041                    "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3042                );
3043            }
3044            assert!(
3045                w > 0.0,
3046                "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3047            );
3048        }
3049    }
3050
3051    #[test]
3052    fn loglog_fifth_derivative_should_match_closed_form_sign() {
3053        let eta = 0.0_f64;
3054        let r = (-eta).exp();
3055        let expected =
3056            (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3057        let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3058        assert!(
3059            (d5 - expected).abs() < 1e-15,
3060            "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3061        );
3062        assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3063    }
3064}