use crate::traits::FloatScalar;
use super::{InterpError, find_interval, validate_sorted};
#[derive(Debug, Clone)]
pub struct BilinearInterp<T, const NX: usize, const NY: usize> {
xs: [T; NX],
ys: [T; NY],
zs: [[T; NY]; NX], }
impl<T: FloatScalar, const NX: usize, const NY: usize> BilinearInterp<T, NX, NY> {
pub fn new(xs: [T; NX], ys: [T; NY], zs_rows: [[T; NX]; NY]) -> Result<Self, InterpError> {
if NX < 2 || NY < 2 {
return Err(InterpError::TooFewPoints);
}
validate_sorted(&xs)?;
validate_sorted(&ys)?;
let mut zs = [[T::zero(); NY]; NX];
let mut ix = 0;
while ix < NX {
let mut iy = 0;
while iy < NY {
zs[ix][iy] = zs_rows[iy][ix];
iy += 1;
}
ix += 1;
}
Ok(Self { xs, ys, zs })
}
pub fn eval(&self, x: T, y: T) -> T {
let ix = find_interval(&self.xs, x);
let iy = find_interval(&self.ys, y);
let tx = (x - self.xs[ix]) / (self.xs[ix + 1] - self.xs[ix]);
let ty = (y - self.ys[iy]) / (self.ys[iy + 1] - self.ys[iy]);
let one = T::one();
let z00 = self.zs[ix][iy];
let z10 = self.zs[ix + 1][iy];
let z01 = self.zs[ix][iy + 1];
let z11 = self.zs[ix + 1][iy + 1];
(one - tx) * (one - ty) * z00
+ tx * (one - ty) * z10
+ (one - tx) * ty * z01
+ tx * ty * z11
}
pub fn xs(&self) -> &[T; NX] {
&self.xs
}
pub fn ys(&self) -> &[T; NY] {
&self.ys
}
}
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
#[derive(Debug, Clone)]
pub struct DynBilinearInterp<T> {
xs: Vec<T>,
ys: Vec<T>,
zs: Vec<T>, ny: usize,
}
#[cfg(feature = "alloc")]
impl<T: FloatScalar> DynBilinearInterp<T> {
pub fn new(
xs: Vec<T>,
ys: Vec<T>,
zs_rows: Vec<Vec<T>>,
) -> Result<Self, InterpError> {
let nx = xs.len();
let ny = ys.len();
if nx < 2 || ny < 2 {
return Err(InterpError::TooFewPoints);
}
if zs_rows.len() != ny {
return Err(InterpError::LengthMismatch);
}
for row in &zs_rows {
if row.len() != nx {
return Err(InterpError::LengthMismatch);
}
}
validate_sorted(&xs)?;
validate_sorted(&ys)?;
let mut zs = alloc::vec![T::zero(); nx * ny];
for iy in 0..ny {
for ix in 0..nx {
zs[ix * ny + iy] = zs_rows[iy][ix];
}
}
Ok(Self { xs, ys, zs, ny })
}
pub fn from_slice(
xs: Vec<T>,
ys: Vec<T>,
zs_col_major: Vec<T>,
) -> Result<Self, InterpError> {
let nx = xs.len();
let ny = ys.len();
if nx < 2 || ny < 2 {
return Err(InterpError::TooFewPoints);
}
if zs_col_major.len() != nx * ny {
return Err(InterpError::LengthMismatch);
}
validate_sorted(&xs)?;
validate_sorted(&ys)?;
Ok(Self {
xs,
ys,
zs: zs_col_major,
ny,
})
}
pub fn eval(&self, x: T, y: T) -> T {
let ix = find_interval(&self.xs, x);
let iy = find_interval(&self.ys, y);
let tx = (x - self.xs[ix]) / (self.xs[ix + 1] - self.xs[ix]);
let ty = (y - self.ys[iy]) / (self.ys[iy + 1] - self.ys[iy]);
let one = T::one();
let z00 = self.zs[ix * self.ny + iy];
let z10 = self.zs[(ix + 1) * self.ny + iy];
let z01 = self.zs[ix * self.ny + (iy + 1)];
let z11 = self.zs[(ix + 1) * self.ny + (iy + 1)];
(one - tx) * (one - ty) * z00
+ tx * (one - ty) * z10
+ (one - tx) * ty * z01
+ tx * ty * z11
}
pub fn xs(&self) -> &[T] {
&self.xs
}
pub fn ys(&self) -> &[T] {
&self.ys
}
}