use crate::traits::FloatScalar;
use super::{InterpError, find_interval, validate_sorted};
#[derive(Debug, Clone)]
pub struct CubicSpline<T, const N: usize> {
xs: [T; N],
coeffs: [[T; 4]; N],
}
impl<T: FloatScalar, const N: usize> CubicSpline<T, N> {
pub fn new(xs: [T; N], ys: [T; N]) -> Result<Self, InterpError> {
if N < 3 {
return Err(InterpError::TooFewPoints);
}
validate_sorted(&xs)?;
let two = T::one() + T::one();
let three = two + T::one();
let six = three + three;
let n = N;
let mut h = [T::zero(); N];
let mut delta = [T::zero(); N];
for i in 0..n - 1 {
h[i] = xs[i + 1] - xs[i];
delta[i] = (ys[i + 1] - ys[i]) / h[i];
}
let mut m = [T::zero(); N];
if n > 3 {
let mut cp = [T::zero(); N]; let mut dp = [T::zero(); N];
let diag = two * (h[0] + h[1]);
if diag.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
let rhs = six * (delta[1] - delta[0]);
cp[1] = h[1] / diag;
dp[1] = rhs / diag;
for i in 2..n - 1 {
let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
if diag_i.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
if i < n - 1 {
cp[i] = h[i] / diag_i;
}
dp[i] = rhs_i / diag_i;
}
m[n - 2] = dp[n - 2];
for i in (1..n - 2).rev() {
m[i] = dp[i] - cp[i] * m[i + 1];
}
} else {
let diag = two * (h[0] + h[1]);
if diag.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
m[1] = six * (delta[1] - delta[0]) / diag;
}
let mut coeffs = [[T::zero(); 4]; N];
for i in 0..n - 1 {
let a = ys[i];
let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
let c = m[i] / two;
let d = (m[i + 1] - m[i]) / (six * h[i]);
coeffs[i] = [a, b, c, d];
}
Ok(Self { xs, coeffs })
}
pub fn eval(&self, x: T) -> T {
let i = find_interval(&self.xs, x);
let dx = x - self.xs[i];
let [a, b, c, d] = self.coeffs[i];
a + dx * (b + dx * (c + dx * d))
}
pub fn eval_derivative(&self, x: T) -> (T, T) {
let i = find_interval(&self.xs, x);
let dx = x - self.xs[i];
let [a, b, c, d] = self.coeffs[i];
let two = T::one() + T::one();
let three = two + T::one();
let val = a + dx * (b + dx * (c + dx * d));
let dval = b + dx * (two * c + three * d * dx);
(val, dval)
}
pub fn xs(&self) -> &[T; N] {
&self.xs
}
}
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
#[derive(Debug, Clone)]
pub struct DynCubicSpline<T> {
xs: Vec<T>,
coeffs: Vec<[T; 4]>,
}
#[cfg(feature = "alloc")]
impl<T: FloatScalar> DynCubicSpline<T> {
pub fn new(xs: Vec<T>, ys: Vec<T>) -> Result<Self, InterpError> {
if xs.len() != ys.len() {
return Err(InterpError::LengthMismatch);
}
if xs.len() < 3 {
return Err(InterpError::TooFewPoints);
}
validate_sorted(&xs)?;
let n = xs.len();
let two = T::one() + T::one();
let three = two + T::one();
let six = three + three;
let mut h = alloc::vec![T::zero(); n];
let mut delta = alloc::vec![T::zero(); n];
for i in 0..n - 1 {
h[i] = xs[i + 1] - xs[i];
delta[i] = (ys[i + 1] - ys[i]) / h[i];
}
let mut m = alloc::vec![T::zero(); n];
if n > 3 {
let mut cp = alloc::vec![T::zero(); n];
let mut dp = alloc::vec![T::zero(); n];
let diag = two * (h[0] + h[1]);
if diag.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
let rhs = six * (delta[1] - delta[0]);
cp[1] = h[1] / diag;
dp[1] = rhs / diag;
for i in 2..n - 1 {
let diag_i = two * (h[i - 1] + h[i]) - h[i - 1] * cp[i - 1];
if diag_i.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
let rhs_i = six * (delta[i] - delta[i - 1]) - h[i - 1] * dp[i - 1];
if i < n - 1 {
cp[i] = h[i] / diag_i;
}
dp[i] = rhs_i / diag_i;
}
m[n - 2] = dp[n - 2];
for i in (1..n - 2).rev() {
m[i] = dp[i] - cp[i] * m[i + 1];
}
} else {
let diag = two * (h[0] + h[1]);
if diag.abs() < T::epsilon() {
return Err(InterpError::IllConditioned);
}
m[1] = six * (delta[1] - delta[0]) / diag;
}
let mut coeffs = alloc::vec![[T::zero(); 4]; n - 1];
for i in 0..n - 1 {
let a = ys[i];
let b = delta[i] - h[i] * (two * m[i] + m[i + 1]) / six;
let c = m[i] / two;
let d = (m[i + 1] - m[i]) / (six * h[i]);
coeffs[i] = [a, b, c, d];
}
Ok(Self { xs, coeffs })
}
pub fn eval(&self, x: T) -> T {
let i = find_interval(&self.xs, x);
let dx = x - self.xs[i];
let [a, b, c, d] = self.coeffs[i];
a + dx * (b + dx * (c + dx * d))
}
pub fn eval_derivative(&self, x: T) -> (T, T) {
let i = find_interval(&self.xs, x);
let dx = x - self.xs[i];
let [a, b, c, d] = self.coeffs[i];
let two = T::one() + T::one();
let three = two + T::one();
let val = a + dx * (b + dx * (c + dx * d));
let dval = b + dx * (two * c + three * d * dx);
(val, dval)
}
pub fn xs(&self) -> &[T] {
&self.xs
}
}