use super::{
WorkingDerivSlices, WorkingDerivativeBuffersMut, WorkingSlices, standard_inverse_link_jet,
working_deriv_slices, working_slices,
};
use crate::estimate::EstimationError;
use crate::types::{InverseLink, MIN_WEIGHT};
use ndarray::{Array1, ArrayView1};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
pub(crate) const MIN_MU: f64 = 1e-10;
pub(crate) const ETA_CLAMP: f64 = 700.0;
pub(super) enum WorkingWeight {
PoissonIdentity,
Constant { factor: f64 },
TweediePower { p: f64, phi: f64 },
NegativeBinomial { theta: f64 },
}
pub(super) enum WorkingCurvature {
Proportional { c_ratio: f64, d_ratio: f64 },
NegativeBinomial { theta: f64 },
}
pub(super) struct LogLinkRule {
pub weight: WorkingWeight,
pub curvature: WorkingCurvature,
pub floor_weight: bool,
pub zero_mu_jet_on_clamp: bool,
}
#[inline]
pub(crate) fn raw_weight(weight: &WorkingWeight, mu: f64, prior_weight: f64) -> f64 {
match *weight {
WorkingWeight::PoissonIdentity => prior_weight * mu,
WorkingWeight::Constant { factor } => prior_weight * factor,
WorkingWeight::TweediePower { p, phi } => {
prior_weight * super::tweedie_log_weight_mu_power(mu, p) / phi
}
WorkingWeight::NegativeBinomial { theta } => {
let negbin_weight = if theta > mu {
mu / (1.0 + mu / theta)
} else {
theta / (1.0 + theta / mu)
};
prior_weight * negbin_weight
}
}
}
#[inline]
pub(crate) fn curvature_terms(curvature: &WorkingCurvature, mu: f64, weight: f64) -> (f64, f64) {
match *curvature {
WorkingCurvature::Proportional { c_ratio, d_ratio } => (c_ratio * weight, d_ratio * weight),
WorkingCurvature::NegativeBinomial { theta } => {
let denom = theta + mu;
(
weight * theta / denom,
weight * theta * (theta - mu) / (denom * denom),
)
}
}
}
pub(super) fn write_log_link_working_state(
rule: &LogLinkRule,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) {
let floor_weight = rule.floor_weight;
let zero_mu_jet_on_clamp = rule.zero_mu_jet_on_clamp;
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), dmu_o), d2_o), d3_o), c_o), d_o))| {
let eta_raw = eta[i];
let eta_i = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let clamp_active = eta_raw != eta_i;
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
let raw_weight = raw_weight(&rule.weight, mu_i, priorweights[i].max(0.0));
let floor_active = floor_weight && raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
*w_o = if raw_weight > 0.0 {
if floor_weight {
raw_weight.max(MIN_WEIGHT)
} else {
raw_weight
}
} else {
0.0
};
*z_o = eta_i + (y[i] - mu_i) / mu_i;
if zero_mu_jet_on_clamp && clamp_active {
*dmu_o = 0.0;
*d2_o = 0.0;
*d3_o = 0.0;
} else {
*dmu_o = mu_i;
*d2_o = mu_i;
*d3_o = mu_i;
}
if floor_active || clamp_active {
*c_o = 0.0;
*d_o = 0.0;
} else {
let (c_i, d_i) = curvature_terms(&rule.curvature, mu_i, *w_o);
*c_o = c_i;
*d_o = d_i;
}
},
);
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_i = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let mu_i = eta_i.exp().max(MIN_MU);
*mu_o = mu_i;
let raw_weight = raw_weight(&rule.weight, mu_i, priorweights[i].max(0.0));
*w_o = if raw_weight > 0.0 {
if floor_weight {
raw_weight.max(MIN_WEIGHT)
} else {
raw_weight
}
} else {
0.0
};
*z_o = eta_i + (y[i] - mu_i) / mu_i;
});
}
}
pub(super) fn write_log_link_eta_curvature(
rule: &LogLinkRule,
inverse_link: &InverseLink,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mut buffers: WorkingDerivativeBuffersMut<'_>,
) -> Result<(), EstimationError> {
let floor_weight = rule.floor_weight;
let zero_mu_jet_on_clamp = rule.zero_mu_jet_on_clamp;
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut buffers);
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| -> Result<(), EstimationError> {
let eta_raw = eta[i];
let eta_used = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let clamp_active = eta_raw != eta_used;
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let raw_w = raw_weight(&rule.weight, jet.mu, priorweights[i].max(0.0));
let floor_active = floor_weight && raw_w > 0.0 && raw_w <= MIN_WEIGHT;
if clamp_active || floor_active {
*c_o = 0.0;
*d_o = 0.0;
} else {
let (c_i, d_i) = curvature_terms(&rule.curvature, jet.mu, raw_w);
*c_o = c_i;
*d_o = d_i;
}
if zero_mu_jet_on_clamp && clamp_active {
*dmu_o = 0.0;
*d2_o = 0.0;
*d3_o = 0.0;
} else {
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)
}