use gamlss_core::PredictorBlock;
use crate::cyclic::{CyclicSplineDesign, CyclicSplineSpec};
use crate::row_basis::SplineRowBasis;
use crate::{SplineError, SplineOrder};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PeriodicSplineSpec {
cyclic: CyclicSplineSpec,
period: f64,
origin: f64,
}
impl PeriodicSplineSpec {
pub fn new(
n_basis: usize,
order: SplineOrder,
period: f64,
origin: f64,
) -> Result<Self, SplineError> {
if !period.is_finite() || period <= 0.0 || !origin.is_finite() {
return Err(SplineError::InvalidPeriod);
}
Ok(Self {
cyclic: CyclicSplineSpec::new(n_basis, order)?,
period,
origin,
})
}
pub fn design(&self, x: &[f64]) -> Result<PeriodicSplineDesign, SplineError> {
if x.iter().any(|value| !value.is_finite()) {
return Err(SplineError::NonFiniteValue);
}
let phi = x
.iter()
.map(|value| (value - self.origin) / self.period)
.collect::<Vec<_>>();
Ok(PeriodicSplineDesign {
x: x.to_vec(),
phase_design: self.cyclic.design(&phi)?,
spec: *self,
})
}
#[must_use]
#[inline(always)]
pub fn n_basis(&self) -> usize {
self.cyclic.n_basis()
}
#[must_use]
#[inline(always)]
pub fn period(&self) -> f64 {
self.period
}
#[must_use]
#[inline(always)]
pub fn origin(&self) -> f64 {
self.origin
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PeriodicSplineDesign {
x: Vec<f64>,
phase_design: CyclicSplineDesign,
spec: PeriodicSplineSpec,
}
impl PeriodicSplineDesign {
pub fn new(
x: &[f64],
n_basis: usize,
order: SplineOrder,
period: f64,
origin: f64,
) -> Result<Self, SplineError> {
PeriodicSplineSpec::new(n_basis, order, period, origin)?.design(x)
}
#[must_use]
#[inline(always)]
pub fn x(&self) -> &[f64] {
&self.x
}
#[must_use]
#[inline(always)]
pub fn n_basis(&self) -> usize {
self.spec.n_basis()
}
#[must_use]
#[inline(always)]
pub fn spec(&self) -> PeriodicSplineSpec {
self.spec
}
#[must_use]
#[inline]
pub fn eta_derivative_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.x.len());
debug_assert_eq!(beta.len(), self.n_basis());
self.phase_design.eta_derivative_row(row, beta) / self.spec.period
}
}
impl PredictorBlock for PeriodicSplineDesign {
#[inline(always)]
fn nrows(&self) -> usize {
PredictorBlock::nrows(&self.phase_design)
}
#[inline(always)]
fn nparams(&self) -> usize {
PredictorBlock::nparams(&self.phase_design)
}
#[inline(always)]
fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
self.phase_design.eta_row(row, beta)
}
#[inline]
fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
self.phase_design.add_gradient(scores, beta, grad);
}
#[inline]
fn add_weighted_gradient(
&self,
scores: &[f64],
multiplier: &[f64],
beta: &[f64],
grad: &mut [f64],
) {
self.phase_design
.add_weighted_gradient(scores, multiplier, beta, grad);
}
}
impl SplineRowBasis for PeriodicSplineDesign {
#[inline(always)]
fn nrows(&self) -> usize {
SplineRowBasis::nrows(&self.phase_design)
}
#[inline(always)]
fn nparams(&self) -> usize {
SplineRowBasis::nparams(&self.phase_design)
}
#[inline(always)]
fn for_each_row_basis(&self, row: usize, f: impl FnMut(usize, f64)) {
self.phase_design.for_each_row_basis(row, f);
}
}