1use crate::estimate::EstimationError;
2use gam_math::probability::{normal_cdf, normal_pdf};
3use gam_math::special::stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg;
4use crate::quadrature::latent_cloglog_jet5;
5use gam_problem::{
6 InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
7 MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState, StandardLink,
8};
9use ndarray::Array1;
10use statrs::function::beta::{beta_reg, ln_beta};
11use statrs::function::gamma::digamma;
12use std::ops::Neg;
13use std::sync::OnceLock;
14
15const SAS_U_CLAMP: f64 = 50.0;
16pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
21
22#[inline]
23fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
24 static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
25 QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
26}
27
28#[inline]
29fn latent_cloglog_point_jet(
30 state: &LatentCLogLogState,
31 eta: f64,
32) -> Result<InverseLinkJet, EstimationError> {
33 let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
34 Ok(InverseLinkJet {
35 mu: jet.mean,
36 d1: jet.d1,
37 d2: jet.d2,
38 d3: jet.d3,
39 })
40}
41
42#[derive(Clone, Copy, Debug, PartialEq)]
43pub struct InverseLinkJet {
44 pub mu: f64,
45 pub d1: f64,
46 pub d2: f64,
47 pub d3: f64,
48}
49
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub struct LogitJet5 {
52 pub mu: f64,
53 pub d1: f64,
54 pub d2: f64,
55 pub d3: f64,
56 pub d4: f64,
57 pub d5: f64,
58}
59
60#[inline]
61fn canonicalzero(v: f64) -> f64 {
62 if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
63}
64
65#[inline]
66fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
67 jet.d1 = canonicalzero(jet.d1);
68 jet.d2 = canonicalzero(jet.d2);
69 jet.d3 = canonicalzero(jet.d3);
70 jet
71}
72
73#[inline]
74pub fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
75 if eta.is_nan() {
76 return LogitJet5 {
77 mu: f64::NAN,
78 d1: f64::NAN,
79 d2: f64::NAN,
80 d3: f64::NAN,
81 d4: f64::NAN,
82 d5: f64::NAN,
83 };
84 }
85 if eta == f64::INFINITY {
86 return LogitJet5 {
87 mu: 1.0,
88 d1: 0.0,
89 d2: 0.0,
90 d3: 0.0,
91 d4: 0.0,
92 d5: 0.0,
93 };
94 }
95 if eta == f64::NEG_INFINITY {
96 return LogitJet5 {
97 mu: 0.0,
98 d1: 0.0,
99 d2: 0.0,
100 d3: 0.0,
101 d4: 0.0,
102 d5: 0.0,
103 };
104 }
105
106 let jet = if eta >= 0.0 {
107 let z = (-eta).exp();
108 let opz = 1.0 + z;
109 let opz2 = opz * opz;
110 let opz3 = opz2 * opz;
111 let opz4 = opz3 * opz;
112 let opz5 = opz4 * opz;
113 let opz6 = opz5 * opz;
114 let z2 = z * z;
115 let z3 = z2 * z;
116 let z4 = z3 * z;
117 LogitJet5 {
118 mu: 1.0 / opz,
119 d1: z / opz2,
120 d2: z * (z - 1.0) / opz3,
121 d3: z * (z2 - 4.0 * z + 1.0) / opz4,
122 d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
123 d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
124 }
125 } else {
126 let z = eta.exp();
127 let opz = 1.0 + z;
128 let opz2 = opz * opz;
129 let opz3 = opz2 * opz;
130 let opz4 = opz3 * opz;
131 let opz5 = opz4 * opz;
132 let opz6 = opz5 * opz;
133 let z2 = z * z;
134 let z3 = z2 * z;
135 let z4 = z3 * z;
136 LogitJet5 {
137 mu: z / opz,
138 d1: z / opz2,
139 d2: z * (1.0 - z) / opz3,
140 d3: z * (1.0 - 4.0 * z + z2) / opz4,
141 d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
142 d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
143 }
144 };
145 LogitJet5 {
146 mu: jet.mu,
147 d1: canonicalzero(jet.d1),
148 d2: canonicalzero(jet.d2),
149 d3: canonicalzero(jet.d3),
150 d4: canonicalzero(jet.d4),
151 d5: canonicalzero(jet.d5),
152 }
153}
154
155#[inline]
156fn probit_jet(eta: f64) -> InverseLinkJet {
157 if eta.is_nan() {
168 return InverseLinkJet {
169 mu: f64::NAN,
170 d1: f64::NAN,
171 d2: f64::NAN,
172 d3: f64::NAN,
173 };
174 }
175 if eta == f64::INFINITY {
176 return InverseLinkJet {
177 mu: 1.0,
178 d1: 0.0,
179 d2: 0.0,
180 d3: 0.0,
181 };
182 }
183 if eta == f64::NEG_INFINITY {
184 return InverseLinkJet {
185 mu: 0.0,
186 d1: 0.0,
187 d2: 0.0,
188 d3: 0.0,
189 };
190 }
191 let x = eta;
192 let phi = normal_pdf(x);
193 InverseLinkJet {
194 mu: normal_cdf(x),
195 d1: phi,
196 d2: -x * phi,
197 d3: (x * x - 1.0) * phi,
198 }
199}
200
201#[inline]
202fn probit_pdfthird_derivative(eta: f64) -> f64 {
203 if eta.is_nan() {
207 return f64::NAN;
208 }
209 if !eta.is_finite() {
210 return 0.0;
211 }
212 let x = eta;
213 let phi = normal_pdf(x);
214 canonicalzero(-(x * x * x - 3.0 * x) * phi)
215}
216
217#[inline]
218fn probit_pdffourth_derivative(eta: f64) -> f64 {
219 if eta.is_nan() {
221 return f64::NAN;
222 }
223 if !eta.is_finite() {
224 return 0.0;
225 }
226 let x = eta;
227 let phi = normal_pdf(x);
228 canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
229}
230
231#[inline]
234fn taylor5_mul(a: &[f64; 5], b: &[f64; 5]) -> [f64; 5] {
235 let mut c = [0.0_f64; 5];
236 for i in 0..5 {
237 let ai = a[i];
238 if ai == 0.0 {
239 continue;
240 }
241 for j in 0..(5 - i) {
242 c[i + j] += ai * b[j];
243 }
244 }
245 c
246}
247
248#[inline]
250fn taylor5_inv(a: &[f64; 5]) -> [f64; 5] {
251 let mut b = [0.0_f64; 5];
252 b[0] = 1.0 / a[0];
253 for k in 1..5 {
254 let mut s = 0.0_f64;
255 for j in 1..=k {
256 s += a[j] * b[k - j];
257 }
258 b[k] = -s * b[0];
259 }
260 b
261}
262
263pub(crate) fn fisher_weight_jet5(link: StandardLink, eta: f64) -> (f64, f64, f64, f64, f64) {
279 match link {
280 StandardLink::Logit => {
281 let jet = logit_inverse_link_jet5(eta);
282 (jet.d1, jet.d2, jet.d3, jet.d4, jet.d5)
283 }
284 StandardLink::Probit => probit_fisher_weight_jet5(eta),
285 StandardLink::CLogLog => component_fisher_weight_jet5(LinkComponent::CLogLog, eta),
286 StandardLink::Identity | StandardLink::Log => (0.0, 0.0, 0.0, 0.0, 0.0),
287 }
288}
289
290pub(crate) fn fisher_weight_jet5_for_inverse_link(
291 link: &InverseLink,
292 eta: f64,
293) -> Result<(f64, f64, f64, f64, f64), EstimationError> {
294 match link {
295 InverseLink::Standard(link) => Ok(fisher_weight_jet5(*link, eta)),
296 InverseLink::LatentCLogLog(_)
297 | InverseLink::Sas(_)
298 | InverseLink::BetaLogistic(_)
299 | InverseLink::Mixture(_) => {
300 let jet = link.jet(eta)?;
301 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)?;
302 let d5 = inverse_link_pdffourth_derivative_for_inverse_link(link, eta)?;
303 Ok(fisher_weight_jet5_from_inverse_link_derivatives(
304 jet.mu, jet.d1, jet.d2, jet.d3, d4, d5,
305 ))
306 }
307 }
308}
309
310#[inline]
311pub(crate) fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
312 matches!(
313 link,
314 InverseLink::Standard(StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,)
315 | InverseLink::LatentCLogLog(_)
316 | InverseLink::Sas(_)
317 | InverseLink::BetaLogistic(_)
318 | InverseLink::Mixture(_)
319 )
320}
321
322#[inline]
323fn component_fisher_weight_jet5(component: LinkComponent, eta: f64) -> (f64, f64, f64, f64, f64) {
324 let jet = component_inverse_link_jet(component, eta);
325 let d4 = component_inverse_link_pdfthird_derivative(component, eta);
326 let d5 = component_inverse_link_pdffourth_derivative(component, eta);
327 fisher_weight_jet5_from_inverse_link_derivatives(jet.mu, jet.d1, jet.d2, jet.d3, d4, d5)
328}
329
330#[inline]
331fn fisher_weight_jet5_from_inverse_link_derivatives(
332 mu: f64,
333 d1: f64,
334 d2: f64,
335 d3: f64,
336 d4: f64,
337 d5: f64,
338) -> (f64, f64, f64, f64, f64) {
339 if [mu, d1, d2, d3, d4, d5].iter().any(|v| v.is_nan()) {
340 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
341 }
342 let variance = mu * (1.0 - mu);
343 if !(variance > 0.0) || !variance.is_finite() {
344 return (0.0, 0.0, 0.0, 0.0, 0.0);
345 }
346
347 let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
348 let mu_d = [mu, d1, d2, d3, d4];
349 let one_minus_mu_d = [1.0 - mu, -d1, -d2, -d3, -d4];
350 let dmu_d = [d1, d2, d3, d4, d5];
351 let mut mu_t = [0.0_f64; 5];
352 let mut one_minus_mu_t = [0.0_f64; 5];
353 let mut dmu_t = [0.0_f64; 5];
354 for k in 0..5 {
355 let inv_fact = 1.0 / factorial[k];
356 mu_t[k] = mu_d[k] * inv_fact;
357 one_minus_mu_t[k] = one_minus_mu_d[k] * inv_fact;
358 dmu_t[k] = dmu_d[k] * inv_fact;
359 }
360 let num_t = taylor5_mul(&dmu_t, &dmu_t);
361 let den_t = taylor5_mul(&mu_t, &one_minus_mu_t);
362 if !(den_t[0] > 0.0) || !den_t[0].is_finite() {
363 return (0.0, 0.0, 0.0, 0.0, 0.0);
364 }
365 let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
366 (
367 canonicalzero(w_t[0] * factorial[0]),
368 canonicalzero(w_t[1] * factorial[1]),
369 canonicalzero(w_t[2] * factorial[2]),
370 canonicalzero(w_t[3] * factorial[3]),
371 canonicalzero(w_t[4] * factorial[4]),
372 )
373}
374
375#[inline]
378fn probit_fisher_weight_jet5(eta: f64) -> (f64, f64, f64, f64, f64) {
379 if eta.is_nan() {
380 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
381 }
382 if !eta.is_finite() {
383 return (0.0, 0.0, 0.0, 0.0, 0.0);
384 }
385 let x = eta;
386 let p = normal_cdf(x);
387 let q = normal_cdf(-x);
391 let phi = normal_pdf(x);
392 if !(p > 0.0) || !(q > 0.0) || p * q <= 0.0 {
395 return (0.0, 0.0, 0.0, 0.0, 0.0);
396 }
397 let phi1 = -x * phi;
399 let phi2 = (x * x - 1.0) * phi;
400 let phi3 = -(x * x * x - 3.0 * x) * phi;
401 let phi4 = (x * x * x * x - 6.0 * x * x + 3.0) * phi;
402 let f_d = [phi, phi1, phi2, phi3, phi4];
405 let p_d = [p, phi, phi1, phi2, phi3];
406 let q_d = [q, -phi, -phi1, -phi2, -phi3];
407 let factorial = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
409 let mut f_t = [0.0_f64; 5];
410 let mut p_t = [0.0_f64; 5];
411 let mut q_t = [0.0_f64; 5];
412 for k in 0..5 {
413 let inv_fact = 1.0 / factorial[k];
414 f_t[k] = f_d[k] * inv_fact;
415 p_t[k] = p_d[k] * inv_fact;
416 q_t[k] = q_d[k] * inv_fact;
417 }
418 let num_t = taylor5_mul(&f_t, &f_t);
419 let den_t = taylor5_mul(&p_t, &q_t);
420 let w_t = taylor5_mul(&num_t, &taylor5_inv(&den_t));
421 (
423 canonicalzero(w_t[0] * factorial[0]),
424 canonicalzero(w_t[1] * factorial[1]),
425 canonicalzero(w_t[2] * factorial[2]),
426 canonicalzero(w_t[3] * factorial[3]),
427 canonicalzero(w_t[4] * factorial[4]),
428 )
429}
430
431#[inline]
432fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
433 InverseLinkJet {
434 mu: base.mu,
435 d1: base.d1 * z1,
436 d2: base.d2 * z1 * z1 + base.d1 * z2,
437 d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
438 }
439}
440
441#[inline]
442fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
443 match component {
444 LinkComponent::Probit => probit_pdfthird_derivative(eta),
445 LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
446 LinkComponent::CLogLog => {
447 if eta.is_nan() {
455 return f64::NAN;
456 }
457 if !eta.is_finite() {
458 return 0.0;
459 }
460 let t = eta.exp();
461 canonicalzero(stable_nonnegative_poly_times_exp_neg(
462 t,
463 &[0.0, 1.0, -7.0, 6.0, -1.0],
464 ))
465 }
466 LinkComponent::LogLog => {
467 if eta.is_nan() {
474 return f64::NAN;
475 }
476 if !eta.is_finite() {
477 return 0.0;
478 }
479 let r = (-eta).exp();
480 canonicalzero(stable_nonnegative_poly_times_exp_neg(
481 r,
482 &[0.0, -1.0, 7.0, -6.0, 1.0],
483 ))
484 }
485 LinkComponent::Cauchit => {
486 if eta.is_nan() {
494 return f64::NAN;
495 }
496 if !eta.is_finite() {
497 return 0.0;
498 }
499 let denom = 1.0 + eta * eta;
500 24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
501 }
502 }
503}
504
505#[inline]
508fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
509 match component {
510 LinkComponent::Probit => probit_pdffourth_derivative(eta),
511 LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
512 LinkComponent::CLogLog => {
513 if eta.is_nan() {
518 return f64::NAN;
519 }
520 if !eta.is_finite() {
521 return 0.0;
522 }
523 let t = eta.exp();
524 canonicalzero(stable_nonnegative_poly_times_exp_neg(
525 t,
526 &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
527 ))
528 }
529 LinkComponent::LogLog => {
530 if eta.is_nan() {
535 return f64::NAN;
536 }
537 if !eta.is_finite() {
538 return 0.0;
539 }
540 let r = (-eta).exp();
541 canonicalzero(stable_nonnegative_poly_times_exp_neg(
542 r,
543 &[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
544 ))
545 }
546 LinkComponent::Cauchit => {
547 if eta.is_nan() {
549 return f64::NAN;
550 }
551 if !eta.is_finite() {
552 return 0.0;
553 }
554 let e2 = eta * eta;
555 let denom = 1.0 + e2;
556 24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
557 }
558 }
559}
560
561#[derive(Clone, Debug, PartialEq)]
562pub struct MixtureJetWithRhoPartials {
563 pub jet: InverseLinkJet,
564 pub djet_drho: Vec<InverseLinkJet>,
567}
568
569#[derive(Clone, Debug, PartialEq)]
570pub struct SasJetWithParamPartials {
571 pub jet: InverseLinkJet,
572 pub djet_depsilon: InverseLinkJet,
573 pub djet_dlog_delta: InverseLinkJet,
574}
575
576#[derive(Clone, Debug, PartialEq)]
577pub enum LinkParamPartials {
578 Mixture(MixtureJetWithRhoPartials),
579 Sas(SasJetWithParamPartials),
580}
581
582pub trait InverseLinkKernel {
588 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
589
590 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
591 assert!(eta.is_finite(), "eta must be finite");
592 Ok(None)
593 }
594}
595
596#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
597pub struct ProbitLinkKernel;
598
599#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
600pub struct LogitLinkKernel;
601
602#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
603pub struct CLogLogLinkKernel;
604
605#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
606pub struct LogLogLinkKernel;
607
608#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
609pub struct CauchitLinkKernel;
610
611pub fn sas_link_state_from_raw(
619 raw_epsilon: f64,
620 raw_log_delta: f64,
621) -> Result<SasLinkState, String> {
622 if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
623 return Err("SAS link parameters must be finite".to_string());
624 }
625 Ok(SasLinkState {
626 epsilon: raw_epsilon,
627 log_delta: raw_log_delta,
628 delta: sas_delta_from_raw_log_delta(raw_log_delta),
629 })
630}
631
632pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
633 sas_link_state_from_raw(spec.initial_epsilon, spec.initial_log_delta)
634}
635
636pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
637 if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
638 return Err("Beta-Logistic link parameters must be finite".to_string());
639 }
640 let log_shape_center = spec.initial_log_delta;
646 Ok(SasLinkState {
647 epsilon: spec.initial_epsilon,
648 log_delta: log_shape_center,
649 delta: sas_delta_from_raw_log_delta(log_shape_center),
650 })
651}
652
653#[inline]
654fn tanh_bound(value: f64, bound: f64) -> f64 {
655 let b = bound.max(f64::EPSILON);
656 b * (value / b).tanh()
657}
658
659#[inline]
660fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
661 let b = bound.max(f64::EPSILON);
662 let t = (value / b).tanh();
663 1.0 - t * t
664}
665
666#[inline]
667fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
668 let b = bound.max(f64::EPSILON);
669 let t = (value / b).tanh();
670 let s = 1.0 - t * t;
671 -2.0 * t * s / b
672}
673
674#[inline]
675fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
676 let b = bound.max(f64::EPSILON);
677 let t = (value / b).tanh();
678 let s = 1.0 - t * t;
679 -2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
680}
681
682#[inline]
683fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
684 let b = bound.max(f64::EPSILON);
685 let t = (value / b).tanh();
686 let s = 1.0 - t * t;
687 8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
688}
689
690#[inline]
691fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
692 let b = bound.max(f64::EPSILON);
696 let t = (value / b).tanh();
697 let s = 1.0 - t * t;
698 let t2 = t * t;
699 let b4 = b * b * b * b;
700 8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
701}
702
703#[inline]
704fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
705 let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
706 let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
707 (ld_eff, dld_eff_draw)
708}
709
710#[inline]
711fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
712 let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
713 ld_eff.exp()
714}
715
716pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
717 if spec.components.is_empty() {
718 return Err("mixture link requires at least 1 component".to_string());
719 }
720 if spec.initial_rho.len() + 1 != spec.components.len() {
721 return Err(format!(
722 "mixture link rho length mismatch: expected {}, got {}",
723 spec.components.len() - 1,
724 spec.initial_rho.len()
725 ));
726 }
727 for i in 0..spec.components.len() {
728 for j in (i + 1)..spec.components.len() {
729 if spec.components[i] == spec.components[j] {
730 return Err("mixture link components must be unique".to_string());
731 }
732 }
733 }
734 let has_anchor = spec.components.iter().any(|component| {
749 matches!(
750 component,
751 LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
752 )
753 });
754 if !has_anchor && spec.components.len() > 1 {
755 let unsupported: Vec<&str> = spec
756 .components
757 .iter()
758 .map(|component| component.name())
759 .collect();
760 return Err(format!(
761 "mixture link components {{{}}} are unsupported: at least one component \
762 must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
763 projected LinkFunction is well defined; cauchit and loglog have no \
764 LinkFunction representative",
765 unsupported.join(", ")
766 ));
767 }
768 Ok(())
769}
770
771pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
772 let k = rho.len() + 1;
773 let mut logits = Vec::with_capacity(k);
774 let mut maxv = 0.0_f64;
775 for &v in rho {
776 maxv = maxv.max(v);
777 logits.push(v);
778 }
779 maxv = maxv.max(0.0);
780 logits.push(0.0);
781
782 let mut sum = 0.0_f64;
783 let mut exps = vec![0.0_f64; k];
784 for i in 0..k {
785 let e = (logits[i] - maxv).exp();
786 exps[i] = e;
787 sum += e;
788 }
789 if !sum.is_finite() || sum <= 0.0 {
790 return Array1::from_elem(k, 1.0 / k as f64);
791 }
792 let inv = 1.0 / sum;
793 Array1::from_iter(exps.into_iter().map(|v| v * inv))
794}
795
796pub fn softmaxwith_jacobian_last_fixedzero(
799 rho: &Array1<f64>,
800) -> (Array1<f64>, ndarray::Array2<f64>) {
801 let pi = softmax_last_fixedzero(rho);
802 let k = pi.len();
803 let m = k.saturating_sub(1);
804 let mut jac = ndarray::Array2::<f64>::zeros((k, m));
805 for j in 0..m {
806 let pi_j = pi[j];
807 for kk in 0..k {
808 let delta = if kk == j { 1.0 } else { 0.0 };
809 jac[[kk, j]] = pi[kk] * (delta - pi_j);
810 }
811 }
812 (pi, jac)
813}
814
815pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
816 validate_mixturespec(spec)?;
817 let pi = softmax_last_fixedzero(&spec.initial_rho);
818 Ok(MixtureLinkState {
819 components: spec.components.clone(),
820 rho: spec.initial_rho.clone(),
821 pi,
822 })
823}
824
825#[inline]
826pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
827 canonicalize_jet(match component {
828 LinkComponent::Logit => {
829 let jet = logit_inverse_link_jet5(eta);
830 InverseLinkJet {
831 mu: jet.mu,
832 d1: jet.d1,
833 d2: jet.d2,
834 d3: jet.d3,
835 }
836 }
837 LinkComponent::Probit => probit_jet(eta),
838 LinkComponent::CLogLog => {
839 if eta.is_nan() {
840 return InverseLinkJet {
841 mu: f64::NAN,
842 d1: f64::NAN,
843 d2: f64::NAN,
844 d3: f64::NAN,
845 };
846 }
847 let t = eta.exp();
848 if !t.is_finite() {
849 return InverseLinkJet {
850 mu: 1.0,
851 d1: 0.0,
852 d2: 0.0,
853 d3: 0.0,
854 };
855 }
856 InverseLinkJet {
857 mu: -(-t).exp_m1(),
858 d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
859 d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
860 d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
861 }
862 }
863 LinkComponent::LogLog => {
864 if eta.is_nan() {
865 return InverseLinkJet {
866 mu: f64::NAN,
867 d1: f64::NAN,
868 d2: f64::NAN,
869 d3: f64::NAN,
870 };
871 }
872 let r = (-eta).exp();
873 if !r.is_finite() {
874 return InverseLinkJet {
875 mu: 0.0,
876 d1: 0.0,
877 d2: 0.0,
878 d3: 0.0,
879 };
880 }
881 InverseLinkJet {
882 mu: (-r).exp(),
883 d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
884 d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
885 d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
886 }
887 }
888 LinkComponent::Cauchit => {
889 if eta.is_nan() {
890 return InverseLinkJet {
891 mu: f64::NAN,
892 d1: f64::NAN,
893 d2: f64::NAN,
894 d3: f64::NAN,
895 };
896 }
897 let den = 1.0 + eta * eta;
898 let d1 = if eta.is_finite() {
899 1.0 / (std::f64::consts::PI * den)
900 } else {
901 0.0
902 };
903 let d2 = if eta.is_finite() {
904 -2.0 * eta / (std::f64::consts::PI * den * den)
905 } else {
906 0.0
907 };
908 let d3 = if eta.is_finite() {
909 (6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
910 } else {
911 0.0
912 };
913 InverseLinkJet {
914 mu: 0.5 + eta.atan() / std::f64::consts::PI,
915 d1,
916 d2,
917 d3,
918 }
919 }
920 })
921}
922
923impl InverseLinkKernel for ProbitLinkKernel {
924 #[inline]
925 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
926 Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
927 }
928}
929
930impl InverseLinkKernel for LogitLinkKernel {
931 #[inline]
932 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
933 Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
934 }
935}
936
937impl InverseLinkKernel for CLogLogLinkKernel {
938 #[inline]
939 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
940 Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
941 }
942}
943
944impl InverseLinkKernel for LogLogLinkKernel {
945 #[inline]
946 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
947 Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
948 }
949}
950
951impl InverseLinkKernel for CauchitLinkKernel {
952 #[inline]
953 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
954 Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
955 }
956}
957
958impl InverseLinkKernel for LinkComponent {
959 #[inline]
960 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
961 Ok(component_inverse_link_jet(*self, eta))
962 }
963}
964
965impl InverseLinkKernel for LinkFunction {
966 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
967 match self {
968 LinkFunction::Logit => LogitLinkKernel.jet(eta),
969 LinkFunction::Probit => ProbitLinkKernel.jet(eta),
970 LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
971 LinkFunction::Identity => Ok(InverseLinkJet {
972 mu: eta,
973 d1: 1.0,
974 d2: 0.0,
975 d3: 0.0,
976 }),
977 LinkFunction::Log => {
978 let e = eta.clamp(-700.0, 700.0).exp();
990 Ok(InverseLinkJet {
991 mu: e,
992 d1: e,
993 d2: e,
994 d3: e,
995 })
996 }
997 LinkFunction::Sas => Err(EstimationError::InvalidInput(
998 "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
999 )),
1000 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1001 "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1002 .to_string(),
1003 )),
1004 }
1005 }
1006}
1007
1008impl InverseLinkKernel for SasLinkState {
1009 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1010 Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
1011 }
1012
1013 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1014 Ok(Some(LinkParamPartials::Sas(
1015 sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
1016 )))
1017 }
1018}
1019
1020#[derive(Clone, Copy, Debug)]
1021pub struct BetaLogisticKernel {
1022 pub log_shape_center: f64,
1025 pub epsilon: f64,
1026}
1027
1028impl InverseLinkKernel for BetaLogisticKernel {
1029 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1030 Ok(beta_logistic_inverse_link_jet(
1031 eta,
1032 self.log_shape_center,
1033 self.epsilon,
1034 ))
1035 }
1036
1037 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1038 Ok(Some(LinkParamPartials::Sas(
1039 beta_logistic_inverse_link_jetwith_param_partials(
1040 eta,
1041 self.log_shape_center,
1042 self.epsilon,
1043 ),
1044 )))
1045 }
1046}
1047
1048impl InverseLinkKernel for MixtureLinkState {
1049 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1050 Ok(mixture_inverse_link_jet(self, eta))
1051 }
1052
1053 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1054 Ok(Some(LinkParamPartials::Mixture(
1055 mixture_inverse_link_jetwith_rho_partials(self, eta),
1056 )))
1057 }
1058}
1059
1060impl InverseLinkKernel for InverseLink {
1061 fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
1062 match self {
1063 InverseLink::Standard(link_fn) => link_fn.as_link_function().jet(eta),
1064 InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
1065 InverseLink::Sas(state) => state.jet(eta),
1066 InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1067 log_shape_center: state.log_delta,
1068 epsilon: state.epsilon,
1069 }
1070 .jet(eta),
1071 InverseLink::Mixture(state) => state.jet(eta),
1072 }
1073 }
1074
1075 fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
1076 match self {
1077 InverseLink::Standard(_) => Ok(None),
1078 InverseLink::LatentCLogLog(_) => Ok(None),
1079 InverseLink::Sas(state) => state.param_partials(eta),
1080 InverseLink::BetaLogistic(state) => BetaLogisticKernel {
1081 log_shape_center: state.log_delta,
1082 epsilon: state.epsilon,
1083 }
1084 .param_partials(eta),
1085 InverseLink::Mixture(state) => state.param_partials(eta),
1086 }
1087 }
1088}
1089
1090pub fn inverse_link_jet_for_inverse_link(
1094 link: &InverseLink,
1095 eta: f64,
1096) -> Result<InverseLinkJet, EstimationError> {
1097 link.jet(eta)
1098}
1099
1100pub fn inverse_link_mu_d1_for_inverse_link(
1110 link: &InverseLink,
1111 eta: f64,
1112) -> Result<(f64, f64), EstimationError> {
1113 match link {
1114 InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(link_fn.as_link_function(), eta)?),
1115 InverseLink::LatentCLogLog(state) => {
1116 let jet = latent_cloglog_point_jet(state, eta)?;
1117 Ok((jet.mu, jet.d1))
1118 }
1119 InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
1120 InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
1121 eta,
1122 state.log_delta,
1123 state.epsilon,
1124 )),
1125 InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
1126 }
1127}
1128
1129fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
1130 match link {
1131 LinkFunction::Identity => Ok((eta, 1.0)),
1132 LinkFunction::Log => {
1133 let e = eta.clamp(-700.0, 700.0).exp();
1138 Ok((e, e))
1139 }
1140 LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
1141 LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
1142 LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
1143 LinkFunction::Sas => Err(EstimationError::InvalidInput(
1144 "LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
1145 )),
1146 LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
1147 "LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
1148 .to_string(),
1149 )),
1150 }
1151}
1152
1153#[inline]
1154fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
1155 match component {
1161 LinkComponent::Logit => {
1162 let jet = logit_inverse_link_jet5(eta);
1163 (jet.mu, canonicalzero(jet.d1))
1164 }
1165 LinkComponent::Probit => {
1166 if eta.is_nan() {
1167 return (f64::NAN, f64::NAN);
1168 }
1169 if eta == f64::INFINITY {
1170 return (1.0, 0.0);
1171 }
1172 if eta == f64::NEG_INFINITY {
1173 return (0.0, 0.0);
1174 }
1175 let phi = normal_pdf(eta);
1176 (normal_cdf(eta), canonicalzero(phi))
1177 }
1178 LinkComponent::CLogLog => {
1179 if eta.is_nan() {
1180 return (f64::NAN, f64::NAN);
1181 }
1182 let t = eta.exp();
1183 if !t.is_finite() {
1184 return (1.0, 0.0);
1185 }
1186 (
1187 -(-t).exp_m1(),
1188 canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
1189 )
1190 }
1191 LinkComponent::LogLog => {
1192 if eta.is_nan() {
1193 return (f64::NAN, f64::NAN);
1194 }
1195 let r = (-eta).exp();
1196 if !r.is_finite() {
1197 return (0.0, 0.0);
1198 }
1199 (
1200 (-r).exp(),
1201 canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
1202 )
1203 }
1204 LinkComponent::Cauchit => {
1205 if eta.is_nan() {
1206 return (f64::NAN, f64::NAN);
1207 }
1208 let den = 1.0 + eta * eta;
1209 let d1 = if eta.is_finite() {
1210 1.0 / (std::f64::consts::PI * den)
1211 } else {
1212 0.0
1213 };
1214 (0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
1215 }
1216 }
1217}
1218
1219fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
1220 let delta_id = sas_delta_from_raw_log_delta(log_delta);
1221 if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1222 return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
1223 }
1224 let e = if eta.is_finite() { eta } else { 0.0 };
1225 let a = e.asinh();
1226 let delta = delta_id;
1227 let u_raw = delta * a - epsilon;
1228 let u = tanh_bound(u_raw, SAS_U_CLAMP);
1229 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1230 let s = u.sinh();
1231 let c = u.cosh();
1232 let z = s;
1233 let q = e.hypot(1.0);
1234 let inv_q = 1.0 / q;
1235 let r1 = delta * inv_q;
1236 let u1 = g1 * r1;
1237 let z1 = c * u1;
1238 let base = probit_jet(z);
1241 (base.mu, canonicalzero(base.d1 * z1))
1242}
1243
1244fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
1245 let logistic = logistic_uwith_derivatives(eta);
1246 let a = (delta - epsilon).exp();
1247 let b = (delta + epsilon).exp();
1248 let mu = beta_reg_logistic(a, b, logistic);
1249 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1250 (mu, log_d1.exp())
1251}
1252
1253fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
1254 let mut mu = 0.0_f64;
1255 let mut d1 = 0.0_f64;
1256 let k = state.components.len().min(state.pi.len());
1257 for i in 0..k {
1258 let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
1259 let w = state.pi[i];
1260 mu += w * mu_i;
1261 d1 += w * d1_i;
1262 }
1263 (mu, d1)
1264}
1265
1266#[derive(Clone, Copy)]
1267enum PdfDerivativeOrder {
1268 Third,
1269 Fourth,
1270}
1271
1272impl PdfDerivativeOrder {
1273 fn probit(self, eta: f64) -> f64 {
1274 match self {
1275 Self::Third => probit_pdfthird_derivative(eta),
1276 Self::Fourth => probit_pdffourth_derivative(eta),
1277 }
1278 }
1279
1280 fn component(self, component: LinkComponent, eta: f64) -> f64 {
1281 match self {
1282 Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
1283 Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
1284 }
1285 }
1286
1287 fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
1288 let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
1289 Ok(match self {
1290 Self::Third => jet.d4,
1291 Self::Fourth => jet.d5,
1292 })
1293 }
1294
1295 fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1296 match self {
1297 Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
1298 Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
1299 }
1300 }
1301
1302 fn beta_logistic(self, eta: f64, log_shape_center: f64, epsilon: f64) -> f64 {
1303 match self {
1304 Self::Third => {
1305 beta_logistic_inverse_link_pdfthird_derivative(eta, log_shape_center, epsilon)
1306 }
1307 Self::Fourth => {
1308 beta_logistic_inverse_link_pdffourth_derivative(eta, log_shape_center, epsilon)
1309 }
1310 }
1311 }
1312}
1313
1314fn inverse_link_pdf_derivative_for_inverse_link(
1315 link: &InverseLink,
1316 eta: f64,
1317 order: PdfDerivativeOrder,
1318) -> Result<f64, EstimationError> {
1319 match link {
1320 InverseLink::Standard(StandardLink::Identity) => Ok(0.0),
1321 InverseLink::Standard(StandardLink::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
1322 InverseLink::Standard(StandardLink::Probit) => Ok(order.probit(eta)),
1323 InverseLink::Standard(StandardLink::Logit) => {
1324 Ok(order.component(LinkComponent::Logit, eta))
1325 }
1326 InverseLink::Standard(StandardLink::CLogLog) => {
1327 Ok(order.component(LinkComponent::CLogLog, eta))
1328 }
1329 InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
1330 InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
1331 InverseLink::BetaLogistic(state) => {
1332 Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
1333 }
1334 InverseLink::Mixture(state) => Ok(state
1335 .components
1336 .iter()
1337 .zip(state.pi.iter())
1338 .map(|(&component, &weight)| weight * order.component(component, eta))
1339 .sum()),
1340 }
1341}
1342
1343pub fn inverse_link_pdfthird_derivative_for_inverse_link(
1344 link: &InverseLink,
1345 eta: f64,
1346) -> Result<f64, EstimationError> {
1347 inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
1363}
1364
1365pub fn inverse_link_pdffourth_derivative_for_inverse_link(
1371 link: &InverseLink,
1372 eta: f64,
1373) -> Result<f64, EstimationError> {
1374 inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
1375}
1376
1377
1378#[inline]
1379fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
1380 const SURVIVAL_ETA_CLAMP: f64 = 30.0;
1384
1385 let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
1386 let hazard = z.exp();
1387 let survival = (-hazard).exp();
1388 if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
1389 return InverseLinkJet {
1390 mu: survival,
1391 d1: 0.0,
1392 d2: 0.0,
1393 d3: 0.0,
1394 };
1395 }
1396
1397 let d1 = -hazard * survival;
1398 let d2 = hazard * (hazard - 1.0) * survival;
1399 let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
1400 InverseLinkJet {
1401 mu: survival,
1402 d1,
1403 d2,
1404 d3,
1405 }
1406}
1407
1408pub fn inverse_link_jet_for_family(
1409 spec: &LikelihoodSpec,
1410 eta: f64,
1411) -> Result<InverseLinkJet, EstimationError> {
1412 if matches!(spec.response, ResponseFamily::RoystonParmar) {
1415 return Ok(royston_parmar_inverse_link_jet(eta));
1416 }
1417 spec.link.jet(eta)
1418}
1419
1420#[inline]
1427fn log_inverse_link_jet_exact(eta: f64) -> InverseLinkJet {
1428 let e = eta.exp();
1429 InverseLinkJet {
1430 mu: e,
1431 d1: e,
1432 d2: e,
1433 d3: e,
1434 }
1435}
1436
1437pub fn inverse_link_jet_for_family_public(
1450 spec: &LikelihoodSpec,
1451 eta: f64,
1452) -> Result<InverseLinkJet, EstimationError> {
1453 if matches!(spec.response, ResponseFamily::RoystonParmar) {
1454 return Ok(royston_parmar_inverse_link_jet(eta));
1455 }
1456 if let InverseLink::Standard(StandardLink::Log) = spec.link {
1457 return Ok(log_inverse_link_jet_exact(eta));
1458 }
1459 spec.link.jet(eta)
1460}
1461
1462#[inline]
1463pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
1464 let mut mu = 0.0_f64;
1465 let mut d1 = 0.0_f64;
1466 let mut d2 = 0.0_f64;
1467 let mut d3 = 0.0_f64;
1468 let k = state.components.len().min(state.pi.len());
1469 for i in 0..k {
1470 let jet = component_inverse_link_jet(state.components[i], eta);
1471 let w = state.pi[i];
1472 mu += w * jet.mu;
1473 d1 += w * jet.d1;
1474 d2 += w * jet.d2;
1475 d3 += w * jet.d3;
1476 }
1477 InverseLinkJet { mu, d1, d2, d3 }
1478}
1479
1480pub fn mixture_inverse_link_jetwith_rho_partials(
1488 state: &MixtureLinkState,
1489 eta: f64,
1490) -> MixtureJetWithRhoPartials {
1491 let k = state.components.len().min(state.pi.len());
1492 let m = k.saturating_sub(1);
1493 let mut djet_drho = vec![
1494 InverseLinkJet {
1495 mu: 0.0,
1496 d1: 0.0,
1497 d2: 0.0,
1498 d3: 0.0,
1499 };
1500 m
1501 ];
1502 let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
1503 MixtureJetWithRhoPartials { jet, djet_drho }
1504}
1505
1506pub fn mixture_inverse_link_jetwith_rho_partials_into(
1509 state: &MixtureLinkState,
1510 eta: f64,
1511 out: &mut [InverseLinkJet],
1512) -> InverseLinkJet {
1513 let k = state.components.len().min(state.pi.len());
1514 let m = k.saturating_sub(1);
1515 assert!(
1516 out.len() >= m,
1517 "rho-partial output buffer too small: got {}, need {}",
1518 out.len(),
1519 m
1520 );
1521 let mut mixed = InverseLinkJet {
1522 mu: 0.0,
1523 d1: 0.0,
1524 d2: 0.0,
1525 d3: 0.0,
1526 };
1527 for i in 0..k {
1528 let jet_i = component_inverse_link_jet(state.components[i], eta);
1529 let w = state.pi[i];
1530 mixed.mu += w * jet_i.mu;
1531 mixed.d1 += w * jet_i.d1;
1532 mixed.d2 += w * jet_i.d2;
1533 mixed.d3 += w * jet_i.d3;
1534 if i < m {
1537 out[i] = jet_i;
1538 }
1539 }
1540 for j in 0..m {
1541 let pi_j = state.pi[j];
1542 let cj = out[j];
1543 out[j] = InverseLinkJet {
1544 mu: pi_j * (cj.mu - mixed.mu),
1545 d1: pi_j * (cj.d1 - mixed.d1),
1546 d2: pi_j * (cj.d2 - mixed.d2),
1547 d3: pi_j * (cj.d3 - mixed.d3),
1548 };
1549 }
1550 mixed
1551}
1552
1553#[derive(Clone, Copy)]
1554struct LogisticU {
1555 u: f64,
1556 one_minus_u: f64,
1557 ln_u: f64,
1558 ln_one_minus_u: f64,
1559 du: f64,
1560 use_upper_tail: bool,
1561}
1562
1563#[inline]
1564fn logistic_uwith_derivatives(eta: f64) -> LogisticU {
1565 let ln_u = -gam_linalg::utils::stable_softplus(-eta);
1566 let ln_one_minus_u = -gam_linalg::utils::stable_softplus(eta);
1567 let u = ln_u.exp();
1568 let one_minus_u = ln_one_minus_u.exp();
1569 let du = (ln_u + ln_one_minus_u).exp();
1570 LogisticU {
1571 u,
1572 one_minus_u,
1573 ln_u,
1574 ln_one_minus_u,
1575 du,
1576 use_upper_tail: eta >= 0.0,
1577 }
1578}
1579
1580#[inline]
1581fn beta_reg_logistic(a: f64, b: f64, logistic: LogisticU) -> f64 {
1582 if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1583 return f64::NAN;
1584 }
1585 if logistic.ln_u == f64::NEG_INFINITY {
1586 return 0.0;
1587 }
1588 if logistic.ln_one_minus_u == f64::NEG_INFINITY {
1589 return 1.0;
1590 }
1591 if logistic.use_upper_tail {
1592 1.0 - beta_reg(b, a, logistic.one_minus_u)
1593 } else {
1594 beta_reg(a, b, logistic.u)
1595 }
1596}
1597
1598#[inline]
1599fn beta_reg_with_shape_partials_logistic(a: f64, b: f64, logistic: LogisticU) -> (f64, f64, f64) {
1600 if logistic.ln_u.is_nan() || logistic.ln_one_minus_u.is_nan() {
1601 return (f64::NAN, f64::NAN, f64::NAN);
1602 }
1603 if logistic.use_upper_tail {
1604 let (tail, dtail_db, dtail_da) = beta_reg_with_shape_partials(b, a, logistic.one_minus_u);
1605 (1.0 - tail, -dtail_da, -dtail_db)
1606 } else {
1607 beta_reg_with_shape_partials(a, b, logistic.u)
1608 }
1609}
1610
1611#[inline]
1612fn beta_logistic_log_d1(a: f64, b: f64, logistic: LogisticU) -> f64 {
1613 a * logistic.ln_u + b * logistic.ln_one_minus_u - ln_beta(a, b)
1614}
1615
1616#[derive(Clone, Copy)]
1617struct ShapeDual {
1618 v: f64,
1619 da: f64,
1620 db: f64,
1621}
1622
1623impl ShapeDual {
1624 #[inline]
1625 fn constant(v: f64) -> Self {
1626 Self {
1627 v,
1628 da: 0.0,
1629 db: 0.0,
1630 }
1631 }
1632
1633 #[inline]
1634 fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
1635 Self { v, da, db }
1636 }
1637
1638 #[inline]
1639 fn clamp_small(self, floor: f64) -> Self {
1640 if self.v.abs() < floor {
1641 Self::constant(floor)
1642 } else {
1643 self
1644 }
1645 }
1646}
1647
1648impl std::ops::Add for ShapeDual {
1649 type Output = Self;
1650
1651 #[inline]
1652 fn add(self, rhs: Self) -> Self {
1653 Self {
1654 v: self.v + rhs.v,
1655 da: self.da + rhs.da,
1656 db: self.db + rhs.db,
1657 }
1658 }
1659}
1660
1661impl std::ops::Sub for ShapeDual {
1662 type Output = Self;
1663
1664 #[inline]
1665 fn sub(self, rhs: Self) -> Self {
1666 Self {
1667 v: self.v - rhs.v,
1668 da: self.da - rhs.da,
1669 db: self.db - rhs.db,
1670 }
1671 }
1672}
1673
1674impl std::ops::Mul for ShapeDual {
1675 type Output = Self;
1676
1677 #[inline]
1678 fn mul(self, rhs: Self) -> Self {
1679 Self {
1680 v: self.v * rhs.v,
1681 da: self.da * rhs.v + self.v * rhs.da,
1682 db: self.db * rhs.v + self.v * rhs.db,
1683 }
1684 }
1685}
1686
1687impl std::ops::Div for ShapeDual {
1688 type Output = Self;
1689
1690 #[inline]
1691 fn div(self, rhs: Self) -> Self {
1692 let inv = 1.0 / rhs.v;
1693 let inv2 = inv * inv;
1694 Self {
1695 v: self.v * inv,
1696 da: (self.da * rhs.v - self.v * rhs.da) * inv2,
1697 db: (self.db * rhs.v - self.v * rhs.db) * inv2,
1698 }
1699 }
1700}
1701
1702impl std::ops::Neg for ShapeDual {
1703 type Output = Self;
1704
1705 #[inline]
1706 fn neg(self) -> Self {
1707 ShapeDual {
1708 v: -self.v,
1709 da: -self.da,
1710 db: -self.db,
1711 }
1712 }
1713}
1714
1715#[inline]
1716fn shape_dual(v: f64) -> ShapeDual {
1717 ShapeDual::constant(v)
1718}
1719
1720fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
1724 if x0 <= 0.0 {
1725 return (0.0, 0.0, 0.0);
1726 }
1727 if x0 >= 1.0 {
1728 return (1.0, 0.0, 0.0);
1729 }
1730
1731 let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
1732 let (a, b, x) = if symm_transform {
1733 (
1734 ShapeDual::from_value_partials(b0, 0.0, 1.0),
1735 ShapeDual::from_value_partials(a0, 1.0, 0.0),
1736 1.0 - x0,
1737 )
1738 } else {
1739 (
1740 ShapeDual::from_value_partials(a0, 1.0, 0.0),
1741 ShapeDual::from_value_partials(b0, 0.0, 1.0),
1742 x0,
1743 )
1744 };
1745
1746 let ln_x = x.ln();
1747 let ln_1mx = (1.0 - x).ln();
1748 let psi_ab = digamma(a.v + b.v);
1749 let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
1750 - statrs::function::gamma::ln_gamma(a.v)
1751 - statrs::function::gamma::ln_gamma(b.v)
1752 + a.v * ln_x
1753 + b.v * ln_1mx;
1754 let bt_v = log_bt.exp();
1755 let log_bt_a = psi_ab - digamma(a.v) + ln_x;
1756 let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
1757 let bt = ShapeDual {
1758 v: bt_v,
1759 da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
1760 db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
1761 };
1762
1763 let eps = 0.00000000000000011102230246251565;
1764 let fpmin = f64::MIN_POSITIVE / eps;
1765 let one = shape_dual(1.0);
1766 let qab = a + b;
1767 let qap = a + one;
1768 let qam = a - one;
1769 let mut c = one;
1770 let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
1771 d = one / d;
1772 let mut h = d;
1773
1774 for m in 1..141 {
1775 let mf = f64::from(m);
1776 let m2 = mf * 2.0;
1777 let md = shape_dual(mf);
1778 let m2d = shape_dual(m2);
1779 let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
1780 d = (one + aa * d).clamp_small(fpmin);
1781 c = (one + aa / c).clamp_small(fpmin);
1782 d = one / d;
1783 h = h * d * c;
1784
1785 aa = (a + md).neg() * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
1786 d = (one + aa * d).clamp_small(fpmin);
1787 c = (one + aa / c).clamp_small(fpmin);
1788 d = one / d;
1789 let del = d * c;
1790 h = h * del;
1791
1792 if (del.v - 1.0).abs() <= eps {
1793 let reg = bt * h / a;
1794 return if symm_transform {
1795 (1.0 - reg.v, -reg.da, -reg.db)
1796 } else {
1797 (reg.v, reg.da, reg.db)
1798 };
1799 }
1800 }
1801 let reg = bt * h / a;
1802 if symm_transform {
1803 (1.0 - reg.v, -reg.da, -reg.db)
1804 } else {
1805 (reg.v, reg.da, reg.db)
1806 }
1807}
1808
1809pub fn beta_logistic_inverse_link_jet(
1819 eta: f64,
1820 log_shape_center: f64,
1821 epsilon: f64,
1822) -> InverseLinkJet {
1823 let logistic = logistic_uwith_derivatives(eta);
1824 let a = (log_shape_center - epsilon).exp();
1825 let b = (log_shape_center + epsilon).exp();
1826 let mu = beta_reg_logistic(a, b, logistic);
1827 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1828 let d1 = log_d1.exp();
1829 let t = a * logistic.one_minus_u - b * logistic.u;
1830 let d2 = d1 * t;
1831 let d3 = d1 * (t * t - (a + b) * logistic.du);
1832 InverseLinkJet { mu, d1, d2, d3 }
1833}
1834
1835pub fn beta_logistic_inverse_link_pdfthird_derivative(
1836 eta: f64,
1837 log_shape_center: f64,
1838 epsilon: f64,
1839) -> f64 {
1840 let logistic = logistic_uwith_derivatives(eta);
1863 let a = (log_shape_center - epsilon).exp();
1864 let b = (log_shape_center + epsilon).exp();
1865 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1866 let d1 = log_d1.exp();
1867 let c = a + b;
1868 let t = a * logistic.one_minus_u - b * logistic.u;
1869 let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1870 d1 * (t * t * t - 3.0 * c * t * logistic.du - c * u2)
1871}
1872
1873pub fn beta_logistic_inverse_link_pdffourth_derivative(
1881 eta: f64,
1882 log_shape_center: f64,
1883 epsilon: f64,
1884) -> f64 {
1885 let logistic = logistic_uwith_derivatives(eta);
1886 let a = (log_shape_center - epsilon).exp();
1887 let b = (log_shape_center + epsilon).exp();
1888 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1889 let d1 = log_d1.exp();
1890 let c = a + b;
1891 let t = a * logistic.one_minus_u - b * logistic.u;
1892 let u2 = logistic.du * (logistic.one_minus_u - logistic.u);
1893 let u3 = u2 * (logistic.one_minus_u - logistic.u) - 2.0 * logistic.du * logistic.du;
1894 let t2 = t * t;
1895 d1 * (t2 * t2 - 6.0 * c * t2 * logistic.du - 4.0 * c * t * u2
1896 + 3.0 * c * c * logistic.du * logistic.du
1897 - c * u3)
1898}
1899
1900pub fn beta_logistic_inverse_link_jetwith_param_partials(
1901 eta: f64,
1902 log_shape_center: f64,
1903 epsilon: f64,
1904) -> SasJetWithParamPartials {
1905 let logistic = logistic_uwith_derivatives(eta);
1906 let a = (log_shape_center - epsilon).exp();
1907 let b = (log_shape_center + epsilon).exp();
1908 let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials_logistic(a, b, logistic);
1909 let dmu_dlog_shape_center = a * dmu_da + b * dmu_db;
1910 let dmu_depsilon = -a * dmu_da + b * dmu_db;
1911 let log_d1 = beta_logistic_log_d1(a, b, logistic);
1912 let d1 = log_d1.exp();
1913 let t = a * logistic.one_minus_u - b * logistic.u;
1914 let d2 = d1 * t;
1915 let k = t * t - (a + b) * logistic.du;
1916 let d3 = d1 * k;
1917 let jet = InverseLinkJet { mu, d1, d2, d3 };
1918
1919 let psi_a = digamma(a);
1920 let psi_b = digamma(b);
1921 let psi_ab = digamma(a + b);
1922 let la = logistic.ln_u - psi_a + psi_ab;
1923 let lb = logistic.ln_one_minus_u - psi_b + psi_ab;
1924
1925 let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
1926 let logd1_p = a_p * la + b_p * lb;
1927 let d1_p = d1 * logd1_p;
1928 let t_p = a_p * logistic.one_minus_u - b_p * logistic.u;
1929 let d2_p = d1_p * t + d1 * t_p;
1930 let k_p = 2.0 * t * t_p - (a_p + b_p) * logistic.du;
1931 let d3_p = d1_p * k + d1 * k_p;
1932 InverseLinkJet {
1933 mu: dmu,
1934 d1: d1_p,
1935 d2: d2_p,
1936 d3: d3_p,
1937 }
1938 };
1939 let djet_dlog_shape_center = partials_for(a, b, dmu_dlog_shape_center);
1940 let djet_depsilon = partials_for(-a, b, dmu_depsilon);
1941 SasJetWithParamPartials {
1942 jet,
1943 djet_depsilon,
1944 djet_dlog_delta: djet_dlog_shape_center,
1945 }
1946}
1947
1948pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
1952 let delta_id = sas_delta_from_raw_log_delta(log_delta);
1953 if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
1954 return component_inverse_link_jet(LinkComponent::Probit, eta);
1955 }
1956 let e = if eta.is_finite() { eta } else { 0.0 };
1957 let a = e.asinh();
1958 let delta = delta_id;
1959 let u_raw = delta * a - epsilon;
1960 let u = tanh_bound(u_raw, SAS_U_CLAMP);
1961 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
1962 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
1963 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
1964 let s = u.sinh();
1965 let c = u.cosh();
1966 let z = s;
1967 let q = e.hypot(1.0);
1968 let inv_q = 1.0 / q;
1969 let inv_q2 = inv_q * inv_q;
1970 let inv_q3 = inv_q2 * inv_q;
1971 let inv_q5 = inv_q3 * inv_q2;
1972 let r1 = delta * inv_q;
1973 let r2 = -delta * e * inv_q3;
1974 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
1975 let u1 = g1 * r1;
1976 let u2 = g2 * r1 * r1 + g1 * r2;
1977 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
1978 let z1 = c * u1;
1979 let z2 = s * u1 * u1 + c * u2;
1980 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
1981 let base = probit_jet(z);
1982 chain_inverse_link_jet(base, z1, z2, z3)
1983}
1984
1985pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
1986 let e = if eta.is_finite() { eta } else { 0.0 };
2022 let a = e.asinh();
2023 let delta = sas_delta_from_raw_log_delta(log_delta);
2024 let u_raw = delta * a - epsilon;
2025 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2026 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2027 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2028 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2029 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2030 let s = u.sinh();
2031 let c = u.cosh();
2032 let z = s;
2033 let base = probit_jet(z);
2034 let q = e.hypot(1.0);
2035 let inv_q = 1.0 / q;
2036 let inv_q2 = inv_q * inv_q;
2037 let inv_q3 = inv_q2 * inv_q;
2038 let inv_q5 = inv_q3 * inv_q2;
2039 let inv_q7 = inv_q5 * inv_q2;
2040 let r1 = delta * inv_q;
2041 let r2 = -delta * e * inv_q3;
2042 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2043 let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2044 let u1 = g1 * r1;
2045 let u2 = g2 * r1 * r1 + g1 * r2;
2046 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2047 let u4 = g4 * r1.powi(4)
2048 + 6.0 * g3 * r1 * r1 * r2
2049 + 3.0 * g2 * r2 * r2
2050 + 4.0 * g2 * r1 * r3
2051 + g1 * r4;
2052 let z1 = c * u1;
2053 let z2 = s * u1 * u1 + c * u2;
2054 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2055 let z4 =
2056 s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2057 let base4 = probit_pdfthird_derivative(z);
2058 let out = base4 * z1.powi(4)
2059 + 6.0 * base.d3 * z1 * z1 * z2
2060 + 3.0 * base.d2 * z2 * z2
2061 + 4.0 * base.d2 * z1 * z3
2062 + base.d1 * z4;
2063 canonicalzero(out)
2064}
2065
2066pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
2083 let e = if eta.is_finite() { eta } else { 0.0 };
2084 let a = e.asinh();
2085 let delta = sas_delta_from_raw_log_delta(log_delta);
2086 let u_raw = delta * a - epsilon;
2087 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2088 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2089 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2090 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2091 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2092 let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
2093 let s = u.sinh();
2094 let c = u.cosh();
2095 let z = s;
2096
2097 let base = probit_jet(z);
2099 let phi3 = probit_pdfthird_derivative(z); let phi4 = probit_pdffourth_derivative(z); let q = e.hypot(1.0);
2104 let inv_q = 1.0 / q;
2105 let inv_q2 = inv_q * inv_q;
2106 let inv_q3 = inv_q2 * inv_q;
2107 let inv_q5 = inv_q3 * inv_q2;
2108 let inv_q7 = inv_q5 * inv_q2;
2109 let inv_q9 = inv_q7 * inv_q2;
2110
2111 let r1 = delta * inv_q;
2112 let r2 = -delta * e * inv_q3;
2113 let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
2114 let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
2115 let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
2116
2117 let u1 = g1 * r1;
2119 let u2 = g2 * r1 * r1 + g1 * r2;
2120 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2121 let u4 = g4 * r1.powi(4)
2122 + 6.0 * g3 * r1 * r1 * r2
2123 + 3.0 * g2 * r2 * r2
2124 + 4.0 * g2 * r1 * r3
2125 + g1 * r4;
2126 let u5 = g5 * r1.powi(5)
2127 + 10.0 * g4 * r1 * r1 * r1 * r2
2128 + 15.0 * g3 * r1 * r2 * r2
2129 + 10.0 * g3 * r1 * r1 * r3
2130 + 10.0 * g2 * r2 * r3
2131 + 5.0 * g2 * r1 * r4
2132 + g1 * r5;
2133
2134 let z1 = c * u1;
2136 let z2 = s * u1 * u1 + c * u2;
2137 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2138 let z4 =
2139 s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
2140 let z5 = c * u1.powi(5)
2141 + 10.0 * s * u1 * u1 * u1 * u2
2142 + 15.0 * c * u1 * u2 * u2
2143 + 10.0 * c * u1 * u1 * u3
2144 + 10.0 * s * u2 * u3
2145 + 5.0 * s * u1 * u4
2146 + c * u5;
2147
2148 let out = phi4 * z1.powi(5)
2151 + 10.0 * phi3 * z1 * z1 * z1 * z2
2152 + 15.0 * base.d3 * z1 * z2 * z2
2153 + 10.0 * base.d3 * z1 * z1 * z3
2154 + 10.0 * base.d2 * z2 * z3
2155 + 5.0 * base.d2 * z1 * z4
2156 + base.d1 * z5;
2157 canonicalzero(out)
2158}
2159
2160pub fn sas_inverse_link_jetwith_param_partials(
2161 eta: f64,
2162 epsilon: f64,
2163 log_delta: f64,
2164) -> SasJetWithParamPartials {
2165 let e = if eta.is_finite() { eta } else { 0.0 };
2166 let a = e.asinh();
2167 let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
2168 let delta = ld_eff.exp();
2169 let ddelta_draw = delta * dld_eff_draw;
2170 let u_raw = delta * a - epsilon;
2171 let u = tanh_bound(u_raw, SAS_U_CLAMP);
2172 let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
2173 let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
2174 let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
2175 let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
2176 let s = u.sinh();
2177 let c = u.cosh();
2178 let z = s;
2179 let q = e.hypot(1.0);
2180 let inv_q = 1.0 / q;
2181 let inv_q2 = inv_q * inv_q;
2182 let inv_q3 = inv_q2 * inv_q;
2183 let inv_q5 = inv_q3 * inv_q2;
2184 let a1 = inv_q;
2185 let a2 = -e * inv_q3;
2186 let a3 = (2.0 * e * e - 1.0) * inv_q5;
2187 let r1 = delta * a1;
2188 let r2 = delta * a2;
2189 let r3 = delta * a3;
2190 let u1 = g1 * r1;
2191 let u2 = g2 * r1 * r1 + g1 * r2;
2192 let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
2193 let z1 = c * u1;
2194 let z2 = s * u1 * u1 + c * u2;
2195 let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
2196
2197 let base = probit_jet(z);
2198 let jet = chain_inverse_link_jet(base, z1, z2, z3);
2199
2200 let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
2203 let z_t = c * u_t;
2204 let z1_t = s * u_t * u1 + c * u1_t;
2205 let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
2206 let z3_t = s * u_t * u1 * u1 * u1
2207 + 3.0 * c * u1 * u1 * u1_t
2208 + 3.0 * c * u_t * u1 * u2
2209 + 3.0 * s * (u1_t * u2 + u1 * u2_t)
2210 + s * u_t * u3
2211 + c * u3_t;
2212
2213 InverseLinkJet {
2214 mu: base.d1 * z_t,
2215 d1: base.d2 * z_t * z1 + base.d1 * z1_t,
2216 d2: base.d3 * z_t * z1 * z1
2217 + 2.0 * base.d2 * z1 * z1_t
2218 + base.d2 * z_t * z2
2219 + base.d1 * z2_t,
2220 d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
2221 + 3.0 * base.d3 * z1 * z1 * z1_t
2222 + 3.0 * base.d3 * z_t * z1 * z2
2223 + 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
2224 + base.d2 * z_t * z3
2225 + base.d1 * z3_t,
2226 }
2227 };
2228
2229 let rt_eps = -1.0;
2231 let r1t_eps = 0.0;
2232 let r2t_eps = 0.0;
2233 let r3t_eps = 0.0;
2234 let u_eps = g1 * rt_eps;
2235 let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
2236 let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
2237 let u3_eps = g4 * rt_eps * r1 * r1 * r1
2238 + 3.0 * g3 * r1 * r1 * r1t_eps
2239 + 3.0 * g3 * rt_eps * r1 * r2
2240 + 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
2241 + g2 * rt_eps * r3
2242 + g1 * r3t_eps;
2243 let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
2244
2245 let rt_ld = ddelta_draw * a;
2247 let r1t_ld = ddelta_draw * a1;
2248 let r2t_ld = ddelta_draw * a2;
2249 let r3t_ld = ddelta_draw * a3;
2250 let u_ld = g1 * rt_ld;
2251 let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
2252 let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
2253 let u3_ld = g4 * rt_ld * r1 * r1 * r1
2254 + 3.0 * g3 * r1 * r1 * r1t_ld
2255 + 3.0 * g3 * rt_ld * r1 * r2
2256 + 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
2257 + g2 * rt_ld * r3
2258 + g1 * r3t_ld;
2259 let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
2260
2261 SasJetWithParamPartials {
2262 jet,
2263 djet_depsilon,
2264 djet_dlog_delta,
2265 }
2266}
2267
2268#[cfg(test)]
2269mod tests {
2270 use super::*;
2271 use gam_problem::{InverseLink, LikelihoodSpec, LinkComponent, MixtureLinkSpec, SasLinkState};
2272
2273 #[test]
2274 fn softmax_jacobian_matchesfd() {
2275 let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
2276 let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
2277 let h = 1e-6;
2278 for j in 0..rho.len() {
2279 let mut rp = rho.clone();
2280 rp[j] += h;
2281 let mut rm = rho.clone();
2282 rm[j] -= h;
2283 let pp = softmax_last_fixedzero(&rp);
2284 let pm = softmax_last_fixedzero(&rm);
2285 let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
2286 for k in 0..pi.len() {
2287 let err = (jac[[k, j]] - fd[k]).abs();
2288 assert_eq!(
2289 jac[[k, j]].signum(),
2290 fd[k].signum(),
2291 "jac sign mismatch at ({k},{j}): analytic={} fd={}",
2292 jac[[k, j]],
2293 fd[k]
2294 );
2295 assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
2296 }
2297 }
2298 }
2299
2300 #[test]
2301 fn mixture_jet_rho_partials_matchfd() {
2302 let spec = MixtureLinkSpec {
2303 components: vec![
2304 LinkComponent::Probit,
2305 LinkComponent::Logit,
2306 LinkComponent::CLogLog,
2307 LinkComponent::Cauchit,
2308 ],
2309 initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
2310 };
2311 let state = state_fromspec(&spec).expect("state");
2312 let eta = 0.35;
2313 let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
2314 let h = 1e-6;
2315 for j in 0..state.rho.len() {
2316 let mut rp = state.rho.clone();
2317 rp[j] += h;
2318 let sp = MixtureLinkSpec {
2319 components: state.components.clone(),
2320 initial_rho: rp,
2321 };
2322 let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
2323 let mut rm = state.rho.clone();
2324 rm[j] -= h;
2325 let sm = MixtureLinkSpec {
2326 components: state.components.clone(),
2327 initial_rho: rm,
2328 };
2329 let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
2330 let fd = InverseLinkJet {
2331 mu: (jp.mu - jm.mu) / (2.0 * h),
2332 d1: (jp.d1 - jm.d1) / (2.0 * h),
2333 d2: (jp.d2 - jm.d2) / (2.0 * h),
2334 d3: (jp.d3 - jm.d3) / (2.0 * h),
2335 };
2336 let an = out.djet_drho[j];
2337 assert_eq!(an.mu.signum(), fd.mu.signum());
2338 assert_eq!(an.d1.signum(), fd.d1.signum());
2339 assert_eq!(an.d2.signum(), fd.d2.signum());
2340 assert_eq!(an.d3.signum(), fd.d3.signum());
2341 assert!((an.mu - fd.mu).abs() < 1e-6);
2342 assert!((an.d1 - fd.d1).abs() < 1e-6);
2343 assert!((an.d2 - fd.d2).abs() < 1e-6);
2344 assert!((an.d3 - fd.d3).abs() < 1e-6);
2345 }
2346 }
2347
2348 #[test]
2349 fn sas_param_partials_matchfd() {
2350 let eta = 0.37;
2351 let epsilon = -0.12;
2352 let log_delta = 0.21;
2353 let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2354 let h = 1e-6;
2355
2356 let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
2357 let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
2358 let fd_ep = InverseLinkJet {
2359 mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
2360 d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
2361 d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
2362 d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
2363 };
2364 assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
2365 assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
2366 assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
2367 assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
2368 assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
2369 assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
2370 assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
2371 assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
2372
2373 let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
2374 let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
2375 let fd_ld = InverseLinkJet {
2376 mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
2377 d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
2378 d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
2379 d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
2380 };
2381 assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
2382 assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
2383 assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
2384 assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
2385 assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
2386 assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
2387 assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
2388 assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
2389 }
2390
2391 #[test]
2392 fn sas_jet_extreme_inputs_stay_finite() {
2393 let cases = [
2394 (-1e6, 0.0, 0.0),
2395 (1e6, 0.0, 0.0),
2396 (3.0, 12.0, 12.0),
2397 (-3.0, -12.0, -12.0),
2398 (0.5, 40.0, 10.0),
2399 (0.5, -40.0, -10.0),
2400 ];
2401 for (eta, eps, log_delta) in cases {
2402 let j = sas_inverse_link_jet(eta, eps, log_delta);
2403 assert!(j.mu.is_finite());
2404 assert!(j.d1.is_finite());
2405 assert!(j.d2.is_finite());
2406 assert!(j.d3.is_finite());
2407 let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
2408 assert!(p.djet_depsilon.mu.is_finite());
2409 assert!(p.djet_depsilon.d1.is_finite());
2410 assert!(p.djet_depsilon.d2.is_finite());
2411 assert!(p.djet_depsilon.d3.is_finite());
2412 assert!(p.djet_dlog_delta.mu.is_finite());
2413 assert!(p.djet_dlog_delta.d1.is_finite());
2414 assert!(p.djet_dlog_delta.d2.is_finite());
2415 assert!(p.djet_dlog_delta.d3.is_finite());
2416 }
2417 }
2418
2419 #[test]
2420 fn sas_param_partials_remain_finite_in_extreme_region() {
2421 let eta = 10.0;
2422 let epsilon = -60.0;
2423 let log_delta = 40.0;
2424 let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
2425 assert!(j.djet_depsilon.mu.is_finite());
2426 assert!(j.djet_depsilon.d1.is_finite());
2427 assert!(j.djet_depsilon.d2.is_finite());
2428 assert!(j.djet_depsilon.d3.is_finite());
2429 assert!(j.djet_dlog_delta.mu.is_finite());
2430 assert!(j.djet_dlog_delta.d1.is_finite());
2431 assert!(j.djet_dlog_delta.d2.is_finite());
2432 assert!(j.djet_dlog_delta.d3.is_finite());
2433 }
2434
2435 #[test]
2436 fn sas_eta_jets_matchfd() {
2437 let eta = -0.43;
2438 let epsilon = 0.27;
2439 let log_delta = -0.31;
2440 let h = 1e-5;
2441 let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
2442 let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
2443 let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
2444 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2445 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2446 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2447 assert_eq!(j0.d1.signum(), d1fd.signum());
2448 assert_eq!(j0.d2.signum(), d2fd.signum());
2449 assert_eq!(j0.d3.signum(), d3fd.signum());
2450 assert!((j0.d1 - d1fd).abs() < 5e-5);
2451 assert!((j0.d2 - d2fd).abs() < 2e-4);
2452 assert!((j0.d3 - d3fd).abs() < 1e-3);
2453 }
2454
2455 #[test]
2456 fn family_dispatch_resolves_parameterized_links_from_spec() {
2457 let sas_state = sas_link_state_from_raw(0.0, 0.0).expect("sas state");
2462 let sas_spec = gam_problem::LikelihoodSpec {
2463 response: gam_problem::ResponseFamily::Binomial,
2464 link: InverseLink::Sas(sas_state),
2465 };
2466 let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
2467 assert!(sas_jet.mu.is_finite());
2468 assert!(sas_jet.d1.is_finite());
2469
2470 let mix_state = MixtureLinkState {
2471 components: vec![LinkComponent::Logit, LinkComponent::Probit],
2472 rho: ndarray::array![0.0],
2473 pi: ndarray::array![0.5, 0.5],
2474 };
2475 let mix_spec = gam_problem::LikelihoodSpec {
2476 response: gam_problem::ResponseFamily::Binomial,
2477 link: InverseLink::Mixture(mix_state),
2478 };
2479 let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
2480 assert!(mix_jet.mu.is_finite());
2481 assert!(mix_jet.d1.is_finite());
2482 }
2483
2484 #[test]
2485 fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
2486 let etas = [-40.0, -30.0, -5.0, 0.42, 5.0, 30.0, 40.0];
2487 for eta in etas {
2488 let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2489 let expected_mu = gam_linalg::utils::stable_logistic(eta);
2490 let expected_d1 = (-gam_linalg::utils::stable_softplus(-eta)
2491 - gam_linalg::utils::stable_softplus(eta))
2492 .exp();
2493 assert!(
2494 (j_bl.mu - expected_mu).abs() <= 1e-15 * expected_mu.abs().max(1.0),
2495 "mu mismatch at eta={eta}: got {}, expected {}",
2496 j_bl.mu,
2497 expected_mu
2498 );
2499 assert!(
2500 (j_bl.d1 - expected_d1).abs() <= 1e-12 * expected_d1.abs().max(f64::MIN_POSITIVE),
2501 "d1 mismatch at eta={eta}: got {}, expected {}",
2502 j_bl.d1,
2503 expected_d1
2504 );
2505 assert!(j_bl.d1 > 0.0, "d1 should stay positive at eta={eta}");
2506 }
2507
2508 let eta = 0.42;
2509 let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
2510 let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
2511 assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
2512 assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
2513 }
2514
2515 #[test]
2516 fn beta_logistic_eta_jets_matchfd() {
2517 let eta = -0.31;
2518 let delta = 0.27;
2519 let epsilon = -0.19;
2520 let h = 1e-5;
2521 let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2522 let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
2523 let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
2524 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2525 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2526 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2527 assert_eq!(j0.d1.signum(), d1fd.signum());
2528 assert_eq!(j0.d2.signum(), d2fd.signum());
2529 assert_eq!(j0.d3.signum(), d3fd.signum());
2530 assert!((j0.d1 - d1fd).abs() < 5e-5);
2531 assert!((j0.d2 - d2fd).abs() < 5e-5);
2532 assert!((j0.d3 - d3fd).abs() < 2e-4);
2533 }
2534
2535 #[test]
2536 fn standard_kernel_structs_match_component_jets() {
2537 let eta = 0.73;
2538 assert_eq!(
2539 ProbitLinkKernel.jet(eta).expect("probit"),
2540 component_inverse_link_jet(LinkComponent::Probit, eta)
2541 );
2542 assert_eq!(
2543 LogitLinkKernel.jet(eta).expect("logit"),
2544 component_inverse_link_jet(LinkComponent::Logit, eta)
2545 );
2546 assert_eq!(
2547 CLogLogLinkKernel.jet(eta).expect("cloglog"),
2548 component_inverse_link_jet(LinkComponent::CLogLog, eta)
2549 );
2550 assert_eq!(
2551 LogLogLinkKernel.jet(eta).expect("loglog"),
2552 component_inverse_link_jet(LinkComponent::LogLog, eta)
2553 );
2554 assert_eq!(
2555 CauchitLinkKernel.jet(eta).expect("cauchit"),
2556 component_inverse_link_jet(LinkComponent::Cauchit, eta)
2557 );
2558 }
2559
2560 #[test]
2561 fn all_component_eta_jets_matchfd() {
2562 let components = [
2563 LinkComponent::Logit,
2564 LinkComponent::Probit,
2565 LinkComponent::CLogLog,
2566 LinkComponent::LogLog,
2567 LinkComponent::Cauchit,
2568 ];
2569 let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
2570 let h = 1e-5;
2571 for c in components {
2572 for &eta in &points {
2573 let j0 = component_inverse_link_jet(c, eta);
2574 let jp = component_inverse_link_jet(c, eta + h);
2575 let jm = component_inverse_link_jet(c, eta - h);
2576 let d1fd = (jp.mu - jm.mu) / (2.0 * h);
2577 let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
2578 let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
2579 let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2580 1.2e-4
2581 } else {
2582 5e-5
2583 };
2584 let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2585 4e-4
2586 } else {
2587 1.2e-4
2588 };
2589 let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
2590 1.2e-3
2591 } else {
2592 4e-4
2593 };
2594 if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
2595 assert_eq!(
2596 j0.d1.signum(),
2597 d1fd.signum(),
2598 "d1 sign mismatch for {c:?} eta={eta}"
2599 );
2600 }
2601 if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
2602 assert_eq!(
2603 j0.d2.signum(),
2604 d2fd.signum(),
2605 "d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
2606 j0.d2,
2607 d2fd
2608 );
2609 }
2610 if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
2611 assert_eq!(
2612 j0.d3.signum(),
2613 d3fd.signum(),
2614 "d3 sign mismatch for {c:?} eta={eta}"
2615 );
2616 }
2617 assert!(
2618 (j0.d1 - d1fd).abs() < d1_tol,
2619 "d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2620 j0.d1,
2621 d1fd
2622 );
2623 assert!(
2624 (j0.d2 - d2fd).abs() < d2_tol,
2625 "d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2626 j0.d2,
2627 d2fd
2628 );
2629 assert!(
2630 (j0.d3 - d3fd).abs() < d3_tol,
2631 "d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
2632 j0.d3,
2633 d3fd
2634 );
2635 }
2636 }
2637 }
2638
2639 #[test]
2640 fn sas_center_matches_probit_at_delta1_epsilon0() {
2641 let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
2642 for eta in etas {
2643 let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
2644 let probit = ProbitLinkKernel.jet(eta).expect("probit");
2645 assert!(
2648 (sas.mu - probit.mu).abs() < 6e-4,
2649 "mu mismatch at eta={eta}"
2650 );
2651 assert!(
2652 (sas.d1 - probit.d1).abs() < 6e-4,
2653 "d1 mismatch at eta={eta}"
2654 );
2655 assert!(
2656 (sas.d2 - probit.d2).abs() < 2e-3,
2657 "d2 mismatch at eta={eta}"
2658 );
2659 assert!(
2660 (sas.d3 - probit.d3).abs() < 4e-3,
2661 "d3 mismatch at eta={eta}"
2662 );
2663 }
2664 }
2665
2666 #[test]
2667 fn beta_logistic_param_partials_matchfd() {
2668 let eta = -0.41;
2669 let delta = 0.23;
2670 let epsilon = -0.17;
2671 let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2672 let h = 1e-6;
2673
2674 let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
2675 let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
2676 let fd_delta = InverseLinkJet {
2677 mu: (dp.mu - dm.mu) / (2.0 * h),
2678 d1: (dp.d1 - dm.d1) / (2.0 * h),
2679 d2: (dp.d2 - dm.d2) / (2.0 * h),
2680 d3: (dp.d3 - dm.d3) / (2.0 * h),
2681 };
2682 assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
2683 assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
2684 assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
2685 assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
2686 assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
2687 assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
2688 assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
2689 assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
2690
2691 let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
2692 let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
2693 let fd_epsilon = InverseLinkJet {
2694 mu: (ep.mu - em.mu) / (2.0 * h),
2695 d1: (ep.d1 - em.d1) / (2.0 * h),
2696 d2: (ep.d2 - em.d2) / (2.0 * h),
2697 d3: (ep.d3 - em.d3) / (2.0 * h),
2698 };
2699 assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
2700 assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
2701 assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
2702 assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
2703 assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
2704 assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
2705 assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
2706 assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
2707 }
2708
2709 #[test]
2710 fn beta_logistic_left_tail_uses_unclamped_log_space() {
2711 let eta = -40.0_f64;
2712 let delta = 0.2_f64;
2713 let epsilon = -0.1_f64;
2714 let a = (delta - epsilon).exp();
2715 let b = (delta + epsilon).exp();
2716 let expected_mu = beta_reg(a, b, eta.exp());
2717 let out = beta_logistic_inverse_link_jet(eta, delta, epsilon);
2718
2719 assert!(
2720 (out.mu - expected_mu).abs() <= 1e-12 * expected_mu.abs().max(f64::MIN_POSITIVE),
2721 "left-tail mu mismatch: got {}, expected {}",
2722 out.mu,
2723 expected_mu
2724 );
2725 assert!(out.d1 > 0.0);
2726 assert!(out.d2 > 0.0);
2727 assert!(out.d3 > 0.0);
2728 assert!(out.d1 < 1e-20);
2729
2730 let partials = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
2731 assert!(partials.jet.d1 > 0.0);
2732 assert!(partials.jet.d2 > 0.0);
2733 assert!(partials.jet.d3 > 0.0);
2734 assert!(partials.djet_dlog_delta.d1.is_finite());
2735 assert!(partials.djet_depsilon.d1.is_finite());
2736 }
2737
2738 #[test]
2739 fn beta_logistic_mu_is_symmetric_in_logistic_tails() {
2740 let delta = 0.2;
2741 let epsilon = -0.35;
2742 let etas = [-40.0, -30.0, -5.0, -0.42, 0.0, 0.42, 5.0, 30.0, 40.0];
2743 for eta in etas {
2744 let left = beta_logistic_inverse_link_jet(eta, delta, epsilon).mu;
2745 let right = 1.0 - beta_logistic_inverse_link_jet(-eta, delta, -epsilon).mu;
2746 assert!(
2747 (left - right).abs() <= 1e-14,
2748 "symmetry mismatch at eta={eta}: left={left}, right={right}"
2749 );
2750 }
2751 }
2752
2753 #[test]
2754 fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
2755 let sas = InverseLink::Sas(sas_link_state_from_raw(-0.25, 0.35).expect("sas state"));
2756 let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
2757 epsilon: 0.18,
2758 log_delta: -0.22,
2759 delta: (-0.22_f64).exp(),
2760 });
2761 let mixture = InverseLink::Mixture(
2762 state_fromspec(&MixtureLinkSpec {
2763 components: vec![
2764 LinkComponent::Probit,
2765 LinkComponent::Logit,
2766 LinkComponent::CLogLog,
2767 LinkComponent::Cauchit,
2768 ],
2769 initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
2770 })
2771 .expect("mixture state"),
2772 );
2773 let links = [
2774 InverseLink::Standard(StandardLink::Probit),
2775 InverseLink::Standard(StandardLink::Logit),
2776 InverseLink::Standard(StandardLink::CLogLog),
2777 sas,
2778 beta_logistic,
2779 mixture,
2780 ];
2781 let etas = [-1.1, -0.2, 0.6];
2782 let h = 1e-5;
2783
2784 for link in &links {
2785 for &eta in &etas {
2786 let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
2787 let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
2788 let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
2789 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
2790 .expect("analytic d4");
2791 assert_eq!(
2792 d4.signum(),
2793 d4fd.signum(),
2794 "d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
2795 link,
2796 d4,
2797 d4fd
2798 );
2799 assert!(
2800 (d4 - d4fd).abs() < 5e-3,
2801 "d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
2802 link,
2803 d4,
2804 d4fd
2805 );
2806 }
2807 }
2808 }
2809
2810 #[test]
2811 fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
2812 let eta = 800.0;
2813 let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2814 assert_eq!(jet.mu, 1.0);
2815 assert!(
2816 jet.d1 == 0.0,
2817 "for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
2818 jet.d1
2819 );
2820 assert!(
2821 jet.d2 == 0.0,
2822 "the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
2823 jet.d2
2824 );
2825 assert!(
2826 jet.d3 == 0.0,
2827 "the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
2828 jet.d3
2829 );
2830
2831 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2832 &InverseLink::Standard(StandardLink::CLogLog),
2833 eta,
2834 )
2835 .expect("cloglog d4");
2836 assert!(
2837 d4 == 0.0,
2838 "the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
2839 );
2840 }
2841
2842 #[test]
2843 fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
2844 let eta = -800.0;
2845 let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
2846 assert_eq!(jet.mu, 0.0);
2847 assert!(
2848 jet.d1 == 0.0,
2849 "for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
2850 jet.d1
2851 );
2852 assert!(
2853 jet.d2 == 0.0,
2854 "the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
2855 jet.d2
2856 );
2857 assert!(
2858 jet.d3 == 0.0,
2859 "the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
2860 jet.d3
2861 );
2862
2863 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2864 &InverseLink::Mixture(
2865 state_fromspec(&MixtureLinkSpec {
2866 components: vec![LinkComponent::LogLog, LinkComponent::Probit],
2867 initial_rho: Array1::from_vec(vec![12.0]),
2868 })
2869 .expect("mixture state"),
2870 ),
2871 eta,
2872 )
2873 .expect("loglog mixture d4");
2874 assert!(
2875 d4.is_finite(),
2876 "even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
2877 );
2878 }
2879
2880 #[test]
2881 fn logit_tail_derivatives_should_match_stable_closed_forms() {
2882 let eta = 50.0_f64;
2883 let z = (-eta).exp();
2884 let denom = 1.0_f64 + z;
2885 let stable_d1 = z / denom.powi(2);
2886 let stable_d2 = z * (z - 1.0) / denom.powi(3);
2887 let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
2888 let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
2889 let stable_d5 =
2890 z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
2891
2892 assert!(stable_d1 > 0.0);
2893 assert!(stable_d2 < 0.0);
2894 assert!(stable_d3 > 0.0);
2895 assert!(stable_d4 < 0.0);
2896 assert!(stable_d5 > 0.0);
2897
2898 let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
2899 assert!(
2900 (jet.d1 - stable_d1).abs() < 1e-30,
2901 "logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
2902 jet.d1,
2903 stable_d1
2904 );
2905 assert!(
2906 (jet.d2 - stable_d2).abs() < 1e-30,
2907 "logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
2908 jet.d2,
2909 stable_d2
2910 );
2911 assert!(
2912 (jet.d3 - stable_d3).abs() < 1e-30,
2913 "logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
2914 jet.d3,
2915 stable_d3
2916 );
2917
2918 let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
2919 &InverseLink::Standard(StandardLink::Logit),
2920 eta,
2921 )
2922 .expect("logit d4");
2923 assert!(
2924 (d4 - stable_d4).abs() < 1e-30,
2925 "logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
2926 d4,
2927 stable_d4
2928 );
2929
2930 let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
2931 &InverseLink::Standard(StandardLink::Logit),
2932 eta,
2933 )
2934 .expect("logit d5");
2935 assert!(
2936 (d5 - stable_d5).abs() < 1e-30,
2937 "logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
2938 d5,
2939 stable_d5
2940 );
2941 }
2942
2943 #[test]
2944 fn cloglog_negative_tail_value_should_match_expm1_form() {
2945 let eta = -50.0_f64;
2946 let t = eta.exp();
2947 let stable_mu = -(-t).exp_m1();
2948 assert!(stable_mu > 0.0);
2949
2950 let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
2951 assert!(
2952 (jet.mu - stable_mu).abs() < 1e-30,
2953 "cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
2954 jet.mu,
2955 stable_mu
2956 );
2957 }
2958
2959 #[test]
2960 fn non_logit_probit_fisher_weight_jets_match_finite_differences() {
2961 fn rel_err(a: f64, b: f64) -> f64 {
2962 (a - b).abs() / a.abs().max(b.abs()).max(1.0e-8)
2963 }
2964
2965 let cases = [
2966 (LinkComponent::CLogLog, [-3.0_f64, -0.5, 0.4, 1.5]),
2967 (LinkComponent::LogLog, [-1.5_f64, -0.4, 0.5, 3.0]),
2968 (LinkComponent::Cauchit, [-3.0_f64, -0.7, 0.6, 3.0]),
2969 ];
2970 for (component, etas) in cases {
2971 for eta in etas {
2972 let (w, w1, w2, w3, w4) = component_fisher_weight_jet5(component, eta);
2973 let jet = component_inverse_link_jet(component, eta);
2974 let expected = jet.d1 * jet.d1 / (jet.mu * (1.0 - jet.mu));
2975 assert!(
2976 rel_err(w, expected) < 1.0e-12,
2977 "{component:?} Fisher weight mismatch at eta={eta}: got {w}, expected {expected}"
2978 );
2979
2980 let h = 1.0e-4;
2981 let fd1 = (component_fisher_weight_jet5(component, eta + h).0
2982 - component_fisher_weight_jet5(component, eta - h).0)
2983 / (2.0 * h);
2984 let fd2 = (component_fisher_weight_jet5(component, eta + h).1
2985 - component_fisher_weight_jet5(component, eta - h).1)
2986 / (2.0 * h);
2987 let fd3 = (component_fisher_weight_jet5(component, eta + h).2
2988 - component_fisher_weight_jet5(component, eta - h).2)
2989 / (2.0 * h);
2990 let fd4 = (component_fisher_weight_jet5(component, eta + h).3
2991 - component_fisher_weight_jet5(component, eta - h).3)
2992 / (2.0 * h);
2993
2994 assert!(
2995 rel_err(w1, fd1) < 1.0e-5,
2996 "{component:?} W' mismatch at eta={eta}: {w1} vs {fd1}"
2997 );
2998 assert!(
2999 rel_err(w2, fd2) < 1.0e-5,
3000 "{component:?} W'' mismatch at eta={eta}: {w2} vs {fd2}"
3001 );
3002 assert!(
3003 rel_err(w3, fd3) < 5.0e-5,
3004 "{component:?} W''' mismatch at eta={eta}: {w3} vs {fd3}"
3005 );
3006 assert!(
3007 rel_err(w4, fd4) < 5.0e-4,
3008 "{component:?} W'''' mismatch at eta={eta}: {w4} vs {fd4}"
3009 );
3010 }
3011 }
3012 }
3013
3014 #[test]
3015 fn mixture_fisher_weight_jet_covers_loglog_and_cauchit_components() {
3016 let state = state_fromspec(&MixtureLinkSpec {
3017 components: vec![
3018 LinkComponent::CLogLog,
3019 LinkComponent::LogLog,
3020 LinkComponent::Cauchit,
3021 ],
3022 initial_rho: Array1::from_vec(vec![0.3, -0.2]),
3023 })
3024 .expect("mixture state");
3025 let link = InverseLink::Mixture(state);
3026 assert!(
3027 inverse_link_has_fisher_weight_jet(&link),
3028 "anchored mixtures with loglog/cauchit components must remain eligible for Firth"
3029 );
3030 assert!(
3031 LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()).supports_firth(),
3032 "Firth support should use the mixture inverse-link Fisher jet, not standalone LinkFunction coverage"
3033 );
3034
3035 for eta in [-2.0_f64, -0.25, 0.75, 2.5] {
3036 let (w, w1, w2, w3, w4) =
3037 fisher_weight_jet5_for_inverse_link(&link, eta).expect("mixture Fisher jet");
3038 for value in [w, w1, w2, w3, w4] {
3039 assert!(
3040 value.is_finite(),
3041 "mixture Fisher weight jet should be finite at eta={eta}; got {value}"
3042 );
3043 }
3044 assert!(
3045 w > 0.0,
3046 "mixture Fisher working weight should be positive away from saturated tails at eta={eta}; got {w}"
3047 );
3048 }
3049 }
3050
3051 #[test]
3052 fn loglog_fifth_derivative_should_match_closed_form_sign() {
3053 let eta = 0.0_f64;
3054 let r = (-eta).exp();
3055 let expected =
3056 (-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
3057 let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
3058 assert!(
3059 (d5 - expected).abs() < 1e-15,
3060 "loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
3061 );
3062 assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
3063 }
3064}