Skip to main content

gam_solve/
mixture_link.rs

1use crate::estimate::EstimationError;
2use gam_math::probability::{normal_cdf, normal_pdf};
3use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
4use crate::quadrature::latent_cloglog_jet5;
5use gam_problem::{
6    InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7    MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16/// Bound B used by the bounded sinh-arcsinh log-delta parameterisation:
17/// `delta = exp(B * tanh(raw_log_delta / B))`. Exposed for the outer-strategy
18/// edge-barrier helpers in `solver/estimate.rs` that previously had to
19/// hard-code the same `12.0` with a "must match" comment.
20pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24    static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25    QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30    state: &LatentCLogLogState,
31    eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33    let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34    Ok(InverseLinkJet {
35        mu: jet.mean,
36        d1: jet.d1,
37        d2: jet.d2,
38        d3: jet.d3,
39    })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44    pub mu: f64,
45    pub d1: f64,
46    pub d2: f64,
47    pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52    pub mu: f64,
53    pub d1: f64,
54    pub d2: f64,
55    pub d3: f64,
56    pub d4: f64,
57    pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62    if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67    jet.d1 = canonicalzero(jet.d1);
68    jet.d2 = canonicalzero(jet.d2);
69    jet.d3 = canonicalzero(jet.d3);
70    jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75    if eta.is_nan() {
76        return LogitJet5 {
77            mu: f64::NAN,
78            d1: f64::NAN,
79            d2: f64::NAN,
80            d3: f64::NAN,
81            d4: f64::NAN,
82            d5: f64::NAN,
83        };
84    }
85    if eta == f64::INFINITY {
86        return LogitJet5 {
87            mu: 1.0,
88            d1: 0.0,
89            d2: 0.0,
90            d3: 0.0,
91            d4: 0.0,
92            d5: 0.0,
93        };
94    }
95    if eta == f64::NEG_INFINITY {
96        return LogitJet5 {
97            mu: 0.0,
98            d1: 0.0,
99            d2: 0.0,
100            d3: 0.0,
101            d4: 0.0,
102            d5: 0.0,
103        };
104    }
105
106    let jet = if eta >= 0.0 {
107        let z = (-eta).exp();
108        let opz = 1.0 + z;
109        let opz2 = opz * opz;
110        let opz3 = opz2 * opz;
111        let opz4 = opz3 * opz;
112        let opz5 = opz4 * opz;
113        let opz6 = opz5 * opz;
114        let z2 = z * z;
115        let z3 = z2 * z;
116        let z4 = z3 * z;
117        LogitJet5 {
118            mu: 1.0 / opz,
119            d1: z / opz2,
120            d2: z * (z - 1.0) / opz3,
121            d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122            d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123            d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124        }
125    } else {
126        let z = eta.exp();
127        let opz = 1.0 + z;
128        let opz2 = opz * opz;
129        let opz3 = opz2 * opz;
130        let opz4 = opz3 * opz;
131        let opz5 = opz4 * opz;
132        let opz6 = opz5 * opz;
133        let z2 = z * z;
134        let z3 = z2 * z;
135        let z4 = z3 * z;
136        LogitJet5 {
137            mu: z / opz,
138            d1: z / opz2,
139            d2: z * (1.0 - z) / opz3,
140            d3: z * (1.0 - 4.0 * z + z2) / opz4,
141            d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142            d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143        }
144    };
145    LogitJet5 {
146        mu: jet.mu,
147        d1: canonicalzero(jet.d1),
148        d2: canonicalzero(jet.d2),
149        d3: canonicalzero(jet.d3),
150        d4: canonicalzero(jet.d4),
151        d5: canonicalzero(jet.d5),
152    }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157    // Exact probit semantics:
158    //
159    //   mu(eta) = Phi(eta),
160    //   mu'     = phi(eta),
161    //   mu''    = -eta * phi(eta),
162    //   mu'''   = (eta^2 - 1) * phi(eta).
163    //
164    // `normal_cdf` now evaluates the exact special-function form
165    // Phi(x) = 0.5 * erfc(-x / sqrt(2)), so the jet can and should use the
166    // matching closed-form Gaussian identities directly.
167    if eta.is_nan() {
168        return InverseLinkJet {
169            mu: f64::NAN,
170            d1: f64::NAN,
171            d2: f64::NAN,
172            d3: f64::NAN,
173        };
174    }
175    if eta == f64::INFINITY {
176        return InverseLinkJet {
177            mu: 1.0,
178            d1: 0.0,
179            d2: 0.0,
180            d3: 0.0,
181        };
182    }
183    if eta == f64::NEG_INFINITY {
184        return InverseLinkJet {
185            mu: 0.0,
186            d1: 0.0,
187            d2: 0.0,
188            d3: 0.0,
189        };
190    }
191    let x = eta;
192    let phi = normal_pdf(x);
193    InverseLinkJet {
194        mu: normal_cdf(x),
195        d1: phi,
196        d2: -x * phi,
197        d3: (x * x - 1.0) * phi,
198    }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203    // Since d1 = mu' = phi(eta), this returns
204    //
205    //   d³/deta³ d1 = mu'''' = -(eta³ - 3 eta) phi(eta).
206    if eta.is_nan() {
207        return f64::NAN;
208    }
209    if !eta.is_finite() {
210        return 0.0;
211    }
212    let x = eta;
213    let phi = normal_pdf(x);
214    canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219    // mu''''' = Phi^{(5)}(eta) = (eta^4 - 6*eta^2 + 3) * phi(eta).
220    if eta.is_nan() {
221        return f64::NAN;
222    }
223    if !eta.is_finite() {
224        return 0.0;
225    }
226    let x = eta;
227    let phi = normal_pdf(x);
228    canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231/// Multiply two 5-term truncated Taylor series (coefficients `a_k = g^(k)/k!`,
232/// `k = 0..=4`) and return the truncated product coefficients.
233#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235    let mut c = [0.0_f64; 5];
236    for i in 0..5 {
237        let ai = a[i];
238        if ai == 0.0 {
239            continue;
240        }
241        for j in 0..(5 - i) {
242            c[i + j] += ai * b[j];
243        }
244    }
245    c
246}
247
248/// Reciprocal of a 5-term truncated Taylor series with nonzero constant term.
249#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251    let mut b = [0.0_f64; 5];
252    b[0] = 1.0 / a[0];
253    for k in 1..5 {
254        let mut s = 0.0_f64;
255        for j in 1..=k {
256            s += a[j] * b[k - j];
257        }
258        b[k] = -s * b[0];
259    }
260    b
261}
262
263/// 5-jet (value + four eta-derivatives) of the GLM Fisher working weight
264/// `W(eta) = mu'(eta)^2 / V(mu(eta))` for the requested standard link, returned
265/// as `(W, W', W'', W''', W'''')`.
266///
267/// For the canonical logit link this is exactly the binomial weight
268/// `W = mu(1 - mu) = mu'`, whose eta-derivatives are the higher derivatives of
269/// the inverse-link jet (`W^(k) = mu^(k+1)`); the dispatch returns
270/// `logit_inverse_link_jet5`'s `d1..d5` byte-for-byte so the existing Firth
271/// logit path is numerically unchanged.
272///
273/// Noncanonical Bernoulli links use the same truncated Taylor-series quotient:
274/// assemble the inverse-link jet through `mu^(5)`, square the `mu'` series, and
275/// divide by the Bernoulli variance series `mu(1-mu)`. As the variance
276/// denominator saturates to zero in either tail, the weight and all derivatives
277/// saturate to zero, matching the inverse-link jet convention.
278pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279    match link {
280        StandardLink::Logit => {
281            let jet = logit_inverse_link_jet5(eta);
282            (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283        }
284        StandardLink::Probit => probit_fisher_weight_jet5(eta),
285        StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286        StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
287    }
288}
289
290pub(crate) fn fisher_weight_jet5_for_inverse_link(
291    link: &InverseLink,
292    eta: f64,
293) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
294    match link {
295        InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
296        InverseLink::LatentCLogLog(_)
297        | InverseLink::Sas(_)
298        | InverseLink::BetaLogistic(_)
299        | InverseLink::Mixture(_) => {
300            let jet = link.jet(eta)?;
301            let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
302            let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
303            Ok(fisher_weight_jet5_from_inverse_link_derivatives(
304                jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
305            ))
306        }
307    }
308}
309
310#[inline]
311pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
312    matches!(
313        link,
314        InverseLink::Standard(StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,)
315            | InverseLink::LatentCLogLog(_)
316            | InverseLink::Sas(_)
317            | InverseLink::BetaLogistic(_)
318            | InverseLink::Mixture(_)
319    )
320}
321
322#[inline]
323fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
324    let jet = component_inverse_link_jet(component, eta);
325    let d4 = component_inverse_link_pdfthird_derivative(component, eta);
326    let d5 = component_inverse_link_pdffourth_derivative(component, eta);
327    fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
328}
329
330#[inline]
331fn fisher_weight_jet5_from_inverse_link_derivatives(
332    mu: f64,
333    d1: f64,
334    d2: f64,
335    d3: f64,
336    d4: f64,
337    d5: f64,
338) -> (f64, f64, f64, f64, f64) {
339    if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
340        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
341    }
342    let variance = mu * (1.0 - mu);
343    if !(variance > 0.0) || !variance.is_finite() {
344        return (0.0, 0.0, 0.0, 0.0, 0.0);
345    }
346
347    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
348    let mu_d = [mu, d1, d2, d3, d4];
349    let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
350    let dmu_d = [d1, d2, d3, d4, d5];
351    let mut mu_t = [0.0_f64; 5];
352    let mut one_minus_mu_t = [0.0_f64; 5];
353    let mut dmu_t = [0.0_f64; 5];
354    for k in 0..5 {
355        let inv_fact = 1.0 / factorial[k];
356        mu_t[k] = mu_d[k] * inv_fact;
357        one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
358        dmu_t[k] = dmu_d[k] * inv_fact;
359    }
360    let num_t = taylor5_mul(&dmu_t, &dmu_t);
361    let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
362    if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
363        return (0.0, 0.0, 0.0, 0.0, 0.0);
364    }
365    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
366    (
367        canonicalzero(w_t[0] * factorial[0]),
368        canonicalzero(w_t[1] * factorial[1]),
369        canonicalzero(w_t[2] * factorial[2]),
370        canonicalzero(w_t[3] * factorial[3]),
371        canonicalzero(w_t[4] * factorial[4]),
372    )
373}
374
375/// Probit Bernoulli Fisher-weight 5-jet `W = phi^2 / (Phi (1 - Phi))` and its
376/// first four eta-derivatives. See [`fisher_weight_jet5`].
377#[inline]
378fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
379    if eta.is_nan() {
380        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
381    }
382    if !eta.is_finite() {
383        return (0.0, 0.0, 0.0, 0.0, 0.0);
384    }
385    let x = eta;
386    let p = normal_cdf(x);
387    // Compute the complement directly via Phi(-x) rather than `1 - Phi(x)`:
388    // in the positive tail `Phi(x)` rounds to 1.0 and `1 - Phi(x)` cancels to
389    // zero, whereas `Phi(-x)` retains the accurate (tiny) tail mass.
390    let q = normal_cdf(-x);
391    let phi = normal_pdf(x);
392    // Saturated tail: the denominator Phi(1-Phi) has underflowed to zero (or
393    // would divide by zero); the working weight and all derivatives go to zero.
394    if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
395        return (0.0, 0.0, 0.0, 0.0, 0.0);
396    }
397    // Gaussian derivative ladder: phi^(k) for k = 0..=4 using phi' = -x phi.
398    let phi1 = -x * phi;
399    let phi2 = (x * x - 1.0) * phi;
400    let phi3 = -(x * x * x - 3.0 * x) * phi;
401    let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
402    // Derivative arrays (d^k/deta^k) for f = phi, p = Phi, q = 1 - Phi.
403    // p^(0) = Phi, p^(k>=1) = phi^(k-1); q is the negated complement.
404    let f_d = [phi, phi1, phi2, phi3, phi4];
405    let p_d = [p, phi, phi1, phi2, phi3];
406    let q_d = [q, -phi, -phi1, -phi2, -phi3];
407    // Convert derivative arrays to Taylor coefficients a_k = g^(k)/k!.
408    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
409    let mut f_t = [0.0_f64; 5];
410    let mut p_t = [0.0_f64; 5];
411    let mut q_t = [0.0_f64; 5];
412    for k in 0..5 {
413        let inv_fact = 1.0 / factorial[k];
414        f_t[k] = f_d[k] * inv_fact;
415        p_t[k] = p_d[k] * inv_fact;
416        q_t[k] = q_d[k] * inv_fact;
417    }
418    let num_t = taylor5_mul(&f_t, &f_t);
419    let den_t = taylor5_mul(&p_t, &q_t);
420    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
421    // Back to derivatives W^(k) = w_t[k] * k!.
422    (
423        canonicalzero(w_t[0] * factorial[0]),
424        canonicalzero(w_t[1] * factorial[1]),
425        canonicalzero(w_t[2] * factorial[2]),
426        canonicalzero(w_t[3] * factorial[3]),
427        canonicalzero(w_t[4] * factorial[4]),
428    )
429}
430
431#[inline]
432fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
433    InverseLinkJet {
434        mu: base.mu,
435        d1: base.d1 * z1,
436        d2: base.d2 * z1 * z1 + base.d1 * z2,
437        d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
438    }
439}
440
441#[inline]
442fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
443    match component {
444        LinkComponent::Probit => probit_pdfthird_derivative(eta),
445        LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
446        LinkComponent::CLogLog => {
447            // CLogLog link:
448            //   mu = 1 - exp(-t),  t = exp(eta),  d1 = t exp(-t).
449            //
450            // Repeated differentiation closes in the basis `d1 * poly(t)`:
451            //   d2 = d1(-t + 1)
452            //   d3 = d1(t² - 3t + 1)
453            //   d4 = d1(-t³ + 6t² - 7t + 1).
454            if eta.is_nan() {
455                return f64::NAN;
456            }
457            if !eta.is_finite() {
458                return 0.0;
459            }
460            let t = eta.exp();
461            canonicalzero(stable_nonnegative_poly_times_exp_neg(
462                t,
463                &[0.0, 1.0, -7.0, 6.0, -1.0],
464            ))
465        }
466        LinkComponent::LogLog => {
467            // LogLog link is the reflected cloglog family with `r = exp(-eta)`:
468            //   mu = exp(-r), d1 = mu r,
469            // and again higher derivatives are `d1 * poly(r)`:
470            //   d2 = d1(r - 1)
471            //   d3 = d1(r² - 3r + 1)
472            //   d4 = d1(r³ - 6r² + 7r - 1).
473            if eta.is_nan() {
474                return f64::NAN;
475            }
476            if !eta.is_finite() {
477                return 0.0;
478            }
479            let r = (-eta).exp();
480            canonicalzero(stable_nonnegative_poly_times_exp_neg(
481                r,
482                &[0.0, -1.0, 7.0, -6.0, 1.0],
483            ))
484        }
485        LinkComponent::Cauchit => {
486            // Cauchit link:
487            //   mu = 1/2 + atan(eta)/pi,
488            //   d1 = 1 / [pi (1+eta²)].
489            //
490            // Differentiating three more times gives
491            //
492            //   d4 = 24 eta (1-eta²) / [pi (1+eta²)^4].
493            if eta.is_nan() {
494                return f64::NAN;
495            }
496            if !eta.is_finite() {
497                return 0.0;
498            }
499            let denom = 1.0 + eta * eta;
500            24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
501        }
502    }
503}
504
505/// Fifth derivative of a component inverse-link CDF (= fourth derivative of PDF).
506/// Extends `component_inverse_link_pdfthird_derivative` by one derivative order.
507#[inline]
508fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
509    match component {
510        LinkComponent::Probit => probit_pdffourth_derivative(eta),
511        LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
512        LinkComponent::CLogLog => {
513            // Exact closed form:
514            //   d5 = exp(-t) * (t - 15t^2 + 25t^3 - 10t^4 + t^5)
515            //      = d1 * (1 - 15t + 25t^2 - 10t^3 + t^4),
516            // where t = exp(eta).
517            if eta.is_nan() {
518                return f64::NAN;
519            }
520            if !eta.is_finite() {
521                return 0.0;
522            }
523            let t = eta.exp();
524            canonicalzero(stable_nonnegative_poly_times_exp_neg(
525                t,
526                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
527            ))
528        }
529        LinkComponent::LogLog => {
530            // Exact closed form:
531            //   d5 = exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5)
532            //      = d1 * (1 - 15r + 25r^2 - 10r^3 + r^4),
533            // where r = exp(-eta).
534            if eta.is_nan() {
535                return f64::NAN;
536            }
537            if !eta.is_finite() {
538                return 0.0;
539            }
540            let r = (-eta).exp();
541            canonicalzero(stable_nonnegative_poly_times_exp_neg(
542                r,
543                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
544            ))
545        }
546        LinkComponent::Cauchit => {
547            // d5 = 24(1 - 10eta^2 + 5eta^4) / [pi * (1+eta^2)^5]
548            if eta.is_nan() {
549                return f64::NAN;
550            }
551            if !eta.is_finite() {
552                return 0.0;
553            }
554            let e2 = eta * eta;
555            let denom = 1.0 + e2;
556            24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
557        }
558    }
559}
560
561#[derive(Clone, Debug, PartialEq)]
562pub struct MixtureJetWithRhoPartials {
563    pub jet: InverseLinkJet,
564    /// Partial derivatives wrt free logits rho_j, j in [0, K-2].
565    /// Each entry stores derivatives of (mu, d1, d2, d3) wrt one rho_j.
566    pub djet_drho: Vec<InverseLinkJet>,
567}
568
569#[derive(Clone, Debug, PartialEq)]
570pub struct SasJetWithParamPartials {
571    pub jet: InverseLinkJet,
572    pub djet_depsilon: InverseLinkJet,
573    pub djet_dlog_delta: InverseLinkJet,
574}
575
576#[derive(Clone, Debug, PartialEq)]
577pub enum LinkParamPartials {
578    Mixture(MixtureJetWithRhoPartials),
579    Sas(SasJetWithParamPartials),
580}
581
582/// Trait-based inverse-link kernel interface.
583///
584/// Implementors provide pointwise inverse-link derivatives wrt `eta`:
585/// `F(eta), F'(eta), F''(eta), F'''(eta)`.
586/// Optionally they may expose parameter partials used by outer-loop optimization.
587pub trait InverseLinkKernel {
588    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
589
590    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
591        assert!(eta.is_finite(), "eta must be finite");
592        Ok(None)
593    }
594}
595
596#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
597pub struct ProbitLinkKernel;
598
599#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
600pub struct LogitLinkKernel;
601
602#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
603pub struct CLogLogLinkKernel;
604
605#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
606pub struct LogLogLinkKernel;
607
608#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
609pub struct CauchitLinkKernel;
610
611/// Construct SAS state from raw optimizer parameters using the same bounded
612/// transform used everywhere in fitting/evaluation.
613///
614/// A free function rather than an inherent `SasLinkState::new` because the
615/// bounded `delta` transform is solver-side math, so the constructor is hosted
616/// here next to the transform rather than on the type. `SasLinkState`'s fields
617/// are `pub`, so it builds directly.
618pub fn sas_link_state_from_raw(
619    raw_epsilon: f64,
620    raw_log_delta: f64,
621) -> Result<SasLinkState, String> {
622    if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
623        return Err("SAS link parameters must be finite".to_string());
624    }
625    Ok(SasLinkState {
626        epsilon: raw_epsilon,
627        log_delta: raw_log_delta,
628        delta: sas_delta_from_raw_log_delta(raw_log_delta),
629    })
630}
631
632pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
633    sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
634}
635
636pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
637    if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
638        return Err("Beta-Logistic link parameters must be finite".to_string());
639    }
640    // For Beta-Logistic, `log_delta` is the unconstrained log geometric-mean beta
641    // shape (the kernels' `log_shape_center`); the derived `delta = exp(log_delta)`
642    // is the positive geometric-mean shape `sqrt(a*b)`. Evaluation consumes
643    // `log_delta`, never `delta`.
644    let log_shape_center = spec.initial_log_delta;
645    Ok(SasLinkState {
646        epsilon: spec.initial_epsilon,
647        log_delta: log_shape_center,
648        delta: log_shape_center.exp(),
649    })
650}
651
652#[inline]
653fn tanh_bound(value: f64, bound: f64) -> f64 {
654    let b = bound.max(f64::EPSILON);
655    b * (value / b).tanh()
656}
657
658#[inline]
659fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
660    let b = bound.max(f64::EPSILON);
661    let t = (value / b).tanh();
662    1.0 - t * t
663}
664
665#[inline]
666fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
667    let b = bound.max(f64::EPSILON);
668    let t = (value / b).tanh();
669    let s = 1.0 - t * t;
670    -2.0 * t * s / b
671}
672
673#[inline]
674fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
675    let b = bound.max(f64::EPSILON);
676    let t = (value / b).tanh();
677    let s = 1.0 - t * t;
678    -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
679}
680
681#[inline]
682fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
683    let b = bound.max(f64::EPSILON);
684    let t = (value / b).tanh();
685    let s = 1.0 - t * t;
686    8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
687}
688
689#[inline]
690fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
691    // 5th derivative of B * tanh(x/B):
692    //   g5 = 8 * s * (2 - 15*t^2 + 15*t^4) / B^4
693    // where t = tanh(x/B) and s = 1 - t^2.
694    let b = bound.max(f64::EPSILON);
695    let t = (value / b).tanh();
696    let s = 1.0 - t * t;
697    let t2 = t * t;
698    let b4 = b * b * b * b;
699    8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
700}
701
702#[inline]
703fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
704    let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
705    let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
706    (ld_eff, dld_eff_draw)
707}
708
709#[inline]
710fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
711    let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
712    ld_eff.exp()
713}
714
715pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
716    if spec.components.is_empty() {
717        return Err("mixture link requires at least 1 component".to_string());
718    }
719    if spec.initial_rho.len() + 1 != spec.components.len() {
720        return Err(format!(
721            "mixture link rho length mismatch: expected {}, got {}",
722            spec.components.len() - 1,
723            spec.initial_rho.len()
724        ));
725    }
726    for i in 0..spec.components.len() {
727        for j in (i + 1)..spec.components.len() {
728            if spec.components[i] == spec.components[j] {
729                return Err("mixture link components must be unique".to_string());
730            }
731        }
732    }
733    // `LinkComponent` admits two variants (Cauchit, LogLog) that have no matching
734    // `LinkFunction` entry. The mixture-link pipeline projects every mixture back onto
735    // a single `LinkFunction` value for downstream solver/IO bookkeeping (see
736    // `InverseLink::link_function`), so a mixture composed solely of components without
737    // a LinkFunction representative would silently lie about its projected link. We
738    // require every spec to contain at least one Logit/Probit/CLogLog "anchor" so the
739    // projection is meaningful. A mixture of only {Cauchit, LogLog} (or any single
740    // component restricted to that pair) has no anchor and is rejected here.
741    let has_anchor = spec.components.iter().any(|component| {
742        matches!(
743            component,
744            LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
745        )
746    });
747    if !has_anchor {
748        let unsupported: Vec<&str> = spec
749            .components
750            .iter()
751            .map(|component| component.name())
752            .collect();
753        return Err(format!(
754            "mixture link components {{{}}} are unsupported: at least one component \
755             must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
756             projected LinkFunction is well defined; cauchit and loglog have no \
757             LinkFunction representative",
758            unsupported.join(", ")
759        ));
760    }
761    Ok(())
762}
763
764pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
765    let k = rho.len() + 1;
766    let mut logits = Vec::with_capacity(k);
767    let mut maxv = 0.0_f64;
768    for &v in rho {
769        maxv = maxv.max(v);
770        logits.push(v);
771    }
772    maxv = maxv.max(0.0);
773    logits.push(0.0);
774
775    let mut sum = 0.0_f64;
776    let mut exps = vec![0.0_f64; k];
777    for i in 0..k {
778        let e = (logits[i] - maxv).exp();
779        exps[i] = e;
780        sum += e;
781    }
782    if !sum.is_finite() || sum <= 0.0 {
783        return Array1::from_elem(k, 1.0 / k as f64);
784    }
785    let inv = 1.0 / sum;
786    Array1::from_iter(exps.into_iter().map(|v| v * inv))
787}
788
789/// Returns softmax weights and Jacobian wrt free logits (last logit fixed at zero).
790/// Jacobian shape is (K, K-1): d pi_k / d rho_j.
791pub fn softmaxwith_jacobian_last_fixedzero(
792    rho: &Array1<f64>,
793) -> (Array1<f64>, ndarray::Array2<f64>) {
794    let pi = softmax_last_fixedzero(rho);
795    let k = pi.len();
796    let m = k.saturating_sub(1);
797    let mut jac = ndarray::Array2::<f64>::zeros((k, m));
798    for j in 0..m {
799        let pi_j = pi[j];
800        for kk in 0..k {
801            let delta = if kk == j { 1.0 } else { 0.0 };
802            jac[[kk, j]] = pi[kk] * (delta - pi_j);
803        }
804    }
805    (pi, jac)
806}
807
808pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
809    validate_mixturespec(spec)?;
810    let pi = softmax_last_fixedzero(&spec.initial_rho);
811    Ok(MixtureLinkState {
812        components: spec.components.clone(),
813        rho: spec.initial_rho.clone(),
814        pi,
815    })
816}
817
818#[inline]
819pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
820    canonicalize_jet(match component {
821        LinkComponent::Logit => {
822            let jet = logit_inverse_link_jet5(eta);
823            InverseLinkJet {
824                mu: jet.mu,
825                d1: jet.d1,
826                d2: jet.d2,
827                d3: jet.d3,
828            }
829        }
830        LinkComponent::Probit => probit_jet(eta),
831        LinkComponent::CLogLog => {
832            if eta.is_nan() {
833                return InverseLinkJet {
834                    mu: f64::NAN,
835                    d1: f64::NAN,
836                    d2: f64::NAN,
837                    d3: f64::NAN,
838                };
839            }
840            let t = eta.exp();
841            if !t.is_finite() {
842                return InverseLinkJet {
843                    mu: 1.0,
844                    d1: 0.0,
845                    d2: 0.0,
846                    d3: 0.0,
847                };
848            }
849            InverseLinkJet {
850                mu: -(-t).exp_m1(),
851                d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
852                d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
853                d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
854            }
855        }
856        LinkComponent::LogLog => {
857            if eta.is_nan() {
858                return InverseLinkJet {
859                    mu: f64::NAN,
860                    d1: f64::NAN,
861                    d2: f64::NAN,
862                    d3: f64::NAN,
863                };
864            }
865            let r = (-eta).exp();
866            if !r.is_finite() {
867                return InverseLinkJet {
868                    mu: 0.0,
869                    d1: 0.0,
870                    d2: 0.0,
871                    d3: 0.0,
872                };
873            }
874            InverseLinkJet {
875                mu: (-r).exp(),
876                d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
877                d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
878                d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
879            }
880        }
881        LinkComponent::Cauchit => {
882            if eta.is_nan() {
883                return InverseLinkJet {
884                    mu: f64::NAN,
885                    d1: f64::NAN,
886                    d2: f64::NAN,
887                    d3: f64::NAN,
888                };
889            }
890            let den = 1.0 + eta * eta;
891            let d1 = if eta.is_finite() {
892                1.0 / (std::f64::consts::PI * den)
893            } else {
894                0.0
895            };
896            let d2 = if eta.is_finite() {
897                -2.0 * eta / (std::f64::consts::PI * den * den)
898            } else {
899                0.0
900            };
901            let d3 = if eta.is_finite() {
902                (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
903            } else {
904                0.0
905            };
906            InverseLinkJet {
907                mu: 0.5 + eta.atan() / std::f64::consts::PI,
908                d1,
909                d2,
910                d3,
911            }
912        }
913    })
914}
915
916impl InverseLinkKernel for ProbitLinkKernel {
917    #[inline]
918    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
919        Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
920    }
921}
922
923impl InverseLinkKernel for LogitLinkKernel {
924    #[inline]
925    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
926        Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
927    }
928}
929
930impl InverseLinkKernel for CLogLogLinkKernel {
931    #[inline]
932    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
933        Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
934    }
935}
936
937impl InverseLinkKernel for LogLogLinkKernel {
938    #[inline]
939    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
940        Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
941    }
942}
943
944impl InverseLinkKernel for CauchitLinkKernel {
945    #[inline]
946    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
947        Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
948    }
949}
950
951impl InverseLinkKernel for LinkComponent {
952    #[inline]
953    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
954        Ok(component_inverse_link_jet(*self, eta))
955    }
956}
957
958impl InverseLinkKernel for LinkFunction {
959    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
960        match self {
961            LinkFunction::Logit => LogitLinkKernel.jet(eta),
962            LinkFunction::Probit => ProbitLinkKernel.jet(eta),
963            LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
964            LinkFunction::Identity => Ok(InverseLinkJet {
965                mu: eta,
966                d1: 1.0,
967                d2: 0.0,
968                d3: 0.0,
969            }),
970            LinkFunction::Log => {
971                // SOLVER-INTERNAL inverse-link jet: `η.clamp(−700, 700).exp()`.
972                // The clamp is an intentional conditioning hack so the IRLS/REML
973                // normal equations stay well posed when η wanders into the tails
974                // during a trust-region step — it is NOT the public response
975                // transform. Public response-scale outputs (predictions, FFI
976                // `apply_inverse_link_array`, posterior bands) must use the EXACT
977                // `exp(η)` in `families::inverse_link::apply_inverse_link_vec`,
978                // which is finite wherever representable. Do not reroute a public
979                // output through this clamped jet (issue #963). Keep the clamp:
980                // solver consumers (e.g. `reml/runtime.rs` trust-region `excess`)
981                // pass raw η and rely on it to keep μ finite.
982                let e = eta.clamp(-700.0, 700.0).exp();
983                Ok(InverseLinkJet {
984                    mu: e,
985                    d1: e,
986                    d2: e,
987                    d3: e,
988                })
989            }
990            LinkFunction::Sas => Err(EstimationError::InvalidInput(
991                "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
992            )),
993            LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
994                "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
995                    .to_string(),
996            )),
997        }
998    }
999}
1000
1001impl InverseLinkKernel for SasLinkState {
1002    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1003        Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1004    }
1005
1006    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1007        Ok(Some(LinkParamPartials::Sas(
1008            sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1009        )))
1010    }
1011}
1012
1013#[derive(Clone, Copy, Debug)]
1014pub struct BetaLogisticKernel {
1015    /// Unconstrained log of the geometric-mean beta shape — the raw optimization
1016    /// parameter `SasLinkState::log_delta`, NOT the derived `SasLinkState::delta`.
1017    pub log_shape_center: f64,
1018    pub epsilon: f64,
1019}
1020
1021impl InverseLinkKernel for BetaLogisticKernel {
1022    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1023        Ok(beta_logistic_inverse_link_jet(
1024            eta,
1025            self.log_shape_center,
1026            self.epsilon,
1027        ))
1028    }
1029
1030    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1031        Ok(Some(LinkParamPartials::Sas(
1032            beta_logistic_inverse_link_jetwith_param_partials(
1033                eta,
1034                self.log_shape_center,
1035                self.epsilon,
1036            ),
1037        )))
1038    }
1039}
1040
1041impl InverseLinkKernel for MixtureLinkState {
1042    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1043        Ok(mixture_inverse_link_jet(self, eta))
1044    }
1045
1046    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1047        Ok(Some(LinkParamPartials::Mixture(
1048            mixture_inverse_link_jetwith_rho_partials(self, eta),
1049        )))
1050    }
1051}
1052
1053impl InverseLinkKernel for InverseLink {
1054    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1055        match self {
1056            InverseLink::Standard(link_fn) => link_fn.as_link_function().jet(eta),
1057            InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1058            InverseLink::Sas(state) => state.jet(eta),
1059            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1060                log_shape_center: state.log_delta,
1061                epsilon: state.epsilon,
1062            }
1063            .jet(eta),
1064            InverseLink::Mixture(state) => state.jet(eta),
1065        }
1066    }
1067
1068    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1069        match self {
1070            InverseLink::Standard(_) => Ok(None),
1071            InverseLink::LatentCLogLog(_) => Ok(None),
1072            InverseLink::Sas(state) => state.param_partials(eta),
1073            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1074                log_shape_center: state.log_delta,
1075                epsilon: state.epsilon,
1076            }
1077            .param_partials(eta),
1078            InverseLink::Mixture(state) => state.param_partials(eta),
1079        }
1080    }
1081}
1082
1083/// Central family-aware inverse-link jet dispatch.
1084///
1085/// For `BinomialSas` and `BinomialMixture`, required state must be provided.
1086pub fn inverse_link_jet_for_inverse_link(
1087    link: &InverseLink,
1088    eta: f64,
1089) -> Result<InverseLinkJet, EstimationError> {
1090    link.jet(eta)
1091}
1092
1093/// Specialized `(mu, d1)` inverse-link evaluation that skips the d2/d3
1094/// polynomial chain used by the full jet. Numerical semantics are preserved:
1095/// the returned `mu` and `d1` are bit-identical to the corresponding fields of
1096/// `inverse_link_jet_for_inverse_link(link, eta)?` for every supported link.
1097///
1098/// For latent cloglog the underlying lognormal-Laplace kernel produces all
1099/// orders together, so this falls back to the full jet for that branch — the
1100/// savings come from the parameterised polynomial links (SAS, beta-logistic,
1101/// mixture) and the simple analytic links where d2/d3 are pure waste.
1102pub fn inverse_link_mu_d1_for_inverse_link(
1103    link: &InverseLink,
1104    eta: f64,
1105) -> Result<(f64, f64), EstimationError> {
1106    match link {
1107        InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1108        InverseLink::LatentCLogLog(state) => {
1109            let jet = latent_cloglog_point_jet(state, eta)?;
1110            Ok((jet.mu, jet.d1))
1111        }
1112        InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1113        InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1114            eta,
1115            state.log_delta,
1116            state.epsilon,
1117        )),
1118        InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1119    }
1120}
1121
1122fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1123    match link {
1124        LinkFunction::Identity => Ok((eta, 1.0)),
1125        LinkFunction::Log => {
1126            // SOLVER-INTERNAL clamped `(μ, dμ/dη)`; see the matching note on the
1127            // full `LinkFunction::Log` jet above. Public response transforms use
1128            // exact `exp(η)` via `families::inverse_link::apply_inverse_link_vec`
1129            // (issue #963).
1130            let e = eta.clamp(-700.0, 700.0).exp();
1131            Ok((e, e))
1132        }
1133        LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1134        LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1135        LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1136        LinkFunction::Sas => Err(EstimationError::InvalidInput(
1137            "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1138        )),
1139        LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1140            "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1141                .to_string(),
1142        )),
1143    }
1144}
1145
1146#[inline]
1147fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1148    // The full per-component jet already factors `mu` and `d1` exactly the same
1149    // way the higher orders are derived, so we either reuse the cheap closed
1150    // forms directly (Logit/Probit/CLogLog/LogLog/Cauchit) or fall back to the
1151    // existing canonicalised jet for the few cases without a separate fast
1152    // path — bit-identical to `component_inverse_link_jet(...).{mu,d1}`.
1153    match component {
1154        LinkComponent::Logit => {
1155            let jet = logit_inverse_link_jet5(eta);
1156            (jet.mu, canonicalzero(jet.d1))
1157        }
1158        LinkComponent::Probit => {
1159            if eta.is_nan() {
1160                return (f64::NAN, f64::NAN);
1161            }
1162            if eta == f64::INFINITY {
1163                return (1.0, 0.0);
1164            }
1165            if eta == f64::NEG_INFINITY {
1166                return (0.0, 0.0);
1167            }
1168            let phi = normal_pdf(eta);
1169            (normal_cdf(eta), canonicalzero(phi))
1170        }
1171        LinkComponent::CLogLog => {
1172            if eta.is_nan() {
1173                return (f64::NAN, f64::NAN);
1174            }
1175            let t = eta.exp();
1176            if !t.is_finite() {
1177                return (1.0, 0.0);
1178            }
1179            (
1180                -(-t).exp_m1(),
1181                canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1182            )
1183        }
1184        LinkComponent::LogLog => {
1185            if eta.is_nan() {
1186                return (f64::NAN, f64::NAN);
1187            }
1188            let r = (-eta).exp();
1189            if !r.is_finite() {
1190                return (0.0, 0.0);
1191            }
1192            (
1193                (-r).exp(),
1194                canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1195            )
1196        }
1197        LinkComponent::Cauchit => {
1198            if eta.is_nan() {
1199                return (f64::NAN, f64::NAN);
1200            }
1201            let den = 1.0 + eta * eta;
1202            let d1 = if eta.is_finite() {
1203                1.0 / (std::f64::consts::PI * den)
1204            } else {
1205                0.0
1206            };
1207            (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1208        }
1209    }
1210}
1211
1212fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1213    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1214    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1215        return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1216    }
1217    let e = if eta.is_finite() { eta } else { 0.0 };
1218    let a = e.asinh();
1219    let delta = delta_id;
1220    let u_raw = delta * a - epsilon;
1221    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1222    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1223    let s = u.sinh();
1224    let c = u.cosh();
1225    let z = s;
1226    let q = e.hypot(1.0);
1227    let inv_q = 1.0 / q;
1228    let r1 = delta * inv_q;
1229    let u1 = g1 * r1;
1230    let z1 = c * u1;
1231    // `mu = Phi(z)` and `d1 = phi(z) * z1`, the same closed forms used by the
1232    // full jet via `chain_inverse_link_jet(probit_jet(z), z1, _, _)`.
1233    let base = probit_jet(z);
1234    (base.mu, canonicalzero(base.d1 * z1))
1235}
1236
1237fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1238    let logistic = logistic_uwith_derivatives(eta);
1239    let a = (delta - epsilon).exp();
1240    let b = (delta + epsilon).exp();
1241    let mu = beta_reg_logistic(a, b, logistic);
1242    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1243    (mu, log_d1.exp())
1244}
1245
1246fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1247    let mut mu = 0.0_f64;
1248    let mut d1 = 0.0_f64;
1249    let k = state.components.len().min(state.pi.len());
1250    for i in 0..k {
1251        let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1252        let w = state.pi[i];
1253        mu += w * mu_i;
1254        d1 += w * d1_i;
1255    }
1256    (mu, d1)
1257}
1258
1259#[derive(Clone, Copy)]
1260enum PdfDerivativeOrder {
1261    Third,
1262    Fourth,
1263}
1264
1265impl PdfDerivativeOrder {
1266    fn probit(self, eta: f64) -> f64 {
1267        match self {
1268            Self::Third => probit_pdfthird_derivative(eta),
1269            Self::Fourth => probit_pdffourth_derivative(eta),
1270        }
1271    }
1272
1273    fn component(self, component: LinkComponent, eta: f64) -> f64 {
1274        match self {
1275            Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1276            Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1277        }
1278    }
1279
1280    fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1281        let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1282        Ok(match self {
1283            Self::Third => jet.d4,
1284            Self::Fourth => jet.d5,
1285        })
1286    }
1287
1288    fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1289        match self {
1290            Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1291            Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1292        }
1293    }
1294
1295    fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1296        match self {
1297            Self::Third => {
1298                beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1299            }
1300            Self::Fourth => {
1301                beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1302            }
1303        }
1304    }
1305}
1306
1307fn inverse_link_pdf_derivative_for_inverse_link(
1308    link: &InverseLink,
1309    eta: f64,
1310    order: PdfDerivativeOrder,
1311) -> Result<f64, EstimationError> {
1312    match link {
1313        InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1314        InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1315        InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1316        InverseLink::Standard(StandardLink::Logit) => {
1317            Ok(order.component(LinkComponent::Logit, eta))
1318        }
1319        InverseLink::Standard(StandardLink::CLogLog) => {
1320            Ok(order.component(LinkComponent::CLogLog, eta))
1321        }
1322        InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1323        InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1324        InverseLink::BetaLogistic(state) => {
1325            Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1326        }
1327        InverseLink::Mixture(state) => Ok(state
1328            .components
1329            .iter()
1330            .zip(state.pi.iter())
1331            .map(|(&component, &weight)| weight * order.component(component, eta))
1332            .sum()),
1333    }
1334}
1335
1336pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1337    link: &InverseLink,
1338    eta: f64,
1339) -> Result<f64, EstimationError> {
1340    // This dispatch returns the fourth eta-derivative of the inverse-link CDF,
1341    // equivalently the third derivative of the inverse-link density
1342    //
1343    //   f(eta) = d/deta mu(eta).
1344    //
1345    // It is used downstream as the `f'''` input in
1346    //
1347    //   d³/deta³ log f = f'''/f - 3 f'f''/f² + 2(f')³/f³.
1348    //
1349    // Mixture links preserve linearity:
1350    //
1351    //   mu = sum_j pi_j mu_j
1352    //   => f''' = sum_j pi_j f_j'''
1353    //
1354    // because the mixture weights `pi_j` are constant with respect to `eta`.
1355    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1356}
1357
1358/// Fifth derivative of the inverse-link CDF (= fourth derivative of the PDF).
1359///
1360/// Extends `inverse_link_pdfthird_derivative_for_inverse_link` by one order.
1361/// Used for the outer REML Hessian Q[v_k, v_l] term in survival models,
1362/// specifically the `m1 * u_{abcd}` Arbogast contribution.
1363pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1364    link: &InverseLink,
1365    eta: f64,
1366) -> Result<f64, EstimationError> {
1367    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1368}
1369
1370
1371#[inline]
1372fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1373    // ApproxKind: NumericalApproximation — exp(exp(eta)) overflows f64 for
1374    // eta > ~3.7; the |eta| > 30 saturation tail returns survival ≈ 0 with
1375    // d_k = 0, matching the analytic limit to within IEEE-754 underflow.
1376    const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1377
1378    let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1379    let hazard = z.exp();
1380    let survival = (-hazard).exp();
1381    if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1382        return InverseLinkJet {
1383            mu: survival,
1384            d1: 0.0,
1385            d2: 0.0,
1386            d3: 0.0,
1387        };
1388    }
1389
1390    let d1 = -hazard * survival;
1391    let d2 = hazard * (hazard - 1.0) * survival;
1392    let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1393    InverseLinkJet {
1394        mu: survival,
1395        d1,
1396        d2,
1397        d3,
1398    }
1399}
1400
1401pub fn inverse_link_jet_for_family(
1402    spec: &LikelihoodSpec,
1403    eta: f64,
1404) -> Result<InverseLinkJet, EstimationError> {
1405    // RoystonParmar uses its own analytic survival inverse link irrespective of
1406    // the (nominal `Identity`) link slot carried in the spec.
1407    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1408        return Ok(royston_parmar_inverse_link_jet(eta));
1409    }
1410    spec.link.jet(eta)
1411}
1412
1413/// Exact-public log inverse-link jet: `mu = d1 = d2 = d3 = exp(η)` with NO
1414/// `η`-clamp. Sibling of the solver-internal `LinkFunction::Log` jet (which
1415/// clamps `η` to `[−700, 700]` as an IRLS/REML conditioning hack); see issue
1416/// #963. Every derivative of `exp` is `exp`, so all four jet slots carry the
1417/// same exact value — finite wherever representable, `0.0` on underflow,
1418/// `+∞` on overflow.
1419#[inline]
1420fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1421    let e = eta.exp();
1422    InverseLinkJet {
1423        mu: e,
1424        d1: e,
1425        d2: e,
1426        d3: e,
1427    }
1428}
1429
1430/// EXACT public inverse-link jet for response-scale prediction outputs.
1431///
1432/// Identical to [`inverse_link_jet_for_family`] for every link EXCEPT the
1433/// standard `Log` link, where it returns the exact `exp(η)` jet instead of the
1434/// solver's `η.clamp(−700, 700).exp()` conditioning transform. On finite `η`
1435/// the two diverge (η = 705: exact `exp(705)` ≈ 1.5e306 vs clamped
1436/// `exp(700)` ≈ 1.0e304; η = −720: exact ≈ 2e−313 vs clamped ≈ 9.9e−305), so
1437/// public predictions (`FamilyStrategy::inverse_link_jet`/`inverse_link_array`,
1438/// the predict response + delta-method SE path) route here. The solver/REML/
1439/// PIRLS engines keep the clamped jet (issue #963). For `|η| ≤ 700` this is
1440/// byte-identical to the clamped jet (the clamp is inert there), so no
1441/// in-range prediction changes.
1442pub fn inverse_link_jet_for_family_public(
1443    spec: &LikelihoodSpec,
1444    eta: f64,
1445) -> Result<InverseLinkJet, EstimationError> {
1446    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1447        return Ok(royston_parmar_inverse_link_jet(eta));
1448    }
1449    if let InverseLink::Standard(StandardLink::Log) = spec.link {
1450        return Ok(log_inverse_link_jet_exact(eta));
1451    }
1452    spec.link.jet(eta)
1453}
1454
1455#[inline]
1456pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1457    let mut mu = 0.0_f64;
1458    let mut d1 = 0.0_f64;
1459    let mut d2 = 0.0_f64;
1460    let mut d3 = 0.0_f64;
1461    let k = state.components.len().min(state.pi.len());
1462    for i in 0..k {
1463        let jet = component_inverse_link_jet(state.components[i], eta);
1464        let w = state.pi[i];
1465        mu += w * jet.mu;
1466        d1 += w * jet.d1;
1467        d2 += w * jet.d2;
1468        d3 += w * jet.d3;
1469    }
1470    InverseLinkJet { mu, d1, d2, d3 }
1471}
1472
1473/// Computes mixture jet and exact partial derivatives wrt free softmax logits.
1474///
1475/// Uses identities:
1476///   d mu     / d rho_j = pi_j (mu_j     - mu)
1477///   d mu'    / d rho_j = pi_j (mu_j'    - mu')
1478///   d mu''   / d rho_j = pi_j (mu_j''   - mu'')
1479///   d mu'''  / d rho_j = pi_j (mu_j'''  - mu''')
1480pub fn mixture_inverse_link_jetwith_rho_partials(
1481    state: &MixtureLinkState,
1482    eta: f64,
1483) -> MixtureJetWithRhoPartials {
1484    let k = state.components.len().min(state.pi.len());
1485    let m = k.saturating_sub(1);
1486    let mut djet_drho = vec![
1487        InverseLinkJet {
1488            mu: 0.0,
1489            d1: 0.0,
1490            d2: 0.0,
1491            d3: 0.0,
1492        };
1493        m
1494    ];
1495    let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1496    MixtureJetWithRhoPartials { jet, djet_drho }
1497}
1498
1499/// Computes mixture jet and writes exact rho partial jets into `out` (length >= K-1).
1500/// This avoids heap allocation in hot loops.
1501pub fn mixture_inverse_link_jetwith_rho_partials_into(
1502    state: &MixtureLinkState,
1503    eta: f64,
1504    out: &mut [InverseLinkJet],
1505) -> InverseLinkJet {
1506    let k = state.components.len().min(state.pi.len());
1507    let m = k.saturating_sub(1);
1508    assert!(
1509        out.len() >= m,
1510        "rho-partial output buffer too small: got {}, need {}",
1511        out.len(),
1512        m
1513    );
1514    let mut mixed = InverseLinkJet {
1515        mu: 0.0,
1516        d1: 0.0,
1517        d2: 0.0,
1518        d3: 0.0,
1519    };
1520    for i in 0..k {
1521        let jet_i = component_inverse_link_jet(state.components[i], eta);
1522        let w = state.pi[i];
1523        mixed.mu += w * jet_i.mu;
1524        mixed.d1 += w * jet_i.d1;
1525        mixed.d2 += w * jet_i.d2;
1526        mixed.d3 += w * jet_i.d3;
1527        // Cache the first K-1 component jets directly in the output buffer so
1528        // we don't recompute them in the partial loop.
1529        if i < m {
1530            out[i] = jet_i;
1531        }
1532    }
1533    for j in 0..m {
1534        let pi_j = state.pi[j];
1535        let cj = out[j];
1536        out[j] = InverseLinkJet {
1537            mu: pi_j * (cj.mu - mixed.mu),
1538            d1: pi_j * (cj.d1 - mixed.d1),
1539            d2: pi_j * (cj.d2 - mixed.d2),
1540            d3: pi_j * (cj.d3 - mixed.d3),
1541        };
1542    }
1543    mixed
1544}
1545
1546#[derive(Clone, Copy)]
1547struct LogisticU {
1548    u: f64,
1549    one_minus_u: f64,
1550    ln_u: f64,
1551    ln_one_minus_u: f64,
1552    du: f64,
1553    use_upper_tail: bool,
1554}
1555
1556#[inline]
1557fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1558    let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1559    let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1560    let u = ln_u.exp();
1561    let one_minus_u = ln_one_minus_u.exp();
1562    let du = (ln_u + ln_one_minus_u).exp();
1563    LogisticU {
1564        u,
1565        one_minus_u,
1566        ln_u,
1567        ln_one_minus_u,
1568        du,
1569        use_upper_tail: eta >= 0.0,
1570    }
1571}
1572
1573#[inline]
1574fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1575    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1576        return f64::NAN;
1577    }
1578    if logistic.ln_u == f64::NEG_INFINITY {
1579        return 0.0;
1580    }
1581    if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1582        return 1.0;
1583    }
1584    if logistic.use_upper_tail {
1585        1.0 - beta_reg(b, a, logistic.one_minus_u)
1586    } else {
1587        beta_reg(a, b, logistic.u)
1588    }
1589}
1590
1591#[inline]
1592fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1593    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1594        return (f64::NAN, f64::NAN, f64::NAN);
1595    }
1596    if logistic.use_upper_tail {
1597        let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1598        (1.0 - tail, -dtail_da, -dtail_db)
1599    } else {
1600        beta_reg_with_shape_partials(a, b, logistic.u)
1601    }
1602}
1603
1604#[inline]
1605fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1606    a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1607}
1608
1609#[derive(Clone, Copy)]
1610struct ShapeDual {
1611    v: f64,
1612    da: f64,
1613    db: f64,
1614}
1615
1616impl ShapeDual {
1617    #[inline]
1618    fn constant(v: f64) -> Self {
1619        Self {
1620            v,
1621            da: 0.0,
1622            db: 0.0,
1623        }
1624    }
1625
1626    #[inline]
1627    fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1628        Self { v, da, db }
1629    }
1630
1631    #[inline]
1632    fn clamp_small(self, floor: f64) -> Self {
1633        if self.v.abs() < floor {
1634            Self::constant(floor)
1635        } else {
1636            self
1637        }
1638    }
1639}
1640
1641impl std::ops::Add for ShapeDual {
1642    type Output = Self;
1643
1644    #[inline]
1645    fn add(self, rhs: Self) -> Self {
1646        Self {
1647            v: self.v + rhs.v,
1648            da: self.da + rhs.da,
1649            db: self.db + rhs.db,
1650        }
1651    }
1652}
1653
1654impl std::ops::Sub for ShapeDual {
1655    type Output = Self;
1656
1657    #[inline]
1658    fn sub(self, rhs: Self) -> Self {
1659        Self {
1660            v: self.v - rhs.v,
1661            da: self.da - rhs.da,
1662            db: self.db - rhs.db,
1663        }
1664    }
1665}
1666
1667impl std::ops::Mul for ShapeDual {
1668    type Output = Self;
1669
1670    #[inline]
1671    fn mul(self, rhs: Self) -> Self {
1672        Self {
1673            v: self.v * rhs.v,
1674            da: self.da * rhs.v + self.v * rhs.da,
1675            db: self.db * rhs.v + self.v * rhs.db,
1676        }
1677    }
1678}
1679
1680impl std::ops::Div for ShapeDual {
1681    type Output = Self;
1682
1683    #[inline]
1684    fn div(self, rhs: Self) -> Self {
1685        let inv = 1.0 / rhs.v;
1686        let inv2 = inv * inv;
1687        Self {
1688            v: self.v * inv,
1689            da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1690            db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1691        }
1692    }
1693}
1694
1695impl std::ops::Neg for ShapeDual {
1696    type Output = Self;
1697
1698    #[inline]
1699    fn neg(self) -> Self {
1700        ShapeDual {
1701            v: -self.v,
1702            da: -self.da,
1703            db: -self.db,
1704        }
1705    }
1706}
1707
1708#[inline]
1709fn shape_dual(v: f64) -> ShapeDual {
1710    ShapeDual::constant(v)
1711}
1712
1713// Analytic shape partials for I_x(a,b), obtained by differentiating the same
1714// regularized-beta continued fraction used by statrs. The normalizing term uses
1715// d log B(a,b) / da = psi(a) - psi(a+b) and likewise for b.
1716fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1717    if x0 <= 0.0 {
1718        return (0.0, 0.0, 0.0);
1719    }
1720    if x0 >= 1.0 {
1721        return (1.0, 0.0, 0.0);
1722    }
1723
1724    let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1725    let (a, b, x) = if symm_transform {
1726        (
1727            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1728            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1729            1.0 - x0,
1730        )
1731    } else {
1732        (
1733            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1734            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1735            x0,
1736        )
1737    };
1738
1739    let ln_x = x.ln();
1740    let ln_1mx = (1.0 - x).ln();
1741    let psi_ab = digamma(a.v + b.v);
1742    let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1743        - statrs::function::gamma::ln_gamma(a.v)
1744        - statrs::function::gamma::ln_gamma(b.v)
1745        + a.v * ln_x
1746        + b.v * ln_1mx;
1747    let bt_v = log_bt.exp();
1748    let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1749    let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1750    let bt = ShapeDual {
1751        v: bt_v,
1752        da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1753        db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1754    };
1755
1756    let eps = 0.00000000000000011102230246251565;
1757    let fpmin = f64::MIN_POSITIVE / eps;
1758    let one = shape_dual(1.0);
1759    let qab = a + b;
1760    let qap = a + one;
1761    let qam = a - one;
1762    let mut c = one;
1763    let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1764    d = one / d;
1765    let mut h = d;
1766
1767    for m in 1..141 {
1768        let mf = f64::from(m);
1769        let m2 = mf * 2.0;
1770        let md = shape_dual(mf);
1771        let m2d = shape_dual(m2);
1772        let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1773        d = (one + aa * d).clamp_small(fpmin);
1774        c = (one + aa / c).clamp_small(fpmin);
1775        d = one / d;
1776        h = h * d * c;
1777
1778        aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1779        d = (one + aa * d).clamp_small(fpmin);
1780        c = (one + aa / c).clamp_small(fpmin);
1781        d = one / d;
1782        let del = d * c;
1783        h = h * del;
1784
1785        if (del.v - 1.0).abs() <= eps {
1786            let reg = bt * h / a;
1787            return if symm_transform {
1788                (1.0 - reg.v, -reg.da, -reg.db)
1789            } else {
1790                (reg.v, reg.da, reg.db)
1791            };
1792        }
1793    }
1794    let reg = bt * h / a;
1795    if symm_transform {
1796        (1.0 - reg.v, -reg.da, -reg.db)
1797    } else {
1798        (reg.v, reg.da, reg.db)
1799    }
1800}
1801
1802/// Beta-Logistic inverse-link jet for:
1803///   u = logistic(eta)
1804///   a = exp(log_shape_center - epsilon), b = exp(log_shape_center + epsilon)
1805///   mu = I_u(a, b)
1806///
1807/// NOTE: `log_shape_center` is the *unconstrained* log of the geometric-mean
1808/// beta shape (so a·b = exp(2·log_shape_center)). Callers must pass the raw
1809/// optimization parameter `SasLinkState::log_delta`, NOT the derived positive
1810/// `SasLinkState::delta = exp(log_shape_center)`.
1811pub fn beta_logistic_inverse_link_jet(
1812    eta: f64,
1813    log_shape_center: f64,
1814    epsilon: f64,
1815) -> InverseLinkJet {
1816    let logistic = logistic_uwith_derivatives(eta);
1817    let a = (log_shape_center - epsilon).exp();
1818    let b = (log_shape_center + epsilon).exp();
1819    let mu = beta_reg_logistic(a, b, logistic);
1820    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1821    let d1 = log_d1.exp();
1822    let t = a * logistic.one_minus_u - b * logistic.u;
1823    let d2 = d1 * t;
1824    let d3 = d1 * (t * t - (a + b) * logistic.du);
1825    InverseLinkJet { mu, d1, d2, d3 }
1826}
1827
1828pub fn beta_logistic_inverse_link_pdfthird_derivative(
1829    eta: f64,
1830    log_shape_center: f64,
1831    epsilon: f64,
1832) -> f64 {
1833    // Beta-logistic link:
1834    //
1835    //   u = logistic(eta),
1836    //   d1 = C * u^a (1-u)^b,
1837    //   t  = a(1-u) - b u,
1838    //   c  = a + b,
1839    //
1840    // so
1841    //
1842    //   d2 = d1 * t
1843    //   d3 = d1 * (t² - c u')
1844    //
1845    // with `u' = u(1-u)`.
1846    //
1847    // Differentiate once more:
1848    //
1849    //   d4 = d/deta[d1 (t² - c u')]
1850    //      = d1' (t² - c u') + d1 (2 t t' - c u'')
1851    //      = d1 [ t(t² - c u') - 2 c t u' - c u'' ]
1852    //      = d1 [ t³ - 3 c t u' - c u'' ],
1853    //
1854    // since `t' = -c u'`.
1855    let logistic = logistic_uwith_derivatives(eta);
1856    let a = (log_shape_center - epsilon).exp();
1857    let b = (log_shape_center + epsilon).exp();
1858    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1859    let d1 = log_d1.exp();
1860    let c = a + b;
1861    let t = a * logistic.one_minus_u - b * logistic.u;
1862    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1863    d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1864}
1865
1866/// Fifth derivative of the beta-logistic inverse-link CDF (= 4th deriv of PDF).
1867///
1868/// With `P_4 = t^3 - 3ct*u' - c*u''` giving `d4 = d1 * P_4`, the next order is:
1869///
1870///   d5 = d1 * [t^4 - 6c*t^2*u' - 4c*t*u'' + 3c^2*u'^2 - c*u''']
1871///
1872/// where u' = u(1-u), u'' = u'(1-2u), u''' = u''(1-2u) - 2*u'^2.
1873pub fn beta_logistic_inverse_link_pdffourth_derivative(
1874    eta: f64,
1875    log_shape_center: f64,
1876    epsilon: f64,
1877) -> f64 {
1878    let logistic = logistic_uwith_derivatives(eta);
1879    let a = (log_shape_center - epsilon).exp();
1880    let b = (log_shape_center + epsilon).exp();
1881    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1882    let d1 = log_d1.exp();
1883    let c = a + b;
1884    let t = a * logistic.one_minus_u - b * logistic.u;
1885    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1886    let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1887    let t2 = t * t;
1888    d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1889        + 3.0 * c * c * logistic.du * logistic.du
1890        - c * u3)
1891}
1892
1893pub fn beta_logistic_inverse_link_jetwith_param_partials(
1894    eta: f64,
1895    log_shape_center: f64,
1896    epsilon: f64,
1897) -> SasJetWithParamPartials {
1898    let logistic = logistic_uwith_derivatives(eta);
1899    let a = (log_shape_center - epsilon).exp();
1900    let b = (log_shape_center + epsilon).exp();
1901    let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1902    let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1903    let dmu_depsilon = -a * dmu_da + b * dmu_db;
1904    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1905    let d1 = log_d1.exp();
1906    let t = a * logistic.one_minus_u - b * logistic.u;
1907    let d2 = d1 * t;
1908    let k = t * t - (a + b) * logistic.du;
1909    let d3 = d1 * k;
1910    let jet = InverseLinkJet { mu, d1, d2, d3 };
1911
1912    let psi_a = digamma(a);
1913    let psi_b = digamma(b);
1914    let psi_ab = digamma(a + b);
1915    let la = logistic.ln_u - psi_a + psi_ab;
1916    let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1917
1918    let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1919        let logd1_p = a_p * la + b_p * lb;
1920        let d1_p = d1 * logd1_p;
1921        let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1922        let d2_p = d1_p * t + d1 * t_p;
1923        let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1924        let d3_p = d1_p * k + d1 * k_p;
1925        InverseLinkJet {
1926            mu: dmu,
1927            d1: d1_p,
1928            d2: d2_p,
1929            d3: d3_p,
1930        }
1931    };
1932    let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1933    let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1934    SasJetWithParamPartials {
1935        jet,
1936        djet_depsilon,
1937        djet_dlog_delta: djet_dlog_shape_center,
1938    }
1939}
1940
1941/// SAS inverse-link jet for:
1942///   mu(eta) = Phi(sinh(delta * asinh(eta) - epsilon)),
1943///   delta = exp(B * tanh(log_delta / B)), B = SAS_LOG_DELTA_BOUND.
1944pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1945    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1946    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1947        return component_inverse_link_jet(LinkComponent::Probit, eta);
1948    }
1949    let e = if eta.is_finite() { eta } else { 0.0 };
1950    let a = e.asinh();
1951    let delta = delta_id;
1952    let u_raw = delta * a - epsilon;
1953    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1954    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1955    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1956    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1957    let s = u.sinh();
1958    let c = u.cosh();
1959    let z = s;
1960    let q = e.hypot(1.0);
1961    let inv_q = 1.0 / q;
1962    let inv_q2 = inv_q * inv_q;
1963    let inv_q3 = inv_q2 * inv_q;
1964    let inv_q5 = inv_q3 * inv_q2;
1965    let r1 = delta * inv_q;
1966    let r2 = -delta * e * inv_q3;
1967    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
1968    let u1 = g1 * r1;
1969    let u2 = g2 * r1 * r1 + g1 * r2;
1970    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
1971    let z1 = c * u1;
1972    let z2 = s * u1 * u1 + c * u2;
1973    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
1974    let base = probit_jet(z);
1975    chain_inverse_link_jet(base, z1, z2, z3)
1976}
1977
1978pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1979    // SAS link with bounded latent transform:
1980    //
1981    //   a  = asinh(eta),
1982    //   u  = tanh_bound(delta * a - epsilon),
1983    //   z  = sinh(u),
1984    //   mu = Phi(z).
1985    //
1986    // Write:
1987    //
1988    //   z1 = z'
1989    //   z2 = z''
1990    //   z3 = z'''
1991    //   z4 = z''''.
1992    //
1993    // Since `mu' = phi(z) z1`, repeated differentiation factors through the
1994    // standard normal Hermite-polynomial identities:
1995    //
1996    //   mu''   = phi(z) [ z2 - z z1² ]
1997    //
1998    //   mu'''  = phi(z) [ z3 - 3 z z1 z2 + (z² - 1) z1³ ]
1999    //          = phi(z) k3
2000    //
2001    //   mu'''' = phi(z) [ k4 - z z1 k3 ],
2002    //
2003    // where `k4` is the derivative of `k3` after collecting like terms. The
2004    // code below computes `u1..u4`, then `z1..z4`, then `k3` and `k4`, exactly
2005    // matching that chain.
2006    //
2007    // The needed fourth derivative of `u(eta)` is obtained from the nested
2008    // composition `u(eta) = g(r(eta))` with
2009    //   g = tanh_bound, r = delta * asinh(eta) - epsilon:
2010    //
2011    //   u4 = g'''' r1^4 + 6 g''' r1² r2 + 3 g'' r2² + 4 g'' r1 r3 + g' r4,
2012    //
2013    // which is the standard scalar Arbogast expansion for order four.
2014    let e = if eta.is_finite() { eta } else { 0.0 };
2015    let a = e.asinh();
2016    let delta = sas_delta_from_raw_log_delta(log_delta);
2017    let u_raw = delta * a - epsilon;
2018    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2019    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2020    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2021    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2022    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2023    let s = u.sinh();
2024    let c = u.cosh();
2025    let z = s;
2026    let base = probit_jet(z);
2027    let q = e.hypot(1.0);
2028    let inv_q = 1.0 / q;
2029    let inv_q2 = inv_q * inv_q;
2030    let inv_q3 = inv_q2 * inv_q;
2031    let inv_q5 = inv_q3 * inv_q2;
2032    let inv_q7 = inv_q5 * inv_q2;
2033    let r1 = delta * inv_q;
2034    let r2 = -delta * e * inv_q3;
2035    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2036    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2037    let u1 = g1 * r1;
2038    let u2 = g2 * r1 * r1 + g1 * r2;
2039    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2040    let u4 = g4 * r1.powi(4)
2041        + 6.0 * g3 * r1 * r1 * r2
2042        + 3.0 * g2 * r2 * r2
2043        + 4.0 * g2 * r1 * r3
2044        + g1 * r4;
2045    let z1 = c * u1;
2046    let z2 = s * u1 * u1 + c * u2;
2047    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2048    let z4 =
2049        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2050    let base4 = probit_pdfthird_derivative(z);
2051    let out = base4 * z1.powi(4)
2052        + 6.0 * base.d3 * z1 * z1 * z2
2053        + 3.0 * base.d2 * z2 * z2
2054        + 4.0 * base.d2 * z1 * z3
2055        + base.d1 * z4;
2056    canonicalzero(out)
2057}
2058
2059/// Fifth derivative of the SAS inverse-link CDF (= fourth derivative of the PDF).
2060///
2061/// Extends `sas_inverse_link_pdfthird_derivative` by one more derivative order,
2062/// using the same composition chain u(eta) = g(r(eta)), z = sinh(u), mu = Phi(z).
2063///
2064/// The Arbogast expansion at order 5 for u(eta) = g(r(eta)) is:
2065///   u5 = g5 r1^5 + 10 g4 r1^3 r2 + 15 g3 r1 r2^2 + 10 g3 r1^2 r3
2066///        + 10 g2 r2 r3 + 5 g2 r1 r4 + g1 r5
2067///
2068/// The z = sinh(u) expansion at order 5 is the standard Arbogast for sinh:
2069///   z5 = c*u1^5 + 10*s*u1^3*u2 + 15*c*u1*u2^2 + 10*c*u1^2*u3
2070///        + 10*s*u2*u3 + 5*s*u1*u4 + c*u5
2071///
2072/// The mu = Phi(z) expansion at order 5 uses probit derivatives:
2073///   mu^(5) = Phi5*z1^5 + 10*Phi4*z1^3*z2 + 15*Phi3*z1*z2^2 + 10*Phi3*z1^2*z3
2074///            + 10*Phi2*z2*z3 + 5*Phi2*z1*z4 + Phi1*z5
2075pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2076    let e = if eta.is_finite() { eta } else { 0.0 };
2077    let a = e.asinh();
2078    let delta = sas_delta_from_raw_log_delta(log_delta);
2079    let u_raw = delta * a - epsilon;
2080    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2081    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2082    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2083    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2084    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2085    let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2086    let s = u.sinh();
2087    let c = u.cosh();
2088    let z = s;
2089
2090    // Probit derivatives at z.
2091    let base = probit_jet(z);
2092    let phi3 = probit_pdfthird_derivative(z); // Phi^{(4)}
2093    let phi4 = probit_pdffourth_derivative(z); // Phi^{(5)}
2094
2095    // Powers of q = sqrt(1 + eta^2) for r1..r5.
2096    let q = e.hypot(1.0);
2097    let inv_q = 1.0 / q;
2098    let inv_q2 = inv_q * inv_q;
2099    let inv_q3 = inv_q2 * inv_q;
2100    let inv_q5 = inv_q3 * inv_q2;
2101    let inv_q7 = inv_q5 * inv_q2;
2102    let inv_q9 = inv_q7 * inv_q2;
2103
2104    let r1 = delta * inv_q;
2105    let r2 = -delta * e * inv_q3;
2106    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2107    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2108    let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2109
2110    // u1..u5 via Arbogast for g(r(eta)).
2111    let u1 = g1 * r1;
2112    let u2 = g2 * r1 * r1 + g1 * r2;
2113    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2114    let u4 = g4 * r1.powi(4)
2115        + 6.0 * g3 * r1 * r1 * r2
2116        + 3.0 * g2 * r2 * r2
2117        + 4.0 * g2 * r1 * r3
2118        + g1 * r4;
2119    let u5 = g5 * r1.powi(5)
2120        + 10.0 * g4 * r1 * r1 * r1 * r2
2121        + 15.0 * g3 * r1 * r2 * r2
2122        + 10.0 * g3 * r1 * r1 * r3
2123        + 10.0 * g2 * r2 * r3
2124        + 5.0 * g2 * r1 * r4
2125        + g1 * r5;
2126
2127    // z1..z5 via Arbogast for sinh(u(eta)).
2128    let z1 = c * u1;
2129    let z2 = s * u1 * u1 + c * u2;
2130    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2131    let z4 =
2132        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2133    let z5 = c * u1.powi(5)
2134        + 10.0 * s * u1 * u1 * u1 * u2
2135        + 15.0 * c * u1 * u2 * u2
2136        + 10.0 * c * u1 * u1 * u3
2137        + 10.0 * s * u2 * u3
2138        + 5.0 * s * u1 * u4
2139        + c * u5;
2140
2141    // mu^(5) = Phi^(5)*z1^5 + 10*Phi^(4)*z1^3*z2 + 15*Phi^(3)*z1*z2^2
2142    //        + 10*Phi^(3)*z1^2*z3 + 10*Phi^(2)*z2*z3 + 5*Phi^(2)*z1*z4 + Phi^(1)*z5
2143    let out = phi4 * z1.powi(5)
2144        + 10.0 * phi3 * z1 * z1 * z1 * z2
2145        + 15.0 * base.d3 * z1 * z2 * z2
2146        + 10.0 * base.d3 * z1 * z1 * z3
2147        + 10.0 * base.d2 * z2 * z3
2148        + 5.0 * base.d2 * z1 * z4
2149        + base.d1 * z5;
2150    canonicalzero(out)
2151}
2152
2153pub fn sas_inverse_link_jetwith_param_partials(
2154    eta: f64,
2155    epsilon: f64,
2156    log_delta: f64,
2157) -> SasJetWithParamPartials {
2158    let e = if eta.is_finite() { eta } else { 0.0 };
2159    let a = e.asinh();
2160    let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2161    let delta = ld_eff.exp();
2162    let ddelta_draw = delta * dld_eff_draw;
2163    let u_raw = delta * a - epsilon;
2164    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2165    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2166    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2167    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2168    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2169    let s = u.sinh();
2170    let c = u.cosh();
2171    let z = s;
2172    let q = e.hypot(1.0);
2173    let inv_q = 1.0 / q;
2174    let inv_q2 = inv_q * inv_q;
2175    let inv_q3 = inv_q2 * inv_q;
2176    let inv_q5 = inv_q3 * inv_q2;
2177    let a1 = inv_q;
2178    let a2 = -e * inv_q3;
2179    let a3 = (2.0 * e * e - 1.0) * inv_q5;
2180    let r1 = delta * a1;
2181    let r2 = delta * a2;
2182    let r3 = delta * a3;
2183    let u1 = g1 * r1;
2184    let u2 = g2 * r1 * r1 + g1 * r2;
2185    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2186    let z1 = c * u1;
2187    let z2 = s * u1 * u1 + c * u2;
2188    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2189
2190    let base = probit_jet(z);
2191    let jet = chain_inverse_link_jet(base, z1, z2, z3);
2192
2193    // Generic chain for parameter t:
2194    // u_t, u1_t, u2_t, u3_t -> z_t,z1_t,z2_t,z3_t -> mu_t,d1_t,d2_t,d3_t
2195    let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2196        let z_t = c * u_t;
2197        let z1_t = s * u_t * u1 + c * u1_t;
2198        let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2199        let z3_t = s * u_t * u1 * u1 * u1
2200            + 3.0 * c * u1 * u1 * u1_t
2201            + 3.0 * c * u_t * u1 * u2
2202            + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2203            + s * u_t * u3
2204            + c * u3_t;
2205
2206        InverseLinkJet {
2207            mu: base.d1 * z_t,
2208            d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2209            d2: base.d3 * z_t * z1 * z1
2210                + 2.0 * base.d2 * z1 * z1_t
2211                + base.d2 * z_t * z2
2212                + base.d1 * z2_t,
2213            d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2214                + 3.0 * base.d3 * z1 * z1 * z1_t
2215                + 3.0 * base.d3 * z_t * z1 * z2
2216                + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2217                + base.d2 * z_t * z3
2218                + base.d1 * z3_t,
2219        }
2220    };
2221
2222    // epsilon partials (raw_u_t = -1).
2223    let rt_eps = -1.0;
2224    let r1t_eps = 0.0;
2225    let r2t_eps = 0.0;
2226    let r3t_eps = 0.0;
2227    let u_eps = g1 * rt_eps;
2228    let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2229    let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2230    let u3_eps = g4 * rt_eps * r1 * r1 * r1
2231        + 3.0 * g3 * r1 * r1 * r1t_eps
2232        + 3.0 * g3 * rt_eps * r1 * r2
2233        + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2234        + g2 * rt_eps * r3
2235        + g1 * r3t_eps;
2236    let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2237
2238    // raw log-delta partials (through smooth bounded effective log-delta).
2239    let rt_ld = ddelta_draw * a;
2240    let r1t_ld = ddelta_draw * a1;
2241    let r2t_ld = ddelta_draw * a2;
2242    let r3t_ld = ddelta_draw * a3;
2243    let u_ld = g1 * rt_ld;
2244    let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2245    let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2246    let u3_ld = g4 * rt_ld * r1 * r1 * r1
2247        + 3.0 * g3 * r1 * r1 * r1t_ld
2248        + 3.0 * g3 * rt_ld * r1 * r2
2249        + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2250        + g2 * rt_ld * r3
2251        + g1 * r3t_ld;
2252    let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2253
2254    SasJetWithParamPartials {
2255        jet,
2256        djet_depsilon,
2257        djet_dlog_delta,
2258    }
2259}
2260
2261#[cfg(test)]
2262mod tests {
2263    use super::*;
2264    use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2265
2266    #[test]
2267    fn softmax_jacobian_matchesfd() {
2268        let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2269        let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2270        let h = 1e-6;
2271        for j in 0..rho.len() {
2272            let mut rp = rho.clone();
2273            rp[j] += h;
2274            let mut rm = rho.clone();
2275            rm[j] -= h;
2276            let pp = softmax_last_fixedzero(&rp);
2277            let pm = softmax_last_fixedzero(&rm);
2278            let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2279            for k in 0..pi.len() {
2280                let err = (jac[[k, j]] - fd[k]).abs();
2281                assert_eq!(
2282                    jac[[k, j]].signum(),
2283                    fd[k].signum(),
2284                    "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2285                    jac[[k, j]],
2286                    fd[k]
2287                );
2288                assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2289            }
2290        }
2291    }
2292
2293    #[test]
2294    fn mixture_jet_rho_partials_matchfd() {
2295        let spec = MixtureLinkSpec {
2296            components: vec![
2297                LinkComponent::Probit,
2298                LinkComponent::Logit,
2299                LinkComponent::CLogLog,
2300                LinkComponent::Cauchit,
2301            ],
2302            initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2303        };
2304        let state = state_fromspec(&spec).expect("state");
2305        let eta = 0.35;
2306        let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2307        let h = 1e-6;
2308        for j in 0..state.rho.len() {
2309            let mut rp = state.rho.clone();
2310            rp[j] += h;
2311            let sp = MixtureLinkSpec {
2312                components: state.components.clone(),
2313                initial_rho: rp,
2314            };
2315            let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2316            let mut rm = state.rho.clone();
2317            rm[j] -= h;
2318            let sm = MixtureLinkSpec {
2319                components: state.components.clone(),
2320                initial_rho: rm,
2321            };
2322            let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2323            let fd = InverseLinkJet {
2324                mu: (jp.mu - jm.mu) / (2.0 * h),
2325                d1: (jp.d1 - jm.d1) / (2.0 * h),
2326                d2: (jp.d2 - jm.d2) / (2.0 * h),
2327                d3: (jp.d3 - jm.d3) / (2.0 * h),
2328            };
2329            let an = out.djet_drho[j];
2330            assert_eq!(an.mu.signum(), fd.mu.signum());
2331            assert_eq!(an.d1.signum(), fd.d1.signum());
2332            assert_eq!(an.d2.signum(), fd.d2.signum());
2333            assert_eq!(an.d3.signum(), fd.d3.signum());
2334            assert!((an.mu - fd.mu).abs() < 1e-6);
2335            assert!((an.d1 - fd.d1).abs() < 1e-6);
2336            assert!((an.d2 - fd.d2).abs() < 1e-6);
2337            assert!((an.d3 - fd.d3).abs() < 1e-6);
2338        }
2339    }
2340
2341    #[test]
2342    fn sas_param_partials_matchfd() {
2343        let eta = 0.37;
2344        let epsilon = -0.12;
2345        let log_delta = 0.21;
2346        let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2347        let h = 1e-6;
2348
2349        let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2350        let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2351        let fd_ep = InverseLinkJet {
2352            mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2353            d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2354            d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2355            d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2356        };
2357        assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2358        assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2359        assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2360        assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2361        assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2362        assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2363        assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2364        assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2365
2366        let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2367        let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2368        let fd_ld = InverseLinkJet {
2369            mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2370            d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2371            d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2372            d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2373        };
2374        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2375        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2376        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2377        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2378        assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2379        assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2380        assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2381        assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2382    }
2383
2384    #[test]
2385    fn sas_jet_extreme_inputs_stay_finite() {
2386        let cases = [
2387            (-1e6, 0.0, 0.0),
2388            (1e6, 0.0, 0.0),
2389            (3.0, 12.0, 12.0),
2390            (-3.0, -12.0, -12.0),
2391            (0.5, 40.0, 10.0),
2392            (0.5, -40.0, -10.0),
2393        ];
2394        for (eta, eps, log_delta) in cases {
2395            let j = sas_inverse_link_jet(eta, eps, log_delta);
2396            assert!(j.mu.is_finite());
2397            assert!(j.d1.is_finite());
2398            assert!(j.d2.is_finite());
2399            assert!(j.d3.is_finite());
2400            let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2401            assert!(p.djet_depsilon.mu.is_finite());
2402            assert!(p.djet_depsilon.d1.is_finite());
2403            assert!(p.djet_depsilon.d2.is_finite());
2404            assert!(p.djet_depsilon.d3.is_finite());
2405            assert!(p.djet_dlog_delta.mu.is_finite());
2406            assert!(p.djet_dlog_delta.d1.is_finite());
2407            assert!(p.djet_dlog_delta.d2.is_finite());
2408            assert!(p.djet_dlog_delta.d3.is_finite());
2409        }
2410    }
2411
2412    #[test]
2413    fn sas_param_partials_remain_finite_in_extreme_region() {
2414        let eta = 10.0;
2415        let epsilon = -60.0;
2416        let log_delta = 40.0;
2417        let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2418        assert!(j.djet_depsilon.mu.is_finite());
2419        assert!(j.djet_depsilon.d1.is_finite());
2420        assert!(j.djet_depsilon.d2.is_finite());
2421        assert!(j.djet_depsilon.d3.is_finite());
2422        assert!(j.djet_dlog_delta.mu.is_finite());
2423        assert!(j.djet_dlog_delta.d1.is_finite());
2424        assert!(j.djet_dlog_delta.d2.is_finite());
2425        assert!(j.djet_dlog_delta.d3.is_finite());
2426    }
2427
2428    #[test]
2429    fn sas_eta_jets_matchfd() {
2430        let eta = -0.43;
2431        let epsilon = 0.27;
2432        let log_delta = -0.31;
2433        let h = 1e-5;
2434        let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2435        let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2436        let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2437        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2438        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2439        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2440        assert_eq!(j0.d1.signum(), d1fd.signum());
2441        assert_eq!(j0.d2.signum(), d2fd.signum());
2442        assert_eq!(j0.d3.signum(), d3fd.signum());
2443        assert!((j0.d1 - d1fd).abs() < 5e-5);
2444        assert!((j0.d2 - d2fd).abs() < 2e-4);
2445        assert!((j0.d3 - d3fd).abs() < 1e-3);
2446    }
2447
2448    #[test]
2449    fn family_dispatch_resolves_parameterized_links_from_spec() {
2450        // After the LikelihoodSpec migration, the dispatch no longer needs
2451        // out-of-band state arguments — the parameterized link state lives on
2452        // `spec.link`. Verify that supplying explicit SAS/Mixture link states
2453        // through the spec produces finite jets at a representative eta.
2454        let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2455        let sas_spec = gam_problem::LikelihoodSpec {
2456            response: gam_problem::ResponseFamily::Binomial,
2457            link: InverseLink::Sas(sas_state),
2458        };
2459        let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2460        assert!(sas_jet.mu.is_finite());
2461        assert!(sas_jet.d1.is_finite());
2462
2463        let mix_state = MixtureLinkState {
2464            components: vec![LinkComponent::Logit, LinkComponent::Probit],
2465            rho: ndarray::array![0.0],
2466            pi: ndarray::array![0.5, 0.5],
2467        };
2468        let mix_spec = gam_problem::LikelihoodSpec {
2469            response: gam_problem::ResponseFamily::Binomial,
2470            link: InverseLink::Mixture(mix_state),
2471        };
2472        let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2473        assert!(mix_jet.mu.is_finite());
2474        assert!(mix_jet.d1.is_finite());
2475    }
2476
2477    #[test]
2478    fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2479        let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2480        for eta in etas {
2481            let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2482            let expected_mu = gam_linalg::utils::stable_logistic(eta);
2483            let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2484                - gam_linalg::utils::stable_softplus(eta))
2485            .exp();
2486            assert!(
2487                (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2488                "mu mismatch at eta={eta}: got {}, expected {}",
2489                j_bl.mu,
2490                expected_mu
2491            );
2492            assert!(
2493                (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2494                "d1 mismatch at eta={eta}: got {}, expected {}",
2495                j_bl.d1,
2496                expected_d1
2497            );
2498            assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2499        }
2500
2501        let eta = 0.42;
2502        let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2503        let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2504        assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2505        assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2506    }
2507
2508    #[test]
2509    fn beta_logistic_eta_jets_matchfd() {
2510        let eta = -0.31;
2511        let delta = 0.27;
2512        let epsilon = -0.19;
2513        let h = 1e-5;
2514        let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2515        let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2516        let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2517        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2518        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2519        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2520        assert_eq!(j0.d1.signum(), d1fd.signum());
2521        assert_eq!(j0.d2.signum(), d2fd.signum());
2522        assert_eq!(j0.d3.signum(), d3fd.signum());
2523        assert!((j0.d1 - d1fd).abs() < 5e-5);
2524        assert!((j0.d2 - d2fd).abs() < 5e-5);
2525        assert!((j0.d3 - d3fd).abs() < 2e-4);
2526    }
2527
2528    #[test]
2529    fn standard_kernel_structs_match_component_jets() {
2530        let eta = 0.73;
2531        assert_eq!(
2532            ProbitLinkKernel.jet(eta).expect("probit"),
2533            component_inverse_link_jet(LinkComponent::Probit, eta)
2534        );
2535        assert_eq!(
2536            LogitLinkKernel.jet(eta).expect("logit"),
2537            component_inverse_link_jet(LinkComponent::Logit, eta)
2538        );
2539        assert_eq!(
2540            CLogLogLinkKernel.jet(eta).expect("cloglog"),
2541            component_inverse_link_jet(LinkComponent::CLogLog, eta)
2542        );
2543        assert_eq!(
2544            LogLogLinkKernel.jet(eta).expect("loglog"),
2545            component_inverse_link_jet(LinkComponent::LogLog, eta)
2546        );
2547        assert_eq!(
2548            CauchitLinkKernel.jet(eta).expect("cauchit"),
2549            component_inverse_link_jet(LinkComponent::Cauchit, eta)
2550        );
2551    }
2552
2553    #[test]
2554    fn all_component_eta_jets_matchfd() {
2555        let components = [
2556            LinkComponent::Logit,
2557            LinkComponent::Probit,
2558            LinkComponent::CLogLog,
2559            LinkComponent::LogLog,
2560            LinkComponent::Cauchit,
2561        ];
2562        let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2563        let h = 1e-5;
2564        for c in components {
2565            for &eta in &points {
2566                let j0 = component_inverse_link_jet(c, eta);
2567                let jp = component_inverse_link_jet(c, eta + h);
2568                let jm = component_inverse_link_jet(c, eta - h);
2569                let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2570                let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2571                let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2572                let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2573                    1.2e-4
2574                } else {
2575                    5e-5
2576                };
2577                let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2578                    4e-4
2579                } else {
2580                    1.2e-4
2581                };
2582                let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2583                    1.2e-3
2584                } else {
2585                    4e-4
2586                };
2587                if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2588                    assert_eq!(
2589                        j0.d1.signum(),
2590                        d1fd.signum(),
2591                        "d1 sign mismatch for {c:?} eta={eta}"
2592                    );
2593                }
2594                if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2595                    assert_eq!(
2596                        j0.d2.signum(),
2597                        d2fd.signum(),
2598                        "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2599                        j0.d2,
2600                        d2fd
2601                    );
2602                }
2603                if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2604                    assert_eq!(
2605                        j0.d3.signum(),
2606                        d3fd.signum(),
2607                        "d3 sign mismatch for {c:?} eta={eta}"
2608                    );
2609                }
2610                assert!(
2611                    (j0.d1 - d1fd).abs() < d1_tol,
2612                    "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2613                    j0.d1,
2614                    d1fd
2615                );
2616                assert!(
2617                    (j0.d2 - d2fd).abs() < d2_tol,
2618                    "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2619                    j0.d2,
2620                    d2fd
2621                );
2622                assert!(
2623                    (j0.d3 - d3fd).abs() < d3_tol,
2624                    "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2625                    j0.d3,
2626                    d3fd
2627                );
2628            }
2629        }
2630    }
2631
2632    #[test]
2633    fn sas_center_matches_probit_at_delta1_epsilon0() {
2634        let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2635        for eta in etas {
2636            let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2637            let probit = ProbitLinkKernel.jet(eta).expect("probit");
2638            // SAS implementation uses a smooth bounded latent (`tanh_bound`) for
2639            // numerical robustness, so the probit center is approximate in practice.
2640            assert!(
2641                (sas.mu - probit.mu).abs() < 6e-4,
2642                "mu mismatch at eta={eta}"
2643            );
2644            assert!(
2645                (sas.d1 - probit.d1).abs() < 6e-4,
2646                "d1 mismatch at eta={eta}"
2647            );
2648            assert!(
2649                (sas.d2 - probit.d2).abs() < 2e-3,
2650                "d2 mismatch at eta={eta}"
2651            );
2652            assert!(
2653                (sas.d3 - probit.d3).abs() < 4e-3,
2654                "d3 mismatch at eta={eta}"
2655            );
2656        }
2657    }
2658
2659    #[test]
2660    fn beta_logistic_param_partials_matchfd() {
2661        let eta = -0.41;
2662        let delta = 0.23;
2663        let epsilon = -0.17;
2664        let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2665        let h = 1e-6;
2666
2667        let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2668        let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2669        let fd_delta = InverseLinkJet {
2670            mu: (dp.mu - dm.mu) / (2.0 * h),
2671            d1: (dp.d1 - dm.d1) / (2.0 * h),
2672            d2: (dp.d2 - dm.d2) / (2.0 * h),
2673            d3: (dp.d3 - dm.d3) / (2.0 * h),
2674        };
2675        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2676        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2677        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2678        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2679        assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2680        assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2681        assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2682        assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2683
2684        let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2685        let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2686        let fd_epsilon = InverseLinkJet {
2687            mu: (ep.mu - em.mu) / (2.0 * h),
2688            d1: (ep.d1 - em.d1) / (2.0 * h),
2689            d2: (ep.d2 - em.d2) / (2.0 * h),
2690            d3: (ep.d3 - em.d3) / (2.0 * h),
2691        };
2692        assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2693        assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2694        assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2695        assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2696        assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2697        assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2698        assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2699        assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2700    }
2701
2702    #[test]
2703    fn beta_logistic_left_tail_uses_unclamped_log_space() {
2704        let eta = -40.0_f64;
2705        let delta = 0.2_f64;
2706        let epsilon = -0.1_f64;
2707        let a = (delta - epsilon).exp();
2708        let b = (delta + epsilon).exp();
2709        let expected_mu = beta_reg(a, b, eta.exp());
2710        let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2711
2712        assert!(
2713            (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2714            "left-tail mu mismatch: got {}, expected {}",
2715            out.mu,
2716            expected_mu
2717        );
2718        assert!(out.d1 > 0.0);
2719        assert!(out.d2 > 0.0);
2720        assert!(out.d3 > 0.0);
2721        assert!(out.d1 < 1e-20);
2722
2723        let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2724        assert!(partials.jet.d1 > 0.0);
2725        assert!(partials.jet.d2 > 0.0);
2726        assert!(partials.jet.d3 > 0.0);
2727        assert!(partials.djet_dlog_delta.d1.is_finite());
2728        assert!(partials.djet_depsilon.d1.is_finite());
2729    }
2730
2731    #[test]
2732    fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2733        let delta = 0.2;
2734        let epsilon = -0.35;
2735        let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2736        for eta in etas {
2737            let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2738            let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2739            assert!(
2740                (left - right).abs() <= 1e-14,
2741                "symmetry mismatch at eta={eta}: left={left}, right={right}"
2742            );
2743        }
2744    }
2745
2746    #[test]
2747    fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2748        let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2749        let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2750            epsilon: 0.18,
2751            log_delta: -0.22,
2752            delta: (-0.22_f64).exp(),
2753        });
2754        let mixture = InverseLink::Mixture(
2755            state_fromspec(&MixtureLinkSpec {
2756                components: vec![
2757                    LinkComponent::Probit,
2758                    LinkComponent::Logit,
2759                    LinkComponent::CLogLog,
2760                    LinkComponent::Cauchit,
2761                ],
2762                initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2763            })
2764            .expect("mixture state"),
2765        );
2766        let links = [
2767            InverseLink::Standard(StandardLink::Probit),
2768            InverseLink::Standard(StandardLink::Logit),
2769            InverseLink::Standard(StandardLink::CLogLog),
2770            sas,
2771            beta_logistic,
2772            mixture,
2773        ];
2774        let etas = [-1.1, -0.2, 0.6];
2775        let h = 1e-5;
2776
2777        for link in &links {
2778            for &eta in &etas {
2779                let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2780                let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2781                let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2782                let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2783                    .expect("analytic d4");
2784                assert_eq!(
2785                    d4.signum(),
2786                    d4fd.signum(),
2787                    "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2788                    link,
2789                    d4,
2790                    d4fd
2791                );
2792                assert!(
2793                    (d4 - d4fd).abs() < 5e-3,
2794                    "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2795                    link,
2796                    d4,
2797                    d4fd
2798                );
2799            }
2800        }
2801    }
2802
2803    #[test]
2804    fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2805        let eta = 800.0;
2806        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2807        assert_eq!(jet.mu, 1.0);
2808        assert!(
2809            jet.d1 == 0.0,
2810            "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2811            jet.d1
2812        );
2813        assert!(
2814            jet.d2 == 0.0,
2815            "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2816            jet.d2
2817        );
2818        assert!(
2819            jet.d3 == 0.0,
2820            "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2821            jet.d3
2822        );
2823
2824        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2825            &InverseLink::Standard(StandardLink::CLogLog),
2826            eta,
2827        )
2828        .expect("cloglog d4");
2829        assert!(
2830            d4 == 0.0,
2831            "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2832        );
2833    }
2834
2835    #[test]
2836    fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2837        let eta = -800.0;
2838        let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2839        assert_eq!(jet.mu, 0.0);
2840        assert!(
2841            jet.d1 == 0.0,
2842            "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2843            jet.d1
2844        );
2845        assert!(
2846            jet.d2 == 0.0,
2847            "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2848            jet.d2
2849        );
2850        assert!(
2851            jet.d3 == 0.0,
2852            "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2853            jet.d3
2854        );
2855
2856        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2857            &InverseLink::Mixture(
2858                state_fromspec(&MixtureLinkSpec {
2859                    components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2860                    initial_rho: Array1::from_vec(vec![12.0]),
2861                })
2862                .expect("mixture state"),
2863            ),
2864            eta,
2865        )
2866        .expect("loglog mixture d4");
2867        assert!(
2868            d4.is_finite(),
2869            "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2870        );
2871    }
2872
2873    #[test]
2874    fn logit_tail_derivatives_should_match_stable_closed_forms() {
2875        let eta = 50.0_f64;
2876        let z = (-eta).exp();
2877        let denom = 1.0_f64 + z;
2878        let stable_d1 = z / denom.powi(2);
2879        let stable_d2 = z * (z - 1.0) / denom.powi(3);
2880        let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2881        let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2882        let stable_d5 =
2883            z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2884
2885        assert!(stable_d1 > 0.0);
2886        assert!(stable_d2 < 0.0);
2887        assert!(stable_d3 > 0.0);
2888        assert!(stable_d4 < 0.0);
2889        assert!(stable_d5 > 0.0);
2890
2891        let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2892        assert!(
2893            (jet.d1 - stable_d1).abs() < 1e-30,
2894            "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2895            jet.d1,
2896            stable_d1
2897        );
2898        assert!(
2899            (jet.d2 - stable_d2).abs() < 1e-30,
2900            "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2901            jet.d2,
2902            stable_d2
2903        );
2904        assert!(
2905            (jet.d3 - stable_d3).abs() < 1e-30,
2906            "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2907            jet.d3,
2908            stable_d3
2909        );
2910
2911        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2912            &InverseLink::Standard(StandardLink::Logit),
2913            eta,
2914        )
2915        .expect("logit d4");
2916        assert!(
2917            (d4 - stable_d4).abs() < 1e-30,
2918            "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2919            d4,
2920            stable_d4
2921        );
2922
2923        let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2924            &InverseLink::Standard(StandardLink::Logit),
2925            eta,
2926        )
2927        .expect("logit d5");
2928        assert!(
2929            (d5 - stable_d5).abs() < 1e-30,
2930            "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2931            d5,
2932            stable_d5
2933        );
2934    }
2935
2936    #[test]
2937    fn cloglog_negative_tail_value_should_match_expm1_form() {
2938        let eta = -50.0_f64;
2939        let t = eta.exp();
2940        let stable_mu = -(-t).exp_m1();
2941        assert!(stable_mu > 0.0);
2942
2943        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2944        assert!(
2945            (jet.mu - stable_mu).abs() < 1e-30,
2946            "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2947            jet.mu,
2948            stable_mu
2949        );
2950    }
2951
2952    #[test]
2953    fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2954        fn rel_err(a: f64, b: f64) -> f64 {
2955            (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2956        }
2957
2958        let cases = [
2959            (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2960            (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2961            (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2962        ];
2963        for (component, etas) in cases {
2964            for eta in etas {
2965                let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
2966                let jet = component_inverse_link_jet(component, eta);
2967                let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
2968                assert!(
2969                    rel_err(w, expected) < 1.0e-12,
2970                    "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
2971                );
2972
2973                let h = 1.0e-4;
2974                let fd1 = (component_fisher_weight_jet5(component, eta + h).0
2975                    - component_fisher_weight_jet5(component, eta - h).0)
2976                    / (2.0 * h);
2977                let fd2 = (component_fisher_weight_jet5(component, eta + h).1
2978                    - component_fisher_weight_jet5(component, eta - h).1)
2979                    / (2.0 * h);
2980                let fd3 = (component_fisher_weight_jet5(component, eta + h).2
2981                    - component_fisher_weight_jet5(component, eta - h).2)
2982                    / (2.0 * h);
2983                let fd4 = (component_fisher_weight_jet5(component, eta + h).3
2984                    - component_fisher_weight_jet5(component, eta - h).3)
2985                    / (2.0 * h);
2986
2987                assert!(
2988                    rel_err(w1, fd1) < 1.0e-5,
2989                    "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
2990                );
2991                assert!(
2992                    rel_err(w2, fd2) < 1.0e-5,
2993                    "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
2994                );
2995                assert!(
2996                    rel_err(w3, fd3) < 5.0e-5,
2997                    "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
2998                );
2999                assert!(
3000                    rel_err(w4, fd4) < 5.0e-4,
3001                    "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3002                );
3003            }
3004        }
3005    }
3006
3007    #[test]
3008    fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3009        let state = state_fromspec(&MixtureLinkSpec {
3010            components: vec![
3011                LinkComponent::CLogLog,
3012                LinkComponent::LogLog,
3013                LinkComponent::Cauchit,
3014            ],
3015            initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3016        })
3017        .expect("mixture state");
3018        let link = InverseLink::Mixture(state);
3019        assert!(
3020            inverse_link_has_fisher_weight_jet(&link),
3021            "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3022        );
3023        assert!(
3024            LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3025            "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3026        );
3027
3028        for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3029            let (w, w1, w2, w3, w4) =
3030                fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3031            for value in [w, w1, w2, w3, w4] {
3032                assert!(
3033                    value.is_finite(),
3034                    "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3035                );
3036            }
3037            assert!(
3038                w > 0.0,
3039                "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3040            );
3041        }
3042    }
3043
3044    #[test]
3045    fn loglog_fifth_derivative_should_match_closed_form_sign() {
3046        let eta = 0.0_f64;
3047        let r = (-eta).exp();
3048        let expected =
3049            (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3050        let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3051        assert!(
3052            (d5 - expected).abs() < 1e-15,
3053            "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3054        );
3055        assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3056    }
3057}