use ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CoefficientGroupPrior {
Flat,
NormalLogPrecision {
mean: f64,
sd: f64,
},
GammaPrecision {
shape: f64,
rate: f64,
},
PenalizedComplexity {
upper: f64,
tail_prob: f64,
},
}
impl CoefficientGroupPrior {
pub fn to_rho_prior(&self) -> RhoPrior {
match *self {
Self::Flat => RhoPrior::Flat,
Self::NormalLogPrecision { mean, sd } => RhoPrior::Normal { mean, sd },
Self::GammaPrecision { shape, rate } => RhoPrior::GammaPrecision { shape, rate },
Self::PenalizedComplexity { upper, tail_prob } => {
RhoPrior::PenalizedComplexity { upper, tail_prob }
}
}
}
pub fn validate(&self, context: &str) -> Result<(), String> {
match *self {
Self::Flat => Ok(()),
Self::NormalLogPrecision { mean, sd } => {
if !mean.is_finite() {
return Err(format!(
"{context} Normal log-precision prior requires finite mean, got {mean}"
));
}
if !sd.is_finite() || sd <= 0.0 {
return Err(format!(
"{context} Normal log-precision prior requires sd > 0, got {sd}"
));
}
Ok(())
}
Self::GammaPrecision { shape, rate } => {
if !shape.is_finite() || shape <= 0.0 {
return Err(format!(
"{context} Gamma precision prior requires shape > 0, got {shape}"
));
}
if !rate.is_finite() || rate < 0.0 {
return Err(format!(
"{context} Gamma precision prior requires rate >= 0, got {rate}"
));
}
Ok(())
}
Self::PenalizedComplexity { upper, tail_prob } => {
if !upper.is_finite() || upper <= 0.0 {
return Err(format!(
"{context} penalized-complexity prior requires upper > 0, got {upper}"
));
}
if !tail_prob.is_finite() || tail_prob <= 0.0 || tail_prob >= 1.0 {
return Err(format!(
"{context} penalized-complexity prior requires tail probability in (0, 1), got {tail_prob}"
));
}
Ok(())
}
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WigglePenaltyConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl WigglePenaltyConfig {
pub fn cubic_triple_operator_default() -> Self {
Self {
degree: 3,
num_internal_knots: 8,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
monotonicity_eps: 1e-4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkFunction {
Logit,
Probit,
CLogLog,
Sas,
BetaLogistic,
Identity,
Log,
}
impl LinkFunction {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Logit => "logit",
Self::Probit => "probit",
Self::CLogLog => "cloglog",
Self::Sas => "sas",
Self::BetaLogistic => "beta-logistic",
Self::Identity => "identity",
Self::Log => "log",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StandardLink {
Logit,
Probit,
CLogLog,
Identity,
Log,
}
impl StandardLink {
#[inline]
pub const fn name(self) -> &'static str {
self.as_link_function().name()
}
#[inline]
pub const fn as_link_function(self) -> LinkFunction {
match self {
Self::Logit => LinkFunction::Logit,
Self::Probit => LinkFunction::Probit,
Self::CLogLog => LinkFunction::CLogLog,
Self::Identity => LinkFunction::Identity,
Self::Log => LinkFunction::Log,
}
}
}
impl From<StandardLink> for LinkFunction {
#[inline]
fn from(link: StandardLink) -> Self {
link.as_link_function()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StateBearingLinkInStandardSlot(pub LinkFunction);
impl std::fmt::Display for StateBearingLinkInStandardSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"state-bearing link `{}` cannot be carried by `InverseLink::Standard`; \
route through `InverseLink::Sas` / `InverseLink::BetaLogistic`",
self.0.name()
)
}
}
impl std::error::Error for StateBearingLinkInStandardSlot {}
impl TryFrom<LinkFunction> for StandardLink {
type Error = StateBearingLinkInStandardSlot;
#[inline]
fn try_from(link: LinkFunction) -> Result<Self, Self::Error> {
match link {
LinkFunction::Logit => Ok(Self::Logit),
LinkFunction::Probit => Ok(Self::Probit),
LinkFunction::CLogLog => Ok(Self::CLogLog),
LinkFunction::Identity => Ok(Self::Identity),
LinkFunction::Log => Ok(Self::Log),
LinkFunction::Sas | LinkFunction::BetaLogistic => {
Err(StateBearingLinkInStandardSlot(link))
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkComponent {
Probit,
Logit,
CLogLog,
LogLog,
Cauchit,
}
impl LinkComponent {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Probit => "probit",
Self::Logit => "logit",
Self::CLogLog => "cloglog",
Self::LogLog => "loglog",
Self::Cauchit => "cauchit",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkSpec {
pub components: Vec<LinkComponent>,
pub initial_rho: Array1<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkState {
pub components: Vec<LinkComponent>,
pub rho: Array1<f64>,
pub pi: Array1<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkSpec {
pub initial_epsilon: f64,
pub initial_log_delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkState {
pub epsilon: f64,
pub log_delta: f64,
pub delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct LatentCLogLogState {
pub latent_sd: f64,
}
impl LatentCLogLogState {
#[inline]
pub fn new(latent_sd: f64) -> Result<Self, String> {
if !latent_sd.is_finite() || latent_sd < 0.0 {
return Err(format!(
"latent cloglog standard deviation must be finite and >= 0, got {latent_sd}"
));
}
Ok(Self { latent_sd })
}
}
#[inline]
fn inverse_link_has_fisher_weight_jet(link: &InverseLink) -> bool {
matches!(
link,
InverseLink::Standard(StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,)
| InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_)
)
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum InverseLink {
Standard(StandardLink),
LatentCLogLog(LatentCLogLogState),
Sas(SasLinkState),
BetaLogistic(SasLinkState),
Mixture(MixtureLinkState),
}
impl InverseLink {
#[inline]
pub const fn link_function(&self) -> LinkFunction {
match self {
Self::Standard(link) => link.as_link_function(),
Self::LatentCLogLog(_) => LinkFunction::CLogLog,
Self::Sas(_) => LinkFunction::Sas,
Self::BetaLogistic(_) => LinkFunction::BetaLogistic,
Self::Mixture(_) => LinkFunction::Logit,
}
}
#[inline]
pub const fn mixture_state(&self) -> Option<&MixtureLinkState> {
match self {
Self::Mixture(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn sas_state(&self) -> Option<&SasLinkState> {
match self {
Self::Sas(state) | Self::BetaLogistic(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn latent_cloglog_state(&self) -> Option<&LatentCLogLogState> {
match self {
Self::LatentCLogLog(state) => Some(state),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RhoPrior {
Flat,
Normal {
mean: f64,
sd: f64,
},
GammaPrecision {
shape: f64,
rate: f64,
},
PenalizedComplexity {
upper: f64,
tail_prob: f64,
},
Independent(Vec<RhoPrior>),
}
impl Default for RhoPrior {
fn default() -> Self {
Self::Normal { mean: 0.0, sd: 3.0 }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ResponseFamily {
Gaussian,
Binomial,
Poisson,
Tweedie {
p: f64,
},
NegativeBinomial {
theta: f64,
theta_fixed: bool,
},
Beta {
phi: f64,
},
Gamma,
RoystonParmar,
}
impl ResponseFamily {
#[inline]
pub const fn name(&self) -> &'static str {
match self {
Self::Gaussian => "gaussian",
Self::Binomial => "binomial",
Self::Poisson => "poisson",
Self::Tweedie { .. } => "tweedie",
Self::NegativeBinomial { .. } => "negative-binomial",
Self::Beta { .. } => "beta",
Self::Gamma => "gamma",
Self::RoystonParmar => "royston-parmar",
}
}
#[inline]
pub fn mean_clamp_bounds(&self) -> Option<(f64, f64)> {
match self {
Self::Binomial | Self::RoystonParmar => Some((0.0, 1.0)),
Self::Beta { .. } => Some((1e-10, 1.0 - 1e-10)),
Self::Gaussian
| Self::Poisson
| Self::Tweedie { .. }
| Self::NegativeBinomial { .. }
| Self::Gamma => None,
}
}
#[inline]
pub fn response_support_bounds(&self) -> Option<(f64, f64)> {
match self {
Self::Gamma | Self::Poisson | Self::NegativeBinomial { .. } | Self::Tweedie { .. } => {
Some((0.0, f64::INFINITY))
}
Self::Beta { .. } | Self::Binomial => Some((0.0, 1.0)),
Self::Gaussian | Self::RoystonParmar => None,
}
}
#[inline]
pub fn response_support_requirement(&self) -> Option<&'static str> {
match self {
Self::Gamma => Some("strictly positive response values (y > 0)"),
Self::Poisson | Self::NegativeBinomial { .. } | Self::Tweedie { .. } => {
Some("non-negative response values (y ≥ 0)")
}
Self::Beta { .. } => Some("response values strictly in the open interval (0, 1)"),
Self::Binomial => Some("binary response values (y ∈ {0, 1})"),
Self::Gaussian | Self::RoystonParmar => None,
}
}
#[inline]
fn response_support_contains(&self, yi: f64) -> bool {
match self {
Self::Gamma => yi.is_finite() && yi > 0.0,
Self::Poisson | Self::NegativeBinomial { .. } | Self::Tweedie { .. } => {
yi.is_finite() && yi >= 0.0
}
Self::Beta { .. } => yi.is_finite() && yi > 0.0 && yi < 1.0,
Self::Binomial => {
yi.is_finite()
&& ((yi - 0.0).abs() < BINOMIAL_BINARY_TOL
|| (yi - 1.0).abs() < BINOMIAL_BINARY_TOL)
}
Self::Gaussian | Self::RoystonParmar => true,
}
}
#[inline]
fn response_support_label(&self) -> &'static str {
match self {
Self::Gaussian => "Gaussian",
Self::Binomial => "Binomial",
Self::Poisson => "Poisson",
Self::Tweedie { .. } => "Tweedie",
Self::NegativeBinomial { .. } => "Negative-Binomial",
Self::Beta { .. } => "Beta",
Self::Gamma => "Gamma",
Self::RoystonParmar => "Royston-Parmar",
}
}
pub fn validate_response_support(
&self,
y: ArrayView1<'_, f64>,
) -> Result<(), ResponseSupportViolation> {
let requirement = match self.response_support_requirement() {
Some(r) => r,
None => return Ok(()),
};
let mut offending: Vec<(usize, f64)> = Vec::new();
let mut total_violations: usize = 0;
for (i, &yi) in y.iter().enumerate() {
if !self.response_support_contains(yi) {
total_violations += 1;
if offending.len() < ResponseSupportViolation::MAX_REPORTED {
offending.push((i, yi));
}
}
}
if total_violations == 0 {
Ok(())
} else {
Err(ResponseSupportViolation {
family_label: self.response_support_label(),
requirement,
offending,
total_violations,
})
}
}
pub fn validate_response_degeneracy(
&self,
y: ArrayView1<'_, f64>,
) -> Result<(), ResponseDegeneracy> {
match self {
Self::Binomial => {
let mut saw_zero = false;
let mut saw_one = false;
for &yi in y.iter() {
if (yi - 0.0).abs() < BINOMIAL_BINARY_TOL {
saw_zero = true;
} else if (yi - 1.0).abs() < BINOMIAL_BINARY_TOL {
saw_one = true;
}
if saw_zero && saw_one {
return Ok(());
}
}
let kind = if saw_one {
ResponseDegeneracyKind::BinomialAllOnes
} else if saw_zero {
ResponseDegeneracyKind::BinomialAllZeros
} else {
return Ok(());
};
Err(ResponseDegeneracy {
family_label: self.response_support_label(),
kind,
})
}
Self::Gaussian => Ok(()),
Self::Poisson
| Self::Tweedie { .. }
| Self::NegativeBinomial { .. }
| Self::Beta { .. }
| Self::Gamma
| Self::RoystonParmar => Ok(()),
}
}
pub fn infer_from_response(
y: ArrayView1<'_, f64>,
y_kind: ResponseColumnKind,
) -> Result<Self, ResponseInferenceRefusal> {
match y_kind {
ResponseColumnKind::Categorical { levels } => Err(ResponseInferenceRefusal {
reason: ResponseInferenceRefusalReason::NonNumericResponse,
levels,
}),
ResponseColumnKind::Binary => Ok(Self::Binomial),
ResponseColumnKind::Numeric => {
let binary = !y.is_empty()
&& y.iter().all(|v| {
v.is_finite()
&& ((*v - 0.0).abs() < BINOMIAL_BINARY_TOL
|| (*v - 1.0).abs() < BINOMIAL_BINARY_TOL)
});
if binary {
return Ok(Self::Binomial);
}
let count = !y.is_empty()
&& y.iter().all(|v| {
v.is_finite() && *v >= 0.0 && (*v - v.round()).abs() <= COUNT_INTEGER_TOL
})
&& y.iter().any(|v| *v >= 2.0 - COUNT_INTEGER_TOL);
if count {
Ok(Self::Poisson)
} else {
Ok(Self::Gaussian)
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ResponseSupportViolation {
pub family_label: &'static str,
pub requirement: &'static str,
pub offending: Vec<(usize, f64)>,
pub total_violations: usize,
}
impl ResponseSupportViolation {
pub const MAX_REPORTED: usize = 5;
pub fn message_for(&self, response_name: &str) -> String {
let shown = self
.offending
.iter()
.map(|(i, v)| format!("y[{i}]={v}"))
.collect::<Vec<_>>()
.join(", ");
let more = if self.total_violations > self.offending.len() {
format!(", ... ({} total)", self.total_violations)
} else {
String::new()
};
format!(
"{family} family requires {req}; response column '{name}' violates this constraint at row(s) [{shown}{more}]",
family = self.family_label,
req = self.requirement,
name = response_name,
)
}
}
impl std::fmt::Display for ResponseSupportViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseSupportViolation {}
pub const BINOMIAL_BINARY_TOL: f64 = 1.0e-12;
pub const COUNT_INTEGER_TOL: f64 = 1.0e-9;
#[derive(Debug, Clone)]
pub enum ResponseDegeneracyKind {
BinomialAllZeros,
BinomialAllOnes,
}
#[derive(Debug, Clone)]
pub struct ResponseDegeneracy {
pub family_label: &'static str,
pub kind: ResponseDegeneracyKind,
}
impl ResponseDegeneracy {
pub fn message_for(&self, response_name: &str) -> String {
match self.kind {
ResponseDegeneracyKind::BinomialAllZeros => format!(
"{family} response '{name}' is degenerate: all values are 0 (no events). \
The maximum-likelihood logit is −∞ at this boundary, so the REML score \
is not finite. Fix: ensure the response contains at least one 0 and \
at least one 1 (e.g. drop the offending subgroup, or refit on a pooled \
sample that includes both classes).",
family = self.family_label,
name = response_name,
),
ResponseDegeneracyKind::BinomialAllOnes => format!(
"{family} response '{name}' is degenerate: all values are 1 (no non-events). \
The maximum-likelihood logit is +∞ at this boundary, so the REML score \
is not finite. Fix: ensure the response contains at least one 0 and \
at least one 1 (e.g. drop the offending subgroup, or refit on a pooled \
sample that includes both classes).",
family = self.family_label,
name = response_name,
),
}
}
}
impl std::fmt::Display for ResponseDegeneracy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseDegeneracy {}
#[derive(Debug, Clone)]
pub enum ResponseColumnKind {
Numeric,
Binary,
Categorical { levels: Vec<String> },
}
#[derive(Debug, Clone)]
pub enum ResponseInferenceRefusalReason {
NonNumericResponse,
}
#[derive(Debug, Clone)]
pub struct ResponseInferenceRefusal {
pub reason: ResponseInferenceRefusalReason,
pub levels: Vec<String>,
}
impl ResponseInferenceRefusal {
pub fn message_for(&self, response_name: &str) -> String {
match self.reason {
ResponseInferenceRefusalReason::NonNumericResponse => {
let n = self.levels.len().min(5);
let head = self
.levels
.iter()
.take(n)
.map(|s| format!("'{s}'"))
.collect::<Vec<_>>()
.join(", ");
let preview = if self.levels.len() > n {
format!("[{head}, ...]")
} else {
format!("[{head}]")
};
format!(
"response column '{name}' contains non-numeric values {preview}. \
Did you mean to use family='binomial' for a binary outcome, \
or does '{name}' contain categorical labels that should be encoded first?",
name = response_name,
preview = preview,
)
}
}
}
}
impl std::fmt::Display for ResponseInferenceRefusal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseInferenceRefusal {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(try_from = "LikelihoodSpecWire", into = "LikelihoodSpecWire")]
pub struct LikelihoodSpec {
pub response: ResponseFamily,
pub link: InverseLink,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LikelihoodSpecWire {
pub response: ResponseFamily,
pub link: InverseLink,
}
impl From<LikelihoodSpec> for LikelihoodSpecWire {
#[inline]
fn from(spec: LikelihoodSpec) -> Self {
Self {
response: spec.response,
link: spec.link,
}
}
}
impl TryFrom<LikelihoodSpecWire> for LikelihoodSpec {
type Error = IllegalLikelihoodCell;
#[inline]
fn try_from(wire: LikelihoodSpecWire) -> Result<Self, Self::Error> {
Self::try_new(wire.response, wire.link)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct IllegalLikelihoodCell {
pub response: &'static str,
pub link: &'static str,
}
impl std::fmt::Display for IllegalLikelihoodCell {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"illegal likelihood cell: response `{}` does not admit inverse link `{}`. \
Each non-binomial family is pinned to one link (Gaussian/Royston-Parmar→identity, \
Poisson/Gamma/Tweedie/Negative-Binomial→log, Beta→logit); the binomial family \
admits logit/probit/cloglog and the latent-cloglog/SAS/beta-logistic/blended \
links, but not identity/log.",
self.response, self.link
)
}
}
impl std::error::Error for IllegalLikelihoodCell {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum FamilySpecKind {
GaussianIdentity,
PoissonLog,
GammaLog,
TweedieLog { p: f64 },
NegativeBinomialLog { theta: f64 },
BetaLogit { phi: f64 },
RoystonParmar,
BinomialLogit,
BinomialProbit,
BinomialCLogLog,
BinomialLatentCLogLog(LatentCLogLogState),
BinomialSas(SasLinkState),
BinomialBetaLogistic(SasLinkState),
BinomialMixture(MixtureLinkState),
}
impl FamilySpecKind {
#[inline]
pub const fn name(&self) -> &'static str {
match self {
Self::GaussianIdentity => "gaussian",
Self::PoissonLog => "poisson-log",
Self::TweedieLog { .. } => "tweedie-log",
Self::NegativeBinomialLog { .. } => "negative-binomial-log",
Self::BetaLogit { .. } => "beta-regression-logit",
Self::GammaLog => "gamma-log",
Self::RoystonParmar => "royston-parmar",
Self::BinomialLogit => "binomial-logit",
Self::BinomialProbit => "binomial-probit",
Self::BinomialCLogLog => "binomial-cloglog",
Self::BinomialLatentCLogLog(_) => "latent-cloglog-binomial",
Self::BinomialSas(_) => "binomial-sas",
Self::BinomialBetaLogistic(_) => "binomial-beta-logistic",
Self::BinomialMixture(_) => "binomial-blended-inverse-link",
}
}
#[inline]
pub const fn pretty_name(&self) -> &'static str {
match self {
Self::GaussianIdentity => "Gaussian Identity",
Self::PoissonLog => "Poisson Log",
Self::TweedieLog { .. } => "Tweedie Log",
Self::NegativeBinomialLog { .. } => "Negative-Binomial Log",
Self::BetaLogit { .. } => "Beta Regression Logit",
Self::GammaLog => "Gamma Log",
Self::RoystonParmar => "Royston Parmar",
Self::BinomialLogit => "Binomial Logit",
Self::BinomialProbit => "Binomial Probit",
Self::BinomialCLogLog => "Binomial CLogLog",
Self::BinomialLatentCLogLog(_) => "Latent CLogLog Binomial",
Self::BinomialSas(_) => "Binomial SAS",
Self::BinomialBetaLogistic(_) => "Binomial Beta-Logistic",
Self::BinomialMixture(_) => "Binomial Blended Inverse-Link",
}
}
#[inline]
pub const fn is_binomial(&self) -> bool {
matches!(
self,
Self::BinomialLogit
| Self::BinomialProbit
| Self::BinomialCLogLog
| Self::BinomialLatentCLogLog(_)
| Self::BinomialSas(_)
| Self::BinomialBetaLogistic(_)
| Self::BinomialMixture(_)
)
}
#[inline]
pub const fn is_gaussian_identity(&self) -> bool {
matches!(self, Self::GaussianIdentity)
}
#[inline]
pub const fn is_royston_parmar(&self) -> bool {
matches!(self, Self::RoystonParmar)
}
#[inline]
pub const fn is_latent_cloglog(&self) -> bool {
matches!(self, Self::BinomialLatentCLogLog(_))
}
#[inline]
pub const fn is_binomial_mixture(&self) -> bool {
matches!(self, Self::BinomialMixture(_))
}
#[inline]
pub const fn is_binomial_sas(&self) -> bool {
matches!(self, Self::BinomialSas(_))
}
#[inline]
pub const fn is_binomial_beta_logistic(&self) -> bool {
matches!(self, Self::BinomialBetaLogistic(_))
}
#[inline]
pub const fn supports_firth(&self) -> bool {
self.is_binomial()
}
}
impl LikelihoodSpec {
#[inline]
pub const fn new(response: ResponseFamily, link: InverseLink) -> Self {
Self { response, link }
}
#[inline]
pub fn is_legal_cell(response: &ResponseFamily, link: &InverseLink) -> bool {
match response {
ResponseFamily::Gaussian | ResponseFamily::RoystonParmar => {
matches!(link, InverseLink::Standard(StandardLink::Identity))
}
ResponseFamily::Poisson
| ResponseFamily::Gamma
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => {
matches!(link, InverseLink::Standard(StandardLink::Log))
}
ResponseFamily::Beta { .. } => {
matches!(link, InverseLink::Standard(StandardLink::Logit))
}
ResponseFamily::Binomial => match link {
InverseLink::Standard(
StandardLink::Logit | StandardLink::Probit | StandardLink::CLogLog,
) => true,
InverseLink::Standard(StandardLink::Identity | StandardLink::Log) => false,
InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => true,
},
}
}
#[inline]
pub fn try_new(
response: ResponseFamily,
link: InverseLink,
) -> Result<Self, IllegalLikelihoodCell> {
if Self::is_legal_cell(&response, &link) {
Ok(Self::new(response, link))
} else {
Err(IllegalLikelihoodCell {
response: response.name(),
link: link.link_function().name(),
})
}
}
#[inline]
pub const fn gaussian_identity() -> Self {
Self::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
)
}
#[inline]
pub const fn binomial_logit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
)
}
#[inline]
pub const fn binomial_probit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
)
}
#[inline]
pub const fn binomial_cloglog() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
)
}
#[inline]
pub const fn binomial_latent_cloglog(state: LatentCLogLogState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::LatentCLogLog(state))
}
#[inline]
pub const fn binomial_sas(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Sas(state))
}
#[inline]
pub const fn binomial_beta_logistic(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::BetaLogistic(state))
}
#[inline]
pub fn binomial_mixture(state: MixtureLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Mixture(state))
}
#[inline]
pub const fn poisson_log() -> Self {
Self::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn tweedie_log(p: f64) -> Self {
Self::new(
ResponseFamily::Tweedie { p },
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn negative_binomial_log(theta: f64) -> Self {
Self::new(
ResponseFamily::NegativeBinomial {
theta,
theta_fixed: false,
},
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn negative_binomial_log_fixed(theta: f64) -> Self {
Self::new(
ResponseFamily::NegativeBinomial {
theta,
theta_fixed: true,
},
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn beta_logit(phi: f64) -> Self {
Self::new(
ResponseFamily::Beta { phi },
InverseLink::Standard(StandardLink::Logit),
)
}
#[inline]
pub const fn gamma_log() -> Self {
Self::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn royston_parmar() -> Self {
Self::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(StandardLink::Identity),
)
}
#[inline]
pub const fn link_function(&self) -> LinkFunction {
self.link.link_function()
}
pub fn kind(&self) -> FamilySpecKind {
self.legal_cell_kind().expect(
"illegal likelihood cell reached kind(): construction (try_new) and \
deserialization (LikelihoodSpecWire) guarantee legality",
)
}
fn legal_cell_kind(&self) -> Option<FamilySpecKind> {
Some(match (&self.response, &self.link) {
(ResponseFamily::Gaussian, InverseLink::Standard(StandardLink::Identity)) => {
FamilySpecKind::GaussianIdentity
}
(ResponseFamily::RoystonParmar, InverseLink::Standard(StandardLink::Identity)) => {
FamilySpecKind::RoystonParmar
}
(ResponseFamily::Poisson, InverseLink::Standard(StandardLink::Log)) => {
FamilySpecKind::PoissonLog
}
(ResponseFamily::Gamma, InverseLink::Standard(StandardLink::Log)) => {
FamilySpecKind::GammaLog
}
(ResponseFamily::Tweedie { p }, InverseLink::Standard(StandardLink::Log)) => {
FamilySpecKind::TweedieLog { p: *p }
}
(
ResponseFamily::NegativeBinomial { theta, .. },
InverseLink::Standard(StandardLink::Log),
) => FamilySpecKind::NegativeBinomialLog { theta: *theta },
(ResponseFamily::Beta { phi }, InverseLink::Standard(StandardLink::Logit)) => {
FamilySpecKind::BetaLogit { phi: *phi }
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
FamilySpecKind::BinomialLogit
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
FamilySpecKind::BinomialProbit
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
FamilySpecKind::BinomialCLogLog
}
(ResponseFamily::Binomial, InverseLink::LatentCLogLog(state)) => {
FamilySpecKind::BinomialLatentCLogLog(*state)
}
(ResponseFamily::Binomial, InverseLink::Sas(state)) => {
FamilySpecKind::BinomialSas(*state)
}
(ResponseFamily::Binomial, InverseLink::BetaLogistic(state)) => {
FamilySpecKind::BinomialBetaLogistic(*state)
}
(ResponseFamily::Binomial, InverseLink::Mixture(state)) => {
FamilySpecKind::BinomialMixture(state.clone())
}
_ => return None,
})
}
#[inline]
pub fn is_binomial(&self) -> bool {
self.kind().is_binomial()
}
#[inline]
pub fn is_gaussian_identity(&self) -> bool {
self.kind().is_gaussian_identity()
}
#[inline]
pub fn is_royston_parmar(&self) -> bool {
self.kind().is_royston_parmar()
}
#[inline]
pub fn is_latent_cloglog(&self) -> bool {
self.kind().is_latent_cloglog()
}
#[inline]
pub fn is_binomial_mixture(&self) -> bool {
self.kind().is_binomial_mixture()
}
#[inline]
pub fn is_binomial_sas(&self) -> bool {
self.kind().is_binomial_sas()
}
#[inline]
pub fn is_binomial_beta_logistic(&self) -> bool {
self.kind().is_binomial_beta_logistic()
}
#[inline]
pub fn default_scale_metadata(&self) -> LikelihoodScaleMetadata {
match &self.response {
ResponseFamily::Gaussian => LikelihoodScaleMetadata::ProfiledGaussian,
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape: 1.0 },
ResponseFamily::Binomial | ResponseFamily::Poisson => {
LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 }
}
ResponseFamily::NegativeBinomial { theta, theta_fixed } => {
if *theta_fixed {
LikelihoodScaleMetadata::FixedNegBinTheta { theta: *theta }
} else {
LikelihoodScaleMetadata::EstimatedNegBinTheta { theta: *theta }
}
}
ResponseFamily::Tweedie { .. } => {
LikelihoodScaleMetadata::EstimatedTweediePhi { phi: 1.0 }
}
ResponseFamily::Beta { phi } => LikelihoodScaleMetadata::EstimatedBetaPhi { phi: *phi },
ResponseFamily::RoystonParmar => LikelihoodScaleMetadata::Unspecified,
}
}
#[inline]
pub fn pretty_name(&self) -> &'static str {
self.kind().pretty_name()
}
#[inline]
pub fn name(&self) -> &'static str {
self.kind().name()
}
#[inline]
pub fn supports_firth(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& inverse_link_has_fisher_weight_jet(&self.link)
}
#[inline]
pub const fn fixed_dispersion(&self) -> Option<f64> {
match self.response {
ResponseFamily::Gaussian | ResponseFamily::Gamma | ResponseFamily::RoystonParmar => {
None
}
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => Some(1.0),
ResponseFamily::Beta { phi } => Some(phi),
}
}
}
#[inline]
pub const fn is_valid_tweedie_power(p: f64) -> bool {
p.is_finite() && p > 1.0 && p < 2.0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnsupportedLinkError {
pub family: &'static str,
pub link_name: String,
}
impl UnsupportedLinkError {
#[inline]
pub fn new(family: &'static str, link: &InverseLink) -> Self {
Self {
family,
link_name: inverse_link_diagnostic_name(link),
}
}
}
impl std::fmt::Display for UnsupportedLinkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"inverse link `{}` is not supported by the {} response family",
self.link_name, self.family
)
}
}
impl std::error::Error for UnsupportedLinkError {}
#[inline]
pub fn inverse_link_diagnostic_name(link: &InverseLink) -> String {
match link {
InverseLink::Standard(lf) => lf.name().to_string(),
InverseLink::LatentCLogLog(_) => "latent-cloglog".to_string(),
InverseLink::Sas(_) => "sas".to_string(),
InverseLink::BetaLogistic(_) => "beta-logistic".to_string(),
InverseLink::Mixture(_) => "mixture".to_string(),
}
}
#[inline]
pub fn inverse_link_to_binomial_spec(
link: &InverseLink,
) -> Result<LikelihoodSpec, UnsupportedLinkError> {
match link {
InverseLink::Standard(StandardLink::Logit)
| InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog) => {
Ok(LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()))
}
InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => {
Ok(LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()))
}
InverseLink::Standard(StandardLink::Log)
| InverseLink::Standard(StandardLink::Identity) => {
Err(UnsupportedLinkError::new("binomial", link))
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LikelihoodScaleMetadata {
ProfiledGaussian,
FixedDispersion { phi: f64 },
FixedGammaShape { shape: f64 },
EstimatedGammaShape { shape: f64 },
EstimatedBetaPhi { phi: f64 },
EstimatedTweediePhi { phi: f64 },
EstimatedNegBinTheta { theta: f64 },
FixedNegBinTheta { theta: f64 },
Unspecified,
}
impl LikelihoodScaleMetadata {
#[inline]
pub const fn fixed_phi(self) -> Option<f64> {
match self {
Self::FixedDispersion { phi }
| Self::EstimatedBetaPhi { phi }
| Self::EstimatedTweediePhi { phi } => Some(phi),
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => {
Some(1.0 / shape)
}
Self::EstimatedNegBinTheta { .. } | Self::FixedNegBinTheta { .. } => Some(1.0),
Self::ProfiledGaussian | Self::Unspecified => None,
}
}
#[inline]
pub const fn negbin_theta_is_estimated(self) -> bool {
matches!(self, Self::EstimatedNegBinTheta { .. })
}
#[inline]
pub const fn negbin_theta(self) -> Option<f64> {
match self {
Self::EstimatedNegBinTheta { theta } | Self::FixedNegBinTheta { theta } => Some(theta),
_ => None,
}
}
#[inline]
pub const fn beta_phi_is_estimated(self) -> bool {
matches!(self, Self::EstimatedBetaPhi { .. })
}
#[inline]
pub const fn tweedie_phi_is_estimated(self) -> bool {
matches!(self, Self::EstimatedTweediePhi { .. })
}
#[inline]
pub const fn gamma_shape(self) -> Option<f64> {
match self {
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => Some(shape),
_ => None,
}
}
#[inline]
pub const fn gamma_shape_is_estimated(self) -> bool {
matches!(self, Self::EstimatedGammaShape { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogLikelihoodNormalization {
Full,
OmittingResponseConstants,
UserProvided,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GlmLikelihoodSpec {
pub spec: LikelihoodSpec,
pub scale: LikelihoodScaleMetadata,
}
impl GlmLikelihoodSpec {
#[inline]
pub fn canonical(spec: LikelihoodSpec) -> Self {
let scale = spec.default_scale_metadata();
Self { spec, scale }
}
#[inline]
pub fn link_function(&self) -> LinkFunction {
self.spec.link_function()
}
#[inline]
pub fn fixed_phi(&self) -> Option<f64> {
self.scale.fixed_phi()
}
#[inline]
pub fn coefficient_covariance_scale(&self, profiled_gaussian_phi: f64) -> f64 {
match self.scale {
LikelihoodScaleMetadata::ProfiledGaussian => profiled_gaussian_phi,
LikelihoodScaleMetadata::FixedDispersion { .. }
| LikelihoodScaleMetadata::FixedGammaShape { .. }
| LikelihoodScaleMetadata::EstimatedGammaShape { .. }
| LikelihoodScaleMetadata::EstimatedBetaPhi { .. }
| LikelihoodScaleMetadata::EstimatedTweediePhi { .. }
| LikelihoodScaleMetadata::EstimatedNegBinTheta { .. }
| LikelihoodScaleMetadata::FixedNegBinTheta { .. }
| LikelihoodScaleMetadata::Unspecified => 1.0,
}
}
#[inline]
pub fn gamma_shape(&self) -> Option<f64> {
self.scale.gamma_shape()
}
#[inline]
pub fn with_gamma_shape(mut self, shape: f64) -> Self {
self.scale = match self.scale {
LikelihoodScaleMetadata::FixedGammaShape { .. } => {
LikelihoodScaleMetadata::FixedGammaShape { shape }
}
LikelihoodScaleMetadata::EstimatedGammaShape { .. } => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape }
}
other => match &self.spec.response {
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape },
_ => other,
},
};
self
}
#[inline]
pub fn beta_phi_is_estimated(&self) -> bool {
self.scale.beta_phi_is_estimated()
}
#[inline]
pub fn with_beta_phi(mut self, phi: f64) -> Self {
if let ResponseFamily::Beta { phi: family_phi } = &mut self.spec.response {
*family_phi = phi;
self.scale = LikelihoodScaleMetadata::EstimatedBetaPhi { phi };
}
self
}
#[inline]
pub fn tweedie_phi_is_estimated(&self) -> bool {
self.scale.tweedie_phi_is_estimated()
}
#[inline]
pub fn with_tweedie_phi(mut self, phi: f64) -> Self {
if matches!(self.spec.response, ResponseFamily::Tweedie { .. }) {
self.scale = LikelihoodScaleMetadata::EstimatedTweediePhi { phi };
}
self
}
#[inline]
pub fn negbin_theta_is_estimated(&self) -> bool {
self.scale.negbin_theta_is_estimated()
}
#[inline]
pub fn with_negbin_theta(mut self, theta: f64) -> Self {
if let ResponseFamily::NegativeBinomial {
theta: family_theta,
theta_fixed,
} = &mut self.spec.response
&& !*theta_fixed
{
*family_theta = theta;
self.scale = LikelihoodScaleMetadata::EstimatedNegBinTheta { theta };
}
self
}
#[inline]
pub fn negbin_theta(&self) -> Option<f64> {
match self.spec.response {
ResponseFamily::NegativeBinomial { theta, .. } => Some(theta),
_ => None,
}
}
#[inline]
pub fn with_negbin_theta_frozen_for_search(mut self, theta: f64) -> Self {
if let ResponseFamily::NegativeBinomial {
theta: family_theta,
theta_fixed,
} = &mut self.spec.response
{
*family_theta = theta;
*theta_fixed = true;
self.scale = LikelihoodScaleMetadata::FixedNegBinTheta { theta };
}
self
}
}