use super::*;
pub struct VarianceJet {
pub v: f64,
pub v1: f64,
pub v2: f64,
pub v3: f64,
pub v4: f64,
}
impl VarianceJet {
const VARIANCE_MU_FLOOR: f64 = 1e-10;
#[inline]
pub fn bernoulli(mu: f64) -> Self {
Self {
v: mu * (1.0 - mu),
v1: 1.0 - 2.0 * mu,
v2: -2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn poisson(mu: f64) -> Self {
Self {
v: mu,
v1: 1.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gamma(mu: f64) -> Self {
Self {
v: mu * mu,
v1: 2.0 * mu,
v2: 2.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn tweedie(mu: f64, p: f64) -> Self {
let mu = mu.max(Self::VARIANCE_MU_FLOOR);
Self {
v: mu.powf(p),
v1: p * mu.powf(p - 1.0),
v2: p * (p - 1.0) * mu.powf(p - 2.0),
v3: p * (p - 1.0) * (p - 2.0) * mu.powf(p - 3.0),
v4: p * (p - 1.0) * (p - 2.0) * (p - 3.0) * mu.powf(p - 4.0),
}
}
#[inline]
pub fn negative_binomial(mu: f64, theta: f64) -> Self {
let mu = mu.max(Self::VARIANCE_MU_FLOOR);
let inv_theta = if valid_negbin_theta(theta) {
1.0 / theta
} else {
f64::NAN
};
Self {
v: mu + mu * mu * inv_theta,
v1: 1.0 + 2.0 * mu * inv_theta,
v2: 2.0 * inv_theta,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn gaussian() -> Self {
Self {
v: 1.0,
v1: 0.0,
v2: 0.0,
v3: 0.0,
v4: 0.0,
}
}
#[inline]
pub fn binomial_n(mu: f64) -> Self {
Self::bernoulli(mu)
}
#[inline]
pub fn beta(mu: f64, phi: f64) -> Self {
let scale = 1.0 / (1.0 + phi.max(1e-12));
let base = Self::bernoulli(mu);
Self {
v: base.v * scale,
v1: base.v1 * scale,
v2: base.v2 * scale,
v3: 0.0,
v4: 0.0,
}
}
}
pub(crate) const OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC: f64 = 1e-6;
pub(crate) const OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR: f64 = 1e-12;
#[inline]
pub fn solver_hessian_weight_floor(fisher_weight: f64) -> f64 {
(fisher_weight.max(0.0) * OBSERVED_HESSIAN_WEIGHT_FLOOR_FRAC)
.max(OBSERVED_HESSIAN_WEIGHT_ABS_FLOOR)
}
pub fn outer_hessian_curvature_arrays(
hessian_weights: crate::matrix::SignedWeightsView<'_>,
fisher_weights: crate::matrix::PsdWeightsView<'_>,
c_array: &Array1<f64>,
d_array: &Array1<f64>,
eta: &Array1<f64>,
inverse_link: &InverseLink,
) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
let hessian_view = hessian_weights.view();
let fisher_view = fisher_weights.view();
let n = hessian_view.len();
let mut w_out = Array1::<f64>::zeros(n);
let mut c_out = Array1::<f64>::zeros(n);
let mut d_out = Array1::<f64>::zeros(n);
for i in 0..n {
let floor = solver_hessian_weight_floor(fisher_view[i]);
let w = hessian_view[i];
let clamp_active = eta_clamp_active(inverse_link, eta[i]);
let w_below_floor = !(w.is_finite() && w > floor);
if w_below_floor {
w_out[i] = floor;
c_out[i] = 0.0;
d_out[i] = 0.0;
} else if clamp_active {
w_out[i] = w;
c_out[i] = 0.0;
d_out[i] = 0.0;
} else {
w_out[i] = w;
c_out[i] = c_array[i];
d_out[i] = d_array[i];
}
}
(w_out, c_out, d_out)
}
#[inline]
pub(crate) fn fixed_glm_dispersion(likelihood: &GlmLikelihoodSpec) -> f64 {
likelihood.fixed_phi().unwrap_or(1.0)
}
#[inline]
pub fn weight_family_for_glm_likelihood(likelihood: &GlmLikelihoodSpec) -> WeightFamily {
match &likelihood.spec.response {
ResponseFamily::Gaussian => WeightFamily::Gaussian,
ResponseFamily::Poisson => WeightFamily::Poisson,
ResponseFamily::Tweedie { p } => WeightFamily::Tweedie { p: *p },
ResponseFamily::NegativeBinomial { theta, .. } => {
WeightFamily::NegativeBinomial { theta: *theta }
}
ResponseFamily::Beta { phi } => WeightFamily::Beta { phi: *phi },
ResponseFamily::Gamma => WeightFamily::Gamma,
ResponseFamily::Binomial => WeightFamily::Binomial,
ResponseFamily::RoystonParmar => WeightFamily::Gaussian,
}
}
#[inline]
pub(crate) fn weight_link_for_inverse_link(inverse_link: &InverseLink) -> WeightLink {
match inverse_link {
InverseLink::Standard(StandardLink::Identity) => WeightLink::Identity,
InverseLink::Standard(StandardLink::Log) => WeightLink::Log,
InverseLink::Standard(StandardLink::Logit) => WeightLink::Logit,
InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => WeightLink::Other,
}
}
#[inline]
pub(crate) fn supports_observed_hessian_curvature_for_likelihood(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
) -> bool {
let spec = &likelihood.spec;
if matches!(spec.response, ResponseFamily::NegativeBinomial { .. }) {
return matches!(inverse_link, InverseLink::Standard(StandardLink::Log));
}
if matches!(spec.response, ResponseFamily::Gamma) {
return true;
}
if !matches!(spec.response, ResponseFamily::Binomial) {
return false;
}
matches!(
spec.link,
InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
)
}
#[inline]
pub(crate) fn eta_for_observed_hessian_jet(inverse_link: &InverseLink, eta: f64) -> f64 {
match inverse_link {
InverseLink::Standard(StandardLink::Logit | StandardLink::Log) => {
eta.clamp(-ETA_CLAMP, ETA_CLAMP)
}
InverseLink::Standard(StandardLink::Identity) => eta,
InverseLink::Standard(StandardLink::Probit) => eta.clamp(-6.0, 6.0),
InverseLink::Standard(StandardLink::CLogLog) | InverseLink::LatentCLogLog(_) => {
eta.clamp(-23.0, 3.0)
}
InverseLink::Sas(_) | InverseLink::BetaLogistic(_) | InverseLink::Mixture(_) => {
eta.clamp(-20.0, 20.0)
}
}
}
#[inline]
pub fn eta_clamp_active(inverse_link: &InverseLink, eta: f64) -> bool {
let clamped = eta_for_observed_hessian_jet(inverse_link, eta);
clamped != eta
}
pub(crate) fn solver_hessian_weights_into(
hessian_weights: &Array1<f64>,
fisher_weights: &Array1<f64>,
out: &mut Array1<f64>,
) {
if out.len() != hessian_weights.len() {
*out = Array1::<f64>::zeros(hessian_weights.len());
}
ndarray::Zip::from(out)
.and(hessian_weights)
.and(fisher_weights)
.par_for_each(|o, &w, &fw| {
let floor = solver_hessian_weight_floor(fw);
*o = if w.is_finite() && w > floor { w } else { floor };
});
}
pub(crate) fn compute_observed_hessian_curvature_arrays_into(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
hessian_weights: &mut Array1<f64>,
hessian_c: &mut Array1<f64>,
hessian_d: &mut Array1<f64>,
) -> Result<(), EstimationError> {
assert!(supports_observed_hessian_curvature_for_likelihood(
likelihood,
inverse_link
));
let n = eta.len();
if hessian_weights.len() != n {
*hessian_weights = Array1::<f64>::zeros(n);
}
if hessian_c.len() != n {
*hessian_c = Array1::<f64>::zeros(n);
}
if hessian_d.len() != n {
*hessian_d = Array1::<f64>::zeros(n);
}
let weight_family = weight_family_for_glm_likelihood(likelihood);
let weight_link = weight_link_for_inverse_link(inverse_link);
let phi = fixed_glm_dispersion(likelihood);
hessian_weights
.as_slice_mut()
.expect("hessian weights must be contiguous")
.par_iter_mut()
.zip(
hessian_c
.as_slice_mut()
.expect("hessian c must be contiguous")
.par_iter_mut(),
)
.zip(
hessian_d
.as_slice_mut()
.expect("hessian d must be contiguous")
.par_iter_mut(),
)
.enumerate()
.try_for_each(|(i, ((w_out, c_out), d_out))| -> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
let jet =
crate::mixture_link::inverse_link_jet_for_inverse_link(inverse_link, eta_used)?;
let h4 = crate::mixture_link::inverse_link_pdfthird_derivative_for_inverse_link(
inverse_link, eta_used,
)?;
let (w_obs, c_obs, d_obs) = observed_weight_dispatch(
weight_family,
weight_link,
eta_used,
y[i],
jet.mu,
phi,
priorweights[i].max(0.0),
jet,
h4,
);
let fisher_weight = fisher_weights[i].max(0.0);
if !(w_obs.is_finite() && w_obs > 0.0) {
crate::bail_invalid_estim!(
"observed Hessian curvature is not positive finite at row {i}: observed={w_obs}, fisher={fisher_weight}"
);
}
if !c_obs.is_finite() || !d_obs.is_finite() {
crate::bail_invalid_estim!(
"observed Hessian curvature derivatives are non-finite at row {i}: c={c_obs}, d={d_obs}"
);
}
*w_out = w_obs;
*c_out = c_obs;
*d_out = d_obs;
Ok(())
})
}
pub(crate) fn compute_observed_hessian_curvature_arrays(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
y: ArrayView1<'_, f64>,
fisher_weights: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>), EstimationError> {
let n = eta.len();
let mut hessian_weights = Array1::<f64>::zeros(n);
let mut hessian_c = Array1::<f64>::zeros(n);
let mut hessian_d = Array1::<f64>::zeros(n);
compute_observed_hessian_curvature_arrays_into(
likelihood,
inverse_link,
eta,
y,
fisher_weights,
priorweights,
&mut hessian_weights,
&mut hessian_c,
&mut hessian_d,
)?;
Ok((hessian_weights, hessian_c, hessian_d))
}
#[inline]
pub fn observed_weight_noncanonical(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> (f64, f64, f64) {
let VarianceJet {
v,
v1,
v2,
v3,
v4: _,
} = vj;
let phi_v = phi * v;
let phi_v2 = phi * v * v;
let phi_v3 = phi * v * v * v;
let h1_sq = h1 * h1;
let w_f = h1_sq / phi_v;
let n0 = h1_sq; let n1 = 2.0 * h1 * h2; let n2 = 2.0 * (h2 * h2 + h1 * h3); let vd1 = h1 * v1; let vd2 = h2 * v1 + h1_sq * v2;
let c_f = (n1 * v - n0 * vd1) / phi_v2;
let numer_cf = n1 * v - n0 * vd1;
let dnumer_cf = n2 * v - n0 * vd2;
let d_f = (dnumer_cf * v - 2.0 * numer_cf * vd1) / (phi_v3);
let b_num = h2 * v - h1_sq * v1;
let b = b_num / phi_v2;
let b_eta_num =
h3 * v * v - 3.0 * h1 * h2 * v * v1 - h1_sq * h1 * v * v2 + 2.0 * h1_sq * h1 * v1 * v1;
let b_eta = b_eta_num / phi_v3;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let db_eta_num = h4 * v * v + 2.0 * h3 * v * h1 * v1
- 3.0 * (h2 * h2 + h1 * h3) * v * v1
- 3.0 * h1 * h2 * (h1 * v1 * v1 + v * h1 * v2)
- 3.0 * h1_sq * h2 * v * v2
- h1_cu * (h1 * v1 * v2 + v * h1 * v3)
+ 6.0 * h1_sq * h2 * v1 * v1
+ 4.0 * h1_qu * v1 * v2;
let phi_v4 = phi_v3 * v;
let b_etaeta = (db_eta_num * v - 3.0 * b_eta_num * h1 * v1) / phi_v4;
let resid = y - mu;
let w_obs = w_f - resid * b;
let c_obs = c_f + h1 * b - resid * b_eta;
let d_obs = d_f + h2 * b + 2.0 * h1 * b_eta - resid * b_etaeta;
(pw * w_obs, pw * c_obs, pw * d_obs)
}
#[inline]
pub fn e_obs_from_jets(
y: f64,
mu: f64,
h1: f64,
h2: f64,
h3: f64,
h4: f64,
h5: f64,
vj: VarianceJet,
phi: f64,
pw: f64,
) -> f64 {
let VarianceJet { v, v1, v2, v3, v4 } = vj;
let q = phi * v;
let h1_sq = h1 * h1;
let h1_cu = h1_sq * h1;
let h1_qu = h1_sq * h1_sq;
let q1 = phi * v1 * h1;
let q2 = phi * (v1 * h2 + v2 * h1_sq);
let q3 = phi * (v1 * h3 + 3.0 * v2 * h1 * h2 + v3 * h1_cu);
let q4 = phi
* (v1 * h4 + 4.0 * v2 * h1 * h3 + 3.0 * v2 * h2 * h2 + 6.0 * v3 * h1_sq * h2 + v4 * h1_qu);
let t0 = h1 / q;
let t1 = (h2 - t0 * q1) / q;
let t2 = (h3 - 2.0 * t1 * q1 - t0 * q2) / q;
let t3 = (h4 - 3.0 * t2 * q1 - 3.0 * t1 * q2 - t0 * q3) / q;
let t4 = (h5 - 4.0 * t3 * q1 - 6.0 * t2 * q2 - 4.0 * t1 * q3 - t0 * q4) / q;
let w_f3 = h1 * t3 + 3.0 * h2 * t2 + 3.0 * h3 * t1 + h4 * t0;
let resid = y - mu;
let e_obs = w_f3 + h3 * t1 + 3.0 * h2 * t2 + 3.0 * h1 * t3 - resid * t4;
pw * e_obs
}
#[inline]
pub fn observed_weight_gaussian_log(y: f64, mu: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let inv_phi = pw / phi;
let w = inv_phi * mu * (2.0 * mu - y);
let c = inv_phi * mu * (4.0 * mu - y);
let d = inv_phi * mu * (8.0 * mu - y);
(w, c, d)
}
#[inline]
pub fn observed_weight_gaussian_inverse(y: f64, eta: f64, phi: f64, pw: f64) -> (f64, f64, f64) {
let eta2 = eta * eta;
let eta4 = eta2 * eta2;
let eta5 = eta4 * eta;
let eta6 = eta4 * eta2;
let ey = eta * y;
let inv_phi = pw / phi;
let w = inv_phi * (3.0 - 2.0 * ey) / eta4;
let c = inv_phi * 6.0 * (ey - 2.0) / eta5;
let d = inv_phi * 12.0 * (5.0 - 2.0 * ey) / eta6;
(w, c, d)
}
#[inline]
pub(crate) fn observed_weight_binomial_logit_from_jet(
n_trials: f64,
jet: MixtureInverseLinkJet,
pw: f64,
) -> (f64, f64, f64) {
let scale = pw * n_trials;
(scale * jet.d1, scale * jet.d2, scale * jet.d3)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightFamily {
Gaussian,
Binomial,
Poisson,
Tweedie { p: f64 },
NegativeBinomial { theta: f64 },
Beta { phi: f64 },
Gamma,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightLink {
Identity,
Log,
Logit,
Inverse,
Other,
}
#[inline]
pub fn variance_jet_for_weight_family(family: WeightFamily, mu: f64) -> VarianceJet {
match family {
WeightFamily::Gaussian => VarianceJet::gaussian(),
WeightFamily::Binomial => VarianceJet::binomial_n(mu),
WeightFamily::Poisson => VarianceJet::poisson(mu),
WeightFamily::Tweedie { p } => VarianceJet::tweedie(mu, p),
WeightFamily::NegativeBinomial { theta } => VarianceJet::negative_binomial(mu, theta),
WeightFamily::Beta { phi } => VarianceJet::beta(mu, phi),
WeightFamily::Gamma => VarianceJet::gamma(mu),
}
}
pub fn observed_weight_dispatch(
family: WeightFamily,
link: WeightLink,
eta: f64,
y: f64,
mu: f64,
phi: f64,
prior_weight: f64,
jet: MixtureInverseLinkJet,
h4: f64,
) -> (f64, f64, f64) {
match (family, link) {
(WeightFamily::Gaussian, WeightLink::Log) => {
observed_weight_gaussian_log(y, mu, phi, prior_weight)
}
(WeightFamily::Gaussian, WeightLink::Inverse) => {
observed_weight_gaussian_inverse(y, eta, phi, prior_weight)
}
(WeightFamily::Binomial, WeightLink::Logit) => {
observed_weight_binomial_logit_from_jet(1.0, jet, prior_weight)
}
_ => {
let vj = variance_jet_for_weight_family(family, mu);
observed_weight_noncanonical(y, mu, jet.d1, jet.d2, jet.d3, h4, vj, phi, prior_weight)
}
}
}
#[derive(Clone)]
pub enum DirectionalWorkingCurvature {
Diagonal(Array1<f64>),
}
pub fn directionalworking_curvature_from_c_array(
c_array: &Array1<f64>,
hessian_weights: &Array1<f64>,
eta_direction: &Array1<f64>,
) -> DirectionalWorkingCurvature {
let mut w_direction = c_array * eta_direction;
for i in 0..w_direction.len() {
if hessian_weights[i] <= 0.0 || !w_direction[i].is_finite() {
w_direction[i] = 0.0;
}
}
DirectionalWorkingCurvature::Diagonal(w_direction)
}