Skip to main content

gam_solve/pirls/
working_model_trait.rs

1//! The `WorkingModel` / `WorkingLikelihood` trait surface plus the shared
2//! working-buffer machinery: candidate-screen results, the accepted-state cache
3//! key, and the contiguous mu/weights/z and Newton-derivative buffer slices that
4//! every per-family working-state writer routes through.
5
6use super::*;
7
8pub trait WorkingModel {
9    fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError>;
10
11    fn update_with_curvature(
12        &mut self,
13        beta: &Coefficients,
14        _: HessianCurvatureKind,
15    ) -> Result<WorkingState, EstimationError> {
16        self.update(beta)
17    }
18
19    fn update_candidate(
20        &mut self,
21        beta: &Coefficients,
22        curvature: HessianCurvatureKind,
23    ) -> Result<WorkingState, EstimationError> {
24        self.update_with_curvature(beta, curvature)
25    }
26
27    fn screen_candidate(
28        &mut self,
29        beta: &Coefficients,
30        arr: &Array1<f64>,
31        _: &LinearPredictor,
32        curvature: HessianCurvatureKind,
33    ) -> Result<CandidateEvaluation, EstimationError> {
34        assert!(arr.iter().all(|v| !v.is_nan()));
35        self.update_candidate(beta, curvature)
36            .map(CandidateEvaluation::Full)
37    }
38
39    fn supports_observed_information_curvature(&self) -> bool {
40        false
41    }
42}
43
44/// Result of a cheap LM-candidate screen: penalized objective + arithmetic
45/// finiteness, without the gradient/Hessian needed for an accepted step.
46#[derive(Debug, Clone)]
47pub struct CandidateScreen {
48    pub penalized_objective: f64,
49    pub deviance: f64,
50    pub penalty_term: f64,
51    pub arithmetic_finite: bool,
52}
53
54/// Outcome of `WorkingModel::screen_candidate`: either a cheap screen result
55/// (LM loop must upgrade with `update_with_curvature` on acceptance) or the
56/// full state when screening was not applicable.
57pub enum CandidateEvaluation {
58    Screen(CandidateScreen),
59    Full(WorkingState),
60}
61
62impl CandidateEvaluation {
63    #[inline]
64    pub(crate) fn penalized_objective(&self, firth_bias_reduction: bool) -> f64 {
65        match self {
66            Self::Screen(s) => s.penalized_objective,
67            Self::Full(state) => {
68                let mut value = state.deviance + state.penalty_term;
69                if firth_bias_reduction && let Some(j) = state.jeffreys_logdet() {
70                    value -= 2.0 * j;
71                }
72                value
73            }
74        }
75    }
76
77    #[inline]
78    pub(crate) fn arithmetic_finite(&self) -> bool {
79        match self {
80            Self::Screen(s) => s.arithmetic_finite,
81            Self::Full(state) => state.gradient.iter().all(|g| g.is_finite()),
82        }
83    }
84
85    #[inline]
86    pub(crate) fn into_full(self) -> Option<WorkingState> {
87        match self {
88            Self::Full(state) => Some(state),
89            Self::Screen(_) => None,
90        }
91    }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub(super) struct PirlsAcceptedStateCacheKey {
96    curvature: HessianCurvatureKind,
97    firth_active: bool,
98    beta_bits: Vec<u64>,
99    arrow_latent_bits: Option<Vec<u64>>,
100}
101
102impl PirlsAcceptedStateCacheKey {
103    pub(crate) fn requested(
104        beta: &Coefficients,
105        curvature: HessianCurvatureKind,
106        options: &WorkingModelPirlsOptions,
107    ) -> Self {
108        Self::new(beta, curvature, options.firth_bias_reduction, options)
109    }
110
111    pub(crate) fn accepted(
112        beta: &Coefficients,
113        state: &WorkingState,
114        options: &WorkingModelPirlsOptions,
115    ) -> Self {
116        Self::new(
117            beta,
118            state.hessian_curvature,
119            matches!(state.firth, FirthDiagnostics::Active { .. }),
120            options,
121        )
122    }
123
124    pub(crate) fn new(
125        beta: &Coefficients,
126        curvature: HessianCurvatureKind,
127        firth_active: bool,
128        options: &WorkingModelPirlsOptions,
129    ) -> Self {
130        let arrow_latent_bits = options.arrow_schur.as_ref().map(|arrow_cfg| {
131            arrow_cfg.snapshot_t.as_ref()()
132                .iter()
133                .map(|value| value.to_bits())
134                .collect()
135        });
136        Self {
137            curvature,
138            firth_active,
139            beta_bits: beta.as_ref().iter().map(|value| value.to_bits()).collect(),
140            arrow_latent_bits,
141        }
142    }
143}
144
145/// Uncertainty inputs for integrated (GHQ) IRLS updates.
146#[derive(Clone, Copy)]
147pub(crate) struct IntegratedWorkingInput<'a> {
148    pub quadctx: &'a crate::quadrature::QuadratureContext,
149    pub se: ArrayView1<'a, f64>,
150    pub mixture_link_state: Option<&'a MixtureLinkState>,
151    pub sas_link_state: Option<&'a SasLinkState>,
152}
153
154pub struct WorkingDerivativeBuffersMut<'a> {
155    pub(crate) c: &'a mut Array1<f64>,
156    pub(crate) d: &'a mut Array1<f64>,
157    pub(crate) dmu_deta: &'a mut Array1<f64>,
158    pub(crate) d2mu_deta2: &'a mut Array1<f64>,
159    pub(crate) d3mu_deta3: &'a mut Array1<f64>,
160}
161
162/// Contiguous mutable views of the three core working buffers (`mu`, `weights`,
163/// `z`) shared by every PIRLS working-state writer.
164pub(super) struct WorkingSlices<'a> {
165    pub mu: &'a mut [f64],
166    pub weights: &'a mut [f64],
167    pub z: &'a mut [f64],
168}
169
170/// Contiguous mutable views of the Newton derivative/curvature buffers
171/// (`c`, `d`, `dmu/deta` jet) shared by the full-derivative PIRLS writers.
172pub(super) struct WorkingDerivSlices<'a> {
173    pub c: &'a mut [f64],
174    pub d: &'a mut [f64],
175    pub dmu: &'a mut [f64],
176    pub d2: &'a mut [f64],
177    pub d3: &'a mut [f64],
178}
179
180/// Canonical "contiguous-or-panic" unpacking of the three core working buffers.
181///
182/// Single source of truth for the contiguity contract and panic messages that
183/// every working-state writer relies on; every writer routes through this.
184#[inline]
185pub(super) fn working_slices<'a>(
186    mu: &'a mut Array1<f64>,
187    weights: &'a mut Array1<f64>,
188    z: &'a mut Array1<f64>,
189) -> WorkingSlices<'a> {
190    WorkingSlices {
191        mu: mu.as_slice_mut().expect("mu must be contiguous"),
192        weights: weights.as_slice_mut().expect("weights must be contiguous"),
193        z: z.as_slice_mut().expect("z must be contiguous"),
194    }
195}
196
197/// Canonical "contiguous-or-panic" unpacking of the Newton derivative buffers.
198///
199/// Single source of truth for the contiguity contract and panic messages of the
200/// `c`/`d`/`dmu`/`d2`/`d3` curvature buffers; every full-derivative writer routes
201/// through this.
202#[inline]
203pub(super) fn working_deriv_slices<'a>(
204    derivs: &'a mut WorkingDerivativeBuffersMut<'_>,
205) -> WorkingDerivSlices<'a> {
206    WorkingDerivSlices {
207        c: derivs.c.as_slice_mut().expect("c must be contiguous"),
208        d: derivs.d.as_slice_mut().expect("d must be contiguous"),
209        dmu: derivs
210            .dmu_deta
211            .as_slice_mut()
212            .expect("dmu_deta must be contiguous"),
213        d2: derivs
214            .d2mu_deta2
215            .as_slice_mut()
216            .expect("d2mu_deta2 must be contiguous"),
217        d3: derivs
218            .d3mu_deta3
219            .as_slice_mut()
220            .expect("d3mu_deta3 must be contiguous"),
221    }
222}
223
224#[derive(Clone, Copy)]
225pub(crate) struct WorkingBernoulliGeometry {
226    pub(crate) mu: f64,
227    pub(crate) weight: f64,
228    pub(crate) z: f64,
229    pub(crate) c: f64,
230    pub(crate) d: f64,
231}
232
233/// Shared likelihood interface used by PIRLS working updates.
234///
235/// This keeps the update/deviance math in one place so engine-level likelihoods
236/// and higher-level wrappers (custom family, GAMLSS warm starts) can share a
237/// consistent implementation.
238pub(crate) trait WorkingLikelihood {
239    fn irls_update(
240        &self,
241        y: ArrayView1<f64>,
242        eta: &Array1<f64>,
243        priorweights: ArrayView1<f64>,
244        mu: &mut Array1<f64>,
245        weights: &mut Array1<f64>,
246        z: &mut Array1<f64>,
247        integrated: Option<IntegratedWorkingInput<'_>>,
248        derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
249    ) -> Result<(), EstimationError>;
250
251    fn loglik_deviance(
252        &self,
253        y: ArrayView1<f64>,
254        mu: &Array1<f64>,
255        priorweights: ArrayView1<f64>,
256    ) -> Result<f64, EstimationError>;
257}
258
259impl WorkingLikelihood for GlmLikelihoodSpec {
260    fn irls_update(
261        &self,
262        y: ArrayView1<f64>,
263        eta: &Array1<f64>,
264        priorweights: ArrayView1<f64>,
265        mu: &mut Array1<f64>,
266        weights: &mut Array1<f64>,
267        z: &mut Array1<f64>,
268        integrated: Option<IntegratedWorkingInput<'_>>,
269        derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
270    ) -> Result<(), EstimationError> {
271        match (&self.spec.response, &self.spec.link, integrated.is_some()) {
272            (ResponseFamily::Binomial, _, true) => {
273                let integ = integrated.unwrap();
274                update_glmvectors_integrated_by_family(
275                    integ.quadctx,
276                    y,
277                    eta,
278                    integ.se,
279                    &self.spec,
280                    priorweights,
281                    mu,
282                    weights,
283                    z,
284                    derivatives,
285                    integ.mixture_link_state,
286                    integ.sas_link_state,
287                )?;
288                Ok(())
289            }
290            (ResponseFamily::Binomial, link, false) => {
291                if matches!(link, InverseLink::Mixture(_)) {
292                    crate::bail_invalid_estim!(
293                        "BinomialMixture IRLS update requires explicit mixture link state"
294                            .to_string(),
295                    );
296                }
297                update_glmvectors(
298                    y,
299                    eta,
300                    &self.spec.link,
301                    priorweights,
302                    mu,
303                    weights,
304                    z,
305                    derivatives,
306                )?;
307                Ok(())
308            }
309            (ResponseFamily::Gaussian, _, _) => {
310                update_glmvectors(
311                    y,
312                    eta,
313                    &InverseLink::Standard(StandardLink::Identity),
314                    priorweights,
315                    mu,
316                    weights,
317                    z,
318                    None,
319                )?;
320                // For Gaussian identity, the canonical IRLS working weight is
321                //     w_i = prior_i * (dmu/deta)^2 / Var(Y_i | mu_i) = prior_i / phi.
322                // When the scale metadata explicitly fixes phi (rather than
323                // profiling sigma out), the working weights must include 1/phi
324                // so that PIRLS minimises the scaled deviance / scaled negative
325                // log-likelihood that the calibrator and downstream variance
326                // calculations expect. `ProfiledGaussian` returns `None` here,
327                // preserving the historical "weights == prior" behaviour for
328                // the default profiled case.
329                if let Some(phi) = self.scale.fixed_phi() {
330                    if !(phi.is_finite() && phi > 0.0) {
331                        crate::bail_invalid_estim!(
332                            "Gaussian fixed dispersion phi must be finite and positive (got {})",
333                            phi
334                        );
335                    }
336                    if phi != 1.0 {
337                        let inv_phi = 1.0 / phi;
338                        weights.mapv_inplace(|w| w * inv_phi);
339                    }
340                }
341                Ok(())
342            }
343            (ResponseFamily::Poisson, _, _) => {
344                write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
345                Ok(())
346            }
347            (ResponseFamily::Tweedie { p }, _, _) => {
348                let p = *p;
349                write_tweedie_log_working_state(
350                    y,
351                    eta,
352                    priorweights,
353                    p,
354                    fixed_glm_dispersion(self),
355                    mu,
356                    weights,
357                    z,
358                    derivatives,
359                )?;
360                Ok(())
361            }
362            (ResponseFamily::NegativeBinomial { theta, .. }, _, _) => {
363                let theta = *theta;
364                write_negative_binomial_log_working_state(
365                    y,
366                    eta,
367                    priorweights,
368                    theta,
369                    mu,
370                    weights,
371                    z,
372                    derivatives,
373                )?;
374                Ok(())
375            }
376            (ResponseFamily::Beta { phi }, _, _) => {
377                let phi = *phi;
378                write_beta_logit_working_state(
379                    y,
380                    eta,
381                    priorweights,
382                    phi,
383                    mu,
384                    weights,
385                    z,
386                    derivatives,
387                )?;
388                Ok(())
389            }
390            (ResponseFamily::Gamma, _, _) => {
391                write_gamma_log_working_state(
392                    y,
393                    eta,
394                    priorweights,
395                    self.gamma_shape().unwrap_or(1.0),
396                    mu,
397                    weights,
398                    z,
399                    derivatives,
400                );
401                Ok(())
402            }
403            (ResponseFamily::RoystonParmar, _, _) => Err(EstimationError::InvalidInput(
404                "RoystonParmar is survival-specific and not a GLM IRLS family".to_string(),
405            )),
406        }
407    }
408
409    fn loglik_deviance(
410        &self,
411        y: ArrayView1<f64>,
412        mu: &Array1<f64>,
413        priorweights: ArrayView1<f64>,
414    ) -> Result<f64, EstimationError> {
415        if matches!(self.spec.response, ResponseFamily::Tweedie { .. }) {
416            validate_tweedie_responses(&y, &priorweights)?;
417        }
418        Ok(calculate_deviance(y, mu, self, priorweights))
419    }
420}