use core::marker::PhantomData;
use alloc::boxed::Box;
use qtty::{Quantity, Unit};
use crate::data::Provenance;
use crate::grid::algo;
use crate::grid::{GridError, OutOfRange};
#[derive(Debug, Clone)]
pub struct Grid3D<X: Unit, Y: Unit, Z: Unit, V: Unit> {
x_axis: crate::grid::Axis,
y_axis: crate::grid::Axis,
z_axis: crate::grid::Axis,
values: Box<[f64]>,
out_of_range: OutOfRange,
provenance: Option<Provenance>,
_phantom: PhantomData<(X, Y, Z, V)>,
}
type Locate3 = (usize, f64, usize, f64, usize, f64);
impl<X: Unit, Y: Unit, Z: Unit, V: Unit> Grid3D<X, Y, Z, V> {
pub fn from_raw_row_major(
xs: &[f64],
ys: &[f64],
zs: &[f64],
values: &[f64],
oor: OutOfRange,
) -> Result<Self, GridError> {
let x_axis = crate::grid::Axis::NonUniform(xs.to_vec().into_boxed_slice());
x_axis.validate_for_axis("x")?;
let y_axis = crate::grid::Axis::NonUniform(ys.to_vec().into_boxed_slice());
y_axis.validate_for_axis("y")?;
let z_axis = crate::grid::Axis::NonUniform(zs.to_vec().into_boxed_slice());
z_axis.validate_for_axis("z")?;
let expected = x_axis.len() * y_axis.len() * z_axis.len();
if expected != values.len() {
return Err(GridError::ShapeMismatch {
expected,
got: values.len(),
});
}
Ok(Self {
x_axis,
y_axis,
z_axis,
values: values.to_vec().into_boxed_slice(),
out_of_range: oor,
provenance: None,
_phantom: PhantomData,
})
}
#[must_use]
pub fn interp_at(&self, x: Quantity<X>, y: Quantity<Y>, z: Quantity<Z>) -> Quantity<V> {
match self.locate_query(x.value(), y.value(), z.value(), false) {
Ok(Some((ix, tx, iy, ty, iz, tz))) => {
Quantity::new(self.interpolate(ix, tx, iy, ty, iz, tz))
}
Ok(None) => Quantity::zero(),
Err(_) => {
let (ix, tx) = self.x_axis.locate(x.value());
let (iy, ty) = self.y_axis.locate(y.value());
let (iz, tz) = self.z_axis.locate(z.value());
Quantity::new(self.interpolate(ix, tx, iy, ty, iz, tz))
}
}
}
pub fn try_interp_at(
&self,
x: Quantity<X>,
y: Quantity<Y>,
z: Quantity<Z>,
) -> Result<Quantity<V>, GridError> {
match self.locate_query(x.value(), y.value(), z.value(), true)? {
Some((ix, tx, iy, ty, iz, tz)) => {
Ok(Quantity::new(self.interpolate(ix, tx, iy, ty, iz, tz)))
}
None => Ok(Quantity::zero()),
}
}
pub fn interp_at_with(
&self,
x: Quantity<X>,
y: Quantity<Y>,
z: Quantity<Z>,
oor_x: OutOfRange,
oor_y: OutOfRange,
oor_z: OutOfRange,
) -> Result<Quantity<V>, GridError> {
let xv = x.value();
let yv = y.value();
let zv = z.value();
let (x_lo, x_hi) = self.x_axis.bounds();
let (y_lo, y_hi) = self.y_axis.bounds();
let (z_lo, z_hi) = self.z_axis.bounds();
if !algo::check_oor(xv, x_lo, x_hi, oor_x, "x")? {
return Ok(Quantity::zero());
}
if !algo::check_oor(yv, y_lo, y_hi, oor_y, "y")? {
return Ok(Quantity::zero());
}
if !algo::check_oor(zv, z_lo, z_hi, oor_z, "z")? {
return Ok(Quantity::zero());
}
let (ix, tx) = self.x_axis.locate(xv);
let (iy, ty) = self.y_axis.locate(yv);
let (iz, tz) = self.z_axis.locate(zv);
Ok(Quantity::new(self.interpolate(ix, tx, iy, ty, iz, tz)))
}
#[must_use]
pub fn nx(&self) -> usize {
self.x_axis.len()
}
#[must_use]
pub fn ny(&self) -> usize {
self.y_axis.len()
}
#[must_use]
pub fn nz(&self) -> usize {
self.z_axis.len()
}
#[must_use]
pub fn len(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
#[must_use]
pub fn with_provenance(mut self, provenance: Provenance) -> Self {
self.provenance = Some(provenance);
self
}
#[must_use]
pub fn provenance(&self) -> Option<&Provenance> {
self.provenance.as_ref()
}
#[must_use]
pub fn x_bounds(&self) -> (Quantity<X>, Quantity<X>) {
let (lo, hi) = self.x_axis.bounds();
(Quantity::<X>::new(lo), Quantity::<X>::new(hi))
}
#[must_use]
pub fn y_bounds(&self) -> (Quantity<Y>, Quantity<Y>) {
let (lo, hi) = self.y_axis.bounds();
(Quantity::<Y>::new(lo), Quantity::<Y>::new(hi))
}
#[must_use]
pub fn z_bounds(&self) -> (Quantity<Z>, Quantity<Z>) {
let (lo, hi) = self.z_axis.bounds();
(Quantity::<Z>::new(lo), Quantity::<Z>::new(hi))
}
#[must_use]
#[allow(clippy::type_complexity)]
pub fn domain(
&self,
) -> (
(Quantity<X>, Quantity<X>),
(Quantity<Y>, Quantity<Y>),
(Quantity<Z>, Quantity<Z>),
) {
(self.x_bounds(), self.y_bounds(), self.z_bounds())
}
fn interpolate(&self, ix: usize, tx: f64, iy: usize, ty: f64, iz: usize, tz: f64) -> f64 {
let nx = self.nx();
let ny = self.ny();
let idx = |iz: usize, iy: usize, ix: usize| (iz * ny + iy) * nx + ix;
let v000 = self.values[idx(iz, iy, ix)];
let v100 = self.values[idx(iz, iy, ix + 1)];
let v010 = self.values[idx(iz, iy + 1, ix)];
let v110 = self.values[idx(iz, iy + 1, ix + 1)];
let v001 = self.values[idx(iz + 1, iy, ix)];
let v101 = self.values[idx(iz + 1, iy, ix + 1)];
let v011 = self.values[idx(iz + 1, iy + 1, ix)];
let v111 = self.values[idx(iz + 1, iy + 1, ix + 1)];
algo::trilinear_unit(v000, v100, v010, v110, v001, v101, v011, v111, tx, ty, tz)
}
fn locate_query(
&self,
x: f64,
y: f64,
z: f64,
strict_error: bool,
) -> Result<Option<Locate3>, GridError> {
let x_in = self.x_axis.contains(x);
let y_in = self.y_axis.contains(y);
let z_in = self.z_axis.contains(z);
if x_in && y_in && z_in {
let (ix, tx) = self.x_axis.locate(x);
let (iy, ty) = self.y_axis.locate(y);
let (iz, tz) = self.z_axis.locate(z);
return Ok(Some((ix, tx, iy, ty, iz, tz)));
}
match self.out_of_range {
OutOfRange::ClampToEndpoints => {
let (ix, tx) = self.x_axis.locate(x);
let (iy, ty) = self.y_axis.locate(y);
let (iz, tz) = self.z_axis.locate(z);
Ok(Some((ix, tx, iy, ty, iz, tz)))
}
OutOfRange::Zero => Ok(None),
OutOfRange::Error if strict_error => {
if !x_in {
let (lo, hi) = self.x_axis.bounds();
return Err(GridError::OutOfRange {
axis: "x",
value: x,
lo,
hi,
});
}
if !y_in {
let (lo, hi) = self.y_axis.bounds();
return Err(GridError::OutOfRange {
axis: "y",
value: y,
lo,
hi,
});
}
let (lo, hi) = self.z_axis.bounds();
Err(GridError::OutOfRange {
axis: "z",
value: z,
lo,
hi,
})
}
OutOfRange::Error => {
let (ix, tx) = self.x_axis.locate(x);
let (iy, ty) = self.y_axis.locate(y);
let (iz, tz) = self.z_axis.locate(z);
Ok(Some((ix, tx, iy, ty, iz, tz)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use qtty::unit::{Kilometer, Nanometer, Radian, Ratio};
#[test]
fn trilinear_midpoint_works() {
let grid = Grid3D::<Nanometer, Radian, Kilometer, Ratio>::from_raw_row_major(
&[400.0, 500.0],
&[0.0, 1.0],
&[0.0, 2.0],
&[0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0],
OutOfRange::ClampToEndpoints,
)
.unwrap();
let value = grid.interp_at(
Quantity::<Nanometer>::new(450.0),
Quantity::<Radian>::new(0.5),
Quantity::<Kilometer>::new(1.0),
);
assert_eq!(value.value(), 7.0);
}
#[test]
fn storage_is_z_outer_y_mid_x_inner() {
let grid = Grid3D::<Nanometer, Radian, Kilometer, Ratio>::from_raw_row_major(
&[400.0, 500.0],
&[0.0, 1.0],
&[0.0, 2.0],
&[0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0],
OutOfRange::ClampToEndpoints,
)
.unwrap();
assert_eq!(
grid.interp_at(Quantity::new(400.0), Quantity::new(0.0), Quantity::new(0.0))
.value(),
0.0
);
assert_eq!(
grid.interp_at(Quantity::new(500.0), Quantity::new(0.0), Quantity::new(0.0))
.value(),
2.0
);
assert_eq!(
grid.interp_at(Quantity::new(400.0), Quantity::new(1.0), Quantity::new(0.0))
.value(),
4.0
);
assert_eq!(
grid.interp_at(Quantity::new(400.0), Quantity::new(0.0), Quantity::new(2.0))
.value(),
8.0
);
}
}