use crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link;
use crate::probability::signed_probit_logcdf_and_mills_ratio;
use crate::types::{InverseLink, StandardLink};
#[inline]
pub(super) fn binomial_loglik_mu_derivatives(y: f64, mu: f64) -> (f64, f64, f64, f64) {
let (a1, a2, a3, a4) = if y == 0.0 {
(0.0, 0.0, 0.0, 0.0)
} else {
let im = 1.0 / mu;
let y_im = y * im;
(
y_im,
-y_im * im,
2.0 * y_im * im * im,
-6.0 * y_im * im * im * im,
)
};
let z = 1.0 - y;
let (b1, b2, b3, b4) = if z == 0.0 {
(0.0, 0.0, 0.0, 0.0)
} else {
let io = 1.0 / (1.0 - mu);
let z_io = z * io;
(
-z_io,
-z_io * io,
-2.0 * z_io * io * io,
-6.0 * z_io * io * io * io,
)
};
(a1 + b1, a2 + b2, a3 + b3, a4 + b4)
}
#[inline]
fn binomial_mu_is_interior(mu: f64) -> bool {
mu > 0.0 && mu < 1.0
}
#[inline]
pub(super) fn binomial_score_curvaturethird_from_jet(
y: f64,
weight: f64,
mu: f64,
d1: f64,
d2: f64,
d3: f64,
) -> (f64, f64, f64) {
if weight == 0.0 || !binomial_mu_is_interior(mu) {
return (0.0, 0.0, 0.0);
}
let (ellmu, ellmumu, ellmumum, _) = binomial_loglik_mu_derivatives(y, mu);
let score_q = weight * ellmu * d1;
let d2ell_dq2 = weight * (ellmumu * d1 * d1 + ellmu * d2);
let curvature_q = -d2ell_dq2;
let third_q = weight * (ellmumum * d1 * d1 * d1 + 3.0 * ellmumu * d1 * d2 + ellmu * d3);
(score_q, curvature_q, third_q)
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_from_jet(
y: f64,
weight: f64,
mu: f64,
d1: f64,
d2: f64,
d3: f64,
) -> (f64, f64, f64) {
let (score_q, curvature_q, third_q) =
binomial_score_curvaturethird_from_jet(y, weight, mu, d1, d2, d3);
(-score_q, curvature_q, -third_q)
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_probit_closed_form(
y: f64,
weight: f64,
q: f64,
) -> (f64, f64, f64) {
if weight == 0.0 || !q.is_finite() {
return (0.0, 0.0, 0.0);
}
let (_, left) = signed_probit_logcdf_and_mills_ratio(q);
let (_, right) = signed_probit_logcdf_and_mills_ratio(-q);
let left_prime = -left * (q + left);
let left_m2 = -left_prime;
let left_m3 = left + left_prime * (q + 2.0 * left);
let right_prime = right * (right - q);
let right_m2 = right_prime;
let right_m3 = right_prime * (2.0 * right - q) - right;
let y0 = 1.0 - y;
let m1 = weight * (y0 * right - y * left);
let m2 = weight * (y0 * right_m2 + y * left_m2);
let m3 = weight * (y0 * right_m3 + y * left_m3);
(m1, m2, m3)
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_probit_closed_form(
y: f64,
weight: f64,
q: f64,
) -> f64 {
if weight == 0.0 || !q.is_finite() {
return 0.0;
}
let (_, left) = signed_probit_logcdf_and_mills_ratio(q);
let (_, right) = signed_probit_logcdf_and_mills_ratio(-q);
let left_prime = -left * (q + left);
let left_m3 = left + left_prime * (q + 2.0 * left);
let left_m4 = 2.0 * left_prime - left_m3 * (q + 2.0 * left) + 2.0 * left_prime * left_prime;
let right_prime = right * (right - q);
let right_m3 = right_prime * (2.0 * right - q) - right;
let right_m4 =
right_m3 * (2.0 * right - q) + 2.0 * right_prime * right_prime - 2.0 * right_prime;
weight * ((1.0 - y) * right_m4 + y * left_m4)
}
#[inline]
fn logit_probability_and_variance(q: f64) -> (f64, f64) {
if q >= 0.0 {
let t = (-q).exp();
let denom = 1.0 + t;
(1.0 / denom, t / (denom * denom))
} else {
let t = q.exp();
let denom = 1.0 + t;
(t / denom, t / (denom * denom))
}
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_logit_closed_form(
y: f64,
weight: f64,
q: f64,
) -> (f64, f64, f64) {
if weight == 0.0 || !q.is_finite() {
return (0.0, 0.0, 0.0);
}
let (p, s) = logit_probability_and_variance(q);
let m1 = weight * (p - y);
let m2 = weight * s;
let m3 = weight * s * (1.0 - 2.0 * p);
(m1, m2, m3)
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_logit_closed_form(
x: f64,
weight: f64,
q: f64,
) -> f64 {
assert!(!x.is_nan());
if weight == 0.0 || !q.is_finite() {
return 0.0;
}
let (_p, s) = logit_probability_and_variance(q);
weight * s * (1.0 - 6.0 * s)
}
#[inline]
pub(super) fn cloglog_stable_h(z: f64) -> f64 {
if z.abs() < 1e-12 {
return 1.0 - z * 0.5 + z * z / 12.0;
}
let expm1_z = z.exp_m1();
if expm1_z.is_infinite() {
let r = (-z).exp();
if r == 0.0 {
return 0.0;
}
return z * r / (1.0 - r);
}
z / expm1_z
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_cloglog_closed_form(
y: f64,
weight: f64,
q: f64,
) -> (f64, f64, f64) {
if weight == 0.0 || !q.is_finite() {
return (0.0, 0.0, 0.0);
}
let z = q.exp(); let h = cloglog_stable_h(z);
let y0 = 1.0 - y;
let y0_term = if y0 == 0.0 { 0.0 } else { y0 * z };
if y == 0.0 || h == 0.0 {
let base = weight * y0_term;
return (base, base, base);
}
let m1 = weight * (y0_term - y * h);
let m2 = weight * (y0_term + y * h * (h + z - 1.0));
let m3 =
weight * (y0_term - y * h * (2.0 * h * h + 3.0 * (z - 1.0) * h + z * z - 3.0 * z + 1.0));
(m1, m2, m3)
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_cloglog_closed_form(
y: f64,
weight: f64,
q: f64,
) -> f64 {
if weight == 0.0 || !q.is_finite() {
return 0.0;
}
let z = q.exp();
let h = cloglog_stable_h(z);
let y0 = 1.0 - y;
let y0_term = if y0 == 0.0 { 0.0 } else { y0 * z };
if y == 0.0 || h == 0.0 {
return weight * y0_term;
}
let h2 = h * h;
let h3 = h2 * h;
let z2 = z * z;
let z3 = z2 * z;
let y1_poly = 6.0 * h3 + 12.0 * (z - 1.0) * h2 + (7.0 * z2 - 18.0 * z + 7.0) * h + z3
- 6.0 * z2
+ 7.0 * z
- 1.0;
weight * (y0_term + y * h * y1_poly)
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_from_jet(
y: f64,
weight: f64,
mu: f64,
d1: f64,
d2: f64,
d3: f64,
d4: f64,
) -> f64 {
if weight == 0.0
|| !binomial_mu_is_interior(mu)
|| !d1.is_finite()
|| !d2.is_finite()
|| !d3.is_finite()
|| !d4.is_finite()
{
return 0.0;
}
let (ellmu, ellmumu, ellmumum, ellmumumum) = binomial_loglik_mu_derivatives(y, mu);
let fourth_q = weight
* (ellmumumum * d1.powi(4)
+ 6.0 * ellmumum * d1 * d1 * d2
+ ellmumu * (3.0 * d2 * d2 + 4.0 * d1 * d3)
+ ellmu * d4);
-fourth_q
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_dispatch(
y: f64,
weight: f64,
q: f64,
mu: f64,
d1: f64,
d2: f64,
d3: f64,
link_kind: &InverseLink,
) -> (f64, f64, f64) {
if binomial_link_has_closed_form(link_kind) {
return binomial_neglog_q_derivatives_closed_form_dispatch(y, weight, q, link_kind);
}
binomial_neglog_q_derivatives_from_jet(y, weight, mu, d1, d2, d3)
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_dispatch(
y: f64,
weight: f64,
q: f64,
mu: f64,
d1: f64,
d2: f64,
d3: f64,
link_kind: &InverseLink,
) -> Result<f64, String> {
if binomial_link_has_closed_form(link_kind) {
return Ok(binomial_neglog_q_fourth_derivative_closed_form_dispatch(
y, weight, q, link_kind,
));
}
let d4 = inverse_link_pdfthird_derivative_for_inverse_link(link_kind, q)
.map_err(|e| format!("binomial inverse-link third derivative evaluation failed: {e}"))?;
Ok(binomial_neglog_q_fourth_derivative_from_jet(
y, weight, mu, d1, d2, d3, d4,
))
}
#[inline]
pub(super) fn binomial_neglog_q_derivatives_closed_form_dispatch(
y: f64,
weight: f64,
q: f64,
link_kind: &InverseLink,
) -> (f64, f64, f64) {
match link_kind {
InverseLink::Standard(StandardLink::Probit) => {
binomial_neglog_q_derivatives_probit_closed_form(y, weight, q)
}
InverseLink::Standard(StandardLink::Logit) => {
binomial_neglog_q_derivatives_logit_closed_form(y, weight, q)
}
InverseLink::Standard(StandardLink::CLogLog) => {
binomial_neglog_q_derivatives_cloglog_closed_form(y, weight, q)
}
_ => {
(0.0, 0.0, 0.0)
}
}
}
#[inline]
pub(super) fn binomial_neglog_q_fourth_derivative_closed_form_dispatch(
y: f64,
weight: f64,
q: f64,
link_kind: &InverseLink,
) -> f64 {
match link_kind {
InverseLink::Standard(StandardLink::Probit) => {
binomial_neglog_q_fourth_derivative_probit_closed_form(y, weight, q)
}
InverseLink::Standard(StandardLink::Logit) => {
binomial_neglog_q_fourth_derivative_logit_closed_form(y, weight, q)
}
InverseLink::Standard(StandardLink::CLogLog) => {
binomial_neglog_q_fourth_derivative_cloglog_closed_form(y, weight, q)
}
_ => 0.0,
}
}
#[inline]
pub(super) fn binomial_link_has_closed_form(link_kind: &InverseLink) -> bool {
matches!(
link_kind,
InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::Logit)
| InverseLink::Standard(StandardLink::CLogLog)
)
}
#[cfg(test)]
mod tests {
use super::*;
fn cauchit_jet(q: f64) -> (f64, f64, f64, f64, f64) {
let u = 1.0 + q * q;
let mu = 0.5 + q.atan() / std::f64::consts::PI;
let d1 = 1.0 / (std::f64::consts::PI * u);
let d2 = -2.0 * q / (std::f64::consts::PI * u.powi(2));
let d3 = 2.0 * (3.0 * q * q - 1.0) / (std::f64::consts::PI * u.powi(3));
let d4 = 24.0 * q * (1.0 - q * q) / (std::f64::consts::PI * u.powi(4));
(mu, d1, d2, d3, d4)
}
fn cauchit_m4(y: f64, weight: f64, q: f64) -> f64 {
let (mu, d1, d2, d3, d4) = cauchit_jet(q);
binomial_neglog_q_fourth_derivative_from_jet(y, weight, mu, d1, d2, d3, d4)
}
fn cauchit_m3(y: f64, weight: f64, q: f64) -> f64 {
let (mu, d1, d2, d3, _d4) = cauchit_jet(q);
binomial_neglog_q_derivatives_from_jet(y, weight, mu, d1, d2, d3).2
}
#[test]
fn generic_binomial_m4_matches_finite_difference_of_m3_cauchit() {
let h = 1e-4;
for &(y, weight, q) in &[
(0.3_f64, 2.0_f64, 0.7_f64),
(0.8_f64, 1.0_f64, -0.4_f64),
(0.1_f64, 3.0_f64, 1.3_f64),
(0.6_f64, 0.5_f64, -1.1_f64),
] {
let fd = (-cauchit_m3(y, weight, q + 2.0 * h) + 8.0 * cauchit_m3(y, weight, q + h)
- 8.0 * cauchit_m3(y, weight, q - h)
+ cauchit_m3(y, weight, q - 2.0 * h))
/ (12.0 * h);
let analytic = cauchit_m4(y, weight, q);
let tol = 1e-5 * (1.0 + analytic.abs());
assert!(
(analytic - fd).abs() < tol,
"cauchit m4 (y={y}, w={weight}, q={q}): analytic={analytic}, fd={fd}, diff={}",
(analytic - fd).abs()
);
}
}
#[test]
fn generic_binomial_m4_matches_analytic_cauchit_ground_truth() {
let analytic = cauchit_m4(0.3, 2.0, 0.7);
assert!(
(analytic - 2.1168155916).abs() < 1e-7,
"expected +2.1168155916, got {analytic}"
);
assert!(
(analytic - (-10.3779706944)).abs() > 1.0,
"matched the buggy m4 value {analytic}"
);
}
fn logit_jet(q: f64) -> (f64, f64, f64, f64, f64) {
let p = 1.0 / (1.0 + (-q).exp());
let s = p * (1.0 - p);
(
p,
s,
s * (1.0 - 2.0 * p),
s * (1.0 - 6.0 * s),
s * (1.0 - 2.0 * p) * (1.0 - 12.0 * s),
)
}
#[test]
fn logit_closed_form_agrees_with_generic_jet_path() {
for &(y, w, q) in &[
(0.3_f64, 2.0_f64, 0.5_f64),
(0.7, 1.0, -1.3),
(0.0, 1.5, 2.0),
(1.0, 0.5, -0.8),
(0.42, 3.0, 0.0),
] {
let (m1, m2, m3) = binomial_neglog_q_derivatives_logit_closed_form(y, w, q);
let m4 = binomial_neglog_q_fourth_derivative_logit_closed_form(y, w, q);
let (mu, d1, d2, d3, d4) = logit_jet(q);
let (s1, c2, t3) = binomial_neglog_q_derivatives_from_jet(y, w, mu, d1, d2, d3);
let j4 = binomial_neglog_q_fourth_derivative_from_jet(y, w, mu, d1, d2, d3, d4);
let tol = 1e-9 * (1.0 + m1.abs() + m2.abs() + m3.abs() + m4.abs());
assert!(
(m1 - s1).abs() < tol,
"m1 mismatch q={q}: closed={m1} jet={s1}"
);
assert!(
(m2 - c2).abs() < tol,
"m2 mismatch q={q}: closed={m2} jet={c2}"
);
assert!(
(m3 - t3).abs() < tol,
"m3 mismatch q={q}: closed={m3} jet={t3}"
);
assert!(
(m4 - j4).abs() < tol,
"m4 mismatch q={q}: closed={m4} jet={j4}"
);
}
}
#[test]
fn logit_curvature_exact_through_the_old_clamp_boundary() {
for &q in &[24.0_f64, 30.0, 40.0, 50.0] {
let t = (-q).exp();
let denom = 1.0 + t;
let s_exact = t / (denom * denom);
let (_, m2, _) = binomial_neglog_q_derivatives_logit_closed_form(1.0, 1.0, q);
let m4 = binomial_neglog_q_fourth_derivative_logit_closed_form(1.0, 1.0, q);
assert!(
(m2 - s_exact).abs() <= 1e-12 * s_exact,
"q={q}: m2={m2} != exact s={s_exact}"
);
assert!(
m2 < 1e-10,
"q={q}: m2={m2} looks floored at MIN_PROB·(1-MIN_PROB)"
);
assert!(
(m4 - s_exact * (1.0 - 6.0 * s_exact)).abs() <= 1e-12 * s_exact,
"q={q}: m4={m4} not exact ws(1-6s)"
);
}
}
#[test]
fn generic_jet_uses_raw_sub_min_prob_mu_not_floored() {
let (y, w) = (1.0_f64, 1.0_f64);
let mu = 1e-12_f64;
let m4 = binomial_neglog_q_fourth_derivative_from_jet(y, w, mu, 1.0, 0.0, 0.0, 0.0);
let exact = 6.0 * w * y / mu.powi(4);
assert!(
(m4 - exact).abs() <= 1e-6 * exact,
"raw-μ m4 should be 6yw/μ⁴={exact}, got {m4}"
);
let floored = 6.0 * w * y / 1e-10_f64.powi(4);
assert!(
m4 > 100.0 * floored,
"m4={m4} is near the floored value {floored}; μ was clamped"
);
let (score, _curv, _third) =
binomial_score_curvaturethird_from_jet(y, w, mu, 1.0, 0.0, 0.0);
let exact_score = w * y / mu;
assert!(
(score - exact_score).abs() <= 1e-6 * exact_score,
"raw-μ score should be wy/μ={exact_score}, got {score}"
);
}
#[test]
fn generic_jet_saturated_boundary_collapses_to_zero() {
for &mu in &[0.0_f64, 1.0, -0.0, f64::NAN, f64::INFINITY] {
let (s, c, t) = binomial_score_curvaturethird_from_jet(0.7, 2.0, mu, 1.0, 0.5, 0.1);
let m4 = binomial_neglog_q_fourth_derivative_from_jet(0.7, 2.0, mu, 1.0, 0.5, 0.1, 0.2);
assert_eq!(
(s, c, t),
(0.0, 0.0, 0.0),
"boundary μ={mu} must give zero score/curv/third"
);
assert_eq!(m4, 0.0, "boundary μ={mu} must give zero m4");
}
}
#[test]
fn loglik_mu_derivatives_no_nan_at_compatible_boundary() {
let (e1, e2, e3, e4) = binomial_loglik_mu_derivatives(0.0, 0.0);
for v in [e1, e2, e3, e4] {
assert!(v.is_finite(), "y=0,μ=0 produced non-finite {v}");
}
assert_eq!(e1, -1.0, "ℓ'(0)=-(1-y)/(1-μ)=-1 at y=0,μ=0");
let (f1, f2, f3, f4) = binomial_loglik_mu_derivatives(1.0, 1.0);
for v in [f1, f2, f3, f4] {
assert!(v.is_finite(), "y=1,μ=1 produced non-finite {v}");
}
assert_eq!(f1, 1.0, "ℓ'(1)=y/μ=1 at y=1,μ=1");
}
}