use crate::ext_qtty::{Quantity, Scalar, Unit};
use crate::interp::OutOfRange;
use crate::provenance::Provenance;
use core::cmp::Ordering;
use super::{algo, AxisDirection, TableError};
#[derive(Debug, Clone, Copy)]
pub struct ConstantRegion<S: Scalar> {
pub x_min_inclusive: Option<S>,
pub x_max_exclusive: Option<S>,
pub y_min_inclusive: Option<S>,
pub y_max_exclusive: Option<S>,
pub value: S,
}
impl<S: Scalar> ConstantRegion<S> {
pub fn lower_corner(x_max_exclusive: S, y_max_exclusive: S, value: S) -> Self {
Self {
x_min_inclusive: None,
x_max_exclusive: Some(x_max_exclusive),
y_min_inclusive: None,
y_max_exclusive: Some(y_max_exclusive),
value,
}
}
#[inline]
fn contains(&self, x: S, y: S) -> bool {
if let Some(b) = self.x_min_inclusive {
if !matches!(
x.partial_cmp(&b),
Some(Ordering::Greater) | Some(Ordering::Equal)
) {
return false;
}
}
if let Some(b) = self.x_max_exclusive {
if !matches!(x.partial_cmp(&b), Some(Ordering::Less)) {
return false;
}
}
if let Some(b) = self.y_min_inclusive {
if !matches!(
y.partial_cmp(&b),
Some(Ordering::Greater) | Some(Ordering::Equal)
) {
return false;
}
}
if let Some(b) = self.y_max_exclusive {
if !matches!(y.partial_cmp(&b), Some(Ordering::Less)) {
return false;
}
}
true
}
}
#[derive(Debug, Clone)]
pub struct Grid2D<X: Unit, Y: Unit, V: Unit, S: Scalar = f64> {
xs: Vec<S>,
ys: Vec<S>,
table: Vec<S>,
nx: usize,
ny: usize,
dir_x: AxisDirection,
dir_y: AxisDirection,
y_reflect_offset: Option<S>,
regions: Vec<ConstantRegion<S>>,
provenance: Provenance,
_markers: core::marker::PhantomData<(X, Y, V)>,
}
impl<X: Unit, Y: Unit, V: Unit, S: Scalar + Into<f64> + From<f64>> Grid2D<X, Y, V, S> {
pub fn from_raw_row_major(xs: Vec<S>, ys: Vec<S>, table: Vec<S>) -> Result<Self, TableError> {
let nx = xs.len();
let ny = ys.len();
if table.len() != nx * ny {
return Err(TableError::ShapeMismatch {
expected_x: nx,
expected_y: ny,
actual_rows: if nx == 0 { 0 } else { table.len() / nx.max(1) },
actual_cols: nx,
});
}
let xs_f64: Vec<f64> = xs.iter().copied().map(Into::into).collect();
let ys_f64: Vec<f64> = ys.iter().copied().map(Into::into).collect();
let dir_x = algo::validate_axis("x", &xs_f64)?;
let dir_y = algo::validate_axis("y", &ys_f64)?;
Ok(Self {
xs,
ys,
table,
nx,
ny,
dir_x,
dir_y,
y_reflect_offset: None,
regions: Vec::new(),
provenance: Provenance::default(),
_markers: core::marker::PhantomData,
})
}
pub fn from_raw_row_major_y_descending(
xs: Vec<S>,
ys_desc: Vec<S>,
table_desc: Vec<S>,
) -> Result<Self, TableError> {
let nx = xs.len();
let ny = ys_desc.len();
if table_desc.len() != nx * ny {
return Err(TableError::ShapeMismatch {
expected_x: nx,
expected_y: ny,
actual_rows: if nx == 0 {
0
} else {
table_desc.len() / nx.max(1)
},
actual_cols: nx,
});
}
if ny < 2 {
return Err(TableError::TooFewSamples { axis: "y", len: ny });
}
let ys_desc_f64: Vec<f64> = ys_desc.iter().copied().map(Into::into).collect();
let step = ys_desc_f64[1] - ys_desc_f64[0];
if !matches!(step.partial_cmp(&0.0), Some(Ordering::Less)) {
return Err(TableError::NotMonotonic {
axis: "y",
at_index: 1,
});
}
for i in 2..ny {
let s = ys_desc_f64[i] - ys_desc_f64[i - 1];
if s.to_bits() != step.to_bits() {
return Err(TableError::NotMonotonic {
axis: "y",
at_index: i,
});
}
}
let mut ys_asc: Vec<S> = ys_desc.clone();
ys_asc.reverse();
let table_asc: Vec<S> = table_desc;
let xs_f64: Vec<f64> = xs.iter().copied().map(Into::into).collect();
let dir_x = algo::validate_axis("x", &xs_f64)?;
let off = ys_desc[0] + ys_desc[ny - 1];
Ok(Self {
xs,
ys: ys_asc,
table: table_asc,
nx,
ny,
dir_x,
dir_y: AxisDirection::Ascending,
y_reflect_offset: Some(off),
regions: Vec::new(),
provenance: Provenance::default(),
_markers: core::marker::PhantomData,
})
}
pub fn with_constant_region(mut self, region: ConstantRegion<S>) -> Self {
self.regions.push(region);
self
}
pub fn with_provenance(mut self, provenance: Provenance) -> Self {
self.provenance = provenance;
self
}
pub fn provenance(&self) -> &Provenance {
&self.provenance
}
pub fn nx(&self) -> usize {
self.nx
}
pub fn ny(&self) -> usize {
self.ny
}
pub fn interp_at(
&self,
x: Quantity<X, S>,
y: Quantity<Y, S>,
oor_x: OutOfRange,
oor_y: OutOfRange,
) -> Result<Quantity<V, S>, TableError> {
let xv = x.value();
let yv = y.value();
for r in &self.regions {
if r.contains(xv, yv) {
return Ok(Quantity::<V, S>::new(r.value));
}
}
let y_internal = match self.y_reflect_offset {
Some(off) => off - yv,
None => yv,
};
let xs_f64: Vec<f64> = self.xs.iter().copied().map(Into::into).collect();
let ys_f64: Vec<f64> = self.ys.iter().copied().map(Into::into).collect();
let table_f64: Vec<f64> = self.table.iter().copied().map(Into::into).collect();
let rows: Vec<&[f64]> = (0..self.ny)
.map(|iy| &table_f64[iy * self.nx..(iy + 1) * self.nx])
.collect();
let v = algo::bilinear(
&xs_f64,
&ys_f64,
&rows,
xv.into(),
y_internal.into(),
oor_x,
oor_y,
self.dir_x,
self.dir_y,
)?;
Ok(Quantity::<V, S>::new(S::from(v)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ext_qtty::length::{Meter, Nanometer};
#[test]
fn typed_corners_recover_table_values() {
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![0.0, 1.0];
let table = vec![10.0, 20.0, 30.0, 100.0, 200.0, 300.0];
let g: Grid2D<Nanometer, Meter, Meter> = Grid2D::from_raw_row_major(xs, ys, table).unwrap();
let v = g
.interp_at(
Quantity::<Nanometer>::new(1.0),
Quantity::<Meter>::new(1.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v.value(), 200.0);
}
#[test]
fn rejects_shape_mismatch() {
let xs = vec![0.0, 1.0];
let ys = vec![0.0, 1.0];
let table = vec![1.0, 2.0, 3.0]; let r: Result<Grid2D<Nanometer, Meter, Meter>, _> =
Grid2D::from_raw_row_major(xs, ys, table);
assert!(matches!(r, Err(TableError::ShapeMismatch { .. })));
}
#[test]
fn accepts_descending_y_axis() {
use super::super::AxisDirection;
let xs = vec![0.0_f64, 1.0, 2.0];
let ys = vec![10.0_f64, 5.0, 0.0]; let table = vec![1.0_f64; 9];
let g: Grid2D<Nanometer, Meter, Meter> = Grid2D::from_raw_row_major(xs, ys, table).unwrap();
assert_eq!(g.dir_y, AxisDirection::Descending);
assert_eq!(g.dir_x, AxisDirection::Ascending);
}
#[test]
fn grid2d_descending_y_nsb_parity_bit_for_bit() {
let xs: Vec<f64> = vec![0.0, 5.0, 10.0, 15.0, 20.0];
let ys: Vec<f64> = vec![20.0, 15.0, 10.0, 5.0, 0.0];
let mut table_flat = vec![0.0_f64; 25];
for iy in 0..5usize {
for ix in 0..5usize {
table_flat[iy * 5 + ix] = (iy + 1) as f64 * 100.0 + (ix + 1) as f64 * 10.0;
}
}
let g: Grid2D<Nanometer, Meter, Meter> =
Grid2D::from_raw_row_major(xs.clone(), ys.clone(), table_flat.clone()).unwrap();
let xq = 7.5_f64;
let yq = 7.5_f64;
let got = g
.interp_at(
Quantity::<Nanometer>::new(xq),
Quantity::<Meter>::new(yq),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
let ix0 = 1usize;
let iy0 = 2usize;
let bt = (xq - xs[ix0]) / (xs[ix0 + 1] - xs[ix0]);
let lt = (yq - ys[iy0]) / (ys[iy0 + 1] - ys[iy0]);
let expected = super::super::algo::bilinear_unit(
table_flat[iy0 * 5 + ix0],
table_flat[iy0 * 5 + ix0 + 1],
table_flat[(iy0 + 1) * 5 + ix0],
table_flat[(iy0 + 1) * 5 + ix0 + 1],
bt,
lt,
);
assert_eq!(
got.value().to_bits(),
expected.to_bits(),
"Grid2D(descending y) must match bilinear_unit bit-for-bit: got={}, expected={}",
got.value(),
expected
);
}
#[test]
fn constant_region_short_circuits() {
let xs = vec![0.0_f64, 10.0, 20.0];
let ys = vec![0.0_f64, 10.0, 20.0];
let table = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let g: Grid2D<Nanometer, Meter, Meter> = Grid2D::from_raw_row_major(xs, ys, table)
.unwrap()
.with_constant_region(ConstantRegion::lower_corner(5.0, 5.0, 999.0));
let v_in = g
.interp_at(
Quantity::<Nanometer>::new(2.0),
Quantity::<Meter>::new(2.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v_in.value(), 999.0);
let v_b = g
.interp_at(
Quantity::<Nanometer>::new(5.0),
Quantity::<Meter>::new(2.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_ne!(v_b.value(), 999.0);
let v_out = g
.interp_at(
Quantity::<Nanometer>::new(15.0),
Quantity::<Meter>::new(15.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v_out.value(), 7.0); }
#[test]
fn constant_regions_first_match_wins() {
let xs = vec![0.0_f64, 10.0];
let ys = vec![0.0_f64, 10.0];
let table = vec![1.0, 2.0, 3.0, 4.0];
let g: Grid2D<Nanometer, Meter, Meter> = Grid2D::from_raw_row_major(xs, ys, table)
.unwrap()
.with_constant_region(ConstantRegion::lower_corner(5.0, 5.0, 100.0))
.with_constant_region(ConstantRegion::lower_corner(8.0, 8.0, 200.0));
let v = g
.interp_at(
Quantity::<Nanometer>::new(2.0),
Quantity::<Meter>::new(2.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v.value(), 100.0);
let v = g
.interp_at(
Quantity::<Nanometer>::new(6.0),
Quantity::<Meter>::new(6.0),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v.value(), 200.0);
}
#[test]
fn y_descending_rejects_non_uniform() {
let xs = vec![0.0_f64, 1.0];
let ys = vec![10.0_f64, 7.0, 0.0]; let table = vec![1.0_f64; 6];
let r: Result<Grid2D<Nanometer, Meter, Meter>, _> =
Grid2D::from_raw_row_major_y_descending(xs, ys, table);
assert!(matches!(r, Err(TableError::NotMonotonic { .. })));
}
#[test]
fn y_descending_rejects_ascending() {
let xs = vec![0.0_f64, 1.0];
let ys = vec![0.0_f64, 5.0, 10.0];
let table = vec![1.0_f64; 6];
let r: Result<Grid2D<Nanometer, Meter, Meter>, _> =
Grid2D::from_raw_row_major_y_descending(xs, ys, table);
assert!(matches!(r, Err(TableError::NotMonotonic { .. })));
}
#[test]
fn y_descending_round_trip() {
let xs = vec![0.0_f64, 5.0, 10.0];
let ys_desc = vec![20.0_f64, 15.0, 10.0, 5.0, 0.0];
let mut table = Vec::with_capacity(15);
for iy in 0..5 {
for ix in 0..3 {
table.push((iy * 10 + ix) as f64);
}
}
let g: Grid2D<Nanometer, Meter, Meter> =
Grid2D::from_raw_row_major_y_descending(xs.clone(), ys_desc.clone(), table.clone())
.unwrap();
for (iy, &yv) in ys_desc.iter().enumerate() {
for ix in 0..3 {
let v = g
.interp_at(
Quantity::<Nanometer>::new(xs[ix]),
Quantity::<Meter>::new(yv),
OutOfRange::Error,
OutOfRange::Error,
)
.unwrap();
assert_eq!(v.value(), table[iy * 3 + ix]);
}
}
}
#[test]
fn y_descending_matches_legacy_bit_for_bit() {
let xs: Vec<f64> = (0..=18).map(|i| i as f64 * 5.0).collect();
let ys_desc: Vec<f64> = (0..37).map(|i| 180.0 - i as f64 * 5.0).collect();
let mut table = Vec::with_capacity(37 * 19);
for iy in 0..37 {
for ix in 0..19 {
table.push((iy as f64) * 100.0 + (ix as f64) * 7.5);
}
}
let g: Grid2D<Nanometer, Meter, Meter> =
Grid2D::from_raw_row_major_y_descending(xs.clone(), ys_desc.clone(), table.clone())
.unwrap();
let dl = 27.3_f64;
let beta = 12.7_f64;
let b0 = (beta / 5.0).floor() as usize;
let bt = (beta - 5.0 * b0 as f64) / 5.0;
let l0_idx = ((180.0 - dl.ceil()) / 5.0).floor() as isize;
let l0 = l0_idx.clamp(0, 35) as usize;
let lt = (180.0 - dl - 5.0 * l0 as f64) / 5.0;
let row_lo = l0; let row_hi = l0 + 1;
let expected = super::super::algo::bilinear_unit(
table[row_lo * 19 + b0],
table[row_lo * 19 + b0 + 1],
table[row_hi * 19 + b0],
table[row_hi * 19 + b0 + 1],
bt,
lt,
);
let got = g
.interp_at(
Quantity::<Nanometer>::new(beta),
Quantity::<Meter>::new(dl),
OutOfRange::ClampToEndpoints,
OutOfRange::ClampToEndpoints,
)
.unwrap();
assert_eq!(
got.value().to_bits(),
expected.to_bits(),
"y-descending Grid2D must match legacy bit-for-bit: got={}, expected={}",
got.value(),
expected
);
}
}