use ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
use std::ops::{Deref, DerefMut};
pub use crate::hull::PeeledHull;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WigglePenaltyConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl WigglePenaltyConfig {
pub fn cubic_triple_operator_default() -> Self {
Self {
degree: 3,
num_internal_knots: 8,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
monotonicity_eps: 1e-4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkFunction {
Logit,
Probit,
CLogLog,
Sas,
BetaLogistic,
Identity,
Log,
}
impl LinkFunction {
#[inline]
pub fn name(self) -> &'static str {
match self {
Self::Logit => "logit",
Self::Probit => "probit",
Self::CLogLog => "cloglog",
Self::Sas => "sas",
Self::BetaLogistic => "beta-logistic",
Self::Identity => "identity",
Self::Log => "log",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkComponent {
Probit,
Logit,
CLogLog,
LogLog,
Cauchit,
}
impl LinkComponent {
#[inline]
pub fn name(self) -> &'static str {
match self {
Self::Probit => "probit",
Self::Logit => "logit",
Self::CLogLog => "cloglog",
Self::LogLog => "loglog",
Self::Cauchit => "cauchit",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkSpec {
pub components: Vec<LinkComponent>,
pub initial_rho: Array1<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkState {
pub components: Vec<LinkComponent>,
pub rho: Array1<f64>,
pub pi: Array1<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkSpec {
pub initial_epsilon: f64,
pub initial_log_delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkState {
pub epsilon: f64,
pub log_delta: f64,
pub delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct LatentCLogLogState {
pub latent_sd: f64,
}
impl LatentCLogLogState {
#[inline]
pub fn new(latent_sd: f64) -> Result<Self, String> {
if !latent_sd.is_finite() || latent_sd < 0.0 {
return Err(format!(
"latent cloglog standard deviation must be finite and >= 0, got {latent_sd}"
));
}
Ok(Self { latent_sd })
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum InverseLink {
Standard(LinkFunction),
LatentCLogLog(LatentCLogLogState),
Sas(SasLinkState),
BetaLogistic(SasLinkState),
Mixture(MixtureLinkState),
}
impl InverseLink {
#[inline]
pub fn link_function(&self) -> LinkFunction {
match self {
Self::Standard(link) => *link,
Self::LatentCLogLog(_) => LinkFunction::CLogLog,
Self::Sas(_) => LinkFunction::Sas,
Self::BetaLogistic(_) => LinkFunction::BetaLogistic,
Self::Mixture(_) => LinkFunction::Logit,
}
}
#[inline]
pub fn mixture_state(&self) -> Option<&MixtureLinkState> {
match self {
Self::Mixture(state) => Some(state),
_ => None,
}
}
#[inline]
pub fn sas_state(&self) -> Option<&SasLinkState> {
match self {
Self::Sas(state) | Self::BetaLogistic(state) => Some(state),
_ => None,
}
}
#[inline]
pub fn latent_cloglog_state(&self) -> Option<&LatentCLogLogState> {
match self {
Self::LatentCLogLog(state) => Some(state),
_ => None,
}
}
pub fn saved_string(&self) -> String {
match self {
Self::Standard(link) => link.name().to_string(),
Self::LatentCLogLog(state) => format!("latent-cloglog(sd={})", state.latent_sd),
Self::Sas(_) => "sas".to_string(),
Self::BetaLogistic(_) => "beta-logistic".to_string(),
Self::Mixture(state) => {
let names = state
.components
.iter()
.map(|component| component.name())
.collect::<Vec<_>>()
.join(",");
format!("blended({names})")
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RhoPrior {
Flat,
Normal { mean: f64, sd: f64 },
}
impl Default for RhoPrior {
fn default() -> Self {
Self::Normal { mean: 0.0, sd: 3.0 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LikelihoodFamily {
GaussianIdentity,
BinomialLogit,
BinomialProbit,
BinomialCLogLog,
BinomialLatentCLogLog,
BinomialSas,
BinomialBetaLogistic,
BinomialMixture,
PoissonLog,
GammaLog,
RoystonParmar,
}
impl LikelihoodFamily {
#[inline]
pub fn link_function(self) -> LinkFunction {
match self {
Self::GaussianIdentity | Self::RoystonParmar => LinkFunction::Identity,
Self::PoissonLog | Self::GammaLog => LinkFunction::Log,
Self::BinomialLogit | Self::BinomialMixture => LinkFunction::Logit,
Self::BinomialProbit => LinkFunction::Probit,
Self::BinomialCLogLog | Self::BinomialLatentCLogLog => LinkFunction::CLogLog,
Self::BinomialSas => LinkFunction::Sas,
Self::BinomialBetaLogistic => LinkFunction::BetaLogistic,
}
}
#[inline]
pub fn name(self) -> &'static str {
match self {
Self::GaussianIdentity => "gaussian",
Self::BinomialLogit => "binomial-logit",
Self::BinomialProbit => "binomial-probit",
Self::BinomialCLogLog => "binomial-cloglog",
Self::BinomialLatentCLogLog => "latent-cloglog-binomial",
Self::BinomialSas => "binomial-sas",
Self::BinomialBetaLogistic => "binomial-beta-logistic",
Self::BinomialMixture => "binomial-blended-inverse-link",
Self::PoissonLog => "poisson-log",
Self::GammaLog => "gamma-log",
Self::RoystonParmar => "royston-parmar",
}
}
#[inline]
pub fn pretty_name(self) -> &'static str {
match self {
Self::GaussianIdentity => "Gaussian Identity",
Self::BinomialLogit => "Binomial Logit",
Self::BinomialProbit => "Binomial Probit",
Self::BinomialCLogLog => "Binomial CLogLog",
Self::BinomialLatentCLogLog => "Latent CLogLog Binomial",
Self::BinomialSas => "Binomial SAS",
Self::BinomialBetaLogistic => "Binomial Beta-Logistic",
Self::BinomialMixture => "Binomial Blended Inverse-Link",
Self::PoissonLog => "Poisson Log",
Self::GammaLog => "Gamma Log",
Self::RoystonParmar => "Royston Parmar",
}
}
#[inline]
pub fn supports_firth(self) -> bool {
matches!(self, Self::BinomialLogit)
}
#[inline]
pub(crate) fn is_binomial(self) -> bool {
matches!(
self,
Self::BinomialLogit
| Self::BinomialProbit
| Self::BinomialCLogLog
| Self::BinomialLatentCLogLog
| Self::BinomialSas
| Self::BinomialBetaLogistic
| Self::BinomialMixture
)
}
#[inline]
pub fn default_scale_metadata(self) -> LikelihoodScaleMetadata {
match self {
Self::GaussianIdentity => LikelihoodScaleMetadata::ProfiledGaussian,
Self::GammaLog => LikelihoodScaleMetadata::EstimatedGammaShape { shape: 1.0 },
Self::BinomialLogit
| Self::BinomialProbit
| Self::BinomialCLogLog
| Self::BinomialLatentCLogLog
| Self::BinomialSas
| Self::BinomialBetaLogistic
| Self::BinomialMixture
| Self::PoissonLog => LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 },
Self::RoystonParmar => LikelihoodScaleMetadata::Unspecified,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GlmLikelihoodFamily {
GaussianIdentity,
BinomialLogit,
BinomialProbit,
BinomialCLogLog,
BinomialSas,
BinomialBetaLogistic,
BinomialMixture,
PoissonLog,
GammaLog,
}
impl GlmLikelihoodFamily {
#[inline]
pub fn link_function(self) -> LinkFunction {
LikelihoodFamily::from(self).link_function()
}
#[inline]
pub fn supports_firth(self) -> bool {
LikelihoodFamily::from(self).supports_firth()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LikelihoodScaleMetadata {
ProfiledGaussian,
FixedDispersion { phi: f64 },
FixedGammaShape { shape: f64 },
EstimatedGammaShape { shape: f64 },
Unspecified,
}
impl LikelihoodScaleMetadata {
#[inline]
pub fn fixed_phi(self) -> Option<f64> {
match self {
Self::FixedDispersion { phi } => Some(phi),
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => {
Some(1.0 / shape)
}
Self::ProfiledGaussian | Self::Unspecified => None,
}
}
#[inline]
pub fn gamma_shape(self) -> Option<f64> {
match self {
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => Some(shape),
_ => None,
}
}
#[inline]
pub fn gamma_shape_is_estimated(self) -> bool {
matches!(self, Self::EstimatedGammaShape { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogLikelihoodNormalization {
Full,
OmittingResponseConstants,
UserProvided,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct GlmLikelihoodSpec {
pub family: GlmLikelihoodFamily,
pub scale: LikelihoodScaleMetadata,
}
impl GlmLikelihoodSpec {
#[inline]
pub fn canonical(family: GlmLikelihoodFamily) -> Self {
let scale = match family {
GlmLikelihoodFamily::GaussianIdentity => LikelihoodScaleMetadata::ProfiledGaussian,
GlmLikelihoodFamily::GammaLog => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape: 1.0 }
}
GlmLikelihoodFamily::BinomialLogit
| GlmLikelihoodFamily::BinomialProbit
| GlmLikelihoodFamily::BinomialCLogLog
| GlmLikelihoodFamily::BinomialSas
| GlmLikelihoodFamily::BinomialBetaLogistic
| GlmLikelihoodFamily::BinomialMixture
| GlmLikelihoodFamily::PoissonLog => {
LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 }
}
};
Self { family, scale }
}
#[inline]
pub fn link_function(self) -> LinkFunction {
self.family.link_function()
}
#[inline]
pub fn response_family(self) -> LikelihoodFamily {
self.family.into()
}
#[inline]
pub fn fixed_phi(self) -> Option<f64> {
self.scale.fixed_phi()
}
#[inline]
pub fn gamma_shape(self) -> Option<f64> {
self.scale.gamma_shape()
}
#[inline]
pub fn with_gamma_shape(mut self, shape: f64) -> Self {
self.scale = match self.scale {
LikelihoodScaleMetadata::FixedGammaShape { .. } => {
LikelihoodScaleMetadata::FixedGammaShape { shape }
}
LikelihoodScaleMetadata::EstimatedGammaShape { .. } => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape }
}
_ if self.family == GlmLikelihoodFamily::GammaLog => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape }
}
other => other,
};
self
}
}
impl TryFrom<LikelihoodFamily> for GlmLikelihoodFamily {
type Error = &'static str;
fn try_from(value: LikelihoodFamily) -> Result<Self, Self::Error> {
match value {
LikelihoodFamily::GaussianIdentity => Ok(Self::GaussianIdentity),
LikelihoodFamily::BinomialLogit => Ok(Self::BinomialLogit),
LikelihoodFamily::BinomialProbit => Ok(Self::BinomialProbit),
LikelihoodFamily::BinomialCLogLog | LikelihoodFamily::BinomialLatentCLogLog => {
Ok(Self::BinomialCLogLog)
}
LikelihoodFamily::BinomialSas => Ok(Self::BinomialSas),
LikelihoodFamily::BinomialBetaLogistic => Ok(Self::BinomialBetaLogistic),
LikelihoodFamily::BinomialMixture => Ok(Self::BinomialMixture),
LikelihoodFamily::PoissonLog => Ok(Self::PoissonLog),
LikelihoodFamily::GammaLog => Ok(Self::GammaLog),
LikelihoodFamily::RoystonParmar => {
Err("RoystonParmar is survival-specific and not a GLM likelihood")
}
}
}
}
impl From<GlmLikelihoodFamily> for LikelihoodFamily {
fn from(value: GlmLikelihoodFamily) -> Self {
match value {
GlmLikelihoodFamily::GaussianIdentity => Self::GaussianIdentity,
GlmLikelihoodFamily::BinomialLogit => Self::BinomialLogit,
GlmLikelihoodFamily::BinomialProbit => Self::BinomialProbit,
GlmLikelihoodFamily::BinomialCLogLog => Self::BinomialCLogLog,
GlmLikelihoodFamily::BinomialSas => Self::BinomialSas,
GlmLikelihoodFamily::BinomialBetaLogistic => Self::BinomialBetaLogistic,
GlmLikelihoodFamily::BinomialMixture => Self::BinomialMixture,
GlmLikelihoodFamily::PoissonLog => Self::PoissonLog,
GlmLikelihoodFamily::GammaLog => Self::GammaLog,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeDeterminantMode {
Auto,
Full,
PositivePart,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeMatrixForm {
ScaledIdentity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct RidgePolicy {
pub rho_independent: bool,
pub include_quadratic_penalty: bool,
pub include_penalty_logdet: bool,
pub include_laplacehessian: bool,
pub determinant_mode: RidgeDeterminantMode,
}
impl RidgePolicy {
pub fn explicit_stabilization_full() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::Auto,
}
}
pub fn explicit_stabilization_full_exact() -> Self {
Self {
determinant_mode: RidgeDeterminantMode::Full,
..Self::explicit_stabilization_full()
}
}
pub fn explicit_stabilization_pospart() -> Self {
Self {
determinant_mode: RidgeDeterminantMode::PositivePart,
..Self::explicit_stabilization_full()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RidgePassport {
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub policy: RidgePolicy,
}
impl RidgePassport {
pub fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
Self {
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
policy,
}
}
#[inline]
pub fn penalty_logdet_ridge(self) -> f64 {
if self.policy.include_penalty_logdet {
self.delta
} else {
0.0
}
}
#[inline]
pub fn laplacehessianridge(self) -> f64 {
if self.policy.include_laplacehessian {
self.delta
} else {
0.0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Inertia {
pub positive: usize,
pub zero: usize,
pub negative: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationRule {
FixedConstant,
InertiaTarget { spd_floor: f64 },
Heuristic,
UserSpecified,
BackoffEscalation { attempts: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationKind {
None,
SolverDampingOnly,
NumericalPerturbation {
backward_error_bound: Option<f64>,
},
ExplicitPrior,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct StabilizationLedger {
pub kind: StabilizationKind,
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub chosen_by: StabilizationRule,
pub inertia_before: Option<Inertia>,
pub inertia_after: Option<Inertia>,
pub included_in_quadratic: bool,
pub included_in_laplace_hessian: bool,
pub included_in_penalty_logdet: bool,
}
impl StabilizationLedger {
pub fn none() -> Self {
Self {
kind: StabilizationKind::None,
delta: 0.0,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
Self {
kind: StabilizationKind::SolverDampingOnly,
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub fn numerical_perturbation(
delta: f64,
chosen_by: StabilizationRule,
backward_error_bound: Option<f64>,
) -> Self {
Self {
kind: StabilizationKind::NumericalPerturbation {
backward_error_bound,
},
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
Self {
kind: StabilizationKind::ExplicitPrior,
delta,
matrix_form,
chosen_by: StabilizationRule::UserSpecified,
inertia_before: None,
inertia_after: None,
included_in_quadratic: true,
included_in_laplace_hessian: true,
included_in_penalty_logdet: true,
}
}
pub fn from_passport(passport: RidgePassport) -> Self {
let any_included = passport.policy.include_quadratic_penalty
|| passport.policy.include_laplacehessian
|| passport.policy.include_penalty_logdet;
let kind = if !any_included {
StabilizationKind::NumericalPerturbation {
backward_error_bound: None,
}
} else {
StabilizationKind::ExplicitPrior
};
Self {
kind,
delta: passport.delta,
matrix_form: passport.matrix_form,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
included_in_quadratic: passport.policy.include_quadratic_penalty,
included_in_laplace_hessian: passport.policy.include_laplacehessian,
included_in_penalty_logdet: passport.policy.include_penalty_logdet,
}
}
#[inline]
pub fn quadratic_delta(&self) -> f64 {
if self.included_in_quadratic {
self.delta
} else {
0.0
}
}
#[inline]
pub fn laplace_hessian_delta(&self) -> f64 {
if self.included_in_laplace_hessian {
self.delta
} else {
0.0
}
}
#[inline]
pub fn penalty_logdet_delta(&self) -> f64 {
if self.included_in_penalty_logdet {
self.delta
} else {
0.0
}
}
pub fn invariants_hold(&self) -> bool {
match self.kind {
StabilizationKind::None => {
self.delta == 0.0
&& !self.included_in_quadratic
&& !self.included_in_laplace_hessian
&& !self.included_in_penalty_logdet
}
StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => {
!self.included_in_quadratic
&& !self.included_in_laplace_hessian
&& !self.included_in_penalty_logdet
}
StabilizationKind::ExplicitPrior => {
self.included_in_quadratic
&& self.included_in_laplace_hessian
&& self.included_in_penalty_logdet
}
}
}
}
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct Coefficients(pub Array1<f64>);
impl Coefficients {
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
pub fn zeros(len: usize) -> Self {
Self(Array1::zeros(len))
}
}
impl Deref for Coefficients {
type Target = Array1<f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Coefficients {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl AsRef<Array1<f64>> for Coefficients {
fn as_ref(&self) -> &Array1<f64> {
&self.0
}
}
impl From<Array1<f64>> for Coefficients {
fn from(values: Array1<f64>) -> Self {
Self(values)
}
}
impl From<Coefficients> for Array1<f64> {
fn from(values: Coefficients) -> Self {
values.0
}
}
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct LinearPredictor(pub Array1<f64>);
impl LinearPredictor {
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
pub fn zeros(len: usize) -> Self {
Self(Array1::zeros(len))
}
}
impl Deref for LinearPredictor {
type Target = Array1<f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for LinearPredictor {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl AsRef<Array1<f64>> for LinearPredictor {
fn as_ref(&self) -> &Array1<f64> {
&self.0
}
}
impl From<Array1<f64>> for LinearPredictor {
fn from(values: Array1<f64>) -> Self {
Self(values)
}
}
impl From<LinearPredictor> for Array1<f64> {
fn from(values: LinearPredictor) -> Self {
values.0
}
}
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct LogSmoothingParams(pub Array1<f64>);
impl LogSmoothingParams {
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
}
impl Deref for LogSmoothingParams {
type Target = Array1<f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for LogSmoothingParams {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Array1<f64>> for LogSmoothingParams {
fn from(values: Array1<f64>) -> Self {
Self(values)
}
}
impl From<LogSmoothingParams> for Array1<f64> {
fn from(values: LogSmoothingParams) -> Self {
values.0
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug)]
pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
impl<'a> LogSmoothingParamsView<'a> {
pub fn new(values: ArrayView1<'a, f64>) -> Self {
Self(values)
}
pub fn exp(&self) -> Array1<f64> {
self.0.mapv(f64::exp)
}
}
impl<'a> Deref for LogSmoothingParamsView<'a> {
type Target = ArrayView1<'a, f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}