1use gamlss_core::{
2 Gamlss, Identity, Log, Logit, Mu, ParameterBlock, ParameterBlocks, Precision, Rate, Scale,
3 Shape, Sigma,
4};
5use gamlss_family::{
6 DefaultBeta, DefaultGamma, DefaultInverseGaussian, DefaultLogNormal, DefaultNormal,
7 DefaultWeibull,
8};
9
10use crate::{
11 BuiltModel, Col, DataView, FormulaError, FormulaPenalty, ModelSchema, NumericResponse,
12 PredictionDesign, ResponseSchema, TermExpr,
13 compile::{ResponseDomain, fit_terms, predictor_from_fitted_terms, required_response},
14 penalty::prediction_penalty,
15 predictor::FormulaPredictorBlock,
16 schema::terms_for,
17};
18
19pub type FormulaBlock<P, L> = ParameterBlock<P, L, FormulaPredictorBlock, FormulaPenalty>;
21
22pub type NormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
24pub type GammaBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Rate, Log>);
26pub type LogNormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
28pub type WeibullBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Scale, Log>);
30pub type InverseGaussianBlocks = (FormulaBlock<Mu, Log>, FormulaBlock<Shape, Log>);
32pub type BetaBlocks = (FormulaBlock<Mu, Logit>, FormulaBlock<Precision, Log>);
34
35pub type CompiledNormal<'a> = Gamlss<DefaultNormal, NormalBlocks, NumericResponse<'a>>;
37pub type CompiledGamma<'a> = Gamlss<DefaultGamma, GammaBlocks, NumericResponse<'a>>;
39pub type CompiledLogNormal<'a> = Gamlss<DefaultLogNormal, LogNormalBlocks, NumericResponse<'a>>;
41pub type CompiledWeibull<'a> = Gamlss<DefaultWeibull, WeibullBlocks, NumericResponse<'a>>;
43pub type CompiledInverseGaussian<'a> =
45 Gamlss<DefaultInverseGaussian, InverseGaussianBlocks, NumericResponse<'a>>;
46pub type CompiledBeta<'a> = Gamlss<DefaultBeta, BetaBlocks, NumericResponse<'a>>;
48
49pub type BuiltNormal<'a> = BuiltModel<CompiledNormal<'a>>;
51pub type BuiltGamma<'a> = BuiltModel<CompiledGamma<'a>>;
53pub type BuiltLogNormal<'a> = BuiltModel<CompiledLogNormal<'a>>;
55pub type BuiltWeibull<'a> = BuiltModel<CompiledWeibull<'a>>;
57pub type BuiltInverseGaussian<'a> = BuiltModel<CompiledInverseGaussian<'a>>;
59pub type BuiltBeta<'a> = BuiltModel<CompiledBeta<'a>>;
61
62macro_rules! define_spec {
63 (
64 $(#[$meta:meta])*
65 $spec:ident, $built:ident, $compiled:ident, $blocks:ident, $family:ty;
66 family_name = $family_name:literal, domain = $domain:expr;
67 first = $first_field:ident, $first_method:ident, $first_name:literal, $first_param:ty, $first_link:ty;
68 second = $second_field:ident, $second_method:ident, $second_name:literal, $second_param:ty, $second_link:ty
69 ) => {
70 $(#[$meta])*
71 #[derive(Debug, Clone, PartialEq)]
72 pub struct $spec {
73 response: Option<Col<f64>>,
74 weights: Option<Col<f64>>,
75 $first_field: Option<TermExpr>,
76 $second_field: Option<TermExpr>,
77 duplicate_parameter: Option<&'static str>,
78 }
79
80 impl $spec {
81 #[must_use]
83 pub fn new() -> Self {
84 Self {
85 response: None,
86 weights: None,
87 $first_field: None,
88 $second_field: None,
89 duplicate_parameter: None,
90 }
91 }
92
93 #[must_use]
95 pub fn response(mut self, response: Col<f64>) -> Self {
96 self.response = Some(response);
97 self
98 }
99
100 #[must_use]
102 pub fn weights(mut self, weights: Col<f64>) -> Self {
103 self.weights = Some(weights);
104 self
105 }
106
107 #[must_use]
109 pub fn $first_method(mut self, terms: impl Into<TermExpr>) -> Self {
110 if self.$first_field.is_some() {
111 self.duplicate_parameter.get_or_insert($first_name);
112 } else {
113 self.$first_field = Some(terms.into());
114 }
115 self
116 }
117
118 #[must_use]
120 pub fn $second_method(mut self, terms: impl Into<TermExpr>) -> Self {
121 if self.$second_field.is_some() {
122 self.duplicate_parameter.get_or_insert($second_name);
123 } else {
124 self.$second_field = Some(terms.into());
125 }
126 self
127 }
128
129 pub fn build<'a, D>(&self, data: &'a D) -> Result<$built<'a>, FormulaError>
131 where
132 D: DataView + ?Sized,
133 {
134 if data.nrows() == 0 {
135 return Err(FormulaError::EmptyData);
136 }
137 if let Some(parameter) = self.duplicate_parameter {
138 return Err(FormulaError::DuplicateParameter(parameter));
139 }
140
141 let (response_col, response) = required_response(
142 $family_name,
143 $domain,
144 data,
145 &self.response,
146 &self.weights,
147 )?;
148 let first_build = fit_terms($first_name, self.$first_field.clone(), data)?;
149 let second_build = fit_terms($second_name, self.$second_field.clone(), data)?;
150
151 let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
152 first_build.predictor,
153 first_build.penalty,
154 0,
155 );
156 let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
157 second_build.predictor,
158 second_build.penalty,
159 0,
160 );
161 let blocks: $blocks = ParameterBlocks::new((first, second));
162 let model: $compiled<'a> =
163 Gamlss::try_new_with_observations(<$family>::new(), blocks, response)?;
164 let layout = model.parameter_layout();
165 let schema = ModelSchema {
166 response: ResponseSchema {
167 col: response_col,
168 weights: self.weights.clone(),
169 nrows: data.nrows(),
170 },
171 parameters: vec![first_build.terms, second_build.terms],
172 };
173
174 Ok(BuiltModel::new(model, schema, layout))
175 }
176 }
177
178 impl Default for $spec {
179 fn default() -> Self {
180 Self::new()
181 }
182 }
183
184 impl<'a> BuiltModel<$compiled<'a>> {
185 pub fn prediction_blocks<D>(&self, data: &D) -> Result<$blocks, FormulaError>
187 where
188 D: DataView + ?Sized,
189 {
190 let first_terms = terms_for(self.schema(), $first_name);
191 let second_terms = terms_for(self.schema(), $second_name);
192 let first_x = predictor_from_fitted_terms($first_name, first_terms, data)?;
193 let second_x = predictor_from_fitted_terms($second_name, second_terms, data)?;
194
195 let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
196 first_x,
197 prediction_penalty(first_terms),
198 0,
199 );
200 let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
201 second_x,
202 prediction_penalty(second_terms),
203 0,
204 );
205
206 Ok(ParameterBlocks::new((first, second)))
207 }
208
209 pub fn prediction_design<D>(
211 &self,
212 data: &D,
213 ) -> Result<PredictionDesign<$blocks>, FormulaError>
214 where
215 D: DataView + ?Sized,
216 {
217 Ok(PredictionDesign::new(self.prediction_blocks(data)?))
218 }
219
220 pub fn predict_theta_with_design(
222 &self,
223 theta: &[f64],
224 design: &PredictionDesign<$blocks>,
225 ) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
226 {
227 Ok(self.model().predict_theta_with_blocks(theta, design.blocks())?)
228 }
229
230 pub fn predict_theta<D>(
232 &self,
233 theta: &[f64],
234 data: &D,
235 ) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
236 where
237 D: DataView + ?Sized,
238 {
239 let design = self.prediction_design(data)?;
240 self.predict_theta_with_design(theta, &design)
241 }
242 }
243 };
244}
245
246define_spec!(
247 NormalSpec, BuiltNormal, CompiledNormal, NormalBlocks, DefaultNormal;
249 family_name = "normal", domain = ResponseDomain::Finite;
250 first = mu_terms, mu, "mu", Mu, Identity;
251 second = sigma_terms, sigma, "sigma", Sigma, Log
252);
253
254define_spec!(
255 GammaSpec, BuiltGamma, CompiledGamma, GammaBlocks, DefaultGamma;
257 family_name = "gamma", domain = ResponseDomain::Positive;
258 first = shape_terms, shape, "shape", Shape, Log;
259 second = rate_terms, rate, "rate", Rate, Log
260);
261
262define_spec!(
263 LogNormalSpec, BuiltLogNormal, CompiledLogNormal, LogNormalBlocks, DefaultLogNormal;
265 family_name = "log-normal", domain = ResponseDomain::Positive;
266 first = mu_terms, mu, "mu", Mu, Identity;
267 second = sigma_terms, sigma, "sigma", Sigma, Log
268);
269
270define_spec!(
271 WeibullSpec, BuiltWeibull, CompiledWeibull, WeibullBlocks, DefaultWeibull;
273 family_name = "weibull", domain = ResponseDomain::Positive;
274 first = shape_terms, shape, "shape", Shape, Log;
275 second = scale_terms, scale, "scale", Scale, Log
276);
277
278define_spec!(
279 InverseGaussianSpec, BuiltInverseGaussian, CompiledInverseGaussian, InverseGaussianBlocks, DefaultInverseGaussian;
281 family_name = "inverse Gaussian", domain = ResponseDomain::Positive;
282 first = mu_terms, mu, "mu", Mu, Log;
283 second = shape_terms, shape, "shape", Shape, Log
284);
285
286define_spec!(
287 BetaSpec, BuiltBeta, CompiledBeta, BetaBlocks, DefaultBeta;
289 family_name = "beta", domain = ResponseDomain::Unit;
290 first = mu_terms, mu, "mu", Mu, Logit;
291 second = precision_terms, precision, "precision", Precision, Log
292);
293
294#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
296pub struct ModelSpec;
297
298impl ModelSpec {
299 #[must_use]
301 pub fn normal() -> NormalSpec {
302 NormalSpec::new()
303 }
304
305 #[must_use]
307 pub fn gamma() -> GammaSpec {
308 GammaSpec::new()
309 }
310
311 #[must_use]
313 pub fn log_normal() -> LogNormalSpec {
314 LogNormalSpec::new()
315 }
316
317 #[must_use]
319 pub fn weibull() -> WeibullSpec {
320 WeibullSpec::new()
321 }
322
323 #[must_use]
325 pub fn inverse_gaussian() -> InverseGaussianSpec {
326 InverseGaussianSpec::new()
327 }
328
329 #[must_use]
331 pub fn beta() -> BetaSpec {
332 BetaSpec::new()
333 }
334}
335
336#[must_use]
338pub fn normal() -> NormalSpec {
339 ModelSpec::normal()
340}
341
342#[must_use]
344pub fn gamma() -> GammaSpec {
345 ModelSpec::gamma()
346}
347
348#[must_use]
350pub fn log_normal() -> LogNormalSpec {
351 ModelSpec::log_normal()
352}
353
354#[must_use]
356pub fn weibull() -> WeibullSpec {
357 ModelSpec::weibull()
358}
359
360#[must_use]
362pub fn inverse_gaussian() -> InverseGaussianSpec {
363 ModelSpec::inverse_gaussian()
364}
365
366#[must_use]
368pub fn beta() -> BetaSpec {
369 ModelSpec::beta()
370}