use crate::ext_qtty::{Quantity, Scalar, Unit};
use crate::interp::OutOfRange;
use crate::provenance::Provenance;
use super::{algo, AxisDirection, TableError};
#[derive(Debug, Clone)]
pub struct Grid1D<X: Unit, V: Unit, S: Scalar = f64> {
xs: Vec<S>,
vs: Vec<S>,
dir_x: AxisDirection,
provenance: Provenance,
_markers: core::marker::PhantomData<(X, V)>,
}
impl<X: Unit, V: Unit, S: Scalar + Into<f64> + From<f64>> Grid1D<X, V, S> {
pub fn from_raw(xs: Vec<S>, vs: Vec<S>) -> Result<Self, TableError> {
if xs.len() != vs.len() {
return Err(TableError::ShapeMismatch {
expected_x: xs.len(),
expected_y: 1,
actual_rows: 1,
actual_cols: vs.len(),
});
}
let xs_f64: Vec<f64> = xs.iter().copied().map(Into::into).collect();
let dir_x = algo::validate_axis("x", &xs_f64)?;
Ok(Self {
xs,
vs,
dir_x,
provenance: Provenance::default(),
_markers: core::marker::PhantomData,
})
}
pub fn with_provenance(mut self, provenance: Provenance) -> Self {
self.provenance = provenance;
self
}
pub fn provenance(&self) -> &Provenance {
&self.provenance
}
pub fn len(&self) -> usize {
self.xs.len()
}
pub fn is_empty(&self) -> bool {
self.xs.is_empty()
}
pub fn interp_at(
&self,
x: Quantity<X, S>,
oor: OutOfRange,
) -> Result<Quantity<V, S>, TableError> {
let xs_f64: Vec<f64> = self.xs.iter().copied().map(Into::into).collect();
let vs_f64: Vec<f64> = self.vs.iter().copied().map(Into::into).collect();
let v = algo::linear_1d(&xs_f64, &vs_f64, x.value().into(), oor, self.dir_x)?;
Ok(Quantity::<V, S>::new(S::from(v)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ext_qtty::length::{Meter, Nanometer};
#[test]
fn typed_round_trip() {
let xs = vec![400.0, 500.0, 600.0];
let vs = vec![1.0, 2.0, 3.0];
let g: Grid1D<Nanometer, Meter> = Grid1D::from_raw(xs, vs).unwrap();
let q = Quantity::<Nanometer>::new(450.0);
let v = g.interp_at(q, OutOfRange::Error).unwrap();
assert_eq!(v.value(), 1.5);
}
#[test]
fn rejects_non_monotonic() {
let xs = vec![1.0_f64, 2.0, 2.0];
let vs = vec![1.0_f64, 2.0, 3.0];
let g: Result<Grid1D<Nanometer, Meter>, _> = Grid1D::from_raw(xs, vs);
assert!(matches!(g, Err(TableError::NotMonotonic { .. })));
}
#[test]
fn descending_axis_interpolates_correctly() {
let xs = vec![10.0_f64, 5.0, 0.0];
let vs = vec![1.0_f64, 2.0, 3.0];
let g: Grid1D<Nanometer, Meter> = Grid1D::from_raw(xs, vs).unwrap();
let v = g
.interp_at(Quantity::<Nanometer>::new(7.5), OutOfRange::Error)
.unwrap();
assert_eq!(v.value(), 1.5);
}
#[test]
fn descending_axis_clamp_above_max() {
let xs = vec![10.0_f64, 5.0, 0.0];
let vs = vec![1.0_f64, 2.0, 3.0];
let g: Grid1D<Nanometer, Meter> = Grid1D::from_raw(xs, vs).unwrap();
let v = g
.interp_at(
Quantity::<Nanometer>::new(15.0),
OutOfRange::ClampToEndpoints,
)
.unwrap();
assert_eq!(v.value(), 1.0);
}
}