use crate::{find_closest_neighbours_indices, is_monotonic};
use std::ops::{Add, Div, Mul, Sub};
#[derive(Debug, Clone)]
pub struct Interp1D<C, Z> {
x: Vec<C>,
z: Vec<Z>,
}
fn interpolate1d<C, Z>(z0: Z, z1: Z, alpha: C) -> Z
where
C: Copy + Sub<Output = C>,
Z: Copy + Mul<C, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
{
z0 - z0 * alpha + z1 * alpha
}
#[test]
fn test_interpolate1d() {
assert!((interpolate1d(1.0f64, 2., 0.) - 1.).abs() < 1e-6);
assert!((interpolate1d(1.0f64, 2., 1.) - 2.).abs() < 1e-6);
assert!((interpolate1d(1.0f64, 2., 0.5) - 1.5).abs() < 1e-6);
}
fn interp1d<C, Z>(x: C, (x0, x1): (C, C), (v0, v1): (Z, Z)) -> Z
where
C: Copy + Sub<Output = C> + Div,
Z: Copy + Mul<<C as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
<C as Div>::Output: Copy + Add<Output = <C as Div>::Output> + Sub<Output = <C as Div>::Output>,
{
let dx = x1 - x0;
let alpha = (x - x0) / dx;
interpolate1d(v0, v1, alpha)
}
impl<C, Z> Interp1D<C, Z> {
pub fn new(x: Vec<C>, z: Vec<Z>) -> Self
where
C: PartialOrd,
{
assert_eq!(z.len(), x.len(), "x-axis length mismatch.");
assert!(!x.is_empty());
assert!(is_monotonic(&x), "x values must be monotonic.");
Self { x, z }
}
pub fn bounds(&self) -> (C, C)
where
C: Copy,
{
(self.x[0], self.x[self.x.len() - 1])
}
pub fn is_within_bounds(&self, x: C) -> bool
where
C: PartialOrd + Copy,
{
let (x0, x1) = self.bounds();
x0 <= x && x <= x1
}
pub fn xs(&self) -> &Vec<C> {
&self.x
}
pub fn z(&self) -> &Vec<Z> {
&self.z
}
pub fn map_values<Z2>(&self, f: impl Fn(&Z) -> Z2) -> Interp1D<C, Z2>
where
C: PartialOrd + Clone,
{
Interp1D {
x: self.x.clone(),
z: self.z.iter().map(f).collect(),
}
}
pub fn map_axis<Xnew>(self, f: impl Fn(C) -> Xnew) -> Interp1D<Xnew, Z>
where
Xnew: PartialOrd,
{
let xnew = self.x.into_iter().map(f).collect();
assert!(is_monotonic(&xnew));
Interp1D { x: xnew, z: self.z }
}
}
impl<C, Z> Interp1D<C, Z>
where
C: Copy + Sub<Output = C> + Div + PartialOrd,
Z: Copy + Mul<<C as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
<C as Div>::Output: Copy + Add<Output = <C as Div>::Output> + Sub<Output = <C as Div>::Output>,
{
pub fn eval(&self, x: C) -> Z {
let (x0, x1) = find_closest_neighbours_indices(&self.x, x);
interp1d(x, (self.x[x0], self.x[x1]), (self.z[x0], self.z[x1]))
}
pub fn eval_no_extrapolation(&self, x: C) -> Option<Z> {
if self.is_within_bounds(x) {
Some(self.eval(x))
} else {
None
}
}
}
#[test]
fn test_interp1d() {
let xs = vec![0.0f64, 1.0, 2.0];
let zs = vec![0.0, 1.0, 0.0];
let interp = Interp1D::new(xs, zs);
assert!((interp.eval_no_extrapolation(1.0).unwrap() - 1.0).abs() < 1e-6);
assert!((interp.eval_no_extrapolation(2.0).unwrap() - 0.0).abs() < 1e-6);
assert!((interp.eval_no_extrapolation(1.5).unwrap() - 0.5).abs() < 1e-6);
}
#[test]
fn test_interp1d_map_values() {
let xs = vec![0.0f64, 1.0, 2.0];
let zs = vec![0.0, 1.0, 0.0];
let interp = Interp1D::new(xs, zs);
let interp = interp.map_values(|x| 2. * x);
assert!((interp.eval_no_extrapolation(1.0).unwrap() - 2.0).abs() < 1e-6);
}
#[test]
fn test_interp1d_map_axis() {
let xs = vec![0.0f64, 1.0, 2.0];
let zs = vec![0.0, 1.0, 0.0];
let interp = Interp1D::new(xs, zs);
let interp = interp.map_axis(|x| 2. * x);
assert!((interp.eval_no_extrapolation(2.0).unwrap() - 1.0).abs() < 1e-6);
}