use crate::traits::FloatScalar;
use super::{InterpError, find_interval, validate_sorted};
#[derive(Debug, Clone)]
pub struct HermiteInterp<T, const N: usize> {
xs: [T; N],
ys: [T; N],
dys: [T; N],
}
impl<T: FloatScalar, const N: usize> HermiteInterp<T, N> {
pub fn new(xs: [T; N], ys: [T; N], dys: [T; N]) -> Result<Self, InterpError> {
if N < 2 {
return Err(InterpError::TooFewPoints);
}
validate_sorted(&xs)?;
Ok(Self { xs, ys, dys })
}
pub fn eval(&self, x: T) -> T {
let i = find_interval(&self.xs, x);
let h = self.xs[i + 1] - self.xs[i];
let t = (x - self.xs[i]) / h;
let t2 = t * t;
let t3 = t2 * t;
let two = T::one() + T::one();
let three = two + T::one();
let h00 = two * t3 - three * t2 + T::one();
let h10 = t3 - two * t2 + t;
let h01 = (T::zero() - two) * t3 + three * t2;
let h11 = t3 - t2;
h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1]
}
pub fn eval_derivative(&self, x: T) -> (T, T) {
let i = find_interval(&self.xs, x);
let h = self.xs[i + 1] - self.xs[i];
let t = (x - self.xs[i]) / h;
let t2 = t * t;
let t3 = t2 * t;
let two = T::one() + T::one();
let three = two + T::one();
let six = three + three;
let h00 = two * t3 - three * t2 + T::one();
let h10 = t3 - two * t2 + t;
let h01 = (T::zero() - two) * t3 + three * t2;
let h11 = t3 - t2;
let val =
h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1];
let dh00 = six * t2 - six * t;
let dh10 = three * t2 - (two + two) * t + T::one();
let dh01 = (T::zero() - six) * t2 + six * t;
let dh11 = three * t2 - two * t;
let dval = (dh00 * self.ys[i] + dh10 * h * self.dys[i] + dh01 * self.ys[i + 1]
+ dh11 * h * self.dys[i + 1])
/ h;
(val, dval)
}
pub fn xs(&self) -> &[T; N] {
&self.xs
}
pub fn ys(&self) -> &[T; N] {
&self.ys
}
}
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
#[derive(Debug, Clone)]
pub struct DynHermiteInterp<T> {
xs: Vec<T>,
ys: Vec<T>,
dys: Vec<T>,
}
#[cfg(feature = "alloc")]
impl<T: FloatScalar> DynHermiteInterp<T> {
pub fn new(xs: Vec<T>, ys: Vec<T>, dys: Vec<T>) -> Result<Self, InterpError> {
if xs.len() != ys.len() || xs.len() != dys.len() {
return Err(InterpError::LengthMismatch);
}
if xs.len() < 2 {
return Err(InterpError::TooFewPoints);
}
validate_sorted(&xs)?;
Ok(Self { xs, ys, dys })
}
pub fn eval(&self, x: T) -> T {
let i = find_interval(&self.xs, x);
let h = self.xs[i + 1] - self.xs[i];
let t = (x - self.xs[i]) / h;
let t2 = t * t;
let t3 = t2 * t;
let two = T::one() + T::one();
let three = two + T::one();
let h00 = two * t3 - three * t2 + T::one();
let h10 = t3 - two * t2 + t;
let h01 = (T::zero() - two) * t3 + three * t2;
let h11 = t3 - t2;
h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1]
}
pub fn eval_derivative(&self, x: T) -> (T, T) {
let i = find_interval(&self.xs, x);
let h = self.xs[i + 1] - self.xs[i];
let t = (x - self.xs[i]) / h;
let t2 = t * t;
let t3 = t2 * t;
let two = T::one() + T::one();
let three = two + T::one();
let six = three + three;
let h00 = two * t3 - three * t2 + T::one();
let h10 = t3 - two * t2 + t;
let h01 = (T::zero() - two) * t3 + three * t2;
let h11 = t3 - t2;
let val =
h00 * self.ys[i] + h10 * h * self.dys[i] + h01 * self.ys[i + 1] + h11 * h * self.dys[i + 1];
let dh00 = six * t2 - six * t;
let dh10 = three * t2 - (two + two) * t + T::one();
let dh01 = (T::zero() - six) * t2 + six * t;
let dh11 = three * t2 - two * t;
let dval = (dh00 * self.ys[i] + dh10 * h * self.dys[i] + dh01 * self.ys[i + 1]
+ dh11 * h * self.dys[i + 1])
/ h;
(val, dval)
}
pub fn xs(&self) -> &[T] {
&self.xs
}
pub fn ys(&self) -> &[T] {
&self.ys
}
}