use crate::index_arr_fixed_dims;
use num_traits::Float;
pub fn interpn<T: Float>(
grids: &[&[T]],
vals: &[T],
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = grids.len();
if grids.len() != ndims || obs.len() != ndims {
return Err("Dimension mismatch");
}
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8).",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
NearestRectilinear::<'_, T, N>::new(grids.try_into().unwrap(), vals)?
.interp(obs.try_into().unwrap(), out)
}
)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
grids: &[&[T]],
vals: &[T],
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(grids, vals, obs, &mut out)?;
Ok(out)
}
pub use crate::multilinear::rectilinear::check_bounds;
pub struct NearestRectilinear<'a, T: Float, const N: usize> {
grids: &'a [&'a [T]; N],
dims: [usize; N],
vals: &'a [T],
}
impl<'a, T: Float, const N: usize> NearestRectilinear<'a, T, N> {
pub fn new(grids: &'a [&'a [T]; N], vals: &'a [T]) -> Result<Self, &'static str> {
const {
assert!(
N > 0 && N < 9,
"Flattened method defined for 1-8 dimensions."
);
}
let dims = crate::validate_rectilinear_grid(grids, vals)?;
Ok(Self { grids, dims, vals })
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
#[inline]
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let two = T::one() + T::one();
let half = T::one() / two;
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
let origin = self.get_loc(x[i], i)?;
let x0 = self.grids[i][origin];
let x1 = self.grids[i][origin + 1];
let step = x1 - x0;
let dt = (x[i] - x0) / step;
let offset = if dt <= half { 0 } else { 1 };
loc[i] = origin + offset;
}
let interped = index_arr_fixed_dims(loc, dimprod, self.vals);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<usize, &'static str> {
let grid = self.grids[dim];
let iloc: isize = grid.partition_point(|x| *x < v) as isize - 1;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(2).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
Ok(loc)
}
}
#[cfg(test)]
mod test {
use super::{NearestRectilinear, interpn};
use crate::testing::*;
use crate::utils::*;
fn nearest_rectilinear_index(value: f64, grid: &[f64]) -> usize {
let iloc = grid.partition_point(|x| *x < value) as isize - 1;
let n = grid.len() as isize;
let dimmax = n.saturating_sub(2).max(0);
let origin = iloc.max(0).min(dimmax) as usize;
let x0 = grid[origin];
let x1 = grid[origin + 1];
let dt = (value - x0) / (x1 - x0);
if dt <= 0.5 { origin } else { origin + 1 }
}
#[test]
fn test_interp_extrap_2d_small() {
let (nx, ny) = (3, 2);
let x = linspace(-1.0, 1.0, nx);
let y = Vec::from([0.5, 0.6]);
let grids = [&x[..], &y[..]];
let xy = meshgrid(Vec::from([&x, &y]));
let z: Vec<f64> = (0..nx * ny).map(|i| &xy[i][0] + &xy[i][1]).collect();
let xobs = linspace(-10.0_f64, 10.0, 5);
let yobs = linspace(-10.0_f64, 10.0, 5);
let xyobs = meshgrid(Vec::from([&xobs, &yobs]));
let interpolator: NearestRectilinear<'_, _, 2> =
NearestRectilinear::new(&grids, &z[..]).unwrap();
xyobs.iter().for_each(|xyi| {
let zii = interpolator.interp_one([xyi[0], xyi[1]]).unwrap();
let ix = nearest_rectilinear_index(xyi[0], &x);
let iy = nearest_rectilinear_index(xyi[1], &y);
let expected = x[ix] + y[iy];
assert!((expected - zii).abs() < 1e-12)
});
}
#[test]
fn test_interp_extrap_1d_to_8d() {
let mut rng = rng_fixed_seed();
for ndims in 1..=8 {
println!("Testing in {ndims} dims");
let dims: Vec<usize> = vec![2; ndims];
let xs: Vec<Vec<f64>> = (0..ndims)
.map(|i| {
let mut x = linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]);
let dx = randn::<f64>(&mut rng, x.len());
(0..x.len()).for_each(|i| x[i] += (dx[i] - 0.5) / 10.0);
(0..x.len() - 1).for_each(|i| assert!(x[i + 1] > x[i]));
x
})
.collect();
let grids: Vec<&[f64]> = xs.iter().map(|x| &x[..]).collect();
let grid = meshgrid((0..ndims).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect();
let xobs: Vec<Vec<f64>> = (0..ndims)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), 3))
.collect();
let gridobs = meshgrid((0..ndims).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..ndims)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let expected: Vec<f64> = gridobs
.iter()
.map(|point| {
(0..ndims)
.map(|dim| {
let idx = nearest_rectilinear_index(point[dim], &xs[dim]);
xs[dim][idx]
})
.sum()
})
.collect();
let mut out = vec![0.0; expected.len()];
interpn(&grids, &u, &xobsslice, &mut out[..]).unwrap();
(0..expected.len()).for_each(|i| assert!((out[i] - expected[i]).abs() < 1e-12));
}
}
#[test]
fn test_interp_hat_func() {
fn hat_func(x: f64) -> f64 {
if x <= 1.0 { x } else { 2.0 - x }
}
let x = (0..3).map(|x| x as f64).collect::<Vec<f64>>();
let grids = [&x[..]];
let y = (0..3).map(|x| hat_func(x as f64)).collect::<Vec<f64>>();
let obs = linspace(-2.0, 4.0, 100);
let interpolator: NearestRectilinear<f64, 1> = NearestRectilinear::new(&grids, &y).unwrap();
(0..obs.len()).for_each(|i| {
let idx = nearest_rectilinear_index(obs[i], &x);
assert_eq!(y[idx], interpolator.interp_one([obs[i]]).unwrap());
})
}
}