use crate::Accelerator;
use crate::Cache;
use crate::DynInterp2dType;
use crate::Interp2dType;
use crate::Interpolation2d;
use crate::{DomainError, InterpolationError};
#[derive(Clone)]
pub struct Spline2d<I, T>
where
I: Interp2dType<T> + Send + Sync + 'static,
{
pub interp: I::Interpolation2d,
pub xa: Box<[T]>,
pub ya: Box<[T]>,
pub za: Box<[T]>,
name: Box<str>,
min_size: usize,
}
impl<I, T> Spline2d<I, T>
where
I: Interp2dType<T> + Send + Sync + 'static,
{
#[doc(alias = "gsl_spline2d_init")]
pub fn new(typ: I, xa: &[T], ya: &[T], za: &[T]) -> Result<Self, InterpolationError>
where
T: Clone,
{
let spline = Self {
interp: typ.build(xa, ya, za)?,
xa: xa.into(),
ya: ya.into(),
za: za.into(),
name: typ.name().into(),
min_size: typ.min_size(),
};
Ok(spline)
}
#[doc(alias = "gsl_spline2d_eval")]
#[doc(alias = "gsl_spline2d_eval_e")]
pub fn eval(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>
where
T: crate::Num,
{
self.interp
.eval(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_spline2d_eval_extrap")]
#[doc(alias = "gsl_spline2d_eval_extrap_e")]
pub fn eval_extrap(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_extrap(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_spline2d_eval_deriv_x")]
#[doc(alias = "gsl_spline2d_eval_deriv_x_e")]
pub fn eval_deriv_x(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_deriv_x(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_spline2d_eval_deriv_y")]
#[doc(alias = "gsl_spline2d_eval_deriv_y_e")]
pub fn eval_deriv_y(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_deriv_y(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_interp2d_eval_deriv_xx")]
#[doc(alias = "gsl_interp2d_eval_deriv_xx_e")]
pub fn eval_deriv_xx(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_deriv_xx(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_interp2d_eval_deriv_yy")]
#[doc(alias = "gsl_interp2d_eval_deriv_yy_e")]
pub fn eval_deriv_yy(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_deriv_yy(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_interp2d_eval_deriv_xy")]
#[doc(alias = "gsl_interp2d_eval_deriv_xy_e")]
pub fn eval_deriv_xy(
&self,
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError> {
self.interp
.eval_deriv_xy(&self.xa, &self.ya, &self.za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_interp2d_name")]
pub fn name(&self) -> &str {
&self.name
}
#[doc(alias = "gsl_interp2d_min_size")]
pub fn min_size(&self) -> usize {
self.min_size
}
}
pub type DynSpline2d<T> = Spline2d<DynInterp2dType<T>, T>;
impl<T> DynSpline2d<T> {
#[doc(alias = "gsl_spline2d_init")]
pub fn new_dyn<I>(typ: I, xa: &[T], ya: &[T], za: &[T]) -> Result<Self, InterpolationError>
where
T: Clone,
I: Interp2dType<T> + Send + Sync + 'static,
I::Interpolation2d: Send + Sync + 'static,
{
Self::new(DynInterp2dType::new(typ), xa, ya, za)
}
}
pub fn make_spline2d<T>(
typ: &str,
xa: &[T],
ya: &[T],
za: &[T],
) -> Result<DynSpline2d<T>, InterpolationError>
where
T: crate::Num + ndarray_linalg::Lapack,
{
use crate::*;
match typ.to_lowercase().as_str() {
"bilinear" => Ok(DynSpline2d::new_dyn(Bilinear, xa, ya, za)?),
"bicubic" => Ok(DynSpline2d::new_dyn(Bicubic, xa, ya, za)?),
_ => Err(InterpolationError::InvalidType(typ.into())),
}
}
#[cfg(test)]
mod test {
use crate::tests::build_comparator;
use crate::*;
#[test]
fn test_spline2d_creation() {
let xa = [0.0, 1.0, 2.0];
let ya = [0.0, 2.0, 4.0];
let za = [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0, 6.0];
let spline2d = Spline2d::new(Bilinear, &xa, &ya, &za).unwrap();
let _: &str = spline2d.name();
let _: usize = spline2d.min_size();
}
#[test]
fn test_spline2d_eval() {
let comp = build_comparator::<f64>();
let mut xacc = Accelerator::new();
let mut yacc = Accelerator::new();
let mut cache = Cache::new();
let xa = [0.0, 1.0, 2.0, 3.0];
let ya = [0.0, 1.0, 2.0, 3.0];
#[rustfmt::skip]
let za = [
1.0, 1.1, 1.2, 1.3,
1.1, 1.2, 1.3, 1.4,
1.2, 1.3, 1.4, 1.5,
1.3, 1.4, 1.5, 1.6,
];
let spline2d = Spline2d::new(Bicubic, &xa, &ya, &za).unwrap();
let (x, y) = (1.5, 1.5);
let z = spline2d
.eval(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
let dzdx = spline2d
.eval_deriv_x(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
let dzdy = spline2d
.eval_deriv_y(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
let dzdx2 = spline2d
.eval_deriv_xx(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
let dzdy2 = spline2d
.eval_deriv_yy(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
let dzdxy = spline2d
.eval_deriv_xy(x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
assert!(comp.is_close(z, 1.3));
assert!(comp.is_close(dzdx, 0.1));
assert!(comp.is_close(dzdy, 0.1));
assert!(comp.is_close(dzdx2, 0.0));
assert!(comp.is_close(dzdy2, 0.0));
assert!(comp.is_close(dzdxy, 0.0));
let ze = spline2d
.eval_extrap(4.0, 4.0, &mut xacc, &mut yacc, &mut cache)
.unwrap();
assert!(comp.is_close(ze, 1.8));
}
#[test]
fn test_dyn_spline2d() {
let xa = [0.0, 1.0, 2.0, 3.0];
let ya = [0.0, 1.0, 2.0, 3.0];
#[rustfmt::skip]
let za = [
1.0, 1.1, 1.2, 1.3,
1.1, 1.2, 1.3, 1.4,
1.2, 1.3, 1.4, 1.5,
1.3, 1.4, 1.5, 1.6,
];
let spline = Spline2d::new_dyn(Bicubic, &xa, &ya, &za).unwrap();
let xacc = &mut Accelerator::new();
let yacc = &mut Accelerator::new();
let cache = &mut Cache::new();
spline.eval(1.5, 2.5, xacc, yacc, cache).unwrap();
spline.eval_deriv_x(1.5, 2.5, xacc, yacc, cache).unwrap();
spline.eval_deriv_y(1.5, 2.5, xacc, yacc, cache).unwrap();
spline.eval_deriv_xx(1.5, 2.5, xacc, yacc, cache).unwrap();
spline.eval_deriv_yy(1.5, 2.5, xacc, yacc, cache).unwrap();
spline.eval_deriv_xy(1.5, 2.5, xacc, yacc, cache).unwrap();
}
#[test]
fn test_make_spline2d() {
let xa = [0.0, 1.0, 2.0, 3.0];
let ya = [0.0, 1.0, 2.0, 3.0];
#[rustfmt::skip]
let za = [
1.0, 1.1, 1.2, 1.3,
1.1, 1.2, 1.3, 1.4,
1.2, 1.3, 1.4, 1.5,
1.3, 1.4, 1.5, 1.6,
];
make_spline2d("bilinear", &xa, &ya, &za).unwrap();
make_spline2d("bicubic", &xa, &ya, &za).unwrap();
assert!(make_spline2d("wrong", &xa, &ya, &za).is_err());
}
}