use crate::{find_closest_neighbours_indices, is_monotonic};
use ndarray::{ArrayBase, Axis, Data, Ix2, OwnedRepr};
use std::ops::{Add, Div, Mul, Sub};
#[derive(Debug)]
pub struct Interp2D<C1, C2, Z, S>
where
S: Data<Elem = Z>,
{
x: Vec<C1>,
y: Vec<C2>,
z: ArrayBase<S, Ix2>,
}
impl<C1, C2, Z, S> Clone for Interp2D<C1, C2, Z, S>
where
C1: Clone,
C2: Clone,
S: Data<Elem = Z>,
ArrayBase<S, Ix2>: Clone,
{
fn clone(&self) -> Self {
Self {
x: self.x.clone(),
y: self.y.clone(),
z: self.z.clone(),
}
}
}
fn interpolate2d_bilinear<C, Z>(x00: Z, x10: Z, x01: Z, x11: Z, alpha: C, beta: C) -> Z
where
C: Copy + Add<Output = C> + Sub<Output = C>,
Z: Copy + Mul<C, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
{
x00 + (x10 - x00) * alpha + (x01 - x00) * beta + (x00 + x11 - x10 - x01) * alpha * beta
}
#[test]
fn test_interpolate2d() {
let tol = 1e-6f64;
assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 0., 0.) - 1.).abs() < tol);
assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 1., 0.) - 2.).abs() < tol);
assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 0., 1.) - 3.).abs() < tol);
assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 1., 1.) - 4.).abs() < tol);
assert!((interpolate2d_bilinear(0.0f64, 1., 1., 0., 0.5, 0.5) - 0.5).abs() < tol);
}
fn interp2d<C1, C2, Z>(
(x, y): (C1, C2),
(x0, x1): (C1, C1),
(y0, y1): (C2, C2),
(v00, v10, v01, v11): (Z, Z, Z, Z),
) -> Z
where
C1: Copy + Sub<Output = C1> + Div,
C2: Copy + Sub<Output = C2> + Div<Output = <C1 as Div>::Output>,
Z: Copy + Mul<<C1 as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
<C1 as Div>::Output:
Copy + Add<Output = <C1 as Div>::Output> + Sub<Output = <C1 as Div>::Output>,
{
let dx = x1 - x0;
let dy = y1 - y0;
let alpha = (x - x0) / dx;
let beta = (y - y0) / dy;
interpolate2d_bilinear(v00, v10, v01, v11, alpha, beta)
}
impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
where
S: Data<Elem = Z>,
{
pub fn new(x: Vec<C1>, y: Vec<C2>, z: ArrayBase<S, Ix2>) -> Self
where
C1: PartialOrd,
C2: PartialOrd,
{
assert_eq!(z.len_of(Axis(0)), x.len(), "x-axis length mismatch.");
assert_eq!(z.len_of(Axis(1)), y.len(), "y-axis length mismatch.");
assert!(!x.is_empty());
assert!(!y.is_empty());
assert!(is_monotonic(&x), "x values must be monotonic.");
assert!(is_monotonic(&y), "x values must be monotonic.");
Self { x, y, z }
}
pub fn bounds(&self) -> ((C1, C1), (C2, C2))
where
C1: Copy,
C2: Copy,
{
(
(self.x[0], self.x[self.x.len() - 1]),
(self.y[0], self.y[self.y.len() - 1]),
)
}
pub fn is_within_bounds(&self, (x, y): (C1, C2)) -> bool
where
C1: PartialOrd + Copy,
C2: PartialOrd + Copy,
{
let ((x0, x1), (y0, y1)) = self.bounds();
x0 <= x && x <= x1 && y0 <= y && y <= y1
}
pub fn xs(&self) -> &Vec<C1> {
&self.x
}
pub fn ys(&self) -> &Vec<C2> {
&self.y
}
pub fn z(&self) -> &ArrayBase<S, Ix2> {
&self.z
}
pub fn map_values<Z2>(&self, f: impl Fn(&Z) -> Z2) -> Interp2D<C1, C2, Z2, OwnedRepr<Z2>>
where
C1: PartialOrd + Clone,
C2: PartialOrd + Clone,
{
Interp2D {
x: self.x.clone(),
y: self.y.clone(),
z: self.z.map(f),
}
}
pub fn map_x_axis<Xnew>(self, f: impl Fn(C1) -> Xnew) -> Interp2D<Xnew, C2, Z, S>
where
Xnew: PartialOrd,
{
let xnew = self.x.into_iter().map(f).collect();
assert!(is_monotonic(&xnew));
Interp2D {
x: xnew,
y: self.y,
z: self.z,
}
}
pub fn map_y_axis<Ynew>(self, f: impl Fn(C2) -> Ynew) -> Interp2D<C1, Ynew, Z, S>
where
Ynew: PartialOrd,
{
let ynew = self.y.into_iter().map(f).collect();
assert!(is_monotonic(&ynew));
Interp2D {
x: self.x,
y: ynew,
z: self.z,
}
}
}
impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
where
C1: Copy + Sub<Output = C1> + Div + PartialOrd,
C2: Copy + Sub<Output = C2> + Div<Output = <C1 as Div>::Output> + PartialOrd,
Z: Copy + Mul<<C1 as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
<C1 as Div>::Output:
Copy + Add<Output = <C1 as Div>::Output> + Sub<Output = <C1 as Div>::Output>,
S: Data<Elem = Z> + Clone,
{
pub fn eval(&self, (x, y): (C1, C2)) -> Z {
let (x0, x1) = find_closest_neighbours_indices(&self.x, x);
let (y0, y1) = find_closest_neighbours_indices(&self.y, y);
interp2d(
(x, y),
(self.x[x0], self.x[x1]),
(self.y[y0], self.y[y1]),
(
self.z[[x0, y0]],
self.z[[x1, y0]],
self.z[[x0, y1]],
self.z[[x1, y1]],
),
)
}
pub fn eval_no_extrapolation(&self, xy: (C1, C2)) -> Option<Z> {
if self.is_within_bounds(xy) {
Some(self.eval(xy))
} else {
None
}
}
}
impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
where
Z: Clone,
S: Data<Elem = Z> + Clone,
OwnedRepr<Z>: Data<Elem = Z>,
{
pub fn swap_variables(self) -> Interp2D<C2, C1, Z, OwnedRepr<Z>> {
let Self { x, y, z } = self;
Interp2D {
x: y,
y: x,
z: z.t().to_owned(),
}
}
}
#[test]
fn test_interp2d_on_view() {
let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![3.0, 4.0, 5.0];
let interp = Interp2D::new(xs, ys, grid.view());
let tol = 1e-6f64;
assert!((interp.eval_no_extrapolation((1.0, 4.0)).unwrap() - 1.0).abs() < tol);
}
#[test]
fn test_interp2d() {
let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![3.0, 4.0, 5.0];
let interp = Interp2D::new(xs, ys, grid);
let tol = 1e-6f64;
assert!((interp.eval_no_extrapolation((1.0, 4.0)).unwrap() - 1.0).abs() < tol);
assert!((interp.eval_no_extrapolation((0.0, 3.0)).unwrap() - 0.0).abs() < tol);
assert!((interp.eval_no_extrapolation((0.5, 3.5)).unwrap() - 0.25).abs() < tol);
}
#[test]
fn test_map_values() {
let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![3.0, 4.0, 5.0];
let interp = Interp2D::new(xs, ys, grid);
let interp_doubled = interp.map_values(|&v| 2. * v);
let tol = 1e-6;
assert!((interp_doubled.eval_no_extrapolation((1.0, 4.0)).unwrap() - 2.0).abs() < tol);
}
#[test]
fn test_map_x_axis() {
let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![3.0, 4.0, 5.0];
let interp = Interp2D::new(xs, ys, grid);
let shift = 1.0;
let interp_shifted = interp.map_x_axis(|x| x + shift);
let tol = 1e-6;
assert!(
(interp_shifted
.eval_no_extrapolation((1.0 + shift, 4.0))
.unwrap()
- 1.0)
.abs()
< tol
);
}
#[test]
fn test_map_y_axis() {
let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![3.0, 4.0, 5.0];
let interp = Interp2D::new(xs, ys, grid);
let shift = 1.0;
let interp_shifted = interp.map_y_axis(|y| y + shift);
let tol = 1e-6;
assert!(
(interp_shifted
.eval_no_extrapolation((1.0, 4.0 + shift))
.unwrap()
- 1.0)
.abs()
< tol
);
}