use ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
use std::ops::{Deref, DerefMut};
pub use crate::hull::PeeledHull;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WigglePenaltyConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl WigglePenaltyConfig {
pub fn cubic_triple_operator_default() -> Self {
Self {
degree: 3,
num_internal_knots: 8,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
monotonicity_eps: 1e-4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkFunction {
Logit,
Probit,
CLogLog,
Sas,
BetaLogistic,
Identity,
Log,
}
impl LinkFunction {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Logit => "logit",
Self::Probit => "probit",
Self::CLogLog => "cloglog",
Self::Sas => "sas",
Self::BetaLogistic => "beta-logistic",
Self::Identity => "identity",
Self::Log => "log",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StandardLink {
Logit,
Probit,
CLogLog,
Identity,
Log,
}
impl StandardLink {
#[inline]
pub const fn name(self) -> &'static str {
self.as_link_function().name()
}
#[inline]
pub const fn as_link_function(self) -> LinkFunction {
match self {
Self::Logit => LinkFunction::Logit,
Self::Probit => LinkFunction::Probit,
Self::CLogLog => LinkFunction::CLogLog,
Self::Identity => LinkFunction::Identity,
Self::Log => LinkFunction::Log,
}
}
}
impl From<StandardLink> for LinkFunction {
#[inline]
fn from(link: StandardLink) -> Self {
link.as_link_function()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StateBearingLinkInStandardSlot(pub LinkFunction);
impl std::fmt::Display for StateBearingLinkInStandardSlot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"state-bearing link `{}` cannot be carried by `InverseLink::Standard`; \
route through `InverseLink::Sas` / `InverseLink::BetaLogistic`",
self.0.name()
)
}
}
impl std::error::Error for StateBearingLinkInStandardSlot {}
impl TryFrom<LinkFunction> for StandardLink {
type Error = StateBearingLinkInStandardSlot;
#[inline]
fn try_from(link: LinkFunction) -> Result<Self, Self::Error> {
match link {
LinkFunction::Logit => Ok(Self::Logit),
LinkFunction::Probit => Ok(Self::Probit),
LinkFunction::CLogLog => Ok(Self::CLogLog),
LinkFunction::Identity => Ok(Self::Identity),
LinkFunction::Log => Ok(Self::Log),
LinkFunction::Sas | LinkFunction::BetaLogistic => {
Err(StateBearingLinkInStandardSlot(link))
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkComponent {
Probit,
Logit,
CLogLog,
LogLog,
Cauchit,
}
impl LinkComponent {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Probit => "probit",
Self::Logit => "logit",
Self::CLogLog => "cloglog",
Self::LogLog => "loglog",
Self::Cauchit => "cauchit",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkSpec {
pub components: Vec<LinkComponent>,
pub initial_rho: Array1<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkState {
pub components: Vec<LinkComponent>,
pub rho: Array1<f64>,
pub pi: Array1<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkSpec {
pub initial_epsilon: f64,
pub initial_log_delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkState {
pub epsilon: f64,
pub log_delta: f64,
pub delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct LatentCLogLogState {
pub latent_sd: f64,
}
impl LatentCLogLogState {
#[inline]
pub fn new(latent_sd: f64) -> Result<Self, String> {
if !latent_sd.is_finite() || latent_sd < 0.0 {
return Err(format!(
"latent cloglog standard deviation must be finite and >= 0, got {latent_sd}"
));
}
Ok(Self { latent_sd })
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum InverseLink {
Standard(StandardLink),
LatentCLogLog(LatentCLogLogState),
Sas(SasLinkState),
BetaLogistic(SasLinkState),
Mixture(MixtureLinkState),
}
impl InverseLink {
#[inline]
pub const fn link_function(&self) -> LinkFunction {
match self {
Self::Standard(link) => link.as_link_function(),
Self::LatentCLogLog(_) => LinkFunction::CLogLog,
Self::Sas(_) => LinkFunction::Sas,
Self::BetaLogistic(_) => LinkFunction::BetaLogistic,
Self::Mixture(_) => LinkFunction::Logit,
}
}
#[inline]
pub const fn mixture_state(&self) -> Option<&MixtureLinkState> {
match self {
Self::Mixture(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn sas_state(&self) -> Option<&SasLinkState> {
match self {
Self::Sas(state) | Self::BetaLogistic(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn latent_cloglog_state(&self) -> Option<&LatentCLogLogState> {
match self {
Self::LatentCLogLog(state) => Some(state),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RhoPrior {
Flat,
Normal {
mean: f64,
sd: f64,
},
GammaPrecision {
shape: f64,
rate: f64,
},
Independent(Vec<RhoPrior>),
}
impl Default for RhoPrior {
fn default() -> Self {
Self::Normal { mean: 0.0, sd: 3.0 }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ResponseFamily {
Gaussian,
Binomial,
Poisson,
Tweedie { p: f64 },
NegativeBinomial { theta: f64 },
Beta { phi: f64 },
Gamma,
RoystonParmar,
}
impl ResponseFamily {
#[inline]
pub const fn name(&self) -> &'static str {
match self {
Self::Gaussian => "gaussian",
Self::Binomial => "binomial",
Self::Poisson => "poisson",
Self::Tweedie { .. } => "tweedie",
Self::NegativeBinomial { .. } => "negative-binomial",
Self::Beta { .. } => "beta",
Self::Gamma => "gamma",
Self::RoystonParmar => "royston-parmar",
}
}
#[inline]
pub fn mean_clamp_bounds(&self) -> Option<(f64, f64)> {
match self {
Self::Binomial | Self::RoystonParmar => Some((0.0, 1.0)),
Self::Beta { .. } => Some((1e-10, 1.0 - 1e-10)),
Self::Gaussian
| Self::Poisson
| Self::Tweedie { .. }
| Self::NegativeBinomial { .. }
| Self::Gamma => None,
}
}
#[inline]
pub fn response_support_requirement(&self) -> Option<&'static str> {
match self {
Self::Gamma => Some("strictly positive response values (y > 0)"),
Self::Poisson | Self::NegativeBinomial { .. } | Self::Tweedie { .. } => {
Some("non-negative response values (y ≥ 0)")
}
Self::Beta { .. } => Some("response values strictly in the open interval (0, 1)"),
Self::Gaussian | Self::Binomial | Self::RoystonParmar => None,
}
}
#[inline]
fn response_support_contains(&self, yi: f64) -> bool {
match self {
Self::Gamma => yi.is_finite() && yi > 0.0,
Self::Poisson | Self::NegativeBinomial { .. } | Self::Tweedie { .. } => {
yi.is_finite() && yi >= 0.0
}
Self::Beta { .. } => yi.is_finite() && yi > 0.0 && yi < 1.0,
Self::Gaussian | Self::Binomial | Self::RoystonParmar => true,
}
}
#[inline]
fn response_support_label(&self) -> &'static str {
match self {
Self::Gaussian => "Gaussian",
Self::Binomial => "Binomial",
Self::Poisson => "Poisson",
Self::Tweedie { .. } => "Tweedie",
Self::NegativeBinomial { .. } => "Negative-Binomial",
Self::Beta { .. } => "Beta",
Self::Gamma => "Gamma",
Self::RoystonParmar => "Royston-Parmar",
}
}
pub fn validate_response_support(
&self,
y: ArrayView1<'_, f64>,
) -> Result<(), ResponseSupportViolation> {
let requirement = match self.response_support_requirement() {
Some(r) => r,
None => return Ok(()),
};
let mut offending: Vec<(usize, f64)> = Vec::new();
let mut total_violations: usize = 0;
for (i, &yi) in y.iter().enumerate() {
if !self.response_support_contains(yi) {
total_violations += 1;
if offending.len() < ResponseSupportViolation::MAX_REPORTED {
offending.push((i, yi));
}
}
}
if total_violations == 0 {
Ok(())
} else {
Err(ResponseSupportViolation {
family_label: self.response_support_label(),
requirement,
offending,
total_violations,
})
}
}
pub fn validate_response_degeneracy(
&self,
y: ArrayView1<'_, f64>,
) -> Result<(), ResponseDegeneracy> {
match self {
Self::Binomial => {
let mut saw_zero = false;
let mut saw_one = false;
for &yi in y.iter() {
if (yi - 0.0).abs() < 1e-12 {
saw_zero = true;
} else if (yi - 1.0).abs() < 1e-12 {
saw_one = true;
}
if saw_zero && saw_one {
return Ok(());
}
}
let kind = if saw_one {
ResponseDegeneracyKind::BinomialAllOnes
} else if saw_zero {
ResponseDegeneracyKind::BinomialAllZeros
} else {
return Ok(());
};
Err(ResponseDegeneracy {
family_label: self.response_support_label(),
kind,
})
}
Self::Gaussian => {
if y.is_empty() {
return Ok(());
}
let n = y.len();
let mean: f64 = y.iter().copied().sum::<f64>() / (n as f64);
let ssq: f64 = y.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>();
let var = if n > 1 { ssq / ((n - 1) as f64) } else { ssq };
let sd = var.sqrt();
if !sd.is_finite() || sd <= GAUSSIAN_MIN_SAMPLE_SD {
Err(ResponseDegeneracy {
family_label: self.response_support_label(),
kind: ResponseDegeneracyKind::GaussianNearConstant {
sample_sd: sd,
min_sd: GAUSSIAN_MIN_SAMPLE_SD,
},
})
} else {
Ok(())
}
}
Self::Poisson
| Self::Tweedie { .. }
| Self::NegativeBinomial { .. }
| Self::Beta { .. }
| Self::Gamma
| Self::RoystonParmar => Ok(()),
}
}
pub fn infer_from_response(
y: ArrayView1<'_, f64>,
y_kind: ResponseColumnKind,
) -> Result<Self, ResponseInferenceRefusal> {
match y_kind {
ResponseColumnKind::Categorical { levels } => Err(ResponseInferenceRefusal {
reason: ResponseInferenceRefusalReason::NonNumericResponse,
levels,
}),
ResponseColumnKind::Binary => Ok(Self::Binomial),
ResponseColumnKind::Numeric => {
let binary = !y.is_empty()
&& y.iter().all(|v| {
v.is_finite() && ((*v - 0.0).abs() < 1e-12 || (*v - 1.0).abs() < 1e-12)
});
if binary {
Ok(Self::Binomial)
} else {
Ok(Self::Gaussian)
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct ResponseSupportViolation {
pub family_label: &'static str,
pub requirement: &'static str,
pub offending: Vec<(usize, f64)>,
pub total_violations: usize,
}
impl ResponseSupportViolation {
pub const MAX_REPORTED: usize = 5;
pub fn message_for(&self, response_name: &str) -> String {
let shown = self
.offending
.iter()
.map(|(i, v)| format!("y[{i}]={v}"))
.collect::<Vec<_>>()
.join(", ");
let more = if self.total_violations > self.offending.len() {
format!(", ... ({} total)", self.total_violations)
} else {
String::new()
};
format!(
"{family} family requires {req}; response column '{name}' violates this constraint at row(s) [{shown}{more}]",
family = self.family_label,
req = self.requirement,
name = response_name,
)
}
}
impl std::fmt::Display for ResponseSupportViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseSupportViolation {}
pub const GAUSSIAN_MIN_SAMPLE_SD: f64 = 1.0e-10;
#[derive(Debug, Clone)]
pub enum ResponseDegeneracyKind {
BinomialAllZeros,
BinomialAllOnes,
GaussianNearConstant { sample_sd: f64, min_sd: f64 },
}
#[derive(Debug, Clone)]
pub struct ResponseDegeneracy {
pub family_label: &'static str,
pub kind: ResponseDegeneracyKind,
}
impl ResponseDegeneracy {
pub fn message_for(&self, response_name: &str) -> String {
match self.kind {
ResponseDegeneracyKind::BinomialAllZeros => format!(
"{family} response '{name}' is degenerate: all values are 0 (no events). \
The maximum-likelihood logit is −∞ at this boundary, so the REML score \
is not finite. Fix: ensure the response contains at least one 0 and \
at least one 1 (e.g. drop the offending subgroup, or refit on a pooled \
sample that includes both classes).",
family = self.family_label,
name = response_name,
),
ResponseDegeneracyKind::BinomialAllOnes => format!(
"{family} response '{name}' is degenerate: all values are 1 (no non-events). \
The maximum-likelihood logit is +∞ at this boundary, so the REML score \
is not finite. Fix: ensure the response contains at least one 0 and \
at least one 1 (e.g. drop the offending subgroup, or refit on a pooled \
sample that includes both classes).",
family = self.family_label,
name = response_name,
),
ResponseDegeneracyKind::GaussianNearConstant { sample_sd, min_sd } => format!(
"{family} response '{name}' is effectively constant (sample sd ≈ {sd:.3e} ≤ {floor:.0e}); \
the marginal REML log-likelihood −n/2·log σ² diverges to +∞ as σ → 0. \
Fix: check the response column units (is it being read in the right \
scale?), centre/rescale the response, or drop the column if it carries \
no signal.",
family = self.family_label,
name = response_name,
sd = sample_sd,
floor = min_sd,
),
}
}
}
impl std::fmt::Display for ResponseDegeneracy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseDegeneracy {}
#[derive(Debug, Clone)]
pub enum ResponseColumnKind {
Numeric,
Binary,
Categorical { levels: Vec<String> },
}
#[derive(Debug, Clone)]
pub enum ResponseInferenceRefusalReason {
NonNumericResponse,
}
#[derive(Debug, Clone)]
pub struct ResponseInferenceRefusal {
pub reason: ResponseInferenceRefusalReason,
pub levels: Vec<String>,
}
impl ResponseInferenceRefusal {
pub fn message_for(&self, response_name: &str) -> String {
match self.reason {
ResponseInferenceRefusalReason::NonNumericResponse => {
let n = self.levels.len().min(5);
let head = self
.levels
.iter()
.take(n)
.map(|s| format!("'{s}'"))
.collect::<Vec<_>>()
.join(", ");
let preview = if self.levels.len() > n {
format!("[{head}, ...]")
} else {
format!("[{head}]")
};
format!(
"response column '{name}' contains non-numeric values {preview}. \
Did you mean to use family='binomial' for a binary outcome, \
or does '{name}' contain categorical labels that should be encoded first?",
name = response_name,
preview = preview,
)
}
}
}
}
impl std::fmt::Display for ResponseInferenceRefusal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message_for("y"))
}
}
impl std::error::Error for ResponseInferenceRefusal {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LikelihoodSpec {
pub response: ResponseFamily,
pub link: InverseLink,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum FamilySpecKind {
GaussianIdentity,
PoissonLog,
GammaLog,
TweedieLog { p: f64 },
NegativeBinomialLog { theta: f64 },
BetaLogit { phi: f64 },
RoystonParmar,
BinomialLogit,
BinomialProbit,
BinomialCLogLog,
BinomialLatentCLogLog(LatentCLogLogState),
BinomialSas(SasLinkState),
BinomialBetaLogistic(SasLinkState),
BinomialMixture(MixtureLinkState),
}
impl FamilySpecKind {
#[inline]
pub const fn name(&self) -> &'static str {
match self {
Self::GaussianIdentity => "gaussian",
Self::PoissonLog => "poisson-log",
Self::TweedieLog { .. } => "tweedie-log",
Self::NegativeBinomialLog { .. } => "negative-binomial-log",
Self::BetaLogit { .. } => "beta-regression-logit",
Self::GammaLog => "gamma-log",
Self::RoystonParmar => "royston-parmar",
Self::BinomialLogit => "binomial-logit",
Self::BinomialProbit => "binomial-probit",
Self::BinomialCLogLog => "binomial-cloglog",
Self::BinomialLatentCLogLog(_) => "latent-cloglog-binomial",
Self::BinomialSas(_) => "binomial-sas",
Self::BinomialBetaLogistic(_) => "binomial-beta-logistic",
Self::BinomialMixture(_) => "binomial-blended-inverse-link",
}
}
#[inline]
pub const fn pretty_name(&self) -> &'static str {
match self {
Self::GaussianIdentity => "Gaussian Identity",
Self::PoissonLog => "Poisson Log",
Self::TweedieLog { .. } => "Tweedie Log",
Self::NegativeBinomialLog { .. } => "Negative-Binomial Log",
Self::BetaLogit { .. } => "Beta Regression Logit",
Self::GammaLog => "Gamma Log",
Self::RoystonParmar => "Royston Parmar",
Self::BinomialLogit => "Binomial Logit",
Self::BinomialProbit => "Binomial Probit",
Self::BinomialCLogLog => "Binomial CLogLog",
Self::BinomialLatentCLogLog(_) => "Latent CLogLog Binomial",
Self::BinomialSas(_) => "Binomial SAS",
Self::BinomialBetaLogistic(_) => "Binomial Beta-Logistic",
Self::BinomialMixture(_) => "Binomial Blended Inverse-Link",
}
}
#[inline]
pub const fn is_binomial(&self) -> bool {
matches!(
self,
Self::BinomialLogit
| Self::BinomialProbit
| Self::BinomialCLogLog
| Self::BinomialLatentCLogLog(_)
| Self::BinomialSas(_)
| Self::BinomialBetaLogistic(_)
| Self::BinomialMixture(_)
)
}
#[inline]
pub const fn is_gaussian_identity(&self) -> bool {
matches!(self, Self::GaussianIdentity)
}
#[inline]
pub const fn is_royston_parmar(&self) -> bool {
matches!(self, Self::RoystonParmar)
}
#[inline]
pub const fn is_latent_cloglog(&self) -> bool {
matches!(self, Self::BinomialLatentCLogLog(_))
}
#[inline]
pub const fn is_binomial_mixture(&self) -> bool {
matches!(self, Self::BinomialMixture(_))
}
#[inline]
pub const fn is_binomial_sas(&self) -> bool {
matches!(self, Self::BinomialSas(_))
}
#[inline]
pub const fn is_binomial_beta_logistic(&self) -> bool {
matches!(self, Self::BinomialBetaLogistic(_))
}
#[inline]
pub const fn supports_firth(&self) -> bool {
matches!(self, Self::BinomialLogit)
}
}
impl LikelihoodSpec {
#[inline]
pub const fn new(response: ResponseFamily, link: InverseLink) -> Self {
Self { response, link }
}
#[inline]
pub const fn gaussian_identity() -> Self {
Self::new(
ResponseFamily::Gaussian,
InverseLink::Standard(StandardLink::Identity),
)
}
#[inline]
pub const fn binomial_logit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Logit),
)
}
#[inline]
pub const fn binomial_probit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Probit),
)
}
#[inline]
pub const fn binomial_cloglog() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::CLogLog),
)
}
#[inline]
pub const fn binomial_latent_cloglog(state: LatentCLogLogState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::LatentCLogLog(state))
}
#[inline]
pub const fn binomial_sas(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Sas(state))
}
#[inline]
pub const fn binomial_beta_logistic(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::BetaLogistic(state))
}
#[inline]
pub fn binomial_mixture(state: MixtureLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Mixture(state))
}
#[inline]
pub const fn poisson_log() -> Self {
Self::new(
ResponseFamily::Poisson,
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn tweedie_log(p: f64) -> Self {
Self::new(
ResponseFamily::Tweedie { p },
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn negative_binomial_log(theta: f64) -> Self {
Self::new(
ResponseFamily::NegativeBinomial { theta },
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn beta_logit(phi: f64) -> Self {
Self::new(
ResponseFamily::Beta { phi },
InverseLink::Standard(StandardLink::Logit),
)
}
#[inline]
pub const fn gamma_log() -> Self {
Self::new(
ResponseFamily::Gamma,
InverseLink::Standard(StandardLink::Log),
)
}
#[inline]
pub const fn royston_parmar() -> Self {
Self::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(StandardLink::Identity),
)
}
#[inline]
pub const fn link_function(&self) -> LinkFunction {
self.link.link_function()
}
pub fn kind(&self) -> FamilySpecKind {
match (&self.response, &self.link) {
(ResponseFamily::Gaussian, _) => FamilySpecKind::GaussianIdentity,
(ResponseFamily::Poisson, _) => FamilySpecKind::PoissonLog,
(ResponseFamily::Tweedie { p }, _) => FamilySpecKind::TweedieLog { p: *p },
(ResponseFamily::NegativeBinomial { theta }, _) => {
FamilySpecKind::NegativeBinomialLog { theta: *theta }
}
(ResponseFamily::Beta { phi }, _) => FamilySpecKind::BetaLogit { phi: *phi },
(ResponseFamily::Gamma, _) => FamilySpecKind::GammaLog,
(ResponseFamily::RoystonParmar, _) => FamilySpecKind::RoystonParmar,
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Logit)) => {
FamilySpecKind::BinomialLogit
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::Probit)) => {
FamilySpecKind::BinomialProbit
}
(ResponseFamily::Binomial, InverseLink::Standard(StandardLink::CLogLog)) => {
FamilySpecKind::BinomialCLogLog
}
(
ResponseFamily::Binomial,
InverseLink::Standard(StandardLink::Identity | StandardLink::Log),
) => FamilySpecKind::BinomialLogit,
(ResponseFamily::Binomial, InverseLink::LatentCLogLog(state)) => {
FamilySpecKind::BinomialLatentCLogLog(*state)
}
(ResponseFamily::Binomial, InverseLink::Sas(state)) => {
FamilySpecKind::BinomialSas(*state)
}
(ResponseFamily::Binomial, InverseLink::BetaLogistic(state)) => {
FamilySpecKind::BinomialBetaLogistic(*state)
}
(ResponseFamily::Binomial, InverseLink::Mixture(state)) => {
FamilySpecKind::BinomialMixture(state.clone())
}
}
}
#[inline]
pub fn is_binomial(&self) -> bool {
self.kind().is_binomial()
}
#[inline]
pub fn is_gaussian_identity(&self) -> bool {
self.kind().is_gaussian_identity()
}
#[inline]
pub fn is_royston_parmar(&self) -> bool {
self.kind().is_royston_parmar()
}
#[inline]
pub fn is_latent_cloglog(&self) -> bool {
self.kind().is_latent_cloglog()
}
#[inline]
pub fn is_binomial_mixture(&self) -> bool {
self.kind().is_binomial_mixture()
}
#[inline]
pub fn is_binomial_sas(&self) -> bool {
self.kind().is_binomial_sas()
}
#[inline]
pub fn is_binomial_beta_logistic(&self) -> bool {
self.kind().is_binomial_beta_logistic()
}
#[inline]
pub fn default_scale_metadata(&self) -> LikelihoodScaleMetadata {
match &self.response {
ResponseFamily::Gaussian => LikelihoodScaleMetadata::ProfiledGaussian,
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape: 1.0 },
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => {
LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 }
}
ResponseFamily::Beta { phi } => LikelihoodScaleMetadata::FixedDispersion { phi: *phi },
ResponseFamily::RoystonParmar => LikelihoodScaleMetadata::Unspecified,
}
}
#[inline]
pub fn pretty_name(&self) -> &'static str {
self.kind().pretty_name()
}
#[inline]
pub fn name(&self) -> &'static str {
self.kind().name()
}
#[inline]
pub fn supports_firth(&self) -> bool {
self.kind().supports_firth()
}
#[inline]
pub const fn fixed_dispersion(&self) -> Option<f64> {
match self.response {
ResponseFamily::Gaussian | ResponseFamily::Gamma | ResponseFamily::RoystonParmar => {
None
}
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => Some(1.0),
ResponseFamily::Beta { phi } => Some(phi),
}
}
}
#[inline]
pub const fn is_valid_tweedie_power(p: f64) -> bool {
p.is_finite() && p > 1.0 && p < 2.0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnsupportedLinkError {
pub family: &'static str,
pub link_name: String,
}
impl UnsupportedLinkError {
#[inline]
pub fn new(family: &'static str, link: &InverseLink) -> Self {
Self {
family,
link_name: inverse_link_diagnostic_name(link),
}
}
}
impl std::fmt::Display for UnsupportedLinkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"inverse link `{}` is not supported by the {} response family",
self.link_name, self.family
)
}
}
impl std::error::Error for UnsupportedLinkError {}
#[inline]
fn inverse_link_diagnostic_name(link: &InverseLink) -> String {
match link {
InverseLink::Standard(lf) => lf.name().to_string(),
InverseLink::LatentCLogLog(_) => "latent-cloglog".to_string(),
InverseLink::Sas(_) => "sas".to_string(),
InverseLink::BetaLogistic(_) => "beta-logistic".to_string(),
InverseLink::Mixture(_) => "mixture".to_string(),
}
}
#[inline]
pub fn inverse_link_to_binomial_spec(
link: &InverseLink,
) -> Result<LikelihoodSpec, UnsupportedLinkError> {
match link {
InverseLink::Standard(StandardLink::Logit)
| InverseLink::Standard(StandardLink::Probit)
| InverseLink::Standard(StandardLink::CLogLog) => {
Ok(LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()))
}
InverseLink::LatentCLogLog(_)
| InverseLink::Sas(_)
| InverseLink::BetaLogistic(_)
| InverseLink::Mixture(_) => {
Ok(LikelihoodSpec::new(ResponseFamily::Binomial, link.clone()))
}
InverseLink::Standard(StandardLink::Log)
| InverseLink::Standard(StandardLink::Identity) => {
Err(UnsupportedLinkError::new("binomial", link))
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LikelihoodScaleMetadata {
ProfiledGaussian,
FixedDispersion { phi: f64 },
FixedGammaShape { shape: f64 },
EstimatedGammaShape { shape: f64 },
Unspecified,
}
impl LikelihoodScaleMetadata {
#[inline]
pub const fn fixed_phi(self) -> Option<f64> {
match self {
Self::FixedDispersion { phi } => Some(phi),
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => {
Some(1.0 / shape)
}
Self::ProfiledGaussian | Self::Unspecified => None,
}
}
#[inline]
pub const fn gamma_shape(self) -> Option<f64> {
match self {
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => Some(shape),
_ => None,
}
}
#[inline]
pub const fn gamma_shape_is_estimated(self) -> bool {
matches!(self, Self::EstimatedGammaShape { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogLikelihoodNormalization {
Full,
OmittingResponseConstants,
UserProvided,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GlmLikelihoodSpec {
pub spec: LikelihoodSpec,
pub scale: LikelihoodScaleMetadata,
}
impl GlmLikelihoodSpec {
#[inline]
pub fn canonical(spec: LikelihoodSpec) -> Self {
let scale = spec.default_scale_metadata();
Self { spec, scale }
}
#[inline]
pub fn link_function(&self) -> LinkFunction {
self.spec.link_function()
}
#[inline]
pub fn fixed_phi(&self) -> Option<f64> {
self.scale.fixed_phi()
}
#[inline]
pub fn gamma_shape(&self) -> Option<f64> {
self.scale.gamma_shape()
}
#[inline]
pub fn with_gamma_shape(mut self, shape: f64) -> Self {
self.scale = match self.scale {
LikelihoodScaleMetadata::FixedGammaShape { .. } => {
LikelihoodScaleMetadata::FixedGammaShape { shape }
}
LikelihoodScaleMetadata::EstimatedGammaShape { .. } => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape }
}
other => match &self.spec.response {
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape },
_ => other,
},
};
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeDeterminantMode {
Auto,
Full,
PositivePart,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeMatrixForm {
ScaledIdentity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct RidgePolicy {
pub rho_independent: bool,
pub include_quadratic_penalty: bool,
pub include_penalty_logdet: bool,
pub include_laplacehessian: bool,
pub determinant_mode: RidgeDeterminantMode,
}
impl RidgePolicy {
pub const fn explicit_stabilization_full() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::Auto,
}
}
pub const fn explicit_stabilization_full_exact() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::Full,
}
}
pub const fn explicit_stabilization_pospart() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::PositivePart,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RidgePassport {
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub policy: RidgePolicy,
}
impl RidgePassport {
pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
Self {
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
policy,
}
}
#[inline]
pub const fn penalty_logdet_ridge(self) -> f64 {
if self.policy.include_penalty_logdet {
self.delta
} else {
0.0
}
}
#[inline]
pub const fn laplacehessianridge(self) -> f64 {
if self.policy.include_laplacehessian {
self.delta
} else {
0.0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Inertia {
pub positive: usize,
pub zero: usize,
pub negative: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationRule {
FixedConstant,
InertiaTarget { spd_floor: f64 },
Heuristic,
UserSpecified,
BackoffEscalation { attempts: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationKind {
None,
SolverDampingOnly,
NumericalPerturbation {
backward_error_bound: Option<f64>,
},
ExplicitPrior,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct StabilizationLedger {
pub kind: StabilizationKind,
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub chosen_by: StabilizationRule,
pub inertia_before: Option<Inertia>,
pub inertia_after: Option<Inertia>,
}
impl StabilizationLedger {
pub const fn none() -> Self {
Self {
kind: StabilizationKind::None,
delta: 0.0,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
}
}
pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
Self {
kind: StabilizationKind::SolverDampingOnly,
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
}
}
pub const fn numerical_perturbation(
delta: f64,
chosen_by: StabilizationRule,
backward_error_bound: Option<f64>,
) -> Self {
Self {
kind: StabilizationKind::NumericalPerturbation {
backward_error_bound,
},
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
}
}
pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
Self {
kind: StabilizationKind::ExplicitPrior,
delta,
matrix_form,
chosen_by: StabilizationRule::UserSpecified,
inertia_before: None,
inertia_after: None,
}
}
pub const fn from_passport(passport: RidgePassport) -> Self {
let any_included = passport.policy.include_quadratic_penalty
|| passport.policy.include_laplacehessian
|| passport.policy.include_penalty_logdet;
let kind = if any_included {
StabilizationKind::ExplicitPrior
} else {
StabilizationKind::NumericalPerturbation {
backward_error_bound: None,
}
};
Self {
kind,
delta: passport.delta,
matrix_form: passport.matrix_form,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
}
}
#[inline]
pub const fn quadratic_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
#[inline]
pub const fn laplace_hessian_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
#[inline]
pub const fn penalty_logdet_delta(&self) -> f64 {
match self.kind {
StabilizationKind::ExplicitPrior => self.delta,
StabilizationKind::None
| StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => 0.0,
}
}
}
macro_rules! array1_f64_newtype {
($name:ident $(, $extra:ident)*) => {
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct $name(pub Array1<f64>);
impl $name {
#[inline]
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
#[inline]
pub fn zeros(len: usize) -> Self {
Self(Array1::zeros(len))
}
}
impl Deref for $name {
type Target = Array1<f64>;
#[inline]
fn deref(&self) -> &Self::Target { &self.0 }
}
impl DerefMut for $name {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}
impl AsRef<Array1<f64>> for $name {
#[inline]
fn as_ref(&self) -> &Array1<f64> { &self.0 }
}
impl From<Array1<f64>> for $name {
#[inline]
fn from(values: Array1<f64>) -> Self { Self(values) }
}
impl From<$name> for Array1<f64> {
#[inline]
fn from(values: $name) -> Self { values.0 }
}
$( array1_f64_newtype!(@extra $name $extra); )*
};
(@extra $name:ident exp) => {
impl $name {
#[inline]
pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
}
};
}
array1_f64_newtype!(Coefficients);
array1_f64_newtype!(LinearPredictor);
array1_f64_newtype!(LogSmoothingParams, exp);
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct SmoothTermIdx(usize);
impl SmoothTermIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn placeholder() -> Self {
Self(usize::MAX)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
#[inline]
pub const fn is_placeholder(self) -> bool {
self.0 == usize::MAX
}
}
impl std::fmt::Display for SmoothTermIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PenaltyIdx(usize);
impl PenaltyIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for PenaltyIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct BasisIdx(usize);
impl BasisIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for BasisIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ColIdx(usize);
impl ColIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for ColIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RowIdx(usize);
impl RowIdx {
#[inline]
pub const fn new(idx: usize) -> Self {
Self(idx)
}
#[inline]
pub const fn get(self) -> usize {
self.0
}
}
impl std::fmt::Display for RowIdx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug)]
pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
impl<'a> LogSmoothingParamsView<'a> {
pub fn new(values: ArrayView1<'a, f64>) -> Self {
Self(values)
}
pub fn exp(&self) -> Array1<f64> {
self.0.mapv(f64::exp)
}
}
impl<'a> Deref for LogSmoothingParamsView<'a> {
type Target = ArrayView1<'a, f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}