use crate::Accelerator;
use crate::DomainError;
use crate::DynInterpType;
use crate::InterpType;
use crate::Interpolation;
use crate::InterpolationError;
#[derive(Clone)]
pub struct Spline<I, T>
where
I: InterpType<T> + Send + Sync + 'static,
{
pub interp: I::Interpolation,
pub xa: Box<[T]>,
pub ya: Box<[T]>,
name: Box<str>,
min_size: usize,
}
impl<I, T> Spline<I, T>
where
I: InterpType<T> + Send + Sync + 'static,
{
#[doc(alias = "gsl_spline_init")]
pub fn new(typ: I, xa: &[T], ya: &[T]) -> Result<Self, InterpolationError>
where
T: Clone,
{
Ok(Self {
interp: typ.build(xa, ya)?,
xa: xa.into(),
ya: ya.into(),
name: typ.name().into(),
min_size: typ.min_size(),
})
}
#[doc(alias = "gsl_spline_eval")]
#[doc(alias = "gsl_spline_eval_e")]
pub fn eval(&self, x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
self.interp.eval(&self.xa, &self.ya, x, acc)
}
#[doc(alias = "gsl_spline_eval_deriv")]
#[doc(alias = "gsl_spline_eval_deriv_e")]
pub fn eval_deriv(&self, x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
self.interp.eval_deriv(&self.xa, &self.ya, x, acc)
}
#[doc(alias = "gsl_spline_eval_deriv2")]
#[doc(alias = "gsl_spline_eval_deriv2_e")]
pub fn eval_deriv2(&self, x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
self.interp.eval_deriv2(&self.xa, &self.ya, x, acc)
}
#[allow(rustdoc::broken_intra_doc_links)]
#[doc(alias = "gsl_spline_eval_integ")]
#[doc(alias = "gsl_spline_eval_integ_e")]
pub fn eval_integ(&self, a: T, b: T, acc: &mut Accelerator) -> Result<T, DomainError> {
self.interp.eval_integ(&self.xa, &self.ya, a, b, acc)
}
#[doc(alias = "gsl_spline_name")]
pub fn name(&self) -> &str {
&self.name
}
#[doc(alias = "gsl_spline_min_size")]
pub fn min_size(&self) -> usize {
self.min_size
}
}
pub type DynSpline<T> = Spline<DynInterpType<T>, T>;
impl<T> DynSpline<T> {
#[doc(alias = "gsl_spline_init")]
pub fn new_dyn<I>(typ: I, xa: &[T], ya: &[T]) -> Result<Self, InterpolationError>
where
T: Clone,
I: InterpType<T> + Send + Sync + 'static,
I::Interpolation: Send + Sync + 'static,
{
Self::new(DynInterpType::new(typ), xa, ya)
}
}
pub fn make_spline<T>(typ: &str, xa: &[T], ya: &[T]) -> Result<DynSpline<T>, InterpolationError>
where
T: crate::Num + ndarray_linalg::Lapack,
{
use crate::*;
match typ.to_lowercase().as_str() {
"linear" => Ok(DynSpline::new_dyn(Linear, xa, ya)?),
"cubic" => Ok(DynSpline::new_dyn(Cubic, xa, ya)?),
"akima" => Ok(DynSpline::new_dyn(Akima, xa, ya)?),
"cubicperiodic" | "cubic periodic" => Ok(DynSpline::new_dyn(CubicPeriodic, xa, ya)?),
"akimaperiodic" | "akima periodic" => Ok(DynSpline::new_dyn(AkimaPeriodic, xa, ya)?),
"steffen" => Ok(DynSpline::new_dyn(Steffen, xa, ya)?),
_ => Err(InterpolationError::InvalidType(typ.into())),
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::*;
#[test]
fn test_spline_creation() {
let xa = [0.0, 1.0, 2.0, 3.0, 4.0];
let ya = [0.0, 2.0, 4.0, 6.0, 8.0];
let spline = Spline::new(Cubic, &xa, &ya).unwrap();
let _: &str = spline.name();
let _: usize = spline.min_size();
}
#[test]
fn test_spline_eval() {
let xa = [0.0, 1.0, 2.0];
let ya = [0.0, 1.0, 2.0];
let spline = Spline::new(Cubic, &xa, &ya).unwrap();
let mut acc = Accelerator::new();
let x = 0.5;
let y = spline.eval(x, &mut acc).unwrap();
let dy = spline.eval_deriv(x, &mut acc).unwrap();
let dy2 = spline.eval_deriv2(x, &mut acc).unwrap();
let int = spline.eval_integ(0.0, x, &mut acc).unwrap();
assert_eq!(y, 0.5);
assert_eq!(dy, 1.0);
assert_eq!(dy2, 0.0);
assert_eq!(int, 0.125);
}
#[test]
fn test_dyn_spline() {
let xa = [0.0, 1.0, 2.0, 3.0, 4.0];
let ya = [0.0, 2.0, 4.0, 6.0, 8.0];
let spline = Spline::new_dyn(Cubic, &xa, &ya).unwrap();
let mut acc = Accelerator::new();
spline.eval(1.5, &mut acc).unwrap();
spline.eval_deriv(1.5, &mut acc).unwrap();
spline.eval_deriv2(1.5, &mut acc).unwrap();
spline.eval_integ(1.5, 2.5, &mut acc).unwrap();
}
#[test]
fn test_make_spline() {
let xa = [0.0, 1.0, 2.0, 3.0, 4.0];
let ya = [0.0, 2.0, 4.0, 6.0, 8.0];
make_spline("linear", &xa, &ya).unwrap();
make_spline("cubic", &xa, &ya).unwrap();
make_spline("akima", &xa, &ya).unwrap();
make_spline("akimaperiodic", &xa, &ya).unwrap();
make_spline("steffen", &xa, &ya).unwrap();
assert!(make_spline("wrong", &xa, &ya).is_err());
}
#[test]
#[should_panic]
fn test_make_cubic_periodic_spline_panic() {
let xa = [0.0, 1.0, 2.0, 3.0, 4.0];
let ya = [0.0, 2.0, 4.0, 6.0, 8.0];
make_spline("cubicperiodic", &xa, &ya).unwrap();
}
}