use crate::index_arr_fixed_dims;
use crunchy::unroll;
use num_traits::{Float, NumCast};
pub fn interpn<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
obs: &[&[T]],
out: &mut [T],
) -> Result<(), &'static str> {
let ndims = dims.len();
if starts.len() != ndims || steps.len() != ndims || obs.len() != ndims {
return Err("Dimension mismatch");
}
crate::dispatch_ndims!(
ndims,
"Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions.",
[1, 2, 3, 4, 5, 6, 7, 8],
|N| {
MultilinearRegular::<'_, T, N>::new(
dims.try_into().unwrap(),
starts.try_into().unwrap(),
steps.try_into().unwrap(),
vals,
)?
.interp(obs.try_into().unwrap(), out)
}
)?;
Ok(())
}
#[cfg(feature = "std")]
pub fn interpn_alloc<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
vals: &[T],
obs: &[&[T]],
) -> Result<Vec<T>, &'static str> {
let mut out = vec![T::zero(); obs[0].len()];
interpn(dims, starts, steps, vals, obs, &mut out)?;
Ok(out)
}
pub fn check_bounds<T: Float>(
dims: &[usize],
starts: &[T],
steps: &[T],
obs: &[&[T]],
atol: T,
out: &mut [bool],
) -> Result<(), &'static str> {
let n = dims.len();
if !(obs.len() == n && out.len() == n) {
return Err("Dimension mismatch");
}
for i in 0..n {
let first = starts[i];
let last_elem = <T as NumCast>::from(dims[i] - 1);
match last_elem {
Some(last_elem) => {
let last = starts[i] + steps[i] * last_elem;
let lo = first.min(last);
let hi = first.max(last);
let bad = obs[i]
.iter()
.any(|&x| (x - lo) <= -atol || (x - hi) >= atol);
out[i] = bad;
}
None => {
out[i] = true;
}
}
}
Ok(())
}
pub struct MultilinearRegular<'a, T: Float, const N: usize> {
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
}
impl<'a, T: Float, const N: usize> MultilinearRegular<'a, T, N> {
pub fn new(
dims: [usize; N],
starts: [T; N],
steps: [T; N],
vals: &'a [T],
) -> Result<Self, &'static str> {
crate::validate_regular_grid(&dims, &steps, vals)?;
Ok(Self {
dims,
starts,
steps,
vals,
})
}
pub fn interp(&self, x: &[&[T]; N], out: &mut [T]) -> Result<(), &'static str> {
let n = out.len();
for i in 0..N {
if x[i].len() != n {
return Err("Dimension mismatch");
}
}
let mut tmp = [T::zero(); N];
for i in 0..n {
(0..N).for_each(|j| tmp[j] = x[j][i]);
out[i] = self.interp_one(tmp)?;
}
Ok(())
}
#[inline]
pub fn interp_one(&self, x: [T; N]) -> Result<T, &'static str> {
let mut origin = [0_usize; N]; let mut dts = [T::zero(); N]; let mut dimprod = [1_usize; N];
let mut loc = [0_usize; N];
let mut store = [[T::zero(); FP]; N];
let mut acc = 1;
for i in 0..N {
if i > 0 {
acc *= self.dims[N - i];
}
dimprod[N - i - 1] = acc;
origin[i] = self.get_loc(x[i], i)?;
let origin_f =
<T as NumCast>::from(origin[i]).ok_or("Unrepresentable coordinate value")?;
#[cfg(not(feature = "fma"))]
let index_zero_loc = self.starts[i] + self.steps[i] * origin_f;
#[cfg(feature = "fma")]
let index_zero_loc = self.steps[i].mul_add(origin_f, self.starts[i]);
dts[i] = (x[i] - index_zero_loc) / self.steps[i];
}
const FP: usize = 2; let nverts = const { FP.pow(N as u32) };
macro_rules! unroll_vertices_body {
($i:ident) => {
for j in 0..N {
if j == 0 {
for k in 0..N {
let offset: usize = ($i & (1 << k)) >> k;
loc[k] = origin[k] + offset;
}
let store_ind: usize = $i % FP;
store[0][store_ind] = index_arr_fixed_dims(loc, dimprod, self.vals);
} else {
let q: usize = FP.pow(j as u32);
let level: bool = ($i + 1).is_multiple_of(q);
if level {
let p: usize = (($i + 1) / q).saturating_sub(1) % FP;
let ind: usize = j.saturating_sub(1);
let y0 = store[ind][0];
let dy = store[ind][1] - y0;
let t = dts[ind];
#[cfg(not(feature = "fma"))]
let interped = y0 + t * dy;
#[cfg(feature = "fma")]
let interped = t.mul_add(dy, y0);
store[j][p] = interped;
}
}
}
};
}
if N <= 6 {
unroll! {
for i < 64 in 0..nverts { unroll_vertices_body!(i)
}
}
} else {
for i in 0..nverts {
unroll_vertices_body!(i)
}
}
let y0 = store[N - 1][0];
let dy = store[N - 1][1] - y0;
let t = dts[N - 1];
#[cfg(not(feature = "fma"))]
let interped = y0 + t * dy;
#[cfg(feature = "fma")]
let interped = t.mul_add(dy, y0);
Ok(interped)
}
#[inline]
fn get_loc(&self, v: T, dim: usize) -> Result<usize, &'static str> {
let floc = ((v - self.starts[dim]) / self.steps[dim]).floor(); let iloc = <isize as NumCast>::from(floc).ok_or("Unrepresentable coordinate value")?;
let n = self.dims[dim] as isize; let dimmax = n.saturating_sub(2).max(0); let loc: usize = iloc.max(0).min(dimmax) as usize;
Ok(loc)
}
}
#[cfg(test)]
mod test {
use super::interpn;
use crate::{MultilinearRegular, utils::*};
#[test]
fn test_interp_extrap_1d_to_8d() {
for n in 1..=8 {
println!("Testing in {n} dims");
let dims: Vec<usize> = vec![2; n];
let xs: Vec<Vec<f64>> = (0..n)
.map(|i| linspace(-5.0 * (i as f64), 5.0 * ((i + 1) as f64), dims[i]))
.collect();
let grid = meshgrid((0..n).map(|i| &xs[i]).collect());
let u: Vec<f64> = grid.iter().map(|x| x.iter().sum()).collect(); let starts: Vec<f64> = xs.iter().map(|x| x[0]).collect();
let steps: Vec<f64> = xs.iter().map(|x| x[1] - x[0]).collect();
let xobs: Vec<Vec<f64>> = (0..n)
.map(|i| linspace(-7.0 * (i as f64), 7.0 * ((i + 1) as f64), 3))
.collect();
let gridobs = meshgrid((0..n).map(|i| &xobs[i]).collect());
let gridobs_t: Vec<Vec<f64>> = (0..n)
.map(|i| gridobs.iter().map(|x| x[i]).collect())
.collect(); let xobsslice: Vec<&[f64]> = gridobs_t.iter().map(|x| &x[..]).collect();
let uobs: Vec<f64> = gridobs.iter().map(|x| x.iter().sum()).collect(); let mut out = vec![0.0; uobs.len()];
interpn(&dims, &starts, &steps, &u, &xobsslice, &mut out[..]).unwrap();
(0..uobs.len()).for_each(|i| {
let outi = out[i];
let uobsi = uobs[i];
println!("{outi} {uobsi}");
assert!((out[i] - uobs[i]).abs() < 1e-12)
});
}
}
#[test]
fn test_interp_hat_func() {
fn hat_func(x: f64) -> f64 {
if x <= 1.0 { x } else { 2.0 - x }
}
let y = (0..3).map(|x| hat_func(x as f64)).collect::<Vec<f64>>();
let obs = linspace(-2.0, 4.0, 100);
let interpolator: MultilinearRegular<f64, 1> =
MultilinearRegular::new([3], [0.0], [1.0], &y).unwrap();
(0..obs.len()).for_each(|i| {
assert_eq!(hat_func(obs[i]), interpolator.interp_one([obs[i]]).unwrap());
})
}
}