use crate::{index_arr_fixed_dims, mul_add};
use num_traits::{Float, NumCast};
pub fn interpn<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = dims.len();
if starts.len() != ndims || steps.len() != ndims || obs.len() != ndims {
return Err("Dimension mismatch");
}
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8).",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
NearestRegular::<'_, T, N>::new(
dims.try_into().unwrap(),
starts.try_into().unwrap(),
steps.try_into().unwrap(),
vals,
)?
.interp(obs.try_into().unwrap(), out)
}
)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(dims, starts, steps, vals, obs, &mut out)?;
Ok(out)
}
pub use crate::multilinear::regular::check_bounds;
pub struct NearestRegular<'a, T: Float, const N: usize> {
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
}
impl<'a, T: Float, const N: usize> NearestRegular<'a, T, N> {
pub fn new(
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
) -> Result<Self, &'static str> {
const {
assert!(
N > 0 && N < 9,
"Flattened method defined for 1-8 dimensions. For higher dimensions, use recursive method."
);
}
crate::validate_regular_grid(&dims, &steps, vals)?;
Ok(Self {
dims,
starts,
steps,
vals,
})
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
#[inline]
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let two = T::one() + T::one();
let half = T::one() / two;
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
let origin = self.get_loc(x[i], i)?;
let origin_f =
<T as NumCast>::from(origin).ok_or("Unrepresentable coordinate value")?;
let index_zero_loc = mul_add(self.steps[i], origin_f, self.starts[i]);
let dt = (x[i] - index_zero_loc) / self.steps[i];
let offset = if dt <= half { 0 } else { 1 };
loc[i] = origin + offset;
}
let interped = index_arr_fixed_dims(loc, dimprod, self.vals);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<usize, &'static str> {
let floc = ((v - self.starts[dim]) / self.steps[dim]).floor(); let iloc = <isize as NumCast>::from(floc).ok_or("Unrepresentable coordinate value")?;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(2).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
Ok(loc)
}
}
#[cfg(test)]
mod test {
use super::interpn;
use crate::{NearestRegular, utils::*};
fn nearest_regular_index(value: f64, start: f64, step: f64, dim: usize) -> usize {
let floc = ((value - start) / step).floor();
let n = dim as isize;
let dimmax = n.saturating_sub(2).max(0);
let origin = floc as isize;
let origin = origin.max(0).min(dimmax) as usize;
let index_zero = start + step * origin as f64;
let dt = (value - index_zero) / step;
if dt <= 0.5 {
origin
} else {
(origin + 1).min(dim - 1)
}
}
#[test]
fn test_interp_extrap_1d_to_8d() {
for n in 1..=8 {
println!("Testing in {n} dims");
let dims: Vec<usize> = vec![2; n];
let xs: Vec<Vec<f64>> = (0..n)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]))
.collect();
let grid = meshgrid((0..n).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect(); let starts: Vec<f64> = xs.iter().map(|x| x[0]).collect();
let steps: Vec<f64> = xs.iter().map(|x| x[1] - x[0]).collect();
let xobs: Vec<Vec<f64>> = (0..n)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), 3))
.collect();
let gridobs = meshgrid((0..n).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..n)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let expected: Vec<f64> = gridobs
.iter()
.map(|point| {
(0..n)
.map(|dim| {
let idx = nearest_regular_index(
point[dim],
starts[dim],
steps[dim],
dims[dim],
);
starts[dim] + steps[dim] * idx as f64
})
.sum()
})
.collect();
let mut out = vec![0.0; expected.len()];
interpn(&dims, &starts, &steps, &u, &xobsslice, &mut out[..]).unwrap();
(0..expected.len()).for_each(|i| {
let outi = out[i];
let expecti = expected[i];
println!("{outi} {expecti}");
assert!((outi - expecti).abs() < 1e-12)
});
}
}
#[test]
fn test_interp_hat_func() {
fn hat_func(x: f64) -> f64 {
if x <= 1.0 { x } else { 2.0 - x }
}
let y = (0..3).map(|x| hat_func(x as f64)).collect::<Vec<f64>>();
let obs = linspace(-2.0, 4.0, 100);
let interpolator: NearestRegular<f64, 1> =
NearestRegular::new([3], [0.0], [1.0], &y).unwrap();
(0..obs.len()).for_each(|i| {
let idx = nearest_regular_index(obs[i], 0.0, 1.0, y.len());
assert_eq!(y[idx], interpolator.interp_one([obs[i]]).unwrap());
})
}
}