use crate::gpu::policy::{PirlsLoopAdmission, PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
use crate::types::{InverseLink, LikelihoodSpec, ResponseFamily, StandardLink};
pub fn pirls_loop_family_for(spec: &LikelihoodSpec) -> Option<PirlsLoopFamilyKind> {
let link = match &spec.link {
InverseLink::Standard(lf) => *lf,
_ => return None,
};
match (&spec.response, link) {
(ResponseFamily::Binomial, StandardLink::Logit) => {
Some(PirlsLoopFamilyKind::BernoulliLogit)
}
(ResponseFamily::Binomial, StandardLink::Probit) => {
Some(PirlsLoopFamilyKind::BernoulliProbit)
}
(ResponseFamily::Binomial, StandardLink::CLogLog) => {
Some(PirlsLoopFamilyKind::BernoulliCLogLog)
}
(ResponseFamily::Poisson, StandardLink::Log) => Some(PirlsLoopFamilyKind::PoissonLog),
(ResponseFamily::Gaussian, StandardLink::Identity) => {
Some(PirlsLoopFamilyKind::GaussianIdentity)
}
(ResponseFamily::Gamma, StandardLink::Log) => Some(PirlsLoopFamilyKind::GammaLog),
_ => None,
}
}
pub fn pirls_loop_curvature_for(family: PirlsLoopFamilyKind) -> PirlsLoopCurvatureKind {
match family {
PirlsLoopFamilyKind::BernoulliProbit | PirlsLoopFamilyKind::BernoulliCLogLog => {
PirlsLoopCurvatureKind::Observed
}
PirlsLoopFamilyKind::BernoulliLogit
| PirlsLoopFamilyKind::PoissonLog
| PirlsLoopFamilyKind::GaussianIdentity
| PirlsLoopFamilyKind::GammaLog => PirlsLoopCurvatureKind::Fisher,
}
}
pub fn gpu_runtime_available() -> bool {
crate::gpu::runtime::GpuRuntime::is_available()
}
pub fn admission_for(spec: &LikelihoodSpec, n: usize, p: usize) -> Option<PirlsLoopAdmission> {
let family = pirls_loop_family_for(spec)?;
let curvature = pirls_loop_curvature_for(family);
Some(PirlsLoopAdmission {
n,
p,
family: Some(family),
curvature,
gpu_available: gpu_runtime_available(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{LikelihoodSpec, MixtureLinkState};
use ndarray::Array1;
fn dummy_mixture_state() -> MixtureLinkState {
MixtureLinkState {
components: vec![
crate::types::LinkComponent::Logit,
crate::types::LinkComponent::Probit,
],
rho: Array1::from(vec![0.0_f64]),
pi: Array1::from(vec![0.5_f64, 0.5_f64]),
}
}
#[test]
fn maps_six_canonical_built_in_pairings() {
for (spec, want) in [
(
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
),
PirlsLoopFamilyKind::BernoulliLogit,
),
(
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
),
PirlsLoopFamilyKind::BernoulliProbit,
),
(
LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
),
PirlsLoopFamilyKind::BernoulliCLogLog,
),
(
LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
),
PirlsLoopFamilyKind::PoissonLog,
),
(
LikelihoodSpec::gaussian_identity(),
PirlsLoopFamilyKind::GaussianIdentity,
),
(
LikelihoodSpec::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
),
PirlsLoopFamilyKind::GammaLog,
),
] {
assert_eq!(pirls_loop_family_for(&spec), Some(want), "for {:?}", spec);
}
}
#[test]
fn declines_unsupported_response_link_pairings() {
let mixture_state = dummy_mixture_state();
assert_eq!(
pirls_loop_family_for(&LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Mixture(mixture_state),
)),
None
);
assert_eq!(
pirls_loop_family_for(&LikelihoodSpec::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Identity),
)),
None
);
assert_eq!(
pirls_loop_family_for(&LikelihoodSpec::new(
ResponseFamily::Tweedie { p: 1.5 },
InverseLink::Standard(StandardLink::Log),
)),
None
);
}
#[test]
fn non_canonical_bernoulli_links_request_observed_curvature() {
assert_eq!(
pirls_loop_curvature_for(PirlsLoopFamilyKind::BernoulliProbit),
PirlsLoopCurvatureKind::Observed
);
assert_eq!(
pirls_loop_curvature_for(PirlsLoopFamilyKind::BernoulliCLogLog),
PirlsLoopCurvatureKind::Observed
);
assert_eq!(
pirls_loop_curvature_for(PirlsLoopFamilyKind::BernoulliLogit),
PirlsLoopCurvatureKind::Fisher
);
}
#[test]
fn admission_is_none_for_unmapped_family() {
let mixture_state = dummy_mixture_state();
let spec = LikelihoodSpec::new(
ResponseFamily::Binomial,
InverseLink::Mixture(mixture_state),
);
assert!(admission_for(&spec, 80_000, 44).is_none());
}
}