Skip to main content

gam_solve/
mixture_link.rs

1use crate::estimate::EstimationError;
2use gam_math::probability::{normal_cdf, normal_pdf};
3use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
4use crate::quadrature::latent_cloglog_jet5;
5use gam_problem::{
6    InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7    MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16/// Bound B used by the bounded sinh-arcsinh log-delta parameterisation:
17/// `delta = exp(B * tanh(raw_log_delta / B))`. Exposed for the outer-strategy
18/// edge-barrier helpers in `solver/estimate.rs` that previously had to
19/// hard-code the same `12.0` with a "must match" comment.
20pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24    static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25    QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30    state: &LatentCLogLogState,
31    eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33    let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34    Ok(InverseLinkJet {
35        mu: jet.mean,
36        d1: jet.d1,
37        d2: jet.d2,
38        d3: jet.d3,
39    })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44    pub mu: f64,
45    pub d1: f64,
46    pub d2: f64,
47    pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52    pub mu: f64,
53    pub d1: f64,
54    pub d2: f64,
55    pub d3: f64,
56    pub d4: f64,
57    pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62    if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67    jet.d1 = canonicalzero(jet.d1);
68    jet.d2 = canonicalzero(jet.d2);
69    jet.d3 = canonicalzero(jet.d3);
70    jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75    if eta.is_nan() {
76        return LogitJet5 {
77            mu: f64::NAN,
78            d1: f64::NAN,
79            d2: f64::NAN,
80            d3: f64::NAN,
81            d4: f64::NAN,
82            d5: f64::NAN,
83        };
84    }
85    if eta == f64::INFINITY {
86        return LogitJet5 {
87            mu: 1.0,
88            d1: 0.0,
89            d2: 0.0,
90            d3: 0.0,
91            d4: 0.0,
92            d5: 0.0,
93        };
94    }
95    if eta == f64::NEG_INFINITY {
96        return LogitJet5 {
97            mu: 0.0,
98            d1: 0.0,
99            d2: 0.0,
100            d3: 0.0,
101            d4: 0.0,
102            d5: 0.0,
103        };
104    }
105
106    let jet = if eta >= 0.0 {
107        let z = (-eta).exp();
108        let opz = 1.0 + z;
109        let opz2 = opz * opz;
110        let opz3 = opz2 * opz;
111        let opz4 = opz3 * opz;
112        let opz5 = opz4 * opz;
113        let opz6 = opz5 * opz;
114        let z2 = z * z;
115        let z3 = z2 * z;
116        let z4 = z3 * z;
117        LogitJet5 {
118            mu: 1.0 / opz,
119            d1: z / opz2,
120            d2: z * (z - 1.0) / opz3,
121            d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122            d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123            d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124        }
125    } else {
126        let z = eta.exp();
127        let opz = 1.0 + z;
128        let opz2 = opz * opz;
129        let opz3 = opz2 * opz;
130        let opz4 = opz3 * opz;
131        let opz5 = opz4 * opz;
132        let opz6 = opz5 * opz;
133        let z2 = z * z;
134        let z3 = z2 * z;
135        let z4 = z3 * z;
136        LogitJet5 {
137            mu: z / opz,
138            d1: z / opz2,
139            d2: z * (1.0 - z) / opz3,
140            d3: z * (1.0 - 4.0 * z + z2) / opz4,
141            d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142            d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143        }
144    };
145    LogitJet5 {
146        mu: jet.mu,
147        d1: canonicalzero(jet.d1),
148        d2: canonicalzero(jet.d2),
149        d3: canonicalzero(jet.d3),
150        d4: canonicalzero(jet.d4),
151        d5: canonicalzero(jet.d5),
152    }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157    // Exact probit semantics:
158    //
159    //   mu(eta) = Phi(eta),
160    //   mu'     = phi(eta),
161    //   mu''    = -eta * phi(eta),
162    //   mu'''   = (eta^2 - 1) * phi(eta).
163    //
164    // `normal_cdf` now evaluates the exact special-function form
165    // Phi(x) = 0.5 * erfc(-x / sqrt(2)), so the jet can and should use the
166    // matching closed-form Gaussian identities directly.
167    if eta.is_nan() {
168        return InverseLinkJet {
169            mu: f64::NAN,
170            d1: f64::NAN,
171            d2: f64::NAN,
172            d3: f64::NAN,
173        };
174    }
175    if eta == f64::INFINITY {
176        return InverseLinkJet {
177            mu: 1.0,
178            d1: 0.0,
179            d2: 0.0,
180            d3: 0.0,
181        };
182    }
183    if eta == f64::NEG_INFINITY {
184        return InverseLinkJet {
185            mu: 0.0,
186            d1: 0.0,
187            d2: 0.0,
188            d3: 0.0,
189        };
190    }
191    let x = eta;
192    let phi = normal_pdf(x);
193    InverseLinkJet {
194        mu: normal_cdf(x),
195        d1: phi,
196        d2: -x * phi,
197        d3: (x * x - 1.0) * phi,
198    }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203    // Since d1 = mu' = phi(eta), this returns
204    //
205    //   d³/deta³ d1 = mu'''' = -(eta³ - 3 eta) phi(eta).
206    if eta.is_nan() {
207        return f64::NAN;
208    }
209    if !eta.is_finite() {
210        return 0.0;
211    }
212    let x = eta;
213    let phi = normal_pdf(x);
214    canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219    // mu''''' = Phi^{(5)}(eta) = (eta^4 - 6*eta^2 + 3) * phi(eta).
220    if eta.is_nan() {
221        return f64::NAN;
222    }
223    if !eta.is_finite() {
224        return 0.0;
225    }
226    let x = eta;
227    let phi = normal_pdf(x);
228    canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231/// Multiply two 5-term truncated Taylor series (coefficients `a_k = g^(k)/k!`,
232/// `k = 0..=4`) and return the truncated product coefficients.
233#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235    let mut c = [0.0_f64; 5];
236    for i in 0..5 {
237        let ai = a[i];
238        if ai == 0.0 {
239            continue;
240        }
241        for j in 0..(5 - i) {
242            c[i + j] += ai * b[j];
243        }
244    }
245    c
246}
247
248/// Reciprocal of a 5-term truncated Taylor series with nonzero constant term.
249#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251    let mut b = [0.0_f64; 5];
252    b[0] = 1.0 / a[0];
253    for k in 1..5 {
254        let mut s = 0.0_f64;
255        for j in 1..=k {
256            s += a[j] * b[k - j];
257        }
258        b[k] = -s * b[0];
259    }
260    b
261}
262
263/// 5-jet (value + four eta-derivatives) of the GLM Fisher working weight
264/// `W(eta) = mu'(eta)^2 / V(mu(eta))` for the requested standard link, returned
265/// as `(W, W', W'', W''', W'''')`.
266///
267/// For the canonical logit link this is exactly the binomial weight
268/// `W = mu(1 - mu) = mu'`, whose eta-derivatives are the higher derivatives of
269/// the inverse-link jet (`W^(k) = mu^(k+1)`); the dispatch returns
270/// `logit_inverse_link_jet5`'s `d1..d5` byte-for-byte so the existing Firth
271/// logit path is numerically unchanged.
272///
273/// Noncanonical Bernoulli links use the same truncated Taylor-series quotient:
274/// assemble the inverse-link jet through `mu^(5)`, square the `mu'` series, and
275/// divide by the Bernoulli variance series `mu(1-mu)`. As the variance
276/// denominator saturates to zero in either tail, the weight and all derivatives
277/// saturate to zero, matching the inverse-link jet convention.
278pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279    match link {
280        StandardLink::Logit => {
281            let jet = logit_inverse_link_jet5(eta);
282            (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283        }
284        StandardLink::Probit => probit_fisher_weight_jet5(eta),
285        StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286        StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
287    }
288}
289
290pub(crate) fn fisher_weight_jet5_for_inverse_link(
291    link: &InverseLink,
292    eta: f64,
293) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
294    match link {
295        InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
296        InverseLink::LatentCLogLog(_)
297        | InverseLink::Sas(_)
298        | InverseLink::BetaLogistic(_)
299        | InverseLink::Mixture(_) => {
300            let jet = link.jet(eta)?;
301            let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
302            let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
303            Ok(fisher_weight_jet5_from_inverse_link_derivatives(
304                jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
305            ))
306        }
307    }
308}
309
310#[inline]
311pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
312    matches!(
313        link,
314        InverseLink::Standard(StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,)
315            | InverseLink::LatentCLogLog(_)
316            | InverseLink::Sas(_)
317            | InverseLink::BetaLogistic(_)
318            | InverseLink::Mixture(_)
319    )
320}
321
322#[inline]
323fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
324    let jet = component_inverse_link_jet(component, eta);
325    let d4 = component_inverse_link_pdfthird_derivative(component, eta);
326    let d5 = component_inverse_link_pdffourth_derivative(component, eta);
327    fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
328}
329
330#[inline]
331fn fisher_weight_jet5_from_inverse_link_derivatives(
332    mu: f64,
333    d1: f64,
334    d2: f64,
335    d3: f64,
336    d4: f64,
337    d5: f64,
338) -> (f64, f64, f64, f64, f64) {
339    if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
340        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
341    }
342    let variance = mu * (1.0 - mu);
343    if !(variance > 0.0) || !variance.is_finite() {
344        return (0.0, 0.0, 0.0, 0.0, 0.0);
345    }
346
347    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
348    let mu_d = [mu, d1, d2, d3, d4];
349    let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
350    let dmu_d = [d1, d2, d3, d4, d5];
351    let mut mu_t = [0.0_f64; 5];
352    let mut one_minus_mu_t = [0.0_f64; 5];
353    let mut dmu_t = [0.0_f64; 5];
354    for k in 0..5 {
355        let inv_fact = 1.0 / factorial[k];
356        mu_t[k] = mu_d[k] * inv_fact;
357        one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
358        dmu_t[k] = dmu_d[k] * inv_fact;
359    }
360    let num_t = taylor5_mul(&dmu_t, &dmu_t);
361    let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
362    if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
363        return (0.0, 0.0, 0.0, 0.0, 0.0);
364    }
365    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
366    (
367        canonicalzero(w_t[0] * factorial[0]),
368        canonicalzero(w_t[1] * factorial[1]),
369        canonicalzero(w_t[2] * factorial[2]),
370        canonicalzero(w_t[3] * factorial[3]),
371        canonicalzero(w_t[4] * factorial[4]),
372    )
373}
374
375/// Probit Bernoulli Fisher-weight 5-jet `W = phi^2 / (Phi (1 - Phi))` and its
376/// first four eta-derivatives. See [`fisher_weight_jet5`].
377#[inline]
378fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
379    if eta.is_nan() {
380        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
381    }
382    if !eta.is_finite() {
383        return (0.0, 0.0, 0.0, 0.0, 0.0);
384    }
385    let x = eta;
386    let p = normal_cdf(x);
387    // Compute the complement directly via Phi(-x) rather than `1 - Phi(x)`:
388    // in the positive tail `Phi(x)` rounds to 1.0 and `1 - Phi(x)` cancels to
389    // zero, whereas `Phi(-x)` retains the accurate (tiny) tail mass.
390    let q = normal_cdf(-x);
391    let phi = normal_pdf(x);
392    // Saturated tail: the denominator Phi(1-Phi) has underflowed to zero (or
393    // would divide by zero); the working weight and all derivatives go to zero.
394    if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
395        return (0.0, 0.0, 0.0, 0.0, 0.0);
396    }
397    // Gaussian derivative ladder: phi^(k) for k = 0..=4 using phi' = -x phi.
398    let phi1 = -x * phi;
399    let phi2 = (x * x - 1.0) * phi;
400    let phi3 = -(x * x * x - 3.0 * x) * phi;
401    let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
402    // Derivative arrays (d^k/deta^k) for f = phi, p = Phi, q = 1 - Phi.
403    // p^(0) = Phi, p^(k>=1) = phi^(k-1); q is the negated complement.
404    let f_d = [phi, phi1, phi2, phi3, phi4];
405    let p_d = [p, phi, phi1, phi2, phi3];
406    let q_d = [q, -phi, -phi1, -phi2, -phi3];
407    // Convert derivative arrays to Taylor coefficients a_k = g^(k)/k!.
408    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
409    let mut f_t = [0.0_f64; 5];
410    let mut p_t = [0.0_f64; 5];
411    let mut q_t = [0.0_f64; 5];
412    for k in 0..5 {
413        let inv_fact = 1.0 / factorial[k];
414        f_t[k] = f_d[k] * inv_fact;
415        p_t[k] = p_d[k] * inv_fact;
416        q_t[k] = q_d[k] * inv_fact;
417    }
418    let num_t = taylor5_mul(&f_t, &f_t);
419    let den_t = taylor5_mul(&p_t, &q_t);
420    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
421    // Back to derivatives W^(k) = w_t[k] * k!.
422    (
423        canonicalzero(w_t[0] * factorial[0]),
424        canonicalzero(w_t[1] * factorial[1]),
425        canonicalzero(w_t[2] * factorial[2]),
426        canonicalzero(w_t[3] * factorial[3]),
427        canonicalzero(w_t[4] * factorial[4]),
428    )
429}
430
431#[inline]
432fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
433    InverseLinkJet {
434        mu: base.mu,
435        d1: base.d1 * z1,
436        d2: base.d2 * z1 * z1 + base.d1 * z2,
437        d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
438    }
439}
440
441#[inline]
442fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
443    match component {
444        LinkComponent::Probit => probit_pdfthird_derivative(eta),
445        LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
446        LinkComponent::CLogLog => {
447            // CLogLog link:
448            //   mu = 1 - exp(-t),  t = exp(eta),  d1 = t exp(-t).
449            //
450            // Repeated differentiation closes in the basis `d1 * poly(t)`:
451            //   d2 = d1(-t + 1)
452            //   d3 = d1(t² - 3t + 1)
453            //   d4 = d1(-t³ + 6t² - 7t + 1).
454            if eta.is_nan() {
455                return f64::NAN;
456            }
457            if !eta.is_finite() {
458                return 0.0;
459            }
460            let t = eta.exp();
461            canonicalzero(stable_nonnegative_poly_times_exp_neg(
462                t,
463                &[0.0, 1.0, -7.0, 6.0, -1.0],
464            ))
465        }
466        LinkComponent::LogLog => {
467            // LogLog link is the reflected cloglog family with `r = exp(-eta)`:
468            //   mu = exp(-r), d1 = mu r,
469            // and again higher derivatives are `d1 * poly(r)`:
470            //   d2 = d1(r - 1)
471            //   d3 = d1(r² - 3r + 1)
472            //   d4 = d1(r³ - 6r² + 7r - 1).
473            if eta.is_nan() {
474                return f64::NAN;
475            }
476            if !eta.is_finite() {
477                return 0.0;
478            }
479            let r = (-eta).exp();
480            canonicalzero(stable_nonnegative_poly_times_exp_neg(
481                r,
482                &[0.0, -1.0, 7.0, -6.0, 1.0],
483            ))
484        }
485        LinkComponent::Cauchit => {
486            // Cauchit link:
487            //   mu = 1/2 + atan(eta)/pi,
488            //   d1 = 1 / [pi (1+eta²)].
489            //
490            // Differentiating three more times gives
491            //
492            //   d4 = 24 eta (1-eta²) / [pi (1+eta²)^4].
493            if eta.is_nan() {
494                return f64::NAN;
495            }
496            if !eta.is_finite() {
497                return 0.0;
498            }
499            let denom = 1.0 + eta * eta;
500            24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
501        }
502    }
503}
504
505/// Fifth derivative of a component inverse-link CDF (= fourth derivative of PDF).
506/// Extends `component_inverse_link_pdfthird_derivative` by one derivative order.
507#[inline]
508fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
509    match component {
510        LinkComponent::Probit => probit_pdffourth_derivative(eta),
511        LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
512        LinkComponent::CLogLog => {
513            // Exact closed form:
514            //   d5 = exp(-t) * (t - 15t^2 + 25t^3 - 10t^4 + t^5)
515            //      = d1 * (1 - 15t + 25t^2 - 10t^3 + t^4),
516            // where t = exp(eta).
517            if eta.is_nan() {
518                return f64::NAN;
519            }
520            if !eta.is_finite() {
521                return 0.0;
522            }
523            let t = eta.exp();
524            canonicalzero(stable_nonnegative_poly_times_exp_neg(
525                t,
526                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
527            ))
528        }
529        LinkComponent::LogLog => {
530            // Exact closed form:
531            //   d5 = exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5)
532            //      = d1 * (1 - 15r + 25r^2 - 10r^3 + r^4),
533            // where r = exp(-eta).
534            if eta.is_nan() {
535                return f64::NAN;
536            }
537            if !eta.is_finite() {
538                return 0.0;
539            }
540            let r = (-eta).exp();
541            canonicalzero(stable_nonnegative_poly_times_exp_neg(
542                r,
543                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
544            ))
545        }
546        LinkComponent::Cauchit => {
547            // d5 = 24(1 - 10eta^2 + 5eta^4) / [pi * (1+eta^2)^5]
548            if eta.is_nan() {
549                return f64::NAN;
550            }
551            if !eta.is_finite() {
552                return 0.0;
553            }
554            let e2 = eta * eta;
555            let denom = 1.0 + e2;
556            24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
557        }
558    }
559}
560
561#[derive(Clone, Debug, PartialEq)]
562pub struct MixtureJetWithRhoPartials {
563    pub jet: InverseLinkJet,
564    /// Partial derivatives wrt free logits rho_j, j in [0, K-2].
565    /// Each entry stores derivatives of (mu, d1, d2, d3) wrt one rho_j.
566    pub djet_drho: Vec<InverseLinkJet>,
567}
568
569#[derive(Clone, Debug, PartialEq)]
570pub struct SasJetWithParamPartials {
571    pub jet: InverseLinkJet,
572    pub djet_depsilon: InverseLinkJet,
573    pub djet_dlog_delta: InverseLinkJet,
574}
575
576#[derive(Clone, Debug, PartialEq)]
577pub enum LinkParamPartials {
578    Mixture(MixtureJetWithRhoPartials),
579    Sas(SasJetWithParamPartials),
580}
581
582/// Trait-based inverse-link kernel interface.
583///
584/// Implementors provide pointwise inverse-link derivatives wrt `eta`:
585/// `F(eta), F'(eta), F''(eta), F'''(eta)`.
586/// Optionally they may expose parameter partials used by outer-loop optimization.
587pub trait InverseLinkKernel {
588    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
589
590    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
591        assert!(eta.is_finite(), "eta must be finite");
592        Ok(None)
593    }
594}
595
596#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
597pub struct ProbitLinkKernel;
598
599#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
600pub struct LogitLinkKernel;
601
602#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
603pub struct CLogLogLinkKernel;
604
605#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
606pub struct LogLogLinkKernel;
607
608#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
609pub struct CauchitLinkKernel;
610
611/// Construct SAS state from raw optimizer parameters using the same bounded
612/// transform used everywhere in fitting/evaluation.
613///
614/// A free function rather than an inherent `SasLinkState::new` because the
615/// bounded `delta` transform is solver-side math, so the constructor is hosted
616/// here next to the transform rather than on the type. `SasLinkState`'s fields
617/// are `pub`, so it builds directly.
618pub fn sas_link_state_from_raw(
619    raw_epsilon: f64,
620    raw_log_delta: f64,
621) -> Result<SasLinkState, String> {
622    if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
623        return Err("SAS link parameters must be finite".to_string());
624    }
625    Ok(SasLinkState {
626        epsilon: raw_epsilon,
627        log_delta: raw_log_delta,
628        delta: sas_delta_from_raw_log_delta(raw_log_delta),
629    })
630}
631
632pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
633    sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
634}
635
636pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
637    if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
638        return Err("Beta-Logistic link parameters must be finite".to_string());
639    }
640    // For Beta-Logistic, `log_delta` is the unconstrained log geometric-mean beta
641    // shape (the kernels' `log_shape_center`). Evaluation consumes `log_delta`,
642    // never `delta`, but keep the shared `SasLinkState::delta` field on the same
643    // bounded SAS parameterization used by `state_from_sasspec` so constructing a
644    // state from a large finite raw log-delta cannot overflow this derived field.
645    let log_shape_center = spec.initial_log_delta;
646    Ok(SasLinkState {
647        epsilon: spec.initial_epsilon,
648        log_delta: log_shape_center,
649        delta: sas_delta_from_raw_log_delta(log_shape_center),
650    })
651}
652
653#[inline]
654fn tanh_bound(value: f64, bound: f64) -> f64 {
655    let b = bound.max(f64::EPSILON);
656    b * (value / b).tanh()
657}
658
659#[inline]
660fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
661    let b = bound.max(f64::EPSILON);
662    let t = (value / b).tanh();
663    1.0 - t * t
664}
665
666#[inline]
667fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
668    let b = bound.max(f64::EPSILON);
669    let t = (value / b).tanh();
670    let s = 1.0 - t * t;
671    -2.0 * t * s / b
672}
673
674#[inline]
675fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
676    let b = bound.max(f64::EPSILON);
677    let t = (value / b).tanh();
678    let s = 1.0 - t * t;
679    -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
680}
681
682#[inline]
683fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
684    let b = bound.max(f64::EPSILON);
685    let t = (value / b).tanh();
686    let s = 1.0 - t * t;
687    8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
688}
689
690#[inline]
691fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
692    // 5th derivative of B * tanh(x/B):
693    //   g5 = 8 * s * (2 - 15*t^2 + 15*t^4) / B^4
694    // where t = tanh(x/B) and s = 1 - t^2.
695    let b = bound.max(f64::EPSILON);
696    let t = (value / b).tanh();
697    let s = 1.0 - t * t;
698    let t2 = t * t;
699    let b4 = b * b * b * b;
700    8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
701}
702
703#[inline]
704fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
705    let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
706    let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
707    (ld_eff, dld_eff_draw)
708}
709
710#[inline]
711fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
712    let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
713    ld_eff.exp()
714}
715
716pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
717    if spec.components.is_empty() {
718        return Err("mixture link requires at least 1 component".to_string());
719    }
720    if spec.initial_rho.len() + 1 != spec.components.len() {
721        return Err(format!(
722            "mixture link rho length mismatch: expected {}, got {}",
723            spec.components.len() - 1,
724            spec.initial_rho.len()
725        ));
726    }
727    for i in 0..spec.components.len() {
728        for j in (i + 1)..spec.components.len() {
729            if spec.components[i] == spec.components[j] {
730                return Err("mixture link components must be unique".to_string());
731            }
732        }
733    }
734    // `LinkComponent` admits two variants (Cauchit, LogLog) that have no matching
735    // `LinkFunction` entry. The mixture-link pipeline projects every mixture back onto
736    // a single `LinkFunction` value for downstream solver/IO bookkeeping (see
737    // `InverseLink::link_function`), so a mixture composed solely of components without
738    // a LinkFunction representative would silently lie about its projected link. We
739    // require every spec to contain at least one Logit/Probit/CLogLog "anchor" so the
740    // projection is meaningful. A mixture of only {Cauchit, LogLog} (or any single
741    // component restricted to that pair) has no anchor and is rejected here.
742    let has_anchor = spec.components.iter().any(|component| {
743        matches!(
744            component,
745            LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
746        )
747    });
748    if !has_anchor {
749        let unsupported: Vec<&str> = spec
750            .components
751            .iter()
752            .map(|component| component.name())
753            .collect();
754        return Err(format!(
755            "mixture link components {{{}}} are unsupported: at least one component \
756             must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
757             projected LinkFunction is well defined; cauchit and loglog have no \
758             LinkFunction representative",
759            unsupported.join(", ")
760        ));
761    }
762    Ok(())
763}
764
765pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
766    let k = rho.len() + 1;
767    let mut logits = Vec::with_capacity(k);
768    let mut maxv = 0.0_f64;
769    for &v in rho {
770        maxv = maxv.max(v);
771        logits.push(v);
772    }
773    maxv = maxv.max(0.0);
774    logits.push(0.0);
775
776    let mut sum = 0.0_f64;
777    let mut exps = vec![0.0_f64; k];
778    for i in 0..k {
779        let e = (logits[i] - maxv).exp();
780        exps[i] = e;
781        sum += e;
782    }
783    if !sum.is_finite() || sum <= 0.0 {
784        return Array1::from_elem(k, 1.0 / k as f64);
785    }
786    let inv = 1.0 / sum;
787    Array1::from_iter(exps.into_iter().map(|v| v * inv))
788}
789
790/// Returns softmax weights and Jacobian wrt free logits (last logit fixed at zero).
791/// Jacobian shape is (K, K-1): d pi_k / d rho_j.
792pub fn softmaxwith_jacobian_last_fixedzero(
793    rho: &Array1<f64>,
794) -> (Array1<f64>, ndarray::Array2<f64>) {
795    let pi = softmax_last_fixedzero(rho);
796    let k = pi.len();
797    let m = k.saturating_sub(1);
798    let mut jac = ndarray::Array2::<f64>::zeros((k, m));
799    for j in 0..m {
800        let pi_j = pi[j];
801        for kk in 0..k {
802            let delta = if kk == j { 1.0 } else { 0.0 };
803            jac[[kk, j]] = pi[kk] * (delta - pi_j);
804        }
805    }
806    (pi, jac)
807}
808
809pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
810    validate_mixturespec(spec)?;
811    let pi = softmax_last_fixedzero(&spec.initial_rho);
812    Ok(MixtureLinkState {
813        components: spec.components.clone(),
814        rho: spec.initial_rho.clone(),
815        pi,
816    })
817}
818
819#[inline]
820pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
821    canonicalize_jet(match component {
822        LinkComponent::Logit => {
823            let jet = logit_inverse_link_jet5(eta);
824            InverseLinkJet {
825                mu: jet.mu,
826                d1: jet.d1,
827                d2: jet.d2,
828                d3: jet.d3,
829            }
830        }
831        LinkComponent::Probit => probit_jet(eta),
832        LinkComponent::CLogLog => {
833            if eta.is_nan() {
834                return InverseLinkJet {
835                    mu: f64::NAN,
836                    d1: f64::NAN,
837                    d2: f64::NAN,
838                    d3: f64::NAN,
839                };
840            }
841            let t = eta.exp();
842            if !t.is_finite() {
843                return InverseLinkJet {
844                    mu: 1.0,
845                    d1: 0.0,
846                    d2: 0.0,
847                    d3: 0.0,
848                };
849            }
850            InverseLinkJet {
851                mu: -(-t).exp_m1(),
852                d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
853                d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
854                d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
855            }
856        }
857        LinkComponent::LogLog => {
858            if eta.is_nan() {
859                return InverseLinkJet {
860                    mu: f64::NAN,
861                    d1: f64::NAN,
862                    d2: f64::NAN,
863                    d3: f64::NAN,
864                };
865            }
866            let r = (-eta).exp();
867            if !r.is_finite() {
868                return InverseLinkJet {
869                    mu: 0.0,
870                    d1: 0.0,
871                    d2: 0.0,
872                    d3: 0.0,
873                };
874            }
875            InverseLinkJet {
876                mu: (-r).exp(),
877                d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
878                d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
879                d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
880            }
881        }
882        LinkComponent::Cauchit => {
883            if eta.is_nan() {
884                return InverseLinkJet {
885                    mu: f64::NAN,
886                    d1: f64::NAN,
887                    d2: f64::NAN,
888                    d3: f64::NAN,
889                };
890            }
891            let den = 1.0 + eta * eta;
892            let d1 = if eta.is_finite() {
893                1.0 / (std::f64::consts::PI * den)
894            } else {
895                0.0
896            };
897            let d2 = if eta.is_finite() {
898                -2.0 * eta / (std::f64::consts::PI * den * den)
899            } else {
900                0.0
901            };
902            let d3 = if eta.is_finite() {
903                (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
904            } else {
905                0.0
906            };
907            InverseLinkJet {
908                mu: 0.5 + eta.atan() / std::f64::consts::PI,
909                d1,
910                d2,
911                d3,
912            }
913        }
914    })
915}
916
917impl InverseLinkKernel for ProbitLinkKernel {
918    #[inline]
919    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
920        Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
921    }
922}
923
924impl InverseLinkKernel for LogitLinkKernel {
925    #[inline]
926    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
927        Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
928    }
929}
930
931impl InverseLinkKernel for CLogLogLinkKernel {
932    #[inline]
933    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
934        Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
935    }
936}
937
938impl InverseLinkKernel for LogLogLinkKernel {
939    #[inline]
940    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
941        Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
942    }
943}
944
945impl InverseLinkKernel for CauchitLinkKernel {
946    #[inline]
947    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
948        Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
949    }
950}
951
952impl InverseLinkKernel for LinkComponent {
953    #[inline]
954    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
955        Ok(component_inverse_link_jet(*self, eta))
956    }
957}
958
959impl InverseLinkKernel for LinkFunction {
960    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
961        match self {
962            LinkFunction::Logit => LogitLinkKernel.jet(eta),
963            LinkFunction::Probit => ProbitLinkKernel.jet(eta),
964            LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
965            LinkFunction::Identity => Ok(InverseLinkJet {
966                mu: eta,
967                d1: 1.0,
968                d2: 0.0,
969                d3: 0.0,
970            }),
971            LinkFunction::Log => {
972                // SOLVER-INTERNAL inverse-link jet: `η.clamp(−700, 700).exp()`.
973                // The clamp is an intentional conditioning hack so the IRLS/REML
974                // normal equations stay well posed when η wanders into the tails
975                // during a trust-region step — it is NOT the public response
976                // transform. Public response-scale outputs (predictions, FFI
977                // `apply_inverse_link_array`, posterior bands) must use the EXACT
978                // `exp(η)` in `families::inverse_link::apply_inverse_link_vec`,
979                // which is finite wherever representable. Do not reroute a public
980                // output through this clamped jet (issue #963). Keep the clamp:
981                // solver consumers (e.g. `reml/runtime.rs` trust-region `excess`)
982                // pass raw η and rely on it to keep μ finite.
983                let e = eta.clamp(-700.0, 700.0).exp();
984                Ok(InverseLinkJet {
985                    mu: e,
986                    d1: e,
987                    d2: e,
988                    d3: e,
989                })
990            }
991            LinkFunction::Sas => Err(EstimationError::InvalidInput(
992                "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
993            )),
994            LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
995                "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
996                    .to_string(),
997            )),
998        }
999    }
1000}
1001
1002impl InverseLinkKernel for SasLinkState {
1003    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1004        Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1005    }
1006
1007    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1008        Ok(Some(LinkParamPartials::Sas(
1009            sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1010        )))
1011    }
1012}
1013
1014#[derive(Clone, Copy, Debug)]
1015pub struct BetaLogisticKernel {
1016    /// Unconstrained log of the geometric-mean beta shape — the raw optimization
1017    /// parameter `SasLinkState::log_delta`, NOT the derived `SasLinkState::delta`.
1018    pub log_shape_center: f64,
1019    pub epsilon: f64,
1020}
1021
1022impl InverseLinkKernel for BetaLogisticKernel {
1023    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1024        Ok(beta_logistic_inverse_link_jet(
1025            eta,
1026            self.log_shape_center,
1027            self.epsilon,
1028        ))
1029    }
1030
1031    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1032        Ok(Some(LinkParamPartials::Sas(
1033            beta_logistic_inverse_link_jetwith_param_partials(
1034                eta,
1035                self.log_shape_center,
1036                self.epsilon,
1037            ),
1038        )))
1039    }
1040}
1041
1042impl InverseLinkKernel for MixtureLinkState {
1043    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1044        Ok(mixture_inverse_link_jet(self, eta))
1045    }
1046
1047    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1048        Ok(Some(LinkParamPartials::Mixture(
1049            mixture_inverse_link_jetwith_rho_partials(self, eta),
1050        )))
1051    }
1052}
1053
1054impl InverseLinkKernel for InverseLink {
1055    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1056        match self {
1057            InverseLink::Standard(link_fn) => link_fn.as_link_function().jet(eta),
1058            InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1059            InverseLink::Sas(state) => state.jet(eta),
1060            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1061                log_shape_center: state.log_delta,
1062                epsilon: state.epsilon,
1063            }
1064            .jet(eta),
1065            InverseLink::Mixture(state) => state.jet(eta),
1066        }
1067    }
1068
1069    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1070        match self {
1071            InverseLink::Standard(_) => Ok(None),
1072            InverseLink::LatentCLogLog(_) => Ok(None),
1073            InverseLink::Sas(state) => state.param_partials(eta),
1074            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1075                log_shape_center: state.log_delta,
1076                epsilon: state.epsilon,
1077            }
1078            .param_partials(eta),
1079            InverseLink::Mixture(state) => state.param_partials(eta),
1080        }
1081    }
1082}
1083
1084/// Central family-aware inverse-link jet dispatch.
1085///
1086/// For `BinomialSas` and `BinomialMixture`, required state must be provided.
1087pub fn inverse_link_jet_for_inverse_link(
1088    link: &InverseLink,
1089    eta: f64,
1090) -> Result<InverseLinkJet, EstimationError> {
1091    link.jet(eta)
1092}
1093
1094/// Specialized `(mu, d1)` inverse-link evaluation that skips the d2/d3
1095/// polynomial chain used by the full jet. Numerical semantics are preserved:
1096/// the returned `mu` and `d1` are bit-identical to the corresponding fields of
1097/// `inverse_link_jet_for_inverse_link(link, eta)?` for every supported link.
1098///
1099/// For latent cloglog the underlying lognormal-Laplace kernel produces all
1100/// orders together, so this falls back to the full jet for that branch — the
1101/// savings come from the parameterised polynomial links (SAS, beta-logistic,
1102/// mixture) and the simple analytic links where d2/d3 are pure waste.
1103pub fn inverse_link_mu_d1_for_inverse_link(
1104    link: &InverseLink,
1105    eta: f64,
1106) -> Result<(f64, f64), EstimationError> {
1107    match link {
1108        InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1109        InverseLink::LatentCLogLog(state) => {
1110            let jet = latent_cloglog_point_jet(state, eta)?;
1111            Ok((jet.mu, jet.d1))
1112        }
1113        InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1114        InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1115            eta,
1116            state.log_delta,
1117            state.epsilon,
1118        )),
1119        InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1120    }
1121}
1122
1123fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1124    match link {
1125        LinkFunction::Identity => Ok((eta, 1.0)),
1126        LinkFunction::Log => {
1127            // SOLVER-INTERNAL clamped `(μ, dμ/dη)`; see the matching note on the
1128            // full `LinkFunction::Log` jet above. Public response transforms use
1129            // exact `exp(η)` via `families::inverse_link::apply_inverse_link_vec`
1130            // (issue #963).
1131            let e = eta.clamp(-700.0, 700.0).exp();
1132            Ok((e, e))
1133        }
1134        LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1135        LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1136        LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1137        LinkFunction::Sas => Err(EstimationError::InvalidInput(
1138            "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1139        )),
1140        LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1141            "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1142                .to_string(),
1143        )),
1144    }
1145}
1146
1147#[inline]
1148fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1149    // The full per-component jet already factors `mu` and `d1` exactly the same
1150    // way the higher orders are derived, so we either reuse the cheap closed
1151    // forms directly (Logit/Probit/CLogLog/LogLog/Cauchit) or fall back to the
1152    // existing canonicalised jet for the few cases without a separate fast
1153    // path — bit-identical to `component_inverse_link_jet(...).{mu,d1}`.
1154    match component {
1155        LinkComponent::Logit => {
1156            let jet = logit_inverse_link_jet5(eta);
1157            (jet.mu, canonicalzero(jet.d1))
1158        }
1159        LinkComponent::Probit => {
1160            if eta.is_nan() {
1161                return (f64::NAN, f64::NAN);
1162            }
1163            if eta == f64::INFINITY {
1164                return (1.0, 0.0);
1165            }
1166            if eta == f64::NEG_INFINITY {
1167                return (0.0, 0.0);
1168            }
1169            let phi = normal_pdf(eta);
1170            (normal_cdf(eta), canonicalzero(phi))
1171        }
1172        LinkComponent::CLogLog => {
1173            if eta.is_nan() {
1174                return (f64::NAN, f64::NAN);
1175            }
1176            let t = eta.exp();
1177            if !t.is_finite() {
1178                return (1.0, 0.0);
1179            }
1180            (
1181                -(-t).exp_m1(),
1182                canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1183            )
1184        }
1185        LinkComponent::LogLog => {
1186            if eta.is_nan() {
1187                return (f64::NAN, f64::NAN);
1188            }
1189            let r = (-eta).exp();
1190            if !r.is_finite() {
1191                return (0.0, 0.0);
1192            }
1193            (
1194                (-r).exp(),
1195                canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1196            )
1197        }
1198        LinkComponent::Cauchit => {
1199            if eta.is_nan() {
1200                return (f64::NAN, f64::NAN);
1201            }
1202            let den = 1.0 + eta * eta;
1203            let d1 = if eta.is_finite() {
1204                1.0 / (std::f64::consts::PI * den)
1205            } else {
1206                0.0
1207            };
1208            (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1209        }
1210    }
1211}
1212
1213fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1214    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1215    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1216        return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1217    }
1218    let e = if eta.is_finite() { eta } else { 0.0 };
1219    let a = e.asinh();
1220    let delta = delta_id;
1221    let u_raw = delta * a - epsilon;
1222    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1223    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1224    let s = u.sinh();
1225    let c = u.cosh();
1226    let z = s;
1227    let q = e.hypot(1.0);
1228    let inv_q = 1.0 / q;
1229    let r1 = delta * inv_q;
1230    let u1 = g1 * r1;
1231    let z1 = c * u1;
1232    // `mu = Phi(z)` and `d1 = phi(z) * z1`, the same closed forms used by the
1233    // full jet via `chain_inverse_link_jet(probit_jet(z), z1, _, _)`.
1234    let base = probit_jet(z);
1235    (base.mu, canonicalzero(base.d1 * z1))
1236}
1237
1238fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1239    let logistic = logistic_uwith_derivatives(eta);
1240    let a = (delta - epsilon).exp();
1241    let b = (delta + epsilon).exp();
1242    let mu = beta_reg_logistic(a, b, logistic);
1243    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1244    (mu, log_d1.exp())
1245}
1246
1247fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1248    let mut mu = 0.0_f64;
1249    let mut d1 = 0.0_f64;
1250    let k = state.components.len().min(state.pi.len());
1251    for i in 0..k {
1252        let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1253        let w = state.pi[i];
1254        mu += w * mu_i;
1255        d1 += w * d1_i;
1256    }
1257    (mu, d1)
1258}
1259
1260#[derive(Clone, Copy)]
1261enum PdfDerivativeOrder {
1262    Third,
1263    Fourth,
1264}
1265
1266impl PdfDerivativeOrder {
1267    fn probit(self, eta: f64) -> f64 {
1268        match self {
1269            Self::Third => probit_pdfthird_derivative(eta),
1270            Self::Fourth => probit_pdffourth_derivative(eta),
1271        }
1272    }
1273
1274    fn component(self, component: LinkComponent, eta: f64) -> f64 {
1275        match self {
1276            Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1277            Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1278        }
1279    }
1280
1281    fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1282        let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1283        Ok(match self {
1284            Self::Third => jet.d4,
1285            Self::Fourth => jet.d5,
1286        })
1287    }
1288
1289    fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1290        match self {
1291            Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1292            Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1293        }
1294    }
1295
1296    fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1297        match self {
1298            Self::Third => {
1299                beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1300            }
1301            Self::Fourth => {
1302                beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1303            }
1304        }
1305    }
1306}
1307
1308fn inverse_link_pdf_derivative_for_inverse_link(
1309    link: &InverseLink,
1310    eta: f64,
1311    order: PdfDerivativeOrder,
1312) -> Result<f64, EstimationError> {
1313    match link {
1314        InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1315        InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1316        InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1317        InverseLink::Standard(StandardLink::Logit) => {
1318            Ok(order.component(LinkComponent::Logit, eta))
1319        }
1320        InverseLink::Standard(StandardLink::CLogLog) => {
1321            Ok(order.component(LinkComponent::CLogLog, eta))
1322        }
1323        InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1324        InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1325        InverseLink::BetaLogistic(state) => {
1326            Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1327        }
1328        InverseLink::Mixture(state) => Ok(state
1329            .components
1330            .iter()
1331            .zip(state.pi.iter())
1332            .map(|(&component, &weight)| weight * order.component(component, eta))
1333            .sum()),
1334    }
1335}
1336
1337pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1338    link: &InverseLink,
1339    eta: f64,
1340) -> Result<f64, EstimationError> {
1341    // This dispatch returns the fourth eta-derivative of the inverse-link CDF,
1342    // equivalently the third derivative of the inverse-link density
1343    //
1344    //   f(eta) = d/deta mu(eta).
1345    //
1346    // It is used downstream as the `f'''` input in
1347    //
1348    //   d³/deta³ log f = f'''/f - 3 f'f''/f² + 2(f')³/f³.
1349    //
1350    // Mixture links preserve linearity:
1351    //
1352    //   mu = sum_j pi_j mu_j
1353    //   => f''' = sum_j pi_j f_j'''
1354    //
1355    // because the mixture weights `pi_j` are constant with respect to `eta`.
1356    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1357}
1358
1359/// Fifth derivative of the inverse-link CDF (= fourth derivative of the PDF).
1360///
1361/// Extends `inverse_link_pdfthird_derivative_for_inverse_link` by one order.
1362/// Used for the outer REML Hessian Q[v_k, v_l] term in survival models,
1363/// specifically the `m1 * u_{abcd}` Arbogast contribution.
1364pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1365    link: &InverseLink,
1366    eta: f64,
1367) -> Result<f64, EstimationError> {
1368    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1369}
1370
1371
1372#[inline]
1373fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1374    // ApproxKind: NumericalApproximation — exp(exp(eta)) overflows f64 for
1375    // eta > ~3.7; the |eta| > 30 saturation tail returns survival ≈ 0 with
1376    // d_k = 0, matching the analytic limit to within IEEE-754 underflow.
1377    const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1378
1379    let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1380    let hazard = z.exp();
1381    let survival = (-hazard).exp();
1382    if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1383        return InverseLinkJet {
1384            mu: survival,
1385            d1: 0.0,
1386            d2: 0.0,
1387            d3: 0.0,
1388        };
1389    }
1390
1391    let d1 = -hazard * survival;
1392    let d2 = hazard * (hazard - 1.0) * survival;
1393    let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1394    InverseLinkJet {
1395        mu: survival,
1396        d1,
1397        d2,
1398        d3,
1399    }
1400}
1401
1402pub fn inverse_link_jet_for_family(
1403    spec: &LikelihoodSpec,
1404    eta: f64,
1405) -> Result<InverseLinkJet, EstimationError> {
1406    // RoystonParmar uses its own analytic survival inverse link irrespective of
1407    // the (nominal `Identity`) link slot carried in the spec.
1408    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1409        return Ok(royston_parmar_inverse_link_jet(eta));
1410    }
1411    spec.link.jet(eta)
1412}
1413
1414/// Exact-public log inverse-link jet: `mu = d1 = d2 = d3 = exp(η)` with NO
1415/// `η`-clamp. Sibling of the solver-internal `LinkFunction::Log` jet (which
1416/// clamps `η` to `[−700, 700]` as an IRLS/REML conditioning hack); see issue
1417/// #963. Every derivative of `exp` is `exp`, so all four jet slots carry the
1418/// same exact value — finite wherever representable, `0.0` on underflow,
1419/// `+∞` on overflow.
1420#[inline]
1421fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1422    let e = eta.exp();
1423    InverseLinkJet {
1424        mu: e,
1425        d1: e,
1426        d2: e,
1427        d3: e,
1428    }
1429}
1430
1431/// EXACT public inverse-link jet for response-scale prediction outputs.
1432///
1433/// Identical to [`inverse_link_jet_for_family`] for every link EXCEPT the
1434/// standard `Log` link, where it returns the exact `exp(η)` jet instead of the
1435/// solver's `η.clamp(−700, 700).exp()` conditioning transform. On finite `η`
1436/// the two diverge (η = 705: exact `exp(705)` ≈ 1.5e306 vs clamped
1437/// `exp(700)` ≈ 1.0e304; η = −720: exact ≈ 2e−313 vs clamped ≈ 9.9e−305), so
1438/// public predictions (`FamilyStrategy::inverse_link_jet`/`inverse_link_array`,
1439/// the predict response + delta-method SE path) route here. The solver/REML/
1440/// PIRLS engines keep the clamped jet (issue #963). For `|η| ≤ 700` this is
1441/// byte-identical to the clamped jet (the clamp is inert there), so no
1442/// in-range prediction changes.
1443pub fn inverse_link_jet_for_family_public(
1444    spec: &LikelihoodSpec,
1445    eta: f64,
1446) -> Result<InverseLinkJet, EstimationError> {
1447    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1448        return Ok(royston_parmar_inverse_link_jet(eta));
1449    }
1450    if let InverseLink::Standard(StandardLink::Log) = spec.link {
1451        return Ok(log_inverse_link_jet_exact(eta));
1452    }
1453    spec.link.jet(eta)
1454}
1455
1456#[inline]
1457pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1458    let mut mu = 0.0_f64;
1459    let mut d1 = 0.0_f64;
1460    let mut d2 = 0.0_f64;
1461    let mut d3 = 0.0_f64;
1462    let k = state.components.len().min(state.pi.len());
1463    for i in 0..k {
1464        let jet = component_inverse_link_jet(state.components[i], eta);
1465        let w = state.pi[i];
1466        mu += w * jet.mu;
1467        d1 += w * jet.d1;
1468        d2 += w * jet.d2;
1469        d3 += w * jet.d3;
1470    }
1471    InverseLinkJet { mu, d1, d2, d3 }
1472}
1473
1474/// Computes mixture jet and exact partial derivatives wrt free softmax logits.
1475///
1476/// Uses identities:
1477///   d mu     / d rho_j = pi_j (mu_j     - mu)
1478///   d mu'    / d rho_j = pi_j (mu_j'    - mu')
1479///   d mu''   / d rho_j = pi_j (mu_j''   - mu'')
1480///   d mu'''  / d rho_j = pi_j (mu_j'''  - mu''')
1481pub fn mixture_inverse_link_jetwith_rho_partials(
1482    state: &MixtureLinkState,
1483    eta: f64,
1484) -> MixtureJetWithRhoPartials {
1485    let k = state.components.len().min(state.pi.len());
1486    let m = k.saturating_sub(1);
1487    let mut djet_drho = vec![
1488        InverseLinkJet {
1489            mu: 0.0,
1490            d1: 0.0,
1491            d2: 0.0,
1492            d3: 0.0,
1493        };
1494        m
1495    ];
1496    let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1497    MixtureJetWithRhoPartials { jet, djet_drho }
1498}
1499
1500/// Computes mixture jet and writes exact rho partial jets into `out` (length >= K-1).
1501/// This avoids heap allocation in hot loops.
1502pub fn mixture_inverse_link_jetwith_rho_partials_into(
1503    state: &MixtureLinkState,
1504    eta: f64,
1505    out: &mut [InverseLinkJet],
1506) -> InverseLinkJet {
1507    let k = state.components.len().min(state.pi.len());
1508    let m = k.saturating_sub(1);
1509    assert!(
1510        out.len() >= m,
1511        "rho-partial output buffer too small: got {}, need {}",
1512        out.len(),
1513        m
1514    );
1515    let mut mixed = InverseLinkJet {
1516        mu: 0.0,
1517        d1: 0.0,
1518        d2: 0.0,
1519        d3: 0.0,
1520    };
1521    for i in 0..k {
1522        let jet_i = component_inverse_link_jet(state.components[i], eta);
1523        let w = state.pi[i];
1524        mixed.mu += w * jet_i.mu;
1525        mixed.d1 += w * jet_i.d1;
1526        mixed.d2 += w * jet_i.d2;
1527        mixed.d3 += w * jet_i.d3;
1528        // Cache the first K-1 component jets directly in the output buffer so
1529        // we don't recompute them in the partial loop.
1530        if i < m {
1531            out[i] = jet_i;
1532        }
1533    }
1534    for j in 0..m {
1535        let pi_j = state.pi[j];
1536        let cj = out[j];
1537        out[j] = InverseLinkJet {
1538            mu: pi_j * (cj.mu - mixed.mu),
1539            d1: pi_j * (cj.d1 - mixed.d1),
1540            d2: pi_j * (cj.d2 - mixed.d2),
1541            d3: pi_j * (cj.d3 - mixed.d3),
1542        };
1543    }
1544    mixed
1545}
1546
1547#[derive(Clone, Copy)]
1548struct LogisticU {
1549    u: f64,
1550    one_minus_u: f64,
1551    ln_u: f64,
1552    ln_one_minus_u: f64,
1553    du: f64,
1554    use_upper_tail: bool,
1555}
1556
1557#[inline]
1558fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1559    let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1560    let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1561    let u = ln_u.exp();
1562    let one_minus_u = ln_one_minus_u.exp();
1563    let du = (ln_u + ln_one_minus_u).exp();
1564    LogisticU {
1565        u,
1566        one_minus_u,
1567        ln_u,
1568        ln_one_minus_u,
1569        du,
1570        use_upper_tail: eta >= 0.0,
1571    }
1572}
1573
1574#[inline]
1575fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1576    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1577        return f64::NAN;
1578    }
1579    if logistic.ln_u == f64::NEG_INFINITY {
1580        return 0.0;
1581    }
1582    if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1583        return 1.0;
1584    }
1585    if logistic.use_upper_tail {
1586        1.0 - beta_reg(b, a, logistic.one_minus_u)
1587    } else {
1588        beta_reg(a, b, logistic.u)
1589    }
1590}
1591
1592#[inline]
1593fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1594    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1595        return (f64::NAN, f64::NAN, f64::NAN);
1596    }
1597    if logistic.use_upper_tail {
1598        let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1599        (1.0 - tail, -dtail_da, -dtail_db)
1600    } else {
1601        beta_reg_with_shape_partials(a, b, logistic.u)
1602    }
1603}
1604
1605#[inline]
1606fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1607    a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1608}
1609
1610#[derive(Clone, Copy)]
1611struct ShapeDual {
1612    v: f64,
1613    da: f64,
1614    db: f64,
1615}
1616
1617impl ShapeDual {
1618    #[inline]
1619    fn constant(v: f64) -> Self {
1620        Self {
1621            v,
1622            da: 0.0,
1623            db: 0.0,
1624        }
1625    }
1626
1627    #[inline]
1628    fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1629        Self { v, da, db }
1630    }
1631
1632    #[inline]
1633    fn clamp_small(self, floor: f64) -> Self {
1634        if self.v.abs() < floor {
1635            Self::constant(floor)
1636        } else {
1637            self
1638        }
1639    }
1640}
1641
1642impl std::ops::Add for ShapeDual {
1643    type Output = Self;
1644
1645    #[inline]
1646    fn add(self, rhs: Self) -> Self {
1647        Self {
1648            v: self.v + rhs.v,
1649            da: self.da + rhs.da,
1650            db: self.db + rhs.db,
1651        }
1652    }
1653}
1654
1655impl std::ops::Sub for ShapeDual {
1656    type Output = Self;
1657
1658    #[inline]
1659    fn sub(self, rhs: Self) -> Self {
1660        Self {
1661            v: self.v - rhs.v,
1662            da: self.da - rhs.da,
1663            db: self.db - rhs.db,
1664        }
1665    }
1666}
1667
1668impl std::ops::Mul for ShapeDual {
1669    type Output = Self;
1670
1671    #[inline]
1672    fn mul(self, rhs: Self) -> Self {
1673        Self {
1674            v: self.v * rhs.v,
1675            da: self.da * rhs.v + self.v * rhs.da,
1676            db: self.db * rhs.v + self.v * rhs.db,
1677        }
1678    }
1679}
1680
1681impl std::ops::Div for ShapeDual {
1682    type Output = Self;
1683
1684    #[inline]
1685    fn div(self, rhs: Self) -> Self {
1686        let inv = 1.0 / rhs.v;
1687        let inv2 = inv * inv;
1688        Self {
1689            v: self.v * inv,
1690            da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1691            db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1692        }
1693    }
1694}
1695
1696impl std::ops::Neg for ShapeDual {
1697    type Output = Self;
1698
1699    #[inline]
1700    fn neg(self) -> Self {
1701        ShapeDual {
1702            v: -self.v,
1703            da: -self.da,
1704            db: -self.db,
1705        }
1706    }
1707}
1708
1709#[inline]
1710fn shape_dual(v: f64) -> ShapeDual {
1711    ShapeDual::constant(v)
1712}
1713
1714// Analytic shape partials for I_x(a,b), obtained by differentiating the same
1715// regularized-beta continued fraction used by statrs. The normalizing term uses
1716// d log B(a,b) / da = psi(a) - psi(a+b) and likewise for b.
1717fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1718    if x0 <= 0.0 {
1719        return (0.0, 0.0, 0.0);
1720    }
1721    if x0 >= 1.0 {
1722        return (1.0, 0.0, 0.0);
1723    }
1724
1725    let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1726    let (a, b, x) = if symm_transform {
1727        (
1728            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1729            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1730            1.0 - x0,
1731        )
1732    } else {
1733        (
1734            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1735            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1736            x0,
1737        )
1738    };
1739
1740    let ln_x = x.ln();
1741    let ln_1mx = (1.0 - x).ln();
1742    let psi_ab = digamma(a.v + b.v);
1743    let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1744        - statrs::function::gamma::ln_gamma(a.v)
1745        - statrs::function::gamma::ln_gamma(b.v)
1746        + a.v * ln_x
1747        + b.v * ln_1mx;
1748    let bt_v = log_bt.exp();
1749    let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1750    let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1751    let bt = ShapeDual {
1752        v: bt_v,
1753        da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1754        db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1755    };
1756
1757    let eps = 0.00000000000000011102230246251565;
1758    let fpmin = f64::MIN_POSITIVE / eps;
1759    let one = shape_dual(1.0);
1760    let qab = a + b;
1761    let qap = a + one;
1762    let qam = a - one;
1763    let mut c = one;
1764    let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1765    d = one / d;
1766    let mut h = d;
1767
1768    for m in 1..141 {
1769        let mf = f64::from(m);
1770        let m2 = mf * 2.0;
1771        let md = shape_dual(mf);
1772        let m2d = shape_dual(m2);
1773        let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1774        d = (one + aa * d).clamp_small(fpmin);
1775        c = (one + aa / c).clamp_small(fpmin);
1776        d = one / d;
1777        h = h * d * c;
1778
1779        aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1780        d = (one + aa * d).clamp_small(fpmin);
1781        c = (one + aa / c).clamp_small(fpmin);
1782        d = one / d;
1783        let del = d * c;
1784        h = h * del;
1785
1786        if (del.v - 1.0).abs() <= eps {
1787            let reg = bt * h / a;
1788            return if symm_transform {
1789                (1.0 - reg.v, -reg.da, -reg.db)
1790            } else {
1791                (reg.v, reg.da, reg.db)
1792            };
1793        }
1794    }
1795    let reg = bt * h / a;
1796    if symm_transform {
1797        (1.0 - reg.v, -reg.da, -reg.db)
1798    } else {
1799        (reg.v, reg.da, reg.db)
1800    }
1801}
1802
1803/// Beta-Logistic inverse-link jet for:
1804///   u = logistic(eta)
1805///   a = exp(log_shape_center - epsilon), b = exp(log_shape_center + epsilon)
1806///   mu = I_u(a, b)
1807///
1808/// NOTE: `log_shape_center` is the *unconstrained* log of the geometric-mean
1809/// beta shape (so a·b = exp(2·log_shape_center)). Callers must pass the raw
1810/// optimization parameter `SasLinkState::log_delta`, NOT the derived positive
1811/// `SasLinkState::delta = exp(log_shape_center)`.
1812pub fn beta_logistic_inverse_link_jet(
1813    eta: f64,
1814    log_shape_center: f64,
1815    epsilon: f64,
1816) -> InverseLinkJet {
1817    let logistic = logistic_uwith_derivatives(eta);
1818    let a = (log_shape_center - epsilon).exp();
1819    let b = (log_shape_center + epsilon).exp();
1820    let mu = beta_reg_logistic(a, b, logistic);
1821    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1822    let d1 = log_d1.exp();
1823    let t = a * logistic.one_minus_u - b * logistic.u;
1824    let d2 = d1 * t;
1825    let d3 = d1 * (t * t - (a + b) * logistic.du);
1826    InverseLinkJet { mu, d1, d2, d3 }
1827}
1828
1829pub fn beta_logistic_inverse_link_pdfthird_derivative(
1830    eta: f64,
1831    log_shape_center: f64,
1832    epsilon: f64,
1833) -> f64 {
1834    // Beta-logistic link:
1835    //
1836    //   u = logistic(eta),
1837    //   d1 = C * u^a (1-u)^b,
1838    //   t  = a(1-u) - b u,
1839    //   c  = a + b,
1840    //
1841    // so
1842    //
1843    //   d2 = d1 * t
1844    //   d3 = d1 * (t² - c u')
1845    //
1846    // with `u' = u(1-u)`.
1847    //
1848    // Differentiate once more:
1849    //
1850    //   d4 = d/deta[d1 (t² - c u')]
1851    //      = d1' (t² - c u') + d1 (2 t t' - c u'')
1852    //      = d1 [ t(t² - c u') - 2 c t u' - c u'' ]
1853    //      = d1 [ t³ - 3 c t u' - c u'' ],
1854    //
1855    // since `t' = -c u'`.
1856    let logistic = logistic_uwith_derivatives(eta);
1857    let a = (log_shape_center - epsilon).exp();
1858    let b = (log_shape_center + epsilon).exp();
1859    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1860    let d1 = log_d1.exp();
1861    let c = a + b;
1862    let t = a * logistic.one_minus_u - b * logistic.u;
1863    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1864    d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1865}
1866
1867/// Fifth derivative of the beta-logistic inverse-link CDF (= 4th deriv of PDF).
1868///
1869/// With `P_4 = t^3 - 3ct*u' - c*u''` giving `d4 = d1 * P_4`, the next order is:
1870///
1871///   d5 = d1 * [t^4 - 6c*t^2*u' - 4c*t*u'' + 3c^2*u'^2 - c*u''']
1872///
1873/// where u' = u(1-u), u'' = u'(1-2u), u''' = u''(1-2u) - 2*u'^2.
1874pub fn beta_logistic_inverse_link_pdffourth_derivative(
1875    eta: f64,
1876    log_shape_center: f64,
1877    epsilon: f64,
1878) -> f64 {
1879    let logistic = logistic_uwith_derivatives(eta);
1880    let a = (log_shape_center - epsilon).exp();
1881    let b = (log_shape_center + epsilon).exp();
1882    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1883    let d1 = log_d1.exp();
1884    let c = a + b;
1885    let t = a * logistic.one_minus_u - b * logistic.u;
1886    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1887    let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1888    let t2 = t * t;
1889    d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1890        + 3.0 * c * c * logistic.du * logistic.du
1891        - c * u3)
1892}
1893
1894pub fn beta_logistic_inverse_link_jetwith_param_partials(
1895    eta: f64,
1896    log_shape_center: f64,
1897    epsilon: f64,
1898) -> SasJetWithParamPartials {
1899    let logistic = logistic_uwith_derivatives(eta);
1900    let a = (log_shape_center - epsilon).exp();
1901    let b = (log_shape_center + epsilon).exp();
1902    let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1903    let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1904    let dmu_depsilon = -a * dmu_da + b * dmu_db;
1905    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1906    let d1 = log_d1.exp();
1907    let t = a * logistic.one_minus_u - b * logistic.u;
1908    let d2 = d1 * t;
1909    let k = t * t - (a + b) * logistic.du;
1910    let d3 = d1 * k;
1911    let jet = InverseLinkJet { mu, d1, d2, d3 };
1912
1913    let psi_a = digamma(a);
1914    let psi_b = digamma(b);
1915    let psi_ab = digamma(a + b);
1916    let la = logistic.ln_u - psi_a + psi_ab;
1917    let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1918
1919    let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1920        let logd1_p = a_p * la + b_p * lb;
1921        let d1_p = d1 * logd1_p;
1922        let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1923        let d2_p = d1_p * t + d1 * t_p;
1924        let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1925        let d3_p = d1_p * k + d1 * k_p;
1926        InverseLinkJet {
1927            mu: dmu,
1928            d1: d1_p,
1929            d2: d2_p,
1930            d3: d3_p,
1931        }
1932    };
1933    let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1934    let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1935    SasJetWithParamPartials {
1936        jet,
1937        djet_depsilon,
1938        djet_dlog_delta: djet_dlog_shape_center,
1939    }
1940}
1941
1942/// SAS inverse-link jet for:
1943///   mu(eta) = Phi(sinh(delta * asinh(eta) - epsilon)),
1944///   delta = exp(B * tanh(log_delta / B)), B = SAS_LOG_DELTA_BOUND.
1945pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1946    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1947    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1948        return component_inverse_link_jet(LinkComponent::Probit, eta);
1949    }
1950    let e = if eta.is_finite() { eta } else { 0.0 };
1951    let a = e.asinh();
1952    let delta = delta_id;
1953    let u_raw = delta * a - epsilon;
1954    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1955    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1956    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1957    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1958    let s = u.sinh();
1959    let c = u.cosh();
1960    let z = s;
1961    let q = e.hypot(1.0);
1962    let inv_q = 1.0 / q;
1963    let inv_q2 = inv_q * inv_q;
1964    let inv_q3 = inv_q2 * inv_q;
1965    let inv_q5 = inv_q3 * inv_q2;
1966    let r1 = delta * inv_q;
1967    let r2 = -delta * e * inv_q3;
1968    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
1969    let u1 = g1 * r1;
1970    let u2 = g2 * r1 * r1 + g1 * r2;
1971    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
1972    let z1 = c * u1;
1973    let z2 = s * u1 * u1 + c * u2;
1974    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
1975    let base = probit_jet(z);
1976    chain_inverse_link_jet(base, z1, z2, z3)
1977}
1978
1979pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1980    // SAS link with bounded latent transform:
1981    //
1982    //   a  = asinh(eta),
1983    //   u  = tanh_bound(delta * a - epsilon),
1984    //   z  = sinh(u),
1985    //   mu = Phi(z).
1986    //
1987    // Write:
1988    //
1989    //   z1 = z'
1990    //   z2 = z''
1991    //   z3 = z'''
1992    //   z4 = z''''.
1993    //
1994    // Since `mu' = phi(z) z1`, repeated differentiation factors through the
1995    // standard normal Hermite-polynomial identities:
1996    //
1997    //   mu''   = phi(z) [ z2 - z z1² ]
1998    //
1999    //   mu'''  = phi(z) [ z3 - 3 z z1 z2 + (z² - 1) z1³ ]
2000    //          = phi(z) k3
2001    //
2002    //   mu'''' = phi(z) [ k4 - z z1 k3 ],
2003    //
2004    // where `k4` is the derivative of `k3` after collecting like terms. The
2005    // code below computes `u1..u4`, then `z1..z4`, then `k3` and `k4`, exactly
2006    // matching that chain.
2007    //
2008    // The needed fourth derivative of `u(eta)` is obtained from the nested
2009    // composition `u(eta) = g(r(eta))` with
2010    //   g = tanh_bound, r = delta * asinh(eta) - epsilon:
2011    //
2012    //   u4 = g'''' r1^4 + 6 g''' r1² r2 + 3 g'' r2² + 4 g'' r1 r3 + g' r4,
2013    //
2014    // which is the standard scalar Arbogast expansion for order four.
2015    let e = if eta.is_finite() { eta } else { 0.0 };
2016    let a = e.asinh();
2017    let delta = sas_delta_from_raw_log_delta(log_delta);
2018    let u_raw = delta * a - epsilon;
2019    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2020    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2021    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2022    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2023    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2024    let s = u.sinh();
2025    let c = u.cosh();
2026    let z = s;
2027    let base = probit_jet(z);
2028    let q = e.hypot(1.0);
2029    let inv_q = 1.0 / q;
2030    let inv_q2 = inv_q * inv_q;
2031    let inv_q3 = inv_q2 * inv_q;
2032    let inv_q5 = inv_q3 * inv_q2;
2033    let inv_q7 = inv_q5 * inv_q2;
2034    let r1 = delta * inv_q;
2035    let r2 = -delta * e * inv_q3;
2036    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2037    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2038    let u1 = g1 * r1;
2039    let u2 = g2 * r1 * r1 + g1 * r2;
2040    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2041    let u4 = g4 * r1.powi(4)
2042        + 6.0 * g3 * r1 * r1 * r2
2043        + 3.0 * g2 * r2 * r2
2044        + 4.0 * g2 * r1 * r3
2045        + g1 * r4;
2046    let z1 = c * u1;
2047    let z2 = s * u1 * u1 + c * u2;
2048    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2049    let z4 =
2050        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2051    let base4 = probit_pdfthird_derivative(z);
2052    let out = base4 * z1.powi(4)
2053        + 6.0 * base.d3 * z1 * z1 * z2
2054        + 3.0 * base.d2 * z2 * z2
2055        + 4.0 * base.d2 * z1 * z3
2056        + base.d1 * z4;
2057    canonicalzero(out)
2058}
2059
2060/// Fifth derivative of the SAS inverse-link CDF (= fourth derivative of the PDF).
2061///
2062/// Extends `sas_inverse_link_pdfthird_derivative` by one more derivative order,
2063/// using the same composition chain u(eta) = g(r(eta)), z = sinh(u), mu = Phi(z).
2064///
2065/// The Arbogast expansion at order 5 for u(eta) = g(r(eta)) is:
2066///   u5 = g5 r1^5 + 10 g4 r1^3 r2 + 15 g3 r1 r2^2 + 10 g3 r1^2 r3
2067///        + 10 g2 r2 r3 + 5 g2 r1 r4 + g1 r5
2068///
2069/// The z = sinh(u) expansion at order 5 is the standard Arbogast for sinh:
2070///   z5 = c*u1^5 + 10*s*u1^3*u2 + 15*c*u1*u2^2 + 10*c*u1^2*u3
2071///        + 10*s*u2*u3 + 5*s*u1*u4 + c*u5
2072///
2073/// The mu = Phi(z) expansion at order 5 uses probit derivatives:
2074///   mu^(5) = Phi5*z1^5 + 10*Phi4*z1^3*z2 + 15*Phi3*z1*z2^2 + 10*Phi3*z1^2*z3
2075///            + 10*Phi2*z2*z3 + 5*Phi2*z1*z4 + Phi1*z5
2076pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2077    let e = if eta.is_finite() { eta } else { 0.0 };
2078    let a = e.asinh();
2079    let delta = sas_delta_from_raw_log_delta(log_delta);
2080    let u_raw = delta * a - epsilon;
2081    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2082    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2083    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2084    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2085    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2086    let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2087    let s = u.sinh();
2088    let c = u.cosh();
2089    let z = s;
2090
2091    // Probit derivatives at z.
2092    let base = probit_jet(z);
2093    let phi3 = probit_pdfthird_derivative(z); // Phi^{(4)}
2094    let phi4 = probit_pdffourth_derivative(z); // Phi^{(5)}
2095
2096    // Powers of q = sqrt(1 + eta^2) for r1..r5.
2097    let q = e.hypot(1.0);
2098    let inv_q = 1.0 / q;
2099    let inv_q2 = inv_q * inv_q;
2100    let inv_q3 = inv_q2 * inv_q;
2101    let inv_q5 = inv_q3 * inv_q2;
2102    let inv_q7 = inv_q5 * inv_q2;
2103    let inv_q9 = inv_q7 * inv_q2;
2104
2105    let r1 = delta * inv_q;
2106    let r2 = -delta * e * inv_q3;
2107    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2108    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2109    let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2110
2111    // u1..u5 via Arbogast for g(r(eta)).
2112    let u1 = g1 * r1;
2113    let u2 = g2 * r1 * r1 + g1 * r2;
2114    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2115    let u4 = g4 * r1.powi(4)
2116        + 6.0 * g3 * r1 * r1 * r2
2117        + 3.0 * g2 * r2 * r2
2118        + 4.0 * g2 * r1 * r3
2119        + g1 * r4;
2120    let u5 = g5 * r1.powi(5)
2121        + 10.0 * g4 * r1 * r1 * r1 * r2
2122        + 15.0 * g3 * r1 * r2 * r2
2123        + 10.0 * g3 * r1 * r1 * r3
2124        + 10.0 * g2 * r2 * r3
2125        + 5.0 * g2 * r1 * r4
2126        + g1 * r5;
2127
2128    // z1..z5 via Arbogast for sinh(u(eta)).
2129    let z1 = c * u1;
2130    let z2 = s * u1 * u1 + c * u2;
2131    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2132    let z4 =
2133        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2134    let z5 = c * u1.powi(5)
2135        + 10.0 * s * u1 * u1 * u1 * u2
2136        + 15.0 * c * u1 * u2 * u2
2137        + 10.0 * c * u1 * u1 * u3
2138        + 10.0 * s * u2 * u3
2139        + 5.0 * s * u1 * u4
2140        + c * u5;
2141
2142    // mu^(5) = Phi^(5)*z1^5 + 10*Phi^(4)*z1^3*z2 + 15*Phi^(3)*z1*z2^2
2143    //        + 10*Phi^(3)*z1^2*z3 + 10*Phi^(2)*z2*z3 + 5*Phi^(2)*z1*z4 + Phi^(1)*z5
2144    let out = phi4 * z1.powi(5)
2145        + 10.0 * phi3 * z1 * z1 * z1 * z2
2146        + 15.0 * base.d3 * z1 * z2 * z2
2147        + 10.0 * base.d3 * z1 * z1 * z3
2148        + 10.0 * base.d2 * z2 * z3
2149        + 5.0 * base.d2 * z1 * z4
2150        + base.d1 * z5;
2151    canonicalzero(out)
2152}
2153
2154pub fn sas_inverse_link_jetwith_param_partials(
2155    eta: f64,
2156    epsilon: f64,
2157    log_delta: f64,
2158) -> SasJetWithParamPartials {
2159    let e = if eta.is_finite() { eta } else { 0.0 };
2160    let a = e.asinh();
2161    let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2162    let delta = ld_eff.exp();
2163    let ddelta_draw = delta * dld_eff_draw;
2164    let u_raw = delta * a - epsilon;
2165    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2166    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2167    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2168    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2169    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2170    let s = u.sinh();
2171    let c = u.cosh();
2172    let z = s;
2173    let q = e.hypot(1.0);
2174    let inv_q = 1.0 / q;
2175    let inv_q2 = inv_q * inv_q;
2176    let inv_q3 = inv_q2 * inv_q;
2177    let inv_q5 = inv_q3 * inv_q2;
2178    let a1 = inv_q;
2179    let a2 = -e * inv_q3;
2180    let a3 = (2.0 * e * e - 1.0) * inv_q5;
2181    let r1 = delta * a1;
2182    let r2 = delta * a2;
2183    let r3 = delta * a3;
2184    let u1 = g1 * r1;
2185    let u2 = g2 * r1 * r1 + g1 * r2;
2186    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2187    let z1 = c * u1;
2188    let z2 = s * u1 * u1 + c * u2;
2189    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2190
2191    let base = probit_jet(z);
2192    let jet = chain_inverse_link_jet(base, z1, z2, z3);
2193
2194    // Generic chain for parameter t:
2195    // u_t, u1_t, u2_t, u3_t -> z_t,z1_t,z2_t,z3_t -> mu_t,d1_t,d2_t,d3_t
2196    let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2197        let z_t = c * u_t;
2198        let z1_t = s * u_t * u1 + c * u1_t;
2199        let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2200        let z3_t = s * u_t * u1 * u1 * u1
2201            + 3.0 * c * u1 * u1 * u1_t
2202            + 3.0 * c * u_t * u1 * u2
2203            + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2204            + s * u_t * u3
2205            + c * u3_t;
2206
2207        InverseLinkJet {
2208            mu: base.d1 * z_t,
2209            d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2210            d2: base.d3 * z_t * z1 * z1
2211                + 2.0 * base.d2 * z1 * z1_t
2212                + base.d2 * z_t * z2
2213                + base.d1 * z2_t,
2214            d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2215                + 3.0 * base.d3 * z1 * z1 * z1_t
2216                + 3.0 * base.d3 * z_t * z1 * z2
2217                + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2218                + base.d2 * z_t * z3
2219                + base.d1 * z3_t,
2220        }
2221    };
2222
2223    // epsilon partials (raw_u_t = -1).
2224    let rt_eps = -1.0;
2225    let r1t_eps = 0.0;
2226    let r2t_eps = 0.0;
2227    let r3t_eps = 0.0;
2228    let u_eps = g1 * rt_eps;
2229    let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2230    let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2231    let u3_eps = g4 * rt_eps * r1 * r1 * r1
2232        + 3.0 * g3 * r1 * r1 * r1t_eps
2233        + 3.0 * g3 * rt_eps * r1 * r2
2234        + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2235        + g2 * rt_eps * r3
2236        + g1 * r3t_eps;
2237    let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2238
2239    // raw log-delta partials (through smooth bounded effective log-delta).
2240    let rt_ld = ddelta_draw * a;
2241    let r1t_ld = ddelta_draw * a1;
2242    let r2t_ld = ddelta_draw * a2;
2243    let r3t_ld = ddelta_draw * a3;
2244    let u_ld = g1 * rt_ld;
2245    let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2246    let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2247    let u3_ld = g4 * rt_ld * r1 * r1 * r1
2248        + 3.0 * g3 * r1 * r1 * r1t_ld
2249        + 3.0 * g3 * rt_ld * r1 * r2
2250        + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2251        + g2 * rt_ld * r3
2252        + g1 * r3t_ld;
2253    let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2254
2255    SasJetWithParamPartials {
2256        jet,
2257        djet_depsilon,
2258        djet_dlog_delta,
2259    }
2260}
2261
2262#[cfg(test)]
2263mod tests {
2264    use super::*;
2265    use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2266
2267    #[test]
2268    fn softmax_jacobian_matchesfd() {
2269        let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2270        let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2271        let h = 1e-6;
2272        for j in 0..rho.len() {
2273            let mut rp = rho.clone();
2274            rp[j] += h;
2275            let mut rm = rho.clone();
2276            rm[j] -= h;
2277            let pp = softmax_last_fixedzero(&rp);
2278            let pm = softmax_last_fixedzero(&rm);
2279            let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2280            for k in 0..pi.len() {
2281                let err = (jac[[k, j]] - fd[k]).abs();
2282                assert_eq!(
2283                    jac[[k, j]].signum(),
2284                    fd[k].signum(),
2285                    "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2286                    jac[[k, j]],
2287                    fd[k]
2288                );
2289                assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2290            }
2291        }
2292    }
2293
2294    #[test]
2295    fn mixture_jet_rho_partials_matchfd() {
2296        let spec = MixtureLinkSpec {
2297            components: vec![
2298                LinkComponent::Probit,
2299                LinkComponent::Logit,
2300                LinkComponent::CLogLog,
2301                LinkComponent::Cauchit,
2302            ],
2303            initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2304        };
2305        let state = state_fromspec(&spec).expect("state");
2306        let eta = 0.35;
2307        let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2308        let h = 1e-6;
2309        for j in 0..state.rho.len() {
2310            let mut rp = state.rho.clone();
2311            rp[j] += h;
2312            let sp = MixtureLinkSpec {
2313                components: state.components.clone(),
2314                initial_rho: rp,
2315            };
2316            let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2317            let mut rm = state.rho.clone();
2318            rm[j] -= h;
2319            let sm = MixtureLinkSpec {
2320                components: state.components.clone(),
2321                initial_rho: rm,
2322            };
2323            let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2324            let fd = InverseLinkJet {
2325                mu: (jp.mu - jm.mu) / (2.0 * h),
2326                d1: (jp.d1 - jm.d1) / (2.0 * h),
2327                d2: (jp.d2 - jm.d2) / (2.0 * h),
2328                d3: (jp.d3 - jm.d3) / (2.0 * h),
2329            };
2330            let an = out.djet_drho[j];
2331            assert_eq!(an.mu.signum(), fd.mu.signum());
2332            assert_eq!(an.d1.signum(), fd.d1.signum());
2333            assert_eq!(an.d2.signum(), fd.d2.signum());
2334            assert_eq!(an.d3.signum(), fd.d3.signum());
2335            assert!((an.mu - fd.mu).abs() < 1e-6);
2336            assert!((an.d1 - fd.d1).abs() < 1e-6);
2337            assert!((an.d2 - fd.d2).abs() < 1e-6);
2338            assert!((an.d3 - fd.d3).abs() < 1e-6);
2339        }
2340    }
2341
2342    #[test]
2343    fn sas_param_partials_matchfd() {
2344        let eta = 0.37;
2345        let epsilon = -0.12;
2346        let log_delta = 0.21;
2347        let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2348        let h = 1e-6;
2349
2350        let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2351        let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2352        let fd_ep = InverseLinkJet {
2353            mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2354            d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2355            d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2356            d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2357        };
2358        assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2359        assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2360        assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2361        assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2362        assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2363        assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2364        assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2365        assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2366
2367        let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2368        let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2369        let fd_ld = InverseLinkJet {
2370            mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2371            d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2372            d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2373            d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2374        };
2375        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2376        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2377        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2378        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2379        assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2380        assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2381        assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2382        assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2383    }
2384
2385    #[test]
2386    fn sas_jet_extreme_inputs_stay_finite() {
2387        let cases = [
2388            (-1e6, 0.0, 0.0),
2389            (1e6, 0.0, 0.0),
2390            (3.0, 12.0, 12.0),
2391            (-3.0, -12.0, -12.0),
2392            (0.5, 40.0, 10.0),
2393            (0.5, -40.0, -10.0),
2394        ];
2395        for (eta, eps, log_delta) in cases {
2396            let j = sas_inverse_link_jet(eta, eps, log_delta);
2397            assert!(j.mu.is_finite());
2398            assert!(j.d1.is_finite());
2399            assert!(j.d2.is_finite());
2400            assert!(j.d3.is_finite());
2401            let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2402            assert!(p.djet_depsilon.mu.is_finite());
2403            assert!(p.djet_depsilon.d1.is_finite());
2404            assert!(p.djet_depsilon.d2.is_finite());
2405            assert!(p.djet_depsilon.d3.is_finite());
2406            assert!(p.djet_dlog_delta.mu.is_finite());
2407            assert!(p.djet_dlog_delta.d1.is_finite());
2408            assert!(p.djet_dlog_delta.d2.is_finite());
2409            assert!(p.djet_dlog_delta.d3.is_finite());
2410        }
2411    }
2412
2413    #[test]
2414    fn sas_param_partials_remain_finite_in_extreme_region() {
2415        let eta = 10.0;
2416        let epsilon = -60.0;
2417        let log_delta = 40.0;
2418        let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2419        assert!(j.djet_depsilon.mu.is_finite());
2420        assert!(j.djet_depsilon.d1.is_finite());
2421        assert!(j.djet_depsilon.d2.is_finite());
2422        assert!(j.djet_depsilon.d3.is_finite());
2423        assert!(j.djet_dlog_delta.mu.is_finite());
2424        assert!(j.djet_dlog_delta.d1.is_finite());
2425        assert!(j.djet_dlog_delta.d2.is_finite());
2426        assert!(j.djet_dlog_delta.d3.is_finite());
2427    }
2428
2429    #[test]
2430    fn sas_eta_jets_matchfd() {
2431        let eta = -0.43;
2432        let epsilon = 0.27;
2433        let log_delta = -0.31;
2434        let h = 1e-5;
2435        let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2436        let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2437        let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2438        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2439        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2440        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2441        assert_eq!(j0.d1.signum(), d1fd.signum());
2442        assert_eq!(j0.d2.signum(), d2fd.signum());
2443        assert_eq!(j0.d3.signum(), d3fd.signum());
2444        assert!((j0.d1 - d1fd).abs() < 5e-5);
2445        assert!((j0.d2 - d2fd).abs() < 2e-4);
2446        assert!((j0.d3 - d3fd).abs() < 1e-3);
2447    }
2448
2449    #[test]
2450    fn family_dispatch_resolves_parameterized_links_from_spec() {
2451        // After the LikelihoodSpec migration, the dispatch no longer needs
2452        // out-of-band state arguments — the parameterized link state lives on
2453        // `spec.link`. Verify that supplying explicit SAS/Mixture link states
2454        // through the spec produces finite jets at a representative eta.
2455        let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2456        let sas_spec = gam_problem::LikelihoodSpec {
2457            response: gam_problem::ResponseFamily::Binomial,
2458            link: InverseLink::Sas(sas_state),
2459        };
2460        let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2461        assert!(sas_jet.mu.is_finite());
2462        assert!(sas_jet.d1.is_finite());
2463
2464        let mix_state = MixtureLinkState {
2465            components: vec![LinkComponent::Logit, LinkComponent::Probit],
2466            rho: ndarray::array![0.0],
2467            pi: ndarray::array![0.5, 0.5],
2468        };
2469        let mix_spec = gam_problem::LikelihoodSpec {
2470            response: gam_problem::ResponseFamily::Binomial,
2471            link: InverseLink::Mixture(mix_state),
2472        };
2473        let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2474        assert!(mix_jet.mu.is_finite());
2475        assert!(mix_jet.d1.is_finite());
2476    }
2477
2478    #[test]
2479    fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2480        let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2481        for eta in etas {
2482            let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2483            let expected_mu = gam_linalg::utils::stable_logistic(eta);
2484            let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2485                - gam_linalg::utils::stable_softplus(eta))
2486            .exp();
2487            assert!(
2488                (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2489                "mu mismatch at eta={eta}: got {}, expected {}",
2490                j_bl.mu,
2491                expected_mu
2492            );
2493            assert!(
2494                (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2495                "d1 mismatch at eta={eta}: got {}, expected {}",
2496                j_bl.d1,
2497                expected_d1
2498            );
2499            assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2500        }
2501
2502        let eta = 0.42;
2503        let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2504        let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2505        assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2506        assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2507    }
2508
2509    #[test]
2510    fn beta_logistic_eta_jets_matchfd() {
2511        let eta = -0.31;
2512        let delta = 0.27;
2513        let epsilon = -0.19;
2514        let h = 1e-5;
2515        let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2516        let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2517        let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2518        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2519        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2520        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2521        assert_eq!(j0.d1.signum(), d1fd.signum());
2522        assert_eq!(j0.d2.signum(), d2fd.signum());
2523        assert_eq!(j0.d3.signum(), d3fd.signum());
2524        assert!((j0.d1 - d1fd).abs() < 5e-5);
2525        assert!((j0.d2 - d2fd).abs() < 5e-5);
2526        assert!((j0.d3 - d3fd).abs() < 2e-4);
2527    }
2528
2529    #[test]
2530    fn standard_kernel_structs_match_component_jets() {
2531        let eta = 0.73;
2532        assert_eq!(
2533            ProbitLinkKernel.jet(eta).expect("probit"),
2534            component_inverse_link_jet(LinkComponent::Probit, eta)
2535        );
2536        assert_eq!(
2537            LogitLinkKernel.jet(eta).expect("logit"),
2538            component_inverse_link_jet(LinkComponent::Logit, eta)
2539        );
2540        assert_eq!(
2541            CLogLogLinkKernel.jet(eta).expect("cloglog"),
2542            component_inverse_link_jet(LinkComponent::CLogLog, eta)
2543        );
2544        assert_eq!(
2545            LogLogLinkKernel.jet(eta).expect("loglog"),
2546            component_inverse_link_jet(LinkComponent::LogLog, eta)
2547        );
2548        assert_eq!(
2549            CauchitLinkKernel.jet(eta).expect("cauchit"),
2550            component_inverse_link_jet(LinkComponent::Cauchit, eta)
2551        );
2552    }
2553
2554    #[test]
2555    fn all_component_eta_jets_matchfd() {
2556        let components = [
2557            LinkComponent::Logit,
2558            LinkComponent::Probit,
2559            LinkComponent::CLogLog,
2560            LinkComponent::LogLog,
2561            LinkComponent::Cauchit,
2562        ];
2563        let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2564        let h = 1e-5;
2565        for c in components {
2566            for &eta in &points {
2567                let j0 = component_inverse_link_jet(c, eta);
2568                let jp = component_inverse_link_jet(c, eta + h);
2569                let jm = component_inverse_link_jet(c, eta - h);
2570                let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2571                let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2572                let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2573                let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2574                    1.2e-4
2575                } else {
2576                    5e-5
2577                };
2578                let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2579                    4e-4
2580                } else {
2581                    1.2e-4
2582                };
2583                let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2584                    1.2e-3
2585                } else {
2586                    4e-4
2587                };
2588                if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2589                    assert_eq!(
2590                        j0.d1.signum(),
2591                        d1fd.signum(),
2592                        "d1 sign mismatch for {c:?} eta={eta}"
2593                    );
2594                }
2595                if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2596                    assert_eq!(
2597                        j0.d2.signum(),
2598                        d2fd.signum(),
2599                        "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2600                        j0.d2,
2601                        d2fd
2602                    );
2603                }
2604                if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2605                    assert_eq!(
2606                        j0.d3.signum(),
2607                        d3fd.signum(),
2608                        "d3 sign mismatch for {c:?} eta={eta}"
2609                    );
2610                }
2611                assert!(
2612                    (j0.d1 - d1fd).abs() < d1_tol,
2613                    "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2614                    j0.d1,
2615                    d1fd
2616                );
2617                assert!(
2618                    (j0.d2 - d2fd).abs() < d2_tol,
2619                    "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2620                    j0.d2,
2621                    d2fd
2622                );
2623                assert!(
2624                    (j0.d3 - d3fd).abs() < d3_tol,
2625                    "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2626                    j0.d3,
2627                    d3fd
2628                );
2629            }
2630        }
2631    }
2632
2633    #[test]
2634    fn sas_center_matches_probit_at_delta1_epsilon0() {
2635        let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2636        for eta in etas {
2637            let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2638            let probit = ProbitLinkKernel.jet(eta).expect("probit");
2639            // SAS implementation uses a smooth bounded latent (`tanh_bound`) for
2640            // numerical robustness, so the probit center is approximate in practice.
2641            assert!(
2642                (sas.mu - probit.mu).abs() < 6e-4,
2643                "mu mismatch at eta={eta}"
2644            );
2645            assert!(
2646                (sas.d1 - probit.d1).abs() < 6e-4,
2647                "d1 mismatch at eta={eta}"
2648            );
2649            assert!(
2650                (sas.d2 - probit.d2).abs() < 2e-3,
2651                "d2 mismatch at eta={eta}"
2652            );
2653            assert!(
2654                (sas.d3 - probit.d3).abs() < 4e-3,
2655                "d3 mismatch at eta={eta}"
2656            );
2657        }
2658    }
2659
2660    #[test]
2661    fn beta_logistic_param_partials_matchfd() {
2662        let eta = -0.41;
2663        let delta = 0.23;
2664        let epsilon = -0.17;
2665        let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2666        let h = 1e-6;
2667
2668        let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2669        let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2670        let fd_delta = InverseLinkJet {
2671            mu: (dp.mu - dm.mu) / (2.0 * h),
2672            d1: (dp.d1 - dm.d1) / (2.0 * h),
2673            d2: (dp.d2 - dm.d2) / (2.0 * h),
2674            d3: (dp.d3 - dm.d3) / (2.0 * h),
2675        };
2676        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2677        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2678        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2679        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2680        assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2681        assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2682        assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2683        assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2684
2685        let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2686        let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2687        let fd_epsilon = InverseLinkJet {
2688            mu: (ep.mu - em.mu) / (2.0 * h),
2689            d1: (ep.d1 - em.d1) / (2.0 * h),
2690            d2: (ep.d2 - em.d2) / (2.0 * h),
2691            d3: (ep.d3 - em.d3) / (2.0 * h),
2692        };
2693        assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2694        assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2695        assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2696        assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2697        assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2698        assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2699        assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2700        assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2701    }
2702
2703    #[test]
2704    fn beta_logistic_left_tail_uses_unclamped_log_space() {
2705        let eta = -40.0_f64;
2706        let delta = 0.2_f64;
2707        let epsilon = -0.1_f64;
2708        let a = (delta - epsilon).exp();
2709        let b = (delta + epsilon).exp();
2710        let expected_mu = beta_reg(a, b, eta.exp());
2711        let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2712
2713        assert!(
2714            (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2715            "left-tail mu mismatch: got {}, expected {}",
2716            out.mu,
2717            expected_mu
2718        );
2719        assert!(out.d1 > 0.0);
2720        assert!(out.d2 > 0.0);
2721        assert!(out.d3 > 0.0);
2722        assert!(out.d1 < 1e-20);
2723
2724        let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2725        assert!(partials.jet.d1 > 0.0);
2726        assert!(partials.jet.d2 > 0.0);
2727        assert!(partials.jet.d3 > 0.0);
2728        assert!(partials.djet_dlog_delta.d1.is_finite());
2729        assert!(partials.djet_depsilon.d1.is_finite());
2730    }
2731
2732    #[test]
2733    fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2734        let delta = 0.2;
2735        let epsilon = -0.35;
2736        let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2737        for eta in etas {
2738            let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2739            let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2740            assert!(
2741                (left - right).abs() <= 1e-14,
2742                "symmetry mismatch at eta={eta}: left={left}, right={right}"
2743            );
2744        }
2745    }
2746
2747    #[test]
2748    fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2749        let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2750        let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2751            epsilon: 0.18,
2752            log_delta: -0.22,
2753            delta: (-0.22_f64).exp(),
2754        });
2755        let mixture = InverseLink::Mixture(
2756            state_fromspec(&MixtureLinkSpec {
2757                components: vec![
2758                    LinkComponent::Probit,
2759                    LinkComponent::Logit,
2760                    LinkComponent::CLogLog,
2761                    LinkComponent::Cauchit,
2762                ],
2763                initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2764            })
2765            .expect("mixture state"),
2766        );
2767        let links = [
2768            InverseLink::Standard(StandardLink::Probit),
2769            InverseLink::Standard(StandardLink::Logit),
2770            InverseLink::Standard(StandardLink::CLogLog),
2771            sas,
2772            beta_logistic,
2773            mixture,
2774        ];
2775        let etas = [-1.1, -0.2, 0.6];
2776        let h = 1e-5;
2777
2778        for link in &links {
2779            for &eta in &etas {
2780                let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2781                let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2782                let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2783                let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2784                    .expect("analytic d4");
2785                assert_eq!(
2786                    d4.signum(),
2787                    d4fd.signum(),
2788                    "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2789                    link,
2790                    d4,
2791                    d4fd
2792                );
2793                assert!(
2794                    (d4 - d4fd).abs() < 5e-3,
2795                    "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2796                    link,
2797                    d4,
2798                    d4fd
2799                );
2800            }
2801        }
2802    }
2803
2804    #[test]
2805    fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2806        let eta = 800.0;
2807        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2808        assert_eq!(jet.mu, 1.0);
2809        assert!(
2810            jet.d1 == 0.0,
2811            "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2812            jet.d1
2813        );
2814        assert!(
2815            jet.d2 == 0.0,
2816            "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2817            jet.d2
2818        );
2819        assert!(
2820            jet.d3 == 0.0,
2821            "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2822            jet.d3
2823        );
2824
2825        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2826            &InverseLink::Standard(StandardLink::CLogLog),
2827            eta,
2828        )
2829        .expect("cloglog d4");
2830        assert!(
2831            d4 == 0.0,
2832            "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2833        );
2834    }
2835
2836    #[test]
2837    fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2838        let eta = -800.0;
2839        let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2840        assert_eq!(jet.mu, 0.0);
2841        assert!(
2842            jet.d1 == 0.0,
2843            "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2844            jet.d1
2845        );
2846        assert!(
2847            jet.d2 == 0.0,
2848            "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2849            jet.d2
2850        );
2851        assert!(
2852            jet.d3 == 0.0,
2853            "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2854            jet.d3
2855        );
2856
2857        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2858            &InverseLink::Mixture(
2859                state_fromspec(&MixtureLinkSpec {
2860                    components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2861                    initial_rho: Array1::from_vec(vec![12.0]),
2862                })
2863                .expect("mixture state"),
2864            ),
2865            eta,
2866        )
2867        .expect("loglog mixture d4");
2868        assert!(
2869            d4.is_finite(),
2870            "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2871        );
2872    }
2873
2874    #[test]
2875    fn logit_tail_derivatives_should_match_stable_closed_forms() {
2876        let eta = 50.0_f64;
2877        let z = (-eta).exp();
2878        let denom = 1.0_f64 + z;
2879        let stable_d1 = z / denom.powi(2);
2880        let stable_d2 = z * (z - 1.0) / denom.powi(3);
2881        let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2882        let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2883        let stable_d5 =
2884            z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2885
2886        assert!(stable_d1 > 0.0);
2887        assert!(stable_d2 < 0.0);
2888        assert!(stable_d3 > 0.0);
2889        assert!(stable_d4 < 0.0);
2890        assert!(stable_d5 > 0.0);
2891
2892        let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2893        assert!(
2894            (jet.d1 - stable_d1).abs() < 1e-30,
2895            "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2896            jet.d1,
2897            stable_d1
2898        );
2899        assert!(
2900            (jet.d2 - stable_d2).abs() < 1e-30,
2901            "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2902            jet.d2,
2903            stable_d2
2904        );
2905        assert!(
2906            (jet.d3 - stable_d3).abs() < 1e-30,
2907            "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2908            jet.d3,
2909            stable_d3
2910        );
2911
2912        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2913            &InverseLink::Standard(StandardLink::Logit),
2914            eta,
2915        )
2916        .expect("logit d4");
2917        assert!(
2918            (d4 - stable_d4).abs() < 1e-30,
2919            "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2920            d4,
2921            stable_d4
2922        );
2923
2924        let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2925            &InverseLink::Standard(StandardLink::Logit),
2926            eta,
2927        )
2928        .expect("logit d5");
2929        assert!(
2930            (d5 - stable_d5).abs() < 1e-30,
2931            "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2932            d5,
2933            stable_d5
2934        );
2935    }
2936
2937    #[test]
2938    fn cloglog_negative_tail_value_should_match_expm1_form() {
2939        let eta = -50.0_f64;
2940        let t = eta.exp();
2941        let stable_mu = -(-t).exp_m1();
2942        assert!(stable_mu > 0.0);
2943
2944        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2945        assert!(
2946            (jet.mu - stable_mu).abs() < 1e-30,
2947            "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2948            jet.mu,
2949            stable_mu
2950        );
2951    }
2952
2953    #[test]
2954    fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2955        fn rel_err(a: f64, b: f64) -> f64 {
2956            (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2957        }
2958
2959        let cases = [
2960            (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2961            (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2962            (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2963        ];
2964        for (component, etas) in cases {
2965            for eta in etas {
2966                let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
2967                let jet = component_inverse_link_jet(component, eta);
2968                let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
2969                assert!(
2970                    rel_err(w, expected) < 1.0e-12,
2971                    "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
2972                );
2973
2974                let h = 1.0e-4;
2975                let fd1 = (component_fisher_weight_jet5(component, eta + h).0
2976                    - component_fisher_weight_jet5(component, eta - h).0)
2977                    / (2.0 * h);
2978                let fd2 = (component_fisher_weight_jet5(component, eta + h).1
2979                    - component_fisher_weight_jet5(component, eta - h).1)
2980                    / (2.0 * h);
2981                let fd3 = (component_fisher_weight_jet5(component, eta + h).2
2982                    - component_fisher_weight_jet5(component, eta - h).2)
2983                    / (2.0 * h);
2984                let fd4 = (component_fisher_weight_jet5(component, eta + h).3
2985                    - component_fisher_weight_jet5(component, eta - h).3)
2986                    / (2.0 * h);
2987
2988                assert!(
2989                    rel_err(w1, fd1) < 1.0e-5,
2990                    "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
2991                );
2992                assert!(
2993                    rel_err(w2, fd2) < 1.0e-5,
2994                    "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
2995                );
2996                assert!(
2997                    rel_err(w3, fd3) < 5.0e-5,
2998                    "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
2999                );
3000                assert!(
3001                    rel_err(w4, fd4) < 5.0e-4,
3002                    "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3003                );
3004            }
3005        }
3006    }
3007
3008    #[test]
3009    fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3010        let state = state_fromspec(&MixtureLinkSpec {
3011            components: vec![
3012                LinkComponent::CLogLog,
3013                LinkComponent::LogLog,
3014                LinkComponent::Cauchit,
3015            ],
3016            initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3017        })
3018        .expect("mixture state");
3019        let link = InverseLink::Mixture(state);
3020        assert!(
3021            inverse_link_has_fisher_weight_jet(&link),
3022            "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3023        );
3024        assert!(
3025            LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3026            "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3027        );
3028
3029        for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3030            let (w, w1, w2, w3, w4) =
3031                fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3032            for value in [w, w1, w2, w3, w4] {
3033                assert!(
3034                    value.is_finite(),
3035                    "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3036                );
3037            }
3038            assert!(
3039                w > 0.0,
3040                "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3041            );
3042        }
3043    }
3044
3045    #[test]
3046    fn loglog_fifth_derivative_should_match_closed_form_sign() {
3047        let eta = 0.0_f64;
3048        let r = (-eta).exp();
3049        let expected =
3050            (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3051        let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3052        assert!(
3053            (d5 - expected).abs() < 1e-15,
3054            "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3055        );
3056        assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3057    }
3058}