1use crate::estimate::EstimationError;
2use crate::quadrature::latent_cloglog_jet5;
3use gam_math::probability::{normal_cdf, normal_pdf};
4use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
5use gam_problem::{
6 InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7 MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24 static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25 QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30 state: &LatentCLogLogState,
31 eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33 let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34 Ok(InverseLinkJet {
35 mu: jet.mean,
36 d1: jet.d1,
37 d2: jet.d2,
38 d3: jet.d3,
39 })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44 pub mu: f64,
45 pub d1: f64,
46 pub d2: f64,
47 pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52 pub mu: f64,
53 pub d1: f64,
54 pub d2: f64,
55 pub d3: f64,
56 pub d4: f64,
57 pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62 if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67 jet.d1 = canonicalzero(jet.d1);
68 jet.d2 = canonicalzero(jet.d2);
69 jet.d3 = canonicalzero(jet.d3);
70 jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75 if eta.is_nan() {
76 return LogitJet5 {
77 mu: f64::NAN,
78 d1: f64::NAN,
79 d2: f64::NAN,
80 d3: f64::NAN,
81 d4: f64::NAN,
82 d5: f64::NAN,
83 };
84 }
85 if eta == f64::INFINITY {
86 return LogitJet5 {
87 mu: 1.0,
88 d1: 0.0,
89 d2: 0.0,
90 d3: 0.0,
91 d4: 0.0,
92 d5: 0.0,
93 };
94 }
95 if eta == f64::NEG_INFINITY {
96 return LogitJet5 {
97 mu: 0.0,
98 d1: 0.0,
99 d2: 0.0,
100 d3: 0.0,
101 d4: 0.0,
102 d5: 0.0,
103 };
104 }
105
106 let jet = if eta >= 0.0 {
107 let z = (-eta).exp();
108 let opz = 1.0 + z;
109 let opz2 = opz * opz;
110 let opz3 = opz2 * opz;
111 let opz4 = opz3 * opz;
112 let opz5 = opz4 * opz;
113 let opz6 = opz5 * opz;
114 let z2 = z * z;
115 let z3 = z2 * z;
116 let z4 = z3 * z;
117 LogitJet5 {
118 mu: 1.0 / opz,
119 d1: z / opz2,
120 d2: z * (z - 1.0) / opz3,
121 d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122 d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123 d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124 }
125 } else {
126 let z = eta.exp();
127 let opz = 1.0 + z;
128 let opz2 = opz * opz;
129 let opz3 = opz2 * opz;
130 let opz4 = opz3 * opz;
131 let opz5 = opz4 * opz;
132 let opz6 = opz5 * opz;
133 let z2 = z * z;
134 let z3 = z2 * z;
135 let z4 = z3 * z;
136 LogitJet5 {
137 mu: z / opz,
138 d1: z / opz2,
139 d2: z * (1.0 - z) / opz3,
140 d3: z * (1.0 - 4.0 * z + z2) / opz4,
141 d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142 d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143 }
144 };
145 LogitJet5 {
146 mu: jet.mu,
147 d1: canonicalzero(jet.d1),
148 d2: canonicalzero(jet.d2),
149 d3: canonicalzero(jet.d3),
150 d4: canonicalzero(jet.d4),
151 d5: canonicalzero(jet.d5),
152 }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157 if eta.is_nan() {
168 return InverseLinkJet {
169 mu: f64::NAN,
170 d1: f64::NAN,
171 d2: f64::NAN,
172 d3: f64::NAN,
173 };
174 }
175 if eta == f64::INFINITY {
176 return InverseLinkJet {
177 mu: 1.0,
178 d1: 0.0,
179 d2: 0.0,
180 d3: 0.0,
181 };
182 }
183 if eta == f64::NEG_INFINITY {
184 return InverseLinkJet {
185 mu: 0.0,
186 d1: 0.0,
187 d2: 0.0,
188 d3: 0.0,
189 };
190 }
191 let x = eta;
192 let phi = normal_pdf(x);
193 InverseLinkJet {
194 mu: normal_cdf(x),
195 d1: phi,
196 d2: -x * phi,
197 d3: (x * x - 1.0) * phi,
198 }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203 if eta.is_nan() {
207 return f64::NAN;
208 }
209 if !eta.is_finite() {
210 return 0.0;
211 }
212 let x = eta;
213 let phi = normal_pdf(x);
214 canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219 if eta.is_nan() {
221 return f64::NAN;
222 }
223 if !eta.is_finite() {
224 return 0.0;
225 }
226 let x = eta;
227 let phi = normal_pdf(x);
228 canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235 let mut c = [0.0_f64; 5];
236 for i in 0..5 {
237 let ai = a[i];
238 if ai == 0.0 {
239 continue;
240 }
241 for j in 0..(5 - i) {
242 c[i + j] += ai * b[j];
243 }
244 }
245 c
246}
247
248#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251 let mut b = [0.0_f64; 5];
252 b[0] = 1.0 / a[0];
253 for k in 1..5 {
254 let mut s = 0.0_f64;
255 for j in 1..=k {
256 s += a[j] * b[k - j];
257 }
258 b[k] = -s * b[0];
259 }
260 b
261}
262
263pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279 match link {
280 StandardLink::Logit => {
281 let jet = logit_inverse_link_jet5(eta);
282 (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283 }
284 StandardLink::Probit => probit_fisher_weight_jet5(eta),
285 StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286 StandardLink::LogLog => component_fisher_weight_jet5(LinkComponent::LogLog, eta),
287 StandardLink::Cauchit => component_fisher_weight_jet5(LinkComponent::Cauchit, eta),
288 StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
289 }
290}
291
292pub(crate) fn fisher_weight_jet5_for_inverse_link(
293 link: &InverseLink,
294 eta: f64,
295) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
296 match link {
297 InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
298 InverseLink::LatentCLogLog(_)
299 | InverseLink::Sas(_)
300 | InverseLink::BetaLogistic(_)
301 | InverseLink::Mixture(_) => {
302 let jet = link.jet(eta)?;
303 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
304 let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
305 Ok(fisher_weight_jet5_from_inverse_link_derivatives(
306 jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
307 ))
308 }
309 }
310}
311
312#[inline]
313pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
314 matches!(
321 link,
322 InverseLink::Standard(
323 StandardLink::Logit
324 | StandardLink::Probit
325 | StandardLink::CLogLog
326 | StandardLink::LogLog
327 | StandardLink::Cauchit,
328 )
329 | InverseLink::LatentCLogLog(_)
330 | InverseLink::Sas(_)
331 | InverseLink::BetaLogistic(_)
332 | InverseLink::Mixture(_)
333 )
334}
335
336#[inline]
337fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
338 let jet = component_inverse_link_jet(component, eta);
339 let d4 = component_inverse_link_pdfthird_derivative(component, eta);
340 let d5 = component_inverse_link_pdffourth_derivative(component, eta);
341 fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
342}
343
344#[inline]
345fn fisher_weight_jet5_from_inverse_link_derivatives(
346 mu: f64,
347 d1: f64,
348 d2: f64,
349 d3: f64,
350 d4: f64,
351 d5: f64,
352) -> (f64, f64, f64, f64, f64) {
353 if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
354 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
355 }
356 let variance = mu * (1.0 - mu);
357 if !(variance > 0.0) || !variance.is_finite() {
358 return (0.0, 0.0, 0.0, 0.0, 0.0);
359 }
360
361 let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
362 let mu_d = [mu, d1, d2, d3, d4];
363 let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
364 let dmu_d = [d1, d2, d3, d4, d5];
365 let mut mu_t = [0.0_f64; 5];
366 let mut one_minus_mu_t = [0.0_f64; 5];
367 let mut dmu_t = [0.0_f64; 5];
368 for k in 0..5 {
369 let inv_fact = 1.0 / factorial[k];
370 mu_t[k] = mu_d[k] * inv_fact;
371 one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
372 dmu_t[k] = dmu_d[k] * inv_fact;
373 }
374 let num_t = taylor5_mul(&dmu_t, &dmu_t);
375 let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
376 if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
377 return (0.0, 0.0, 0.0, 0.0, 0.0);
378 }
379 let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
380 (
381 canonicalzero(w_t[0] * factorial[0]),
382 canonicalzero(w_t[1] * factorial[1]),
383 canonicalzero(w_t[2] * factorial[2]),
384 canonicalzero(w_t[3] * factorial[3]),
385 canonicalzero(w_t[4] * factorial[4]),
386 )
387}
388
389#[inline]
392fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
393 if eta.is_nan() {
394 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
395 }
396 if !eta.is_finite() {
397 return (0.0, 0.0, 0.0, 0.0, 0.0);
398 }
399 let x = eta;
400 let p = normal_cdf(x);
401 let q = normal_cdf(-x);
405 let phi = normal_pdf(x);
406 if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
409 return (0.0, 0.0, 0.0, 0.0, 0.0);
410 }
411 let phi1 = -x * phi;
413 let phi2 = (x * x - 1.0) * phi;
414 let phi3 = -(x * x * x - 3.0 * x) * phi;
415 let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
416 let f_d = [phi, phi1, phi2, phi3, phi4];
419 let p_d = [p, phi, phi1, phi2, phi3];
420 let q_d = [q, -phi, -phi1, -phi2, -phi3];
421 let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
423 let mut f_t = [0.0_f64; 5];
424 let mut p_t = [0.0_f64; 5];
425 let mut q_t = [0.0_f64; 5];
426 for k in 0..5 {
427 let inv_fact = 1.0 / factorial[k];
428 f_t[k] = f_d[k] * inv_fact;
429 p_t[k] = p_d[k] * inv_fact;
430 q_t[k] = q_d[k] * inv_fact;
431 }
432 let num_t = taylor5_mul(&f_t, &f_t);
433 let den_t = taylor5_mul(&p_t, &q_t);
434 let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
435 (
437 canonicalzero(w_t[0] * factorial[0]),
438 canonicalzero(w_t[1] * factorial[1]),
439 canonicalzero(w_t[2] * factorial[2]),
440 canonicalzero(w_t[3] * factorial[3]),
441 canonicalzero(w_t[4] * factorial[4]),
442 )
443}
444
445#[inline]
446fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
447 InverseLinkJet {
448 mu: base.mu,
449 d1: base.d1 * z1,
450 d2: base.d2 * z1 * z1 + base.d1 * z2,
451 d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
452 }
453}
454
455#[inline]
456fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
457 match component {
458 LinkComponent::Probit => probit_pdfthird_derivative(eta),
459 LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
460 LinkComponent::CLogLog => {
461 if eta.is_nan() {
469 return f64::NAN;
470 }
471 if !eta.is_finite() {
472 return 0.0;
473 }
474 let t = eta.exp();
475 canonicalzero(stable_nonnegative_poly_times_exp_neg(
476 t,
477 &[0.0, 1.0, -7.0, 6.0, -1.0],
478 ))
479 }
480 LinkComponent::LogLog => {
481 if eta.is_nan() {
488 return f64::NAN;
489 }
490 if !eta.is_finite() {
491 return 0.0;
492 }
493 let r = (-eta).exp();
494 canonicalzero(stable_nonnegative_poly_times_exp_neg(
495 r,
496 &[0.0, -1.0, 7.0, -6.0, 1.0],
497 ))
498 }
499 LinkComponent::Cauchit => {
500 if eta.is_nan() {
508 return f64::NAN;
509 }
510 if !eta.is_finite() {
511 return 0.0;
512 }
513 let denom = 1.0 + eta * eta;
514 24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
515 }
516 }
517}
518
519#[inline]
522fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
523 match component {
524 LinkComponent::Probit => probit_pdffourth_derivative(eta),
525 LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
526 LinkComponent::CLogLog => {
527 if eta.is_nan() {
532 return f64::NAN;
533 }
534 if !eta.is_finite() {
535 return 0.0;
536 }
537 let t = eta.exp();
538 canonicalzero(stable_nonnegative_poly_times_exp_neg(
539 t,
540 &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
541 ))
542 }
543 LinkComponent::LogLog => {
544 if eta.is_nan() {
549 return f64::NAN;
550 }
551 if !eta.is_finite() {
552 return 0.0;
553 }
554 let r = (-eta).exp();
555 canonicalzero(stable_nonnegative_poly_times_exp_neg(
556 r,
557 &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
558 ))
559 }
560 LinkComponent::Cauchit => {
561 if eta.is_nan() {
563 return f64::NAN;
564 }
565 if !eta.is_finite() {
566 return 0.0;
567 }
568 let e2 = eta * eta;
569 let denom = 1.0 + e2;
570 24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
571 }
572 }
573}
574
575#[derive(Clone, Debug, PartialEq)]
576pub struct MixtureJetWithRhoPartials {
577 pub jet: InverseLinkJet,
578 pub djet_drho: Vec<InverseLinkJet>,
581}
582
583#[derive(Clone, Debug, PartialEq)]
584pub struct SasJetWithParamPartials {
585 pub jet: InverseLinkJet,
586 pub djet_depsilon: InverseLinkJet,
587 pub djet_dlog_delta: InverseLinkJet,
588}
589
590#[derive(Clone, Debug, PartialEq)]
591pub enum LinkParamPartials {
592 Mixture(MixtureJetWithRhoPartials),
593 Sas(SasJetWithParamPartials),
594}
595
596pub trait InverseLinkKernel {
602 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
603
604 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
605 assert!(eta.is_finite(), "eta must be finite");
606 Ok(None)
607 }
608}
609
610#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
611pub struct ProbitLinkKernel;
612
613#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
614pub struct LogitLinkKernel;
615
616#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
617pub struct CLogLogLinkKernel;
618
619#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
620pub struct LogLogLinkKernel;
621
622#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
623pub struct CauchitLinkKernel;
624
625pub fn sas_link_state_from_raw(
633 raw_epsilon: f64,
634 raw_log_delta: f64,
635) -> Result<SasLinkState, String> {
636 if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
637 return Err("SAS link parameters must be finite".to_string());
638 }
639 Ok(SasLinkState {
640 epsilon: raw_epsilon,
641 log_delta: raw_log_delta,
642 delta: sas_delta_from_raw_log_delta(raw_log_delta),
643 })
644}
645
646pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
647 sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
648}
649
650pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
651 if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
652 return Err("Beta-Logistic link parameters must be finite".to_string());
653 }
654 let log_shape_center = spec.initial_log_delta;
660 Ok(SasLinkState {
661 epsilon: spec.initial_epsilon,
662 log_delta: log_shape_center,
663 delta: sas_delta_from_raw_log_delta(log_shape_center),
664 })
665}
666
667#[inline]
668fn tanh_bound(value: f64, bound: f64) -> f64 {
669 let b = bound.max(f64::EPSILON);
670 b * (value / b).tanh()
671}
672
673#[inline]
674fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
675 let b = bound.max(f64::EPSILON);
676 let t = (value / b).tanh();
677 1.0 - t * t
678}
679
680#[inline]
681fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
682 let b = bound.max(f64::EPSILON);
683 let t = (value / b).tanh();
684 let s = 1.0 - t * t;
685 -2.0 * t * s / b
686}
687
688#[inline]
689fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
690 let b = bound.max(f64::EPSILON);
691 let t = (value / b).tanh();
692 let s = 1.0 - t * t;
693 -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
694}
695
696#[inline]
697fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
698 let b = bound.max(f64::EPSILON);
699 let t = (value / b).tanh();
700 let s = 1.0 - t * t;
701 8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
702}
703
704#[inline]
705fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
706 let b = bound.max(f64::EPSILON);
710 let t = (value / b).tanh();
711 let s = 1.0 - t * t;
712 let t2 = t * t;
713 let b4 = b * b * b * b;
714 8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
715}
716
717#[inline]
718fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
719 let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
720 let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
721 (ld_eff, dld_eff_draw)
722}
723
724#[inline]
725fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
726 let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
727 ld_eff.exp()
728}
729
730pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
731 if spec.components.is_empty() {
732 return Err("mixture link requires at least 1 component".to_string());
733 }
734 if spec.initial_rho.len() + 1 != spec.components.len() {
735 return Err(format!(
736 "mixture link rho length mismatch: expected {}, got {}",
737 spec.components.len() - 1,
738 spec.initial_rho.len()
739 ));
740 }
741 for i in 0..spec.components.len() {
742 for j in (i + 1)..spec.components.len() {
743 if spec.components[i] == spec.components[j] {
744 return Err("mixture link components must be unique".to_string());
745 }
746 }
747 }
748 let has_anchor = spec.components.iter().any(|component| {
763 matches!(
764 component,
765 LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
766 )
767 });
768 if !has_anchor && spec.components.len() > 1 {
769 let unsupported: Vec<&str> = spec
770 .components
771 .iter()
772 .map(|component| component.name())
773 .collect();
774 return Err(format!(
775 "mixture link components {{{}}} are unsupported: at least one component \
776 must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
777 projected LinkFunction is well defined; cauchit and loglog have no \
778 LinkFunction representative",
779 unsupported.join(", ")
780 ));
781 }
782 Ok(())
783}
784
785pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
786 let k = rho.len() + 1;
787 let mut logits = Vec::with_capacity(k);
788 let mut maxv = 0.0_f64;
789 for &v in rho {
790 maxv = maxv.max(v);
791 logits.push(v);
792 }
793 maxv = maxv.max(0.0);
794 logits.push(0.0);
795
796 let mut sum = 0.0_f64;
797 let mut exps = vec![0.0_f64; k];
798 for i in 0..k {
799 let e = (logits[i] - maxv).exp();
800 exps[i] = e;
801 sum += e;
802 }
803 if !sum.is_finite() || sum <= 0.0 {
804 return Array1::from_elem(k, 1.0 / k as f64);
805 }
806 let inv = 1.0 / sum;
807 Array1::from_iter(exps.into_iter().map(|v| v * inv))
808}
809
810pub fn softmaxwith_jacobian_last_fixedzero(
813 rho: &Array1<f64>,
814) -> (Array1<f64>, ndarray::Array2<f64>) {
815 let pi = softmax_last_fixedzero(rho);
816 let k = pi.len();
817 let m = k.saturating_sub(1);
818 let mut jac = ndarray::Array2::<f64>::zeros((k, m));
819 for j in 0..m {
820 let pi_j = pi[j];
821 for kk in 0..k {
822 let delta = if kk == j { 1.0 } else { 0.0 };
823 jac[[kk, j]] = pi[kk] * (delta - pi_j);
824 }
825 }
826 (pi, jac)
827}
828
829pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
830 validate_mixturespec(spec)?;
831 let pi = softmax_last_fixedzero(&spec.initial_rho);
832 Ok(MixtureLinkState {
833 components: spec.components.clone(),
834 rho: spec.initial_rho.clone(),
835 pi,
836 })
837}
838
839#[inline]
840pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
841 canonicalize_jet(match component {
842 LinkComponent::Logit => {
843 let jet = logit_inverse_link_jet5(eta);
844 InverseLinkJet {
845 mu: jet.mu,
846 d1: jet.d1,
847 d2: jet.d2,
848 d3: jet.d3,
849 }
850 }
851 LinkComponent::Probit => probit_jet(eta),
852 LinkComponent::CLogLog => {
853 if eta.is_nan() {
854 return InverseLinkJet {
855 mu: f64::NAN,
856 d1: f64::NAN,
857 d2: f64::NAN,
858 d3: f64::NAN,
859 };
860 }
861 let t = eta.exp();
862 if !t.is_finite() {
863 return InverseLinkJet {
864 mu: 1.0,
865 d1: 0.0,
866 d2: 0.0,
867 d3: 0.0,
868 };
869 }
870 InverseLinkJet {
871 mu: -(-t).exp_m1(),
872 d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
873 d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
874 d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
875 }
876 }
877 LinkComponent::LogLog => {
878 if eta.is_nan() {
879 return InverseLinkJet {
880 mu: f64::NAN,
881 d1: f64::NAN,
882 d2: f64::NAN,
883 d3: f64::NAN,
884 };
885 }
886 let r = (-eta).exp();
887 if !r.is_finite() {
888 return InverseLinkJet {
889 mu: 0.0,
890 d1: 0.0,
891 d2: 0.0,
892 d3: 0.0,
893 };
894 }
895 InverseLinkJet {
896 mu: (-r).exp(),
897 d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
898 d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
899 d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
900 }
901 }
902 LinkComponent::Cauchit => {
903 if eta.is_nan() {
904 return InverseLinkJet {
905 mu: f64::NAN,
906 d1: f64::NAN,
907 d2: f64::NAN,
908 d3: f64::NAN,
909 };
910 }
911 let den = 1.0 + eta * eta;
912 let d1 = if eta.is_finite() {
913 1.0 / (std::f64::consts::PI * den)
914 } else {
915 0.0
916 };
917 let d2 = if eta.is_finite() {
918 -2.0 * eta / (std::f64::consts::PI * den * den)
919 } else {
920 0.0
921 };
922 let d3 = if eta.is_finite() {
923 (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
924 } else {
925 0.0
926 };
927 InverseLinkJet {
928 mu: 0.5 + eta.atan() / std::f64::consts::PI,
929 d1,
930 d2,
931 d3,
932 }
933 }
934 })
935}
936
937impl InverseLinkKernel for ProbitLinkKernel {
938 #[inline]
939 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
940 Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
941 }
942}
943
944impl InverseLinkKernel for LogitLinkKernel {
945 #[inline]
946 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
947 Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
948 }
949}
950
951impl InverseLinkKernel for CLogLogLinkKernel {
952 #[inline]
953 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
954 Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
955 }
956}
957
958impl InverseLinkKernel for LogLogLinkKernel {
959 #[inline]
960 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
961 Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
962 }
963}
964
965impl InverseLinkKernel for CauchitLinkKernel {
966 #[inline]
967 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
968 Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
969 }
970}
971
972impl InverseLinkKernel for LinkComponent {
973 #[inline]
974 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
975 Ok(component_inverse_link_jet(*self, eta))
976 }
977}
978
979impl InverseLinkKernel for LinkFunction {
980 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
981 match self {
982 LinkFunction::Logit => LogitLinkKernel.jet(eta),
983 LinkFunction::Probit => ProbitLinkKernel.jet(eta),
984 LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
985 LinkFunction::LogLog => LogLogLinkKernel.jet(eta),
986 LinkFunction::Cauchit => CauchitLinkKernel.jet(eta),
987 LinkFunction::Identity => Ok(InverseLinkJet {
988 mu: eta,
989 d1: 1.0,
990 d2: 0.0,
991 d3: 0.0,
992 }),
993 LinkFunction::Log => {
994 let e = eta.clamp(-700.0, 700.0).exp();
1006 Ok(InverseLinkJet {
1007 mu: e,
1008 d1: e,
1009 d2: e,
1010 d3: e,
1011 })
1012 }
1013 LinkFunction::Sas => Err(EstimationError::InvalidInput(
1014 "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1015 )),
1016 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1017 "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1018 .to_string(),
1019 )),
1020 }
1021 }
1022}
1023
1024impl InverseLinkKernel for SasLinkState {
1025 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1026 Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1027 }
1028
1029 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1030 Ok(Some(LinkParamPartials::Sas(
1031 sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1032 )))
1033 }
1034}
1035
1036#[derive(Clone, Copy, Debug)]
1037pub struct BetaLogisticKernel {
1038 pub log_shape_center: f64,
1041 pub epsilon: f64,
1042}
1043
1044impl InverseLinkKernel for BetaLogisticKernel {
1045 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1046 Ok(beta_logistic_inverse_link_jet(
1047 eta,
1048 self.log_shape_center,
1049 self.epsilon,
1050 ))
1051 }
1052
1053 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1054 Ok(Some(LinkParamPartials::Sas(
1055 beta_logistic_inverse_link_jetwith_param_partials(
1056 eta,
1057 self.log_shape_center,
1058 self.epsilon,
1059 ),
1060 )))
1061 }
1062}
1063
1064impl InverseLinkKernel for MixtureLinkState {
1065 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1066 Ok(mixture_inverse_link_jet(self, eta))
1067 }
1068
1069 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1070 Ok(Some(LinkParamPartials::Mixture(
1071 mixture_inverse_link_jetwith_rho_partials(self, eta),
1072 )))
1073 }
1074}
1075
1076impl InverseLinkKernel for InverseLink {
1077 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1078 match self {
1079 InverseLink::Standard(StandardLink::Logit) => LogitLinkKernel.jet(eta),
1080 InverseLink::Standard(StandardLink::Probit) => ProbitLinkKernel.jet(eta),
1081 InverseLink::Standard(StandardLink::CLogLog) => CLogLogLinkKernel.jet(eta),
1082 InverseLink::Standard(StandardLink::LogLog) => LogLogLinkKernel.jet(eta),
1083 InverseLink::Standard(StandardLink::Cauchit) => CauchitLinkKernel.jet(eta),
1084 InverseLink::Standard(StandardLink::Identity) => LinkFunction::Identity.jet(eta),
1085 InverseLink::Standard(StandardLink::Log) => LinkFunction::Log.jet(eta),
1086 InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1087 InverseLink::Sas(state) => state.jet(eta),
1088 InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1089 log_shape_center: state.log_delta,
1090 epsilon: state.epsilon,
1091 }
1092 .jet(eta),
1093 InverseLink::Mixture(state) => state.jet(eta),
1094 }
1095 }
1096
1097 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1098 match self {
1099 InverseLink::Standard(_) => Ok(None),
1100 InverseLink::LatentCLogLog(_) => Ok(None),
1101 InverseLink::Sas(state) => state.param_partials(eta),
1102 InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1103 log_shape_center: state.log_delta,
1104 epsilon: state.epsilon,
1105 }
1106 .param_partials(eta),
1107 InverseLink::Mixture(state) => state.param_partials(eta),
1108 }
1109 }
1110}
1111
1112pub fn inverse_link_jet_for_inverse_link(
1116 link: &InverseLink,
1117 eta: f64,
1118) -> Result<InverseLinkJet, EstimationError> {
1119 link.jet(eta)
1120}
1121
1122pub fn inverse_link_mu_d1_for_inverse_link(
1132 link: &InverseLink,
1133 eta: f64,
1134) -> Result<(f64, f64), EstimationError> {
1135 match link {
1136 InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1137 InverseLink::LatentCLogLog(state) => {
1138 let jet = latent_cloglog_point_jet(state, eta)?;
1139 Ok((jet.mu, jet.d1))
1140 }
1141 InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1142 InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1143 eta,
1144 state.log_delta,
1145 state.epsilon,
1146 )),
1147 InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1148 }
1149}
1150
1151fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1152 match link {
1153 LinkFunction::Identity => Ok((eta, 1.0)),
1154 LinkFunction::Log => {
1155 let e = eta.clamp(-700.0, 700.0).exp();
1160 Ok((e, e))
1161 }
1162 LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1163 LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1164 LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1165 LinkFunction::LogLog => Ok(component_inverse_link_mu_d1(LinkComponent::LogLog, eta)),
1166 LinkFunction::Cauchit => Ok(component_inverse_link_mu_d1(LinkComponent::Cauchit, eta)),
1167 LinkFunction::Sas => Err(EstimationError::InvalidInput(
1168 "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1169 )),
1170 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1171 "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1172 .to_string(),
1173 )),
1174 }
1175}
1176
1177#[inline]
1178fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1179 match component {
1185 LinkComponent::Logit => {
1186 let jet = logit_inverse_link_jet5(eta);
1187 (jet.mu, canonicalzero(jet.d1))
1188 }
1189 LinkComponent::Probit => {
1190 if eta.is_nan() {
1191 return (f64::NAN, f64::NAN);
1192 }
1193 if eta == f64::INFINITY {
1194 return (1.0, 0.0);
1195 }
1196 if eta == f64::NEG_INFINITY {
1197 return (0.0, 0.0);
1198 }
1199 let phi = normal_pdf(eta);
1200 (normal_cdf(eta), canonicalzero(phi))
1201 }
1202 LinkComponent::CLogLog => {
1203 if eta.is_nan() {
1204 return (f64::NAN, f64::NAN);
1205 }
1206 let t = eta.exp();
1207 if !t.is_finite() {
1208 return (1.0, 0.0);
1209 }
1210 (
1211 -(-t).exp_m1(),
1212 canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1213 )
1214 }
1215 LinkComponent::LogLog => {
1216 if eta.is_nan() {
1217 return (f64::NAN, f64::NAN);
1218 }
1219 let r = (-eta).exp();
1220 if !r.is_finite() {
1221 return (0.0, 0.0);
1222 }
1223 (
1224 (-r).exp(),
1225 canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1226 )
1227 }
1228 LinkComponent::Cauchit => {
1229 if eta.is_nan() {
1230 return (f64::NAN, f64::NAN);
1231 }
1232 let den = 1.0 + eta * eta;
1233 let d1 = if eta.is_finite() {
1234 1.0 / (std::f64::consts::PI * den)
1235 } else {
1236 0.0
1237 };
1238 (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1239 }
1240 }
1241}
1242
1243fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1244 let delta_id = sas_delta_from_raw_log_delta(log_delta);
1245 if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1246 return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1247 }
1248 let e = if eta.is_finite() { eta } else { 0.0 };
1249 let a = e.asinh();
1250 let delta = delta_id;
1251 let u_raw = delta * a + epsilon;
1252 let u = tanh_bound(u_raw, SAS_U_CLAMP);
1253 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1254 let s = u.sinh();
1255 let c = u.cosh();
1256 let z = s;
1257 let q = e.hypot(1.0);
1258 let inv_q = 1.0 / q;
1259 let r1 = delta * inv_q;
1260 let u1 = g1 * r1;
1261 let z1 = c * u1;
1262 let base = probit_jet(z);
1265 (base.mu, canonicalzero(base.d1 * z1))
1266}
1267
1268fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1269 let logistic = logistic_uwith_derivatives(eta);
1270 let a = (delta - epsilon).exp();
1271 let b = (delta + epsilon).exp();
1272 let mu = beta_reg_logistic(a, b, logistic);
1273 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1274 (mu, log_d1.exp())
1275}
1276
1277fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1278 let mut mu = 0.0_f64;
1279 let mut d1 = 0.0_f64;
1280 let k = state.components.len().min(state.pi.len());
1281 for i in 0..k {
1282 let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1283 let w = state.pi[i];
1284 mu += w * mu_i;
1285 d1 += w * d1_i;
1286 }
1287 (mu, d1)
1288}
1289
1290#[derive(Clone, Copy)]
1291enum PdfDerivativeOrder {
1292 Third,
1293 Fourth,
1294}
1295
1296impl PdfDerivativeOrder {
1297 fn probit(self, eta: f64) -> f64 {
1298 match self {
1299 Self::Third => probit_pdfthird_derivative(eta),
1300 Self::Fourth => probit_pdffourth_derivative(eta),
1301 }
1302 }
1303
1304 fn component(self, component: LinkComponent, eta: f64) -> f64 {
1305 match self {
1306 Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1307 Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1308 }
1309 }
1310
1311 fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1312 let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1313 Ok(match self {
1314 Self::Third => jet.d4,
1315 Self::Fourth => jet.d5,
1316 })
1317 }
1318
1319 fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1320 match self {
1321 Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1322 Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1323 }
1324 }
1325
1326 fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1327 match self {
1328 Self::Third => {
1329 beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1330 }
1331 Self::Fourth => {
1332 beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1333 }
1334 }
1335 }
1336}
1337
1338fn inverse_link_pdf_derivative_for_inverse_link(
1339 link: &InverseLink,
1340 eta: f64,
1341 order: PdfDerivativeOrder,
1342) -> Result<f64, EstimationError> {
1343 match link {
1344 InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1345 InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1346 InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1347 InverseLink::Standard(StandardLink::Logit) => {
1348 Ok(order.component(LinkComponent::Logit, eta))
1349 }
1350 InverseLink::Standard(StandardLink::CLogLog) => {
1351 Ok(order.component(LinkComponent::CLogLog, eta))
1352 }
1353 InverseLink::Standard(StandardLink::LogLog) => {
1354 Ok(order.component(LinkComponent::LogLog, eta))
1355 }
1356 InverseLink::Standard(StandardLink::Cauchit) => {
1357 Ok(order.component(LinkComponent::Cauchit, eta))
1358 }
1359 InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1360 InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1361 InverseLink::BetaLogistic(state) => {
1362 Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1363 }
1364 InverseLink::Mixture(state) => Ok(state
1365 .components
1366 .iter()
1367 .zip(state.pi.iter())
1368 .map(|(&component, &weight)| weight * order.component(component, eta))
1369 .sum()),
1370 }
1371}
1372
1373pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1374 link: &InverseLink,
1375 eta: f64,
1376) -> Result<f64, EstimationError> {
1377 inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1393}
1394
1395pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1401 link: &InverseLink,
1402 eta: f64,
1403) -> Result<f64, EstimationError> {
1404 inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1405}
1406
1407#[inline]
1408fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1409 const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1413
1414 let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1415 let hazard = z.exp();
1416 let survival = (-hazard).exp();
1417 if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1418 return InverseLinkJet {
1419 mu: survival,
1420 d1: 0.0,
1421 d2: 0.0,
1422 d3: 0.0,
1423 };
1424 }
1425
1426 let d1 = -hazard * survival;
1427 let d2 = hazard * (hazard - 1.0) * survival;
1428 let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1429 InverseLinkJet {
1430 mu: survival,
1431 d1,
1432 d2,
1433 d3,
1434 }
1435}
1436
1437pub fn inverse_link_jet_for_family(
1438 spec: &LikelihoodSpec,
1439 eta: f64,
1440) -> Result<InverseLinkJet, EstimationError> {
1441 if matches!(spec.response, ResponseFamily::RoystonParmar) {
1444 return Ok(royston_parmar_inverse_link_jet(eta));
1445 }
1446 spec.link.jet(eta)
1447}
1448
1449#[inline]
1456fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1457 let e = eta.exp();
1458 InverseLinkJet {
1459 mu: e,
1460 d1: e,
1461 d2: e,
1462 d3: e,
1463 }
1464}
1465
1466pub fn inverse_link_jet_for_family_public(
1479 spec: &LikelihoodSpec,
1480 eta: f64,
1481) -> Result<InverseLinkJet, EstimationError> {
1482 if matches!(spec.response, ResponseFamily::RoystonParmar) {
1483 return Ok(royston_parmar_inverse_link_jet(eta));
1484 }
1485 if let InverseLink::Standard(StandardLink::Log) = spec.link {
1486 return Ok(log_inverse_link_jet_exact(eta));
1487 }
1488 spec.link.jet(eta)
1489}
1490
1491#[inline]
1492pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1493 let mut mu = 0.0_f64;
1494 let mut d1 = 0.0_f64;
1495 let mut d2 = 0.0_f64;
1496 let mut d3 = 0.0_f64;
1497 let k = state.components.len().min(state.pi.len());
1498 for i in 0..k {
1499 let jet = component_inverse_link_jet(state.components[i], eta);
1500 let w = state.pi[i];
1501 mu += w * jet.mu;
1502 d1 += w * jet.d1;
1503 d2 += w * jet.d2;
1504 d3 += w * jet.d3;
1505 }
1506 InverseLinkJet { mu, d1, d2, d3 }
1507}
1508
1509pub fn mixture_inverse_link_jetwith_rho_partials(
1517 state: &MixtureLinkState,
1518 eta: f64,
1519) -> MixtureJetWithRhoPartials {
1520 let k = state.components.len().min(state.pi.len());
1521 let m = k.saturating_sub(1);
1522 let mut djet_drho = vec![
1523 InverseLinkJet {
1524 mu: 0.0,
1525 d1: 0.0,
1526 d2: 0.0,
1527 d3: 0.0,
1528 };
1529 m
1530 ];
1531 let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1532 MixtureJetWithRhoPartials { jet, djet_drho }
1533}
1534
1535pub fn mixture_inverse_link_jetwith_rho_partials_into(
1538 state: &MixtureLinkState,
1539 eta: f64,
1540 out: &mut [InverseLinkJet],
1541) -> InverseLinkJet {
1542 let k = state.components.len().min(state.pi.len());
1543 let m = k.saturating_sub(1);
1544 assert!(
1545 out.len() >= m,
1546 "rho-partial output buffer too small: got {}, need {}",
1547 out.len(),
1548 m
1549 );
1550 let mut mixed = InverseLinkJet {
1551 mu: 0.0,
1552 d1: 0.0,
1553 d2: 0.0,
1554 d3: 0.0,
1555 };
1556 for i in 0..k {
1557 let jet_i = component_inverse_link_jet(state.components[i], eta);
1558 let w = state.pi[i];
1559 mixed.mu += w * jet_i.mu;
1560 mixed.d1 += w * jet_i.d1;
1561 mixed.d2 += w * jet_i.d2;
1562 mixed.d3 += w * jet_i.d3;
1563 if i < m {
1566 out[i] = jet_i;
1567 }
1568 }
1569 for j in 0..m {
1570 let pi_j = state.pi[j];
1571 let cj = out[j];
1572 out[j] = InverseLinkJet {
1573 mu: pi_j * (cj.mu - mixed.mu),
1574 d1: pi_j * (cj.d1 - mixed.d1),
1575 d2: pi_j * (cj.d2 - mixed.d2),
1576 d3: pi_j * (cj.d3 - mixed.d3),
1577 };
1578 }
1579 mixed
1580}
1581
1582#[derive(Clone, Copy)]
1583struct LogisticU {
1584 u: f64,
1585 one_minus_u: f64,
1586 ln_u: f64,
1587 ln_one_minus_u: f64,
1588 du: f64,
1589 use_upper_tail: bool,
1590}
1591
1592#[inline]
1593fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1594 let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1595 let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1596 let u = ln_u.exp();
1597 let one_minus_u = ln_one_minus_u.exp();
1598 let du = (ln_u + ln_one_minus_u).exp();
1599 LogisticU {
1600 u,
1601 one_minus_u,
1602 ln_u,
1603 ln_one_minus_u,
1604 du,
1605 use_upper_tail: eta >= 0.0,
1606 }
1607}
1608
1609#[inline]
1610fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1611 if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1612 return f64::NAN;
1613 }
1614 if logistic.ln_u == f64::NEG_INFINITY {
1615 return 0.0;
1616 }
1617 if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1618 return 1.0;
1619 }
1620 if logistic.use_upper_tail {
1621 1.0 - beta_reg(b, a, logistic.one_minus_u)
1622 } else {
1623 beta_reg(a, b, logistic.u)
1624 }
1625}
1626
1627#[inline]
1628fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1629 if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1630 return (f64::NAN, f64::NAN, f64::NAN);
1631 }
1632 if logistic.use_upper_tail {
1633 let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1634 (1.0 - tail, -dtail_da, -dtail_db)
1635 } else {
1636 beta_reg_with_shape_partials(a, b, logistic.u)
1637 }
1638}
1639
1640#[inline]
1641fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1642 a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1643}
1644
1645#[derive(Clone, Copy)]
1646struct ShapeDual {
1647 v: f64,
1648 da: f64,
1649 db: f64,
1650}
1651
1652impl ShapeDual {
1653 #[inline]
1654 fn constant(v: f64) -> Self {
1655 Self {
1656 v,
1657 da: 0.0,
1658 db: 0.0,
1659 }
1660 }
1661
1662 #[inline]
1663 fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1664 Self { v, da, db }
1665 }
1666
1667 #[inline]
1668 fn clamp_small(self, floor: f64) -> Self {
1669 if self.v.abs() < floor {
1670 Self::constant(floor)
1671 } else {
1672 self
1673 }
1674 }
1675}
1676
1677impl std::ops::Add for ShapeDual {
1678 type Output = Self;
1679
1680 #[inline]
1681 fn add(self, rhs: Self) -> Self {
1682 Self {
1683 v: self.v + rhs.v,
1684 da: self.da + rhs.da,
1685 db: self.db + rhs.db,
1686 }
1687 }
1688}
1689
1690impl std::ops::Sub for ShapeDual {
1691 type Output = Self;
1692
1693 #[inline]
1694 fn sub(self, rhs: Self) -> Self {
1695 Self {
1696 v: self.v - rhs.v,
1697 da: self.da - rhs.da,
1698 db: self.db - rhs.db,
1699 }
1700 }
1701}
1702
1703impl std::ops::Mul for ShapeDual {
1704 type Output = Self;
1705
1706 #[inline]
1707 fn mul(self, rhs: Self) -> Self {
1708 Self {
1709 v: self.v * rhs.v,
1710 da: self.da * rhs.v + self.v * rhs.da,
1711 db: self.db * rhs.v + self.v * rhs.db,
1712 }
1713 }
1714}
1715
1716impl std::ops::Div for ShapeDual {
1717 type Output = Self;
1718
1719 #[inline]
1720 fn div(self, rhs: Self) -> Self {
1721 let inv = 1.0 / rhs.v;
1722 let inv2 = inv * inv;
1723 Self {
1724 v: self.v * inv,
1725 da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1726 db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1727 }
1728 }
1729}
1730
1731impl std::ops::Neg for ShapeDual {
1732 type Output = Self;
1733
1734 #[inline]
1735 fn neg(self) -> Self {
1736 ShapeDual {
1737 v: -self.v,
1738 da: -self.da,
1739 db: -self.db,
1740 }
1741 }
1742}
1743
1744#[inline]
1745fn shape_dual(v: f64) -> ShapeDual {
1746 ShapeDual::constant(v)
1747}
1748
1749fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1753 if x0 <= 0.0 {
1754 return (0.0, 0.0, 0.0);
1755 }
1756 if x0 >= 1.0 {
1757 return (1.0, 0.0, 0.0);
1758 }
1759
1760 let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1761 let (a, b, x) = if symm_transform {
1762 (
1763 ShapeDual::from_value_partials(b0, 0.0, 1.0),
1764 ShapeDual::from_value_partials(a0, 1.0, 0.0),
1765 1.0 - x0,
1766 )
1767 } else {
1768 (
1769 ShapeDual::from_value_partials(a0, 1.0, 0.0),
1770 ShapeDual::from_value_partials(b0, 0.0, 1.0),
1771 x0,
1772 )
1773 };
1774
1775 let ln_x = x.ln();
1776 let ln_1mx = (1.0 - x).ln();
1777 let psi_ab = digamma(a.v + b.v);
1778 let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1779 - statrs::function::gamma::ln_gamma(a.v)
1780 - statrs::function::gamma::ln_gamma(b.v)
1781 + a.v * ln_x
1782 + b.v * ln_1mx;
1783 let bt_v = log_bt.exp();
1784 let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1785 let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1786 let bt = ShapeDual {
1787 v: bt_v,
1788 da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1789 db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1790 };
1791
1792 let eps = 0.00000000000000011102230246251565;
1793 let fpmin = f64::MIN_POSITIVE / eps;
1794 let one = shape_dual(1.0);
1795 let qab = a + b;
1796 let qap = a + one;
1797 let qam = a - one;
1798 let mut c = one;
1799 let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1800 d = one / d;
1801 let mut h = d;
1802
1803 for m in 1..141 {
1804 let mf = f64::from(m);
1805 let m2 = mf * 2.0;
1806 let md = shape_dual(mf);
1807 let m2d = shape_dual(m2);
1808 let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1809 d = (one + aa * d).clamp_small(fpmin);
1810 c = (one + aa / c).clamp_small(fpmin);
1811 d = one / d;
1812 h = h * d * c;
1813
1814 aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1815 d = (one + aa * d).clamp_small(fpmin);
1816 c = (one + aa / c).clamp_small(fpmin);
1817 d = one / d;
1818 let del = d * c;
1819 h = h * del;
1820
1821 if (del.v - 1.0).abs() <= eps {
1822 let reg = bt * h / a;
1823 return if symm_transform {
1824 (1.0 - reg.v, -reg.da, -reg.db)
1825 } else {
1826 (reg.v, reg.da, reg.db)
1827 };
1828 }
1829 }
1830 let reg = bt * h / a;
1831 if symm_transform {
1832 (1.0 - reg.v, -reg.da, -reg.db)
1833 } else {
1834 (reg.v, reg.da, reg.db)
1835 }
1836}
1837
1838pub fn beta_logistic_inverse_link_jet(
1848 eta: f64,
1849 log_shape_center: f64,
1850 epsilon: f64,
1851) -> InverseLinkJet {
1852 let logistic = logistic_uwith_derivatives(eta);
1853 let a = (log_shape_center - epsilon).exp();
1854 let b = (log_shape_center + epsilon).exp();
1855 let mu = beta_reg_logistic(a, b, logistic);
1856 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1857 let d1 = log_d1.exp();
1858 let t = a * logistic.one_minus_u - b * logistic.u;
1859 let d2 = d1 * t;
1860 let d3 = d1 * (t * t - (a + b) * logistic.du);
1861 InverseLinkJet { mu, d1, d2, d3 }
1862}
1863
1864pub fn beta_logistic_inverse_link_pdfthird_derivative(
1865 eta: f64,
1866 log_shape_center: f64,
1867 epsilon: f64,
1868) -> f64 {
1869 let logistic = logistic_uwith_derivatives(eta);
1892 let a = (log_shape_center - epsilon).exp();
1893 let b = (log_shape_center + epsilon).exp();
1894 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1895 let d1 = log_d1.exp();
1896 let c = a + b;
1897 let t = a * logistic.one_minus_u - b * logistic.u;
1898 let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1899 d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1900}
1901
1902pub fn beta_logistic_inverse_link_pdffourth_derivative(
1910 eta: f64,
1911 log_shape_center: f64,
1912 epsilon: f64,
1913) -> f64 {
1914 let logistic = logistic_uwith_derivatives(eta);
1915 let a = (log_shape_center - epsilon).exp();
1916 let b = (log_shape_center + epsilon).exp();
1917 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1918 let d1 = log_d1.exp();
1919 let c = a + b;
1920 let t = a * logistic.one_minus_u - b * logistic.u;
1921 let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1922 let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1923 let t2 = t * t;
1924 d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1925 + 3.0 * c * c * logistic.du * logistic.du
1926 - c * u3)
1927}
1928
1929pub fn beta_logistic_inverse_link_jetwith_param_partials(
1930 eta: f64,
1931 log_shape_center: f64,
1932 epsilon: f64,
1933) -> SasJetWithParamPartials {
1934 let logistic = logistic_uwith_derivatives(eta);
1935 let a = (log_shape_center - epsilon).exp();
1936 let b = (log_shape_center + epsilon).exp();
1937 let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1938 let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1939 let dmu_depsilon = -a * dmu_da + b * dmu_db;
1940 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1941 let d1 = log_d1.exp();
1942 let t = a * logistic.one_minus_u - b * logistic.u;
1943 let d2 = d1 * t;
1944 let k = t * t - (a + b) * logistic.du;
1945 let d3 = d1 * k;
1946 let jet = InverseLinkJet { mu, d1, d2, d3 };
1947
1948 let psi_a = digamma(a);
1949 let psi_b = digamma(b);
1950 let psi_ab = digamma(a + b);
1951 let la = logistic.ln_u - psi_a + psi_ab;
1952 let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1953
1954 let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1955 let logd1_p = a_p * la + b_p * lb;
1956 let d1_p = d1 * logd1_p;
1957 let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1958 let d2_p = d1_p * t + d1 * t_p;
1959 let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1960 let d3_p = d1_p * k + d1 * k_p;
1961 InverseLinkJet {
1962 mu: dmu,
1963 d1: d1_p,
1964 d2: d2_p,
1965 d3: d3_p,
1966 }
1967 };
1968 let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1969 let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1970 SasJetWithParamPartials {
1971 jet,
1972 djet_depsilon,
1973 djet_dlog_delta: djet_dlog_shape_center,
1974 }
1975}
1976
1977pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1981 let delta_id = sas_delta_from_raw_log_delta(log_delta);
1982 if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1983 return component_inverse_link_jet(LinkComponent::Probit, eta);
1984 }
1985 let e = if eta.is_finite() { eta } else { 0.0 };
1986 let a = e.asinh();
1987 let delta = delta_id;
1988 let u_raw = delta * a + epsilon;
1989 let u = tanh_bound(u_raw, SAS_U_CLAMP);
1990 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1991 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1992 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1993 let s = u.sinh();
1994 let c = u.cosh();
1995 let z = s;
1996 let q = e.hypot(1.0);
1997 let inv_q = 1.0 / q;
1998 let inv_q2 = inv_q * inv_q;
1999 let inv_q3 = inv_q2 * inv_q;
2000 let inv_q5 = inv_q3 * inv_q2;
2001 let r1 = delta * inv_q;
2002 let r2 = -delta * e * inv_q3;
2003 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2004 let u1 = g1 * r1;
2005 let u2 = g2 * r1 * r1 + g1 * r2;
2006 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2007 let z1 = c * u1;
2008 let z2 = s * u1 * u1 + c * u2;
2009 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2010 let base = probit_jet(z);
2011 chain_inverse_link_jet(base, z1, z2, z3)
2012}
2013
2014pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2015 let e = if eta.is_finite() { eta } else { 0.0 };
2051 let a = e.asinh();
2052 let delta = sas_delta_from_raw_log_delta(log_delta);
2053 let u_raw = delta * a + epsilon;
2054 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2055 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2056 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2057 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2058 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2059 let s = u.sinh();
2060 let c = u.cosh();
2061 let z = s;
2062 let base = probit_jet(z);
2063 let q = e.hypot(1.0);
2064 let inv_q = 1.0 / q;
2065 let inv_q2 = inv_q * inv_q;
2066 let inv_q3 = inv_q2 * inv_q;
2067 let inv_q5 = inv_q3 * inv_q2;
2068 let inv_q7 = inv_q5 * inv_q2;
2069 let r1 = delta * inv_q;
2070 let r2 = -delta * e * inv_q3;
2071 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2072 let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2073 let u1 = g1 * r1;
2074 let u2 = g2 * r1 * r1 + g1 * r2;
2075 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2076 let u4 = g4 * r1.powi(4)
2077 + 6.0 * g3 * r1 * r1 * r2
2078 + 3.0 * g2 * r2 * r2
2079 + 4.0 * g2 * r1 * r3
2080 + g1 * r4;
2081 let z1 = c * u1;
2082 let z2 = s * u1 * u1 + c * u2;
2083 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2084 let z4 =
2085 s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2086 let base4 = probit_pdfthird_derivative(z);
2087 let out = base4 * z1.powi(4)
2088 + 6.0 * base.d3 * z1 * z1 * z2
2089 + 3.0 * base.d2 * z2 * z2
2090 + 4.0 * base.d2 * z1 * z3
2091 + base.d1 * z4;
2092 canonicalzero(out)
2093}
2094
2095pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2112 let e = if eta.is_finite() { eta } else { 0.0 };
2113 let a = e.asinh();
2114 let delta = sas_delta_from_raw_log_delta(log_delta);
2115 let u_raw = delta * a + epsilon;
2116 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2117 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2118 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2119 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2120 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2121 let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2122 let s = u.sinh();
2123 let c = u.cosh();
2124 let z = s;
2125
2126 let base = probit_jet(z);
2128 let phi3 = probit_pdfthird_derivative(z); let phi4 = probit_pdffourth_derivative(z); let q = e.hypot(1.0);
2133 let inv_q = 1.0 / q;
2134 let inv_q2 = inv_q * inv_q;
2135 let inv_q3 = inv_q2 * inv_q;
2136 let inv_q5 = inv_q3 * inv_q2;
2137 let inv_q7 = inv_q5 * inv_q2;
2138 let inv_q9 = inv_q7 * inv_q2;
2139
2140 let r1 = delta * inv_q;
2141 let r2 = -delta * e * inv_q3;
2142 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2143 let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2144 let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2145
2146 let u1 = g1 * r1;
2148 let u2 = g2 * r1 * r1 + g1 * r2;
2149 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2150 let u4 = g4 * r1.powi(4)
2151 + 6.0 * g3 * r1 * r1 * r2
2152 + 3.0 * g2 * r2 * r2
2153 + 4.0 * g2 * r1 * r3
2154 + g1 * r4;
2155 let u5 = g5 * r1.powi(5)
2156 + 10.0 * g4 * r1 * r1 * r1 * r2
2157 + 15.0 * g3 * r1 * r2 * r2
2158 + 10.0 * g3 * r1 * r1 * r3
2159 + 10.0 * g2 * r2 * r3
2160 + 5.0 * g2 * r1 * r4
2161 + g1 * r5;
2162
2163 let z1 = c * u1;
2165 let z2 = s * u1 * u1 + c * u2;
2166 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2167 let z4 =
2168 s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2169 let z5 = c * u1.powi(5)
2170 + 10.0 * s * u1 * u1 * u1 * u2
2171 + 15.0 * c * u1 * u2 * u2
2172 + 10.0 * c * u1 * u1 * u3
2173 + 10.0 * s * u2 * u3
2174 + 5.0 * s * u1 * u4
2175 + c * u5;
2176
2177 let out = phi4 * z1.powi(5)
2180 + 10.0 * phi3 * z1 * z1 * z1 * z2
2181 + 15.0 * base.d3 * z1 * z2 * z2
2182 + 10.0 * base.d3 * z1 * z1 * z3
2183 + 10.0 * base.d2 * z2 * z3
2184 + 5.0 * base.d2 * z1 * z4
2185 + base.d1 * z5;
2186 canonicalzero(out)
2187}
2188
2189pub fn sas_inverse_link_jetwith_param_partials(
2190 eta: f64,
2191 epsilon: f64,
2192 log_delta: f64,
2193) -> SasJetWithParamPartials {
2194 let e = if eta.is_finite() { eta } else { 0.0 };
2195 let a = e.asinh();
2196 let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2197 let delta = ld_eff.exp();
2198 let ddelta_draw = delta * dld_eff_draw;
2199 let u_raw = delta * a + epsilon;
2200 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2201 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2202 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2203 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2204 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2205 let s = u.sinh();
2206 let c = u.cosh();
2207 let z = s;
2208 let q = e.hypot(1.0);
2209 let inv_q = 1.0 / q;
2210 let inv_q2 = inv_q * inv_q;
2211 let inv_q3 = inv_q2 * inv_q;
2212 let inv_q5 = inv_q3 * inv_q2;
2213 let a1 = inv_q;
2214 let a2 = -e * inv_q3;
2215 let a3 = (2.0 * e * e - 1.0) * inv_q5;
2216 let r1 = delta * a1;
2217 let r2 = delta * a2;
2218 let r3 = delta * a3;
2219 let u1 = g1 * r1;
2220 let u2 = g2 * r1 * r1 + g1 * r2;
2221 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2222 let z1 = c * u1;
2223 let z2 = s * u1 * u1 + c * u2;
2224 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2225
2226 let base = probit_jet(z);
2227 let jet = chain_inverse_link_jet(base, z1, z2, z3);
2228
2229 let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2232 let z_t = c * u_t;
2233 let z1_t = s * u_t * u1 + c * u1_t;
2234 let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2235 let z3_t = s * u_t * u1 * u1 * u1
2236 + 3.0 * c * u1 * u1 * u1_t
2237 + 3.0 * c * u_t * u1 * u2
2238 + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2239 + s * u_t * u3
2240 + c * u3_t;
2241
2242 InverseLinkJet {
2243 mu: base.d1 * z_t,
2244 d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2245 d2: base.d3 * z_t * z1 * z1
2246 + 2.0 * base.d2 * z1 * z1_t
2247 + base.d2 * z_t * z2
2248 + base.d1 * z2_t,
2249 d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2250 + 3.0 * base.d3 * z1 * z1 * z1_t
2251 + 3.0 * base.d3 * z_t * z1 * z2
2252 + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2253 + base.d2 * z_t * z3
2254 + base.d1 * z3_t,
2255 }
2256 };
2257
2258 let rt_eps = 1.0;
2260 let r1t_eps = 0.0;
2261 let r2t_eps = 0.0;
2262 let r3t_eps = 0.0;
2263 let u_eps = g1 * rt_eps;
2264 let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2265 let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2266 let u3_eps = g4 * rt_eps * r1 * r1 * r1
2267 + 3.0 * g3 * r1 * r1 * r1t_eps
2268 + 3.0 * g3 * rt_eps * r1 * r2
2269 + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2270 + g2 * rt_eps * r3
2271 + g1 * r3t_eps;
2272 let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2273
2274 let rt_ld = ddelta_draw * a;
2276 let r1t_ld = ddelta_draw * a1;
2277 let r2t_ld = ddelta_draw * a2;
2278 let r3t_ld = ddelta_draw * a3;
2279 let u_ld = g1 * rt_ld;
2280 let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2281 let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2282 let u3_ld = g4 * rt_ld * r1 * r1 * r1
2283 + 3.0 * g3 * r1 * r1 * r1t_ld
2284 + 3.0 * g3 * rt_ld * r1 * r2
2285 + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2286 + g2 * rt_ld * r3
2287 + g1 * r3t_ld;
2288 let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2289
2290 SasJetWithParamPartials {
2291 jet,
2292 djet_depsilon,
2293 djet_dlog_delta,
2294 }
2295}
2296
2297#[cfg(test)]
2298mod tests {
2299 use super::*;
2300 use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2301
2302 #[test]
2303 fn softmax_jacobian_matchesfd() {
2304 let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2305 let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2306 let h = 1e-6;
2307 for j in 0..rho.len() {
2308 let mut rp = rho.clone();
2309 rp[j] += h;
2310 let mut rm = rho.clone();
2311 rm[j] -= h;
2312 let pp = softmax_last_fixedzero(&rp);
2313 let pm = softmax_last_fixedzero(&rm);
2314 let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2315 for k in 0..pi.len() {
2316 let err = (jac[[k, j]] - fd[k]).abs();
2317 assert_eq!(
2318 jac[[k, j]].signum(),
2319 fd[k].signum(),
2320 "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2321 jac[[k, j]],
2322 fd[k]
2323 );
2324 assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2325 }
2326 }
2327 }
2328
2329 #[test]
2330 fn mixture_jet_rho_partials_matchfd() {
2331 let spec = MixtureLinkSpec {
2332 components: vec![
2333 LinkComponent::Probit,
2334 LinkComponent::Logit,
2335 LinkComponent::CLogLog,
2336 LinkComponent::Cauchit,
2337 ],
2338 initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2339 };
2340 let state = state_fromspec(&spec).expect("state");
2341 let eta = 0.35;
2342 let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2343 let h = 1e-6;
2344 for j in 0..state.rho.len() {
2345 let mut rp = state.rho.clone();
2346 rp[j] += h;
2347 let sp = MixtureLinkSpec {
2348 components: state.components.clone(),
2349 initial_rho: rp,
2350 };
2351 let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2352 let mut rm = state.rho.clone();
2353 rm[j] -= h;
2354 let sm = MixtureLinkSpec {
2355 components: state.components.clone(),
2356 initial_rho: rm,
2357 };
2358 let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2359 let fd = InverseLinkJet {
2360 mu: (jp.mu - jm.mu) / (2.0 * h),
2361 d1: (jp.d1 - jm.d1) / (2.0 * h),
2362 d2: (jp.d2 - jm.d2) / (2.0 * h),
2363 d3: (jp.d3 - jm.d3) / (2.0 * h),
2364 };
2365 let an = out.djet_drho[j];
2366 assert_eq!(an.mu.signum(), fd.mu.signum());
2367 assert_eq!(an.d1.signum(), fd.d1.signum());
2368 assert_eq!(an.d2.signum(), fd.d2.signum());
2369 assert_eq!(an.d3.signum(), fd.d3.signum());
2370 assert!((an.mu - fd.mu).abs() < 1e-6);
2371 assert!((an.d1 - fd.d1).abs() < 1e-6);
2372 assert!((an.d2 - fd.d2).abs() < 1e-6);
2373 assert!((an.d3 - fd.d3).abs() < 1e-6);
2374 }
2375 }
2376
2377 #[test]
2378 fn sas_param_partials_matchfd() {
2379 let eta = 0.37;
2380 let epsilon = -0.12;
2381 let log_delta = 0.21;
2382 let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2383 let h = 1e-6;
2384
2385 let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2386 let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2387 let fd_ep = InverseLinkJet {
2388 mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2389 d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2390 d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2391 d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2392 };
2393 assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2394 assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2395 assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2396 assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2397 assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2398 assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2399 assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2400 assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2401
2402 let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2403 let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2404 let fd_ld = InverseLinkJet {
2405 mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2406 d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2407 d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2408 d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2409 };
2410 assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2411 assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2412 assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2413 assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2414 assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2415 assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2416 assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2417 assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2418 }
2419
2420 #[test]
2421 fn sas_jet_extreme_inputs_stay_finite() {
2422 let cases = [
2423 (-1e6, 0.0, 0.0),
2424 (1e6, 0.0, 0.0),
2425 (3.0, 12.0, 12.0),
2426 (-3.0, -12.0, -12.0),
2427 (0.5, 40.0, 10.0),
2428 (0.5, -40.0, -10.0),
2429 ];
2430 for (eta, eps, log_delta) in cases {
2431 let j = sas_inverse_link_jet(eta, eps, log_delta);
2432 assert!(j.mu.is_finite());
2433 assert!(j.d1.is_finite());
2434 assert!(j.d2.is_finite());
2435 assert!(j.d3.is_finite());
2436 let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2437 assert!(p.djet_depsilon.mu.is_finite());
2438 assert!(p.djet_depsilon.d1.is_finite());
2439 assert!(p.djet_depsilon.d2.is_finite());
2440 assert!(p.djet_depsilon.d3.is_finite());
2441 assert!(p.djet_dlog_delta.mu.is_finite());
2442 assert!(p.djet_dlog_delta.d1.is_finite());
2443 assert!(p.djet_dlog_delta.d2.is_finite());
2444 assert!(p.djet_dlog_delta.d3.is_finite());
2445 }
2446 }
2447
2448 #[test]
2449 fn sas_param_partials_remain_finite_in_extreme_region() {
2450 let eta = 10.0;
2451 let epsilon = -60.0;
2452 let log_delta = 40.0;
2453 let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2454 assert!(j.djet_depsilon.mu.is_finite());
2455 assert!(j.djet_depsilon.d1.is_finite());
2456 assert!(j.djet_depsilon.d2.is_finite());
2457 assert!(j.djet_depsilon.d3.is_finite());
2458 assert!(j.djet_dlog_delta.mu.is_finite());
2459 assert!(j.djet_dlog_delta.d1.is_finite());
2460 assert!(j.djet_dlog_delta.d2.is_finite());
2461 assert!(j.djet_dlog_delta.d3.is_finite());
2462 }
2463
2464 #[test]
2465 fn sas_eta_jets_matchfd() {
2466 let eta = -0.43;
2467 let epsilon = 0.27;
2468 let log_delta = -0.31;
2469 let h = 1e-5;
2470 let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2471 let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2472 let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2473 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2474 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2475 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2476 assert_eq!(j0.d1.signum(), d1fd.signum());
2477 assert_eq!(j0.d2.signum(), d2fd.signum());
2478 assert_eq!(j0.d3.signum(), d3fd.signum());
2479 assert!((j0.d1 - d1fd).abs() < 5e-5);
2480 assert!((j0.d2 - d2fd).abs() < 2e-4);
2481 assert!((j0.d3 - d3fd).abs() < 1e-3);
2482 }
2483
2484 #[test]
2485 fn family_dispatch_resolves_parameterized_links_from_spec() {
2486 let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2491 let sas_spec = gam_problem::LikelihoodSpec {
2492 response: gam_problem::ResponseFamily::Binomial,
2493 link: InverseLink::Sas(sas_state),
2494 };
2495 let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2496 assert!(sas_jet.mu.is_finite());
2497 assert!(sas_jet.d1.is_finite());
2498
2499 let mix_state = MixtureLinkState {
2500 components: vec![LinkComponent::Logit, LinkComponent::Probit],
2501 rho: ndarray::array![0.0],
2502 pi: ndarray::array![0.5, 0.5],
2503 };
2504 let mix_spec = gam_problem::LikelihoodSpec {
2505 response: gam_problem::ResponseFamily::Binomial,
2506 link: InverseLink::Mixture(mix_state),
2507 };
2508 let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2509 assert!(mix_jet.mu.is_finite());
2510 assert!(mix_jet.d1.is_finite());
2511 }
2512
2513 #[test]
2514 fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2515 let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2516 for eta in etas {
2517 let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2518 let expected_mu = gam_linalg::utils::stable_logistic(eta);
2519 let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2520 - gam_linalg::utils::stable_softplus(eta))
2521 .exp();
2522 assert!(
2523 (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2524 "mu mismatch at eta={eta}: got {}, expected {}",
2525 j_bl.mu,
2526 expected_mu
2527 );
2528 assert!(
2529 (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2530 "d1 mismatch at eta={eta}: got {}, expected {}",
2531 j_bl.d1,
2532 expected_d1
2533 );
2534 assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2535 }
2536
2537 let eta = 0.42;
2538 let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2539 let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2540 assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2541 assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2542 }
2543
2544 #[test]
2545 fn beta_logistic_eta_jets_matchfd() {
2546 let eta = -0.31;
2547 let delta = 0.27;
2548 let epsilon = -0.19;
2549 let h = 1e-5;
2550 let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2551 let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2552 let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2553 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2554 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2555 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2556 assert_eq!(j0.d1.signum(), d1fd.signum());
2557 assert_eq!(j0.d2.signum(), d2fd.signum());
2558 assert_eq!(j0.d3.signum(), d3fd.signum());
2559 assert!((j0.d1 - d1fd).abs() < 5e-5);
2560 assert!((j0.d2 - d2fd).abs() < 5e-5);
2561 assert!((j0.d3 - d3fd).abs() < 2e-4);
2562 }
2563
2564 #[test]
2565 fn standard_kernel_structs_match_component_jets() {
2566 let eta = 0.73;
2567 assert_eq!(
2568 ProbitLinkKernel.jet(eta).expect("probit"),
2569 component_inverse_link_jet(LinkComponent::Probit, eta)
2570 );
2571 assert_eq!(
2572 LogitLinkKernel.jet(eta).expect("logit"),
2573 component_inverse_link_jet(LinkComponent::Logit, eta)
2574 );
2575 assert_eq!(
2576 CLogLogLinkKernel.jet(eta).expect("cloglog"),
2577 component_inverse_link_jet(LinkComponent::CLogLog, eta)
2578 );
2579 assert_eq!(
2580 LogLogLinkKernel.jet(eta).expect("loglog"),
2581 component_inverse_link_jet(LinkComponent::LogLog, eta)
2582 );
2583 assert_eq!(
2584 CauchitLinkKernel.jet(eta).expect("cauchit"),
2585 component_inverse_link_jet(LinkComponent::Cauchit, eta)
2586 );
2587 }
2588
2589 #[test]
2590 fn all_component_eta_jets_matchfd() {
2591 let components = [
2592 LinkComponent::Logit,
2593 LinkComponent::Probit,
2594 LinkComponent::CLogLog,
2595 LinkComponent::LogLog,
2596 LinkComponent::Cauchit,
2597 ];
2598 let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2599 let h = 1e-5;
2600 for c in components {
2601 for &eta in &points {
2602 let j0 = component_inverse_link_jet(c, eta);
2603 let jp = component_inverse_link_jet(c, eta + h);
2604 let jm = component_inverse_link_jet(c, eta - h);
2605 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2606 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2607 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2608 let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2609 1.2e-4
2610 } else {
2611 5e-5
2612 };
2613 let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2614 4e-4
2615 } else {
2616 1.2e-4
2617 };
2618 let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2619 1.2e-3
2620 } else {
2621 4e-4
2622 };
2623 if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2624 assert_eq!(
2625 j0.d1.signum(),
2626 d1fd.signum(),
2627 "d1 sign mismatch for {c:?} eta={eta}"
2628 );
2629 }
2630 if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2631 assert_eq!(
2632 j0.d2.signum(),
2633 d2fd.signum(),
2634 "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2635 j0.d2,
2636 d2fd
2637 );
2638 }
2639 if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2640 assert_eq!(
2641 j0.d3.signum(),
2642 d3fd.signum(),
2643 "d3 sign mismatch for {c:?} eta={eta}"
2644 );
2645 }
2646 assert!(
2647 (j0.d1 - d1fd).abs() < d1_tol,
2648 "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2649 j0.d1,
2650 d1fd
2651 );
2652 assert!(
2653 (j0.d2 - d2fd).abs() < d2_tol,
2654 "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2655 j0.d2,
2656 d2fd
2657 );
2658 assert!(
2659 (j0.d3 - d3fd).abs() < d3_tol,
2660 "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2661 j0.d3,
2662 d3fd
2663 );
2664 }
2665 }
2666 }
2667
2668 #[test]
2669 fn sas_center_matches_probit_at_delta1_epsilon0() {
2670 let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2671 for eta in etas {
2672 let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2673 let probit = ProbitLinkKernel.jet(eta).expect("probit");
2674 assert!(
2677 (sas.mu - probit.mu).abs() < 6e-4,
2678 "mu mismatch at eta={eta}"
2679 );
2680 assert!(
2681 (sas.d1 - probit.d1).abs() < 6e-4,
2682 "d1 mismatch at eta={eta}"
2683 );
2684 assert!(
2685 (sas.d2 - probit.d2).abs() < 2e-3,
2686 "d2 mismatch at eta={eta}"
2687 );
2688 assert!(
2689 (sas.d3 - probit.d3).abs() < 4e-3,
2690 "d3 mismatch at eta={eta}"
2691 );
2692 }
2693 }
2694
2695 #[test]
2696 fn beta_logistic_param_partials_matchfd() {
2697 let eta = -0.41;
2698 let delta = 0.23;
2699 let epsilon = -0.17;
2700 let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2701 let h = 1e-6;
2702
2703 let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2704 let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2705 let fd_delta = InverseLinkJet {
2706 mu: (dp.mu - dm.mu) / (2.0 * h),
2707 d1: (dp.d1 - dm.d1) / (2.0 * h),
2708 d2: (dp.d2 - dm.d2) / (2.0 * h),
2709 d3: (dp.d3 - dm.d3) / (2.0 * h),
2710 };
2711 assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2712 assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2713 assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2714 assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2715 assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2716 assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2717 assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2718 assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2719
2720 let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2721 let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2722 let fd_epsilon = InverseLinkJet {
2723 mu: (ep.mu - em.mu) / (2.0 * h),
2724 d1: (ep.d1 - em.d1) / (2.0 * h),
2725 d2: (ep.d2 - em.d2) / (2.0 * h),
2726 d3: (ep.d3 - em.d3) / (2.0 * h),
2727 };
2728 assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2729 assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2730 assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2731 assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2732 assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2733 assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2734 assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2735 assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2736 }
2737
2738 #[test]
2739 fn beta_logistic_left_tail_uses_unclamped_log_space() {
2740 let eta = -40.0_f64;
2741 let delta = 0.2_f64;
2742 let epsilon = -0.1_f64;
2743 let a = (delta - epsilon).exp();
2744 let b = (delta + epsilon).exp();
2745 let expected_mu = beta_reg(a, b, eta.exp());
2746 let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2747
2748 assert!(
2749 (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2750 "left-tail mu mismatch: got {}, expected {}",
2751 out.mu,
2752 expected_mu
2753 );
2754 assert!(out.d1 > 0.0);
2755 assert!(out.d2 > 0.0);
2756 assert!(out.d3 > 0.0);
2757 assert!(out.d1 < 1e-20);
2758
2759 let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2760 assert!(partials.jet.d1 > 0.0);
2761 assert!(partials.jet.d2 > 0.0);
2762 assert!(partials.jet.d3 > 0.0);
2763 assert!(partials.djet_dlog_delta.d1.is_finite());
2764 assert!(partials.djet_depsilon.d1.is_finite());
2765 }
2766
2767 #[test]
2768 fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2769 let delta = 0.2;
2770 let epsilon = -0.35;
2771 let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2772 for eta in etas {
2773 let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2774 let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2775 assert!(
2776 (left - right).abs() <= 1e-14,
2777 "symmetry mismatch at eta={eta}: left={left}, right={right}"
2778 );
2779 }
2780 }
2781
2782 #[test]
2783 fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2784 let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2785 let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2786 epsilon: 0.18,
2787 log_delta: -0.22,
2788 delta: (-0.22_f64).exp(),
2789 });
2790 let mixture = InverseLink::Mixture(
2791 state_fromspec(&MixtureLinkSpec {
2792 components: vec![
2793 LinkComponent::Probit,
2794 LinkComponent::Logit,
2795 LinkComponent::CLogLog,
2796 LinkComponent::Cauchit,
2797 ],
2798 initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2799 })
2800 .expect("mixture state"),
2801 );
2802 let links = [
2803 InverseLink::Standard(StandardLink::Probit),
2804 InverseLink::Standard(StandardLink::Logit),
2805 InverseLink::Standard(StandardLink::CLogLog),
2806 sas,
2807 beta_logistic,
2808 mixture,
2809 ];
2810 let etas = [-1.1, -0.2, 0.6];
2811 let h = 1e-5;
2812
2813 for link in &links {
2814 for &eta in &etas {
2815 let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2816 let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2817 let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2818 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2819 .expect("analytic d4");
2820 assert_eq!(
2821 d4.signum(),
2822 d4fd.signum(),
2823 "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2824 link,
2825 d4,
2826 d4fd
2827 );
2828 assert!(
2829 (d4 - d4fd).abs() < 5e-3,
2830 "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2831 link,
2832 d4,
2833 d4fd
2834 );
2835 }
2836 }
2837 }
2838
2839 #[test]
2840 fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2841 let eta = 800.0;
2842 let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2843 assert_eq!(jet.mu, 1.0);
2844 assert!(
2845 jet.d1 == 0.0,
2846 "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2847 jet.d1
2848 );
2849 assert!(
2850 jet.d2 == 0.0,
2851 "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2852 jet.d2
2853 );
2854 assert!(
2855 jet.d3 == 0.0,
2856 "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2857 jet.d3
2858 );
2859
2860 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2861 &InverseLink::Standard(StandardLink::CLogLog),
2862 eta,
2863 )
2864 .expect("cloglog d4");
2865 assert!(
2866 d4 == 0.0,
2867 "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2868 );
2869 }
2870
2871 #[test]
2872 fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2873 let eta = -800.0;
2874 let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2875 assert_eq!(jet.mu, 0.0);
2876 assert!(
2877 jet.d1 == 0.0,
2878 "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2879 jet.d1
2880 );
2881 assert!(
2882 jet.d2 == 0.0,
2883 "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2884 jet.d2
2885 );
2886 assert!(
2887 jet.d3 == 0.0,
2888 "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2889 jet.d3
2890 );
2891
2892 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2893 &InverseLink::Mixture(
2894 state_fromspec(&MixtureLinkSpec {
2895 components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2896 initial_rho: Array1::from_vec(vec![12.0]),
2897 })
2898 .expect("mixture state"),
2899 ),
2900 eta,
2901 )
2902 .expect("loglog mixture d4");
2903 assert!(
2904 d4.is_finite(),
2905 "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2906 );
2907 }
2908
2909 #[test]
2910 fn logit_tail_derivatives_should_match_stable_closed_forms() {
2911 let eta = 50.0_f64;
2912 let z = (-eta).exp();
2913 let denom = 1.0_f64 + z;
2914 let stable_d1 = z / denom.powi(2);
2915 let stable_d2 = z * (z - 1.0) / denom.powi(3);
2916 let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2917 let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2918 let stable_d5 =
2919 z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2920
2921 assert!(stable_d1 > 0.0);
2922 assert!(stable_d2 < 0.0);
2923 assert!(stable_d3 > 0.0);
2924 assert!(stable_d4 < 0.0);
2925 assert!(stable_d5 > 0.0);
2926
2927 let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2928 assert!(
2929 (jet.d1 - stable_d1).abs() < 1e-30,
2930 "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2931 jet.d1,
2932 stable_d1
2933 );
2934 assert!(
2935 (jet.d2 - stable_d2).abs() < 1e-30,
2936 "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2937 jet.d2,
2938 stable_d2
2939 );
2940 assert!(
2941 (jet.d3 - stable_d3).abs() < 1e-30,
2942 "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2943 jet.d3,
2944 stable_d3
2945 );
2946
2947 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2948 &InverseLink::Standard(StandardLink::Logit),
2949 eta,
2950 )
2951 .expect("logit d4");
2952 assert!(
2953 (d4 - stable_d4).abs() < 1e-30,
2954 "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2955 d4,
2956 stable_d4
2957 );
2958
2959 let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2960 &InverseLink::Standard(StandardLink::Logit),
2961 eta,
2962 )
2963 .expect("logit d5");
2964 assert!(
2965 (d5 - stable_d5).abs() < 1e-30,
2966 "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2967 d5,
2968 stable_d5
2969 );
2970 }
2971
2972 #[test]
2973 fn cloglog_negative_tail_value_should_match_expm1_form() {
2974 let eta = -50.0_f64;
2975 let t = eta.exp();
2976 let stable_mu = -(-t).exp_m1();
2977 assert!(stable_mu > 0.0);
2978
2979 let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2980 assert!(
2981 (jet.mu - stable_mu).abs() < 1e-30,
2982 "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2983 jet.mu,
2984 stable_mu
2985 );
2986 }
2987
2988 #[test]
2989 fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2990 fn rel_err(a: f64, b: f64) -> f64 {
2991 (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2992 }
2993
2994 let cases = [
2995 (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2996 (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2997 (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2998 ];
2999 for (component, etas) in cases {
3000 for eta in etas {
3001 let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
3002 let jet = component_inverse_link_jet(component, eta);
3003 let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
3004 assert!(
3005 rel_err(w, expected) < 1.0e-12,
3006 "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
3007 );
3008
3009 let h = 1.0e-4;
3010 let fd1 = (component_fisher_weight_jet5(component, eta + h).0
3011 - component_fisher_weight_jet5(component, eta - h).0)
3012 / (2.0 * h);
3013 let fd2 = (component_fisher_weight_jet5(component, eta + h).1
3014 - component_fisher_weight_jet5(component, eta - h).1)
3015 / (2.0 * h);
3016 let fd3 = (component_fisher_weight_jet5(component, eta + h).2
3017 - component_fisher_weight_jet5(component, eta - h).2)
3018 / (2.0 * h);
3019 let fd4 = (component_fisher_weight_jet5(component, eta + h).3
3020 - component_fisher_weight_jet5(component, eta - h).3)
3021 / (2.0 * h);
3022
3023 assert!(
3024 rel_err(w1, fd1) < 1.0e-5,
3025 "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
3026 );
3027 assert!(
3028 rel_err(w2, fd2) < 1.0e-5,
3029 "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
3030 );
3031 assert!(
3032 rel_err(w3, fd3) < 5.0e-5,
3033 "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
3034 );
3035 assert!(
3036 rel_err(w4, fd4) < 5.0e-4,
3037 "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3038 );
3039 }
3040 }
3041 }
3042
3043 #[test]
3044 fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3045 let state = state_fromspec(&MixtureLinkSpec {
3046 components: vec![
3047 LinkComponent::CLogLog,
3048 LinkComponent::LogLog,
3049 LinkComponent::Cauchit,
3050 ],
3051 initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3052 })
3053 .expect("mixture state");
3054 let link = InverseLink::Mixture(state);
3055 assert!(
3056 inverse_link_has_fisher_weight_jet(&link),
3057 "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3058 );
3059 assert!(
3060 LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3061 "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3062 );
3063
3064 for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3065 let (w, w1, w2, w3, w4) =
3066 fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3067 for value in [w, w1, w2, w3, w4] {
3068 assert!(
3069 value.is_finite(),
3070 "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3071 );
3072 }
3073 assert!(
3074 w > 0.0,
3075 "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3076 );
3077 }
3078 }
3079
3080 #[test]
3081 fn loglog_fifth_derivative_should_match_closed_form_sign() {
3082 let eta = 0.0_f64;
3083 let r = (-eta).exp();
3084 let expected =
3085 (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3086 let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3087 assert!(
3088 (d5 - expected).abs() < 1e-15,
3089 "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3090 );
3091 assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3092 }
3093}