use crate::estimate::EstimationError;
use crate::families::lognormal_kernel::latent_cloglog_jet5;
use crate::probability::{
normal_cdf, normal_pdf,
stable_polynomial_times_exp_neg as stable_nonnegative_poly_times_exp_neg,
};
use crate::types::{
InverseLink, LatentCLogLogState, LikelihoodSpec, LinkComponent, LinkFunction, MixtureLinkSpec,
MixtureLinkState, ResponseFamily, SasLinkSpec, SasLinkState,
};
use ndarray::Array1;
use statrs::function::beta::{beta_reg, ln_beta};
use statrs::function::gamma::digamma;
use std::sync::OnceLock;
const SAS_U_CLAMP: f64 = 50.0;
pub(crate) const SAS_LOG_DELTA_BOUND: f64 = 12.0;
const BETA_LOGISTIC_U_EPS: f64 = 1e-12;
#[inline]
fn latent_cloglog_quadctx() -> &'static crate::quadrature::QuadratureContext {
static QUADCTX: OnceLock<crate::quadrature::QuadratureContext> = OnceLock::new();
QUADCTX.get_or_init(crate::quadrature::QuadratureContext::new)
}
#[inline]
fn latent_cloglog_point_jet(
state: &LatentCLogLogState,
eta: f64,
) -> Result<InverseLinkJet, EstimationError> {
let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, state.latent_sd)?;
Ok(InverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
})
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct InverseLinkJet {
pub mu: f64,
pub d1: f64,
pub d2: f64,
pub d3: f64,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) struct LogitJet5 {
pub mu: f64,
pub d1: f64,
pub d2: f64,
pub d3: f64,
pub d4: f64,
pub d5: f64,
}
#[inline]
fn canonicalzero(v: f64) -> f64 {
if v.abs() < f64::MIN_POSITIVE { 0.0 } else { v }
}
#[inline]
fn canonicalize_jet(mut jet: InverseLinkJet) -> InverseLinkJet {
jet.d1 = canonicalzero(jet.d1);
jet.d2 = canonicalzero(jet.d2);
jet.d3 = canonicalzero(jet.d3);
jet
}
#[inline]
pub(crate) fn logit_inverse_link_jet5(eta: f64) -> LogitJet5 {
if eta.is_nan() {
return LogitJet5 {
mu: f64::NAN,
d1: f64::NAN,
d2: f64::NAN,
d3: f64::NAN,
d4: f64::NAN,
d5: f64::NAN,
};
}
if eta == f64::INFINITY {
return LogitJet5 {
mu: 1.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
d4: 0.0,
d5: 0.0,
};
}
if eta == f64::NEG_INFINITY {
return LogitJet5 {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
d4: 0.0,
d5: 0.0,
};
}
let jet = if eta >= 0.0 {
let z = (-eta).exp();
let opz = 1.0 + z;
let opz2 = opz * opz;
let opz3 = opz2 * opz;
let opz4 = opz3 * opz;
let opz5 = opz4 * opz;
let opz6 = opz5 * opz;
let z2 = z * z;
let z3 = z2 * z;
let z4 = z3 * z;
LogitJet5 {
mu: 1.0 / opz,
d1: z / opz2,
d2: z * (z - 1.0) / opz3,
d3: z * (z2 - 4.0 * z + 1.0) / opz4,
d4: z * (z3 - 11.0 * z2 + 11.0 * z - 1.0) / opz5,
d5: z * (z4 - 26.0 * z3 + 66.0 * z2 - 26.0 * z + 1.0) / opz6,
}
} else {
let z = eta.exp();
let opz = 1.0 + z;
let opz2 = opz * opz;
let opz3 = opz2 * opz;
let opz4 = opz3 * opz;
let opz5 = opz4 * opz;
let opz6 = opz5 * opz;
let z2 = z * z;
let z3 = z2 * z;
let z4 = z3 * z;
LogitJet5 {
mu: z / opz,
d1: z / opz2,
d2: z * (1.0 - z) / opz3,
d3: z * (1.0 - 4.0 * z + z2) / opz4,
d4: z * (1.0 - 11.0 * z + 11.0 * z2 - z3) / opz5,
d5: z * (1.0 - 26.0 * z + 66.0 * z2 - 26.0 * z3 + z4) / opz6,
}
};
LogitJet5 {
mu: jet.mu,
d1: canonicalzero(jet.d1),
d2: canonicalzero(jet.d2),
d3: canonicalzero(jet.d3),
d4: canonicalzero(jet.d4),
d5: canonicalzero(jet.d5),
}
}
#[inline]
fn logistic_stable(eta: f64) -> f64 {
logit_inverse_link_jet5(eta).mu
}
#[inline]
fn probit_jet(eta: f64) -> InverseLinkJet {
if eta.is_nan() {
return InverseLinkJet {
mu: f64::NAN,
d1: f64::NAN,
d2: f64::NAN,
d3: f64::NAN,
};
}
if eta == f64::INFINITY {
return InverseLinkJet {
mu: 1.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
if eta == f64::NEG_INFINITY {
return InverseLinkJet {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
let x = eta;
let phi = normal_pdf(x);
InverseLinkJet {
mu: normal_cdf(x),
d1: phi,
d2: -x * phi,
d3: (x * x - 1.0) * phi,
}
}
#[inline]
fn probit_pdfthird_derivative(eta: f64) -> f64 {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let x = eta;
let phi = normal_pdf(x);
canonicalzero(-(x * x * x - 3.0 * x) * phi)
}
#[inline]
fn probit_pdffourth_derivative(eta: f64) -> f64 {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let x = eta;
let phi = normal_pdf(x);
canonicalzero((x * x * x * x - 6.0 * x * x + 3.0) * phi)
}
#[inline]
fn chain_inverse_link_jet(base: InverseLinkJet, z1: f64, z2: f64, z3: f64) -> InverseLinkJet {
InverseLinkJet {
mu: base.mu,
d1: base.d1 * z1,
d2: base.d2 * z1 * z1 + base.d1 * z2,
d3: base.d3 * z1 * z1 * z1 + 3.0 * base.d2 * z1 * z2 + base.d1 * z3,
}
}
#[inline]
fn component_inverse_link_pdfthird_derivative(component: LinkComponent, eta: f64) -> f64 {
match component {
LinkComponent::Probit => probit_pdfthird_derivative(eta),
LinkComponent::Logit => logit_inverse_link_jet5(eta).d4,
LinkComponent::CLogLog => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let t = eta.exp();
canonicalzero(stable_nonnegative_poly_times_exp_neg(
t,
&[0.0, 1.0, -7.0, 6.0, -1.0],
))
}
LinkComponent::LogLog => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let r = (-eta).exp();
canonicalzero(stable_nonnegative_poly_times_exp_neg(
r,
&[0.0, -1.0, 7.0, -6.0, 1.0],
))
}
LinkComponent::Cauchit => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let denom = 1.0 + eta * eta;
24.0 * eta * (1.0 - eta * eta) / (std::f64::consts::PI * denom.powi(4))
}
}
}
#[inline]
fn component_inverse_link_pdffourth_derivative(component: LinkComponent, eta: f64) -> f64 {
match component {
LinkComponent::Probit => probit_pdffourth_derivative(eta),
LinkComponent::Logit => logit_inverse_link_jet5(eta).d5,
LinkComponent::CLogLog => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let t = eta.exp();
canonicalzero(stable_nonnegative_poly_times_exp_neg(
t,
&[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
))
}
LinkComponent::LogLog => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let r = (-eta).exp();
canonicalzero(stable_nonnegative_poly_times_exp_neg(
r,
&[0.0, 1.0, -15.0, 25.0, -10.0, 1.0],
))
}
LinkComponent::Cauchit => {
if eta.is_nan() {
return f64::NAN;
}
if !eta.is_finite() {
return 0.0;
}
let e2 = eta * eta;
let denom = 1.0 + e2;
24.0 * (1.0 - 10.0 * e2 + 5.0 * e2 * e2) / (std::f64::consts::PI * denom.powi(5))
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct MixtureJetWithRhoPartials {
pub jet: InverseLinkJet,
pub djet_drho: Vec<InverseLinkJet>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct SasJetWithParamPartials {
pub jet: InverseLinkJet,
pub djet_depsilon: InverseLinkJet,
pub djet_dlog_delta: InverseLinkJet,
}
#[derive(Clone, Debug, PartialEq)]
pub enum LinkParamPartials {
Mixture(MixtureJetWithRhoPartials),
Sas(SasJetWithParamPartials),
}
pub trait InverseLinkKernel {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
assert!(eta.is_finite(), "eta must be finite");
Ok(None)
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ProbitLinkKernel;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct LogitLinkKernel;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct CLogLogLinkKernel;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct LogLogLinkKernel;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct CauchitLinkKernel;
impl SasLinkState {
pub fn new(raw_epsilon: f64, raw_log_delta: f64) -> Result<Self, String> {
if !raw_epsilon.is_finite() || !raw_log_delta.is_finite() {
return Err("SAS link parameters must be finite".to_string());
}
Ok(Self {
epsilon: raw_epsilon,
log_delta: raw_log_delta,
delta: sas_delta_from_raw_log_delta(raw_log_delta),
})
}
}
pub fn state_from_sasspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
SasLinkState::new(spec.initial_epsilon, spec.initial_log_delta)
}
pub fn state_from_beta_logisticspec(spec: SasLinkSpec) -> Result<SasLinkState, String> {
if !spec.initial_epsilon.is_finite() || !spec.initial_log_delta.is_finite() {
return Err("Beta-Logistic link parameters must be finite".to_string());
}
let delta_raw = spec.initial_log_delta;
Ok(SasLinkState {
epsilon: spec.initial_epsilon,
log_delta: delta_raw,
delta: delta_raw.exp(),
})
}
#[inline]
fn tanh_bound(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
b * (value / b).tanh()
}
#[inline]
fn tanh_bound_d1(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
let t = (value / b).tanh();
1.0 - t * t
}
#[inline]
fn tanh_bound_d2(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
let t = (value / b).tanh();
let s = 1.0 - t * t;
-2.0 * t * s / b
}
#[inline]
fn tanh_bound_d3(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
let t = (value / b).tanh();
let s = 1.0 - t * t;
-2.0 * s * (1.0 - 3.0 * t * t) / (b * b)
}
#[inline]
fn tanh_bound_d4(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
let t = (value / b).tanh();
let s = 1.0 - t * t;
8.0 * t * s * (2.0 - 3.0 * t * t) / (b * b * b)
}
#[inline]
fn tanh_bound_d5(value: f64, bound: f64) -> f64 {
let b = bound.max(f64::EPSILON);
let t = (value / b).tanh();
let s = 1.0 - t * t;
let t2 = t * t;
let b4 = b * b * b * b;
8.0 * s * (2.0 - 15.0 * t2 + 15.0 * t2 * t2) / b4
}
#[inline]
fn sas_effective_log_delta(raw_log_delta: f64) -> (f64, f64) {
let ld_eff = tanh_bound(raw_log_delta, SAS_LOG_DELTA_BOUND);
let dld_eff_draw = tanh_bound_d1(raw_log_delta, SAS_LOG_DELTA_BOUND);
(ld_eff, dld_eff_draw)
}
#[inline]
fn sas_delta_from_raw_log_delta(raw_log_delta: f64) -> f64 {
let (ld_eff, _) = sas_effective_log_delta(raw_log_delta);
ld_eff.exp()
}
pub fn validate_mixturespec(spec: &MixtureLinkSpec) -> Result<(), String> {
if spec.components.is_empty() {
return Err("mixture link requires at least 1 component".to_string());
}
if spec.initial_rho.len() + 1 != spec.components.len() {
return Err(format!(
"mixture link rho length mismatch: expected {}, got {}",
spec.components.len() - 1,
spec.initial_rho.len()
));
}
for i in 0..spec.components.len() {
for j in (i + 1)..spec.components.len() {
if spec.components[i] == spec.components[j] {
return Err("mixture link components must be unique".to_string());
}
}
}
let has_anchor = spec.components.iter().any(|component| {
matches!(
component,
LinkComponent::Logit | LinkComponent::Probit | LinkComponent::CLogLog
)
});
if !has_anchor {
let unsupported: Vec<&str> = spec
.components
.iter()
.map(|component| component.name())
.collect();
return Err(format!(
"mixture link components {{{}}} are unsupported: at least one component \
must map to a LinkFunction variant (logit/probit/cloglog) so the mixture's \
projected LinkFunction is well defined; cauchit and loglog have no \
LinkFunction representative",
unsupported.join(", ")
));
}
Ok(())
}
pub fn softmax_last_fixedzero(rho: &Array1<f64>) -> Array1<f64> {
let k = rho.len() + 1;
let mut logits = Vec::with_capacity(k);
let mut maxv = 0.0_f64;
for &v in rho {
maxv = maxv.max(v);
logits.push(v);
}
maxv = maxv.max(0.0);
logits.push(0.0);
let mut sum = 0.0_f64;
let mut exps = vec![0.0_f64; k];
for i in 0..k {
let e = (logits[i] - maxv).exp();
exps[i] = e;
sum += e;
}
if !sum.is_finite() || sum <= 0.0 {
return Array1::from_elem(k, 1.0 / k as f64);
}
let inv = 1.0 / sum;
Array1::from_iter(exps.into_iter().map(|v| v * inv))
}
pub fn softmaxwith_jacobian_last_fixedzero(
rho: &Array1<f64>,
) -> (Array1<f64>, ndarray::Array2<f64>) {
let pi = softmax_last_fixedzero(rho);
let k = pi.len();
let m = k.saturating_sub(1);
let mut jac = ndarray::Array2::<f64>::zeros((k, m));
for j in 0..m {
let pi_j = pi[j];
for kk in 0..k {
let delta = if kk == j { 1.0 } else { 0.0 };
jac[[kk, j]] = pi[kk] * (delta - pi_j);
}
}
(pi, jac)
}
pub fn state_fromspec(spec: &MixtureLinkSpec) -> Result<MixtureLinkState, String> {
validate_mixturespec(spec)?;
let pi = softmax_last_fixedzero(&spec.initial_rho);
Ok(MixtureLinkState {
components: spec.components.clone(),
rho: spec.initial_rho.clone(),
pi,
})
}
#[inline]
pub fn component_inverse_link_jet(component: LinkComponent, eta: f64) -> InverseLinkJet {
canonicalize_jet(match component {
LinkComponent::Logit => {
let jet = logit_inverse_link_jet5(eta);
InverseLinkJet {
mu: jet.mu,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
}
}
LinkComponent::Probit => probit_jet(eta),
LinkComponent::CLogLog => {
if eta.is_nan() {
return InverseLinkJet {
mu: f64::NAN,
d1: f64::NAN,
d2: f64::NAN,
d3: f64::NAN,
};
}
let t = eta.exp();
if !t.is_finite() {
return InverseLinkJet {
mu: 1.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
InverseLinkJet {
mu: -(-t).exp_m1(),
d1: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0]),
d2: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -1.0]),
d3: stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0, -3.0, 1.0]),
}
}
LinkComponent::LogLog => {
if eta.is_nan() {
return InverseLinkJet {
mu: f64::NAN,
d1: f64::NAN,
d2: f64::NAN,
d3: f64::NAN,
};
}
let r = (-eta).exp();
if !r.is_finite() {
return InverseLinkJet {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
InverseLinkJet {
mu: (-r).exp(),
d1: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0]),
d2: stable_nonnegative_poly_times_exp_neg(r, &[0.0, -1.0, 1.0]),
d3: stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0, -3.0, 1.0]),
}
}
LinkComponent::Cauchit => {
if eta.is_nan() {
return InverseLinkJet {
mu: f64::NAN,
d1: f64::NAN,
d2: f64::NAN,
d3: f64::NAN,
};
}
let den = 1.0 + eta * eta;
let d1 = if eta.is_finite() {
1.0 / (std::f64::consts::PI * den)
} else {
0.0
};
let d2 = if eta.is_finite() {
-2.0 * eta / (std::f64::consts::PI * den * den)
} else {
0.0
};
let d3 = if eta.is_finite() {
(6.0 * eta * eta - 2.0) / (std::f64::consts::PI * den * den * den)
} else {
0.0
};
InverseLinkJet {
mu: 0.5 + eta.atan() / std::f64::consts::PI,
d1,
d2,
d3,
}
}
})
}
impl InverseLinkKernel for ProbitLinkKernel {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(LinkComponent::Probit, eta))
}
}
impl InverseLinkKernel for LogitLinkKernel {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(LinkComponent::Logit, eta))
}
}
impl InverseLinkKernel for CLogLogLinkKernel {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(LinkComponent::CLogLog, eta))
}
}
impl InverseLinkKernel for LogLogLinkKernel {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(LinkComponent::LogLog, eta))
}
}
impl InverseLinkKernel for CauchitLinkKernel {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(LinkComponent::Cauchit, eta))
}
}
impl InverseLinkKernel for LinkComponent {
#[inline]
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(component_inverse_link_jet(*self, eta))
}
}
impl InverseLinkKernel for LinkFunction {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
match self {
LinkFunction::Logit => LogitLinkKernel.jet(eta),
LinkFunction::Probit => ProbitLinkKernel.jet(eta),
LinkFunction::CLogLog => CLogLogLinkKernel.jet(eta),
LinkFunction::Identity => Ok(InverseLinkJet {
mu: eta,
d1: 1.0,
d2: 0.0,
d3: 0.0,
}),
LinkFunction::Log => {
let e = eta.clamp(-700.0, 700.0).exp();
Ok(InverseLinkJet {
mu: e,
d1: e,
d2: e,
d3: e,
})
}
LinkFunction::Sas => Err(EstimationError::InvalidInput(
"LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
)),
LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
"LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
.to_string(),
)),
}
}
}
impl InverseLinkKernel for SasLinkState {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(sas_inverse_link_jet(eta, self.epsilon, self.log_delta))
}
fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
Ok(Some(LinkParamPartials::Sas(
sas_inverse_link_jetwith_param_partials(eta, self.epsilon, self.log_delta),
)))
}
}
#[derive(Clone, Copy, Debug)]
pub struct BetaLogisticKernel {
pub delta: f64,
pub epsilon: f64,
}
impl InverseLinkKernel for BetaLogisticKernel {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(beta_logistic_inverse_link_jet(
eta,
self.delta,
self.epsilon,
))
}
fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
Ok(Some(LinkParamPartials::Sas(
beta_logistic_inverse_link_jetwith_param_partials(eta, self.delta, self.epsilon),
)))
}
}
impl InverseLinkKernel for MixtureLinkState {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
Ok(mixture_inverse_link_jet(self, eta))
}
fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
Ok(Some(LinkParamPartials::Mixture(
mixture_inverse_link_jetwith_rho_partials(self, eta),
)))
}
}
impl InverseLinkKernel for InverseLink {
fn jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
match self {
InverseLink::Standard(link_fn) => link_fn.jet(eta),
InverseLink::LatentCLogLog(state) => latent_cloglog_point_jet(state, eta),
InverseLink::Sas(state) => state.jet(eta),
InverseLink::BetaLogistic(state) => BetaLogisticKernel {
delta: state.log_delta,
epsilon: state.epsilon,
}
.jet(eta),
InverseLink::Mixture(state) => state.jet(eta),
}
}
fn param_partials(&self, eta: f64) -> Result<Option<LinkParamPartials>, EstimationError> {
match self {
InverseLink::Standard(_) => Ok(None),
InverseLink::LatentCLogLog(_) => Ok(None),
InverseLink::Sas(state) => state.param_partials(eta),
InverseLink::BetaLogistic(state) => BetaLogisticKernel {
delta: state.log_delta,
epsilon: state.epsilon,
}
.param_partials(eta),
InverseLink::Mixture(state) => state.param_partials(eta),
}
}
}
pub fn inverse_link_jet_for_inverse_link(
link: &InverseLink,
eta: f64,
) -> Result<InverseLinkJet, EstimationError> {
link.jet(eta)
}
pub fn inverse_link_mu_d1_for_inverse_link(
link: &InverseLink,
eta: f64,
) -> Result<(f64, f64), EstimationError> {
match link {
InverseLink::Standard(link_fn) => Ok(link_function_mu_d1(*link_fn, eta)?),
InverseLink::LatentCLogLog(state) => {
let jet = latent_cloglog_point_jet(state, eta)?;
Ok((jet.mu, jet.d1))
}
InverseLink::Sas(state) => Ok(sas_inverse_link_mu_d1(eta, state.epsilon, state.log_delta)),
InverseLink::BetaLogistic(state) => Ok(beta_logistic_inverse_link_mu_d1(
eta,
state.log_delta,
state.epsilon,
)),
InverseLink::Mixture(state) => Ok(mixture_inverse_link_mu_d1(state, eta)),
}
}
fn link_function_mu_d1(link: LinkFunction, eta: f64) -> Result<(f64, f64), EstimationError> {
match link {
LinkFunction::Identity => Ok((eta, 1.0)),
LinkFunction::Log => {
let e = eta.clamp(-700.0, 700.0).exp();
Ok((e, e))
}
LinkFunction::Logit => Ok(component_inverse_link_mu_d1(LinkComponent::Logit, eta)),
LinkFunction::Probit => Ok(component_inverse_link_mu_d1(LinkComponent::Probit, eta)),
LinkFunction::CLogLog => Ok(component_inverse_link_mu_d1(LinkComponent::CLogLog, eta)),
LinkFunction::Sas => Err(EstimationError::InvalidInput(
"LinkFunction::Sas inverse-link requires explicit SAS link state".to_string(),
)),
LinkFunction::BetaLogistic => Err(EstimationError::InvalidInput(
"LinkFunction::BetaLogistic inverse-link requires explicit Beta-Logistic link state"
.to_string(),
)),
}
}
#[inline]
fn component_inverse_link_mu_d1(component: LinkComponent, eta: f64) -> (f64, f64) {
match component {
LinkComponent::Logit => {
let jet = logit_inverse_link_jet5(eta);
(jet.mu, canonicalzero(jet.d1))
}
LinkComponent::Probit => {
if eta.is_nan() {
return (f64::NAN, f64::NAN);
}
if eta == f64::INFINITY {
return (1.0, 0.0);
}
if eta == f64::NEG_INFINITY {
return (0.0, 0.0);
}
let phi = normal_pdf(eta);
(normal_cdf(eta), canonicalzero(phi))
}
LinkComponent::CLogLog => {
if eta.is_nan() {
return (f64::NAN, f64::NAN);
}
let t = eta.exp();
if !t.is_finite() {
return (1.0, 0.0);
}
(
-(-t).exp_m1(),
canonicalzero(stable_nonnegative_poly_times_exp_neg(t, &[0.0, 1.0])),
)
}
LinkComponent::LogLog => {
if eta.is_nan() {
return (f64::NAN, f64::NAN);
}
let r = (-eta).exp();
if !r.is_finite() {
return (0.0, 0.0);
}
(
(-r).exp(),
canonicalzero(stable_nonnegative_poly_times_exp_neg(r, &[0.0, 1.0])),
)
}
LinkComponent::Cauchit => {
if eta.is_nan() {
return (f64::NAN, f64::NAN);
}
let den = 1.0 + eta * eta;
let d1 = if eta.is_finite() {
1.0 / (std::f64::consts::PI * den)
} else {
0.0
};
(0.5 + eta.atan() / std::f64::consts::PI, canonicalzero(d1))
}
}
}
fn sas_inverse_link_mu_d1(eta: f64, epsilon: f64, log_delta: f64) -> (f64, f64) {
let delta_id = sas_delta_from_raw_log_delta(log_delta);
if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
return component_inverse_link_mu_d1(LinkComponent::Probit, eta);
}
let e = if eta.is_finite() { eta } else { 0.0 };
let a = e.asinh();
let delta = delta_id;
let u_raw = delta * a - epsilon;
let u = tanh_bound(u_raw, SAS_U_CLAMP);
let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
let s = u.sinh();
let c = u.cosh();
let z = s;
let q = e.hypot(1.0);
let inv_q = 1.0 / q;
let r1 = delta * inv_q;
let u1 = g1 * r1;
let z1 = c * u1;
let base = probit_jet(z);
(base.mu, canonicalzero(base.d1 * z1))
}
fn beta_logistic_inverse_link_mu_d1(eta: f64, delta: f64, epsilon: f64) -> (f64, f64) {
let (u, du) = logistic_uwith_derivatives(eta);
let a = (delta - epsilon).exp();
let b = (delta + epsilon).exp();
let mu = beta_reg(a, b, u);
if du == 0.0 {
return (mu, 0.0);
}
let log_d1 = a * u.ln() + b * (1.0 - u).ln() - ln_beta(a, b);
(mu, log_d1.exp())
}
fn mixture_inverse_link_mu_d1(state: &MixtureLinkState, eta: f64) -> (f64, f64) {
let mut mu = 0.0_f64;
let mut d1 = 0.0_f64;
let k = state.components.len().min(state.pi.len());
for i in 0..k {
let (mu_i, d1_i) = component_inverse_link_mu_d1(state.components[i], eta);
let w = state.pi[i];
mu += w * mu_i;
d1 += w * d1_i;
}
(mu, d1)
}
#[derive(Clone, Copy)]
enum PdfDerivativeOrder {
Third,
Fourth,
}
impl PdfDerivativeOrder {
fn probit(self, eta: f64) -> f64 {
match self {
Self::Third => probit_pdfthird_derivative(eta),
Self::Fourth => probit_pdffourth_derivative(eta),
}
}
fn component(self, component: LinkComponent, eta: f64) -> f64 {
match self {
Self::Third => component_inverse_link_pdfthird_derivative(component, eta),
Self::Fourth => component_inverse_link_pdffourth_derivative(component, eta),
}
}
fn latent_cloglog(self, eta: f64, latent_sd: f64) -> Result<f64, EstimationError> {
let jet = latent_cloglog_jet5(latent_cloglog_quadctx(), eta, latent_sd)?;
Ok(match self {
Self::Third => jet.d4,
Self::Fourth => jet.d5,
})
}
fn sas(self, eta: f64, epsilon: f64, log_delta: f64) -> f64 {
match self {
Self::Third => sas_inverse_link_pdfthird_derivative(eta, epsilon, log_delta),
Self::Fourth => sas_inverse_link_pdffourth_derivative(eta, epsilon, log_delta),
}
}
fn beta_logistic(self, eta: f64, delta: f64, epsilon: f64) -> f64 {
match self {
Self::Third => beta_logistic_inverse_link_pdfthird_derivative(eta, delta, epsilon),
Self::Fourth => beta_logistic_inverse_link_pdffourth_derivative(eta, delta, epsilon),
}
}
}
fn inverse_link_pdf_derivative_for_inverse_link(
link: &InverseLink,
eta: f64,
order: PdfDerivativeOrder,
) -> Result<f64, EstimationError> {
match link {
InverseLink::Standard(LinkFunction::Identity) => Ok(0.0),
InverseLink::Standard(LinkFunction::Log) => Ok(eta.clamp(-700.0, 700.0).exp()),
InverseLink::Standard(LinkFunction::Probit) => Ok(order.probit(eta)),
InverseLink::Standard(LinkFunction::Logit) => {
Ok(order.component(LinkComponent::Logit, eta))
}
InverseLink::Standard(LinkFunction::CLogLog) => {
Ok(order.component(LinkComponent::CLogLog, eta))
}
InverseLink::LatentCLogLog(state) => order.latent_cloglog(eta, state.latent_sd),
InverseLink::Standard(LinkFunction::Sas) => Ok(order.sas(eta, 0.0, 0.0)),
InverseLink::Sas(state) => Ok(order.sas(eta, state.epsilon, state.log_delta)),
InverseLink::Standard(LinkFunction::BetaLogistic) => Ok(order.beta_logistic(eta, 0.0, 0.0)),
InverseLink::BetaLogistic(state) => {
Ok(order.beta_logistic(eta, state.log_delta, state.epsilon))
}
InverseLink::Mixture(state) => Ok(state
.components
.iter()
.zip(state.pi.iter())
.map(|(&component, &weight)| weight * order.component(component, eta))
.sum()),
}
}
pub fn inverse_link_pdfthird_derivative_for_inverse_link(
link: &InverseLink,
eta: f64,
) -> Result<f64, EstimationError> {
inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Third)
}
pub fn inverse_link_pdffourth_derivative_for_inverse_link(
link: &InverseLink,
eta: f64,
) -> Result<f64, EstimationError> {
inverse_link_pdf_derivative_for_inverse_link(link, eta, PdfDerivativeOrder::Fourth)
}
pub fn inverse_link_jet_for_link_function(
link: LinkFunction,
eta: f64,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<InverseLinkJet, EstimationError> {
if let Some(state) = mixture_link_state {
return state.jet(eta);
}
if let Some(sas) = sas_link_state {
return match link {
LinkFunction::BetaLogistic => BetaLogisticKernel {
delta: sas.log_delta,
epsilon: sas.epsilon,
}
.jet(eta),
LinkFunction::Sas => sas.jet(eta),
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Identity
| LinkFunction::Log => link.jet(eta),
};
}
link.jet(eta)
}
#[inline]
fn royston_parmar_inverse_link_jet(eta: f64) -> InverseLinkJet {
const SURVIVAL_ETA_CLAMP: f64 = 30.0;
let z = eta.clamp(-SURVIVAL_ETA_CLAMP, SURVIVAL_ETA_CLAMP);
let hazard = z.exp();
let survival = (-hazard).exp();
if !(-SURVIVAL_ETA_CLAMP..=SURVIVAL_ETA_CLAMP).contains(&eta) {
return InverseLinkJet {
mu: survival,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
let d1 = -hazard * survival;
let d2 = hazard * (hazard - 1.0) * survival;
let d3 = (-hazard * hazard * hazard + 3.0 * hazard * hazard - hazard) * survival;
InverseLinkJet {
mu: survival,
d1,
d2,
d3,
}
}
pub fn inverse_link_jet_for_family(
spec: &LikelihoodSpec,
eta: f64,
) -> Result<InverseLinkJet, EstimationError> {
if matches!(spec.response, ResponseFamily::RoystonParmar) {
return Ok(royston_parmar_inverse_link_jet(eta));
}
spec.link.jet(eta)
}
#[inline]
pub fn mixture_inverse_link_jet(state: &MixtureLinkState, eta: f64) -> InverseLinkJet {
let mut mu = 0.0_f64;
let mut d1 = 0.0_f64;
let mut d2 = 0.0_f64;
let mut d3 = 0.0_f64;
let k = state.components.len().min(state.pi.len());
for i in 0..k {
let jet = component_inverse_link_jet(state.components[i], eta);
let w = state.pi[i];
mu += w * jet.mu;
d1 += w * jet.d1;
d2 += w * jet.d2;
d3 += w * jet.d3;
}
InverseLinkJet { mu, d1, d2, d3 }
}
pub fn mixture_inverse_link_jetwith_rho_partials(
state: &MixtureLinkState,
eta: f64,
) -> MixtureJetWithRhoPartials {
let k = state.components.len().min(state.pi.len());
let m = k.saturating_sub(1);
let mut djet_drho = vec![
InverseLinkJet {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
m
];
let jet = mixture_inverse_link_jetwith_rho_partials_into(state, eta, &mut djet_drho);
MixtureJetWithRhoPartials { jet, djet_drho }
}
pub fn mixture_inverse_link_jetwith_rho_partials_into(
state: &MixtureLinkState,
eta: f64,
out: &mut [InverseLinkJet],
) -> InverseLinkJet {
let k = state.components.len().min(state.pi.len());
let m = k.saturating_sub(1);
assert!(
out.len() >= m,
"rho-partial output buffer too small: got {}, need {}",
out.len(),
m
);
let mut mixed = InverseLinkJet {
mu: 0.0,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
for i in 0..k {
let jet_i = component_inverse_link_jet(state.components[i], eta);
let w = state.pi[i];
mixed.mu += w * jet_i.mu;
mixed.d1 += w * jet_i.d1;
mixed.d2 += w * jet_i.d2;
mixed.d3 += w * jet_i.d3;
if i < m {
out[i] = jet_i;
}
}
for j in 0..m {
let pi_j = state.pi[j];
let cj = out[j];
out[j] = InverseLinkJet {
mu: pi_j * (cj.mu - mixed.mu),
d1: pi_j * (cj.d1 - mixed.d1),
d2: pi_j * (cj.d2 - mixed.d2),
d3: pi_j * (cj.d3 - mixed.d3),
};
}
mixed
}
#[inline]
fn logistic_uwith_derivatives(eta: f64) -> (f64, f64) {
let u = logistic_stable(eta);
let u_clamped = u.clamp(BETA_LOGISTIC_U_EPS, 1.0 - BETA_LOGISTIC_U_EPS);
let clamp_active = !eta.is_finite() || u_clamped != u;
let du = if clamp_active {
0.0
} else {
u_clamped * (1.0 - u_clamped)
};
let u = u_clamped;
(u, du)
}
#[derive(Clone, Copy)]
struct ShapeDual {
v: f64,
da: f64,
db: f64,
}
impl ShapeDual {
#[inline]
fn constant(v: f64) -> Self {
Self {
v,
da: 0.0,
db: 0.0,
}
}
#[inline]
fn from_value_partials(v: f64, da: f64, db: f64) -> Self {
Self { v, da, db }
}
#[inline]
fn clamp_small(self, floor: f64) -> Self {
if self.v.abs() < floor {
Self::constant(floor)
} else {
self
}
}
}
impl std::ops::Add for ShapeDual {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self {
v: self.v + rhs.v,
da: self.da + rhs.da,
db: self.db + rhs.db,
}
}
}
impl std::ops::Sub for ShapeDual {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self {
v: self.v - rhs.v,
da: self.da - rhs.da,
db: self.db - rhs.db,
}
}
}
impl std::ops::Mul for ShapeDual {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self {
v: self.v * rhs.v,
da: self.da * rhs.v + self.v * rhs.da,
db: self.db * rhs.v + self.v * rhs.db,
}
}
}
impl std::ops::Div for ShapeDual {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
let inv = 1.0 / rhs.v;
let inv2 = inv * inv;
Self {
v: self.v * inv,
da: (self.da * rhs.v - self.v * rhs.da) * inv2,
db: (self.db * rhs.v - self.v * rhs.db) * inv2,
}
}
}
impl std::ops::Neg for ShapeDual {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self {
v: -self.v,
da: -self.da,
db: -self.db,
}
}
}
#[inline]
fn shape_dual(v: f64) -> ShapeDual {
ShapeDual::constant(v)
}
fn beta_reg_with_shape_partials(a0: f64, b0: f64, x0: f64) -> (f64, f64, f64) {
if x0 <= 0.0 {
return (0.0, 0.0, 0.0);
}
if x0 >= 1.0 {
return (1.0, 0.0, 0.0);
}
let symm_transform = x0 >= (a0 + 1.0) / (a0 + b0 + 2.0);
let (a, b, x) = if symm_transform {
(
ShapeDual::from_value_partials(b0, 0.0, 1.0),
ShapeDual::from_value_partials(a0, 1.0, 0.0),
1.0 - x0,
)
} else {
(
ShapeDual::from_value_partials(a0, 1.0, 0.0),
ShapeDual::from_value_partials(b0, 0.0, 1.0),
x0,
)
};
let ln_x = x.ln();
let ln_1mx = (1.0 - x).ln();
let psi_ab = digamma(a.v + b.v);
let log_bt = statrs::function::gamma::ln_gamma(a.v + b.v)
- statrs::function::gamma::ln_gamma(a.v)
- statrs::function::gamma::ln_gamma(b.v)
+ a.v * ln_x
+ b.v * ln_1mx;
let bt_v = log_bt.exp();
let log_bt_a = psi_ab - digamma(a.v) + ln_x;
let log_bt_b = psi_ab - digamma(b.v) + ln_1mx;
let bt = ShapeDual {
v: bt_v,
da: bt_v * (log_bt_a * a.da + log_bt_b * b.da),
db: bt_v * (log_bt_a * a.db + log_bt_b * b.db),
};
let eps = 0.00000000000000011102230246251565;
let fpmin = f64::MIN_POSITIVE / eps;
let one = shape_dual(1.0);
let qab = a + b;
let qap = a + one;
let qam = a - one;
let mut c = one;
let mut d = (one - qab * shape_dual(x) / qap).clamp_small(fpmin);
d = one / d;
let mut h = d;
for m in 1..141 {
let mf = f64::from(m);
let m2 = mf * 2.0;
let md = shape_dual(mf);
let m2d = shape_dual(m2);
let mut aa = md * (b - md) * shape_dual(x) / ((qam + m2d) * (a + m2d));
d = (one + aa * d).clamp_small(fpmin);
c = (one + aa / c).clamp_small(fpmin);
d = one / d;
h = h * d * c;
aa = -(a + md) * (qab + md) * shape_dual(x) / ((a + m2d) * (qap + m2d));
d = (one + aa * d).clamp_small(fpmin);
c = (one + aa / c).clamp_small(fpmin);
d = one / d;
let del = d * c;
h = h * del;
if (del.v - 1.0).abs() <= eps {
let reg = bt * h / a;
return if symm_transform {
(1.0 - reg.v, -reg.da, -reg.db)
} else {
(reg.v, reg.da, reg.db)
};
}
}
let reg = bt * h / a;
if symm_transform {
(1.0 - reg.v, -reg.da, -reg.db)
} else {
(reg.v, reg.da, reg.db)
}
}
pub fn beta_logistic_inverse_link_jet(eta: f64, delta: f64, epsilon: f64) -> InverseLinkJet {
let (u, du) = logistic_uwith_derivatives(eta);
let a = (delta - epsilon).exp();
let b = (delta + epsilon).exp();
let mu = beta_reg(a, b, u);
if du == 0.0 {
return InverseLinkJet {
mu,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
}
let log_d1 = a * u.ln() + b * (1.0 - u).ln() - ln_beta(a, b);
let d1 = log_d1.exp();
let t = a * (1.0 - u) - b * u;
let d2 = d1 * t;
let d3 = d1 * (t * t - (a + b) * du);
InverseLinkJet { mu, d1, d2, d3 }
}
pub fn beta_logistic_inverse_link_pdfthird_derivative(eta: f64, delta: f64, epsilon: f64) -> f64 {
let (u, du) = logistic_uwith_derivatives(eta);
if du == 0.0 {
return 0.0;
}
let a = (delta - epsilon).exp();
let b = (delta + epsilon).exp();
let log_d1 = a * u.ln() + b * (1.0 - u).ln() - ln_beta(a, b);
let d1 = log_d1.exp();
let c = a + b;
let t = a * (1.0 - u) - b * u;
let u2 = du * (1.0 - 2.0 * u);
d1 * (t * t * t - 3.0 * c * t * du - c * u2)
}
pub fn beta_logistic_inverse_link_pdffourth_derivative(eta: f64, delta: f64, epsilon: f64) -> f64 {
let (u, du) = logistic_uwith_derivatives(eta);
if du == 0.0 {
return 0.0;
}
let a = (delta - epsilon).exp();
let b = (delta + epsilon).exp();
let log_d1 = a * u.ln() + b * (1.0 - u).ln() - ln_beta(a, b);
let d1 = log_d1.exp();
let c = a + b;
let t = a * (1.0 - u) - b * u;
let u2 = du * (1.0 - 2.0 * u);
let u3 = u2 * (1.0 - 2.0 * u) - 2.0 * du * du;
let t2 = t * t;
d1 * (t2 * t2 - 6.0 * c * t2 * du - 4.0 * c * t * u2 + 3.0 * c * c * du * du - c * u3)
}
pub fn beta_logistic_inverse_link_jetwith_param_partials(
eta: f64,
delta: f64,
epsilon: f64,
) -> SasJetWithParamPartials {
let (u, du) = logistic_uwith_derivatives(eta);
let a = (delta - epsilon).exp();
let b = (delta + epsilon).exp();
let (mu, dmu_da, dmu_db) = beta_reg_with_shape_partials(a, b, u);
let dmu_ddelta = a * dmu_da + b * dmu_db;
let dmu_depsilon = -a * dmu_da + b * dmu_db;
if du == 0.0 {
let zero = InverseLinkJet {
mu,
d1: 0.0,
d2: 0.0,
d3: 0.0,
};
return SasJetWithParamPartials {
jet: zero,
djet_depsilon: InverseLinkJet {
mu: dmu_depsilon,
d1: 0.0,
d2: 0.0,
d3: 0.0,
},
djet_dlog_delta: InverseLinkJet {
mu: dmu_ddelta,
d1: 0.0,
d2: 0.0,
d3: 0.0,
},
};
}
let log_d1 = a * u.ln() + b * (1.0 - u).ln() - ln_beta(a, b);
let d1 = log_d1.exp();
let t = a * (1.0 - u) - b * u;
let d2 = d1 * t;
let k = t * t - (a + b) * du;
let d3 = d1 * k;
let jet = InverseLinkJet { mu, d1, d2, d3 };
let psi_a = digamma(a);
let psi_b = digamma(b);
let psi_ab = digamma(a + b);
let la = u.ln() - psi_a + psi_ab;
let lb = (1.0 - u).ln() - psi_b + psi_ab;
let partials_for = |a_p: f64, b_p: f64, dmu: f64| -> InverseLinkJet {
let logd1_p = a_p * la + b_p * lb;
let d1_p = d1 * logd1_p;
let t_p = a_p * (1.0 - u) - b_p * u;
let d2_p = d1_p * t + d1 * t_p;
let k_p = 2.0 * t * t_p - (a_p + b_p) * du;
let d3_p = d1_p * k + d1 * k_p;
InverseLinkJet {
mu: dmu,
d1: d1_p,
d2: d2_p,
d3: d3_p,
}
};
let djet_ddelta = partials_for(a, b, dmu_ddelta);
let djet_depsilon = partials_for(-a, b, dmu_depsilon);
SasJetWithParamPartials {
jet,
djet_depsilon,
djet_dlog_delta: djet_ddelta,
}
}
pub fn sas_inverse_link_jet(eta: f64, epsilon: f64, log_delta: f64) -> InverseLinkJet {
let delta_id = sas_delta_from_raw_log_delta(log_delta);
if epsilon.abs() < 1e-12 && (delta_id - 1.0).abs() < 1e-12 {
return component_inverse_link_jet(LinkComponent::Probit, eta);
}
let e = if eta.is_finite() { eta } else { 0.0 };
let a = e.asinh();
let delta = delta_id;
let u_raw = delta * a - epsilon;
let u = tanh_bound(u_raw, SAS_U_CLAMP);
let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
let s = u.sinh();
let c = u.cosh();
let z = s;
let q = e.hypot(1.0);
let inv_q = 1.0 / q;
let inv_q2 = inv_q * inv_q;
let inv_q3 = inv_q2 * inv_q;
let inv_q5 = inv_q3 * inv_q2;
let r1 = delta * inv_q;
let r2 = -delta * e * inv_q3;
let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
let u1 = g1 * r1;
let u2 = g2 * r1 * r1 + g1 * r2;
let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
let z1 = c * u1;
let z2 = s * u1 * u1 + c * u2;
let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
let base = probit_jet(z);
chain_inverse_link_jet(base, z1, z2, z3)
}
pub fn sas_inverse_link_pdfthird_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
let e = if eta.is_finite() { eta } else { 0.0 };
let a = e.asinh();
let delta = sas_delta_from_raw_log_delta(log_delta);
let u_raw = delta * a - epsilon;
let u = tanh_bound(u_raw, SAS_U_CLAMP);
let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
let s = u.sinh();
let c = u.cosh();
let z = s;
let base = probit_jet(z);
let q = e.hypot(1.0);
let inv_q = 1.0 / q;
let inv_q2 = inv_q * inv_q;
let inv_q3 = inv_q2 * inv_q;
let inv_q5 = inv_q3 * inv_q2;
let inv_q7 = inv_q5 * inv_q2;
let r1 = delta * inv_q;
let r2 = -delta * e * inv_q3;
let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
let u1 = g1 * r1;
let u2 = g2 * r1 * r1 + g1 * r2;
let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
let u4 = g4 * r1.powi(4)
+ 6.0 * g3 * r1 * r1 * r2
+ 3.0 * g2 * r2 * r2
+ 4.0 * g2 * r1 * r3
+ g1 * r4;
let z1 = c * u1;
let z2 = s * u1 * u1 + c * u2;
let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
let z4 =
s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
let base4 = probit_pdfthird_derivative(z);
let out = base4 * z1.powi(4)
+ 6.0 * base.d3 * z1 * z1 * z2
+ 3.0 * base.d2 * z2 * z2
+ 4.0 * base.d2 * z1 * z3
+ base.d1 * z4;
canonicalzero(out)
}
pub fn sas_inverse_link_pdffourth_derivative(eta: f64, epsilon: f64, log_delta: f64) -> f64 {
let e = if eta.is_finite() { eta } else { 0.0 };
let a = e.asinh();
let delta = sas_delta_from_raw_log_delta(log_delta);
let u_raw = delta * a - epsilon;
let u = tanh_bound(u_raw, SAS_U_CLAMP);
let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
let g5 = tanh_bound_d5(u_raw, SAS_U_CLAMP);
let s = u.sinh();
let c = u.cosh();
let z = s;
let base = probit_jet(z);
let phi3 = probit_pdfthird_derivative(z); let phi4 = probit_pdffourth_derivative(z);
let q = e.hypot(1.0);
let inv_q = 1.0 / q;
let inv_q2 = inv_q * inv_q;
let inv_q3 = inv_q2 * inv_q;
let inv_q5 = inv_q3 * inv_q2;
let inv_q7 = inv_q5 * inv_q2;
let inv_q9 = inv_q7 * inv_q2;
let r1 = delta * inv_q;
let r2 = -delta * e * inv_q3;
let r3 = delta * (2.0 * e * e - 1.0) * inv_q5;
let r4 = delta * e * (9.0 - 6.0 * e * e) * inv_q7;
let r5 = delta * (9.0 - 72.0 * e * e + 24.0 * e * e * e * e) * inv_q9;
let u1 = g1 * r1;
let u2 = g2 * r1 * r1 + g1 * r2;
let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
let u4 = g4 * r1.powi(4)
+ 6.0 * g3 * r1 * r1 * r2
+ 3.0 * g2 * r2 * r2
+ 4.0 * g2 * r1 * r3
+ g1 * r4;
let u5 = g5 * r1.powi(5)
+ 10.0 * g4 * r1 * r1 * r1 * r2
+ 15.0 * g3 * r1 * r2 * r2
+ 10.0 * g3 * r1 * r1 * r3
+ 10.0 * g2 * r2 * r3
+ 5.0 * g2 * r1 * r4
+ g1 * r5;
let z1 = c * u1;
let z2 = s * u1 * u1 + c * u2;
let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
let z4 =
s * u1.powi(4) + 6.0 * c * u1 * u1 * u2 + 3.0 * s * u2 * u2 + 4.0 * s * u1 * u3 + c * u4;
let z5 = c * u1.powi(5)
+ 10.0 * s * u1 * u1 * u1 * u2
+ 15.0 * c * u1 * u2 * u2
+ 10.0 * c * u1 * u1 * u3
+ 10.0 * s * u2 * u3
+ 5.0 * s * u1 * u4
+ c * u5;
let out = phi4 * z1.powi(5)
+ 10.0 * phi3 * z1 * z1 * z1 * z2
+ 15.0 * base.d3 * z1 * z2 * z2
+ 10.0 * base.d3 * z1 * z1 * z3
+ 10.0 * base.d2 * z2 * z3
+ 5.0 * base.d2 * z1 * z4
+ base.d1 * z5;
canonicalzero(out)
}
pub fn sas_inverse_link_jetwith_param_partials(
eta: f64,
epsilon: f64,
log_delta: f64,
) -> SasJetWithParamPartials {
let e = if eta.is_finite() { eta } else { 0.0 };
let a = e.asinh();
let (ld_eff, dld_eff_draw) = sas_effective_log_delta(log_delta);
let delta = ld_eff.exp();
let ddelta_draw = delta * dld_eff_draw;
let u_raw = delta * a - epsilon;
let u = tanh_bound(u_raw, SAS_U_CLAMP);
let g1 = tanh_bound_d1(u_raw, SAS_U_CLAMP);
let g2 = tanh_bound_d2(u_raw, SAS_U_CLAMP);
let g3 = tanh_bound_d3(u_raw, SAS_U_CLAMP);
let g4 = tanh_bound_d4(u_raw, SAS_U_CLAMP);
let s = u.sinh();
let c = u.cosh();
let z = s;
let q = e.hypot(1.0);
let inv_q = 1.0 / q;
let inv_q2 = inv_q * inv_q;
let inv_q3 = inv_q2 * inv_q;
let inv_q5 = inv_q3 * inv_q2;
let a1 = inv_q;
let a2 = -e * inv_q3;
let a3 = (2.0 * e * e - 1.0) * inv_q5;
let r1 = delta * a1;
let r2 = delta * a2;
let r3 = delta * a3;
let u1 = g1 * r1;
let u2 = g2 * r1 * r1 + g1 * r2;
let u3 = g3 * r1 * r1 * r1 + 3.0 * g2 * r1 * r2 + g1 * r3;
let z1 = c * u1;
let z2 = s * u1 * u1 + c * u2;
let z3 = c * u1 * u1 * u1 + 3.0 * s * u1 * u2 + c * u3;
let base = probit_jet(z);
let jet = chain_inverse_link_jet(base, z1, z2, z3);
let param_partials = |u_t: f64, u1_t: f64, u2_t: f64, u3_t: f64| -> InverseLinkJet {
let z_t = c * u_t;
let z1_t = s * u_t * u1 + c * u1_t;
let z2_t = c * u_t * u1 * u1 + 2.0 * s * u1 * u1_t + s * u_t * u2 + c * u2_t;
let z3_t = s * u_t * u1 * u1 * u1
+ 3.0 * c * u1 * u1 * u1_t
+ 3.0 * c * u_t * u1 * u2
+ 3.0 * s * (u1_t * u2 + u1 * u2_t)
+ s * u_t * u3
+ c * u3_t;
InverseLinkJet {
mu: base.d1 * z_t,
d1: base.d2 * z_t * z1 + base.d1 * z1_t,
d2: base.d3 * z_t * z1 * z1
+ 2.0 * base.d2 * z1 * z1_t
+ base.d2 * z_t * z2
+ base.d1 * z2_t,
d3: probit_pdfthird_derivative(z) * z_t * z1.powi(3)
+ 3.0 * base.d3 * z1 * z1 * z1_t
+ 3.0 * base.d3 * z_t * z1 * z2
+ 3.0 * base.d2 * (z1_t * z2 + z1 * z2_t)
+ base.d2 * z_t * z3
+ base.d1 * z3_t,
}
};
let rt_eps = -1.0;
let r1t_eps = 0.0;
let r2t_eps = 0.0;
let r3t_eps = 0.0;
let u_eps = g1 * rt_eps;
let u1_eps = g2 * rt_eps * r1 + g1 * r1t_eps;
let u2_eps = g3 * rt_eps * r1 * r1 + 2.0 * g2 * r1 * r1t_eps + g2 * rt_eps * r2 + g1 * r2t_eps;
let u3_eps = g4 * rt_eps * r1 * r1 * r1
+ 3.0 * g3 * r1 * r1 * r1t_eps
+ 3.0 * g3 * rt_eps * r1 * r2
+ 3.0 * g2 * (r1t_eps * r2 + r1 * r2t_eps)
+ g2 * rt_eps * r3
+ g1 * r3t_eps;
let djet_depsilon = param_partials(u_eps, u1_eps, u2_eps, u3_eps);
let rt_ld = ddelta_draw * a;
let r1t_ld = ddelta_draw * a1;
let r2t_ld = ddelta_draw * a2;
let r3t_ld = ddelta_draw * a3;
let u_ld = g1 * rt_ld;
let u1_ld = g2 * rt_ld * r1 + g1 * r1t_ld;
let u2_ld = g3 * rt_ld * r1 * r1 + 2.0 * g2 * r1 * r1t_ld + g2 * rt_ld * r2 + g1 * r2t_ld;
let u3_ld = g4 * rt_ld * r1 * r1 * r1
+ 3.0 * g3 * r1 * r1 * r1t_ld
+ 3.0 * g3 * rt_ld * r1 * r2
+ 3.0 * g2 * (r1t_ld * r2 + r1 * r2t_ld)
+ g2 * rt_ld * r3
+ g1 * r3t_ld;
let djet_dlog_delta = param_partials(u_ld, u1_ld, u2_ld, u3_ld);
SasJetWithParamPartials {
jet,
djet_depsilon,
djet_dlog_delta,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{InverseLink, LinkComponent, LinkFunction, MixtureLinkSpec, SasLinkState};
#[test]
fn softmax_jacobian_matchesfd() {
let rho = Array1::from_vec(vec![0.7, -1.2, 0.4]);
let (pi, jac) = softmaxwith_jacobian_last_fixedzero(&rho);
let h = 1e-6;
for j in 0..rho.len() {
let mut rp = rho.clone();
rp[j] += h;
let mut rm = rho.clone();
rm[j] -= h;
let pp = softmax_last_fixedzero(&rp);
let pm = softmax_last_fixedzero(&rm);
let fd = (&pp - &pm).mapv(|v| v / (2.0 * h));
for k in 0..pi.len() {
let err = (jac[[k, j]] - fd[k]).abs();
assert_eq!(
jac[[k, j]].signum(),
fd[k].signum(),
"jac sign mismatch at ({k},{j}): analytic={} fd={}",
jac[[k, j]],
fd[k]
);
assert!(err < 5e-6, "jac mismatch at ({k},{j}): err={err:e}");
}
}
}
#[test]
fn mixture_jet_rho_partials_matchfd() {
let spec = MixtureLinkSpec {
components: vec![
LinkComponent::Probit,
LinkComponent::Logit,
LinkComponent::CLogLog,
LinkComponent::Cauchit,
],
initial_rho: Array1::from_vec(vec![0.3, -0.6, 0.2]),
};
let state = state_fromspec(&spec).expect("state");
let eta = 0.35;
let out = mixture_inverse_link_jetwith_rho_partials(&state, eta);
let h = 1e-6;
for j in 0..state.rho.len() {
let mut rp = state.rho.clone();
rp[j] += h;
let sp = MixtureLinkSpec {
components: state.components.clone(),
initial_rho: rp,
};
let jp = mixture_inverse_link_jet(&state_fromspec(&sp).expect("sp"), eta);
let mut rm = state.rho.clone();
rm[j] -= h;
let sm = MixtureLinkSpec {
components: state.components.clone(),
initial_rho: rm,
};
let jm = mixture_inverse_link_jet(&state_fromspec(&sm).expect("sm"), eta);
let fd = InverseLinkJet {
mu: (jp.mu - jm.mu) / (2.0 * h),
d1: (jp.d1 - jm.d1) / (2.0 * h),
d2: (jp.d2 - jm.d2) / (2.0 * h),
d3: (jp.d3 - jm.d3) / (2.0 * h),
};
let an = out.djet_drho[j];
assert_eq!(an.mu.signum(), fd.mu.signum());
assert_eq!(an.d1.signum(), fd.d1.signum());
assert_eq!(an.d2.signum(), fd.d2.signum());
assert_eq!(an.d3.signum(), fd.d3.signum());
assert!((an.mu - fd.mu).abs() < 1e-6);
assert!((an.d1 - fd.d1).abs() < 1e-6);
assert!((an.d2 - fd.d2).abs() < 1e-6);
assert!((an.d3 - fd.d3).abs() < 1e-6);
}
}
#[test]
fn sas_param_partials_matchfd() {
let eta = 0.37;
let epsilon = -0.12;
let log_delta = 0.21;
let out = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
let h = 1e-6;
let ep_p = sas_inverse_link_jet(eta, epsilon + h, log_delta);
let ep_m = sas_inverse_link_jet(eta, epsilon - h, log_delta);
let fd_ep = InverseLinkJet {
mu: (ep_p.mu - ep_m.mu) / (2.0 * h),
d1: (ep_p.d1 - ep_m.d1) / (2.0 * h),
d2: (ep_p.d2 - ep_m.d2) / (2.0 * h),
d3: (ep_p.d3 - ep_m.d3) / (2.0 * h),
};
assert_eq!(out.djet_depsilon.mu.signum(), fd_ep.mu.signum());
assert_eq!(out.djet_depsilon.d1.signum(), fd_ep.d1.signum());
assert_eq!(out.djet_depsilon.d2.signum(), fd_ep.d2.signum());
assert_eq!(out.djet_depsilon.d3.signum(), fd_ep.d3.signum());
assert!((out.djet_depsilon.mu - fd_ep.mu).abs() < 5e-5);
assert!((out.djet_depsilon.d1 - fd_ep.d1).abs() < 5e-5);
assert!((out.djet_depsilon.d2 - fd_ep.d2).abs() < 5e-5);
assert!((out.djet_depsilon.d3 - fd_ep.d3).abs() < 5e-4);
let ld_p = sas_inverse_link_jet(eta, epsilon, log_delta + h);
let ld_m = sas_inverse_link_jet(eta, epsilon, log_delta - h);
let fd_ld = InverseLinkJet {
mu: (ld_p.mu - ld_m.mu) / (2.0 * h),
d1: (ld_p.d1 - ld_m.d1) / (2.0 * h),
d2: (ld_p.d2 - ld_m.d2) / (2.0 * h),
d3: (ld_p.d3 - ld_m.d3) / (2.0 * h),
};
assert_eq!(out.djet_dlog_delta.mu.signum(), fd_ld.mu.signum());
assert_eq!(out.djet_dlog_delta.d1.signum(), fd_ld.d1.signum());
assert_eq!(out.djet_dlog_delta.d2.signum(), fd_ld.d2.signum());
assert_eq!(out.djet_dlog_delta.d3.signum(), fd_ld.d3.signum());
assert!((out.djet_dlog_delta.mu - fd_ld.mu).abs() < 5e-5);
assert!((out.djet_dlog_delta.d1 - fd_ld.d1).abs() < 5e-5);
assert!((out.djet_dlog_delta.d2 - fd_ld.d2).abs() < 5e-5);
assert!((out.djet_dlog_delta.d3 - fd_ld.d3).abs() < 5e-4);
}
#[test]
fn sas_jet_extreme_inputs_stay_finite() {
let cases = [
(-1e6, 0.0, 0.0),
(1e6, 0.0, 0.0),
(3.0, 12.0, 12.0),
(-3.0, -12.0, -12.0),
(0.5, 40.0, 10.0),
(0.5, -40.0, -10.0),
];
for (eta, eps, log_delta) in cases {
let j = sas_inverse_link_jet(eta, eps, log_delta);
assert!(j.mu.is_finite());
assert!(j.d1.is_finite());
assert!(j.d2.is_finite());
assert!(j.d3.is_finite());
let p = sas_inverse_link_jetwith_param_partials(eta, eps, log_delta);
assert!(p.djet_depsilon.mu.is_finite());
assert!(p.djet_depsilon.d1.is_finite());
assert!(p.djet_depsilon.d2.is_finite());
assert!(p.djet_depsilon.d3.is_finite());
assert!(p.djet_dlog_delta.mu.is_finite());
assert!(p.djet_dlog_delta.d1.is_finite());
assert!(p.djet_dlog_delta.d2.is_finite());
assert!(p.djet_dlog_delta.d3.is_finite());
}
}
#[test]
fn sas_param_partials_remain_finite_in_extreme_region() {
let eta = 10.0;
let epsilon = -60.0;
let log_delta = 40.0;
let j = sas_inverse_link_jetwith_param_partials(eta, epsilon, log_delta);
assert!(j.djet_depsilon.mu.is_finite());
assert!(j.djet_depsilon.d1.is_finite());
assert!(j.djet_depsilon.d2.is_finite());
assert!(j.djet_depsilon.d3.is_finite());
assert!(j.djet_dlog_delta.mu.is_finite());
assert!(j.djet_dlog_delta.d1.is_finite());
assert!(j.djet_dlog_delta.d2.is_finite());
assert!(j.djet_dlog_delta.d3.is_finite());
}
#[test]
fn sas_eta_jets_matchfd() {
let eta = -0.43;
let epsilon = 0.27;
let log_delta = -0.31;
let h = 1e-5;
let j0 = sas_inverse_link_jet(eta, epsilon, log_delta);
let jp = sas_inverse_link_jet(eta + h, epsilon, log_delta);
let jm = sas_inverse_link_jet(eta - h, epsilon, log_delta);
let d1fd = (jp.mu - jm.mu) / (2.0 * h);
let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
assert_eq!(j0.d1.signum(), d1fd.signum());
assert_eq!(j0.d2.signum(), d2fd.signum());
assert_eq!(j0.d3.signum(), d3fd.signum());
assert!((j0.d1 - d1fd).abs() < 5e-5);
assert!((j0.d2 - d2fd).abs() < 2e-4);
assert!((j0.d3 - d3fd).abs() < 1e-3);
}
#[test]
fn family_dispatch_resolves_parameterized_links_from_spec() {
let sas_state = SasLinkState::new(0.0, 0.0).expect("sas state");
let sas_spec = crate::types::LikelihoodSpec {
response: crate::types::ResponseFamily::Binomial,
link: InverseLink::Sas(sas_state),
};
let sas_jet = inverse_link_jet_for_family(&sas_spec, 0.1).expect("sas jet");
assert!(sas_jet.mu.is_finite());
assert!(sas_jet.d1.is_finite());
let mix_state = MixtureLinkState {
components: vec![LinkComponent::Logit, LinkComponent::Probit],
rho: ndarray::array![0.0],
pi: ndarray::array![0.5, 0.5],
};
let mix_spec = crate::types::LikelihoodSpec {
response: crate::types::ResponseFamily::Binomial,
link: InverseLink::Mixture(mix_state),
};
let mix_jet = inverse_link_jet_for_family(&mix_spec, 0.1).expect("mix jet");
assert!(mix_jet.mu.is_finite());
assert!(mix_jet.d1.is_finite());
}
#[test]
fn beta_logistic_reduces_to_logit_at_delta0_epsilon0() {
let eta = 0.42;
let j_bl = beta_logistic_inverse_link_jet(eta, 0.0, 0.0);
let j_logit = component_inverse_link_jet(LinkComponent::Logit, eta);
assert!((j_bl.mu - j_logit.mu).abs() < 1e-10);
assert!((j_bl.d1 - j_logit.d1).abs() < 1e-10);
assert!((j_bl.d2 - j_logit.d2).abs() < 1e-10);
assert!((j_bl.d3 - j_logit.d3).abs() < 1e-10);
}
#[test]
fn beta_logistic_eta_jets_matchfd() {
let eta = -0.31;
let delta = 0.27;
let epsilon = -0.19;
let h = 1e-5;
let j0 = beta_logistic_inverse_link_jet(eta, delta, epsilon);
let jp = beta_logistic_inverse_link_jet(eta + h, delta, epsilon);
let jm = beta_logistic_inverse_link_jet(eta - h, delta, epsilon);
let d1fd = (jp.mu - jm.mu) / (2.0 * h);
let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
assert_eq!(j0.d1.signum(), d1fd.signum());
assert_eq!(j0.d2.signum(), d2fd.signum());
assert_eq!(j0.d3.signum(), d3fd.signum());
assert!((j0.d1 - d1fd).abs() < 5e-5);
assert!((j0.d2 - d2fd).abs() < 5e-5);
assert!((j0.d3 - d3fd).abs() < 2e-4);
}
#[test]
fn standard_kernel_structs_match_component_jets() {
let eta = 0.73;
assert_eq!(
ProbitLinkKernel.jet(eta).expect("probit"),
component_inverse_link_jet(LinkComponent::Probit, eta)
);
assert_eq!(
LogitLinkKernel.jet(eta).expect("logit"),
component_inverse_link_jet(LinkComponent::Logit, eta)
);
assert_eq!(
CLogLogLinkKernel.jet(eta).expect("cloglog"),
component_inverse_link_jet(LinkComponent::CLogLog, eta)
);
assert_eq!(
LogLogLinkKernel.jet(eta).expect("loglog"),
component_inverse_link_jet(LinkComponent::LogLog, eta)
);
assert_eq!(
CauchitLinkKernel.jet(eta).expect("cauchit"),
component_inverse_link_jet(LinkComponent::Cauchit, eta)
);
}
#[test]
fn all_component_eta_jets_matchfd() {
let components = [
LinkComponent::Logit,
LinkComponent::Probit,
LinkComponent::CLogLog,
LinkComponent::LogLog,
LinkComponent::Cauchit,
];
let points = [-3.0, -1.1, -0.2, 0.0, 0.7, 1.8, 3.2];
let h = 1e-5;
for c in components {
for &eta in &points {
let j0 = component_inverse_link_jet(c, eta);
let jp = component_inverse_link_jet(c, eta + h);
let jm = component_inverse_link_jet(c, eta - h);
let d1fd = (jp.mu - jm.mu) / (2.0 * h);
let d2fd = (jp.d1 - jm.d1) / (2.0 * h);
let d3fd = (jp.d2 - jm.d2) / (2.0 * h);
let d1_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
1.2e-4
} else {
5e-5
};
let d2_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
4e-4
} else {
1.2e-4
};
let d3_tol = if matches!(c, LinkComponent::CLogLog | LinkComponent::LogLog) {
1.2e-3
} else {
4e-4
};
if j0.d1.abs().max(d1fd.abs()) > 1e-10 {
assert_eq!(
j0.d1.signum(),
d1fd.signum(),
"d1 sign mismatch for {c:?} eta={eta}"
);
}
if j0.d2.abs().max(d2fd.abs()) > 1e-10 {
assert_eq!(
j0.d2.signum(),
d2fd.signum(),
"d2 sign mismatch for {c:?} eta={eta}: analytic={} fd={}",
j0.d2,
d2fd
);
}
if j0.d3.abs().max(d3fd.abs()) > 1e-10 {
assert_eq!(
j0.d3.signum(),
d3fd.signum(),
"d3 sign mismatch for {c:?} eta={eta}"
);
}
assert!(
(j0.d1 - d1fd).abs() < d1_tol,
"d1 mismatch for {c:?} eta={eta}: analytic={} fd={}",
j0.d1,
d1fd
);
assert!(
(j0.d2 - d2fd).abs() < d2_tol,
"d2 mismatch for {c:?} eta={eta}: analytic={} fd={}",
j0.d2,
d2fd
);
assert!(
(j0.d3 - d3fd).abs() < d3_tol,
"d3 mismatch for {c:?} eta={eta}: analytic={} fd={}",
j0.d3,
d3fd
);
}
}
}
#[test]
fn sas_center_matches_probit_at_delta1_epsilon0() {
let etas = [-3.0, -1.2, -0.3, 0.0, 0.4, 1.7, 3.0];
for eta in etas {
let sas = sas_inverse_link_jet(eta, 0.0, 0.0);
let probit = ProbitLinkKernel.jet(eta).expect("probit");
assert!(
(sas.mu - probit.mu).abs() < 6e-4,
"mu mismatch at eta={eta}"
);
assert!(
(sas.d1 - probit.d1).abs() < 6e-4,
"d1 mismatch at eta={eta}"
);
assert!(
(sas.d2 - probit.d2).abs() < 2e-3,
"d2 mismatch at eta={eta}"
);
assert!(
(sas.d3 - probit.d3).abs() < 4e-3,
"d3 mismatch at eta={eta}"
);
}
}
#[test]
fn beta_logistic_param_partials_matchfd() {
let eta = -0.41;
let delta = 0.23;
let epsilon = -0.17;
let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
let h = 1e-6;
let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
let fd_delta = InverseLinkJet {
mu: (dp.mu - dm.mu) / (2.0 * h),
d1: (dp.d1 - dm.d1) / (2.0 * h),
d2: (dp.d2 - dm.d2) / (2.0 * h),
d3: (dp.d3 - dm.d3) / (2.0 * h),
};
assert_eq!(out.djet_dlog_delta.mu.signum(), fd_delta.mu.signum());
assert_eq!(out.djet_dlog_delta.d1.signum(), fd_delta.d1.signum());
assert_eq!(out.djet_dlog_delta.d2.signum(), fd_delta.d2.signum());
assert_eq!(out.djet_dlog_delta.d3.signum(), fd_delta.d3.signum());
assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
assert!((out.djet_dlog_delta.d1 - fd_delta.d1).abs() < 5e-5);
assert!((out.djet_dlog_delta.d2 - fd_delta.d2).abs() < 1.2e-4);
assert!((out.djet_dlog_delta.d3 - fd_delta.d3).abs() < 4e-4);
let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
let fd_epsilon = InverseLinkJet {
mu: (ep.mu - em.mu) / (2.0 * h),
d1: (ep.d1 - em.d1) / (2.0 * h),
d2: (ep.d2 - em.d2) / (2.0 * h),
d3: (ep.d3 - em.d3) / (2.0 * h),
};
assert_eq!(out.djet_depsilon.mu.signum(), fd_epsilon.mu.signum());
assert_eq!(out.djet_depsilon.d1.signum(), fd_epsilon.d1.signum());
assert_eq!(out.djet_depsilon.d2.signum(), fd_epsilon.d2.signum());
assert_eq!(out.djet_depsilon.d3.signum(), fd_epsilon.d3.signum());
assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
assert!((out.djet_depsilon.d1 - fd_epsilon.d1).abs() < 5e-5);
assert!((out.djet_depsilon.d2 - fd_epsilon.d2).abs() < 1.2e-4);
assert!((out.djet_depsilon.d3 - fd_epsilon.d3).abs() < 4e-4);
}
#[test]
fn beta_logistic_param_partials_survive_eta_clamp() {
let eta = -1000.0;
let delta = -1.5;
let epsilon = 0.4;
let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
let h = 1e-6;
assert_eq!(out.jet.d1, 0.0);
assert_eq!(out.jet.d2, 0.0);
assert_eq!(out.jet.d3, 0.0);
let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
let fd_delta = InverseLinkJet {
mu: (dp.mu - dm.mu) / (2.0 * h),
d1: (dp.d1 - dm.d1) / (2.0 * h),
d2: (dp.d2 - dm.d2) / (2.0 * h),
d3: (dp.d3 - dm.d3) / (2.0 * h),
};
assert!((out.djet_dlog_delta.mu - fd_delta.mu).abs() < 5e-5);
assert_eq!(out.djet_dlog_delta.d1, 0.0);
assert_eq!(out.djet_dlog_delta.d2, 0.0);
assert_eq!(out.djet_dlog_delta.d3, 0.0);
let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
let fd_epsilon = InverseLinkJet {
mu: (ep.mu - em.mu) / (2.0 * h),
d1: (ep.d1 - em.d1) / (2.0 * h),
d2: (ep.d2 - em.d2) / (2.0 * h),
d3: (ep.d3 - em.d3) / (2.0 * h),
};
assert!((out.djet_depsilon.mu - fd_epsilon.mu).abs() < 5e-5);
assert_eq!(out.djet_depsilon.d1, 0.0);
assert_eq!(out.djet_depsilon.d2, 0.0);
assert_eq!(out.djet_depsilon.d3, 0.0);
}
#[test]
fn beta_logistic_param_partials_match_unclamped_mu_when_eta_clamps() {
let eta = -1000.0;
let delta = 0.01;
let epsilon = 0.0;
let out = beta_logistic_inverse_link_jetwith_param_partials(eta, delta, epsilon);
let h = 1e-6 * (1.0 + delta.abs());
assert!(out.jet.mu < BETA_LOGISTIC_U_EPS);
let dp = beta_logistic_inverse_link_jet(eta, delta + h, epsilon);
let dm = beta_logistic_inverse_link_jet(eta, delta - h, epsilon);
let fd_delta_mu = (dp.mu - dm.mu) / (2.0 * h);
assert!(fd_delta_mu != 0.0);
assert!((out.djet_dlog_delta.mu - fd_delta_mu).abs() < 1e-12);
let ep = beta_logistic_inverse_link_jet(eta, delta, epsilon + h);
let em = beta_logistic_inverse_link_jet(eta, delta, epsilon - h);
let fd_epsilon_mu = (ep.mu - em.mu) / (2.0 * h);
assert!(fd_epsilon_mu != 0.0);
assert!((out.djet_depsilon.mu - fd_epsilon_mu).abs() < 1e-12);
}
#[test]
fn inverse_link_pdfthird_derivative_matches_d3_finite_difference() {
let sas = InverseLink::Sas(SasLinkState::new(-0.25, 0.35).expect("sas state"));
let beta_logistic = InverseLink::BetaLogistic(SasLinkState {
epsilon: 0.18,
log_delta: -0.22,
delta: (-0.22_f64).exp(),
});
let mixture = InverseLink::Mixture(
state_fromspec(&MixtureLinkSpec {
components: vec![
LinkComponent::Probit,
LinkComponent::Logit,
LinkComponent::CLogLog,
LinkComponent::Cauchit,
],
initial_rho: Array1::from_vec(vec![0.35, -0.45, 0.2]),
})
.expect("mixture state"),
);
let links = [
InverseLink::Standard(LinkFunction::Probit),
InverseLink::Standard(LinkFunction::Logit),
InverseLink::Standard(LinkFunction::CLogLog),
sas,
beta_logistic,
mixture,
];
let etas = [-1.1, -0.2, 0.6];
let h = 1e-5;
for link in &links {
for &eta in &etas {
let jp = inverse_link_jet_for_inverse_link(link, eta + h).expect("jet+");
let jm = inverse_link_jet_for_inverse_link(link, eta - h).expect("jet-");
let d4fd = (jp.d3 - jm.d3) / (2.0 * h);
let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link, eta)
.expect("analytic d4");
assert_eq!(
d4.signum(),
d4fd.signum(),
"d4 sign mismatch for {:?} at eta={eta}: analytic={} fd={}",
link,
d4,
d4fd
);
assert!(
(d4 - d4fd).abs() < 5e-3,
"d4 mismatch for {:?} at eta={eta}: analytic={} fd={}",
link,
d4,
d4fd
);
}
}
}
#[test]
fn cloglog_large_finite_eta_should_saturate_without_nan_derivatives() {
let eta = 800.0;
let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
assert_eq!(jet.mu, 1.0);
assert!(
jet.d1 == 0.0,
"for mu(eta)=1-exp(-exp(eta)), dmu/deta = exp(eta-exp(eta)) and should underflow to 0 at eta={eta}; got d1={}",
jet.d1
);
assert!(
jet.d2 == 0.0,
"the saturated cloglog second derivative should also be 0 at eta={eta}; got d2={}",
jet.d2
);
assert!(
jet.d3 == 0.0,
"the saturated cloglog third derivative should also be 0 at eta={eta}; got d3={}",
jet.d3
);
let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
&InverseLink::Standard(LinkFunction::CLogLog),
eta,
)
.expect("cloglog d4");
assert!(
d4 == 0.0,
"the saturated cloglog fourth derivative should also be 0 at eta={eta}; got d4={d4}"
);
}
#[test]
fn loglog_large_negative_finite_eta_should_saturate_without_nan_derivatives() {
let eta = -800.0;
let jet = component_inverse_link_jet(LinkComponent::LogLog, eta);
assert_eq!(jet.mu, 0.0);
assert!(
jet.d1 == 0.0,
"for mu(eta)=exp(-exp(-eta)), dmu/deta = exp(-eta-exp(-eta)) and should underflow to 0 at eta={eta}; got d1={}",
jet.d1
);
assert!(
jet.d2 == 0.0,
"the saturated loglog second derivative should also be 0 at eta={eta}; got d2={}",
jet.d2
);
assert!(
jet.d3 == 0.0,
"the saturated loglog third derivative should also be 0 at eta={eta}; got d3={}",
jet.d3
);
let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
&InverseLink::Mixture(
state_fromspec(&MixtureLinkSpec {
components: vec![LinkComponent::LogLog, LinkComponent::Probit],
initial_rho: Array1::from_vec(vec![12.0]),
})
.expect("mixture state"),
),
eta,
)
.expect("loglog mixture d4");
assert!(
d4.is_finite(),
"even a nearly pure loglog mixture should not produce NaN fourth derivatives at eta={eta}; got d4={d4}"
);
}
#[test]
fn logit_tail_derivatives_should_match_stable_closed_forms() {
let eta = 50.0_f64;
let z = (-eta).exp();
let denom = 1.0_f64 + z;
let stable_d1 = z / denom.powi(2);
let stable_d2 = z * (z - 1.0) / denom.powi(3);
let stable_d3 = z * (z * z - 4.0 * z + 1.0) / denom.powi(4);
let stable_d4 = z * (z * z * z - 11.0 * z * z + 11.0 * z - 1.0) / denom.powi(5);
let stable_d5 =
z * (z * z * z * z - 26.0 * z * z * z + 66.0 * z * z - 26.0 * z + 1.0) / denom.powi(6);
assert!(stable_d1 > 0.0);
assert!(stable_d2 < 0.0);
assert!(stable_d3 > 0.0);
assert!(stable_d4 < 0.0);
assert!(stable_d5 > 0.0);
let jet = component_inverse_link_jet(LinkComponent::Logit, eta);
assert!(
(jet.d1 - stable_d1).abs() < 1e-30,
"logit d1 should equal the stable tail formula z/(1+z)^2 at eta={eta}; got {} vs {}",
jet.d1,
stable_d1
);
assert!(
(jet.d2 - stable_d2).abs() < 1e-30,
"logit d2 should equal the stable tail formula z(z-1)/(1+z)^3 at eta={eta}; got {} vs {}",
jet.d2,
stable_d2
);
assert!(
(jet.d3 - stable_d3).abs() < 1e-30,
"logit d3 should equal the stable tail formula z(z^2-4z+1)/(1+z)^4 at eta={eta}; got {} vs {}",
jet.d3,
stable_d3
);
let d4 = inverse_link_pdfthird_derivative_for_inverse_link(
&InverseLink::Standard(LinkFunction::Logit),
eta,
)
.expect("logit d4");
assert!(
(d4 - stable_d4).abs() < 1e-30,
"logit d4 should equal the stable tail formula z(z^3-11z^2+11z-1)/(1+z)^5 at eta={eta}; got {} vs {}",
d4,
stable_d4
);
let d5 = inverse_link_pdffourth_derivative_for_inverse_link(
&InverseLink::Standard(LinkFunction::Logit),
eta,
)
.expect("logit d5");
assert!(
(d5 - stable_d5).abs() < 1e-30,
"logit d5 should equal the stable tail formula z(z^4-26z^3+66z^2-26z+1)/(1+z)^6 at eta={eta}; got {} vs {}",
d5,
stable_d5
);
}
#[test]
fn cloglog_negative_tail_value_should_match_expm1_form() {
let eta = -50.0_f64;
let t = eta.exp();
let stable_mu = -(-t).exp_m1();
assert!(stable_mu > 0.0);
let jet = component_inverse_link_jet(LinkComponent::CLogLog, eta);
assert!(
(jet.mu - stable_mu).abs() < 1e-30,
"cloglog mu should equal -expm1(-exp(eta)) in the negative tail at eta={eta}; got {} vs {}",
jet.mu,
stable_mu
);
}
#[test]
fn loglog_fifth_derivative_should_match_closed_form_sign() {
let eta = 0.0_f64;
let r = (-eta).exp();
let expected =
(-r).exp() * (r - 15.0 * r * r + 25.0 * r.powi(3) - 10.0 * r.powi(4) + r.powi(5));
let d5 = component_inverse_link_pdffourth_derivative(LinkComponent::LogLog, eta);
assert!(
(d5 - expected).abs() < 1e-15,
"loglog d5 should equal exp(-r) * (r - 15r^2 + 25r^3 - 10r^4 + r^5) at eta={eta}; got {d5} vs {expected}"
);
assert!(d5 > 0.0, "loglog d5 should be positive at eta=0; got {d5}");
}
}