Skip to main content

gamlss_formula/
spec.rs

1use gamlss_core::{
2    Gamlss, Identity, Log, Logit, Mu, ParameterBlock, ParameterBlocks, Precision, Rate, Scale,
3    Shape, Sigma,
4};
5use gamlss_family::{
6    DefaultBeta, DefaultGamma, DefaultInverseGaussian, DefaultLogNormal, DefaultNormal,
7    DefaultWeibull,
8};
9
10use crate::{
11    BuiltModel, Col, DataView, FormulaError, FormulaPenalty, ModelSchema, NumericResponse,
12    PredictionDesign, ResponseSchema, TermExpr,
13    compile::{ResponseDomain, fit_terms, predictor_from_fitted_terms, required_response},
14    penalty::prediction_penalty,
15    predictor::FormulaPredictorBlock,
16    schema::terms_for,
17};
18
19/// Formula block type used by supported family specs.
20pub type FormulaBlock<P, L> = ParameterBlock<P, L, FormulaPredictorBlock, FormulaPenalty>;
21
22/// Blocks for a normal formula model.
23pub type NormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
24/// Blocks for a gamma formula model.
25pub type GammaBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Rate, Log>);
26/// Blocks for a log-normal formula model.
27pub type LogNormalBlocks = (FormulaBlock<Mu, Identity>, FormulaBlock<Sigma, Log>);
28/// Blocks for a Weibull formula model.
29pub type WeibullBlocks = (FormulaBlock<Shape, Log>, FormulaBlock<Scale, Log>);
30/// Blocks for an inverse Gaussian formula model.
31pub type InverseGaussianBlocks = (FormulaBlock<Mu, Log>, FormulaBlock<Shape, Log>);
32/// Blocks for a beta formula model.
33pub type BetaBlocks = (FormulaBlock<Mu, Logit>, FormulaBlock<Precision, Log>);
34
35/// Compiled normal formula model.
36pub type CompiledNormal<'a> = Gamlss<DefaultNormal, NormalBlocks, NumericResponse<'a>>;
37/// Compiled gamma formula model.
38pub type CompiledGamma<'a> = Gamlss<DefaultGamma, GammaBlocks, NumericResponse<'a>>;
39/// Compiled log-normal formula model.
40pub type CompiledLogNormal<'a> = Gamlss<DefaultLogNormal, LogNormalBlocks, NumericResponse<'a>>;
41/// Compiled Weibull formula model.
42pub type CompiledWeibull<'a> = Gamlss<DefaultWeibull, WeibullBlocks, NumericResponse<'a>>;
43/// Compiled inverse Gaussian formula model.
44pub type CompiledInverseGaussian<'a> =
45    Gamlss<DefaultInverseGaussian, InverseGaussianBlocks, NumericResponse<'a>>;
46/// Compiled beta formula model.
47pub type CompiledBeta<'a> = Gamlss<DefaultBeta, BetaBlocks, NumericResponse<'a>>;
48
49/// Built normal formula model.
50pub type BuiltNormal<'a> = BuiltModel<CompiledNormal<'a>>;
51/// Built gamma formula model.
52pub type BuiltGamma<'a> = BuiltModel<CompiledGamma<'a>>;
53/// Built log-normal formula model.
54pub type BuiltLogNormal<'a> = BuiltModel<CompiledLogNormal<'a>>;
55/// Built Weibull formula model.
56pub type BuiltWeibull<'a> = BuiltModel<CompiledWeibull<'a>>;
57/// Built inverse Gaussian formula model.
58pub type BuiltInverseGaussian<'a> = BuiltModel<CompiledInverseGaussian<'a>>;
59/// Built beta formula model.
60pub type BuiltBeta<'a> = BuiltModel<CompiledBeta<'a>>;
61
62macro_rules! define_spec {
63    (
64        $(#[$meta:meta])*
65        $spec:ident, $built:ident, $compiled:ident, $blocks:ident, $family:ty;
66        family_name = $family_name:literal, domain = $domain:expr;
67        first = $first_field:ident, $first_method:ident, $first_name:literal, $first_param:ty, $first_link:ty;
68        second = $second_field:ident, $second_method:ident, $second_name:literal, $second_param:ty, $second_link:ty
69    ) => {
70        $(#[$meta])*
71        #[derive(Debug, Clone, PartialEq)]
72        pub struct $spec {
73            response: Option<Col<f64>>,
74            weights: Option<Col<f64>>,
75            $first_field: Option<TermExpr>,
76            $second_field: Option<TermExpr>,
77            duplicate_parameter: Option<&'static str>,
78        }
79
80        impl $spec {
81            /// Creates an empty model specification.
82            #[must_use]
83            pub fn new() -> Self {
84                Self {
85                    response: None,
86                    weights: None,
87                    $first_field: None,
88                    $second_field: None,
89                    duplicate_parameter: None,
90                }
91            }
92
93            /// Sets the response column.
94            #[must_use]
95            pub fn response(mut self, response: Col<f64>) -> Self {
96                self.response = Some(response);
97                self
98            }
99
100            /// Sets optional observation weights.
101            #[must_use]
102            pub fn weights(mut self, weights: Col<f64>) -> Self {
103                self.weights = Some(weights);
104                self
105            }
106
107            /// Sets terms for the first distribution parameter.
108            #[must_use]
109            pub fn $first_method(mut self, terms: impl Into<TermExpr>) -> Self {
110                if self.$first_field.is_some() {
111                    self.duplicate_parameter.get_or_insert($first_name);
112                } else {
113                    self.$first_field = Some(terms.into());
114                }
115                self
116            }
117
118            /// Sets terms for the second distribution parameter.
119            #[must_use]
120            pub fn $second_method(mut self, terms: impl Into<TermExpr>) -> Self {
121                if self.$second_field.is_some() {
122                    self.duplicate_parameter.get_or_insert($second_name);
123                } else {
124                    self.$second_field = Some(terms.into());
125                }
126                self
127            }
128
129            /// Builds a typed core model plus formula metadata.
130            pub fn build<'a, D>(&self, data: &'a D) -> Result<$built<'a>, FormulaError>
131            where
132                D: DataView + ?Sized,
133            {
134                if data.nrows() == 0 {
135                    return Err(FormulaError::EmptyData);
136                }
137                if let Some(parameter) = self.duplicate_parameter {
138                    return Err(FormulaError::DuplicateParameter(parameter));
139                }
140
141                let (response_col, response) = required_response(
142                    $family_name,
143                    $domain,
144                    data,
145                    &self.response,
146                    &self.weights,
147                )?;
148                let first_build = fit_terms($first_name, self.$first_field.clone(), data)?;
149                let second_build = fit_terms($second_name, self.$second_field.clone(), data)?;
150
151                let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
152                    first_build.predictor,
153                    first_build.penalty,
154                    0,
155                );
156                let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
157                    second_build.predictor,
158                    second_build.penalty,
159                    0,
160                );
161                let blocks: $blocks = ParameterBlocks::new((first, second));
162                let model: $compiled<'a> =
163                    Gamlss::try_new_with_observations(<$family>::new(), blocks, response)?;
164                let layout = model.parameter_layout();
165                let schema = ModelSchema {
166                    response: ResponseSchema {
167                        col: response_col,
168                        weights: self.weights.clone(),
169                        nrows: data.nrows(),
170                    },
171                    parameters: vec![first_build.terms, second_build.terms],
172                };
173
174                Ok(BuiltModel::new(model, schema, layout))
175            }
176        }
177
178        impl Default for $spec {
179            fn default() -> Self {
180                Self::new()
181            }
182        }
183
184        impl<'a> BuiltModel<$compiled<'a>> {
185            /// Rebuilds compatible typed prediction blocks from fitted term metadata.
186            pub fn prediction_blocks<D>(&self, data: &D) -> Result<$blocks, FormulaError>
187            where
188                D: DataView + ?Sized,
189            {
190                let first_terms = terms_for(self.schema(), $first_name);
191                let second_terms = terms_for(self.schema(), $second_name);
192                let first_x = predictor_from_fitted_terms($first_name, first_terms, data)?;
193                let second_x = predictor_from_fitted_terms($second_name, second_terms, data)?;
194
195                let first = ParameterBlock::<$first_param, $first_link, _, _>::from_predictor(
196                    first_x,
197                    prediction_penalty(first_terms),
198                    0,
199                );
200                let second = ParameterBlock::<$second_param, $second_link, _, _>::from_predictor(
201                    second_x,
202                    prediction_penalty(second_terms),
203                    0,
204                );
205
206                Ok(ParameterBlocks::new((first, second)))
207            }
208
209            /// Builds reusable prediction design from fitted term metadata.
210            pub fn prediction_design<D>(
211                &self,
212                data: &D,
213            ) -> Result<PredictionDesign<$blocks>, FormulaError>
214            where
215                D: DataView + ?Sized,
216            {
217                Ok(PredictionDesign::new(self.prediction_blocks(data)?))
218            }
219
220            /// Predicts natural-scale distribution parameters with reusable prediction design.
221            pub fn predict_theta_with_design(
222                &self,
223                theta: &[f64],
224                design: &PredictionDesign<$blocks>,
225            ) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
226            {
227                Ok(self.model().predict_theta_with_blocks(theta, design.blocks())?)
228            }
229
230            /// Predicts natural-scale distribution parameters for new rows.
231            pub fn predict_theta<D>(
232                &self,
233                theta: &[f64],
234                data: &D,
235            ) -> Result<Vec<<$family as gamlss_core::Family>::Theta>, FormulaError>
236            where
237                D: DataView + ?Sized,
238            {
239                let design = self.prediction_design(data)?;
240                self.predict_theta_with_design(theta, &design)
241            }
242        }
243    };
244}
245
246define_spec!(
247    /// Typed normal model specification.
248    NormalSpec, BuiltNormal, CompiledNormal, NormalBlocks, DefaultNormal;
249    family_name = "normal", domain = ResponseDomain::Finite;
250    first = mu_terms, mu, "mu", Mu, Identity;
251    second = sigma_terms, sigma, "sigma", Sigma, Log
252);
253
254define_spec!(
255    /// Typed gamma model specification.
256    GammaSpec, BuiltGamma, CompiledGamma, GammaBlocks, DefaultGamma;
257    family_name = "gamma", domain = ResponseDomain::Positive;
258    first = shape_terms, shape, "shape", Shape, Log;
259    second = rate_terms, rate, "rate", Rate, Log
260);
261
262define_spec!(
263    /// Typed log-normal model specification.
264    LogNormalSpec, BuiltLogNormal, CompiledLogNormal, LogNormalBlocks, DefaultLogNormal;
265    family_name = "log-normal", domain = ResponseDomain::Positive;
266    first = mu_terms, mu, "mu", Mu, Identity;
267    second = sigma_terms, sigma, "sigma", Sigma, Log
268);
269
270define_spec!(
271    /// Typed Weibull model specification.
272    WeibullSpec, BuiltWeibull, CompiledWeibull, WeibullBlocks, DefaultWeibull;
273    family_name = "weibull", domain = ResponseDomain::Positive;
274    first = shape_terms, shape, "shape", Shape, Log;
275    second = scale_terms, scale, "scale", Scale, Log
276);
277
278define_spec!(
279    /// Typed inverse Gaussian model specification.
280    InverseGaussianSpec, BuiltInverseGaussian, CompiledInverseGaussian, InverseGaussianBlocks, DefaultInverseGaussian;
281    family_name = "inverse Gaussian", domain = ResponseDomain::Positive;
282    first = mu_terms, mu, "mu", Mu, Log;
283    second = shape_terms, shape, "shape", Shape, Log
284);
285
286define_spec!(
287    /// Typed beta model specification.
288    BetaSpec, BuiltBeta, CompiledBeta, BetaBlocks, DefaultBeta;
289    family_name = "beta", domain = ResponseDomain::Unit;
290    first = mu_terms, mu, "mu", Mu, Logit;
291    second = precision_terms, precision, "precision", Precision, Log
292);
293
294/// Entry point namespace for typed model specifications.
295#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
296pub struct ModelSpec;
297
298impl ModelSpec {
299    /// Creates a normal model spec.
300    #[must_use]
301    pub fn normal() -> NormalSpec {
302        NormalSpec::new()
303    }
304
305    /// Creates a gamma model spec.
306    #[must_use]
307    pub fn gamma() -> GammaSpec {
308        GammaSpec::new()
309    }
310
311    /// Creates a log-normal model spec.
312    #[must_use]
313    pub fn log_normal() -> LogNormalSpec {
314        LogNormalSpec::new()
315    }
316
317    /// Creates a Weibull model spec.
318    #[must_use]
319    pub fn weibull() -> WeibullSpec {
320        WeibullSpec::new()
321    }
322
323    /// Creates an inverse Gaussian model spec.
324    #[must_use]
325    pub fn inverse_gaussian() -> InverseGaussianSpec {
326        InverseGaussianSpec::new()
327    }
328
329    /// Creates a beta model spec.
330    #[must_use]
331    pub fn beta() -> BetaSpec {
332        BetaSpec::new()
333    }
334}
335
336/// Creates a normal model spec.
337#[must_use]
338pub fn normal() -> NormalSpec {
339    ModelSpec::normal()
340}
341
342/// Creates a gamma model spec.
343#[must_use]
344pub fn gamma() -> GammaSpec {
345    ModelSpec::gamma()
346}
347
348/// Creates a log-normal model spec.
349#[must_use]
350pub fn log_normal() -> LogNormalSpec {
351    ModelSpec::log_normal()
352}
353
354/// Creates a Weibull model spec.
355#[must_use]
356pub fn weibull() -> WeibullSpec {
357    ModelSpec::weibull()
358}
359
360/// Creates an inverse Gaussian model spec.
361#[must_use]
362pub fn inverse_gaussian() -> InverseGaussianSpec {
363    ModelSpec::inverse_gaussian()
364}
365
366/// Creates a beta model spec.
367#[must_use]
368pub fn beta() -> BetaSpec {
369    ModelSpec::beta()
370}