use crate::index_arr_fixed_dims;
use crunchy::unroll;
use num_traits::Float;
pub fn interpn<T: Float>(
grids: &[&[T]],
vals: &[T],
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = grids.len();
if grids.len() != ndims || obs.len() != ndims {
return Err("Dimension mismatch");
}
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions.",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
MultilinearRectilinear::<'_, T, N>::new(grids.try_into().unwrap(), vals)?
.interp(obs.try_into().unwrap(), out)
}
)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
grids: &[&[T]],
vals: &[T],
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(grids, vals, obs, &mut out)?;
Ok(out)
}
pub fn check_bounds<T: Float>(
grids: &[&[T]],
obs: &[&[T]],
atol: T,
out: &mut [bool],
) -> Result<(), &'static str> {
let ndims = grids.len();
if !(obs.len() == ndims && out.len() == ndims && (0..ndims).all(|i| !grids[i].is_empty())) {
return Err("Dimension mismatch");
}
for i in 0..ndims {
let lo = grids[i][0];
let hi = grids[i].last();
match hi {
Some(&hi) => {
let bad = obs[i]
.iter()
.any(|&x| (x - lo) <= -atol || (x - hi) >= atol);
out[i] = bad;
}
None => return Err("Dimension mismatch"),
}
}
Ok(())
}
pub struct MultilinearRectilinear<'a, T: Float, const N: usize> {
grids: &'a [&'a [T]; N],
dims: [usize; N],
vals: &'a [T],
}
impl<'a, T: Float, const N: usize> MultilinearRectilinear<'a, T, N> {
pub fn new(grids: &'a [&'a [T]; N], vals: &'a [T]) -> Result<Self, &'static str> {
let dims = crate::validate_rectilinear_grid(grids, vals)?;
Ok(Self { grids, dims, vals })
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
#[inline]
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut origin = [0_usize; N]; let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let mut store = [[T::zero(); FP]; N];
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
origin[i] = self.get_loc(x[i], i)?;
}
const FP: usize = 2; let nverts = const { FP.pow(N as u32) };
macro_rules! unroll_vertices_body {
($i:ident) => {
for j in 0..N {
if j == 0 {
for k in 0..N {
let offset: usize = ($i & (1 << k)) >> k;
loc[k] = origin[k] + offset;
}
let store_ind: usize = $i % FP;
store[0][store_ind] = index_arr_fixed_dims(loc, dimprod, self.vals);
} else {
let q: usize = FP.pow(j as u32);
let level: bool = ($i + 1).is_multiple_of(q);
if level {
let p: usize = (($i + 1) / q).saturating_sub(1) % FP;
let ind: usize = j.saturating_sub(1);
let x0 = self.grids[ind][origin[ind]];
let x1 = self.grids[ind][origin[ind] + 1];
let step = x1 - x0;
let t = (x[ind] - x0) / step;
let y0 = store[ind][0];
let dy = store[ind][1] - y0;
#[cfg(not(feature = "fma"))]
let interped = y0 + t * dy;
#[cfg(feature = "fma")]
let interped = t.mul_add(dy, y0);
store[j][p] = interped;
}
}
}
};
}
if N <= 6 {
unroll! {
for i < 64 in 0..nverts { unroll_vertices_body!(i)
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i)
}
}
let ind = N - 1;
let x0 = self.grids[ind][origin[ind]];
let x1 = self.grids[ind][origin[ind] + 1];
let step = x1 - x0;
let t = (x[ind] - x0) / step;
let y0 = store[ind][0];
let dy = store[ind][1] - y0;
#[cfg(not(feature = "fma"))]
let interped = y0 + t * dy;
#[cfg(feature = "fma")]
let interped = t.mul_add(dy, y0);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<usize, &'static str> {
let grid = self.grids[dim];
let iloc: isize = grid.partition_point(|x| *x < v) as isize - 1;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(2).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
Ok(loc)
}
}
#[cfg(test)]
mod test {
use super::{MultilinearRectilinear, interpn};
use crate::testing::*;
use crate::utils::*;
#[test]
fn test_interp_extrap_2d_small() {
let (nx, ny) = (3, 2);
let x = linspace(-1.0, 1.0, nx);
let y = Vec::from([0.5, 0.6]);
let grids = [&x[..], &y[..]];
let xy = meshgrid(Vec::from([&x, &y]));
let z: Vec<f64> = (0..nx * ny).map(|i| &xy[i][0] + &xy[i][1]).collect();
let xobs = linspace(-10.0_f64, 10.0, 5);
let yobs = linspace(-10.0_f64, 10.0, 5);
let xyobs = meshgrid(Vec::from([&xobs, &yobs]));
let zobs: Vec<f64> = (0..xobs.len() * yobs.len())
.map(|i| &xyobs[i][0] + &xyobs[i][1])
.collect();
let interpolator: MultilinearRectilinear<'_, _, 2> =
MultilinearRectilinear::new(&grids, &z[..]).unwrap();
xyobs.iter().zip(zobs.iter()).for_each(|(xyi, zi)| {
let zii = interpolator.interp_one([xyi[0], xyi[1]]).unwrap();
assert!((*zi - zii).abs() < 1e-12)
});
}
#[test]
fn test_interp_extrap_1d_to_8d() {
let mut rng = rng_fixed_seed();
for ndims in 1..=8 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![2; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), 3))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = gridobs.iter().map(|x| x.iter().sum()).collect(); let mut out = vec![0.0; uobs.len()];
interpn(&grids, &u, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| assert!((out[i] - uobs[i]).abs() < 1e-12));
}
}
#[test]
fn test_interp_hat_func() {
fn hat_func(x: f64) -> f64 {
if x <= 1.0 { x } else { 2.0 - x }
}
let x = (0..3).map(|x| x as f64).collect::<Vec<f64>>();
let grids = [&x[..]];
let y = (0..3).map(|x| hat_func(x as f64)).collect::<Vec<f64>>();
let obs = linspace(-2.0, 4.0, 100);
let interpolator: MultilinearRectilinear<f64, 1> =
MultilinearRectilinear::new(&grids, &y).unwrap();
(0..obs.len()).for_each(|i| {
assert_eq!(hat_func(obs[i]), interpolator.interp_one([obs[i]]).unwrap());
})
}
}