use super::*;
pub trait WorkingModel {
fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError>;
fn update_with_curvature(
&mut self,
beta: &Coefficients,
_: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.update(beta)
}
fn update_candidate(
&mut self,
beta: &Coefficients,
curvature: HessianCurvatureKind,
) -> Result<WorkingState, EstimationError> {
self.update_with_curvature(beta, curvature)
}
fn screen_candidate(
&mut self,
beta: &Coefficients,
arr: &Array1<f64>,
_: &LinearPredictor,
curvature: HessianCurvatureKind,
) -> Result<CandidateEvaluation, EstimationError> {
assert!(arr.iter().all(|v| !v.is_nan()));
self.update_candidate(beta, curvature)
.map(CandidateEvaluation::Full)
}
fn supports_observed_information_curvature(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct CandidateScreen {
pub penalized_objective: f64,
pub deviance: f64,
pub penalty_term: f64,
pub arithmetic_finite: bool,
}
pub enum CandidateEvaluation {
Screen(CandidateScreen),
Full(WorkingState),
}
impl CandidateEvaluation {
#[inline]
pub(crate) fn penalized_objective(&self, firth_bias_reduction: bool) -> f64 {
match self {
Self::Screen(s) => s.penalized_objective,
Self::Full(state) => {
let mut value = state.deviance + state.penalty_term;
if firth_bias_reduction && let Some(j) = state.jeffreys_logdet() {
value -= 2.0 * j;
}
value
}
}
}
#[inline]
pub(crate) fn arithmetic_finite(&self) -> bool {
match self {
Self::Screen(s) => s.arithmetic_finite,
Self::Full(state) => state.gradient.iter().all(|g| g.is_finite()),
}
}
#[inline]
pub(crate) fn into_full(self) -> Option<WorkingState> {
match self {
Self::Full(state) => Some(state),
Self::Screen(_) => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(super) struct PirlsAcceptedStateCacheKey {
curvature: HessianCurvatureKind,
firth_active: bool,
beta_bits: Vec<u64>,
arrow_latent_bits: Option<Vec<u64>>,
}
impl PirlsAcceptedStateCacheKey {
pub(crate) fn requested(
beta: &Coefficients,
curvature: HessianCurvatureKind,
options: &WorkingModelPirlsOptions,
) -> Self {
Self::new(beta, curvature, options.firth_bias_reduction, options)
}
pub(crate) fn accepted(
beta: &Coefficients,
state: &WorkingState,
options: &WorkingModelPirlsOptions,
) -> Self {
Self::new(
beta,
state.hessian_curvature,
matches!(state.firth, FirthDiagnostics::Active { .. }),
options,
)
}
pub(crate) fn new(
beta: &Coefficients,
curvature: HessianCurvatureKind,
firth_active: bool,
options: &WorkingModelPirlsOptions,
) -> Self {
let arrow_latent_bits = options.arrow_schur.as_ref().map(|arrow_cfg| {
arrow_cfg.snapshot_t.as_ref()()
.iter()
.map(|value| value.to_bits())
.collect()
});
Self {
curvature,
firth_active,
beta_bits: beta.as_ref().iter().map(|value| value.to_bits()).collect(),
arrow_latent_bits,
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct IntegratedWorkingInput<'a> {
pub quadctx: &'a crate::quadrature::QuadratureContext,
pub se: ArrayView1<'a, f64>,
pub mixture_link_state: Option<&'a MixtureLinkState>,
pub sas_link_state: Option<&'a SasLinkState>,
}
pub struct WorkingDerivativeBuffersMut<'a> {
pub(crate) c: &'a mut Array1<f64>,
pub(crate) d: &'a mut Array1<f64>,
pub(crate) dmu_deta: &'a mut Array1<f64>,
pub(crate) d2mu_deta2: &'a mut Array1<f64>,
pub(crate) d3mu_deta3: &'a mut Array1<f64>,
}
pub(super) struct WorkingSlices<'a> {
pub mu: &'a mut [f64],
pub weights: &'a mut [f64],
pub z: &'a mut [f64],
}
pub(super) struct WorkingDerivSlices<'a> {
pub c: &'a mut [f64],
pub d: &'a mut [f64],
pub dmu: &'a mut [f64],
pub d2: &'a mut [f64],
pub d3: &'a mut [f64],
}
#[inline]
pub(super) fn working_slices<'a>(
mu: &'a mut Array1<f64>,
weights: &'a mut Array1<f64>,
z: &'a mut Array1<f64>,
) -> WorkingSlices<'a> {
WorkingSlices {
mu: mu.as_slice_mut().expect("mu must be contiguous"),
weights: weights.as_slice_mut().expect("weights must be contiguous"),
z: z.as_slice_mut().expect("z must be contiguous"),
}
}
#[inline]
pub(super) fn working_deriv_slices<'a>(
derivs: &'a mut WorkingDerivativeBuffersMut<'_>,
) -> WorkingDerivSlices<'a> {
WorkingDerivSlices {
c: derivs.c.as_slice_mut().expect("c must be contiguous"),
d: derivs.d.as_slice_mut().expect("d must be contiguous"),
dmu: derivs
.dmu_deta
.as_slice_mut()
.expect("dmu_deta must be contiguous"),
d2: derivs
.d2mu_deta2
.as_slice_mut()
.expect("d2mu_deta2 must be contiguous"),
d3: derivs
.d3mu_deta3
.as_slice_mut()
.expect("d3mu_deta3 must be contiguous"),
}
}
#[derive(Clone, Copy)]
pub(crate) struct WorkingBernoulliGeometry {
pub(crate) mu: f64,
pub(crate) weight: f64,
pub(crate) z: f64,
pub(crate) c: f64,
pub(crate) d: f64,
}
pub(crate) trait WorkingLikelihood {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError>;
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError>;
}
impl WorkingLikelihood for GlmLikelihoodSpec {
fn irls_update(
&self,
y: ArrayView1<f64>,
eta: &Array1<f64>,
priorweights: ArrayView1<f64>,
mu: &mut Array1<f64>,
weights: &mut Array1<f64>,
z: &mut Array1<f64>,
integrated: Option<IntegratedWorkingInput<'_>>,
derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
) -> Result<(), EstimationError> {
match (&self.spec.response, &self.spec.link, integrated.is_some()) {
(ResponseFamily::Binomial, _, true) => {
let integ = integrated.unwrap();
update_glmvectors_integrated_by_family(
integ.quadctx,
y,
eta,
integ.se,
&self.spec,
priorweights,
mu,
weights,
z,
derivatives,
integ.mixture_link_state,
integ.sas_link_state,
)?;
Ok(())
}
(ResponseFamily::Binomial, link, false) => {
if matches!(link, InverseLink::Mixture(_)) {
crate::bail_invalid_estim!(
"BinomialMixture IRLS update requires explicit mixture link state"
.to_string(),
);
}
update_glmvectors(
y,
eta,
&self.spec.link,
priorweights,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Gaussian, _, _) => {
update_glmvectors(
y,
eta,
&InverseLink::Standard(StandardLink::Identity),
priorweights,
mu,
weights,
z,
None,
)?;
if let Some(phi) = self.scale.fixed_phi() {
if !(phi.is_finite() && phi > 0.0) {
crate::bail_invalid_estim!(
"Gaussian fixed dispersion phi must be finite and positive (got {})",
phi
);
}
if phi != 1.0 {
let inv_phi = 1.0 / phi;
weights.mapv_inplace(|w| w * inv_phi);
}
}
Ok(())
}
(ResponseFamily::Poisson, _, _) => {
write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
Ok(())
}
(ResponseFamily::Tweedie { p }, _, _) => {
let p = *p;
write_tweedie_log_working_state(
y,
eta,
priorweights,
p,
fixed_glm_dispersion(self),
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::NegativeBinomial { theta, .. }, _, _) => {
let theta = *theta;
write_negative_binomial_log_working_state(
y,
eta,
priorweights,
theta,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Beta { phi }, _, _) => {
let phi = *phi;
write_beta_logit_working_state(
y,
eta,
priorweights,
phi,
mu,
weights,
z,
derivatives,
)?;
Ok(())
}
(ResponseFamily::Gamma, _, _) => {
write_gamma_log_working_state(
y,
eta,
priorweights,
self.gamma_shape().unwrap_or(1.0),
mu,
weights,
z,
derivatives,
);
Ok(())
}
(ResponseFamily::RoystonParmar, _, _) => Err(EstimationError::InvalidInput(
"RoystonParmar is survival-specific and not a GLM IRLS family".to_string(),
)),
}
}
fn loglik_deviance(
&self,
y: ArrayView1<f64>,
mu: &Array1<f64>,
priorweights: ArrayView1<f64>,
) -> Result<f64, EstimationError> {
if matches!(self.spec.response, ResponseFamily::Tweedie { .. }) {
validate_tweedie_responses(&y, &priorweights)?;
}
Ok(calculate_deviance(y, mu, self, priorweights))
}
}