use crate::types::check_if_inbounds;
use crate::{Accelerator, Cache, DynInterp2dType};
use crate::{DomainError, InterpolationError};
pub trait Interp2dType<T> {
type Interpolation2d: Interpolation2d<T> + Send + Sync;
fn build(
&self,
xa: &[T],
ya: &[T],
za: &[T],
) -> Result<Self::Interpolation2d, InterpolationError>;
#[doc(alias = "gsl_interp_name")]
fn name(&self) -> &str;
#[doc(alias = "gsl_interp_min_size")]
fn min_size(&self) -> usize;
}
#[allow(clippy::too_many_arguments)]
pub trait Interpolation2d<T> {
#[doc(alias = "gsl_interp2d_eval")]
#[doc(alias = "gsl_interp2d_eval_e")]
fn eval(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>
where
T: PartialOrd + Clone,
{
check_if_inbounds(xa, x.clone())?;
check_if_inbounds(ya, y.clone())?;
self.eval_extrap(xa, ya, za, x, y, xacc, yacc, cache)
}
#[doc(alias = "gsl_interp2d_eval_extrap")]
#[doc(alias = "gsl_interp2d_eval_extrap_e")]
fn eval_extrap(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
#[doc(alias = "gsl_interp2d_eval_deriv_x")]
#[doc(alias = "gsl_interp2d_eval_deriv_x_e")]
fn eval_deriv_x(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
#[doc(alias = "gsl_interp2d_eval_deriv_y")]
#[doc(alias = "gsl_interp2d_eval_deriv_y_e")]
fn eval_deriv_y(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
#[doc(alias = "gsl_interp2d_eval_deriv_xx")]
#[doc(alias = "gsl_interp2d_eval_deriv_xx_e")]
fn eval_deriv_xx(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
#[doc(alias = "gsl_interp2d_eval_deriv_yy")]
#[doc(alias = "gsl_interp2d_eval_deriv_yy_e")]
fn eval_deriv_yy(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
#[doc(alias = "gsl_interp2d_eval_deriv_xy")]
#[doc(alias = "gsl_interp2d_eval_deriv_xy_e")]
fn eval_deriv_xy(
&self,
xa: &[T],
ya: &[T],
za: &[T],
x: T,
y: T,
xacc: &mut Accelerator,
yacc: &mut Accelerator,
cache: &mut Cache<T>,
) -> Result<T, DomainError>;
}
#[doc(alias = "gsl_interp2d_idx")]
pub fn z_idx(xi: usize, yi: usize, xlen: usize, ylen: usize) -> Result<usize, DomainError> {
if (xi >= xlen) | (yi >= ylen) {
Err(DomainError)
} else {
Ok(yi * xlen + xi)
}
}
#[doc(alias = "gsl_inter2d_set")]
pub fn z_set<T>(
za: &mut [T],
z: T,
i: usize,
j: usize,
xlen: usize,
ylen: usize,
) -> Result<(), DomainError>
where
T: crate::Num,
{
if (i >= xlen) | (j >= ylen) {
return Err(DomainError);
};
za[z_idx(i, j, xlen, ylen)?] = z;
Ok(())
}
#[doc(alias = "gsl_inter2d_get")]
pub fn z_get<T>(za: &[T], i: usize, j: usize, xlen: usize, ylen: usize) -> Result<T, DomainError>
where
T: crate::Num,
{
if (i >= xlen) | (j >= ylen) {
return Err(DomainError);
};
Ok(za[z_idx(i, j, xlen, ylen)?])
}
pub fn make_interp2d_type<T>(typ: &str) -> Result<DynInterp2dType<T>, InterpolationError>
where
T: crate::Num + ndarray_linalg::Lapack,
{
use crate::*;
match typ.to_lowercase().as_str() {
"bilinear" => Ok(DynInterp2dType::new(Bilinear)),
"bicubic" => Ok(DynInterp2dType::new(Bicubic)),
_ => Err(InterpolationError::InvalidType(typ.into())),
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::*;
#[test]
fn test_z_idx() {
let shape = (4, 3); assert_eq!(z_idx(1, 2, shape.0, shape.1).unwrap(), 9);
assert!(matches!(z_idx(10, 200, shape.0, shape.1), Err(DomainError)));
}
#[test]
fn test_set() {
let xa = [0.0, 1.0];
let ya = [0.0, 2.0];
#[rustfmt::skip]
let mut za = [
0.0, 1.0,
1.0, 0.5,
];
let za00 = 100.0;
let za01 = 300.0;
let za10 = 200.0;
let za11 = 400.0;
let xlen = xa.len();
let ylen = ya.len();
z_set(&mut za, za00, 0, 0, xlen, ylen).unwrap();
z_set(&mut za, za01, 0, 1, xlen, ylen).unwrap();
z_set(&mut za, za10, 1, 0, xlen, ylen).unwrap();
z_set(&mut za, za11, 1, 1, xlen, ylen).unwrap();
assert_eq!(za, [100.0, 200.0, 300.0, 400.0,]);
assert!(matches!(
z_set(&mut za, za11, 10, 10000, xlen, ylen),
Err(DomainError)
));
}
#[test]
fn test_z_get() {
#[rustfmt::skip]
let xa = [0.0, 1.0, 2.0];
let ya = [0.0, 1.0, 2.0, 3.0];
#[rustfmt::skip]
let za = [
0.0, 1.0, 2.0,
3.0, 4.0, 5.0,
6.0, 5.0, 4.0,
3.0, 99.0, 1.0, ];
let (i, j) = (1, 3);
let idx = z_get(&za, i, j, xa.len(), ya.len()).unwrap();
let expected = 99.0;
assert_eq!(idx, expected);
assert!(matches!(
z_get(&za, 10, 2000, xa.len(), ya.len()),
Err(DomainError)
));
}
#[test]
fn test_dyn_interp_type() {
let xa = [0.0, 1.0, 2.0, 3.0];
let ya = [0.0, 2.0, 4.0, 6.0];
#[rustfmt::skip]
let za = [
0.0, 1.0, 2.0, 3.0,
2.0, 3.0, 4.0, 5.0,
4.0, 5.0, 6.0, 7.0,
6.0, 7.0, 8.0, 9.0,
];
let mut xacc = Accelerator::new();
let mut yacc = Accelerator::new();
let mut cache = Cache::new();
let x = 0.5;
let y = 1.0;
let interp2d_type = DynInterp2dType::new(Bicubic);
let interp2d = interp2d_type.build(&xa, &ya, &za).unwrap();
interp2d
.eval(&xa, &ya, &za, x, y, &mut xacc, &mut yacc, &mut cache)
.unwrap();
}
#[test]
fn test_make_interp2d_type() {
make_interp2d_type::<f64>("bilinear").unwrap();
make_interp2d_type::<f64>("bicubic").unwrap();
assert!(make_interp2d_type::<f64>("wrong").is_err());
}
}