use gamlss_core::{
Gamlss, Identity, Log, Logit, Mu, ParameterBlock, ParameterBlocks, Precision, Rate, Scale,
Shape, Sigma,
};
use gamlss_family::{
DefaultBeta, DefaultGamma, DefaultInverseGaussian, DefaultLogNormal, DefaultNormal,
DefaultWeibull,
};
use crate::{
BuiltModel, Col, DataView, FormulaError, FormulaPenalty, ModelSchema, NumericResponse,
PredictionDesign, ResponseSchema, TermExpr,
compile::{ResponseDomain, fit_terms, predictor_from_fitted_terms, required_response},
penalty::prediction_penalty,
predictor::FormulaPredictorBlock,
schema::terms_for,
};
pub type FormulaBlock<P, L> = ParameterBlock<P, L, FormulaPredictorBlock, FormulaPenalty>;
pub type NormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
pub type GammaBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Rate, Log>);
pub type LogNormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
pub type WeibullBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Scale, Log>);
pub type InverseGaussianBlocks = (FormulaBlock<Mu, Log>, FormulaBlock<Shape, Log>);
pub type BetaBlocks = (FormulaBlock<Mu, Logit>, FormulaBlock<Precision, Log>);
pub type CompiledNormal<'a> = Gamlss<DefaultNormal, NormalBlocks, NumericResponse<'a>>;
pub type CompiledGamma<'a> = Gamlss<DefaultGamma, GammaBlocks, NumericResponse<'a>>;
pub type CompiledLogNormal<'a> = Gamlss<DefaultLogNormal, LogNormalBlocks, NumericResponse<'a>>;
pub type CompiledWeibull<'a> = Gamlss<DefaultWeibull, WeibullBlocks, NumericResponse<'a>>;
pub type CompiledInverseGaussian<'a> =
Gamlss<DefaultInverseGaussian, InverseGaussianBlocks, NumericResponse<'a>>;
pub type CompiledBeta<'a> = Gamlss<DefaultBeta, BetaBlocks, NumericResponse<'a>>;
pub type BuiltNormal<'a> = BuiltModel<CompiledNormal<'a>>;
pub type BuiltGamma<'a> = BuiltModel<CompiledGamma<'a>>;
pub type BuiltLogNormal<'a> = BuiltModel<CompiledLogNormal<'a>>;
pub type BuiltWeibull<'a> = BuiltModel<CompiledWeibull<'a>>;
pub type BuiltInverseGaussian<'a> = BuiltModel<CompiledInverseGaussian<'a>>;
pub type BuiltBeta<'a> = BuiltModel<CompiledBeta<'a>>;
macro_rules! define_spec {
(
$(#[$meta:meta])*
$spec:ident, $built:ident, $compiled:ident, $blocks:ident, $family:ty;
family_name = $family_name:literal, domain = $domain:expr;
first = $first_field:ident, $first_method:ident, $first_name:literal, $first_param:ty, $first_link:ty;
second = $second_field:ident, $second_method:ident, $second_name:literal, $second_param:ty, $second_link:ty
) => {
$(#[$meta])*
#[derive(Debug, Clone, PartialEq)]
pub struct $spec {
response: Option<Col<f64>>,
weights: Option<Col<f64>>,
$first_field: Option<TermExpr>,
$second_field: Option<TermExpr>,
duplicate_parameter: Option<&'static str>,
}
impl $spec {
#[must_use]
pub fn new() -> Self {
Self {
response: None,
weights: None,
$first_field: None,
$second_field: None,
duplicate_parameter: None,
}
}
#[must_use]
pub fn response(mut self, response: Col<f64>) -> Self {
self.response = Some(response);
self
}
#[must_use]
pub fn weights(mut self, weights: Col<f64>) -> Self {
self.weights = Some(weights);
self
}
#[must_use]
pub fn $first_method(mut self, terms: impl Into<TermExpr>) -> Self {
if self.$first_field.is_some() {
self.duplicate_parameter.get_or_insert($first_name);
} else {
self.$first_field = Some(terms.into());
}
self
}
#[must_use]
pub fn $second_method(mut self, terms: impl Into<TermExpr>) -> Self {
if self.$second_field.is_some() {
self.duplicate_parameter.get_or_insert($second_name);
} else {
self.$second_field = Some(terms.into());
}
self
}
pub fn build<'a, D>(&self, data: &'a D) -> Result<$built<'a>, FormulaError>
where
D: DataView + ?Sized,
{
if data.nrows() == 0 {
return Err(FormulaError::EmptyData);
}
if let Some(parameter) = self.duplicate_parameter {
return Err(FormulaError::DuplicateParameter(parameter));
}
let (response_col, response) = required_response(
$family_name,
$domain,
data,
&self.response,
&self.weights,
)?;
let first_build = fit_terms($first_name, self.$first_field.clone(), data)?;
let second_build = fit_terms($second_name, self.$second_field.clone(), data)?;
let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
first_build.predictor,
first_build.penalty,
0,
);
let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
second_build.predictor,
second_build.penalty,
0,
);
let blocks: $blocks = ParameterBlocks::new((first, second));
let model: $compiled<'a> =
Gamlss::try_new_with_observations(<$family>::new(), blocks, response)?;
let layout = model.parameter_layout();
let schema = ModelSchema {
response: ResponseSchema {
col: response_col,
weights: self.weights.clone(),
nrows: data.nrows(),
},
parameters: vec![first_build.terms, second_build.terms],
};
Ok(BuiltModel::new(model, schema, layout))
}
}
impl Default for $spec {
fn default() -> Self {
Self::new()
}
}
impl<'a> BuiltModel<$compiled<'a>> {
pub fn prediction_blocks<D>(&self, data: &D) -> Result<$blocks, FormulaError>
where
D: DataView + ?Sized,
{
let first_terms = terms_for(self.schema(), $first_name);
let second_terms = terms_for(self.schema(), $second_name);
let first_x = predictor_from_fitted_terms($first_name, first_terms, data)?;
let second_x = predictor_from_fitted_terms($second_name, second_terms, data)?;
let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
first_x,
prediction_penalty(first_terms),
0,
);
let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
second_x,
prediction_penalty(second_terms),
0,
);
Ok(ParameterBlocks::new((first, second)))
}
pub fn prediction_design<D>(
&self,
data: &D,
) -> Result<PredictionDesign<$blocks>, FormulaError>
where
D: DataView + ?Sized,
{
Ok(PredictionDesign::new(self.prediction_blocks(data)?))
}
pub fn predict_theta_with_design(
&self,
theta: &[f64],
design: &PredictionDesign<$blocks>,
) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
{
Ok(self.model().predict_theta_with_blocks(theta, design.blocks())?)
}
pub fn predict_theta<D>(
&self,
theta: &[f64],
data: &D,
) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
where
D: DataView + ?Sized,
{
let design = self.prediction_design(data)?;
self.predict_theta_with_design(theta, &design)
}
}
};
}
define_spec!(
NormalSpec, BuiltNormal, CompiledNormal, NormalBlocks, DefaultNormal;
family_name = "normal", domain = ResponseDomain::Finite;
first = mu_terms, mu, "mu", Mu, Identity;
second = sigma_terms, sigma, "sigma", Sigma, Log
);
define_spec!(
GammaSpec, BuiltGamma, CompiledGamma, GammaBlocks, DefaultGamma;
family_name = "gamma", domain = ResponseDomain::Positive;
first = shape_terms, shape, "shape", Shape, Log;
second = rate_terms, rate, "rate", Rate, Log
);
define_spec!(
LogNormalSpec, BuiltLogNormal, CompiledLogNormal, LogNormalBlocks, DefaultLogNormal;
family_name = "log-normal", domain = ResponseDomain::Positive;
first = mu_terms, mu, "mu", Mu, Identity;
second = sigma_terms, sigma, "sigma", Sigma, Log
);
define_spec!(
WeibullSpec, BuiltWeibull, CompiledWeibull, WeibullBlocks, DefaultWeibull;
family_name = "weibull", domain = ResponseDomain::Positive;
first = shape_terms, shape, "shape", Shape, Log;
second = scale_terms, scale, "scale", Scale, Log
);
define_spec!(
InverseGaussianSpec, BuiltInverseGaussian, CompiledInverseGaussian, InverseGaussianBlocks, DefaultInverseGaussian;
family_name = "inverse Gaussian", domain = ResponseDomain::Positive;
first = mu_terms, mu, "mu", Mu, Log;
second = shape_terms, shape, "shape", Shape, Log
);
define_spec!(
BetaSpec, BuiltBeta, CompiledBeta, BetaBlocks, DefaultBeta;
family_name = "beta", domain = ResponseDomain::Unit;
first = mu_terms, mu, "mu", Mu, Logit;
second = precision_terms, precision, "precision", Precision, Log
);
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ModelSpec;
impl ModelSpec {
#[must_use]
pub fn normal() -> NormalSpec {
NormalSpec::new()
}
#[must_use]
pub fn gamma() -> GammaSpec {
GammaSpec::new()
}
#[must_use]
pub fn log_normal() -> LogNormalSpec {
LogNormalSpec::new()
}
#[must_use]
pub fn weibull() -> WeibullSpec {
WeibullSpec::new()
}
#[must_use]
pub fn inverse_gaussian() -> InverseGaussianSpec {
InverseGaussianSpec::new()
}
#[must_use]
pub fn beta() -> BetaSpec {
BetaSpec::new()
}
}
#[must_use]
pub fn normal() -> NormalSpec {
ModelSpec::normal()
}
#[must_use]
pub fn gamma() -> GammaSpec {
ModelSpec::gamma()
}
#[must_use]
pub fn log_normal() -> LogNormalSpec {
ModelSpec::log_normal()
}
#[must_use]
pub fn weibull() -> WeibullSpec {
ModelSpec::weibull()
}
#[must_use]
pub fn inverse_gaussian() -> InverseGaussianSpec {
ModelSpec::inverse_gaussian()
}
#[must_use]
pub fn beta() -> BetaSpec {
ModelSpec::beta()
}