use super::*;
#[inline]
pub(crate) fn survival_q0_from_eta(eta_t: f64, eta_ls: f64) -> f64 {
if eta_t == 0.0 {
return 0.0;
}
let log_abs = eta_t.abs().ln() + (-eta_ls).min(EXP_NEG_STABLE_MAX_ARG);
if log_abs > EXP_NEG_STABLE_MAX_ARG {
if eta_t > 0.0 { -f64::MAX } else { f64::MAX }
} else {
-eta_t * exp_sigma_inverse_from_eta_scalar(eta_ls)
}
}
#[inline]
pub(crate) fn probit_survival_value(eta: f64) -> f64 {
if eta.is_nan() {
f64::NAN
} else if eta == f64::INFINITY {
0.0
} else if eta == f64::NEG_INFINITY {
1.0
} else {
0.5 * erfc(eta / std::f64::consts::SQRT_2)
}
}
#[inline]
pub(crate) fn probit_log_survival_and_ratio_derivatives(eta: f64) -> (f64, f64, f64, f64, f64) {
if eta.is_nan() {
return (f64::NAN, f64::NAN, f64::NAN, f64::NAN, f64::NAN);
}
if eta == f64::NEG_INFINITY {
return (0.0, 0.0, 0.0, 0.0, 0.0);
}
let x = eta / std::f64::consts::SQRT_2;
let (log_survival, ratio) = if eta >= 0.0 {
let erfcx_val = erfcx_nonnegative(x);
let log_surv = -0.5 * eta * eta + (0.5 * erfcx_val).ln();
let r = std::f64::consts::FRAC_2_SQRT_PI / (std::f64::consts::SQRT_2 * erfcx_val);
(log_surv, r)
} else {
let survival = probit_survival_value(eta);
(survival.ln(), normal_pdf(eta) / survival)
};
let dr = ratio * (ratio - eta);
let ddr = 2.0 * ratio.powi(3) - 3.0 * eta * ratio.powi(2) + (eta * eta - 1.0) * ratio;
let dddr = 6.0 * ratio.powi(4) - 12.0 * eta * ratio.powi(3)
+ (7.0 * eta * eta - 4.0) * ratio.powi(2)
+ (-eta * eta * eta + 3.0 * eta) * ratio;
(log_survival, ratio, dr, ddr, dddr)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ResidualDistribution {
Gaussian,
Gumbel,
Logistic,
}
pub trait ResidualDistributionOps {
fn cdf(&self, z: f64) -> f64;
fn pdf(&self, z: f64) -> f64;
fn pdf_derivative(&self, z: f64) -> f64;
fn pdfsecond_derivative(&self, z: f64) -> f64;
fn pdfthird_derivative(&self, z: f64) -> f64;
fn pdffourth_derivative(&self, z: f64) -> f64;
}
impl ResidualDistributionOps for ResidualDistribution {
fn cdf(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => normal_cdf(z),
ResidualDistribution::Gumbel => {
component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).mu
}
ResidualDistribution::Logistic => {
component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).mu
}
}
}
fn pdf(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => normal_pdf(z),
ResidualDistribution::Gumbel => {
component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d1
}
ResidualDistribution::Logistic => {
component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d1
}
}
}
fn pdf_derivative(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => -z * normal_pdf(z),
ResidualDistribution::Gumbel => {
component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d2
}
ResidualDistribution::Logistic => {
component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d2
}
}
}
fn pdfsecond_derivative(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => {
let f = normal_pdf(z);
(z * z - 1.0) * f
}
ResidualDistribution::Gumbel => {
component_inverse_link_jet(gam_problem::LinkComponent::CLogLog, z).d3
}
ResidualDistribution::Logistic => {
component_inverse_link_jet(gam_problem::LinkComponent::Logit, z).d3
}
}
}
fn pdfthird_derivative(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => {
let f = normal_pdf(z);
-(z * z * z - 3.0 * z) * f
}
ResidualDistribution::Gumbel => inverse_link_pdfthird_derivative_for_inverse_link(
&InverseLink::Standard(StandardLink::CLogLog),
z,
)
.expect("standard cloglog inverse-link third derivative should evaluate"),
ResidualDistribution::Logistic => inverse_link_pdfthird_derivative_for_inverse_link(
&InverseLink::Standard(StandardLink::Logit),
z,
)
.expect("standard logit inverse-link third derivative should evaluate"),
}
}
fn pdffourth_derivative(&self, z: f64) -> f64 {
match self {
ResidualDistribution::Gaussian => {
let f = normal_pdf(z);
let z2 = z * z;
(z2 * z2 - 6.0 * z2 + 3.0) * f
}
ResidualDistribution::Gumbel => inverse_link_pdffourth_derivative_for_inverse_link(
&InverseLink::Standard(StandardLink::CLogLog),
z,
)
.expect("standard cloglog inverse-link fourth derivative should evaluate"),
ResidualDistribution::Logistic => inverse_link_pdffourth_derivative_for_inverse_link(
&InverseLink::Standard(StandardLink::Logit),
z,
)
.expect("standard logit inverse-link fourth derivative should evaluate"),
}
}
}
#[inline]
pub(crate) fn residual_distribution_link(distribution: ResidualDistribution) -> StandardLink {
match distribution {
ResidualDistribution::Gaussian => StandardLink::Probit,
ResidualDistribution::Gumbel => StandardLink::CLogLog,
ResidualDistribution::Logistic => StandardLink::Logit,
}
}
#[inline]
pub fn residual_distribution_inverse_link(distribution: ResidualDistribution) -> InverseLink {
InverseLink::Standard(residual_distribution_link(distribution))
}
#[inline]
pub fn residual_distribution_from_inverse_link(link: &InverseLink) -> Option<ResidualDistribution> {
match link {
InverseLink::Standard(StandardLink::Probit) => Some(ResidualDistribution::Gaussian),
InverseLink::Standard(StandardLink::CLogLog) => Some(ResidualDistribution::Gumbel),
InverseLink::Standard(StandardLink::Logit) => Some(ResidualDistribution::Logistic),
_ => None,
}
}
pub(crate) fn inverse_link_pdffourth_derivative(
inverse_link: &InverseLink,
eta: f64,
) -> Result<f64, SurvivalLocationScaleError> {
match inverse_link {
InverseLink::Standard(StandardLink::Probit) => {
Ok(ResidualDistribution::Gaussian.pdffourth_derivative(eta))
}
InverseLink::Standard(StandardLink::Logit) => {
Ok(ResidualDistribution::Logistic.pdffourth_derivative(eta))
}
InverseLink::Standard(StandardLink::CLogLog) => {
Ok(ResidualDistribution::Gumbel.pdffourth_derivative(eta))
}
_ => gam_solve::mixture_link::inverse_link_pdffourth_derivative_for_inverse_link(
inverse_link,
eta,
)
.map_err(|e| SurvivalLocationScaleError::NumericalFailure {
reason: format!("inverse link fourth-derivative evaluation failed at eta={eta}: {e}"),
}),
}
}