use super::*;
use crate::types::MIN_WEIGHT;
pub fn update_glmvectors(
y: ArrayView1<f64>,
eta: &Array1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if matches!(link, LinkFunction::Logit)
&& inverse_link.mixture_state().is_none()
&& inverse_link.sas_state().is_none()
{
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
true,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
},
);
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((mu_o, w_o), z_o))| {
let eta_raw = eta[i];
let eta_c = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_c);
let geom = bernoulli_logit_geometry_from_jet(
eta_raw,
eta_c,
y[i],
priorweights[i],
jet,
true,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
});
}
return Ok(());
}
match link {
LinkFunction::Logit
| LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => {
let zero_on_nonsmooth = matches!(link, LinkFunction::Logit);
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(
i,
(((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o),
)|
-> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let eta_used = eta_for_observed_hessian_jet(inverse_link, eta[i]);
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
zero_on_nonsmooth,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
y[i],
priorweights[i],
jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
}
Ok(())
})?;
}
Ok(())
}
LinkFunction::Identity => {
write_identityworking_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
LinkFunction::Log => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
}
}
#[inline]
pub fn update_glmvectors_by_family(
y: ArrayView1<f64>,
eta: &Array1<f64>,
likelihood: &GlmLikelihoodSpec,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
) -> Result<(), EstimationError> {
likelihood.irls_update(y, eta, priorweights, mu, weights, z, None, None)
}
pub(crate) fn integrated_inverse_link_from_family(
spec: &LikelihoodSpec,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<InverseLink, EstimationError> {
match (&spec.response, &spec.link) {
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit))
| (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit))
| (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
Ok(spec.link.clone())
}
(ResponseFamily::Binomial, InverseLink::Sas(_)) => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialSas update requires explicit SasLinkState".to_string(),
)
})?;
Ok(InverseLink::Sas(*state))
}
(ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {
let state = sas_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialBetaLogistic update requires explicit SasLinkState"
.to_string(),
)
})?;
Ok(InverseLink::BetaLogistic(*state))
}
(ResponseFamily::Binomial, InverseLink::Mixture(_)) => {
let state = mixture_link_state.ok_or_else(|| {
EstimationError::InvalidInput(
"Integrated BinomialMixture update requires explicit MixtureLinkState"
.to_string(),
)
})?;
Ok(InverseLink::Mixture(state.clone()))
}
_ => Err(EstimationError::InvalidInput(format!(
"Integrated link-runtime update is not supported for likelihood (response={:?}, link={:?})",
spec.response, spec.link
))),
}
}
#[inline]
pub fn update_glmvectors_integrated_for_link(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
inverse_link: &InverseLink,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
let link = inverse_link.link_function();
if !matches!(
inverse_link,
InverseLink::Standard(StandardLink::Logit)
| InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
) {
crate::bail_invalid_estim!(
"Integrated link-runtime update is not supported for inverse link {:?}",
inverse_link
);
}
if let Some(mut derivs) = derivatives {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
let WorkingDerivSlices {
c: c_s,
d: d_s,
dmu: dmu_s,
d2: d2_s,
d3: d3_s,
} = working_deriv_slices(&mut derivs);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.zip(c_s.par_iter_mut())
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, (((((((mu_o, w_o), z_o), c_o), d_o), dmu_o), d2_o), d3_o))|
-> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::quadrature::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(StandardLink::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let geom = bernoulli_geometry_from_jet(
eta[i],
e,
y[i],
priorweights[i],
local_jet,
);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = local_jet.d1;
*d2_o = local_jet.d2;
*d3_o = local_jet.d3;
Ok(())
},
)?;
} else {
let WorkingSlices {
mu: mu_s,
weights: weights_s,
z: z_s,
} = working_slices(mu, weights, z);
mu_s.par_iter_mut()
.zip(weights_s.par_iter_mut())
.zip(z_s.par_iter_mut())
.enumerate()
.try_for_each(|(i, ((mu_o, w_o), z_o))| -> Result<(), EstimationError> {
let jet = if let InverseLink::LatentCLogLog(state) = inverse_link {
crate::quadrature::latent_cloglog_inverse_link_jet(
quadctx,
eta[i],
se[i].hypot(state.latent_sd),
)?
} else if matches!(inverse_link, InverseLink::Standard(StandardLink::Logit)) {
crate::quadrature::integrated_logit_inverse_link_jet_pirls(
quadctx, eta[i], se[i],
)?
} else {
crate::quadrature::integrated_inverse_link_jetwith_state(
quadctx,
link,
eta[i],
se[i],
inverse_link.mixture_state(),
inverse_link.sas_state(),
)?
};
let local_jet = MixtureInverseLinkJet {
mu: jet.mean,
d1: jet.d1,
d2: jet.d2,
d3: jet.d3,
};
let e = eta[i].clamp(-ETA_CLAMP, ETA_CLAMP);
let geom = bernoulli_geometry_from_jet(eta[i], e, y[i], priorweights[i], local_jet);
*mu_o = geom.mu;
*w_o = geom.weight;
*z_o = geom.z;
Ok(())
})?;
}
Ok(())
}
#[inline]
pub fn update_glmvectors_integrated_by_family(
quadctx: &crate::quadrature::QuadratureContext,
y: ArrayView1<f64>,
eta: &Array1<f64>,
se: ArrayView1<f64>,
spec: &LikelihoodSpec,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
mixture_link_state: Option<&MixtureLinkState>,
sas_link_state: Option<&SasLinkState>,
) -> Result<(), EstimationError> {
let inverse_link =
integrated_inverse_link_from_family(spec, mixture_link_state, sas_link_state)?;
update_glmvectors_integrated_for_link(
quadctx,
y,
eta,
se,
&inverse_link,
priorweights,
mu,
weights,
z,
derivatives,
)
}
pub(crate) fn computeworkingweight_derivatives_from_eta(
likelihood: &GlmLikelihoodSpec,
inverse_link: &InverseLink,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<
(
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
Array1<f64>,
),
EstimationError,
> {
let n = eta.len();
let mut c = Array1::<f64>::zeros(n);
let mut d = Array1::<f64>::zeros(n);
let mut dmu_deta = Array1::<f64>::zeros(n);
let mut d2mu_deta2 = Array1::<f64>::zeros(n);
let mut d3mu_deta3 = Array1::<f64>::zeros(n);
match &likelihood.spec.response {
ResponseFamily::Gaussian => {
dmu_deta.fill(1.0);
}
ResponseFamily::Poisson => {
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::PoissonIdentity,
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 1.0,
d_ratio: 1.0,
},
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Tweedie { p } => {
let p = *p;
let phi = fixed_glm_dispersion(likelihood);
if !is_valid_tweedie_power(p) {
crate::bail_invalid_estim!(
"Tweedie variance power must be finite and strictly between 1 and 2; got {p}",
p = p
);
}
if !(phi.is_finite() && phi > 0.0) {
crate::bail_invalid_estim!(
"Tweedie dispersion phi must be finite and > 0; got {phi}",
phi = phi
);
}
let exponent = 2.0 - p;
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::TweediePower { p, phi },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: exponent,
d_ratio: exponent * exponent,
},
floor_weight: true,
zero_mu_jet_on_clamp: true,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::NegativeBinomial { theta, .. } => {
let theta = *theta;
if !valid_negbin_theta(theta) {
crate::bail_invalid_estim!(
"negative-binomial theta must be finite and > 0; got {theta}",
theta = theta
);
}
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::NegativeBinomial { theta },
curvature: log_link_working_state::WorkingCurvature::NegativeBinomial { theta },
floor_weight: true,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Beta { phi } => {
let phi = *phi;
if !valid_beta_phi(phi) {
crate::bail_invalid_estim!("beta-regression phi must be finite and > 0; got {phi}");
}
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.for_each(|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| {
let eta_raw = eta[i];
let eta_i = eta_raw.clamp(-ETA_CLAMP, ETA_CLAMP);
let jet = logit_inverse_link_jet5(eta_i);
let mu_i = safe_beta_mu(jet.mu);
let q = (mu_i * (1.0 - mu_i)).max(BETA_MU_EPS);
let a = (mu_i * phi).max(BETA_MU_EPS);
let b = ((1.0 - mu_i) * phi).max(BETA_MU_EPS);
let trigamma_sum = trigamma(a) + trigamma(b);
let prior_weight = priorweights[i].max(0.0);
let raw_weight = prior_weight * q * q * phi * phi * trigamma_sum;
let floor_active = raw_weight > 0.0 && raw_weight <= MIN_WEIGHT;
if floor_active || eta_raw != eta_i {
*c_o = 0.0;
*d_o = 0.0;
} else {
let (c_i, d_i) = beta_logit_working_curvature_eta_derivatives(
prior_weight,
phi,
mu_i,
q,
a,
b,
trigamma_sum,
);
*c_o = c_i;
*d_o = d_i;
}
*dmu_o = q;
*d2_o = q * (1.0 - 2.0 * mu_i);
*d3_o = q * (1.0 - 6.0 * q);
});
}
ResponseFamily::Gamma => {
log_link_working_state::write_log_link_eta_curvature(
&log_link_working_state::LogLinkRule {
weight: log_link_working_state::WorkingWeight::Constant { factor: 1.0 },
curvature: log_link_working_state::WorkingCurvature::Proportional {
c_ratio: 0.0,
d_ratio: 0.0,
},
floor_weight: false,
zero_mu_jet_on_clamp: false,
},
inverse_link,
eta,
priorweights,
WorkingDerivativeBuffersMut {
c: &mut c,
d: &mut d,
dmu_deta: &mut dmu_deta,
d2mu_deta2: &mut d2mu_deta2,
d3mu_deta3: &mut d3mu_deta3,
},
)?;
}
ResponseFamily::Binomial => {
let link = inverse_link.link_function();
let zero_on_nonsmooth = matches!(link, LinkFunction::Logit);
let c_s = c.as_slice_mut().expect("c must be contiguous");
let d_s = d.as_slice_mut().expect("d must be contiguous");
let dmu_s = dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous");
let d2_s = d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous");
let d3_s = d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous");
c_s.par_iter_mut()
.zip(d_s.par_iter_mut())
.zip(dmu_s.par_iter_mut())
.zip(d2_s.par_iter_mut())
.zip(d3_s.par_iter_mut())
.enumerate()
.try_for_each(
|(i, ((((c_o, d_o), dmu_o), d2_o), d3_o))| -> Result<(), EstimationError> {
let eta_used = match link {
LinkFunction::Logit => eta[i].clamp(-ETA_CLAMP, ETA_CLAMP),
LinkFunction::Probit
| LinkFunction::CLogLog
| LinkFunction::Sas
| LinkFunction::BetaLogistic => eta[i].clamp(-30.0, 30.0),
LinkFunction::Log => eta[i].clamp(-ETA_CLAMP, ETA_CLAMP),
LinkFunction::Identity => eta[i],
};
if matches!(link, LinkFunction::Logit) {
let jet = logit_inverse_link_jet5(eta_used);
let geom = bernoulli_logit_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
zero_on_nonsmooth,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
} else {
let jet = standard_inverse_link_jet(inverse_link, eta_used)?;
let geom = bernoulli_geometry_from_jet(
eta[i],
eta_used,
jet.mu,
priorweights[i],
jet,
);
*c_o = geom.c;
*d_o = geom.d;
*dmu_o = jet.d1;
*d2_o = jet.d2;
*d3_o = jet.d3;
}
Ok(())
},
)?;
}
ResponseFamily::RoystonParmar => {
crate::bail_invalid_estim!(
"RoystonParmar is survival-specific and not a GLM IRLS family"
);
}
}
Ok((c, d, dmu_deta, d2mu_deta2, d3mu_deta3))
}