use gamlss_core::{Link, PredictorBlock, Softplus};
use crate::SplineError;
use crate::ispline::ISplineBasis;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MonotoneDirection {
Increasing,
Decreasing,
}
impl MonotoneDirection {
fn sign(self) -> f64 {
match self {
Self::Increasing => 1.0,
Self::Decreasing => -1.0,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MonotoneISplineDesign {
x: Vec<f64>,
basis: ISplineBasis,
direction: MonotoneDirection,
}
impl MonotoneISplineDesign {
pub fn new(
x: &[f64],
basis: ISplineBasis,
direction: MonotoneDirection,
) -> Result<Self, SplineError> {
if x.iter().any(|value| !value.is_finite()) {
return Err(SplineError::NonFiniteValue);
}
Ok(Self {
x: x.to_vec(),
basis,
direction,
})
}
#[must_use]
pub fn basis(&self) -> &ISplineBasis {
&self.basis
}
#[must_use]
pub fn x(&self) -> &[f64] {
&self.x
}
#[must_use]
pub fn n_increments(&self) -> usize {
self.basis.n_basis()
}
#[must_use]
pub fn direction(&self) -> MonotoneDirection {
self.direction
}
#[must_use]
pub fn eta_derivative_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.x.len());
debug_assert_eq!(beta.len(), self.nparams());
let sign = self.direction.sign();
self.basis
.evaluate_derivative(self.x[row])
.iter()
.zip(&beta[1..])
.map(|(basis, beta)| sign * Softplus::inverse(*beta) * basis)
.sum()
}
fn add_row_gradient(&self, row: usize, score: f64, beta: &[f64], grad: &mut [f64]) {
let sign = self.direction.sign();
grad[0] += score;
for (index, basis) in self.basis.evaluate(self.x[row]).into_iter().enumerate() {
grad[index + 1] += score * sign * basis * Softplus::derivative_inverse(beta[index + 1]);
}
}
}
impl PredictorBlock for MonotoneISplineDesign {
fn nrows(&self) -> usize {
self.x.len()
}
fn nparams(&self) -> usize {
1 + self.basis.n_basis()
}
fn eta_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.x.len());
debug_assert_eq!(beta.len(), self.nparams());
let sign = self.direction.sign();
beta[0]
+ self
.basis
.evaluate(self.x[row])
.iter()
.zip(&beta[1..])
.map(|(basis, beta)| sign * Softplus::inverse(*beta) * basis)
.sum::<f64>()
}
fn add_gradient(&self, scores: &[f64], beta: &[f64], grad: &mut [f64]) {
debug_assert_eq!(scores.len(), self.x.len());
debug_assert_eq!(beta.len(), self.nparams());
debug_assert_eq!(grad.len(), self.nparams());
for (row, score) in scores.iter().copied().enumerate() {
self.add_row_gradient(row, score, beta, grad);
}
}
fn add_weighted_gradient(
&self,
scores: &[f64],
multiplier: &[f64],
beta: &[f64],
grad: &mut [f64],
) {
debug_assert_eq!(scores.len(), self.x.len());
debug_assert_eq!(multiplier.len(), self.x.len());
debug_assert_eq!(beta.len(), self.nparams());
debug_assert_eq!(grad.len(), self.nparams());
for (row, (&score, &multiplier)) in scores.iter().zip(multiplier).enumerate() {
self.add_row_gradient(row, score * multiplier, beta, grad);
}
}
}