1use crate::survival::lognormal_kernel::latent_cloglog_inverse_link_jet;
2use crate::inference::generative::NoiseModel;
3use gam_solve::mixture_link::{
4 InverseLinkJet, inverse_link_jet_for_family_public, mixture_inverse_link_jet,
5};
6use crate::model_types::{EstimationError, FittedLinkState, UnifiedFitResult};
7use crate::quadrature::{
8 IntegratedMomentsJet, QuadratureContext, cloglog_posterior_meanvariance,
9 integrated_family_moments_jet, integrated_inverse_link_jetwith_state,
10 integrated_inverse_link_mean_and_derivative, logit_posterior_meanvariance,
11 normal_expectation_1d_adaptive, normal_expectation_1d_adaptive_pair,
12 probit_posterior_meanvariance, survival_posterior_mean, survival_posterior_meanvariance,
13};
14use gam_problem::{
15 InverseLink, LikelihoodScaleMetadata, LikelihoodSpec, LinkFunction, ResponseFamily,
16 StandardLink,
17};
18use ndarray::{Array1, ArrayView1};
19
20const PROB_VARIANCE_FLOOR: f64 = 1e-12;
26
27pub trait FamilyStrategy: std::fmt::Debug + Send + Sync {
30 fn name(&self) -> &'static str;
31
32 fn family(&self) -> LikelihoodSpec;
33
34 fn link_function(&self) -> LinkFunction;
35
36 fn inverse_link(&self, eta: f64) -> Result<f64, EstimationError>;
37
38 fn inverse_link_array(&self, eta: ArrayView1<'_, f64>) -> Result<Array1<f64>, EstimationError>;
39
40 fn inverse_link_jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError>;
41
42 fn posterior_mean(
43 &self,
44 quadctx: &QuadratureContext,
45 eta: f64,
46 se_eta: f64,
47 ) -> Result<f64, EstimationError>;
48
49 fn posterior_meanvariance(
50 &self,
51 quadctx: &QuadratureContext,
52 eta: f64,
53 se_eta: f64,
54 ) -> Result<(f64, f64), EstimationError>;
55
56 fn simulate_noise(
57 &self,
58 mean: &Array1<f64>,
59 gaussian_scale: Option<f64>,
60 ) -> Result<NoiseModel, EstimationError>;
61
62 fn integrated_moments(
63 &self,
64 quadctx: &QuadratureContext,
65 eta: f64,
66 se_eta: f64,
67 ) -> Result<IntegratedMomentsJet, EstimationError>;
68}
69
70#[derive(Clone, Debug)]
79pub struct ResolvedFamilyStrategy {
80 spec: LikelihoodSpec,
81}
82
83fn spec_from_family(family: LikelihoodSpec, inverse_link: Option<&InverseLink>) -> LikelihoodSpec {
87 if let Some(link) = inverse_link {
88 return LikelihoodSpec {
89 response: family.response,
90 link: link.clone(),
91 };
92 }
93 family
94}
95
96#[inline]
101pub fn strategy_for_family(
102 family: LikelihoodSpec,
103 inverse_link: Option<&InverseLink>,
104) -> ResolvedFamilyStrategy {
105 ResolvedFamilyStrategy {
106 spec: spec_from_family(family, inverse_link),
107 }
108}
109
110#[inline]
115pub fn strategy_for_spec(spec: &LikelihoodSpec) -> ResolvedFamilyStrategy {
116 ResolvedFamilyStrategy { spec: spec.clone() }
117}
118
119pub fn strategy_from_fit(
125 family: &LikelihoodSpec,
126 fit: &UnifiedFitResult,
127) -> Result<ResolvedFamilyStrategy, EstimationError> {
128 let inverse_link = match fit.fitted_link_state(family)? {
129 FittedLinkState::Standard(Some(link)) => Some(InverseLink::Standard(link)),
130 FittedLinkState::Standard(None) => None,
131 FittedLinkState::LatentCLogLog { state } => Some(InverseLink::LatentCLogLog(state)),
132 FittedLinkState::Sas { state, .. } => Some(InverseLink::Sas(state)),
133 FittedLinkState::BetaLogistic { state, .. } => Some(InverseLink::BetaLogistic(state)),
134 FittedLinkState::Mixture { state, .. } => Some(InverseLink::Mixture(state)),
135 };
136 let spec = if let Some(link) = inverse_link {
137 LikelihoodSpec::new(family.response.clone(), link)
138 } else {
139 family.clone()
140 };
141 Ok(strategy_for_spec(&spec))
142}
143
144impl ResolvedFamilyStrategy {
145 #[inline]
146 fn mixture_state(&self) -> Option<&gam_problem::MixtureLinkState> {
147 self.spec.link.mixture_state()
148 }
149
150 #[inline]
151 fn sas_state(&self) -> Option<&gam_problem::SasLinkState> {
152 self.spec.link.sas_state()
153 }
154
155 #[inline]
156 fn latent_cloglog_state(&self) -> Option<&gam_problem::LatentCLogLogState> {
157 self.spec.link.latent_cloglog_state()
158 }
159
160 #[inline]
161 fn require_latent_cloglog_state(
162 &self,
163 ) -> Result<&gam_problem::LatentCLogLogState, EstimationError> {
164 self.latent_cloglog_state()
165 .ok_or_else(|| missing_state(&self.spec, "latent cloglog"))
166 }
167
168 #[inline]
169 fn require_sas_state(&self) -> Result<&gam_problem::SasLinkState, EstimationError> {
170 self.sas_state()
171 .ok_or_else(|| missing_state(&self.spec, "SAS link"))
172 }
173
174 #[inline]
175 fn require_mixture_state(&self) -> Result<&gam_problem::MixtureLinkState, EstimationError> {
176 self.mixture_state()
177 .ok_or_else(|| missing_state(&self.spec, "mixture link"))
178 }
179}
180
181#[cold]
182fn missing_state(spec: &LikelihoodSpec, what: &str) -> EstimationError {
183 EstimationError::InvalidInput(format!(
184 "{} requires fitted {} state",
185 spec.pretty_name(),
186 what
187 ))
188}
189
190#[inline]
195fn posterior_mv_from_prob_kernel<F>(
196 quadctx: &QuadratureContext,
197 eta: f64,
198 se_eta: f64,
199 prob: F,
200) -> (f64, f64)
201where
202 F: Fn(f64) -> f64,
203{
204 let (m1, m2) = normal_expectation_1d_adaptive_pair(quadctx, eta, se_eta, |x| {
205 let p = prob(x);
206 (p, p * p)
207 });
208 (m1, (m2 - m1 * m1).max(0.0))
209}
210
211impl FamilyStrategy for ResolvedFamilyStrategy {
212 fn name(&self) -> &'static str {
213 self.spec.name()
214 }
215
216 fn family(&self) -> LikelihoodSpec {
217 self.spec.clone()
218 }
219
220 fn link_function(&self) -> LinkFunction {
221 self.spec.link.link_function()
222 }
223
224 fn inverse_link(&self, eta: f64) -> Result<f64, EstimationError> {
225 self.inverse_link_jet(eta).map(|jet| jet.mu)
226 }
227
228 fn inverse_link_array(&self, eta: ArrayView1<'_, f64>) -> Result<Array1<f64>, EstimationError> {
229 let mut out = Array1::<f64>::zeros(eta.len());
230 for i in 0..eta.len() {
231 out[i] = self.inverse_link(eta[i])?;
232 }
233 Ok(out)
234 }
235
236 fn inverse_link_jet(&self, eta: f64) -> Result<InverseLinkJet, EstimationError> {
237 inverse_link_jet_for_family_public(&self.spec, eta)
244 }
245
246 fn posterior_mean(
247 &self,
248 quadctx: &QuadratureContext,
249 eta: f64,
250 se_eta: f64,
251 ) -> Result<f64, EstimationError> {
252 match (&self.spec.response, &self.spec.link) {
253 (ResponseFamily::Gaussian, _) => Ok(eta),
254 (ResponseFamily::Binomial, InverseLink::Standard(_)) => {
255 integrated_inverse_link_mean_and_derivative(
256 quadctx,
257 self.link_function(),
258 eta,
259 se_eta,
260 )
261 .map(|v| v.mean)
262 }
263 (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {
264 let state = self.require_latent_cloglog_state()?;
265 latent_cloglog_inverse_link_jet(quadctx, eta, se_eta.hypot(state.latent_sd))
266 .map(|v| v.mean)
267 }
268 (ResponseFamily::Binomial, InverseLink::Sas(_))
269 | (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {
270 integrated_inverse_link_jetwith_state(
271 quadctx,
272 self.link_function(),
273 eta,
274 se_eta,
275 self.mixture_state(),
276 self.sas_state(),
277 )
278 .map(|v| v.mean)
279 }
280 (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {
281 let state = self.require_mixture_state()?;
282 integrated_family_moments_jet(
283 quadctx,
284 &LikelihoodSpec::binomial_mixture(state.clone()),
285 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
289 eta,
290 se_eta,
291 )
292 .map(|v| v.mean)
293 }
294 (ResponseFamily::Poisson, _)
295 | (ResponseFamily::Tweedie { .. }, _)
296 | (ResponseFamily::NegativeBinomial { .. }, _)
297 | (ResponseFamily::Gamma, _) => {
298 let exponent = eta + 0.5 * se_eta * se_eta;
308 if exponent.is_finite() {
309 let mgf = exponent.exp();
310 if mgf.is_finite() {
311 return Ok(mgf);
312 }
313 }
314 let plugin = eta.exp();
320 if plugin.is_finite() {
321 Ok(plugin)
322 } else {
323 Ok(f64::MAX)
324 }
325 }
326 (ResponseFamily::Beta { .. }, _) => {
327 Ok(logit_posterior_meanvariance(quadctx, eta, se_eta).0)
328 }
329 (ResponseFamily::RoystonParmar, _) => Ok(survival_posterior_mean(quadctx, eta, se_eta)),
330 }
331 }
332
333 fn posterior_meanvariance(
334 &self,
335 quadctx: &QuadratureContext,
336 eta: f64,
337 se_eta: f64,
338 ) -> Result<(f64, f64), EstimationError> {
339 match (&self.spec.response, &self.spec.link) {
340 (ResponseFamily::Gaussian, _) => Ok((eta, (se_eta * se_eta).max(0.0))),
341 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
342 Ok(logit_posterior_meanvariance(quadctx, eta, se_eta))
343 }
344 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
345 Ok(probit_posterior_meanvariance(quadctx, eta, se_eta))
346 }
347 (ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
348 Ok(cloglog_posterior_meanvariance(quadctx, eta, se_eta))
349 }
350 (ResponseFamily::Binomial, InverseLink::Standard(_)) => {
351 Ok(logit_posterior_meanvariance(quadctx, eta, se_eta))
355 }
356 (ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => {
357 let state = self.require_latent_cloglog_state()?;
358 let total_sigma = se_eta.hypot(state.latent_sd);
359 let m1 = latent_cloglog_inverse_link_jet(quadctx, eta, total_sigma)?.mean;
360 let m2 = normal_expectation_1d_adaptive(quadctx, eta, se_eta, |x| {
361 latent_cloglog_inverse_link_jet(quadctx, x, state.latent_sd)
362 .map(|jet| {
363 let p = jet.mean;
364 p * p
365 })
366 .unwrap_or(f64::NAN)
367 });
368 Ok((m1, (m2 - m1 * m1).max(0.0)))
369 }
370 (ResponseFamily::Binomial, InverseLink::Sas(_)) => {
371 let state = self.require_sas_state()?;
372 Ok(posterior_mv_from_prob_kernel(quadctx, eta, se_eta, |x| {
373 gam_solve::mixture_link::sas_inverse_link_jet(x, state.epsilon, state.log_delta).mu
374 }))
375 }
376 (ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => {
377 let state = self.require_sas_state()?;
378 Ok(posterior_mv_from_prob_kernel(quadctx, eta, se_eta, |x| {
379 gam_solve::mixture_link::beta_logistic_inverse_link_jet(
380 x,
381 state.log_delta,
382 state.epsilon,
383 )
384 .mu
385 }))
386 }
387 (ResponseFamily::Binomial, InverseLink::Mixture(_)) => {
388 let state = self.require_mixture_state()?;
389 let m1 = integrated_family_moments_jet(
390 quadctx,
391 &LikelihoodSpec::binomial_mixture(state.clone()),
392 LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
396 eta,
397 se_eta,
398 )?
399 .mean;
400 let m2 = normal_expectation_1d_adaptive(quadctx, eta, se_eta, |x| {
401 let p = mixture_inverse_link_jet(state, x).mu;
402 p * p
403 });
404 Ok((m1, (m2 - m1 * m1).max(0.0)))
405 }
406 (ResponseFamily::Poisson, _)
407 | (ResponseFamily::Tweedie { .. }, _)
408 | (ResponseFamily::NegativeBinomial { .. }, _)
409 | (ResponseFamily::Gamma, _) => {
410 let s2 = se_eta * se_eta;
413 let m1 = (eta + 0.5 * s2).exp();
414 let m2 = (2.0 * eta + s2).exp() * (s2.exp() - 1.0);
415 Ok((m1, m2.max(0.0)))
416 }
417 (ResponseFamily::Beta { .. }, _) => {
418 Ok(logit_posterior_meanvariance(quadctx, eta, se_eta))
419 }
420 (ResponseFamily::RoystonParmar, _) => {
421 Ok(survival_posterior_meanvariance(quadctx, eta, se_eta))
422 }
423 }
424 }
425
426 fn simulate_noise(
427 &self,
428 mean: &Array1<f64>,
429 gaussian_scale: Option<f64>,
430 ) -> Result<NoiseModel, EstimationError> {
431 NoiseModel::from_likelihood(&self.spec, mean.len(), gaussian_scale)
436 }
437
438 fn integrated_moments(
439 &self,
440 quadctx: &QuadratureContext,
441 eta: f64,
442 se_eta: f64,
443 ) -> Result<IntegratedMomentsJet, EstimationError> {
444 if let Some(state) = self.latent_cloglog_state() {
445 let jet = latent_cloglog_inverse_link_jet(quadctx, eta, se_eta.hypot(state.latent_sd))?;
446 let mean = jet.mean;
447 return Ok(IntegratedMomentsJet {
448 mean,
449 variance: (mean * (1.0 - mean)).max(PROB_VARIANCE_FLOOR),
450 d1: jet.d1,
451 d2: jet.d2,
452 d3: jet.d3,
453 mode: jet.mode,
454 });
455 }
456 integrated_family_moments_jet(
463 quadctx,
464 &self.spec,
465 self.spec.default_scale_metadata(),
466 eta,
467 se_eta,
468 )
469 }
470}
471
472#[cfg(test)]
473mod log_link_public_jet_tests {
474 use super::*;
475 use gam_solve::mixture_link::inverse_link_jet_for_family;
476 use gam_problem::LikelihoodSpec;
477 use ndarray::Array1;
478
479 #[test]
487 fn public_predict_log_inverse_link_is_exact_exp_at_boundary() {
488 let strategy = strategy_for_spec(&LikelihoodSpec::poisson_log());
489
490 let exact = 705.0_f64.exp();
493 assert!(exact.is_finite(), "exp(705) must be representable in f64");
494 let jet = strategy.inverse_link_jet(705.0).expect("jet");
495 assert_eq!(jet.mu, exact, "predict mean must be exact exp(705)");
496 assert_eq!(jet.d1, exact, "predict dmu/deta must be exact exp(705)");
498 assert_eq!(jet.d2, exact);
499 assert_eq!(jet.d3, exact);
500 let clamped = 700.0_f64.exp();
501 assert!(
502 jet.mu > clamped * 100.0,
503 "exact exp(705) must exceed the clamped exp(700) by ~exp(5)"
504 );
505
506 let arr = strategy
508 .inverse_link_array(Array1::from(vec![705.0]).view())
509 .expect("array");
510 assert_eq!(arr[0], exact, "inverse_link_array must be exact exp(705)");
511
512 let exact_neg = (-720.0_f64).exp();
515 let jet = strategy.inverse_link_jet(-720.0).expect("jet");
516 assert_eq!(jet.mu, exact_neg, "predict mean must be exact exp(-720)");
517 let clamped_neg = (-700.0_f64).exp();
518 assert!(
519 jet.mu < clamped_neg,
520 "exact exp(-720) must be strictly below the clamped exp(-700)"
521 );
522
523 let over = strategy.inverse_link_jet(710.0).expect("jet");
525 assert!(over.mu.is_infinite() && over.mu > 0.0, "exp(710) -> +inf");
526 let under = strategy.inverse_link_jet(-746.0).expect("jet");
527 assert_eq!(under.mu, 0.0, "exp(-746) -> 0.0");
528 }
529
530 #[test]
534 fn public_predict_log_jet_byte_identical_to_clamped_in_range() {
535 let spec = LikelihoodSpec::poisson_log();
536 let strategy = strategy_for_spec(&spec);
537 for &eta in &[
538 -700.0, -300.0, -12.5, -1.0, -0.25, 0.0, 0.25, 1.0, 12.5, 300.0, 700.0,
539 ] {
540 let public_jet = strategy.inverse_link_jet(eta).expect("public jet");
541 let clamped_jet = inverse_link_jet_for_family(&spec, eta).expect("clamped jet");
542 assert_eq!(
543 public_jet.mu.to_bits(),
544 clamped_jet.mu.to_bits(),
545 "mu must be byte-identical in range at eta={eta}"
546 );
547 assert_eq!(
548 public_jet.d1.to_bits(),
549 clamped_jet.d1.to_bits(),
550 "d1 must be byte-identical in range at eta={eta}"
551 );
552 assert_eq!(public_jet.d2.to_bits(), clamped_jet.d2.to_bits());
553 assert_eq!(public_jet.d3.to_bits(), clamped_jet.d3.to_bits());
554 }
555 }
556}