1use super::*;
7
8pub trait WorkingModel {
9 fn update(&mut self, beta: &Coefficients) -> Result<WorkingState, EstimationError>;
10
11 fn update_with_curvature(
12 &mut self,
13 beta: &Coefficients,
14 _: HessianCurvatureKind,
15 ) -> Result<WorkingState, EstimationError> {
16 self.update(beta)
17 }
18
19 fn update_candidate(
20 &mut self,
21 beta: &Coefficients,
22 curvature: HessianCurvatureKind,
23 ) -> Result<WorkingState, EstimationError> {
24 self.update_with_curvature(beta, curvature)
25 }
26
27 fn screen_candidate(
28 &mut self,
29 beta: &Coefficients,
30 arr: &Array1<f64>,
31 _: &LinearPredictor,
32 curvature: HessianCurvatureKind,
33 ) -> Result<CandidateEvaluation, EstimationError> {
34 assert!(arr.iter().all(|v| !v.is_nan()));
35 self.update_candidate(beta, curvature)
36 .map(CandidateEvaluation::Full)
37 }
38
39 fn supports_observed_information_curvature(&self) -> bool {
40 false
41 }
42}
43
44#[derive(Debug, Clone)]
47pub struct CandidateScreen {
48 pub penalized_objective: f64,
49 pub deviance: f64,
50 pub penalty_term: f64,
51 pub arithmetic_finite: bool,
52}
53
54pub enum CandidateEvaluation {
58 Screen(CandidateScreen),
59 Full(WorkingState),
60}
61
62impl CandidateEvaluation {
63 #[inline]
64 pub(crate) fn penalized_objective(&self, firth_bias_reduction: bool) -> f64 {
65 match self {
66 Self::Screen(s) => s.penalized_objective,
67 Self::Full(state) => {
68 let mut value = state.deviance + state.penalty_term;
69 if firth_bias_reduction && let Some(j) = state.jeffreys_logdet() {
70 value -= 2.0 * j;
71 }
72 value
73 }
74 }
75 }
76
77 #[inline]
78 pub(crate) fn arithmetic_finite(&self) -> bool {
79 match self {
80 Self::Screen(s) => s.arithmetic_finite,
81 Self::Full(state) => state.gradient.iter().all(|g| g.is_finite()),
82 }
83 }
84
85 #[inline]
86 pub(crate) fn into_full(self) -> Option<WorkingState> {
87 match self {
88 Self::Full(state) => Some(state),
89 Self::Screen(_) => None,
90 }
91 }
92}
93
94#[derive(Clone, Debug, PartialEq, Eq)]
95pub(super) struct PirlsAcceptedStateCacheKey {
96 curvature: HessianCurvatureKind,
97 firth_active: bool,
98 beta_bits: Vec<u64>,
99 arrow_latent_bits: Option<Vec<u64>>,
100}
101
102impl PirlsAcceptedStateCacheKey {
103 pub(crate) fn requested(
104 beta: &Coefficients,
105 curvature: HessianCurvatureKind,
106 options: &WorkingModelPirlsOptions,
107 ) -> Self {
108 Self::new(beta, curvature, options.firth_bias_reduction, options)
109 }
110
111 pub(crate) fn accepted(
112 beta: &Coefficients,
113 state: &WorkingState,
114 options: &WorkingModelPirlsOptions,
115 ) -> Self {
116 Self::new(
117 beta,
118 state.hessian_curvature,
119 matches!(state.firth, FirthDiagnostics::Active { .. }),
120 options,
121 )
122 }
123
124 pub(crate) fn new(
125 beta: &Coefficients,
126 curvature: HessianCurvatureKind,
127 firth_active: bool,
128 options: &WorkingModelPirlsOptions,
129 ) -> Self {
130 let arrow_latent_bits = options.arrow_schur.as_ref().map(|arrow_cfg| {
131 arrow_cfg.snapshot_t.as_ref()()
132 .iter()
133 .map(|value| value.to_bits())
134 .collect()
135 });
136 Self {
137 curvature,
138 firth_active,
139 beta_bits: beta.as_ref().iter().map(|value| value.to_bits()).collect(),
140 arrow_latent_bits,
141 }
142 }
143}
144
145#[derive(Clone, Copy)]
147pub(crate) struct IntegratedWorkingInput<'a> {
148 pub quadctx: &'a crate::quadrature::QuadratureContext,
149 pub se: ArrayView1<'a, f64>,
150 pub mixture_link_state: Option<&'a MixtureLinkState>,
151 pub sas_link_state: Option<&'a SasLinkState>,
152}
153
154pub struct WorkingDerivativeBuffersMut<'a> {
155 pub(crate) c: &'a mut Array1<f64>,
156 pub(crate) d: &'a mut Array1<f64>,
157 pub(crate) dmu_deta: &'a mut Array1<f64>,
158 pub(crate) d2mu_deta2: &'a mut Array1<f64>,
159 pub(crate) d3mu_deta3: &'a mut Array1<f64>,
160}
161
162pub(super) struct WorkingSlices<'a> {
165 pub mu: &'a mut [f64],
166 pub weights: &'a mut [f64],
167 pub z: &'a mut [f64],
168}
169
170pub(super) struct WorkingDerivSlices<'a> {
173 pub c: &'a mut [f64],
174 pub d: &'a mut [f64],
175 pub dmu: &'a mut [f64],
176 pub d2: &'a mut [f64],
177 pub d3: &'a mut [f64],
178}
179
180#[inline]
185pub(super) fn working_slices<'a>(
186 mu: &'a mut Array1<f64>,
187 weights: &'a mut Array1<f64>,
188 z: &'a mut Array1<f64>,
189) -> WorkingSlices<'a> {
190 WorkingSlices {
191 mu: mu.as_slice_mut().expect("mu must be contiguous"),
192 weights: weights.as_slice_mut().expect("weights must be contiguous"),
193 z: z.as_slice_mut().expect("z must be contiguous"),
194 }
195}
196
197#[inline]
203pub(super) fn working_deriv_slices<'a>(
204 derivs: &'a mut WorkingDerivativeBuffersMut<'_>,
205) -> WorkingDerivSlices<'a> {
206 WorkingDerivSlices {
207 c: derivs.c.as_slice_mut().expect("c must be contiguous"),
208 d: derivs.d.as_slice_mut().expect("d must be contiguous"),
209 dmu: derivs
210 .dmu_deta
211 .as_slice_mut()
212 .expect("dmu_deta must be contiguous"),
213 d2: derivs
214 .d2mu_deta2
215 .as_slice_mut()
216 .expect("d2mu_deta2 must be contiguous"),
217 d3: derivs
218 .d3mu_deta3
219 .as_slice_mut()
220 .expect("d3mu_deta3 must be contiguous"),
221 }
222}
223
224#[derive(Clone, Copy)]
225pub(crate) struct WorkingBernoulliGeometry {
226 pub(crate) mu: f64,
227 pub(crate) weight: f64,
228 pub(crate) z: f64,
229 pub(crate) c: f64,
230 pub(crate) d: f64,
231}
232
233pub(crate) trait WorkingLikelihood {
239 fn irls_update(
240 &self,
241 y: ArrayView1<f64>,
242 eta: &Array1<f64>,
243 priorweights: ArrayView1<f64>,
244 mu: &mut Array1<f64>,
245 weights: &mut Array1<f64>,
246 z: &mut Array1<f64>,
247 integrated: Option<IntegratedWorkingInput<'_>>,
248 derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
249 ) -> Result<(), EstimationError>;
250
251 fn loglik_deviance(
252 &self,
253 y: ArrayView1<f64>,
254 mu: &Array1<f64>,
255 priorweights: ArrayView1<f64>,
256 ) -> Result<f64, EstimationError>;
257}
258
259impl WorkingLikelihood for GlmLikelihoodSpec {
260 fn irls_update(
261 &self,
262 y: ArrayView1<f64>,
263 eta: &Array1<f64>,
264 priorweights: ArrayView1<f64>,
265 mu: &mut Array1<f64>,
266 weights: &mut Array1<f64>,
267 z: &mut Array1<f64>,
268 integrated: Option<IntegratedWorkingInput<'_>>,
269 derivatives: Option<WorkingDerivativeBuffersMut<'_>>,
270 ) -> Result<(), EstimationError> {
271 match (&self.spec.response, &self.spec.link, integrated.is_some()) {
272 (ResponseFamily::Binomial, _, true) => {
273 let integ = integrated.unwrap();
274 update_glmvectors_integrated_by_family(
275 integ.quadctx,
276 y,
277 eta,
278 integ.se,
279 &self.spec,
280 priorweights,
281 mu,
282 weights,
283 z,
284 derivatives,
285 integ.mixture_link_state,
286 integ.sas_link_state,
287 )?;
288 Ok(())
289 }
290 (ResponseFamily::Binomial, link, false) => {
291 if matches!(link, InverseLink::Mixture(_)) {
292 crate::bail_invalid_estim!(
293 "BinomialMixture IRLS update requires explicit mixture link state"
294 .to_string(),
295 );
296 }
297 update_glmvectors(
298 y,
299 eta,
300 &self.spec.link,
301 priorweights,
302 mu,
303 weights,
304 z,
305 derivatives,
306 )?;
307 Ok(())
308 }
309 (ResponseFamily::Gaussian, _, _) => {
310 update_glmvectors(
311 y,
312 eta,
313 &InverseLink::Standard(StandardLink::Identity),
314 priorweights,
315 mu,
316 weights,
317 z,
318 None,
319 )?;
320 if let Some(phi) = self.scale.fixed_phi() {
330 if !(phi.is_finite() && phi > 0.0) {
331 crate::bail_invalid_estim!(
332 "Gaussian fixed dispersion phi must be finite and positive (got {})",
333 phi
334 );
335 }
336 if phi != 1.0 {
337 let inv_phi = 1.0 / phi;
338 weights.mapv_inplace(|w| w * inv_phi);
339 }
340 }
341 Ok(())
342 }
343 (ResponseFamily::Poisson, _, _) => {
344 write_poisson_log_working_state(y, eta, priorweights, mu, weights, z, derivatives);
345 Ok(())
346 }
347 (ResponseFamily::Tweedie { p }, _, _) => {
348 let p = *p;
349 write_tweedie_log_working_state(
350 y,
351 eta,
352 priorweights,
353 p,
354 fixed_glm_dispersion(self),
355 mu,
356 weights,
357 z,
358 derivatives,
359 )?;
360 Ok(())
361 }
362 (ResponseFamily::NegativeBinomial { theta, .. }, _, _) => {
363 let theta = *theta;
364 write_negative_binomial_log_working_state(
365 y,
366 eta,
367 priorweights,
368 theta,
369 mu,
370 weights,
371 z,
372 derivatives,
373 )?;
374 Ok(())
375 }
376 (ResponseFamily::Beta { phi }, _, _) => {
377 let phi = *phi;
378 write_beta_logit_working_state(
379 y,
380 eta,
381 priorweights,
382 phi,
383 mu,
384 weights,
385 z,
386 derivatives,
387 )?;
388 Ok(())
389 }
390 (ResponseFamily::Gamma, _, _) => {
391 write_gamma_log_working_state(
392 y,
393 eta,
394 priorweights,
395 self.gamma_shape().unwrap_or(1.0),
396 mu,
397 weights,
398 z,
399 derivatives,
400 );
401 Ok(())
402 }
403 (ResponseFamily::RoystonParmar, _, _) => Err(EstimationError::InvalidInput(
404 "RoystonParmar is survival-specific and not a GLM IRLS family".to_string(),
405 )),
406 }
407 }
408
409 fn loglik_deviance(
410 &self,
411 y: ArrayView1<f64>,
412 mu: &Array1<f64>,
413 priorweights: ArrayView1<f64>,
414 ) -> Result<f64, EstimationError> {
415 if matches!(self.spec.response, ResponseFamily::Tweedie { .. }) {
416 validate_tweedie_responses(&y, &priorweights)?;
417 }
418 Ok(calculate_deviance(y, mu, self, priorweights))
419 }
420}