use super::Saturation;
use crate::{
index_arr_fixed_dims,
interp_math::{dot4, hermite_basis},
mul_add,
};
use crunchy::unroll;
use num_traits::Float;
pub fn interpn<T: Float>(
grids: &[&[T]],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = grids.len();
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions.",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
MulticubicRectilinear::<'_, T, N>::new(
grids.try_into().unwrap(),
vals,
linearize_extrapolation,
)?
.interp(obs.try_into().unwrap(), out)
}
)
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
grids: &[&[T]],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(grids, vals, linearize_extrapolation, obs, &mut out)?;
Ok(out)
}
pub use crate::multilinear::rectilinear::check_bounds;
pub struct MulticubicRectilinear<'a, T: Float, const N: usize> {
grids: &'a [&'a [T]],
dims: [usize; N],
vals: &'a [T],
linearize_extrapolation: bool,
}
impl<'a, T: Float, const N: usize> MulticubicRectilinear<'a, T, N> {
pub fn new(
grids: &'a [&'a [T]; N],
vals: &'a [T],
linearize_extrapolation: bool,
) -> Result<Self, &'static str> {
let mut dims = [1_usize; N];
(0..N).for_each(|i| dims[i] = grids[i].len());
let nvals: usize = dims.iter().product();
if vals.len() != nvals {
return Err("Dimension mismatch");
};
let degenerate = dims.iter().any(|&x| x < 4);
if degenerate {
return Err("All grids must have at least 4 entries");
};
let monotonic_maybe = grids.iter().all(|&g| g[1] > g[0]);
if !monotonic_maybe {
return Err("All grids must be monotonically increasing");
};
Ok(Self {
grids,
dims,
vals,
linearize_extrapolation,
})
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut origin = [0_usize; N]; let mut sat = [Saturation::None; N]; let mut weights = [[T::zero(); FP]; N];
let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let mut store = [[T::zero(); FP]; N];
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
(origin[i], sat[i]) = self.get_loc(x[i], i)?;
let grid_cell = &self.grids[i][origin[i]..origin[i] + 4];
weights[i] = interp_weights(
grid_cell.try_into().unwrap(),
x[i],
sat[i],
self.linearize_extrapolation,
);
}
const FP: usize = 4; let nverts = const { FP.pow(N as u32) };
macro_rules! unroll_vertices_body {
($i:ident) => {
for j in 0..N {
if j == 0 {
for k in 0..N {
let offset: usize = ($i & (3 << (2 * k))) >> (2 * k);
loc[k] = origin[k] + offset;
}
let store_ind: usize = $i % FP;
store[0][store_ind] = index_arr_fixed_dims(loc, dimprod, self.vals);
} else {
let q: usize = FP.pow(j as u32);
let level: bool = ($i + 1).is_multiple_of(q);
let p: usize = (($i + 1) / q).saturating_sub(1) % FP;
let ind: usize = j.saturating_sub(1);
if level {
store[j][p] = dot4(weights[ind], store[ind]);
}
}
}
};
}
#[cfg(not(feature = "deep-unroll"))]
if N <= 3 {
unroll! {
for i < 64 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
#[cfg(feature = "deep-unroll")]
if N <= 4 {
unroll! {
for i < 256 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
Ok(dot4(weights[N - 1], store[N - 1]))
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<(usize, Saturation), &'static str> {
let saturation: Saturation; let grid = self.grids[dim];
let iloc: isize = grid.partition_point(|x| *x < v) as isize - 2;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(4).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
if iloc == -2 {
saturation = Saturation::OutsideLow;
}
else if iloc == -1 {
saturation = Saturation::InsideLow;
}
else if iloc == n - 2 {
saturation = Saturation::OutsideHigh;
}
else if iloc == n - 3 {
saturation = Saturation::InsideHigh;
}
else {
saturation = Saturation::None;
}
Ok((loc, saturation))
}
}
#[inline]
fn interp_weights<T: Float>(
grid_cell: &[T; 4],
x: T,
sat: Saturation,
linearize_extrapolation: bool,
) -> [T; 4] {
let one = T::one();
match sat {
Saturation::None => {
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let t = (x - grid_cell[1]) / h12;
let [h00, h10, h01_basis, h11] = hermite_basis(t);
let k0 = centered_difference_weights(h01 / h12, one);
let k1 = centered_difference_weights(one, h23 / h12);
[
h10 * k0[0],
h00 + h10 * k0[1] + h11 * k1[0],
h01_basis + h10 * k0[2] + h11 * k1[1],
h11 * k1[2],
]
}
Saturation::InsideLow => {
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let t = -(x - grid_cell[1]) / h01;
low_weights(t, h12 / h01, false)
}
Saturation::OutsideLow => {
let h01 = grid_cell[1] - grid_cell[0];
let h12 = grid_cell[2] - grid_cell[1];
let t = -(x - grid_cell[1]) / h01;
low_weights(t, h12 / h01, linearize_extrapolation)
}
Saturation::InsideHigh => {
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let t = (x - grid_cell[2]) / h23;
high_weights(t, h12 / h23, false)
}
Saturation::OutsideHigh => {
let h12 = grid_cell[2] - grid_cell[1];
let h23 = grid_cell[3] - grid_cell[2];
let t = (x - grid_cell[2]) / h23;
high_weights(t, h12 / h23, linearize_extrapolation)
}
}
}
#[inline]
fn low_weights<T: Float>(t: T, h12_over_h01: T, linearize_extrapolation: bool) -> [T; 4] {
let one = T::one();
let two = one + one;
let k0 = centered_difference_weights(one, h12_over_h01);
if linearize_extrapolation {
let s = t - one;
[
one + s * (two + k0[0]),
s * (-two + k0[1]),
s * k0[2],
T::zero(),
]
} else {
let [h00, h10, h01, h11] = hermite_basis(t);
let slope_factor = h11 - h10;
[
h01 + two * h11 + slope_factor * k0[0],
h00 - two * h11 + slope_factor * k0[1],
slope_factor * k0[2],
T::zero(),
]
}
}
#[inline]
fn high_weights<T: Float>(t: T, h12_over_h23: T, linearize_extrapolation: bool) -> [T; 4] {
let one = T::one();
let two = one + one;
let k0 = centered_difference_weights(h12_over_h23, one);
if linearize_extrapolation {
let s = t - one;
[
T::zero(),
-s * k0[0],
s * (-two - k0[1]),
one + s * (two - k0[2]),
]
} else {
let [h00, h10, h01, h11] = hermite_basis(t);
let slope_factor = h10 - h11;
[
T::zero(),
slope_factor * k0[0],
h00 - two * h11 + slope_factor * k0[1],
h01 + two * h11 + slope_factor * k0[2],
]
}
}
#[inline]
fn centered_difference_weights<T: Float>(h01: T, h12: T) -> [T; 3] {
let denom = h01 + h12;
let a = h01 / denom;
let c = h12 / denom;
[-c / h01, mul_add(c, T::one() / h01, -a / h12), a / h12]
}
#[cfg(test)]
mod test {
use super::{MulticubicRectilinear, interpn};
use crate::testing::*;
use crate::utils::*;
fn assert_linear_extrapolation(values: [f64; 3]) {
assert!(
(values[2] - 2.0 * values[1] + values[0]).abs() < 1e-12,
"extrapolated values are not linear: {values:?}"
);
}
#[test]
fn test_interp_extrap_1d_to_4d_linear() {
let mut rng = rng_fixed_seed();
for ndims in 1..=4 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = gridobs.iter().map(|x| x.iter().sum()).collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, true, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-10));
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-10));
}
}
#[test]
fn test_linearized_extrapolation_is_linear_outside_grid() {
let x = [-1.0_f64, -0.2, 0.35, 1.4, 2.8, 4.0];
let grids = [&x[..]];
let vals: Vec<f64> = x
.iter()
.map(|&x| x * x * x - 0.5 * x * x + 2.0 * x)
.collect();
let interp = MulticubicRectilinear::new(&grids, &vals, true).unwrap();
assert_linear_extrapolation([
interp.interp_one([-1.0]).unwrap(),
interp.interp_one([-1.6]).unwrap(),
interp.interp_one([-2.2]).unwrap(),
]);
assert_linear_extrapolation([
interp.interp_one([4.0]).unwrap(),
interp.interp_one([4.7]).unwrap(),
interp.interp_one([5.4]).unwrap(),
]);
}
#[test]
fn test_interp_extrap_1d_to_6d_quadratic() {
let mut rng = rng_fixed_seed();
for ndims in 1..=6 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += grid[i][j] * grid[i][j];
}
v
})
.collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += gridobs[i][j] * gridobs[i][j];
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 3e-10));
}
}
#[test]
fn test_interp_1d_to_3d_sine() {
let mut rng = rng_fixed_seed();
for ndims in 1..3 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![10; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (grid[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i] + 1))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (gridobs[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, false, &xobsslice, &mut out[..]).unwrap();
let tol = 2e-2 * f64::from(ndims as u32);
(0..uobs.len()).for_each(|i| {
let err = out[i] - uobs[i];
assert!(err.abs() < tol);
});
}
}
}