1use std::ops::{Add, Range};
2
3use gamlss_spline::{
4 CyclicSplineSpec, ISplineBasis, MonotoneDirection, OpenUniformSplineBasis, SplineOrder,
5};
6
7use crate::{Category, Col};
8
9#[derive(Debug, Clone, PartialEq)]
11pub struct TermExpr {
12 terms: Vec<TermSpec>,
13 default_intercept_if_empty: bool,
14}
15
16impl TermExpr {
17 #[must_use]
19 pub fn new(term: TermSpec) -> Self {
20 Self {
21 terms: vec![term],
22 default_intercept_if_empty: false,
23 }
24 }
25
26 #[must_use]
28 pub fn empty() -> Self {
29 Self {
30 terms: Vec::new(),
31 default_intercept_if_empty: false,
32 }
33 }
34
35 pub(crate) fn into_terms(mut self) -> Vec<TermSpec> {
36 if self.terms.is_empty() && self.default_intercept_if_empty {
37 self.terms.push(TermSpec::Intercept);
38 }
39 self.terms
40 }
41}
42
43impl Default for TermExpr {
44 fn default() -> Self {
45 Self {
46 terms: Vec::new(),
47 default_intercept_if_empty: true,
48 }
49 }
50}
51
52impl Add for TermExpr {
53 type Output = Self;
54
55 fn add(mut self, mut rhs: Self) -> Self::Output {
56 self.terms.append(&mut rhs.terms);
57 if !rhs.default_intercept_if_empty {
58 self.default_intercept_if_empty = false;
59 }
60 self
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
66pub enum TermSpec {
67 Intercept,
69 Linear {
71 col: Col<f64>,
73 },
74 Offset {
76 col: Col<f64>,
78 },
79 Indicator {
81 col: Col<bool>,
83 },
84 Factor {
86 col: Col<Category>,
88 },
89 Interaction {
91 left: Col<f64>,
93 right: Col<f64>,
95 },
96 PSpline(PSplineTerm),
98 CyclicPSpline(CyclicPSplineTerm),
100 Fourier(FourierTerm),
102 TensorPSpline(TensorPSplineTerm),
104 Monotone(MonotoneTerm),
106}
107
108#[must_use]
110pub fn intercept() -> TermExpr {
111 TermExpr::new(TermSpec::Intercept)
112}
113
114#[must_use]
116pub fn no_intercept() -> TermExpr {
117 TermExpr::empty()
118}
119
120#[must_use]
122pub fn linear(col: Col<f64>) -> TermExpr {
123 TermExpr::new(TermSpec::Linear { col })
124}
125
126#[must_use]
128pub fn offset(col: Col<f64>) -> TermExpr {
129 TermExpr::new(TermSpec::Offset { col })
130}
131
132#[must_use]
134pub fn indicator(col: Col<bool>) -> TermExpr {
135 TermExpr::new(TermSpec::Indicator { col })
136}
137
138#[must_use]
140pub fn factor(col: Col<Category>) -> TermExpr {
141 TermExpr::new(TermSpec::Factor { col })
142}
143
144#[must_use]
146pub fn interaction(left: Col<f64>, right: Col<f64>) -> TermExpr {
147 TermExpr::new(TermSpec::Interaction { left, right })
148}
149
150#[must_use]
152pub fn pspline(col: Col<f64>) -> PSplineTerm {
153 PSplineTerm {
154 col,
155 k: 20,
156 order: SplineOrder::Cubic,
157 lambda: 1.0,
158 penalty_order: 2,
159 }
160}
161
162#[must_use]
164pub fn cyclic_pspline(col: Col<f64>) -> CyclicPSplineTerm {
165 CyclicPSplineTerm {
166 col,
167 k: 20,
168 order: SplineOrder::Cubic,
169 lambda: 1.0,
170 penalty_order: 2,
171 }
172}
173
174#[must_use]
176pub fn fourier(col: Col<f64>) -> FourierTerm {
177 FourierTerm {
178 col,
179 period: 1.0,
180 order: 1,
181 include_intercept: false,
182 }
183}
184
185#[must_use]
187pub fn tensor_pspline(left: Col<f64>, right: Col<f64>) -> TensorPSplineTerm {
188 TensorPSplineTerm {
189 left,
190 right,
191 left_k: 10,
192 right_k: 10,
193 left_order: SplineOrder::Cubic,
194 right_order: SplineOrder::Cubic,
195 }
196}
197
198#[must_use]
200pub fn monotone(col: Col<f64>) -> MonotoneTerm {
201 MonotoneTerm {
202 col,
203 k: 20,
204 degree: SplineOrder::Cubic.degree(),
205 direction: MonotoneDirection::Increasing,
206 }
207}
208
209macro_rules! impl_term_expr {
210 ($term:ty, $variant:ident) => {
211 impl From<$term> for TermExpr {
212 fn from(value: $term) -> Self {
213 Self::new(TermSpec::$variant(value))
214 }
215 }
216
217 impl Add<$term> for TermExpr {
218 type Output = TermExpr;
219
220 fn add(self, rhs: $term) -> Self::Output {
221 self + TermExpr::from(rhs)
222 }
223 }
224
225 impl<Rhs> Add<Rhs> for $term
226 where
227 Rhs: Into<TermExpr>,
228 {
229 type Output = TermExpr;
230
231 fn add(self, rhs: Rhs) -> Self::Output {
232 TermExpr::from(self) + rhs.into()
233 }
234 }
235 };
236}
237
238#[derive(Debug, Clone, PartialEq)]
240pub struct PSplineTerm {
241 pub(crate) col: Col<f64>,
242 pub(crate) k: usize,
243 pub(crate) order: SplineOrder,
244 pub(crate) lambda: f64,
245 pub(crate) penalty_order: usize,
246}
247
248impl PSplineTerm {
249 #[must_use]
251 pub fn col(&self) -> &Col<f64> {
252 &self.col
253 }
254
255 #[must_use]
257 pub fn n_basis(&self) -> usize {
258 self.k
259 }
260
261 #[must_use]
263 pub fn spline_order(&self) -> SplineOrder {
264 self.order
265 }
266
267 #[must_use]
269 pub fn penalty_lambda(&self) -> f64 {
270 self.lambda
271 }
272
273 #[must_use]
275 pub fn difference_order(&self) -> usize {
276 self.penalty_order
277 }
278
279 #[must_use]
281 pub fn k(mut self, k: usize) -> Self {
282 self.k = k;
283 self
284 }
285
286 #[must_use]
288 pub fn order(mut self, order: SplineOrder) -> Self {
289 self.order = order;
290 self
291 }
292
293 #[must_use]
295 pub fn lambda(mut self, lambda: f64) -> Self {
296 self.lambda = lambda;
297 self
298 }
299
300 #[must_use]
302 pub fn penalty_order(mut self, order: usize) -> Self {
303 self.penalty_order = order;
304 self
305 }
306}
307
308#[derive(Debug, Clone, PartialEq)]
310pub struct CyclicPSplineTerm {
311 pub(crate) col: Col<f64>,
312 pub(crate) k: usize,
313 pub(crate) order: SplineOrder,
314 pub(crate) lambda: f64,
315 pub(crate) penalty_order: usize,
316}
317
318impl CyclicPSplineTerm {
319 #[must_use]
321 pub fn k(mut self, k: usize) -> Self {
322 self.k = k;
323 self
324 }
325
326 #[must_use]
328 pub fn order(mut self, order: SplineOrder) -> Self {
329 self.order = order;
330 self
331 }
332
333 #[must_use]
335 pub fn lambda(mut self, lambda: f64) -> Self {
336 self.lambda = lambda;
337 self
338 }
339
340 #[must_use]
342 pub fn penalty_order(mut self, order: usize) -> Self {
343 self.penalty_order = order;
344 self
345 }
346}
347
348#[derive(Debug, Clone, PartialEq)]
350pub struct FourierTerm {
351 pub(crate) col: Col<f64>,
352 pub(crate) period: f64,
353 pub(crate) order: usize,
354 pub(crate) include_intercept: bool,
355}
356
357impl FourierTerm {
358 #[must_use]
360 pub fn period(mut self, period: f64) -> Self {
361 self.period = period;
362 self
363 }
364
365 #[must_use]
367 pub fn order(mut self, order: usize) -> Self {
368 self.order = order;
369 self
370 }
371
372 #[must_use]
374 pub fn include_intercept(mut self, include: bool) -> Self {
375 self.include_intercept = include;
376 self
377 }
378}
379
380#[derive(Debug, Clone, PartialEq)]
382pub struct TensorPSplineTerm {
383 pub(crate) left: Col<f64>,
384 pub(crate) right: Col<f64>,
385 pub(crate) left_k: usize,
386 pub(crate) right_k: usize,
387 pub(crate) left_order: SplineOrder,
388 pub(crate) right_order: SplineOrder,
389}
390
391impl TensorPSplineTerm {
392 #[must_use]
394 pub fn k(mut self, left: usize, right: usize) -> Self {
395 self.left_k = left;
396 self.right_k = right;
397 self
398 }
399
400 #[must_use]
402 pub fn order(mut self, left: SplineOrder, right: SplineOrder) -> Self {
403 self.left_order = left;
404 self.right_order = right;
405 self
406 }
407}
408
409#[derive(Debug, Clone, PartialEq)]
411pub struct MonotoneTerm {
412 pub(crate) col: Col<f64>,
413 pub(crate) k: usize,
414 pub(crate) degree: usize,
415 pub(crate) direction: MonotoneDirection,
416}
417
418impl MonotoneTerm {
419 #[must_use]
421 pub fn k(mut self, k: usize) -> Self {
422 self.k = k;
423 self
424 }
425
426 #[must_use]
428 pub fn degree(mut self, degree: usize) -> Self {
429 self.degree = degree;
430 self
431 }
432
433 #[must_use]
435 pub fn direction(mut self, direction: MonotoneDirection) -> Self {
436 self.direction = direction;
437 self
438 }
439}
440
441impl_term_expr!(PSplineTerm, PSpline);
442impl_term_expr!(CyclicPSplineTerm, CyclicPSpline);
443impl_term_expr!(FourierTerm, Fourier);
444impl_term_expr!(TensorPSplineTerm, TensorPSpline);
445impl_term_expr!(MonotoneTerm, Monotone);
446
447#[derive(Debug, Clone, PartialEq)]
449pub enum FittedTerm {
450 Intercept {
452 range: Range<usize>,
454 coefficient: String,
456 },
457 Linear {
459 col: Col<f64>,
461 range: Range<usize>,
463 coefficient: String,
465 },
466 Offset {
468 col: Col<f64>,
470 },
471 Indicator {
473 col: Col<bool>,
475 range: Range<usize>,
477 coefficient: String,
479 },
480 Factor {
482 col: Col<Category>,
484 range: Range<usize>,
486 levels: Vec<String>,
488 baseline: String,
490 coefficients: Vec<String>,
492 },
493 Interaction {
495 left: Col<f64>,
497 right: Col<f64>,
499 range: Range<usize>,
501 coefficient: String,
503 },
504 PSpline {
506 col: Col<f64>,
508 range: Range<usize>,
510 basis: OpenUniformSplineBasis,
512 lambda: f64,
514 penalty_order: usize,
516 coefficients: Vec<String>,
518 },
519 CyclicPSpline {
521 col: Col<f64>,
523 range: Range<usize>,
525 spec: CyclicSplineSpec,
527 lambda: f64,
529 penalty_order: usize,
531 coefficients: Vec<String>,
533 },
534 Fourier {
536 col: Col<f64>,
538 range: Range<usize>,
540 period: f64,
542 order: usize,
544 include_intercept: bool,
546 coefficients: Vec<String>,
548 },
549 TensorPSpline {
551 left: Col<f64>,
553 right: Col<f64>,
555 range: Range<usize>,
557 left_basis: OpenUniformSplineBasis,
559 right_basis: OpenUniformSplineBasis,
561 coefficients: Vec<String>,
563 },
564 Monotone {
566 col: Col<f64>,
568 range: Range<usize>,
570 basis: ISplineBasis,
572 direction: MonotoneDirection,
574 coefficients: Vec<String>,
576 },
577}
578
579impl FittedTerm {
580 #[must_use]
582 pub fn range(&self) -> Range<usize> {
583 match self {
584 Self::Intercept { range, .. }
585 | Self::Linear { range, .. }
586 | Self::Indicator { range, .. }
587 | Self::Factor { range, .. }
588 | Self::Interaction { range, .. }
589 | Self::PSpline { range, .. }
590 | Self::CyclicPSpline { range, .. }
591 | Self::Fourier { range, .. }
592 | Self::TensorPSpline { range, .. }
593 | Self::Monotone { range, .. } => range.clone(),
594 Self::Offset { .. } => 0..0,
595 }
596 }
597
598 pub(crate) fn append_coefficient_names<'a>(&'a self, out: &mut Vec<&'a str>) {
600 match self {
601 Self::Intercept { coefficient, .. }
602 | Self::Linear { coefficient, .. }
603 | Self::Indicator { coefficient, .. }
604 | Self::Interaction { coefficient, .. } => out.push(coefficient.as_str()),
605 Self::PSpline { coefficients, .. }
606 | Self::CyclicPSpline { coefficients, .. }
607 | Self::Fourier { coefficients, .. }
608 | Self::Factor { coefficients, .. }
609 | Self::TensorPSpline { coefficients, .. }
610 | Self::Monotone { coefficients, .. } => {
611 out.extend(coefficients.iter().map(String::as_str));
612 }
613 Self::Offset { .. } => {}
614 }
615 }
616}