Skip to main content

gam_solve/
mixture_link.rs

1use crate::estimate::EstimationError;
2use crate::quadrature::latent_cloglog_jet5;
3use gam_math::probability::{normal_cdf, normal_pdf};
4use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
5use gam_problem::{
6    InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7    MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16/// Bound B used by the bounded sinh-arcsinh log-delta parameterisation:
17/// `delta = exp(B * tanh(raw_log_delta / B))`. Exposed for the outer-strategy
18/// edge-barrier helpers in `solver/estimate.rs` that previously had to
19/// hard-code the same `12.0` with a "must match" comment.
20pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24    static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25    QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30    state: &LatentCLogLogState,
31    eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33    let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34    Ok(InverseLinkJet {
35        mu: jet.mean,
36        d1: jet.d1,
37        d2: jet.d2,
38        d3: jet.d3,
39    })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44    pub mu: f64,
45    pub d1: f64,
46    pub d2: f64,
47    pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52    pub mu: f64,
53    pub d1: f64,
54    pub d2: f64,
55    pub d3: f64,
56    pub d4: f64,
57    pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62    if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67    jet.d1 = canonicalzero(jet.d1);
68    jet.d2 = canonicalzero(jet.d2);
69    jet.d3 = canonicalzero(jet.d3);
70    jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75    if eta.is_nan() {
76        return LogitJet5 {
77            mu: f64::NAN,
78            d1: f64::NAN,
79            d2: f64::NAN,
80            d3: f64::NAN,
81            d4: f64::NAN,
82            d5: f64::NAN,
83        };
84    }
85    if eta == f64::INFINITY {
86        return LogitJet5 {
87            mu: 1.0,
88            d1: 0.0,
89            d2: 0.0,
90            d3: 0.0,
91            d4: 0.0,
92            d5: 0.0,
93        };
94    }
95    if eta == f64::NEG_INFINITY {
96        return LogitJet5 {
97            mu: 0.0,
98            d1: 0.0,
99            d2: 0.0,
100            d3: 0.0,
101            d4: 0.0,
102            d5: 0.0,
103        };
104    }
105
106    let jet = if eta >= 0.0 {
107        let z = (-eta).exp();
108        let opz = 1.0 + z;
109        let opz2 = opz * opz;
110        let opz3 = opz2 * opz;
111        let opz4 = opz3 * opz;
112        let opz5 = opz4 * opz;
113        let opz6 = opz5 * opz;
114        let z2 = z * z;
115        let z3 = z2 * z;
116        let z4 = z3 * z;
117        LogitJet5 {
118            mu: 1.0 / opz,
119            d1: z / opz2,
120            d2: z * (z - 1.0) / opz3,
121            d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122            d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123            d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124        }
125    } else {
126        let z = eta.exp();
127        let opz = 1.0 + z;
128        let opz2 = opz * opz;
129        let opz3 = opz2 * opz;
130        let opz4 = opz3 * opz;
131        let opz5 = opz4 * opz;
132        let opz6 = opz5 * opz;
133        let z2 = z * z;
134        let z3 = z2 * z;
135        let z4 = z3 * z;
136        LogitJet5 {
137            mu: z / opz,
138            d1: z / opz2,
139            d2: z * (1.0 - z) / opz3,
140            d3: z * (1.0 - 4.0 * z + z2) / opz4,
141            d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142            d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143        }
144    };
145    LogitJet5 {
146        mu: jet.mu,
147        d1: canonicalzero(jet.d1),
148        d2: canonicalzero(jet.d2),
149        d3: canonicalzero(jet.d3),
150        d4: canonicalzero(jet.d4),
151        d5: canonicalzero(jet.d5),
152    }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157    // Exact probit semantics:
158    //
159    //   mu(eta) = Phi(eta),
160    //   mu'     = phi(eta),
161    //   mu''    = -eta * phi(eta),
162    //   mu'''   = (eta^2 - 1) * phi(eta).
163    //
164    // `normal_cdf` now evaluates the exact special-function form
165    // Phi(x) = 0.5 * erfc(-x / sqrt(2)), so the jet can and should use the
166    // matching closed-form Gaussian identities directly.
167    if eta.is_nan() {
168        return InverseLinkJet {
169            mu: f64::NAN,
170            d1: f64::NAN,
171            d2: f64::NAN,
172            d3: f64::NAN,
173        };
174    }
175    if eta == f64::INFINITY {
176        return InverseLinkJet {
177            mu: 1.0,
178            d1: 0.0,
179            d2: 0.0,
180            d3: 0.0,
181        };
182    }
183    if eta == f64::NEG_INFINITY {
184        return InverseLinkJet {
185            mu: 0.0,
186            d1: 0.0,
187            d2: 0.0,
188            d3: 0.0,
189        };
190    }
191    let x = eta;
192    let phi = normal_pdf(x);
193    InverseLinkJet {
194        mu: normal_cdf(x),
195        d1: phi,
196        d2: -x * phi,
197        d3: (x * x - 1.0) * phi,
198    }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203    // Since d1 = mu' = phi(eta), this returns
204    //
205    //   d³/deta³ d1 = mu'''' = -(eta³ - 3 eta) phi(eta).
206    if eta.is_nan() {
207        return f64::NAN;
208    }
209    if !eta.is_finite() {
210        return 0.0;
211    }
212    let x = eta;
213    let phi = normal_pdf(x);
214    canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219    // mu''''' = Phi^{(5)}(eta) = (eta^4 - 6*eta^2 + 3) * phi(eta).
220    if eta.is_nan() {
221        return f64::NAN;
222    }
223    if !eta.is_finite() {
224        return 0.0;
225    }
226    let x = eta;
227    let phi = normal_pdf(x);
228    canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231/// Multiply two 5-term truncated Taylor series (coefficients `a_k = g^(k)/k!`,
232/// `k = 0..=4`) and return the truncated product coefficients.
233#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235    let mut c = [0.0_f64; 5];
236    for i in 0..5 {
237        let ai = a[i];
238        if ai == 0.0 {
239            continue;
240        }
241        for j in 0..(5 - i) {
242            c[i + j] += ai * b[j];
243        }
244    }
245    c
246}
247
248/// Reciprocal of a 5-term truncated Taylor series with nonzero constant term.
249#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251    let mut b = [0.0_f64; 5];
252    b[0] = 1.0 / a[0];
253    for k in 1..5 {
254        let mut s = 0.0_f64;
255        for j in 1..=k {
256            s += a[j] * b[k - j];
257        }
258        b[k] = -s * b[0];
259    }
260    b
261}
262
263/// 5-jet (value + four eta-derivatives) of the GLM Fisher working weight
264/// `W(eta) = mu'(eta)^2 / V(mu(eta))` for the requested standard link, returned
265/// as `(W, W', W'', W''', W'''')`.
266///
267/// For the canonical logit link this is exactly the binomial weight
268/// `W = mu(1 - mu) = mu'`, whose eta-derivatives are the higher derivatives of
269/// the inverse-link jet (`W^(k) = mu^(k+1)`); the dispatch returns
270/// `logit_inverse_link_jet5`'s `d1..d5` byte-for-byte so the existing Firth
271/// logit path is numerically unchanged.
272///
273/// Noncanonical Bernoulli links use the same truncated Taylor-series quotient:
274/// assemble the inverse-link jet through `mu^(5)`, square the `mu'` series, and
275/// divide by the Bernoulli variance series `mu(1-mu)`. As the variance
276/// denominator saturates to zero in either tail, the weight and all derivatives
277/// saturate to zero, matching the inverse-link jet convention.
278pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279    match link {
280        StandardLink::Logit => {
281            let jet = logit_inverse_link_jet5(eta);
282            (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283        }
284        StandardLink::Probit => probit_fisher_weight_jet5(eta),
285        StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286        StandardLink::LogLog => component_fisher_weight_jet5(LinkComponent::LogLog, eta),
287        StandardLink::Cauchit => component_fisher_weight_jet5(LinkComponent::Cauchit, eta),
288        StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
289    }
290}
291
292pub(crate) fn fisher_weight_jet5_for_inverse_link(
293    link: &InverseLink,
294    eta: f64,
295) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
296    match link {
297        InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
298        InverseLink::LatentCLogLog(_)
299        | InverseLink::Sas(_)
300        | InverseLink::BetaLogistic(_)
301        | InverseLink::Mixture(_) => {
302            let jet = link.jet(eta)?;
303            let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
304            let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
305            Ok(fisher_weight_jet5_from_inverse_link_derivatives(
306                jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
307            ))
308        }
309    }
310}
311
312#[inline]
313pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
314    // Every standard binomial probability link exposes a full 5-jet Fisher weight
315    // via `fisher_weight_jet5` — LogLog and Cauchit included (their d1..d5 close in
316    // the same stable `poly·exp(-·)` / rational forms as CLogLog/Probit). The gate
317    // must therefore admit them: excluding LogLog/Cauchit here (while the jet, the
318    // gam-spec classifier, `reml_jeffreys_supported_link`, and `is_legal_cell` all
319    // support them) would refuse Firth/Jeffreys on a fully-implemented link.
320    matches!(
321        link,
322        InverseLink::Standard(
323            StandardLink::Logit
324                | StandardLink::Probit
325                | StandardLink::CLogLog
326                | StandardLink::LogLog
327                | StandardLink::Cauchit,
328        )
329            | InverseLink::LatentCLogLog(_)
330            | InverseLink::Sas(_)
331            | InverseLink::BetaLogistic(_)
332            | InverseLink::Mixture(_)
333    )
334}
335
336#[inline]
337fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
338    let jet = component_inverse_link_jet(component, eta);
339    let d4 = component_inverse_link_pdfthird_derivative(component, eta);
340    let d5 = component_inverse_link_pdffourth_derivative(component, eta);
341    fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
342}
343
344#[inline]
345fn fisher_weight_jet5_from_inverse_link_derivatives(
346    mu: f64,
347    d1: f64,
348    d2: f64,
349    d3: f64,
350    d4: f64,
351    d5: f64,
352) -> (f64, f64, f64, f64, f64) {
353    if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
354        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
355    }
356    let variance = mu * (1.0 - mu);
357    if !(variance > 0.0) || !variance.is_finite() {
358        return (0.0, 0.0, 0.0, 0.0, 0.0);
359    }
360
361    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
362    let mu_d = [mu, d1, d2, d3, d4];
363    let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
364    let dmu_d = [d1, d2, d3, d4, d5];
365    let mut mu_t = [0.0_f64; 5];
366    let mut one_minus_mu_t = [0.0_f64; 5];
367    let mut dmu_t = [0.0_f64; 5];
368    for k in 0..5 {
369        let inv_fact = 1.0 / factorial[k];
370        mu_t[k] = mu_d[k] * inv_fact;
371        one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
372        dmu_t[k] = dmu_d[k] * inv_fact;
373    }
374    let num_t = taylor5_mul(&dmu_t, &dmu_t);
375    let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
376    if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
377        return (0.0, 0.0, 0.0, 0.0, 0.0);
378    }
379    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
380    (
381        canonicalzero(w_t[0] * factorial[0]),
382        canonicalzero(w_t[1] * factorial[1]),
383        canonicalzero(w_t[2] * factorial[2]),
384        canonicalzero(w_t[3] * factorial[3]),
385        canonicalzero(w_t[4] * factorial[4]),
386    )
387}
388
389/// Probit Bernoulli Fisher-weight 5-jet `W = phi^2 / (Phi (1 - Phi))` and its
390/// first four eta-derivatives. See [`fisher_weight_jet5`].
391#[inline]
392fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
393    if eta.is_nan() {
394        return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
395    }
396    if !eta.is_finite() {
397        return (0.0, 0.0, 0.0, 0.0, 0.0);
398    }
399    let x = eta;
400    let p = normal_cdf(x);
401    // Compute the complement directly via Phi(-x) rather than `1 - Phi(x)`:
402    // in the positive tail `Phi(x)` rounds to 1.0 and `1 - Phi(x)` cancels to
403    // zero, whereas `Phi(-x)` retains the accurate (tiny) tail mass.
404    let q = normal_cdf(-x);
405    let phi = normal_pdf(x);
406    // Saturated tail: the denominator Phi(1-Phi) has underflowed to zero (or
407    // would divide by zero); the working weight and all derivatives go to zero.
408    if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
409        return (0.0, 0.0, 0.0, 0.0, 0.0);
410    }
411    // Gaussian derivative ladder: phi^(k) for k = 0..=4 using phi' = -x phi.
412    let phi1 = -x * phi;
413    let phi2 = (x * x - 1.0) * phi;
414    let phi3 = -(x * x * x - 3.0 * x) * phi;
415    let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
416    // Derivative arrays (d^k/deta^k) for f = phi, p = Phi, q = 1 - Phi.
417    // p^(0) = Phi, p^(k>=1) = phi^(k-1); q is the negated complement.
418    let f_d = [phi, phi1, phi2, phi3, phi4];
419    let p_d = [p, phi, phi1, phi2, phi3];
420    let q_d = [q, -phi, -phi1, -phi2, -phi3];
421    // Convert derivative arrays to Taylor coefficients a_k = g^(k)/k!.
422    let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
423    let mut f_t = [0.0_f64; 5];
424    let mut p_t = [0.0_f64; 5];
425    let mut q_t = [0.0_f64; 5];
426    for k in 0..5 {
427        let inv_fact = 1.0 / factorial[k];
428        f_t[k] = f_d[k] * inv_fact;
429        p_t[k] = p_d[k] * inv_fact;
430        q_t[k] = q_d[k] * inv_fact;
431    }
432    let num_t = taylor5_mul(&f_t, &f_t);
433    let den_t = taylor5_mul(&p_t, &q_t);
434    let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
435    // Back to derivatives W^(k) = w_t[k] * k!.
436    (
437        canonicalzero(w_t[0] * factorial[0]),
438        canonicalzero(w_t[1] * factorial[1]),
439        canonicalzero(w_t[2] * factorial[2]),
440        canonicalzero(w_t[3] * factorial[3]),
441        canonicalzero(w_t[4] * factorial[4]),
442    )
443}
444
445#[inline]
446fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
447    InverseLinkJet {
448        mu: base.mu,
449        d1: base.d1 * z1,
450        d2: base.d2 * z1 * z1 + base.d1 * z2,
451        d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
452    }
453}
454
455#[inline]
456fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
457    match component {
458        LinkComponent::Probit => probit_pdfthird_derivative(eta),
459        LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
460        LinkComponent::CLogLog => {
461            // CLogLog link:
462            //   mu = 1 - exp(-t),  t = exp(eta),  d1 = t exp(-t).
463            //
464            // Repeated differentiation closes in the basis `d1 * poly(t)`:
465            //   d2 = d1(-t + 1)
466            //   d3 = d1(t² - 3t + 1)
467            //   d4 = d1(-t³ + 6t² - 7t + 1).
468            if eta.is_nan() {
469                return f64::NAN;
470            }
471            if !eta.is_finite() {
472                return 0.0;
473            }
474            let t = eta.exp();
475            canonicalzero(stable_nonnegative_poly_times_exp_neg(
476                t,
477                &[0.0, 1.0, -7.0, 6.0, -1.0],
478            ))
479        }
480        LinkComponent::LogLog => {
481            // LogLog link is the reflected cloglog family with `r = exp(-eta)`:
482            //   mu = exp(-r), d1 = mu r,
483            // and again higher derivatives are `d1 * poly(r)`:
484            //   d2 = d1(r - 1)
485            //   d3 = d1(r² - 3r + 1)
486            //   d4 = d1(r³ - 6r² + 7r - 1).
487            if eta.is_nan() {
488                return f64::NAN;
489            }
490            if !eta.is_finite() {
491                return 0.0;
492            }
493            let r = (-eta).exp();
494            canonicalzero(stable_nonnegative_poly_times_exp_neg(
495                r,
496                &[0.0, -1.0, 7.0, -6.0, 1.0],
497            ))
498        }
499        LinkComponent::Cauchit => {
500            // Cauchit link:
501            //   mu = 1/2 + atan(eta)/pi,
502            //   d1 = 1 / [pi (1+eta²)].
503            //
504            // Differentiating three more times gives
505            //
506            //   d4 = 24 eta (1-eta²) / [pi (1+eta²)^4].
507            if eta.is_nan() {
508                return f64::NAN;
509            }
510            if !eta.is_finite() {
511                return 0.0;
512            }
513            let denom = 1.0 + eta * eta;
514            24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
515        }
516    }
517}
518
519/// Fifth derivative of a component inverse-link CDF (= fourth derivative of PDF).
520/// Extends `component_inverse_link_pdfthird_derivative` by one derivative order.
521#[inline]
522fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
523    match component {
524        LinkComponent::Probit => probit_pdffourth_derivative(eta),
525        LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
526        LinkComponent::CLogLog => {
527            // Exact closed form:
528            //   d5 = exp(-t) * (t - 15t^2 + 25t^3 - 10t^4 + t^5)
529            //      = d1 * (1 - 15t + 25t^2 - 10t^3 + t^4),
530            // where t = exp(eta).
531            if eta.is_nan() {
532                return f64::NAN;
533            }
534            if !eta.is_finite() {
535                return 0.0;
536            }
537            let t = eta.exp();
538            canonicalzero(stable_nonnegative_poly_times_exp_neg(
539                t,
540                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
541            ))
542        }
543        LinkComponent::LogLog => {
544            // Exact closed form:
545            //   d5 = exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5)
546            //      = d1 * (1 - 15r + 25r^2 - 10r^3 + r^4),
547            // where r = exp(-eta).
548            if eta.is_nan() {
549                return f64::NAN;
550            }
551            if !eta.is_finite() {
552                return 0.0;
553            }
554            let r = (-eta).exp();
555            canonicalzero(stable_nonnegative_poly_times_exp_neg(
556                r,
557                &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
558            ))
559        }
560        LinkComponent::Cauchit => {
561            // d5 = 24(1 - 10eta^2 + 5eta^4) / [pi * (1+eta^2)^5]
562            if eta.is_nan() {
563                return f64::NAN;
564            }
565            if !eta.is_finite() {
566                return 0.0;
567            }
568            let e2 = eta * eta;
569            let denom = 1.0 + e2;
570            24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
571        }
572    }
573}
574
575#[derive(Clone, Debug, PartialEq)]
576pub struct MixtureJetWithRhoPartials {
577    pub jet: InverseLinkJet,
578    /// Partial derivatives wrt free logits rho_j, j in [0, K-2].
579    /// Each entry stores derivatives of (mu, d1, d2, d3) wrt one rho_j.
580    pub djet_drho: Vec<InverseLinkJet>,
581}
582
583#[derive(Clone, Debug, PartialEq)]
584pub struct SasJetWithParamPartials {
585    pub jet: InverseLinkJet,
586    pub djet_depsilon: InverseLinkJet,
587    pub djet_dlog_delta: InverseLinkJet,
588}
589
590#[derive(Clone, Debug, PartialEq)]
591pub enum LinkParamPartials {
592    Mixture(MixtureJetWithRhoPartials),
593    Sas(SasJetWithParamPartials),
594}
595
596/// Trait-based inverse-link kernel interface.
597///
598/// Implementors provide pointwise inverse-link derivatives wrt `eta`:
599/// `F(eta), F'(eta), F''(eta), F'''(eta)`.
600/// Optionally they may expose parameter partials used by outer-loop optimization.
601pub trait InverseLinkKernel {
602    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
603
604    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
605        assert!(eta.is_finite(), "eta must be finite");
606        Ok(None)
607    }
608}
609
610#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
611pub struct ProbitLinkKernel;
612
613#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
614pub struct LogitLinkKernel;
615
616#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
617pub struct CLogLogLinkKernel;
618
619#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
620pub struct LogLogLinkKernel;
621
622#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
623pub struct CauchitLinkKernel;
624
625/// Construct SAS state from raw optimizer parameters using the same bounded
626/// transform used everywhere in fitting/evaluation.
627///
628/// A free function rather than an inherent `SasLinkState::new` because the
629/// bounded `delta` transform is solver-side math, so the constructor is hosted
630/// here next to the transform rather than on the type. `SasLinkState`'s fields
631/// are `pub`, so it builds directly.
632pub fn sas_link_state_from_raw(
633    raw_epsilon: f64,
634    raw_log_delta: f64,
635) -> Result<SasLinkState, String> {
636    if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
637        return Err("SAS link parameters must be finite".to_string());
638    }
639    Ok(SasLinkState {
640        epsilon: raw_epsilon,
641        log_delta: raw_log_delta,
642        delta: sas_delta_from_raw_log_delta(raw_log_delta),
643    })
644}
645
646pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
647    sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
648}
649
650pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
651    if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
652        return Err("Beta-Logistic link parameters must be finite".to_string());
653    }
654    // For Beta-Logistic, `log_delta` is the unconstrained log geometric-mean beta
655    // shape (the kernels' `log_shape_center`). Evaluation consumes `log_delta`,
656    // never `delta`, but keep the shared `SasLinkState::delta` field on the same
657    // bounded SAS parameterization used by `state_from_sasspec` so constructing a
658    // state from a large finite raw log-delta cannot overflow this derived field.
659    let log_shape_center = spec.initial_log_delta;
660    Ok(SasLinkState {
661        epsilon: spec.initial_epsilon,
662        log_delta: log_shape_center,
663        delta: sas_delta_from_raw_log_delta(log_shape_center),
664    })
665}
666
667#[inline]
668fn tanh_bound(value: f64, bound: f64) -> f64 {
669    let b = bound.max(f64::EPSILON);
670    b * (value / b).tanh()
671}
672
673#[inline]
674fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
675    let b = bound.max(f64::EPSILON);
676    let t = (value / b).tanh();
677    1.0 - t * t
678}
679
680#[inline]
681fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
682    let b = bound.max(f64::EPSILON);
683    let t = (value / b).tanh();
684    let s = 1.0 - t * t;
685    -2.0 * t * s / b
686}
687
688#[inline]
689fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
690    let b = bound.max(f64::EPSILON);
691    let t = (value / b).tanh();
692    let s = 1.0 - t * t;
693    -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
694}
695
696#[inline]
697fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
698    let b = bound.max(f64::EPSILON);
699    let t = (value / b).tanh();
700    let s = 1.0 - t * t;
701    8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
702}
703
704#[inline]
705fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
706    // 5th derivative of B * tanh(x/B):
707    //   g5 = 8 * s * (2 - 15*t^2 + 15*t^4) / B^4
708    // where t = tanh(x/B) and s = 1 - t^2.
709    let b = bound.max(f64::EPSILON);
710    let t = (value / b).tanh();
711    let s = 1.0 - t * t;
712    let t2 = t * t;
713    let b4 = b * b * b * b;
714    8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
715}
716
717#[inline]
718fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
719    let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
720    let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
721    (ld_eff, dld_eff_draw)
722}
723
724#[inline]
725fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
726    let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
727    ld_eff.exp()
728}
729
730pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
731    if spec.components.is_empty() {
732        return Err("mixture link requires at least 1 component".to_string());
733    }
734    if spec.initial_rho.len() + 1 != spec.components.len() {
735        return Err(format!(
736            "mixture link rho length mismatch: expected {}, got {}",
737            spec.components.len() - 1,
738            spec.initial_rho.len()
739        ));
740    }
741    for i in 0..spec.components.len() {
742        for j in (i + 1)..spec.components.len() {
743            if spec.components[i] == spec.components[j] {
744                return Err("mixture link components must be unique".to_string());
745            }
746        }
747    }
748    // `LinkComponent` admits two variants (Cauchit, LogLog) that have no matching
749    // `LinkFunction` entry. When two or more components are *blended*, the mixture-link
750    // pipeline projects the blend back onto a single `LinkFunction` value for downstream
751    // solver/IO bookkeeping (see `InverseLink::link_function`), so a multi-component blend
752    // composed solely of components without a LinkFunction representative would silently
753    // lie about its projected link. We therefore require any genuine *blend* (two or more
754    // components) to contain at least one Logit/Probit/CLogLog "anchor" so the projection
755    // is meaningful, and reject e.g. a blend of only {Cauchit, LogLog}.
756    //
757    // A *single-component* spec is not a blend at all: it is that one link, with weight
758    // 1.0 and no free mixing logits. `LinkComponent::LogLog` / `LinkComponent::Cauchit`
759    // implement their inverse link and derivative jets exactly, so a single-component
760    // `{LogLog}` / `{Cauchit}` spec is a fully-defined standalone link and is accepted
761    // here (this is how survival `--link loglog` / `--link cauchit` are represented).
762    let has_anchor = spec.components.iter().any(|component| {
763        matches!(
764            component,
765            LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
766        )
767    });
768    if !has_anchor && spec.components.len() > 1 {
769        let unsupported: Vec<&str> = spec
770            .components
771            .iter()
772            .map(|component| component.name())
773            .collect();
774        return Err(format!(
775            "mixture link components {{{}}} are unsupported: at least one component \
776             must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
777             projected LinkFunction is well defined; cauchit and loglog have no \
778             LinkFunction representative",
779            unsupported.join(", ")
780        ));
781    }
782    Ok(())
783}
784
785pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
786    let k = rho.len() + 1;
787    let mut logits = Vec::with_capacity(k);
788    let mut maxv = 0.0_f64;
789    for &v in rho {
790        maxv = maxv.max(v);
791        logits.push(v);
792    }
793    maxv = maxv.max(0.0);
794    logits.push(0.0);
795
796    let mut sum = 0.0_f64;
797    let mut exps = vec![0.0_f64; k];
798    for i in 0..k {
799        let e = (logits[i] - maxv).exp();
800        exps[i] = e;
801        sum += e;
802    }
803    if !sum.is_finite() || sum <= 0.0 {
804        return Array1::from_elem(k, 1.0 / k as f64);
805    }
806    let inv = 1.0 / sum;
807    Array1::from_iter(exps.into_iter().map(|v| v * inv))
808}
809
810/// Returns softmax weights and Jacobian wrt free logits (last logit fixed at zero).
811/// Jacobian shape is (K, K-1): d pi_k / d rho_j.
812pub fn softmaxwith_jacobian_last_fixedzero(
813    rho: &Array1<f64>,
814) -> (Array1<f64>, ndarray::Array2<f64>) {
815    let pi = softmax_last_fixedzero(rho);
816    let k = pi.len();
817    let m = k.saturating_sub(1);
818    let mut jac = ndarray::Array2::<f64>::zeros((k, m));
819    for j in 0..m {
820        let pi_j = pi[j];
821        for kk in 0..k {
822            let delta = if kk == j { 1.0 } else { 0.0 };
823            jac[[kk, j]] = pi[kk] * (delta - pi_j);
824        }
825    }
826    (pi, jac)
827}
828
829pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
830    validate_mixturespec(spec)?;
831    let pi = softmax_last_fixedzero(&spec.initial_rho);
832    Ok(MixtureLinkState {
833        components: spec.components.clone(),
834        rho: spec.initial_rho.clone(),
835        pi,
836    })
837}
838
839#[inline]
840pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
841    canonicalize_jet(match component {
842        LinkComponent::Logit => {
843            let jet = logit_inverse_link_jet5(eta);
844            InverseLinkJet {
845                mu: jet.mu,
846                d1: jet.d1,
847                d2: jet.d2,
848                d3: jet.d3,
849            }
850        }
851        LinkComponent::Probit => probit_jet(eta),
852        LinkComponent::CLogLog => {
853            if eta.is_nan() {
854                return InverseLinkJet {
855                    mu: f64::NAN,
856                    d1: f64::NAN,
857                    d2: f64::NAN,
858                    d3: f64::NAN,
859                };
860            }
861            let t = eta.exp();
862            if !t.is_finite() {
863                return InverseLinkJet {
864                    mu: 1.0,
865                    d1: 0.0,
866                    d2: 0.0,
867                    d3: 0.0,
868                };
869            }
870            InverseLinkJet {
871                mu: -(-t).exp_m1(),
872                d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
873                d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
874                d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
875            }
876        }
877        LinkComponent::LogLog => {
878            if eta.is_nan() {
879                return InverseLinkJet {
880                    mu: f64::NAN,
881                    d1: f64::NAN,
882                    d2: f64::NAN,
883                    d3: f64::NAN,
884                };
885            }
886            let r = (-eta).exp();
887            if !r.is_finite() {
888                return InverseLinkJet {
889                    mu: 0.0,
890                    d1: 0.0,
891                    d2: 0.0,
892                    d3: 0.0,
893                };
894            }
895            InverseLinkJet {
896                mu: (-r).exp(),
897                d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
898                d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
899                d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
900            }
901        }
902        LinkComponent::Cauchit => {
903            if eta.is_nan() {
904                return InverseLinkJet {
905                    mu: f64::NAN,
906                    d1: f64::NAN,
907                    d2: f64::NAN,
908                    d3: f64::NAN,
909                };
910            }
911            let den = 1.0 + eta * eta;
912            let d1 = if eta.is_finite() {
913                1.0 / (std::f64::consts::PI * den)
914            } else {
915                0.0
916            };
917            let d2 = if eta.is_finite() {
918                -2.0 * eta / (std::f64::consts::PI * den * den)
919            } else {
920                0.0
921            };
922            let d3 = if eta.is_finite() {
923                (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
924            } else {
925                0.0
926            };
927            InverseLinkJet {
928                mu: 0.5 + eta.atan() / std::f64::consts::PI,
929                d1,
930                d2,
931                d3,
932            }
933        }
934    })
935}
936
937impl InverseLinkKernel for ProbitLinkKernel {
938    #[inline]
939    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
940        Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
941    }
942}
943
944impl InverseLinkKernel for LogitLinkKernel {
945    #[inline]
946    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
947        Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
948    }
949}
950
951impl InverseLinkKernel for CLogLogLinkKernel {
952    #[inline]
953    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
954        Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
955    }
956}
957
958impl InverseLinkKernel for LogLogLinkKernel {
959    #[inline]
960    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
961        Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
962    }
963}
964
965impl InverseLinkKernel for CauchitLinkKernel {
966    #[inline]
967    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
968        Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
969    }
970}
971
972impl InverseLinkKernel for LinkComponent {
973    #[inline]
974    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
975        Ok(component_inverse_link_jet(*self, eta))
976    }
977}
978
979impl InverseLinkKernel for LinkFunction {
980    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
981        match self {
982            LinkFunction::Logit => LogitLinkKernel.jet(eta),
983            LinkFunction::Probit => ProbitLinkKernel.jet(eta),
984            LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
985            LinkFunction::LogLog => LogLogLinkKernel.jet(eta),
986            LinkFunction::Cauchit => CauchitLinkKernel.jet(eta),
987            LinkFunction::Identity => Ok(InverseLinkJet {
988                mu: eta,
989                d1: 1.0,
990                d2: 0.0,
991                d3: 0.0,
992            }),
993            LinkFunction::Log => {
994                // SOLVER-INTERNAL inverse-link jet: `η.clamp(−700, 700).exp()`.
995                // The clamp is an intentional conditioning hack so the IRLS/REML
996                // normal equations stay well posed when η wanders into the tails
997                // during a trust-region step — it is NOT the public response
998                // transform. Public response-scale outputs (predictions, FFI
999                // `apply_inverse_link_array`, posterior bands) must use the EXACT
1000                // `exp(η)` in `families::inverse_link::apply_inverse_link_vec`,
1001                // which is finite wherever representable. Do not reroute a public
1002                // output through this clamped jet (issue #963). Keep the clamp:
1003                // solver consumers (e.g. `reml/runtime.rs` trust-region `excess`)
1004                // pass raw η and rely on it to keep μ finite.
1005                let e = eta.clamp(-700.0, 700.0).exp();
1006                Ok(InverseLinkJet {
1007                    mu: e,
1008                    d1: e,
1009                    d2: e,
1010                    d3: e,
1011                })
1012            }
1013            LinkFunction::Sas => Err(EstimationError::InvalidInput(
1014                "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1015            )),
1016            LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1017                "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1018                    .to_string(),
1019            )),
1020        }
1021    }
1022}
1023
1024impl InverseLinkKernel for SasLinkState {
1025    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1026        Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1027    }
1028
1029    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1030        Ok(Some(LinkParamPartials::Sas(
1031            sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1032        )))
1033    }
1034}
1035
1036#[derive(Clone, Copy, Debug)]
1037pub struct BetaLogisticKernel {
1038    /// Unconstrained log of the geometric-mean beta shape — the raw optimization
1039    /// parameter `SasLinkState::log_delta`, NOT the derived `SasLinkState::delta`.
1040    pub log_shape_center: f64,
1041    pub epsilon: f64,
1042}
1043
1044impl InverseLinkKernel for BetaLogisticKernel {
1045    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1046        Ok(beta_logistic_inverse_link_jet(
1047            eta,
1048            self.log_shape_center,
1049            self.epsilon,
1050        ))
1051    }
1052
1053    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1054        Ok(Some(LinkParamPartials::Sas(
1055            beta_logistic_inverse_link_jetwith_param_partials(
1056                eta,
1057                self.log_shape_center,
1058                self.epsilon,
1059            ),
1060        )))
1061    }
1062}
1063
1064impl InverseLinkKernel for MixtureLinkState {
1065    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1066        Ok(mixture_inverse_link_jet(self, eta))
1067    }
1068
1069    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1070        Ok(Some(LinkParamPartials::Mixture(
1071            mixture_inverse_link_jetwith_rho_partials(self, eta),
1072        )))
1073    }
1074}
1075
1076impl InverseLinkKernel for InverseLink {
1077    fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1078        match self {
1079            InverseLink::Standard(StandardLink::Logit) => LogitLinkKernel.jet(eta),
1080            InverseLink::Standard(StandardLink::Probit) => ProbitLinkKernel.jet(eta),
1081            InverseLink::Standard(StandardLink::CLogLog) => CLogLogLinkKernel.jet(eta),
1082            InverseLink::Standard(StandardLink::LogLog) => LogLogLinkKernel.jet(eta),
1083            InverseLink::Standard(StandardLink::Cauchit) => CauchitLinkKernel.jet(eta),
1084            InverseLink::Standard(StandardLink::Identity) => LinkFunction::Identity.jet(eta),
1085            InverseLink::Standard(StandardLink::Log) => LinkFunction::Log.jet(eta),
1086            InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1087            InverseLink::Sas(state) => state.jet(eta),
1088            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1089                log_shape_center: state.log_delta,
1090                epsilon: state.epsilon,
1091            }
1092            .jet(eta),
1093            InverseLink::Mixture(state) => state.jet(eta),
1094        }
1095    }
1096
1097    fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1098        match self {
1099            InverseLink::Standard(_) => Ok(None),
1100            InverseLink::LatentCLogLog(_) => Ok(None),
1101            InverseLink::Sas(state) => state.param_partials(eta),
1102            InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1103                log_shape_center: state.log_delta,
1104                epsilon: state.epsilon,
1105            }
1106            .param_partials(eta),
1107            InverseLink::Mixture(state) => state.param_partials(eta),
1108        }
1109    }
1110}
1111
1112/// Central family-aware inverse-link jet dispatch.
1113///
1114/// For `BinomialSas` and `BinomialMixture`, required state must be provided.
1115pub fn inverse_link_jet_for_inverse_link(
1116    link: &InverseLink,
1117    eta: f64,
1118) -> Result<InverseLinkJet, EstimationError> {
1119    link.jet(eta)
1120}
1121
1122/// Specialized `(mu, d1)` inverse-link evaluation that skips the d2/d3
1123/// polynomial chain used by the full jet. Numerical semantics are preserved:
1124/// the returned `mu` and `d1` are bit-identical to the corresponding fields of
1125/// `inverse_link_jet_for_inverse_link(link, eta)?` for every supported link.
1126///
1127/// For latent cloglog the underlying lognormal-Laplace kernel produces all
1128/// orders together, so this falls back to the full jet for that branch — the
1129/// savings come from the parameterised polynomial links (SAS, beta-logistic,
1130/// mixture) and the simple analytic links where d2/d3 are pure waste.
1131pub fn inverse_link_mu_d1_for_inverse_link(
1132    link: &InverseLink,
1133    eta: f64,
1134) -> Result<(f64, f64), EstimationError> {
1135    match link {
1136        InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1137        InverseLink::LatentCLogLog(state) => {
1138            let jet = latent_cloglog_point_jet(state, eta)?;
1139            Ok((jet.mu, jet.d1))
1140        }
1141        InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1142        InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1143            eta,
1144            state.log_delta,
1145            state.epsilon,
1146        )),
1147        InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1148    }
1149}
1150
1151fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1152    match link {
1153        LinkFunction::Identity => Ok((eta, 1.0)),
1154        LinkFunction::Log => {
1155            // SOLVER-INTERNAL clamped `(μ, dμ/dη)`; see the matching note on the
1156            // full `LinkFunction::Log` jet above. Public response transforms use
1157            // exact `exp(η)` via `families::inverse_link::apply_inverse_link_vec`
1158            // (issue #963).
1159            let e = eta.clamp(-700.0, 700.0).exp();
1160            Ok((e, e))
1161        }
1162        LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1163        LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1164        LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1165        LinkFunction::LogLog => Ok(component_inverse_link_mu_d1(LinkComponent::LogLog, eta)),
1166        LinkFunction::Cauchit => Ok(component_inverse_link_mu_d1(LinkComponent::Cauchit, eta)),
1167        LinkFunction::Sas => Err(EstimationError::InvalidInput(
1168            "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1169        )),
1170        LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1171            "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1172                .to_string(),
1173        )),
1174    }
1175}
1176
1177#[inline]
1178fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1179    // The full per-component jet already factors `mu` and `d1` exactly the same
1180    // way the higher orders are derived, so we either reuse the cheap closed
1181    // forms directly (Logit/Probit/CLogLog/LogLog/Cauchit) or fall back to the
1182    // existing canonicalised jet for the few cases without a separate fast
1183    // path — bit-identical to `component_inverse_link_jet(...).{mu,d1}`.
1184    match component {
1185        LinkComponent::Logit => {
1186            let jet = logit_inverse_link_jet5(eta);
1187            (jet.mu, canonicalzero(jet.d1))
1188        }
1189        LinkComponent::Probit => {
1190            if eta.is_nan() {
1191                return (f64::NAN, f64::NAN);
1192            }
1193            if eta == f64::INFINITY {
1194                return (1.0, 0.0);
1195            }
1196            if eta == f64::NEG_INFINITY {
1197                return (0.0, 0.0);
1198            }
1199            let phi = normal_pdf(eta);
1200            (normal_cdf(eta), canonicalzero(phi))
1201        }
1202        LinkComponent::CLogLog => {
1203            if eta.is_nan() {
1204                return (f64::NAN, f64::NAN);
1205            }
1206            let t = eta.exp();
1207            if !t.is_finite() {
1208                return (1.0, 0.0);
1209            }
1210            (
1211                -(-t).exp_m1(),
1212                canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1213            )
1214        }
1215        LinkComponent::LogLog => {
1216            if eta.is_nan() {
1217                return (f64::NAN, f64::NAN);
1218            }
1219            let r = (-eta).exp();
1220            if !r.is_finite() {
1221                return (0.0, 0.0);
1222            }
1223            (
1224                (-r).exp(),
1225                canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1226            )
1227        }
1228        LinkComponent::Cauchit => {
1229            if eta.is_nan() {
1230                return (f64::NAN, f64::NAN);
1231            }
1232            let den = 1.0 + eta * eta;
1233            let d1 = if eta.is_finite() {
1234                1.0 / (std::f64::consts::PI * den)
1235            } else {
1236                0.0
1237            };
1238            (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1239        }
1240    }
1241}
1242
1243fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1244    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1245    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1246        return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1247    }
1248    let e = if eta.is_finite() { eta } else { 0.0 };
1249    let a = e.asinh();
1250    let delta = delta_id;
1251    let u_raw = delta * a + epsilon;
1252    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1253    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1254    let s = u.sinh();
1255    let c = u.cosh();
1256    let z = s;
1257    let q = e.hypot(1.0);
1258    let inv_q = 1.0 / q;
1259    let r1 = delta * inv_q;
1260    let u1 = g1 * r1;
1261    let z1 = c * u1;
1262    // `mu = Phi(z)` and `d1 = phi(z) * z1`, the same closed forms used by the
1263    // full jet via `chain_inverse_link_jet(probit_jet(z), z1, _, _)`.
1264    let base = probit_jet(z);
1265    (base.mu, canonicalzero(base.d1 * z1))
1266}
1267
1268fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1269    let logistic = logistic_uwith_derivatives(eta);
1270    let a = (delta - epsilon).exp();
1271    let b = (delta + epsilon).exp();
1272    let mu = beta_reg_logistic(a, b, logistic);
1273    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1274    (mu, log_d1.exp())
1275}
1276
1277fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1278    let mut mu = 0.0_f64;
1279    let mut d1 = 0.0_f64;
1280    let k = state.components.len().min(state.pi.len());
1281    for i in 0..k {
1282        let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1283        let w = state.pi[i];
1284        mu += w * mu_i;
1285        d1 += w * d1_i;
1286    }
1287    (mu, d1)
1288}
1289
1290#[derive(Clone, Copy)]
1291enum PdfDerivativeOrder {
1292    Third,
1293    Fourth,
1294}
1295
1296impl PdfDerivativeOrder {
1297    fn probit(self, eta: f64) -> f64 {
1298        match self {
1299            Self::Third => probit_pdfthird_derivative(eta),
1300            Self::Fourth => probit_pdffourth_derivative(eta),
1301        }
1302    }
1303
1304    fn component(self, component: LinkComponent, eta: f64) -> f64 {
1305        match self {
1306            Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1307            Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1308        }
1309    }
1310
1311    fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1312        let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1313        Ok(match self {
1314            Self::Third => jet.d4,
1315            Self::Fourth => jet.d5,
1316        })
1317    }
1318
1319    fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1320        match self {
1321            Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1322            Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1323        }
1324    }
1325
1326    fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1327        match self {
1328            Self::Third => {
1329                beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1330            }
1331            Self::Fourth => {
1332                beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1333            }
1334        }
1335    }
1336}
1337
1338fn inverse_link_pdf_derivative_for_inverse_link(
1339    link: &InverseLink,
1340    eta: f64,
1341    order: PdfDerivativeOrder,
1342) -> Result<f64, EstimationError> {
1343    match link {
1344        InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1345        InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1346        InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1347        InverseLink::Standard(StandardLink::Logit) => {
1348            Ok(order.component(LinkComponent::Logit, eta))
1349        }
1350        InverseLink::Standard(StandardLink::CLogLog) => {
1351            Ok(order.component(LinkComponent::CLogLog, eta))
1352        }
1353        InverseLink::Standard(StandardLink::LogLog) => {
1354            Ok(order.component(LinkComponent::LogLog, eta))
1355        }
1356        InverseLink::Standard(StandardLink::Cauchit) => {
1357            Ok(order.component(LinkComponent::Cauchit, eta))
1358        }
1359        InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1360        InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1361        InverseLink::BetaLogistic(state) => {
1362            Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1363        }
1364        InverseLink::Mixture(state) => Ok(state
1365            .components
1366            .iter()
1367            .zip(state.pi.iter())
1368            .map(|(&component, &weight)| weight * order.component(component, eta))
1369            .sum()),
1370    }
1371}
1372
1373pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1374    link: &InverseLink,
1375    eta: f64,
1376) -> Result<f64, EstimationError> {
1377    // This dispatch returns the fourth eta-derivative of the inverse-link CDF,
1378    // equivalently the third derivative of the inverse-link density
1379    //
1380    //   f(eta) = d/deta mu(eta).
1381    //
1382    // It is used downstream as the `f'''` input in
1383    //
1384    //   d³/deta³ log f = f'''/f - 3 f'f''/f² + 2(f')³/f³.
1385    //
1386    // Mixture links preserve linearity:
1387    //
1388    //   mu = sum_j pi_j mu_j
1389    //   => f''' = sum_j pi_j f_j'''
1390    //
1391    // because the mixture weights `pi_j` are constant with respect to `eta`.
1392    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1393}
1394
1395/// Fifth derivative of the inverse-link CDF (= fourth derivative of the PDF).
1396///
1397/// Extends `inverse_link_pdfthird_derivative_for_inverse_link` by one order.
1398/// Used for the outer REML Hessian Q[v_k, v_l] term in survival models,
1399/// specifically the `m1 * u_{abcd}` Arbogast contribution.
1400pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1401    link: &InverseLink,
1402    eta: f64,
1403) -> Result<f64, EstimationError> {
1404    inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1405}
1406
1407#[inline]
1408fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1409    // ApproxKind: NumericalApproximation — exp(exp(eta)) overflows f64 for
1410    // eta > ~3.7; the |eta| > 30 saturation tail returns survival ≈ 0 with
1411    // d_k = 0, matching the analytic limit to within IEEE-754 underflow.
1412    const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1413
1414    let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1415    let hazard = z.exp();
1416    let survival = (-hazard).exp();
1417    if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1418        return InverseLinkJet {
1419            mu: survival,
1420            d1: 0.0,
1421            d2: 0.0,
1422            d3: 0.0,
1423        };
1424    }
1425
1426    let d1 = -hazard * survival;
1427    let d2 = hazard * (hazard - 1.0) * survival;
1428    let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1429    InverseLinkJet {
1430        mu: survival,
1431        d1,
1432        d2,
1433        d3,
1434    }
1435}
1436
1437pub fn inverse_link_jet_for_family(
1438    spec: &LikelihoodSpec,
1439    eta: f64,
1440) -> Result<InverseLinkJet, EstimationError> {
1441    // RoystonParmar uses its own analytic survival inverse link irrespective of
1442    // the (nominal `Identity`) link slot carried in the spec.
1443    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1444        return Ok(royston_parmar_inverse_link_jet(eta));
1445    }
1446    spec.link.jet(eta)
1447}
1448
1449/// Exact-public log inverse-link jet: `mu = d1 = d2 = d3 = exp(η)` with NO
1450/// `η`-clamp. Sibling of the solver-internal `LinkFunction::Log` jet (which
1451/// clamps `η` to `[−700, 700]` as an IRLS/REML conditioning hack); see issue
1452/// #963. Every derivative of `exp` is `exp`, so all four jet slots carry the
1453/// same exact value — finite wherever representable, `0.0` on underflow,
1454/// `+∞` on overflow.
1455#[inline]
1456fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1457    let e = eta.exp();
1458    InverseLinkJet {
1459        mu: e,
1460        d1: e,
1461        d2: e,
1462        d3: e,
1463    }
1464}
1465
1466/// EXACT public inverse-link jet for response-scale prediction outputs.
1467///
1468/// Identical to [`inverse_link_jet_for_family`] for every link EXCEPT the
1469/// standard `Log` link, where it returns the exact `exp(η)` jet instead of the
1470/// solver's `η.clamp(−700, 700).exp()` conditioning transform. On finite `η`
1471/// the two diverge (η = 705: exact `exp(705)` ≈ 1.5e306 vs clamped
1472/// `exp(700)` ≈ 1.0e304; η = −720: exact ≈ 2e−313 vs clamped ≈ 9.9e−305), so
1473/// public predictions (`FamilyStrategy::inverse_link_jet`/`inverse_link_array`,
1474/// the predict response + delta-method SE path) route here. The solver/REML/
1475/// PIRLS engines keep the clamped jet (issue #963). For `|η| ≤ 700` this is
1476/// byte-identical to the clamped jet (the clamp is inert there), so no
1477/// in-range prediction changes.
1478pub fn inverse_link_jet_for_family_public(
1479    spec: &LikelihoodSpec,
1480    eta: f64,
1481) -> Result<InverseLinkJet, EstimationError> {
1482    if matches!(spec.response, ResponseFamily::RoystonParmar) {
1483        return Ok(royston_parmar_inverse_link_jet(eta));
1484    }
1485    if let InverseLink::Standard(StandardLink::Log) = spec.link {
1486        return Ok(log_inverse_link_jet_exact(eta));
1487    }
1488    spec.link.jet(eta)
1489}
1490
1491#[inline]
1492pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1493    let mut mu = 0.0_f64;
1494    let mut d1 = 0.0_f64;
1495    let mut d2 = 0.0_f64;
1496    let mut d3 = 0.0_f64;
1497    let k = state.components.len().min(state.pi.len());
1498    for i in 0..k {
1499        let jet = component_inverse_link_jet(state.components[i], eta);
1500        let w = state.pi[i];
1501        mu += w * jet.mu;
1502        d1 += w * jet.d1;
1503        d2 += w * jet.d2;
1504        d3 += w * jet.d3;
1505    }
1506    InverseLinkJet { mu, d1, d2, d3 }
1507}
1508
1509/// Computes mixture jet and exact partial derivatives wrt free softmax logits.
1510///
1511/// Uses identities:
1512///   d mu     / d rho_j = pi_j (mu_j     - mu)
1513///   d mu'    / d rho_j = pi_j (mu_j'    - mu')
1514///   d mu''   / d rho_j = pi_j (mu_j''   - mu'')
1515///   d mu'''  / d rho_j = pi_j (mu_j'''  - mu''')
1516pub fn mixture_inverse_link_jetwith_rho_partials(
1517    state: &MixtureLinkState,
1518    eta: f64,
1519) -> MixtureJetWithRhoPartials {
1520    let k = state.components.len().min(state.pi.len());
1521    let m = k.saturating_sub(1);
1522    let mut djet_drho = vec![
1523        InverseLinkJet {
1524            mu: 0.0,
1525            d1: 0.0,
1526            d2: 0.0,
1527            d3: 0.0,
1528        };
1529        m
1530    ];
1531    let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1532    MixtureJetWithRhoPartials { jet, djet_drho }
1533}
1534
1535/// Computes mixture jet and writes exact rho partial jets into `out` (length >= K-1).
1536/// This avoids heap allocation in hot loops.
1537pub fn mixture_inverse_link_jetwith_rho_partials_into(
1538    state: &MixtureLinkState,
1539    eta: f64,
1540    out: &mut [InverseLinkJet],
1541) -> InverseLinkJet {
1542    let k = state.components.len().min(state.pi.len());
1543    let m = k.saturating_sub(1);
1544    assert!(
1545        out.len() >= m,
1546        "rho-partial output buffer too small: got {}, need {}",
1547        out.len(),
1548        m
1549    );
1550    let mut mixed = InverseLinkJet {
1551        mu: 0.0,
1552        d1: 0.0,
1553        d2: 0.0,
1554        d3: 0.0,
1555    };
1556    for i in 0..k {
1557        let jet_i = component_inverse_link_jet(state.components[i], eta);
1558        let w = state.pi[i];
1559        mixed.mu += w * jet_i.mu;
1560        mixed.d1 += w * jet_i.d1;
1561        mixed.d2 += w * jet_i.d2;
1562        mixed.d3 += w * jet_i.d3;
1563        // Cache the first K-1 component jets directly in the output buffer so
1564        // we don't recompute them in the partial loop.
1565        if i < m {
1566            out[i] = jet_i;
1567        }
1568    }
1569    for j in 0..m {
1570        let pi_j = state.pi[j];
1571        let cj = out[j];
1572        out[j] = InverseLinkJet {
1573            mu: pi_j * (cj.mu - mixed.mu),
1574            d1: pi_j * (cj.d1 - mixed.d1),
1575            d2: pi_j * (cj.d2 - mixed.d2),
1576            d3: pi_j * (cj.d3 - mixed.d3),
1577        };
1578    }
1579    mixed
1580}
1581
1582#[derive(Clone, Copy)]
1583struct LogisticU {
1584    u: f64,
1585    one_minus_u: f64,
1586    ln_u: f64,
1587    ln_one_minus_u: f64,
1588    du: f64,
1589    use_upper_tail: bool,
1590}
1591
1592#[inline]
1593fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1594    let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1595    let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1596    let u = ln_u.exp();
1597    let one_minus_u = ln_one_minus_u.exp();
1598    let du = (ln_u + ln_one_minus_u).exp();
1599    LogisticU {
1600        u,
1601        one_minus_u,
1602        ln_u,
1603        ln_one_minus_u,
1604        du,
1605        use_upper_tail: eta >= 0.0,
1606    }
1607}
1608
1609#[inline]
1610fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1611    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1612        return f64::NAN;
1613    }
1614    if logistic.ln_u == f64::NEG_INFINITY {
1615        return 0.0;
1616    }
1617    if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1618        return 1.0;
1619    }
1620    if logistic.use_upper_tail {
1621        1.0 - beta_reg(b, a, logistic.one_minus_u)
1622    } else {
1623        beta_reg(a, b, logistic.u)
1624    }
1625}
1626
1627#[inline]
1628fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1629    if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1630        return (f64::NAN, f64::NAN, f64::NAN);
1631    }
1632    if logistic.use_upper_tail {
1633        let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1634        (1.0 - tail, -dtail_da, -dtail_db)
1635    } else {
1636        beta_reg_with_shape_partials(a, b, logistic.u)
1637    }
1638}
1639
1640#[inline]
1641fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1642    a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1643}
1644
1645#[derive(Clone, Copy)]
1646struct ShapeDual {
1647    v: f64,
1648    da: f64,
1649    db: f64,
1650}
1651
1652impl ShapeDual {
1653    #[inline]
1654    fn constant(v: f64) -> Self {
1655        Self {
1656            v,
1657            da: 0.0,
1658            db: 0.0,
1659        }
1660    }
1661
1662    #[inline]
1663    fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1664        Self { v, da, db }
1665    }
1666
1667    #[inline]
1668    fn clamp_small(self, floor: f64) -> Self {
1669        if self.v.abs() < floor {
1670            Self::constant(floor)
1671        } else {
1672            self
1673        }
1674    }
1675}
1676
1677impl std::ops::Add for ShapeDual {
1678    type Output = Self;
1679
1680    #[inline]
1681    fn add(self, rhs: Self) -> Self {
1682        Self {
1683            v: self.v + rhs.v,
1684            da: self.da + rhs.da,
1685            db: self.db + rhs.db,
1686        }
1687    }
1688}
1689
1690impl std::ops::Sub for ShapeDual {
1691    type Output = Self;
1692
1693    #[inline]
1694    fn sub(self, rhs: Self) -> Self {
1695        Self {
1696            v: self.v - rhs.v,
1697            da: self.da - rhs.da,
1698            db: self.db - rhs.db,
1699        }
1700    }
1701}
1702
1703impl std::ops::Mul for ShapeDual {
1704    type Output = Self;
1705
1706    #[inline]
1707    fn mul(self, rhs: Self) -> Self {
1708        Self {
1709            v: self.v * rhs.v,
1710            da: self.da * rhs.v + self.v * rhs.da,
1711            db: self.db * rhs.v + self.v * rhs.db,
1712        }
1713    }
1714}
1715
1716impl std::ops::Div for ShapeDual {
1717    type Output = Self;
1718
1719    #[inline]
1720    fn div(self, rhs: Self) -> Self {
1721        let inv = 1.0 / rhs.v;
1722        let inv2 = inv * inv;
1723        Self {
1724            v: self.v * inv,
1725            da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1726            db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1727        }
1728    }
1729}
1730
1731impl std::ops::Neg for ShapeDual {
1732    type Output = Self;
1733
1734    #[inline]
1735    fn neg(self) -> Self {
1736        ShapeDual {
1737            v: -self.v,
1738            da: -self.da,
1739            db: -self.db,
1740        }
1741    }
1742}
1743
1744#[inline]
1745fn shape_dual(v: f64) -> ShapeDual {
1746    ShapeDual::constant(v)
1747}
1748
1749// Analytic shape partials for I_x(a,b), obtained by differentiating the same
1750// regularized-beta continued fraction used by statrs. The normalizing term uses
1751// d log B(a,b) / da = psi(a) - psi(a+b) and likewise for b.
1752fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1753    if x0 <= 0.0 {
1754        return (0.0, 0.0, 0.0);
1755    }
1756    if x0 >= 1.0 {
1757        return (1.0, 0.0, 0.0);
1758    }
1759
1760    let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1761    let (a, b, x) = if symm_transform {
1762        (
1763            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1764            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1765            1.0 - x0,
1766        )
1767    } else {
1768        (
1769            ShapeDual::from_value_partials(a0, 1.0, 0.0),
1770            ShapeDual::from_value_partials(b0, 0.0, 1.0),
1771            x0,
1772        )
1773    };
1774
1775    let ln_x = x.ln();
1776    let ln_1mx = (1.0 - x).ln();
1777    let psi_ab = digamma(a.v + b.v);
1778    let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1779        - statrs::function::gamma::ln_gamma(a.v)
1780        - statrs::function::gamma::ln_gamma(b.v)
1781        + a.v * ln_x
1782        + b.v * ln_1mx;
1783    let bt_v = log_bt.exp();
1784    let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1785    let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1786    let bt = ShapeDual {
1787        v: bt_v,
1788        da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1789        db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1790    };
1791
1792    let eps = 0.00000000000000011102230246251565;
1793    let fpmin = f64::MIN_POSITIVE / eps;
1794    let one = shape_dual(1.0);
1795    let qab = a + b;
1796    let qap = a + one;
1797    let qam = a - one;
1798    let mut c = one;
1799    let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1800    d = one / d;
1801    let mut h = d;
1802
1803    for m in 1..141 {
1804        let mf = f64::from(m);
1805        let m2 = mf * 2.0;
1806        let md = shape_dual(mf);
1807        let m2d = shape_dual(m2);
1808        let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1809        d = (one + aa * d).clamp_small(fpmin);
1810        c = (one + aa / c).clamp_small(fpmin);
1811        d = one / d;
1812        h = h * d * c;
1813
1814        aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1815        d = (one + aa * d).clamp_small(fpmin);
1816        c = (one + aa / c).clamp_small(fpmin);
1817        d = one / d;
1818        let del = d * c;
1819        h = h * del;
1820
1821        if (del.v - 1.0).abs() <= eps {
1822            let reg = bt * h / a;
1823            return if symm_transform {
1824                (1.0 - reg.v, -reg.da, -reg.db)
1825            } else {
1826                (reg.v, reg.da, reg.db)
1827            };
1828        }
1829    }
1830    let reg = bt * h / a;
1831    if symm_transform {
1832        (1.0 - reg.v, -reg.da, -reg.db)
1833    } else {
1834        (reg.v, reg.da, reg.db)
1835    }
1836}
1837
1838/// Beta-Logistic inverse-link jet for:
1839///   u = logistic(eta)
1840///   a = exp(log_shape_center - epsilon), b = exp(log_shape_center + epsilon)
1841///   mu = I_u(a, b)
1842///
1843/// NOTE: `log_shape_center` is the *unconstrained* log of the geometric-mean
1844/// beta shape (so a·b = exp(2·log_shape_center)). Callers must pass the raw
1845/// optimization parameter `SasLinkState::log_delta`, NOT the derived positive
1846/// `SasLinkState::delta = exp(log_shape_center)`.
1847pub fn beta_logistic_inverse_link_jet(
1848    eta: f64,
1849    log_shape_center: f64,
1850    epsilon: f64,
1851) -> InverseLinkJet {
1852    let logistic = logistic_uwith_derivatives(eta);
1853    let a = (log_shape_center - epsilon).exp();
1854    let b = (log_shape_center + epsilon).exp();
1855    let mu = beta_reg_logistic(a, b, logistic);
1856    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1857    let d1 = log_d1.exp();
1858    let t = a * logistic.one_minus_u - b * logistic.u;
1859    let d2 = d1 * t;
1860    let d3 = d1 * (t * t - (a + b) * logistic.du);
1861    InverseLinkJet { mu, d1, d2, d3 }
1862}
1863
1864pub fn beta_logistic_inverse_link_pdfthird_derivative(
1865    eta: f64,
1866    log_shape_center: f64,
1867    epsilon: f64,
1868) -> f64 {
1869    // Beta-logistic link:
1870    //
1871    //   u = logistic(eta),
1872    //   d1 = C * u^a (1-u)^b,
1873    //   t  = a(1-u) - b u,
1874    //   c  = a + b,
1875    //
1876    // so
1877    //
1878    //   d2 = d1 * t
1879    //   d3 = d1 * (t² - c u')
1880    //
1881    // with `u' = u(1-u)`.
1882    //
1883    // Differentiate once more:
1884    //
1885    //   d4 = d/deta[d1 (t² - c u')]
1886    //      = d1' (t² - c u') + d1 (2 t t' - c u'')
1887    //      = d1 [ t(t² - c u') - 2 c t u' - c u'' ]
1888    //      = d1 [ t³ - 3 c t u' - c u'' ],
1889    //
1890    // since `t' = -c u'`.
1891    let logistic = logistic_uwith_derivatives(eta);
1892    let a = (log_shape_center - epsilon).exp();
1893    let b = (log_shape_center + epsilon).exp();
1894    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1895    let d1 = log_d1.exp();
1896    let c = a + b;
1897    let t = a * logistic.one_minus_u - b * logistic.u;
1898    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1899    d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1900}
1901
1902/// Fifth derivative of the beta-logistic inverse-link CDF (= 4th deriv of PDF).
1903///
1904/// With `P_4 = t^3 - 3ct*u' - c*u''` giving `d4 = d1 * P_4`, the next order is:
1905///
1906///   d5 = d1 * [t^4 - 6c*t^2*u' - 4c*t*u'' + 3c^2*u'^2 - c*u''']
1907///
1908/// where u' = u(1-u), u'' = u'(1-2u), u''' = u''(1-2u) - 2*u'^2.
1909pub fn beta_logistic_inverse_link_pdffourth_derivative(
1910    eta: f64,
1911    log_shape_center: f64,
1912    epsilon: f64,
1913) -> f64 {
1914    let logistic = logistic_uwith_derivatives(eta);
1915    let a = (log_shape_center - epsilon).exp();
1916    let b = (log_shape_center + epsilon).exp();
1917    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1918    let d1 = log_d1.exp();
1919    let c = a + b;
1920    let t = a * logistic.one_minus_u - b * logistic.u;
1921    let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1922    let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1923    let t2 = t * t;
1924    d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1925        + 3.0 * c * c * logistic.du * logistic.du
1926        - c * u3)
1927}
1928
1929pub fn beta_logistic_inverse_link_jetwith_param_partials(
1930    eta: f64,
1931    log_shape_center: f64,
1932    epsilon: f64,
1933) -> SasJetWithParamPartials {
1934    let logistic = logistic_uwith_derivatives(eta);
1935    let a = (log_shape_center - epsilon).exp();
1936    let b = (log_shape_center + epsilon).exp();
1937    let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1938    let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1939    let dmu_depsilon = -a * dmu_da + b * dmu_db;
1940    let log_d1 = beta_logistic_log_d1(a, b, logistic);
1941    let d1 = log_d1.exp();
1942    let t = a * logistic.one_minus_u - b * logistic.u;
1943    let d2 = d1 * t;
1944    let k = t * t - (a + b) * logistic.du;
1945    let d3 = d1 * k;
1946    let jet = InverseLinkJet { mu, d1, d2, d3 };
1947
1948    let psi_a = digamma(a);
1949    let psi_b = digamma(b);
1950    let psi_ab = digamma(a + b);
1951    let la = logistic.ln_u - psi_a + psi_ab;
1952    let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1953
1954    let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1955        let logd1_p = a_p * la + b_p * lb;
1956        let d1_p = d1 * logd1_p;
1957        let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1958        let d2_p = d1_p * t + d1 * t_p;
1959        let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1960        let d3_p = d1_p * k + d1 * k_p;
1961        InverseLinkJet {
1962            mu: dmu,
1963            d1: d1_p,
1964            d2: d2_p,
1965            d3: d3_p,
1966        }
1967    };
1968    let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1969    let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1970    SasJetWithParamPartials {
1971        jet,
1972        djet_depsilon,
1973        djet_dlog_delta: djet_dlog_shape_center,
1974    }
1975}
1976
1977/// SAS inverse-link jet for:
1978///   mu(eta) = Phi(sinh(delta * asinh(eta) + epsilon)),
1979///   delta = exp(B * tanh(log_delta / B)), B = SAS_LOG_DELTA_BOUND.
1980pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1981    let delta_id = sas_delta_from_raw_log_delta(log_delta);
1982    if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1983        return component_inverse_link_jet(LinkComponent::Probit, eta);
1984    }
1985    let e = if eta.is_finite() { eta } else { 0.0 };
1986    let a = e.asinh();
1987    let delta = delta_id;
1988    let u_raw = delta * a + epsilon;
1989    let u = tanh_bound(u_raw, SAS_U_CLAMP);
1990    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1991    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1992    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1993    let s = u.sinh();
1994    let c = u.cosh();
1995    let z = s;
1996    let q = e.hypot(1.0);
1997    let inv_q = 1.0 / q;
1998    let inv_q2 = inv_q * inv_q;
1999    let inv_q3 = inv_q2 * inv_q;
2000    let inv_q5 = inv_q3 * inv_q2;
2001    let r1 = delta * inv_q;
2002    let r2 = -delta * e * inv_q3;
2003    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2004    let u1 = g1 * r1;
2005    let u2 = g2 * r1 * r1 + g1 * r2;
2006    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2007    let z1 = c * u1;
2008    let z2 = s * u1 * u1 + c * u2;
2009    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2010    let base = probit_jet(z);
2011    chain_inverse_link_jet(base, z1, z2, z3)
2012}
2013
2014pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2015    // SAS link with bounded latent transform:
2016    //
2017    //   a  = asinh(eta),
2018    //   u  = tanh_bound(delta * a + epsilon),
2019    //   z  = sinh(u),
2020    //   mu = Phi(z).
2021    //
2022    // Write:
2023    //
2024    //   z1 = z'
2025    //   z2 = z''
2026    //   z3 = z'''
2027    //   z4 = z''''.
2028    //
2029    // Since `mu' = phi(z) z1`, repeated differentiation factors through the
2030    // standard normal Hermite-polynomial identities:
2031    //
2032    //   mu''   = phi(z) [ z2 - z z1² ]
2033    //
2034    //   mu'''  = phi(z) [ z3 - 3 z z1 z2 + (z² - 1) z1³ ]
2035    //          = phi(z) k3
2036    //
2037    //   mu'''' = phi(z) [ k4 - z z1 k3 ],
2038    //
2039    // where `k4` is the derivative of `k3` after collecting like terms. The
2040    // code below computes `u1..u4`, then `z1..z4`, then `k3` and `k4`, exactly
2041    // matching that chain.
2042    //
2043    // The needed fourth derivative of `u(eta)` is obtained from the nested
2044    // composition `u(eta) = g(r(eta))` with
2045    //   g = tanh_bound, r = delta * asinh(eta) - epsilon:
2046    //
2047    //   u4 = g'''' r1^4 + 6 g''' r1² r2 + 3 g'' r2² + 4 g'' r1 r3 + g' r4,
2048    //
2049    // which is the standard scalar Arbogast expansion for order four.
2050    let e = if eta.is_finite() { eta } else { 0.0 };
2051    let a = e.asinh();
2052    let delta = sas_delta_from_raw_log_delta(log_delta);
2053    let u_raw = delta * a + epsilon;
2054    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2055    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2056    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2057    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2058    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2059    let s = u.sinh();
2060    let c = u.cosh();
2061    let z = s;
2062    let base = probit_jet(z);
2063    let q = e.hypot(1.0);
2064    let inv_q = 1.0 / q;
2065    let inv_q2 = inv_q * inv_q;
2066    let inv_q3 = inv_q2 * inv_q;
2067    let inv_q5 = inv_q3 * inv_q2;
2068    let inv_q7 = inv_q5 * inv_q2;
2069    let r1 = delta * inv_q;
2070    let r2 = -delta * e * inv_q3;
2071    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2072    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2073    let u1 = g1 * r1;
2074    let u2 = g2 * r1 * r1 + g1 * r2;
2075    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2076    let u4 = g4 * r1.powi(4)
2077        + 6.0 * g3 * r1 * r1 * r2
2078        + 3.0 * g2 * r2 * r2
2079        + 4.0 * g2 * r1 * r3
2080        + g1 * r4;
2081    let z1 = c * u1;
2082    let z2 = s * u1 * u1 + c * u2;
2083    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2084    let z4 =
2085        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2086    let base4 = probit_pdfthird_derivative(z);
2087    let out = base4 * z1.powi(4)
2088        + 6.0 * base.d3 * z1 * z1 * z2
2089        + 3.0 * base.d2 * z2 * z2
2090        + 4.0 * base.d2 * z1 * z3
2091        + base.d1 * z4;
2092    canonicalzero(out)
2093}
2094
2095/// Fifth derivative of the SAS inverse-link CDF (= fourth derivative of the PDF).
2096///
2097/// Extends `sas_inverse_link_pdfthird_derivative` by one more derivative order,
2098/// using the same composition chain u(eta) = g(r(eta)), z = sinh(u), mu = Phi(z).
2099///
2100/// The Arbogast expansion at order 5 for u(eta) = g(r(eta)) is:
2101///   u5 = g5 r1^5 + 10 g4 r1^3 r2 + 15 g3 r1 r2^2 + 10 g3 r1^2 r3
2102///        + 10 g2 r2 r3 + 5 g2 r1 r4 + g1 r5
2103///
2104/// The z = sinh(u) expansion at order 5 is the standard Arbogast for sinh:
2105///   z5 = c*u1^5 + 10*s*u1^3*u2 + 15*c*u1*u2^2 + 10*c*u1^2*u3
2106///        + 10*s*u2*u3 + 5*s*u1*u4 + c*u5
2107///
2108/// The mu = Phi(z) expansion at order 5 uses probit derivatives:
2109///   mu^(5) = Phi5*z1^5 + 10*Phi4*z1^3*z2 + 15*Phi3*z1*z2^2 + 10*Phi3*z1^2*z3
2110///            + 10*Phi2*z2*z3 + 5*Phi2*z1*z4 + Phi1*z5
2111pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2112    let e = if eta.is_finite() { eta } else { 0.0 };
2113    let a = e.asinh();
2114    let delta = sas_delta_from_raw_log_delta(log_delta);
2115    let u_raw = delta * a + epsilon;
2116    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2117    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2118    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2119    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2120    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2121    let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2122    let s = u.sinh();
2123    let c = u.cosh();
2124    let z = s;
2125
2126    // Probit derivatives at z.
2127    let base = probit_jet(z);
2128    let phi3 = probit_pdfthird_derivative(z); // Phi^{(4)}
2129    let phi4 = probit_pdffourth_derivative(z); // Phi^{(5)}
2130
2131    // Powers of q = sqrt(1 + eta^2) for r1..r5.
2132    let q = e.hypot(1.0);
2133    let inv_q = 1.0 / q;
2134    let inv_q2 = inv_q * inv_q;
2135    let inv_q3 = inv_q2 * inv_q;
2136    let inv_q5 = inv_q3 * inv_q2;
2137    let inv_q7 = inv_q5 * inv_q2;
2138    let inv_q9 = inv_q7 * inv_q2;
2139
2140    let r1 = delta * inv_q;
2141    let r2 = -delta * e * inv_q3;
2142    let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2143    let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2144    let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2145
2146    // u1..u5 via Arbogast for g(r(eta)).
2147    let u1 = g1 * r1;
2148    let u2 = g2 * r1 * r1 + g1 * r2;
2149    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2150    let u4 = g4 * r1.powi(4)
2151        + 6.0 * g3 * r1 * r1 * r2
2152        + 3.0 * g2 * r2 * r2
2153        + 4.0 * g2 * r1 * r3
2154        + g1 * r4;
2155    let u5 = g5 * r1.powi(5)
2156        + 10.0 * g4 * r1 * r1 * r1 * r2
2157        + 15.0 * g3 * r1 * r2 * r2
2158        + 10.0 * g3 * r1 * r1 * r3
2159        + 10.0 * g2 * r2 * r3
2160        + 5.0 * g2 * r1 * r4
2161        + g1 * r5;
2162
2163    // z1..z5 via Arbogast for sinh(u(eta)).
2164    let z1 = c * u1;
2165    let z2 = s * u1 * u1 + c * u2;
2166    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2167    let z4 =
2168        s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2169    let z5 = c * u1.powi(5)
2170        + 10.0 * s * u1 * u1 * u1 * u2
2171        + 15.0 * c * u1 * u2 * u2
2172        + 10.0 * c * u1 * u1 * u3
2173        + 10.0 * s * u2 * u3
2174        + 5.0 * s * u1 * u4
2175        + c * u5;
2176
2177    // mu^(5) = Phi^(5)*z1^5 + 10*Phi^(4)*z1^3*z2 + 15*Phi^(3)*z1*z2^2
2178    //        + 10*Phi^(3)*z1^2*z3 + 10*Phi^(2)*z2*z3 + 5*Phi^(2)*z1*z4 + Phi^(1)*z5
2179    let out = phi4 * z1.powi(5)
2180        + 10.0 * phi3 * z1 * z1 * z1 * z2
2181        + 15.0 * base.d3 * z1 * z2 * z2
2182        + 10.0 * base.d3 * z1 * z1 * z3
2183        + 10.0 * base.d2 * z2 * z3
2184        + 5.0 * base.d2 * z1 * z4
2185        + base.d1 * z5;
2186    canonicalzero(out)
2187}
2188
2189pub fn sas_inverse_link_jetwith_param_partials(
2190    eta: f64,
2191    epsilon: f64,
2192    log_delta: f64,
2193) -> SasJetWithParamPartials {
2194    let e = if eta.is_finite() { eta } else { 0.0 };
2195    let a = e.asinh();
2196    let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2197    let delta = ld_eff.exp();
2198    let ddelta_draw = delta * dld_eff_draw;
2199    let u_raw = delta * a + epsilon;
2200    let u = tanh_bound(u_raw, SAS_U_CLAMP);
2201    let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2202    let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2203    let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2204    let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2205    let s = u.sinh();
2206    let c = u.cosh();
2207    let z = s;
2208    let q = e.hypot(1.0);
2209    let inv_q = 1.0 / q;
2210    let inv_q2 = inv_q * inv_q;
2211    let inv_q3 = inv_q2 * inv_q;
2212    let inv_q5 = inv_q3 * inv_q2;
2213    let a1 = inv_q;
2214    let a2 = -e * inv_q3;
2215    let a3 = (2.0 * e * e - 1.0) * inv_q5;
2216    let r1 = delta * a1;
2217    let r2 = delta * a2;
2218    let r3 = delta * a3;
2219    let u1 = g1 * r1;
2220    let u2 = g2 * r1 * r1 + g1 * r2;
2221    let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2222    let z1 = c * u1;
2223    let z2 = s * u1 * u1 + c * u2;
2224    let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2225
2226    let base = probit_jet(z);
2227    let jet = chain_inverse_link_jet(base, z1, z2, z3);
2228
2229    // Generic chain for parameter t:
2230    // u_t, u1_t, u2_t, u3_t -> z_t,z1_t,z2_t,z3_t -> mu_t,d1_t,d2_t,d3_t
2231    let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2232        let z_t = c * u_t;
2233        let z1_t = s * u_t * u1 + c * u1_t;
2234        let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2235        let z3_t = s * u_t * u1 * u1 * u1
2236            + 3.0 * c * u1 * u1 * u1_t
2237            + 3.0 * c * u_t * u1 * u2
2238            + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2239            + s * u_t * u3
2240            + c * u3_t;
2241
2242        InverseLinkJet {
2243            mu: base.d1 * z_t,
2244            d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2245            d2: base.d3 * z_t * z1 * z1
2246                + 2.0 * base.d2 * z1 * z1_t
2247                + base.d2 * z_t * z2
2248                + base.d1 * z2_t,
2249            d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2250                + 3.0 * base.d3 * z1 * z1 * z1_t
2251                + 3.0 * base.d3 * z_t * z1 * z2
2252                + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2253                + base.d2 * z_t * z3
2254                + base.d1 * z3_t,
2255        }
2256    };
2257
2258    // epsilon partials (raw_u_t = +1).
2259    let rt_eps = 1.0;
2260    let r1t_eps = 0.0;
2261    let r2t_eps = 0.0;
2262    let r3t_eps = 0.0;
2263    let u_eps = g1 * rt_eps;
2264    let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2265    let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2266    let u3_eps = g4 * rt_eps * r1 * r1 * r1
2267        + 3.0 * g3 * r1 * r1 * r1t_eps
2268        + 3.0 * g3 * rt_eps * r1 * r2
2269        + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2270        + g2 * rt_eps * r3
2271        + g1 * r3t_eps;
2272    let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2273
2274    // raw log-delta partials (through smooth bounded effective log-delta).
2275    let rt_ld = ddelta_draw * a;
2276    let r1t_ld = ddelta_draw * a1;
2277    let r2t_ld = ddelta_draw * a2;
2278    let r3t_ld = ddelta_draw * a3;
2279    let u_ld = g1 * rt_ld;
2280    let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2281    let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2282    let u3_ld = g4 * rt_ld * r1 * r1 * r1
2283        + 3.0 * g3 * r1 * r1 * r1t_ld
2284        + 3.0 * g3 * rt_ld * r1 * r2
2285        + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2286        + g2 * rt_ld * r3
2287        + g1 * r3t_ld;
2288    let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2289
2290    SasJetWithParamPartials {
2291        jet,
2292        djet_depsilon,
2293        djet_dlog_delta,
2294    }
2295}
2296
2297#[cfg(test)]
2298mod tests {
2299    use super::*;
2300    use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2301
2302    #[test]
2303    fn softmax_jacobian_matchesfd() {
2304        let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2305        let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2306        let h = 1e-6;
2307        for j in 0..rho.len() {
2308            let mut rp = rho.clone();
2309            rp[j] += h;
2310            let mut rm = rho.clone();
2311            rm[j] -= h;
2312            let pp = softmax_last_fixedzero(&rp);
2313            let pm = softmax_last_fixedzero(&rm);
2314            let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2315            for k in 0..pi.len() {
2316                let err = (jac[[k, j]] - fd[k]).abs();
2317                assert_eq!(
2318                    jac[[k, j]].signum(),
2319                    fd[k].signum(),
2320                    "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2321                    jac[[k, j]],
2322                    fd[k]
2323                );
2324                assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2325            }
2326        }
2327    }
2328
2329    #[test]
2330    fn mixture_jet_rho_partials_matchfd() {
2331        let spec = MixtureLinkSpec {
2332            components: vec![
2333                LinkComponent::Probit,
2334                LinkComponent::Logit,
2335                LinkComponent::CLogLog,
2336                LinkComponent::Cauchit,
2337            ],
2338            initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2339        };
2340        let state = state_fromspec(&spec).expect("state");
2341        let eta = 0.35;
2342        let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2343        let h = 1e-6;
2344        for j in 0..state.rho.len() {
2345            let mut rp = state.rho.clone();
2346            rp[j] += h;
2347            let sp = MixtureLinkSpec {
2348                components: state.components.clone(),
2349                initial_rho: rp,
2350            };
2351            let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2352            let mut rm = state.rho.clone();
2353            rm[j] -= h;
2354            let sm = MixtureLinkSpec {
2355                components: state.components.clone(),
2356                initial_rho: rm,
2357            };
2358            let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2359            let fd = InverseLinkJet {
2360                mu: (jp.mu - jm.mu) / (2.0 * h),
2361                d1: (jp.d1 - jm.d1) / (2.0 * h),
2362                d2: (jp.d2 - jm.d2) / (2.0 * h),
2363                d3: (jp.d3 - jm.d3) / (2.0 * h),
2364            };
2365            let an = out.djet_drho[j];
2366            assert_eq!(an.mu.signum(), fd.mu.signum());
2367            assert_eq!(an.d1.signum(), fd.d1.signum());
2368            assert_eq!(an.d2.signum(), fd.d2.signum());
2369            assert_eq!(an.d3.signum(), fd.d3.signum());
2370            assert!((an.mu - fd.mu).abs() < 1e-6);
2371            assert!((an.d1 - fd.d1).abs() < 1e-6);
2372            assert!((an.d2 - fd.d2).abs() < 1e-6);
2373            assert!((an.d3 - fd.d3).abs() < 1e-6);
2374        }
2375    }
2376
2377    #[test]
2378    fn sas_param_partials_matchfd() {
2379        let eta = 0.37;
2380        let epsilon = -0.12;
2381        let log_delta = 0.21;
2382        let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2383        let h = 1e-6;
2384
2385        let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2386        let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2387        let fd_ep = InverseLinkJet {
2388            mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2389            d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2390            d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2391            d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2392        };
2393        assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2394        assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2395        assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2396        assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2397        assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2398        assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2399        assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2400        assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2401
2402        let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2403        let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2404        let fd_ld = InverseLinkJet {
2405            mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2406            d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2407            d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2408            d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2409        };
2410        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2411        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2412        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2413        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2414        assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2415        assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2416        assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2417        assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2418    }
2419
2420    #[test]
2421    fn sas_jet_extreme_inputs_stay_finite() {
2422        let cases = [
2423            (-1e6, 0.0, 0.0),
2424            (1e6, 0.0, 0.0),
2425            (3.0, 12.0, 12.0),
2426            (-3.0, -12.0, -12.0),
2427            (0.5, 40.0, 10.0),
2428            (0.5, -40.0, -10.0),
2429        ];
2430        for (eta, eps, log_delta) in cases {
2431            let j = sas_inverse_link_jet(eta, eps, log_delta);
2432            assert!(j.mu.is_finite());
2433            assert!(j.d1.is_finite());
2434            assert!(j.d2.is_finite());
2435            assert!(j.d3.is_finite());
2436            let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2437            assert!(p.djet_depsilon.mu.is_finite());
2438            assert!(p.djet_depsilon.d1.is_finite());
2439            assert!(p.djet_depsilon.d2.is_finite());
2440            assert!(p.djet_depsilon.d3.is_finite());
2441            assert!(p.djet_dlog_delta.mu.is_finite());
2442            assert!(p.djet_dlog_delta.d1.is_finite());
2443            assert!(p.djet_dlog_delta.d2.is_finite());
2444            assert!(p.djet_dlog_delta.d3.is_finite());
2445        }
2446    }
2447
2448    #[test]
2449    fn sas_param_partials_remain_finite_in_extreme_region() {
2450        let eta = 10.0;
2451        let epsilon = -60.0;
2452        let log_delta = 40.0;
2453        let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2454        assert!(j.djet_depsilon.mu.is_finite());
2455        assert!(j.djet_depsilon.d1.is_finite());
2456        assert!(j.djet_depsilon.d2.is_finite());
2457        assert!(j.djet_depsilon.d3.is_finite());
2458        assert!(j.djet_dlog_delta.mu.is_finite());
2459        assert!(j.djet_dlog_delta.d1.is_finite());
2460        assert!(j.djet_dlog_delta.d2.is_finite());
2461        assert!(j.djet_dlog_delta.d3.is_finite());
2462    }
2463
2464    #[test]
2465    fn sas_eta_jets_matchfd() {
2466        let eta = -0.43;
2467        let epsilon = 0.27;
2468        let log_delta = -0.31;
2469        let h = 1e-5;
2470        let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2471        let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2472        let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2473        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2474        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2475        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2476        assert_eq!(j0.d1.signum(), d1fd.signum());
2477        assert_eq!(j0.d2.signum(), d2fd.signum());
2478        assert_eq!(j0.d3.signum(), d3fd.signum());
2479        assert!((j0.d1 - d1fd).abs() < 5e-5);
2480        assert!((j0.d2 - d2fd).abs() < 2e-4);
2481        assert!((j0.d3 - d3fd).abs() < 1e-3);
2482    }
2483
2484    #[test]
2485    fn family_dispatch_resolves_parameterized_links_from_spec() {
2486        // After the LikelihoodSpec migration, the dispatch no longer needs
2487        // out-of-band state arguments — the parameterized link state lives on
2488        // `spec.link`. Verify that supplying explicit SAS/Mixture link states
2489        // through the spec produces finite jets at a representative eta.
2490        let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2491        let sas_spec = gam_problem::LikelihoodSpec {
2492            response: gam_problem::ResponseFamily::Binomial,
2493            link: InverseLink::Sas(sas_state),
2494        };
2495        let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2496        assert!(sas_jet.mu.is_finite());
2497        assert!(sas_jet.d1.is_finite());
2498
2499        let mix_state = MixtureLinkState {
2500            components: vec![LinkComponent::Logit, LinkComponent::Probit],
2501            rho: ndarray::array![0.0],
2502            pi: ndarray::array![0.5, 0.5],
2503        };
2504        let mix_spec = gam_problem::LikelihoodSpec {
2505            response: gam_problem::ResponseFamily::Binomial,
2506            link: InverseLink::Mixture(mix_state),
2507        };
2508        let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2509        assert!(mix_jet.mu.is_finite());
2510        assert!(mix_jet.d1.is_finite());
2511    }
2512
2513    #[test]
2514    fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2515        let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2516        for eta in etas {
2517            let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2518            let expected_mu = gam_linalg::utils::stable_logistic(eta);
2519            let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2520                - gam_linalg::utils::stable_softplus(eta))
2521            .exp();
2522            assert!(
2523                (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2524                "mu mismatch at eta={eta}: got {}, expected {}",
2525                j_bl.mu,
2526                expected_mu
2527            );
2528            assert!(
2529                (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2530                "d1 mismatch at eta={eta}: got {}, expected {}",
2531                j_bl.d1,
2532                expected_d1
2533            );
2534            assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2535        }
2536
2537        let eta = 0.42;
2538        let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2539        let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2540        assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2541        assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2542    }
2543
2544    #[test]
2545    fn beta_logistic_eta_jets_matchfd() {
2546        let eta = -0.31;
2547        let delta = 0.27;
2548        let epsilon = -0.19;
2549        let h = 1e-5;
2550        let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2551        let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2552        let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2553        let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2554        let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2555        let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2556        assert_eq!(j0.d1.signum(), d1fd.signum());
2557        assert_eq!(j0.d2.signum(), d2fd.signum());
2558        assert_eq!(j0.d3.signum(), d3fd.signum());
2559        assert!((j0.d1 - d1fd).abs() < 5e-5);
2560        assert!((j0.d2 - d2fd).abs() < 5e-5);
2561        assert!((j0.d3 - d3fd).abs() < 2e-4);
2562    }
2563
2564    #[test]
2565    fn standard_kernel_structs_match_component_jets() {
2566        let eta = 0.73;
2567        assert_eq!(
2568            ProbitLinkKernel.jet(eta).expect("probit"),
2569            component_inverse_link_jet(LinkComponent::Probit, eta)
2570        );
2571        assert_eq!(
2572            LogitLinkKernel.jet(eta).expect("logit"),
2573            component_inverse_link_jet(LinkComponent::Logit, eta)
2574        );
2575        assert_eq!(
2576            CLogLogLinkKernel.jet(eta).expect("cloglog"),
2577            component_inverse_link_jet(LinkComponent::CLogLog, eta)
2578        );
2579        assert_eq!(
2580            LogLogLinkKernel.jet(eta).expect("loglog"),
2581            component_inverse_link_jet(LinkComponent::LogLog, eta)
2582        );
2583        assert_eq!(
2584            CauchitLinkKernel.jet(eta).expect("cauchit"),
2585            component_inverse_link_jet(LinkComponent::Cauchit, eta)
2586        );
2587    }
2588
2589    #[test]
2590    fn all_component_eta_jets_matchfd() {
2591        let components = [
2592            LinkComponent::Logit,
2593            LinkComponent::Probit,
2594            LinkComponent::CLogLog,
2595            LinkComponent::LogLog,
2596            LinkComponent::Cauchit,
2597        ];
2598        let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2599        let h = 1e-5;
2600        for c in components {
2601            for &eta in &points {
2602                let j0 = component_inverse_link_jet(c, eta);
2603                let jp = component_inverse_link_jet(c, eta + h);
2604                let jm = component_inverse_link_jet(c, eta - h);
2605                let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2606                let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2607                let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2608                let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2609                    1.2e-4
2610                } else {
2611                    5e-5
2612                };
2613                let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2614                    4e-4
2615                } else {
2616                    1.2e-4
2617                };
2618                let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2619                    1.2e-3
2620                } else {
2621                    4e-4
2622                };
2623                if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2624                    assert_eq!(
2625                        j0.d1.signum(),
2626                        d1fd.signum(),
2627                        "d1 sign mismatch for {c:?} eta={eta}"
2628                    );
2629                }
2630                if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2631                    assert_eq!(
2632                        j0.d2.signum(),
2633                        d2fd.signum(),
2634                        "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2635                        j0.d2,
2636                        d2fd
2637                    );
2638                }
2639                if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2640                    assert_eq!(
2641                        j0.d3.signum(),
2642                        d3fd.signum(),
2643                        "d3 sign mismatch for {c:?} eta={eta}"
2644                    );
2645                }
2646                assert!(
2647                    (j0.d1 - d1fd).abs() < d1_tol,
2648                    "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2649                    j0.d1,
2650                    d1fd
2651                );
2652                assert!(
2653                    (j0.d2 - d2fd).abs() < d2_tol,
2654                    "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2655                    j0.d2,
2656                    d2fd
2657                );
2658                assert!(
2659                    (j0.d3 - d3fd).abs() < d3_tol,
2660                    "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2661                    j0.d3,
2662                    d3fd
2663                );
2664            }
2665        }
2666    }
2667
2668    #[test]
2669    fn sas_center_matches_probit_at_delta1_epsilon0() {
2670        let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2671        for eta in etas {
2672            let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2673            let probit = ProbitLinkKernel.jet(eta).expect("probit");
2674            // SAS implementation uses a smooth bounded latent (`tanh_bound`) for
2675            // numerical robustness, so the probit center is approximate in practice.
2676            assert!(
2677                (sas.mu - probit.mu).abs() < 6e-4,
2678                "mu mismatch at eta={eta}"
2679            );
2680            assert!(
2681                (sas.d1 - probit.d1).abs() < 6e-4,
2682                "d1 mismatch at eta={eta}"
2683            );
2684            assert!(
2685                (sas.d2 - probit.d2).abs() < 2e-3,
2686                "d2 mismatch at eta={eta}"
2687            );
2688            assert!(
2689                (sas.d3 - probit.d3).abs() < 4e-3,
2690                "d3 mismatch at eta={eta}"
2691            );
2692        }
2693    }
2694
2695    #[test]
2696    fn beta_logistic_param_partials_matchfd() {
2697        let eta = -0.41;
2698        let delta = 0.23;
2699        let epsilon = -0.17;
2700        let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2701        let h = 1e-6;
2702
2703        let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2704        let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2705        let fd_delta = InverseLinkJet {
2706            mu: (dp.mu - dm.mu) / (2.0 * h),
2707            d1: (dp.d1 - dm.d1) / (2.0 * h),
2708            d2: (dp.d2 - dm.d2) / (2.0 * h),
2709            d3: (dp.d3 - dm.d3) / (2.0 * h),
2710        };
2711        assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2712        assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2713        assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2714        assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2715        assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2716        assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2717        assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2718        assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2719
2720        let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2721        let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2722        let fd_epsilon = InverseLinkJet {
2723            mu: (ep.mu - em.mu) / (2.0 * h),
2724            d1: (ep.d1 - em.d1) / (2.0 * h),
2725            d2: (ep.d2 - em.d2) / (2.0 * h),
2726            d3: (ep.d3 - em.d3) / (2.0 * h),
2727        };
2728        assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2729        assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2730        assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2731        assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2732        assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2733        assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2734        assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2735        assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2736    }
2737
2738    #[test]
2739    fn beta_logistic_left_tail_uses_unclamped_log_space() {
2740        let eta = -40.0_f64;
2741        let delta = 0.2_f64;
2742        let epsilon = -0.1_f64;
2743        let a = (delta - epsilon).exp();
2744        let b = (delta + epsilon).exp();
2745        let expected_mu = beta_reg(a, b, eta.exp());
2746        let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2747
2748        assert!(
2749            (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2750            "left-tail mu mismatch: got {}, expected {}",
2751            out.mu,
2752            expected_mu
2753        );
2754        assert!(out.d1 > 0.0);
2755        assert!(out.d2 > 0.0);
2756        assert!(out.d3 > 0.0);
2757        assert!(out.d1 < 1e-20);
2758
2759        let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2760        assert!(partials.jet.d1 > 0.0);
2761        assert!(partials.jet.d2 > 0.0);
2762        assert!(partials.jet.d3 > 0.0);
2763        assert!(partials.djet_dlog_delta.d1.is_finite());
2764        assert!(partials.djet_depsilon.d1.is_finite());
2765    }
2766
2767    #[test]
2768    fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2769        let delta = 0.2;
2770        let epsilon = -0.35;
2771        let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2772        for eta in etas {
2773            let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2774            let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2775            assert!(
2776                (left - right).abs() <= 1e-14,
2777                "symmetry mismatch at eta={eta}: left={left}, right={right}"
2778            );
2779        }
2780    }
2781
2782    #[test]
2783    fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2784        let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2785        let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2786            epsilon: 0.18,
2787            log_delta: -0.22,
2788            delta: (-0.22_f64).exp(),
2789        });
2790        let mixture = InverseLink::Mixture(
2791            state_fromspec(&MixtureLinkSpec {
2792                components: vec![
2793                    LinkComponent::Probit,
2794                    LinkComponent::Logit,
2795                    LinkComponent::CLogLog,
2796                    LinkComponent::Cauchit,
2797                ],
2798                initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2799            })
2800            .expect("mixture state"),
2801        );
2802        let links = [
2803            InverseLink::Standard(StandardLink::Probit),
2804            InverseLink::Standard(StandardLink::Logit),
2805            InverseLink::Standard(StandardLink::CLogLog),
2806            sas,
2807            beta_logistic,
2808            mixture,
2809        ];
2810        let etas = [-1.1, -0.2, 0.6];
2811        let h = 1e-5;
2812
2813        for link in &links {
2814            for &eta in &etas {
2815                let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2816                let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2817                let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2818                let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2819                    .expect("analytic d4");
2820                assert_eq!(
2821                    d4.signum(),
2822                    d4fd.signum(),
2823                    "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2824                    link,
2825                    d4,
2826                    d4fd
2827                );
2828                assert!(
2829                    (d4 - d4fd).abs() < 5e-3,
2830                    "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2831                    link,
2832                    d4,
2833                    d4fd
2834                );
2835            }
2836        }
2837    }
2838
2839    #[test]
2840    fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2841        let eta = 800.0;
2842        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2843        assert_eq!(jet.mu, 1.0);
2844        assert!(
2845            jet.d1 == 0.0,
2846            "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2847            jet.d1
2848        );
2849        assert!(
2850            jet.d2 == 0.0,
2851            "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2852            jet.d2
2853        );
2854        assert!(
2855            jet.d3 == 0.0,
2856            "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2857            jet.d3
2858        );
2859
2860        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2861            &InverseLink::Standard(StandardLink::CLogLog),
2862            eta,
2863        )
2864        .expect("cloglog d4");
2865        assert!(
2866            d4 == 0.0,
2867            "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2868        );
2869    }
2870
2871    #[test]
2872    fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2873        let eta = -800.0;
2874        let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2875        assert_eq!(jet.mu, 0.0);
2876        assert!(
2877            jet.d1 == 0.0,
2878            "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2879            jet.d1
2880        );
2881        assert!(
2882            jet.d2 == 0.0,
2883            "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2884            jet.d2
2885        );
2886        assert!(
2887            jet.d3 == 0.0,
2888            "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2889            jet.d3
2890        );
2891
2892        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2893            &InverseLink::Mixture(
2894                state_fromspec(&MixtureLinkSpec {
2895                    components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2896                    initial_rho: Array1::from_vec(vec![12.0]),
2897                })
2898                .expect("mixture state"),
2899            ),
2900            eta,
2901        )
2902        .expect("loglog mixture d4");
2903        assert!(
2904            d4.is_finite(),
2905            "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2906        );
2907    }
2908
2909    #[test]
2910    fn logit_tail_derivatives_should_match_stable_closed_forms() {
2911        let eta = 50.0_f64;
2912        let z = (-eta).exp();
2913        let denom = 1.0_f64 + z;
2914        let stable_d1 = z / denom.powi(2);
2915        let stable_d2 = z * (z - 1.0) / denom.powi(3);
2916        let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2917        let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2918        let stable_d5 =
2919            z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2920
2921        assert!(stable_d1 > 0.0);
2922        assert!(stable_d2 < 0.0);
2923        assert!(stable_d3 > 0.0);
2924        assert!(stable_d4 < 0.0);
2925        assert!(stable_d5 > 0.0);
2926
2927        let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2928        assert!(
2929            (jet.d1 - stable_d1).abs() < 1e-30,
2930            "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2931            jet.d1,
2932            stable_d1
2933        );
2934        assert!(
2935            (jet.d2 - stable_d2).abs() < 1e-30,
2936            "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2937            jet.d2,
2938            stable_d2
2939        );
2940        assert!(
2941            (jet.d3 - stable_d3).abs() < 1e-30,
2942            "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2943            jet.d3,
2944            stable_d3
2945        );
2946
2947        let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2948            &InverseLink::Standard(StandardLink::Logit),
2949            eta,
2950        )
2951        .expect("logit d4");
2952        assert!(
2953            (d4 - stable_d4).abs() < 1e-30,
2954            "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2955            d4,
2956            stable_d4
2957        );
2958
2959        let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2960            &InverseLink::Standard(StandardLink::Logit),
2961            eta,
2962        )
2963        .expect("logit d5");
2964        assert!(
2965            (d5 - stable_d5).abs() < 1e-30,
2966            "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2967            d5,
2968            stable_d5
2969        );
2970    }
2971
2972    #[test]
2973    fn cloglog_negative_tail_value_should_match_expm1_form() {
2974        let eta = -50.0_f64;
2975        let t = eta.exp();
2976        let stable_mu = -(-t).exp_m1();
2977        assert!(stable_mu > 0.0);
2978
2979        let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2980        assert!(
2981            (jet.mu - stable_mu).abs() < 1e-30,
2982            "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2983            jet.mu,
2984            stable_mu
2985        );
2986    }
2987
2988    #[test]
2989    fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2990        fn rel_err(a: f64, b: f64) -> f64 {
2991            (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2992        }
2993
2994        let cases = [
2995            (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2996            (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2997            (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2998        ];
2999        for (component, etas) in cases {
3000            for eta in etas {
3001                let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
3002                let jet = component_inverse_link_jet(component, eta);
3003                let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
3004                assert!(
3005                    rel_err(w, expected) < 1.0e-12,
3006                    "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
3007                );
3008
3009                let h = 1.0e-4;
3010                let fd1 = (component_fisher_weight_jet5(component, eta + h).0
3011                    - component_fisher_weight_jet5(component, eta - h).0)
3012                    / (2.0 * h);
3013                let fd2 = (component_fisher_weight_jet5(component, eta + h).1
3014                    - component_fisher_weight_jet5(component, eta - h).1)
3015                    / (2.0 * h);
3016                let fd3 = (component_fisher_weight_jet5(component, eta + h).2
3017                    - component_fisher_weight_jet5(component, eta - h).2)
3018                    / (2.0 * h);
3019                let fd4 = (component_fisher_weight_jet5(component, eta + h).3
3020                    - component_fisher_weight_jet5(component, eta - h).3)
3021                    / (2.0 * h);
3022
3023                assert!(
3024                    rel_err(w1, fd1) < 1.0e-5,
3025                    "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
3026                );
3027                assert!(
3028                    rel_err(w2, fd2) < 1.0e-5,
3029                    "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
3030                );
3031                assert!(
3032                    rel_err(w3, fd3) < 5.0e-5,
3033                    "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
3034                );
3035                assert!(
3036                    rel_err(w4, fd4) < 5.0e-4,
3037                    "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3038                );
3039            }
3040        }
3041    }
3042
3043    #[test]
3044    fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3045        let state = state_fromspec(&MixtureLinkSpec {
3046            components: vec![
3047                LinkComponent::CLogLog,
3048                LinkComponent::LogLog,
3049                LinkComponent::Cauchit,
3050            ],
3051            initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3052        })
3053        .expect("mixture state");
3054        let link = InverseLink::Mixture(state);
3055        assert!(
3056            inverse_link_has_fisher_weight_jet(&link),
3057            "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3058        );
3059        assert!(
3060            LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3061            "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3062        );
3063
3064        for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3065            let (w, w1, w2, w3, w4) =
3066                fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3067            for value in [w, w1, w2, w3, w4] {
3068                assert!(
3069                    value.is_finite(),
3070                    "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3071                );
3072            }
3073            assert!(
3074                w > 0.0,
3075                "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3076            );
3077        }
3078    }
3079
3080    #[test]
3081    fn loglog_fifth_derivative_should_match_closed_form_sign() {
3082        let eta = 0.0_f64;
3083        let r = (-eta).exp();
3084        let expected =
3085            (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3086        let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3087        assert!(
3088            (d5 - expected).abs() < 1e-15,
3089            "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3090        );
3091        assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3092    }
3093}