use std::ops::{Add, Range};
use gamlss_spline::{
CyclicSplineSpec, ISplineBasis, MonotoneDirection, OpenUniformSplineBasis, SplineOrder,
};
use crate::{Category, Col};
#[derive(Debug, Clone, PartialEq)]
pub struct TermExpr {
terms: Vec<TermSpec>,
default_intercept_if_empty: bool,
}
impl TermExpr {
#[must_use]
pub fn new(term: TermSpec) -> Self {
Self {
terms: vec![term],
default_intercept_if_empty: false,
}
}
#[must_use]
pub fn empty() -> Self {
Self {
terms: Vec::new(),
default_intercept_if_empty: false,
}
}
pub(crate) fn into_terms(mut self) -> Vec<TermSpec> {
if self.terms.is_empty() && self.default_intercept_if_empty {
self.terms.push(TermSpec::Intercept);
}
self.terms
}
}
impl Default for TermExpr {
fn default() -> Self {
Self {
terms: Vec::new(),
default_intercept_if_empty: true,
}
}
}
impl Add for TermExpr {
type Output = Self;
fn add(mut self, mut rhs: Self) -> Self::Output {
self.terms.append(&mut rhs.terms);
if !rhs.default_intercept_if_empty {
self.default_intercept_if_empty = false;
}
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TermSpec {
Intercept,
Linear {
col: Col<f64>,
},
Offset {
col: Col<f64>,
},
Indicator {
col: Col<bool>,
},
Factor {
col: Col<Category>,
},
Interaction {
left: Col<f64>,
right: Col<f64>,
},
PSpline(PSplineTerm),
CyclicPSpline(CyclicPSplineTerm),
Fourier(FourierTerm),
TensorPSpline(TensorPSplineTerm),
Monotone(MonotoneTerm),
}
#[must_use]
pub fn intercept() -> TermExpr {
TermExpr::new(TermSpec::Intercept)
}
#[must_use]
pub fn no_intercept() -> TermExpr {
TermExpr::empty()
}
#[must_use]
pub fn linear(col: Col<f64>) -> TermExpr {
TermExpr::new(TermSpec::Linear { col })
}
#[must_use]
pub fn offset(col: Col<f64>) -> TermExpr {
TermExpr::new(TermSpec::Offset { col })
}
#[must_use]
pub fn indicator(col: Col<bool>) -> TermExpr {
TermExpr::new(TermSpec::Indicator { col })
}
#[must_use]
pub fn factor(col: Col<Category>) -> TermExpr {
TermExpr::new(TermSpec::Factor { col })
}
#[must_use]
pub fn interaction(left: Col<f64>, right: Col<f64>) -> TermExpr {
TermExpr::new(TermSpec::Interaction { left, right })
}
#[must_use]
pub fn pspline(col: Col<f64>) -> PSplineTerm {
PSplineTerm {
col,
k: 20,
order: SplineOrder::Cubic,
lambda: 1.0,
penalty_order: 2,
}
}
#[must_use]
pub fn cyclic_pspline(col: Col<f64>) -> CyclicPSplineTerm {
CyclicPSplineTerm {
col,
k: 20,
order: SplineOrder::Cubic,
lambda: 1.0,
penalty_order: 2,
}
}
#[must_use]
pub fn fourier(col: Col<f64>) -> FourierTerm {
FourierTerm {
col,
period: 1.0,
order: 1,
include_intercept: false,
}
}
#[must_use]
pub fn tensor_pspline(left: Col<f64>, right: Col<f64>) -> TensorPSplineTerm {
TensorPSplineTerm {
left,
right,
left_k: 10,
right_k: 10,
left_order: SplineOrder::Cubic,
right_order: SplineOrder::Cubic,
}
}
#[must_use]
pub fn monotone(col: Col<f64>) -> MonotoneTerm {
MonotoneTerm {
col,
k: 20,
degree: SplineOrder::Cubic.degree(),
direction: MonotoneDirection::Increasing,
}
}
macro_rules! impl_term_expr {
($term:ty, $variant:ident) => {
impl From<$term> for TermExpr {
fn from(value: $term) -> Self {
Self::new(TermSpec::$variant(value))
}
}
impl Add<$term> for TermExpr {
type Output = TermExpr;
fn add(self, rhs: $term) -> Self::Output {
self + TermExpr::from(rhs)
}
}
impl<Rhs> Add<Rhs> for $term
where
Rhs: Into<TermExpr>,
{
type Output = TermExpr;
fn add(self, rhs: Rhs) -> Self::Output {
TermExpr::from(self) + rhs.into()
}
}
};
}
#[derive(Debug, Clone, PartialEq)]
pub struct PSplineTerm {
pub(crate) col: Col<f64>,
pub(crate) k: usize,
pub(crate) order: SplineOrder,
pub(crate) lambda: f64,
pub(crate) penalty_order: usize,
}
impl PSplineTerm {
#[must_use]
pub fn col(&self) -> &Col<f64> {
&self.col
}
#[must_use]
pub fn n_basis(&self) -> usize {
self.k
}
#[must_use]
pub fn spline_order(&self) -> SplineOrder {
self.order
}
#[must_use]
pub fn penalty_lambda(&self) -> f64 {
self.lambda
}
#[must_use]
pub fn difference_order(&self) -> usize {
self.penalty_order
}
#[must_use]
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub fn order(mut self, order: SplineOrder) -> Self {
self.order = order;
self
}
#[must_use]
pub fn lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
#[must_use]
pub fn penalty_order(mut self, order: usize) -> Self {
self.penalty_order = order;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CyclicPSplineTerm {
pub(crate) col: Col<f64>,
pub(crate) k: usize,
pub(crate) order: SplineOrder,
pub(crate) lambda: f64,
pub(crate) penalty_order: usize,
}
impl CyclicPSplineTerm {
#[must_use]
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub fn order(mut self, order: SplineOrder) -> Self {
self.order = order;
self
}
#[must_use]
pub fn lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
#[must_use]
pub fn penalty_order(mut self, order: usize) -> Self {
self.penalty_order = order;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FourierTerm {
pub(crate) col: Col<f64>,
pub(crate) period: f64,
pub(crate) order: usize,
pub(crate) include_intercept: bool,
}
impl FourierTerm {
#[must_use]
pub fn period(mut self, period: f64) -> Self {
self.period = period;
self
}
#[must_use]
pub fn order(mut self, order: usize) -> Self {
self.order = order;
self
}
#[must_use]
pub fn include_intercept(mut self, include: bool) -> Self {
self.include_intercept = include;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TensorPSplineTerm {
pub(crate) left: Col<f64>,
pub(crate) right: Col<f64>,
pub(crate) left_k: usize,
pub(crate) right_k: usize,
pub(crate) left_order: SplineOrder,
pub(crate) right_order: SplineOrder,
}
impl TensorPSplineTerm {
#[must_use]
pub fn k(mut self, left: usize, right: usize) -> Self {
self.left_k = left;
self.right_k = right;
self
}
#[must_use]
pub fn order(mut self, left: SplineOrder, right: SplineOrder) -> Self {
self.left_order = left;
self.right_order = right;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MonotoneTerm {
pub(crate) col: Col<f64>,
pub(crate) k: usize,
pub(crate) degree: usize,
pub(crate) direction: MonotoneDirection,
}
impl MonotoneTerm {
#[must_use]
pub fn k(mut self, k: usize) -> Self {
self.k = k;
self
}
#[must_use]
pub fn degree(mut self, degree: usize) -> Self {
self.degree = degree;
self
}
#[must_use]
pub fn direction(mut self, direction: MonotoneDirection) -> Self {
self.direction = direction;
self
}
}
impl_term_expr!(PSplineTerm, PSpline);
impl_term_expr!(CyclicPSplineTerm, CyclicPSpline);
impl_term_expr!(FourierTerm, Fourier);
impl_term_expr!(TensorPSplineTerm, TensorPSpline);
impl_term_expr!(MonotoneTerm, Monotone);
#[derive(Debug, Clone, PartialEq)]
pub enum FittedTerm {
Intercept {
range: Range<usize>,
coefficient: String,
},
Linear {
col: Col<f64>,
range: Range<usize>,
coefficient: String,
},
Offset {
col: Col<f64>,
},
Indicator {
col: Col<bool>,
range: Range<usize>,
coefficient: String,
},
Factor {
col: Col<Category>,
range: Range<usize>,
levels: Vec<String>,
baseline: String,
coefficients: Vec<String>,
},
Interaction {
left: Col<f64>,
right: Col<f64>,
range: Range<usize>,
coefficient: String,
},
PSpline {
col: Col<f64>,
range: Range<usize>,
basis: OpenUniformSplineBasis,
lambda: f64,
penalty_order: usize,
coefficients: Vec<String>,
},
CyclicPSpline {
col: Col<f64>,
range: Range<usize>,
spec: CyclicSplineSpec,
lambda: f64,
penalty_order: usize,
coefficients: Vec<String>,
},
Fourier {
col: Col<f64>,
range: Range<usize>,
period: f64,
order: usize,
include_intercept: bool,
coefficients: Vec<String>,
},
TensorPSpline {
left: Col<f64>,
right: Col<f64>,
range: Range<usize>,
left_basis: OpenUniformSplineBasis,
right_basis: OpenUniformSplineBasis,
coefficients: Vec<String>,
},
Monotone {
col: Col<f64>,
range: Range<usize>,
basis: ISplineBasis,
direction: MonotoneDirection,
coefficients: Vec<String>,
},
}
impl FittedTerm {
#[must_use]
pub fn range(&self) -> Range<usize> {
match self {
Self::Intercept { range, .. }
| Self::Linear { range, .. }
| Self::Indicator { range, .. }
| Self::Factor { range, .. }
| Self::Interaction { range, .. }
| Self::PSpline { range, .. }
| Self::CyclicPSpline { range, .. }
| Self::Fourier { range, .. }
| Self::TensorPSpline { range, .. }
| Self::Monotone { range, .. } => range.clone(),
Self::Offset { .. } => 0..0,
}
}
pub(crate) fn append_coefficient_names<'a>(&'a self, out: &mut Vec<&'a str>) {
match self {
Self::Intercept { coefficient, .. }
| Self::Linear { coefficient, .. }
| Self::Indicator { coefficient, .. }
| Self::Interaction { coefficient, .. } => out.push(coefficient.as_str()),
Self::PSpline { coefficients, .. }
| Self::CyclicPSpline { coefficients, .. }
| Self::Fourier { coefficients, .. }
| Self::Factor { coefficients, .. }
| Self::TensorPSpline { coefficients, .. }
| Self::Monotone { coefficients, .. } => {
out.extend(coefficients.iter().map(String::as_str));
}
Self::Offset { .. } => {}
}
}
}