use super::{Saturation, centered_difference_nonuniform, normalized_hermite_spline};
use crate::index_arr_fixed_dims;
use crunchy::unroll;
use num_traits::Float;
pub fn interpn<T: Float>(
grids: &[&[T]],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = grids.len();
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions.",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
MulticubicRectilinear::<'_, T, N>::new(
grids.try_into().unwrap(),
vals,
linearize_extrapolation,
)?
.interp(obs.try_into().unwrap(), out)
}
)
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
grids: &[&[T]],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(grids, vals, linearize_extrapolation, obs, &mut out)?;
Ok(out)
}
pub use crate::multilinear::rectilinear::check_bounds;
pub struct MulticubicRectilinear<'a, T: Float, const N: usize> {
grids: &'a [&'a [T]],
dims: [usize; N],
vals: &'a [T],
linearize_extrapolation: bool,
}
impl<'a, T: Float, const N: usize> MulticubicRectilinear<'a, T, N> {
pub fn new(
grids: &'a [&'a [T]; N],
vals: &'a [T],
linearize_extrapolation: bool,
) -> Result<Self, &'static str> {
let mut dims = [1_usize; N];
(0..N).for_each(|i| dims[i] = grids[i].len());
let nvals: usize = dims.iter().product();
if vals.len() != nvals {
return Err("Dimension mismatch");
};
let degenerate = dims.iter().any(|&x| x < 4);
if degenerate {
return Err("All grids must have at least 4 entries");
};
let monotonic_maybe = grids.iter().all(|&g| g[1] > g[0]);
if !monotonic_maybe {
return Err("All grids must be monotonically increasing");
};
Ok(Self {
grids,
dims,
vals,
linearize_extrapolation,
})
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut origin = [0_usize; N]; let mut sat = [Saturation::None; N]; let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let mut store = [[T::zero(); FP]; N];
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
(origin[i], sat[i]) = self.get_loc(x[i], i)?;
}
const FP: usize = 4; let nverts = const { FP.pow(N as u32) };
macro_rules! unroll_vertices_body {
($i:ident) => {
for j in 0..N {
if j == 0 {
for k in 0..N {
let offset: usize = ($i & (3 << (2 * k))) >> (2 * k);
loc[k] = origin[k] + offset;
}
let store_ind: usize = $i % FP;
store[0][store_ind] = index_arr_fixed_dims(loc, dimprod, self.vals);
} else {
let q: usize = FP.pow(j as u32);
let level: bool = ($i + 1).is_multiple_of(q);
let p: usize = (($i + 1) / q).saturating_sub(1) % FP;
let ind: usize = j.saturating_sub(1);
if level {
let grid_cell = &self.grids[ind][origin[ind]..origin[ind] + 4];
let interped = interp_inner::<T>(
store[ind],
grid_cell.try_into().unwrap(),
x[ind],
sat[ind],
self.linearize_extrapolation,
);
store[j][p] = interped;
}
}
}
};
}
#[cfg(not(feature = "deep-unroll"))]
if N <= 3 {
unroll! {
for i < 64 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
#[cfg(feature = "deep-unroll")]
if N <= 4 {
unroll! {
for i < 256 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
let grid_cell = &self.grids[N - 1][origin[N - 1]..origin[N - 1] + 4];
let interped = interp_inner::<T>(
store[N - 1],
grid_cell.try_into().unwrap(),
x[N - 1],
sat[N - 1],
self.linearize_extrapolation,
);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<(usize, Saturation), &'static str> {
let saturation: Saturation; let grid = self.grids[dim];
let iloc: isize = grid.partition_point(|x| *x < v) as isize - 2;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(4).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
if iloc == -2 {
saturation = Saturation::OutsideLow;
}
else if iloc == -1 {
saturation = Saturation::InsideLow;
}
else if iloc == n - 2 {
saturation = Saturation::OutsideHigh;
}
else if iloc == n - 3 {
saturation = Saturation::InsideHigh;
}
else {
saturation = Saturation::None;
}
Ok((loc, saturation))
}
}
#[inline]
fn interp_inner<T: Float>(
vals: [T; 4],
grid_cell: &[T; 4],
x: T,
sat: Saturation,
linearize_extrapolation: bool,
) -> T {
let one = T::one();
let two = one + one;
match sat {
Saturation::None => {
let y0 = vals[1];
let dy = vals[2] - vals[1];
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let k0 = centered_difference_nonuniform(vals[0], vals[1], vals[2], h01 / h12, T::one());
let k1 = centered_difference_nonuniform(vals[1], vals[2], vals[3], T::one(), h23 / h12);
let t = (x - grid_cell[1]) / h12;
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::InsideLow => {
let y0 = vals[1]; let dy = vals[0] - vals[1];
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let k0 =
-centered_difference_nonuniform(vals[0], vals[1], vals[2], T::one(), h12 / h01);
let k1 = two * dy - k0;
let t = -(x - grid_cell[1]) / h01;
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::OutsideLow => {
let y0 = vals[1];
let y1 = vals[0];
let dy = vals[0] - vals[1];
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let k0 =
-centered_difference_nonuniform(vals[0], vals[1], vals[2], T::one(), h12 / h01);
let k1 = two * dy - k0;
let t = -(x - grid_cell[1]) / h01;
if linearize_extrapolation {
y1 + k1 * (t - one)
} else {
normalized_hermite_spline(t, y0, dy, k0, k1)
}
}
Saturation::InsideHigh => {
let y0 = vals[2];
let dy = vals[3] - vals[2];
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let k0 = centered_difference_nonuniform(vals[1], vals[2], vals[3], h12 / h23, T::one());
let k1 = two * dy - k0;
let t = (x - grid_cell[2]) / h23;
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::OutsideHigh => {
let y0 = vals[2];
let y1 = vals[3];
let dy = vals[3] - vals[2];
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let k0 = centered_difference_nonuniform(vals[1], vals[2], vals[3], h12 / h23, T::one());
let k1 = two * dy - k0;
let t = (x - grid_cell[2]) / h23;
if linearize_extrapolation {
y1 + k1 * (t - one)
} else {
normalized_hermite_spline(t, y0, dy, k0, k1)
}
}
}
}
#[cfg(test)]
mod test {
use super::interpn;
use crate::testing::*;
use crate::utils::*;
#[test]
fn test_interp_extrap_1d_to_4d_linear() {
let mut rng = rng_fixed_seed();
for ndims in 1..=4 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = gridobs.iter().map(|x| x.iter().sum()).collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, true, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-10));
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-10));
}
}
#[test]
fn test_interp_extrap_1d_to_6d_quadratic() {
let mut rng = rng_fixed_seed();
for ndims in 1..=6 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += grid[i][j] * grid[i][j];
}
v
})
.collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += gridobs[i][j] * gridobs[i][j];
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 3e-10));
}
}
#[test]
fn test_interp_1d_to_3d_sine() {
let mut rng = rng_fixed_seed();
for ndims in 1..3 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![10; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (grid[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i] + 1))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (gridobs[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
let tol = 2e-2 * f64::from(ndims as u32);
(0..uobs.len()).for_each(|i| {
let err = out[i] - uobs[i];
assert!(err.abs() < tol);
});
}
}
}