use super::*;
pub(crate) const GAMMA_SHAPE_MIN: f64 = 1e-8;
pub(crate) const GAMMA_SHAPE_MAX: f64 = 1e12;
pub(crate) const GAMMA_SHAPE_TARGET_TOL: f64 = 1e-12;
pub(super) const PIRLS_ETA_ABS_CAP: f64 = 40.0;
#[inline]
pub(crate) fn gamma_shape_score(shape: f64, target: f64) -> f64 {
shape.ln() - digamma(shape) - target
}
pub(crate) fn estimate_gamma_shape_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
const EPS: f64 = 1e-12;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_target, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let yi = y[i].max(EPS);
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(EPS);
let ratio = yi / mui;
(wi * (ratio - ratio.ln() - 1.0), wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(t1, w1), (t2, w2)| (t1 + t2, w1 + w2),
);
if total_weight <= 0.0 {
return 1.0;
}
let target = (weighted_target / total_weight).max(0.0);
if target <= GAMMA_SHAPE_TARGET_TOL {
return GAMMA_SHAPE_MAX;
}
let discriminant = (target - 3.0) * (target - 3.0) + 24.0 * target;
let approx = ((3.0 - target) + discriminant.sqrt()) / (12.0 * target);
let mut lo = GAMMA_SHAPE_MIN;
let mut hi = approx.max(1.0);
while hi < GAMMA_SHAPE_MAX && gamma_shape_score(hi, target) > 0.0 {
hi = (hi * 2.0).min(GAMMA_SHAPE_MAX);
}
if gamma_shape_score(hi, target) > 0.0 {
return GAMMA_SHAPE_MAX;
}
for _ in 0..80 {
let mid = 0.5 * (lo + hi);
if gamma_shape_score(mid, target) > 0.0 {
lo = mid;
} else {
hi = mid;
}
if (hi - lo) <= GAMMA_SHAPE_TARGET_TOL * hi.max(1.0) {
break;
}
}
0.5 * (lo + hi)
}
pub(crate) fn estimate_beta_phi_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
const PHI_MIN: f64 = 1e-3;
const PHI_MAX: f64 = 1e6;
const MU_EPS: f64 = 1e-9;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_pearson, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let mui = (1.0 / (1.0 + (-eta[i].clamp(-ETA_CLAMP, ETA_CLAMP)).exp()))
.clamp(MU_EPS, 1.0 - MU_EPS);
let var_unit = mui * (1.0 - mui);
let resid = y[i] - mui;
(wi * resid * resid / var_unit, wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(p1, w1), (p2, w2)| (p1 + p2, w1 + w2),
);
if total_weight <= 0.0 || weighted_pearson <= 0.0 {
return 1.0;
}
let one_plus_phi = (total_weight / weighted_pearson).max(1.0 + PHI_MIN);
(one_plus_phi - 1.0).clamp(PHI_MIN, PHI_MAX)
}
pub(crate) fn estimate_tweedie_phi_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
p: f64,
) -> f64 {
const PHI_MIN: f64 = 1e-6;
const PHI_MAX: f64 = 1e12;
const MU_EPS: f64 = 1e-300;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (weighted_pearson, total_weight) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(MU_EPS);
let resid = y[i] - mui;
let var_unit = mui.powf(p).max(MU_EPS);
(wi * resid * resid / var_unit, wi)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(p1, w1), (p2, w2)| (p1 + p2, w1 + w2),
);
if total_weight <= 0.0 || !weighted_pearson.is_finite() || weighted_pearson <= 0.0 {
return 1.0;
}
(weighted_pearson / total_weight).clamp(PHI_MIN, PHI_MAX)
}
pub(crate) const NEGBIN_THETA_MIN: f64 = 1e-3;
pub(crate) const NEGBIN_THETA_MAX: f64 = 1e6;
pub(crate) fn negbin_theta_score_and_info(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
theta: f64,
) -> (f64, f64) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let psi_theta = digamma(theta);
let trigamma_theta = trigamma(theta);
let ln_theta = theta.ln();
let inv_theta = 1.0 / theta;
let (score, info) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64);
}
let yi = y[i];
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(1e-300);
let theta_plus_mu = theta + mui;
let theta_plus_y = theta + yi;
let s = digamma(yi + theta) - psi_theta + ln_theta + 1.0
- theta_plus_mu.ln()
- theta_plus_y / theta_plus_mu;
let info_row = -trigamma(yi + theta) + trigamma_theta - inv_theta + 2.0 / theta_plus_mu
- theta_plus_y / (theta_plus_mu * theta_plus_mu);
(wi * s, wi * info_row)
})
.reduce(
|| (0.0_f64, 0.0_f64),
|(s1, i1), (s2, i2)| (s1 + s2, i1 + i2),
);
(score, info)
}
pub(crate) fn estimate_negbin_theta_from_eta(
y: ArrayView1<'_, f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<'_, f64>,
) -> f64 {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let (wsum, wmu, wpearson) = (0..eta.len())
.into_par_iter()
.map(|i| {
let wi = priorweights[i].max(0.0);
if wi == 0.0 {
return (0.0_f64, 0.0_f64, 0.0_f64);
}
let mui = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP).exp().max(1e-300);
let resid = y[i] - mui;
(wi, wi * mui, wi * resid * resid / mui)
})
.reduce(
|| (0.0_f64, 0.0_f64, 0.0_f64),
|(a1, b1, c1), (a2, b2, c2)| (a1 + a2, b1 + b2, c1 + c2),
);
if wsum <= 0.0 {
return 1.0;
}
let mu_bar = wmu / wsum;
let pearson_ratio = wpearson / wsum;
let mut theta = if pearson_ratio > 1.0 + 1e-6 {
(mu_bar / (pearson_ratio - 1.0)).clamp(NEGBIN_THETA_MIN, NEGBIN_THETA_MAX)
} else {
NEGBIN_THETA_MAX
};
let (score_hi, _) = negbin_theta_score_and_info(y, eta, priorweights, NEGBIN_THETA_MAX);
if !score_hi.is_finite() {
return 1.0;
}
if score_hi >= 0.0 {
return NEGBIN_THETA_MAX;
}
let (score_lo, _) = negbin_theta_score_and_info(y, eta, priorweights, NEGBIN_THETA_MIN);
if !score_lo.is_finite() || score_lo <= 0.0 {
return NEGBIN_THETA_MIN;
}
let mut lo = NEGBIN_THETA_MIN;
let mut hi = NEGBIN_THETA_MAX;
theta = theta.clamp(lo, hi);
const MAX_NEWTON_ITERS: usize = 100;
const REL_TOL: f64 = 1e-10;
for _ in 0..MAX_NEWTON_ITERS {
let (score, info) = negbin_theta_score_and_info(y, eta, priorweights, theta);
if !score.is_finite() {
break;
}
if score > 0.0 {
lo = theta;
} else {
hi = theta;
}
let next = if info.is_finite() && info > 0.0 {
let candidate = theta + score / info;
if candidate > lo && candidate < hi {
candidate
} else {
0.5 * (lo + hi)
}
} else {
0.5 * (lo + hi)
};
if (next - theta).abs() <= REL_TOL * theta.max(1.0) {
theta = next;
break;
}
theta = next;
}
theta.clamp(NEGBIN_THETA_MIN, NEGBIN_THETA_MAX)
}