use gamlss_core::PredictorBlock;
use crate::{SplineError, SplineOrder, SplineRowBasis};
#[derive(Debug, Clone, PartialEq)]
pub struct TruncatedPowerBasis {
knots: Vec<f64>,
order: SplineOrder,
include_intercept: bool,
n_basis: usize,
}
impl TruncatedPowerBasis {
pub fn new(
knots: Vec<f64>,
order: SplineOrder,
include_intercept: bool,
) -> Result<Self, SplineError> {
validate_strict_knots(&knots)?;
let n_basis = coefficient_count(knots.len(), order, include_intercept)?;
Ok(Self {
knots,
order,
include_intercept,
n_basis,
})
}
pub fn uniform_from_data(
x: &[f64],
n_knots: usize,
order: SplineOrder,
include_intercept: bool,
) -> Result<Self, SplineError> {
if x.is_empty() {
return Err(SplineError::EmptyInput);
}
let mut min = f64::INFINITY;
let mut max = f64::NEG_INFINITY;
for value in x.iter().copied() {
if !value.is_finite() {
return Err(SplineError::NonFiniteValue);
}
min = min.min(value);
max = max.max(value);
}
if min >= max {
return Err(SplineError::InvalidRange);
}
let denominator = n_knots
.checked_add(1)
.ok_or(SplineError::ParameterOverflow)?;
let step = (max - min) / denominator as f64;
let knots = (1..=n_knots)
.map(|index| min + step * index as f64)
.collect::<Vec<_>>();
Self::new(knots, order, include_intercept)
}
pub fn design(&self, x: &[f64]) -> Result<TruncatedPowerDesign, SplineError> {
if x.iter().any(|value| !value.is_finite()) {
return Err(SplineError::NonFiniteValue);
}
Ok(TruncatedPowerDesign {
x: x.to_vec(),
basis: self.clone(),
})
}
#[must_use]
#[inline(always)]
pub fn knots(&self) -> &[f64] {
&self.knots
}
#[must_use]
#[inline(always)]
pub fn order(&self) -> SplineOrder {
self.order
}
#[must_use]
#[inline(always)]
pub fn include_intercept(&self) -> bool {
self.include_intercept
}
#[must_use]
#[inline(always)]
pub fn n_basis(&self) -> usize {
self.n_basis
}
#[must_use]
pub fn evaluate(&self, x: f64) -> Vec<f64> {
let mut values = vec![0.0; self.n_basis];
self.evaluate_into(x, &mut values);
values
}
#[inline]
pub fn evaluate_into(&self, x: f64, out: &mut [f64]) {
debug_assert_eq!(out.len(), self.n_basis);
for value in out.iter_mut() {
*value = 0.0;
}
self.for_each_basis(x, |index, weight| out[index] = weight);
}
#[inline]
pub fn for_each_basis(&self, x: f64, mut f: impl FnMut(usize, f64)) {
let degree = self.order.degree();
let mut offset = 0;
if self.include_intercept {
f(0, 1.0);
offset = 1;
}
for power in 1..=degree {
let weight = pow_usize(x, power);
if weight != 0.0 {
f(offset + power - 1, weight);
}
}
offset += degree;
for (knot_offset, knot) in self.knots.iter().copied().enumerate() {
if x > knot {
f(offset + knot_offset, pow_usize(x - knot, degree));
}
}
}
#[must_use]
pub fn evaluate_derivative(&self, x: f64) -> Vec<f64> {
let mut values = vec![0.0; self.n_basis];
self.evaluate_derivative_into(x, &mut values);
values
}
#[inline]
pub fn evaluate_derivative_into(&self, x: f64, out: &mut [f64]) {
debug_assert_eq!(out.len(), self.n_basis);
for value in out.iter_mut() {
*value = 0.0;
}
self.for_each_derivative_basis(x, |index, weight| out[index] = weight);
}
#[inline]
pub fn for_each_derivative_basis(&self, x: f64, mut f: impl FnMut(usize, f64)) {
let degree = self.order.degree();
let mut offset = 0;
if self.include_intercept {
offset = 1;
}
for power in 1..=degree {
let weight = if power == 1 {
1.0
} else {
power as f64 * pow_usize(x, power - 1)
};
if weight != 0.0 {
f(offset + power - 1, weight);
}
}
offset += degree;
for (knot_offset, knot) in self.knots.iter().copied().enumerate() {
if x > knot {
f(
offset + knot_offset,
degree as f64 * pow_usize(x - knot, degree - 1),
);
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TruncatedPowerDesign {
x: Vec<f64>,
basis: TruncatedPowerBasis,
}
impl TruncatedPowerDesign {
pub fn uniform_from_data(
x: &[f64],
n_knots: usize,
order: SplineOrder,
include_intercept: bool,
) -> Result<Self, SplineError> {
TruncatedPowerBasis::uniform_from_data(x, n_knots, order, include_intercept)?.design(x)
}
#[must_use]
#[inline(always)]
pub fn basis(&self) -> &TruncatedPowerBasis {
&self.basis
}
#[must_use]
#[inline(always)]
pub fn x(&self) -> &[f64] {
&self.x
}
#[must_use]
#[inline(always)]
pub fn n_basis(&self) -> usize {
self.basis.n_basis()
}
#[must_use]
#[inline]
pub fn eta_derivative_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.x.len());
debug_assert_eq!(beta.len(), self.basis.n_basis());
let mut value = 0.0;
self.basis
.for_each_derivative_basis(self.x[row], |index, weight| {
value += beta[index] * weight;
});
value
}
}
impl PredictorBlock for TruncatedPowerDesign {
#[inline(always)]
fn nrows(&self) -> usize {
self.x.len()
}
#[inline(always)]
fn nparams(&self) -> usize {
self.basis.n_basis()
}
#[inline]
fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.x.len());
debug_assert_eq!(beta.len(), self.basis.n_basis());
let mut value = 0.0;
self.basis.for_each_basis(self.x[row], |index, weight| {
value += beta[index] * weight;
});
value
}
#[inline]
fn add_gradient(&self, scores: &[f64], _: &[f64], grad: &mut [f64]) {
debug_assert_eq!(scores.len(), self.x.len());
debug_assert_eq!(grad.len(), self.basis.n_basis());
for (row, score) in scores.iter().copied().enumerate() {
self.for_each_row_basis(row, |index, weight| {
grad[index] += score * weight;
});
}
}
#[inline]
fn add_weighted_gradient(
&self,
scores: &[f64],
multiplier: &[f64],
_: &[f64],
grad: &mut [f64],
) {
debug_assert_eq!(scores.len(), self.x.len());
debug_assert_eq!(multiplier.len(), self.x.len());
debug_assert_eq!(grad.len(), self.basis.n_basis());
for (row, (&score, &multiplier)) in scores.iter().zip(multiplier).enumerate() {
self.for_each_row_basis(row, |index, weight| {
grad[index] += score * multiplier * weight;
});
}
}
}
impl SplineRowBasis for TruncatedPowerDesign {
#[inline(always)]
fn nrows(&self) -> usize {
self.x.len()
}
#[inline(always)]
fn nparams(&self) -> usize {
self.basis.n_basis()
}
#[inline]
fn for_each_row_basis(&self, row: usize, f: impl FnMut(usize, f64)) {
debug_assert!(row < self.x.len());
self.basis.for_each_basis(self.x[row], f);
}
}
fn validate_strict_knots(knots: &[f64]) -> Result<(), SplineError> {
let mut previous = None;
for knot in knots.iter().copied() {
if !knot.is_finite() || previous.is_some_and(|previous| previous >= knot) {
return Err(SplineError::InvalidKnots);
}
previous = Some(knot);
}
Ok(())
}
fn coefficient_count(
n_knots: usize,
order: SplineOrder,
include_intercept: bool,
) -> Result<usize, SplineError> {
order
.degree()
.checked_add(usize::from(include_intercept))
.and_then(|count| count.checked_add(n_knots))
.ok_or(SplineError::ParameterOverflow)
}
#[inline]
fn pow_usize(value: f64, power: usize) -> f64 {
(0..power).fold(1.0, |product, _| product * value)
}