Skip to main content

gamlss_formula/
terms.rs

1use std::ops::{Add, Range};
2
3use gamlss_spline::{
4    CyclicSplineSpec, ISplineBasis, MonotoneDirection, OpenUniformSplineBasis, SplineOrder,
5};
6
7use crate::{Category, Col};
8
9/// Pre-data term expression for one parameter predictor.
10#[derive(Debug, Clone, PartialEq)]
11pub struct TermExpr {
12    terms: Vec<TermSpec>,
13    default_intercept_if_empty: bool,
14}
15
16impl TermExpr {
17    /// Creates an expression from one term.
18    #[must_use]
19    pub fn new(term: TermSpec) -> Self {
20        Self {
21            terms: vec![term],
22            default_intercept_if_empty: false,
23        }
24    }
25
26    /// Creates an explicit empty expression with no implicit intercept.
27    #[must_use]
28    pub fn empty() -> Self {
29        Self {
30            terms: Vec::new(),
31            default_intercept_if_empty: false,
32        }
33    }
34
35    pub(crate) fn into_terms(mut self) -> Vec<TermSpec> {
36        if self.terms.is_empty() && self.default_intercept_if_empty {
37            self.terms.push(TermSpec::Intercept);
38        }
39        self.terms
40    }
41}
42
43impl Default for TermExpr {
44    fn default() -> Self {
45        Self {
46            terms: Vec::new(),
47            default_intercept_if_empty: true,
48        }
49    }
50}
51
52impl Add for TermExpr {
53    type Output = Self;
54
55    fn add(mut self, mut rhs: Self) -> Self::Output {
56        self.terms.append(&mut rhs.terms);
57        if !rhs.default_intercept_if_empty {
58            self.default_intercept_if_empty = false;
59        }
60        self
61    }
62}
63
64/// Pre-data term specification.
65#[derive(Debug, Clone, PartialEq)]
66pub enum TermSpec {
67    /// Intercept column of ones.
68    Intercept,
69    /// Linear term backed by one numeric column.
70    Linear {
71        /// Source column.
72        col: Col<f64>,
73    },
74    /// Numeric offset added to the predictor without adding coefficients.
75    Offset {
76        /// Source column.
77        col: Col<f64>,
78    },
79    /// Boolean indicator term.
80    Indicator {
81        /// Source column.
82        col: Col<bool>,
83    },
84    /// Treatment-coded categorical term.
85    Factor {
86        /// Source column.
87        col: Col<Category>,
88    },
89    /// Product of two numeric columns.
90    Interaction {
91        /// Left source column.
92        left: Col<f64>,
93        /// Right source column.
94        right: Col<f64>,
95    },
96    /// Open-uniform P-spline term.
97    PSpline(PSplineTerm),
98    /// Cyclic P-spline term.
99    CyclicPSpline(CyclicPSplineTerm),
100    /// Fourier seasonal term.
101    Fourier(FourierTerm),
102    /// Tensor-product open-uniform P-spline term.
103    TensorPSpline(TensorPSplineTerm),
104    /// Hard-monotone I-spline term.
105    Monotone(MonotoneTerm),
106}
107
108/// Creates an intercept term expression.
109#[must_use]
110pub fn intercept() -> TermExpr {
111    TermExpr::new(TermSpec::Intercept)
112}
113
114/// Creates an explicit empty term expression without an intercept.
115#[must_use]
116pub fn no_intercept() -> TermExpr {
117    TermExpr::empty()
118}
119
120/// Creates a linear term expression.
121#[must_use]
122pub fn linear(col: Col<f64>) -> TermExpr {
123    TermExpr::new(TermSpec::Linear { col })
124}
125
126/// Creates a numeric offset term expression.
127#[must_use]
128pub fn offset(col: Col<f64>) -> TermExpr {
129    TermExpr::new(TermSpec::Offset { col })
130}
131
132/// Creates a boolean indicator term expression.
133#[must_use]
134pub fn indicator(col: Col<bool>) -> TermExpr {
135    TermExpr::new(TermSpec::Indicator { col })
136}
137
138/// Creates a categorical factor term expression.
139#[must_use]
140pub fn factor(col: Col<Category>) -> TermExpr {
141    TermExpr::new(TermSpec::Factor { col })
142}
143
144/// Creates a numeric product interaction term expression.
145#[must_use]
146pub fn interaction(left: Col<f64>, right: Col<f64>) -> TermExpr {
147    TermExpr::new(TermSpec::Interaction { left, right })
148}
149
150/// Creates an open-uniform P-spline term expression with default options.
151#[must_use]
152pub fn pspline(col: Col<f64>) -> PSplineTerm {
153    PSplineTerm {
154        col,
155        k: 20,
156        order: SplineOrder::Cubic,
157        lambda: 1.0,
158        penalty_order: 2,
159    }
160}
161
162/// Creates a cyclic P-spline term expression with default options.
163#[must_use]
164pub fn cyclic_pspline(col: Col<f64>) -> CyclicPSplineTerm {
165    CyclicPSplineTerm {
166        col,
167        k: 20,
168        order: SplineOrder::Cubic,
169        lambda: 1.0,
170        penalty_order: 2,
171    }
172}
173
174/// Creates a Fourier term expression with default options.
175#[must_use]
176pub fn fourier(col: Col<f64>) -> FourierTerm {
177    FourierTerm {
178        col,
179        period: 1.0,
180        order: 1,
181        include_intercept: false,
182    }
183}
184
185/// Creates a tensor-product P-spline term expression with default options.
186#[must_use]
187pub fn tensor_pspline(left: Col<f64>, right: Col<f64>) -> TensorPSplineTerm {
188    TensorPSplineTerm {
189        left,
190        right,
191        left_k: 10,
192        right_k: 10,
193        left_order: SplineOrder::Cubic,
194        right_order: SplineOrder::Cubic,
195    }
196}
197
198/// Creates a hard-monotone I-spline term expression with default options.
199#[must_use]
200pub fn monotone(col: Col<f64>) -> MonotoneTerm {
201    MonotoneTerm {
202        col,
203        k: 20,
204        degree: SplineOrder::Cubic.degree(),
205        direction: MonotoneDirection::Increasing,
206    }
207}
208
209macro_rules! impl_term_expr {
210    ($term:ty, $variant:ident) => {
211        impl From<$term> for TermExpr {
212            fn from(value: $term) -> Self {
213                Self::new(TermSpec::$variant(value))
214            }
215        }
216
217        impl Add<$term> for TermExpr {
218            type Output = TermExpr;
219
220            fn add(self, rhs: $term) -> Self::Output {
221                self + TermExpr::from(rhs)
222            }
223        }
224
225        impl<Rhs> Add<Rhs> for $term
226        where
227            Rhs: Into<TermExpr>,
228        {
229            type Output = TermExpr;
230
231            fn add(self, rhs: Rhs) -> Self::Output {
232                TermExpr::from(self) + rhs.into()
233            }
234        }
235    };
236}
237
238/// Open-uniform P-spline term options.
239#[derive(Debug, Clone, PartialEq)]
240pub struct PSplineTerm {
241    pub(crate) col: Col<f64>,
242    pub(crate) k: usize,
243    pub(crate) order: SplineOrder,
244    pub(crate) lambda: f64,
245    pub(crate) penalty_order: usize,
246}
247
248impl PSplineTerm {
249    /// Returns the source column.
250    #[must_use]
251    pub fn col(&self) -> &Col<f64> {
252        &self.col
253    }
254
255    /// Returns the number of spline basis functions.
256    #[must_use]
257    pub fn n_basis(&self) -> usize {
258        self.k
259    }
260
261    /// Returns the spline order.
262    #[must_use]
263    pub fn spline_order(&self) -> SplineOrder {
264        self.order
265    }
266
267    /// Returns the difference penalty weight.
268    #[must_use]
269    pub fn penalty_lambda(&self) -> f64 {
270        self.lambda
271    }
272
273    /// Returns the finite-difference penalty order.
274    #[must_use]
275    pub fn difference_order(&self) -> usize {
276        self.penalty_order
277    }
278
279    /// Sets the number of spline basis functions.
280    #[must_use]
281    pub fn k(mut self, k: usize) -> Self {
282        self.k = k;
283        self
284    }
285
286    /// Sets the spline order.
287    #[must_use]
288    pub fn order(mut self, order: SplineOrder) -> Self {
289        self.order = order;
290        self
291    }
292
293    /// Sets the difference penalty weight.
294    #[must_use]
295    pub fn lambda(mut self, lambda: f64) -> Self {
296        self.lambda = lambda;
297        self
298    }
299
300    /// Sets the finite-difference penalty order.
301    #[must_use]
302    pub fn penalty_order(mut self, order: usize) -> Self {
303        self.penalty_order = order;
304        self
305    }
306}
307
308/// Cyclic P-spline term options.
309#[derive(Debug, Clone, PartialEq)]
310pub struct CyclicPSplineTerm {
311    pub(crate) col: Col<f64>,
312    pub(crate) k: usize,
313    pub(crate) order: SplineOrder,
314    pub(crate) lambda: f64,
315    pub(crate) penalty_order: usize,
316}
317
318impl CyclicPSplineTerm {
319    /// Sets the number of spline basis functions.
320    #[must_use]
321    pub fn k(mut self, k: usize) -> Self {
322        self.k = k;
323        self
324    }
325
326    /// Sets the spline order.
327    #[must_use]
328    pub fn order(mut self, order: SplineOrder) -> Self {
329        self.order = order;
330        self
331    }
332
333    /// Sets the cyclic difference penalty weight.
334    #[must_use]
335    pub fn lambda(mut self, lambda: f64) -> Self {
336        self.lambda = lambda;
337        self
338    }
339
340    /// Sets the finite-difference penalty order.
341    #[must_use]
342    pub fn penalty_order(mut self, order: usize) -> Self {
343        self.penalty_order = order;
344        self
345    }
346}
347
348/// Fourier term options.
349#[derive(Debug, Clone, PartialEq)]
350pub struct FourierTerm {
351    pub(crate) col: Col<f64>,
352    pub(crate) period: f64,
353    pub(crate) order: usize,
354    pub(crate) include_intercept: bool,
355}
356
357impl FourierTerm {
358    /// Sets the Fourier period.
359    #[must_use]
360    pub fn period(mut self, period: f64) -> Self {
361        self.period = period;
362        self
363    }
364
365    /// Sets the Fourier order.
366    #[must_use]
367    pub fn order(mut self, order: usize) -> Self {
368        self.order = order;
369        self
370    }
371
372    /// Includes an intercept column in this term.
373    #[must_use]
374    pub fn include_intercept(mut self, include: bool) -> Self {
375        self.include_intercept = include;
376        self
377    }
378}
379
380/// Tensor-product P-spline term options.
381#[derive(Debug, Clone, PartialEq)]
382pub struct TensorPSplineTerm {
383    pub(crate) left: Col<f64>,
384    pub(crate) right: Col<f64>,
385    pub(crate) left_k: usize,
386    pub(crate) right_k: usize,
387    pub(crate) left_order: SplineOrder,
388    pub(crate) right_order: SplineOrder,
389}
390
391impl TensorPSplineTerm {
392    /// Sets basis counts for the left and right axes.
393    #[must_use]
394    pub fn k(mut self, left: usize, right: usize) -> Self {
395        self.left_k = left;
396        self.right_k = right;
397        self
398    }
399
400    /// Sets spline orders for the left and right axes.
401    #[must_use]
402    pub fn order(mut self, left: SplineOrder, right: SplineOrder) -> Self {
403        self.left_order = left;
404        self.right_order = right;
405        self
406    }
407}
408
409/// Hard-monotone I-spline term options.
410#[derive(Debug, Clone, PartialEq)]
411pub struct MonotoneTerm {
412    pub(crate) col: Col<f64>,
413    pub(crate) k: usize,
414    pub(crate) degree: usize,
415    pub(crate) direction: MonotoneDirection,
416}
417
418impl MonotoneTerm {
419    /// Sets the number of I-spline basis functions.
420    #[must_use]
421    pub fn k(mut self, k: usize) -> Self {
422        self.k = k;
423        self
424    }
425
426    /// Sets the I-spline degree.
427    #[must_use]
428    pub fn degree(mut self, degree: usize) -> Self {
429        self.degree = degree;
430        self
431    }
432
433    /// Sets the monotonicity direction.
434    #[must_use]
435    pub fn direction(mut self, direction: MonotoneDirection) -> Self {
436        self.direction = direction;
437        self
438    }
439}
440
441impl_term_expr!(PSplineTerm, PSpline);
442impl_term_expr!(CyclicPSplineTerm, CyclicPSpline);
443impl_term_expr!(FourierTerm, Fourier);
444impl_term_expr!(TensorPSplineTerm, TensorPSpline);
445impl_term_expr!(MonotoneTerm, Monotone);
446
447/// Fitted term metadata reusable for prediction.
448#[derive(Debug, Clone, PartialEq)]
449pub enum FittedTerm {
450    /// Intercept term.
451    Intercept {
452        /// Local coefficient range inside the parameter block.
453        range: Range<usize>,
454        /// Coefficient name.
455        coefficient: String,
456    },
457    /// Linear term.
458    Linear {
459        /// Source column.
460        col: Col<f64>,
461        /// Local coefficient range inside the parameter block.
462        range: Range<usize>,
463        /// Coefficient name.
464        coefficient: String,
465    },
466    /// Numeric offset term.
467    Offset {
468        /// Source column.
469        col: Col<f64>,
470    },
471    /// Boolean indicator term.
472    Indicator {
473        /// Source column.
474        col: Col<bool>,
475        /// Local coefficient range inside the parameter block.
476        range: Range<usize>,
477        /// Coefficient name.
478        coefficient: String,
479    },
480    /// Treatment-coded categorical factor.
481    Factor {
482        /// Source column.
483        col: Col<Category>,
484        /// Local coefficient range inside the parameter block.
485        range: Range<usize>,
486        /// Sorted training levels.
487        levels: Vec<String>,
488        /// Baseline level.
489        baseline: String,
490        /// Coefficient names.
491        coefficients: Vec<String>,
492    },
493    /// Numeric product interaction.
494    Interaction {
495        /// Left source column.
496        left: Col<f64>,
497        /// Right source column.
498        right: Col<f64>,
499        /// Local coefficient range inside the parameter block.
500        range: Range<usize>,
501        /// Coefficient name.
502        coefficient: String,
503    },
504    /// P-spline term with fitted basis metadata.
505    PSpline {
506        /// Source column.
507        col: Col<f64>,
508        /// Local coefficient range inside the parameter block.
509        range: Range<usize>,
510        /// Fitted spline basis metadata.
511        basis: OpenUniformSplineBasis,
512        /// Penalty weight.
513        lambda: f64,
514        /// Difference penalty order.
515        penalty_order: usize,
516        /// Coefficient names for this term.
517        coefficients: Vec<String>,
518    },
519    /// Cyclic P-spline term with fitted metadata.
520    CyclicPSpline {
521        /// Source column.
522        col: Col<f64>,
523        /// Local coefficient range inside the parameter block.
524        range: Range<usize>,
525        /// Cyclic spline metadata.
526        spec: CyclicSplineSpec,
527        /// Penalty weight.
528        lambda: f64,
529        /// Difference penalty order.
530        penalty_order: usize,
531        /// Coefficient names for this term.
532        coefficients: Vec<String>,
533    },
534    /// Fourier term with fitted metadata.
535    Fourier {
536        /// Source column.
537        col: Col<f64>,
538        /// Local coefficient range inside the parameter block.
539        range: Range<usize>,
540        /// Period.
541        period: f64,
542        /// Fourier order.
543        order: usize,
544        /// Whether this term includes an intercept coefficient.
545        include_intercept: bool,
546        /// Coefficient names for this term.
547        coefficients: Vec<String>,
548    },
549    /// Tensor-product P-spline metadata.
550    TensorPSpline {
551        /// Left source column.
552        left: Col<f64>,
553        /// Right source column.
554        right: Col<f64>,
555        /// Local coefficient range inside the parameter block.
556        range: Range<usize>,
557        /// Left fitted basis metadata.
558        left_basis: OpenUniformSplineBasis,
559        /// Right fitted basis metadata.
560        right_basis: OpenUniformSplineBasis,
561        /// Coefficient names for this term.
562        coefficients: Vec<String>,
563    },
564    /// Hard-monotone I-spline metadata.
565    Monotone {
566        /// Source column.
567        col: Col<f64>,
568        /// Local coefficient range inside the parameter block.
569        range: Range<usize>,
570        /// Fitted I-spline basis metadata.
571        basis: ISplineBasis,
572        /// Monotonicity direction.
573        direction: MonotoneDirection,
574        /// Coefficient names for this term.
575        coefficients: Vec<String>,
576    },
577}
578
579impl FittedTerm {
580    /// Local coefficient range inside the parameter block.
581    #[must_use]
582    pub fn range(&self) -> Range<usize> {
583        match self {
584            Self::Intercept { range, .. }
585            | Self::Linear { range, .. }
586            | Self::Indicator { range, .. }
587            | Self::Factor { range, .. }
588            | Self::Interaction { range, .. }
589            | Self::PSpline { range, .. }
590            | Self::CyclicPSpline { range, .. }
591            | Self::Fourier { range, .. }
592            | Self::TensorPSpline { range, .. }
593            | Self::Monotone { range, .. } => range.clone(),
594            Self::Offset { .. } => 0..0,
595        }
596    }
597
598    /// Appends coefficient names for this term.
599    pub(crate) fn append_coefficient_names<'a>(&'a self, out: &mut Vec<&'a str>) {
600        match self {
601            Self::Intercept { coefficient, .. }
602            | Self::Linear { coefficient, .. }
603            | Self::Indicator { coefficient, .. }
604            | Self::Interaction { coefficient, .. } => out.push(coefficient.as_str()),
605            Self::PSpline { coefficients, .. }
606            | Self::CyclicPSpline { coefficients, .. }
607            | Self::Fourier { coefficients, .. }
608            | Self::Factor { coefficients, .. }
609            | Self::TensorPSpline { coefficients, .. }
610            | Self::Monotone { coefficients, .. } => {
611                out.extend(coefficients.iter().map(String::as_str));
612            }
613            Self::Offset { .. } => {}
614        }
615    }
616}