use ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
use std::ops::{Deref, DerefMut};
pub use crate::hull::PeeledHull;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WigglePenaltyConfig {
pub degree: usize,
pub num_internal_knots: usize,
pub penalty_orders: Vec<usize>,
pub double_penalty: bool,
pub monotonicity_eps: f64,
}
impl WigglePenaltyConfig {
pub fn cubic_triple_operator_default() -> Self {
Self {
degree: 3,
num_internal_knots: 8,
penalty_orders: vec![1, 2, 3],
double_penalty: true,
monotonicity_eps: 1e-4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkFunction {
Logit,
Probit,
CLogLog,
Sas,
BetaLogistic,
Identity,
Log,
}
impl LinkFunction {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Logit => "logit",
Self::Probit => "probit",
Self::CLogLog => "cloglog",
Self::Sas => "sas",
Self::BetaLogistic => "beta-logistic",
Self::Identity => "identity",
Self::Log => "log",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LinkComponent {
Probit,
Logit,
CLogLog,
LogLog,
Cauchit,
}
impl LinkComponent {
#[inline]
pub const fn name(self) -> &'static str {
match self {
Self::Probit => "probit",
Self::Logit => "logit",
Self::CLogLog => "cloglog",
Self::LogLog => "loglog",
Self::Cauchit => "cauchit",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkSpec {
pub components: Vec<LinkComponent>,
pub initial_rho: Array1<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MixtureLinkState {
pub components: Vec<LinkComponent>,
pub rho: Array1<f64>,
pub pi: Array1<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkSpec {
pub initial_epsilon: f64,
pub initial_log_delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SasLinkState {
pub epsilon: f64,
pub log_delta: f64,
pub delta: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct LatentCLogLogState {
pub latent_sd: f64,
}
impl LatentCLogLogState {
#[inline]
pub fn new(latent_sd: f64) -> Result<Self, String> {
if !latent_sd.is_finite() || latent_sd < 0.0 {
return Err(format!(
"latent cloglog standard deviation must be finite and >= 0, got {latent_sd}"
));
}
Ok(Self { latent_sd })
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum InverseLink {
Standard(LinkFunction),
LatentCLogLog(LatentCLogLogState),
Sas(SasLinkState),
BetaLogistic(SasLinkState),
Mixture(MixtureLinkState),
}
impl InverseLink {
#[inline]
pub const fn link_function(&self) -> LinkFunction {
match self {
Self::Standard(link) => *link,
Self::LatentCLogLog(_) => LinkFunction::CLogLog,
Self::Sas(_) => LinkFunction::Sas,
Self::BetaLogistic(_) => LinkFunction::BetaLogistic,
Self::Mixture(_) => LinkFunction::Logit,
}
}
#[inline]
pub const fn mixture_state(&self) -> Option<&MixtureLinkState> {
match self {
Self::Mixture(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn sas_state(&self) -> Option<&SasLinkState> {
match self {
Self::Sas(state) | Self::BetaLogistic(state) => Some(state),
_ => None,
}
}
#[inline]
pub const fn latent_cloglog_state(&self) -> Option<&LatentCLogLogState> {
match self {
Self::LatentCLogLog(state) => Some(state),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RhoPrior {
Flat,
Normal {
mean: f64,
sd: f64,
},
GammaPrecision {
shape: f64,
rate: f64,
},
Independent(Vec<RhoPrior>),
}
impl Default for RhoPrior {
fn default() -> Self {
Self::Normal { mean: 0.0, sd: 3.0 }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ResponseFamily {
Gaussian,
Binomial,
Poisson,
Tweedie { p: f64 },
NegativeBinomial { theta: f64 },
Beta { phi: f64 },
Gamma,
RoystonParmar,
}
impl ResponseFamily {
#[inline]
pub const fn name(&self) -> &'static str {
match self {
Self::Gaussian => "gaussian",
Self::Binomial => "binomial",
Self::Poisson => "poisson",
Self::Tweedie { .. } => "tweedie",
Self::NegativeBinomial { .. } => "negative-binomial",
Self::Beta { .. } => "beta",
Self::Gamma => "gamma",
Self::RoystonParmar => "royston-parmar",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LikelihoodSpec {
pub response: ResponseFamily,
pub link: InverseLink,
}
impl LikelihoodSpec {
#[inline]
pub const fn new(response: ResponseFamily, link: InverseLink) -> Self {
Self { response, link }
}
#[inline]
pub const fn gaussian_identity() -> Self {
Self::new(
ResponseFamily::Gaussian,
InverseLink::Standard(LinkFunction::Identity),
)
}
#[inline]
pub const fn binomial_logit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(LinkFunction::Logit),
)
}
#[inline]
pub const fn binomial_probit() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(LinkFunction::Probit),
)
}
#[inline]
pub const fn binomial_cloglog() -> Self {
Self::new(
ResponseFamily::Binomial,
InverseLink::Standard(LinkFunction::CLogLog),
)
}
#[inline]
pub const fn binomial_latent_cloglog(state: LatentCLogLogState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::LatentCLogLog(state))
}
#[inline]
pub const fn binomial_sas(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Sas(state))
}
#[inline]
pub const fn binomial_beta_logistic(state: SasLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::BetaLogistic(state))
}
#[inline]
pub fn binomial_mixture(state: MixtureLinkState) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Mixture(state))
}
#[inline]
pub const fn binomial_link(link: LinkFunction) -> Self {
Self::new(ResponseFamily::Binomial, InverseLink::Standard(link))
}
#[inline]
pub const fn poisson_log() -> Self {
Self::new(
ResponseFamily::Poisson,
InverseLink::Standard(LinkFunction::Log),
)
}
#[inline]
pub const fn tweedie_log(p: f64) -> Self {
Self::new(
ResponseFamily::Tweedie { p },
InverseLink::Standard(LinkFunction::Log),
)
}
#[inline]
pub const fn negative_binomial_log(theta: f64) -> Self {
Self::new(
ResponseFamily::NegativeBinomial { theta },
InverseLink::Standard(LinkFunction::Log),
)
}
#[inline]
pub const fn beta_logit(phi: f64) -> Self {
Self::new(
ResponseFamily::Beta { phi },
InverseLink::Standard(LinkFunction::Logit),
)
}
#[inline]
pub const fn gamma_log() -> Self {
Self::new(
ResponseFamily::Gamma,
InverseLink::Standard(LinkFunction::Log),
)
}
#[inline]
pub const fn royston_parmar() -> Self {
Self::new(
ResponseFamily::RoystonParmar,
InverseLink::Standard(LinkFunction::Identity),
)
}
#[inline]
pub const fn link_function(&self) -> LinkFunction {
self.link.link_function()
}
#[inline]
pub const fn is_binomial(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
}
#[inline]
pub const fn is_gaussian_identity(&self) -> bool {
matches!(self.response, ResponseFamily::Gaussian)
&& matches!(self.link, InverseLink::Standard(LinkFunction::Identity))
}
#[inline]
pub const fn is_royston_parmar(&self) -> bool {
matches!(self.response, ResponseFamily::RoystonParmar)
}
#[inline]
pub const fn is_latent_cloglog(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& matches!(self.link, InverseLink::LatentCLogLog(_))
}
#[inline]
pub const fn is_binomial_mixture(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& matches!(self.link, InverseLink::Mixture(_))
}
#[inline]
pub const fn is_binomial_sas(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& matches!(self.link, InverseLink::Sas(_))
}
#[inline]
pub const fn is_binomial_beta_logistic(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& matches!(self.link, InverseLink::BetaLogistic(_))
}
#[inline]
pub fn default_scale_metadata(&self) -> LikelihoodScaleMetadata {
match &self.response {
ResponseFamily::Gaussian => LikelihoodScaleMetadata::ProfiledGaussian,
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape: 1.0 },
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => {
LikelihoodScaleMetadata::FixedDispersion { phi: 1.0 }
}
ResponseFamily::Beta { phi } => LikelihoodScaleMetadata::FixedDispersion { phi: *phi },
ResponseFamily::RoystonParmar => LikelihoodScaleMetadata::Unspecified,
}
}
#[inline]
pub fn pretty_name(&self) -> &'static str {
match (&self.response, &self.link) {
(ResponseFamily::Gaussian, _) => "Gaussian Identity",
(ResponseFamily::Poisson, _) => "Poisson Log",
(ResponseFamily::Tweedie { .. }, _) => "Tweedie Log",
(ResponseFamily::NegativeBinomial { .. }, _) => "Negative-Binomial Log",
(ResponseFamily::Beta { .. }, _) => "Beta Regression Logit",
(ResponseFamily::Gamma, _) => "Gamma Log",
(ResponseFamily::RoystonParmar, _) => "Royston Parmar",
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::Logit)) => {
"Binomial Logit"
}
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::Probit)) => {
"Binomial Probit"
}
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::CLogLog)) => {
"Binomial CLogLog"
}
(ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => "Latent CLogLog Binomial",
(ResponseFamily::Binomial, InverseLink::Sas(_)) => "Binomial SAS",
(ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => "Binomial Beta-Logistic",
(ResponseFamily::Binomial, InverseLink::Mixture(_)) => "Binomial Blended Inverse-Link",
(ResponseFamily::Binomial, InverseLink::Standard(_)) => "Binomial Logit",
}
}
#[inline]
pub fn name(&self) -> &'static str {
match (&self.response, &self.link) {
(ResponseFamily::Gaussian, _) => "gaussian",
(ResponseFamily::Poisson, _) => "poisson-log",
(ResponseFamily::Tweedie { .. }, _) => "tweedie-log",
(ResponseFamily::NegativeBinomial { .. }, _) => "negative-binomial-log",
(ResponseFamily::Beta { .. }, _) => "beta-regression-logit",
(ResponseFamily::Gamma, _) => "gamma-log",
(ResponseFamily::RoystonParmar, _) => "royston-parmar",
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::Logit)) => {
"binomial-logit"
}
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::Probit)) => {
"binomial-probit"
}
(ResponseFamily::Binomial, InverseLink::Standard(LinkFunction::CLogLog)) => {
"binomial-cloglog"
}
(ResponseFamily::Binomial, InverseLink::LatentCLogLog(_)) => "latent-cloglog-binomial",
(ResponseFamily::Binomial, InverseLink::Sas(_)) => "binomial-sas",
(ResponseFamily::Binomial, InverseLink::BetaLogistic(_)) => "binomial-beta-logistic",
(ResponseFamily::Binomial, InverseLink::Mixture(_)) => "binomial-blended-inverse-link",
(ResponseFamily::Binomial, InverseLink::Standard(_)) => "binomial-logit",
}
}
#[inline]
pub const fn supports_firth(&self) -> bool {
matches!(self.response, ResponseFamily::Binomial)
&& matches!(self.link, InverseLink::Standard(LinkFunction::Logit))
}
#[inline]
pub const fn fixed_dispersion(&self) -> Option<f64> {
match self.response {
ResponseFamily::Gaussian | ResponseFamily::Gamma | ResponseFamily::RoystonParmar => {
None
}
ResponseFamily::Binomial
| ResponseFamily::Poisson
| ResponseFamily::Tweedie { .. }
| ResponseFamily::NegativeBinomial { .. } => Some(1.0),
ResponseFamily::Beta { phi } => Some(phi),
}
}
}
#[inline]
pub const fn is_valid_tweedie_power(p: f64) -> bool {
p.is_finite() && p > 1.0 && p < 2.0
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LikelihoodScaleMetadata {
ProfiledGaussian,
FixedDispersion { phi: f64 },
FixedGammaShape { shape: f64 },
EstimatedGammaShape { shape: f64 },
Unspecified,
}
impl LikelihoodScaleMetadata {
#[inline]
pub const fn fixed_phi(self) -> Option<f64> {
match self {
Self::FixedDispersion { phi } => Some(phi),
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => {
Some(1.0 / shape)
}
Self::ProfiledGaussian | Self::Unspecified => None,
}
}
#[inline]
pub const fn gamma_shape(self) -> Option<f64> {
match self {
Self::FixedGammaShape { shape } | Self::EstimatedGammaShape { shape } => Some(shape),
_ => None,
}
}
#[inline]
pub const fn gamma_shape_is_estimated(self) -> bool {
matches!(self, Self::EstimatedGammaShape { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LogLikelihoodNormalization {
Full,
OmittingResponseConstants,
UserProvided,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GlmLikelihoodSpec {
pub spec: LikelihoodSpec,
pub scale: LikelihoodScaleMetadata,
}
impl GlmLikelihoodSpec {
#[inline]
pub fn canonical(spec: LikelihoodSpec) -> Self {
let scale = spec.default_scale_metadata();
Self { spec, scale }
}
#[inline]
pub fn link_function(&self) -> LinkFunction {
self.spec.link_function()
}
#[inline]
pub fn fixed_phi(&self) -> Option<f64> {
self.scale.fixed_phi()
}
#[inline]
pub fn gamma_shape(&self) -> Option<f64> {
self.scale.gamma_shape()
}
#[inline]
pub fn with_gamma_shape(mut self, shape: f64) -> Self {
self.scale = match self.scale {
LikelihoodScaleMetadata::FixedGammaShape { .. } => {
LikelihoodScaleMetadata::FixedGammaShape { shape }
}
LikelihoodScaleMetadata::EstimatedGammaShape { .. } => {
LikelihoodScaleMetadata::EstimatedGammaShape { shape }
}
other => match &self.spec.response {
ResponseFamily::Gamma => LikelihoodScaleMetadata::EstimatedGammaShape { shape },
_ => other,
},
};
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeDeterminantMode {
Auto,
Full,
PositivePart,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RidgeMatrixForm {
ScaledIdentity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct RidgePolicy {
pub rho_independent: bool,
pub include_quadratic_penalty: bool,
pub include_penalty_logdet: bool,
pub include_laplacehessian: bool,
pub determinant_mode: RidgeDeterminantMode,
}
impl RidgePolicy {
pub const fn explicit_stabilization_full() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::Auto,
}
}
pub const fn explicit_stabilization_full_exact() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::Full,
}
}
pub const fn explicit_stabilization_pospart() -> Self {
Self {
rho_independent: true,
include_quadratic_penalty: true,
include_penalty_logdet: true,
include_laplacehessian: true,
determinant_mode: RidgeDeterminantMode::PositivePart,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RidgePassport {
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub policy: RidgePolicy,
}
impl RidgePassport {
pub const fn scaled_identity(delta: f64, policy: RidgePolicy) -> Self {
Self {
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
policy,
}
}
#[inline]
pub const fn penalty_logdet_ridge(self) -> f64 {
if self.policy.include_penalty_logdet {
self.delta
} else {
0.0
}
}
#[inline]
pub const fn laplacehessianridge(self) -> f64 {
if self.policy.include_laplacehessian {
self.delta
} else {
0.0
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Inertia {
pub positive: usize,
pub zero: usize,
pub negative: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationRule {
FixedConstant,
InertiaTarget { spd_floor: f64 },
Heuristic,
UserSpecified,
BackoffEscalation { attempts: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum StabilizationKind {
None,
SolverDampingOnly,
NumericalPerturbation {
backward_error_bound: Option<f64>,
},
ExplicitPrior,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct StabilizationLedger {
pub kind: StabilizationKind,
pub delta: f64,
pub matrix_form: RidgeMatrixForm,
pub chosen_by: StabilizationRule,
pub inertia_before: Option<Inertia>,
pub inertia_after: Option<Inertia>,
pub included_in_quadratic: bool,
pub included_in_laplace_hessian: bool,
pub included_in_penalty_logdet: bool,
}
impl StabilizationLedger {
pub const fn none() -> Self {
Self {
kind: StabilizationKind::None,
delta: 0.0,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub const fn solver_damping(delta: f64, chosen_by: StabilizationRule) -> Self {
Self {
kind: StabilizationKind::SolverDampingOnly,
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub const fn numerical_perturbation(
delta: f64,
chosen_by: StabilizationRule,
backward_error_bound: Option<f64>,
) -> Self {
Self {
kind: StabilizationKind::NumericalPerturbation {
backward_error_bound,
},
delta,
matrix_form: RidgeMatrixForm::ScaledIdentity,
chosen_by,
inertia_before: None,
inertia_after: None,
included_in_quadratic: false,
included_in_laplace_hessian: false,
included_in_penalty_logdet: false,
}
}
pub const fn explicit_prior(delta: f64, matrix_form: RidgeMatrixForm) -> Self {
Self {
kind: StabilizationKind::ExplicitPrior,
delta,
matrix_form,
chosen_by: StabilizationRule::UserSpecified,
inertia_before: None,
inertia_after: None,
included_in_quadratic: true,
included_in_laplace_hessian: true,
included_in_penalty_logdet: true,
}
}
pub const fn from_passport(passport: RidgePassport) -> Self {
let any_included = passport.policy.include_quadratic_penalty
|| passport.policy.include_laplacehessian
|| passport.policy.include_penalty_logdet;
let kind = if !any_included {
StabilizationKind::NumericalPerturbation {
backward_error_bound: None,
}
} else {
StabilizationKind::ExplicitPrior
};
Self {
kind,
delta: passport.delta,
matrix_form: passport.matrix_form,
chosen_by: StabilizationRule::FixedConstant,
inertia_before: None,
inertia_after: None,
included_in_quadratic: passport.policy.include_quadratic_penalty,
included_in_laplace_hessian: passport.policy.include_laplacehessian,
included_in_penalty_logdet: passport.policy.include_penalty_logdet,
}
}
#[inline]
pub const fn quadratic_delta(&self) -> f64 {
if self.included_in_quadratic {
self.delta
} else {
0.0
}
}
#[inline]
pub const fn laplace_hessian_delta(&self) -> f64 {
if self.included_in_laplace_hessian {
self.delta
} else {
0.0
}
}
#[inline]
pub const fn penalty_logdet_delta(&self) -> f64 {
if self.included_in_penalty_logdet {
self.delta
} else {
0.0
}
}
pub const fn invariants_hold(&self) -> bool {
match self.kind {
StabilizationKind::None => {
self.delta == 0.0
&& !self.included_in_quadratic
&& !self.included_in_laplace_hessian
&& !self.included_in_penalty_logdet
}
StabilizationKind::SolverDampingOnly
| StabilizationKind::NumericalPerturbation { .. } => {
!self.included_in_quadratic
&& !self.included_in_laplace_hessian
&& !self.included_in_penalty_logdet
}
StabilizationKind::ExplicitPrior => {
self.included_in_quadratic
&& self.included_in_laplace_hessian
&& self.included_in_penalty_logdet
}
}
}
}
macro_rules! array1_f64_newtype {
($name:ident $(, $extra:ident)*) => {
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq)]
pub struct $name(pub Array1<f64>);
impl $name {
#[inline]
pub fn new(values: Array1<f64>) -> Self {
Self(values)
}
#[inline]
pub fn zeros(len: usize) -> Self {
Self(Array1::zeros(len))
}
}
impl Deref for $name {
type Target = Array1<f64>;
#[inline]
fn deref(&self) -> &Self::Target { &self.0 }
}
impl DerefMut for $name {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}
impl AsRef<Array1<f64>> for $name {
#[inline]
fn as_ref(&self) -> &Array1<f64> { &self.0 }
}
impl From<Array1<f64>> for $name {
#[inline]
fn from(values: Array1<f64>) -> Self { Self(values) }
}
impl From<$name> for Array1<f64> {
#[inline]
fn from(values: $name) -> Self { values.0 }
}
$( array1_f64_newtype!(@extra $name $extra); )*
};
(@extra $name:ident exp) => {
impl $name {
#[inline]
pub fn exp(&self) -> Array1<f64> { self.0.mapv(f64::exp) }
}
};
}
array1_f64_newtype!(Coefficients);
array1_f64_newtype!(LinearPredictor);
array1_f64_newtype!(LogSmoothingParams, exp);
#[repr(transparent)]
#[derive(Clone, Copy, Debug)]
pub struct LogSmoothingParamsView<'a>(pub ArrayView1<'a, f64>);
impl<'a> LogSmoothingParamsView<'a> {
pub fn new(values: ArrayView1<'a, f64>) -> Self {
Self(values)
}
pub fn exp(&self) -> Array1<f64> {
self.0.mapv(f64::exp)
}
}
impl<'a> Deref for LogSmoothingParamsView<'a> {
type Target = ArrayView1<'a, f64>;
fn deref(&self) -> &Self::Target {
&self.0
}
}