use super::{Saturation, normalized_hermite_spline};
use crate::index_arr_fixed_dims;
use crunchy::unroll;
use num_traits::{Float, NumCast};
pub fn interpn<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = dims.len();
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions.",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
MulticubicRegular::<'_, T, N>::new(
dims.try_into().unwrap(),
starts.try_into().unwrap(),
steps.try_into().unwrap(),
vals,
linearize_extrapolation,
)?
.interp(obs.try_into().unwrap(), out)
}
)
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
linearize_extrapolation: bool,
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(
dims,
starts,
steps,
vals,
linearize_extrapolation,
obs,
&mut out,
)?;
Ok(out)
}
pub use crate::multilinear::regular::check_bounds;
pub struct MulticubicRegular<'a, T: Float, const N: usize> {
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
linearize_extrapolation: bool,
}
impl<'a, T: Float, const N: usize> MulticubicRegular<'a, T, N> {
pub fn new(
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
linearize_extrapolation: bool,
) -> Result<Self, &'static str> {
let nvals: usize = dims.iter().product();
if !(starts.len() == N && steps.len() == N && vals.len() == nvals && N > 0) {
return Err("Dimension mismatch");
}
let degenerate = dims.iter().any(|&x| x < 4);
if degenerate {
return Err("All grids must have at least four entries");
}
let steps_are_positive = steps.iter().all(|&x| x > T::zero());
if !steps_are_positive {
return Err("All grids must be monotonically increasing");
}
let mut steps_local = [T::zero(); N];
let mut starts_local = [T::zero(); N];
let mut dims_local = [0_usize; N];
steps_local[..N].copy_from_slice(&steps[..N]);
starts_local[..N].copy_from_slice(&starts[..N]);
dims_local[..N].copy_from_slice(&dims[..N]);
Ok(Self {
dims: dims_local,
starts: starts_local,
steps: steps_local,
vals,
linearize_extrapolation,
})
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut origin = [0_usize; N]; let mut sat = [Saturation::None; N]; let mut dts = [T::zero(); N]; let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let mut store = [[T::zero(); FP]; N];
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
(origin[i], sat[i]) = self.get_loc(x[i], i)?;
let index_one_loc = self.starts[i]
+ self.steps[i]
* <T as NumCast>::from(origin[i] + 1)
.ok_or("Unrepresentable coordinate value")?;
dts[i] = (x[i] - index_one_loc) / self.steps[i];
}
const FP: usize = 4; let nverts = const { FP.pow(N as u32) };
macro_rules! unroll_vertices_body {
($i:ident) => {
for j in 0..N {
if j == 0 {
for k in 0..N {
let offset: usize = ($i & (3 << (2 * k))) >> (2 * k);
loc[k] = origin[k] + offset;
}
let store_ind: usize = $i % FP;
store[0][store_ind] = index_arr_fixed_dims(loc, dimprod, self.vals);
} else {
let q: usize = FP.pow(j as u32);
let level: bool = ($i + 1).is_multiple_of(q);
let p: usize = (($i + 1) / q).saturating_sub(1) % FP;
let ind: usize = j.saturating_sub(1);
if level {
let interped = interp_inner::<T>(
&store[ind],
dts[ind],
sat[ind],
self.linearize_extrapolation,
);
store[j][p] = interped;
}
}
}
};
}
#[cfg(not(feature = "deep-unroll"))]
if N <= 3 {
unroll! {
for i < 64 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
#[cfg(feature = "deep-unroll")]
if N <= 4 {
unroll! {
for i < 256 in 0..nverts { unroll_vertices_body!(i);
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i);
}
}
let interped = interp_inner::<T>(
&store[N - 1],
dts[N - 1],
sat[N - 1],
self.linearize_extrapolation,
);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<(usize, Saturation), &'static str> {
let saturation: Saturation;
let floc = ((v - self.starts[dim]) / self.steps[dim]).floor(); let iloc = <isize as NumCast>::from(floc).ok_or("Unrepresentable coordinate value")? - 1;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(4).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
if iloc < -1 {
saturation = Saturation::OutsideLow;
}
else if iloc == -1 {
saturation = Saturation::InsideLow;
}
else if iloc > (n - 3) {
saturation = Saturation::OutsideHigh;
}
else if iloc == (n - 3) {
saturation = Saturation::InsideHigh;
}
else {
saturation = Saturation::None;
}
Ok((loc, saturation))
}
}
#[inline]
fn interp_inner<T: Float>(
vals: &[T; 4],
t: T,
sat: Saturation,
linearize_extrapolation: bool,
) -> T {
let one = T::one();
let two = one + one;
match sat {
Saturation::None => {
let y0 = vals[1];
let dy = vals[2] - vals[1];
let k0 = (vals[2] - vals[0]) / two;
let k1 = (vals[3] - vals[1]) / two;
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::InsideLow => {
let t = -t; let y0 = vals[1]; let dy = vals[0] - vals[1];
let k0 = -(vals[2] - vals[0]) / two;
#[cfg(not(feature = "fma"))]
let k1 = two * dy - k0; #[cfg(feature = "fma")]
let k1 = two.mul_add(dy, -k0);
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::OutsideLow => {
let t = -t; let y0 = vals[1]; let y1 = vals[0];
let dy = vals[0] - vals[1];
let k0 = -(vals[2] - vals[0]) / two;
#[cfg(not(feature = "fma"))]
let k1 = two * dy - k0; #[cfg(feature = "fma")]
let k1 = two.mul_add(dy, -k0);
if linearize_extrapolation {
#[cfg(not(feature = "fma"))]
{
y1 + k1 * (t - one)
}
#[cfg(feature = "fma")]
{
k1.mul_add(t - one, y1)
}
} else {
normalized_hermite_spline(t, y0, dy, k0, k1)
}
}
Saturation::InsideHigh => {
let t = t - one;
let y0 = vals[2];
let dy = vals[3] - vals[2];
let k0 = (vals[3] - vals[1]) / two;
#[cfg(not(feature = "fma"))]
let k1 = two * dy - k0; #[cfg(feature = "fma")]
let k1 = two.mul_add(dy, -k0);
normalized_hermite_spline(t, y0, dy, k0, k1)
}
Saturation::OutsideHigh => {
let t = t - one;
let y0 = vals[2];
let y1 = vals[3];
let dy = vals[3] - vals[2];
let k0 = (vals[3] - vals[1]) / two;
#[cfg(not(feature = "fma"))]
let k1 = two * dy - k0; #[cfg(feature = "fma")]
let k1 = two.mul_add(dy, -k0);
if linearize_extrapolation {
#[cfg(not(feature = "fma"))]
{
y1 + k1 * (t - one)
}
#[cfg(feature = "fma")]
{
k1.mul_add(t - one, y1)
}
} else {
normalized_hermite_spline(t, y0, dy, k0, k1)
}
}
}
}
#[cfg(test)]
mod test {
use super::interpn;
use crate::utils::*;
#[test]
fn test_interp_extrap_1d_to_4d_linear() {
for ndims in 1..=4 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]))
.collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect(); let starts: Vec<f64> = xs.iter().map(|x| x[0]).collect();
let steps: Vec<f64> = xs.iter().map(|x| x[1] - x[0]).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = gridobs.iter().map(|x| x.iter().sum()).collect(); let mut out = vec![0.0; uobs.len()];
interpn(&dims, &starts, &steps, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-12));
interpn(&dims, &starts, &steps, &u, true, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-12));
}
}
#[test]
fn test_interp_extrap_1d_to_6d_quadratic() {
for ndims in 1..=6 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![4; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]))
.collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += grid[i][j] * grid[i][j];
}
v
})
.collect(); let starts: Vec<f64> = xs.iter().map(|x| x[0]).collect();
let steps: Vec<f64> = xs.iter().map(|x| x[1] - x[0]).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += gridobs[i][j] * gridobs[i][j];
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&dims, &starts, &steps, &u, false, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-10));
}
}
#[test]
fn test_interp_1d_to_3d_sine() {
for ndims in 1..3 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![10; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]))
.collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = (0..grid.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (grid[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect(); let starts: Vec<f64> = xs.iter().map(|x| x[0]).collect();
let steps: Vec<f64> = xs.iter().map(|x| x[1] - x[0]).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i] + 2))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = (0..gridobs.len())
.map(|i| {
let mut v = 0.0;
for j in 0..ndims {
v += (gridobs[i][j] * 6.28 / 10.0).sin();
}
v
})
.collect(); let mut out = vec![0.0; uobs.len()];
interpn(&dims, &starts, &steps, &u, false, &xobsslice, &mut out[..]).unwrap();
let tol = 2e-2 * f64::from(ndims as u32);
(0..uobs.len()).for_each(|i| {
let err = out[i] - uobs[i];
assert!(err.abs() < tol);
});
}
}
}