use super::*;
use ndarray::prelude::*;
mod strategies;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "D::Elem: Serialize",
deserialize = "
D: DataOwned,
D::Elem: Deserialize<'de>,
"
))
)]
pub struct InterpDataND<D>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialEq + Debug,
{
pub grid: Vec<ArrayBase<D, Ix1>>,
pub values: ArrayBase<D, IxDyn>,
}
pub type InterpDataNDViewed<T> = InterpDataND<ViewRepr<T>>;
pub type InterpDataNDOwned<T> = InterpDataND<OwnedRepr<T>>;
impl<D> PartialEq for InterpDataND<D>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialEq + Debug,
ArrayBase<D, Ix1>: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.grid == other.grid && self.values == other.values
}
}
impl<D> InterpDataND<D>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialEq + Debug,
{
pub fn new(
grid: Vec<ArrayBase<D, Ix1>>,
values: ArrayBase<D, IxDyn>,
) -> Result<Self, ValidateError>
where
D::Elem: PartialOrd,
{
let data = Self { grid, values };
data.validate()?;
Ok(data)
}
pub fn validate(&self) -> Result<(), ValidateError>
where
D::Elem: PartialOrd,
{
let n = self.ndim();
if (self.grid.len() != n) && !(n == 0 && self.grid.iter().all(|g| g.is_empty())) {
return Err(ValidateError::Other(format!(
"grid length {} does not match dimensionality {}",
self.grid.len(),
n,
)));
}
for i in 0..n {
let i_grid_len = self.grid[i].len();
if i_grid_len == 0 {
return Err(ValidateError::EmptyGrid(i));
}
if !self.grid[i].windows(2).into_iter().all(|w| w[0] <= w[1]) {
return Err(ValidateError::Monotonicity(i));
}
if i_grid_len != self.values.shape()[i] {
return Err(ValidateError::IncompatibleShapes(i));
}
}
Ok(())
}
pub fn ndim(&self) -> usize {
if self.values.len() == 1 {
0
} else {
self.values.ndim()
}
}
pub fn view(&self) -> InterpDataNDViewed<&D::Elem> {
InterpDataNDViewed {
grid: self.grid.iter().map(|g| g.view()).collect(),
values: self.values.view(),
}
}
pub fn into_owned(self) -> InterpDataNDOwned<D::Elem>
where
D::Elem: Clone,
{
InterpDataNDOwned {
grid: self.grid.into_iter().map(|g| g.into_owned()).collect(),
values: self.values.into_owned(),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "
D::Elem: Serialize,
S: Serialize,
",
deserialize = "
D: DataOwned,
D::Elem: Deserialize<'de>,
S: Deserialize<'de>
"
))
)]
pub struct InterpND<D, S>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialEq + Debug,
S: StrategyND<D> + Clone,
{
pub data: InterpDataND<D>,
pub strategy: S,
#[cfg_attr(feature = "serde", serde(default))]
pub extrapolate: Extrapolate<D::Elem>,
}
pub type InterpNDViewed<T, S> = InterpND<ViewRepr<T>, S>;
pub type InterpNDOwned<T, S> = InterpND<OwnedRepr<T>, S>;
extrapolate_impl!(InterpND, StrategyND);
partialeq_impl!(InterpND, InterpDataND, StrategyND);
impl<D, S> InterpND<D, S>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialOrd + Debug,
S: StrategyND<D> + Clone,
{
pub fn new(
grid: Vec<ArrayBase<D, Ix1>>,
values: ArrayBase<D, IxDyn>,
strategy: S,
extrapolate: Extrapolate<D::Elem>,
) -> Result<Self, ValidateError> {
let mut interpolator = Self {
data: InterpDataND::new(grid, values)?,
strategy,
extrapolate,
};
interpolator.check_extrapolate(&interpolator.extrapolate)?;
interpolator.strategy.init(&interpolator.data)?;
Ok(interpolator)
}
pub fn view(&self) -> InterpNDViewed<&D::Elem, S>
where
S: for<'a> StrategyND<ViewRepr<&'a D::Elem>>,
D::Elem: Clone,
{
InterpNDViewed {
data: self.data.view(),
strategy: self.strategy.clone(),
extrapolate: self.extrapolate.clone(),
}
}
pub fn into_owned(self) -> InterpNDOwned<D::Elem, S>
where
S: StrategyND<OwnedRepr<D::Elem>>,
D::Elem: Clone,
{
InterpNDOwned {
data: self.data.into_owned(),
strategy: self.strategy.clone(),
extrapolate: self.extrapolate.clone(),
}
}
}
impl<D, S> Interpolator<D::Elem> for InterpND<D, S>
where
D: Data + RawDataClone + Clone,
D::Elem: Num + Euclid + PartialOrd + Debug + Copy,
S: StrategyND<D> + Clone,
{
#[inline]
fn ndim(&self) -> usize {
self.data.ndim()
}
fn validate(&mut self) -> Result<(), ValidateError> {
self.check_extrapolate(&self.extrapolate)?;
self.data.validate()?;
self.strategy.init(&self.data)?;
Ok(())
}
fn interpolate(&self, point: &[D::Elem]) -> Result<D::Elem, InterpolateError> {
let n = self.ndim();
if point.len() != n {
return Err(InterpolateError::PointLength(n));
}
let mut errors = Vec::new();
for dim in 0..n {
if !(self.data.grid[dim].first().unwrap()..=self.data.grid[dim].last().unwrap())
.contains(&&point[dim])
{
match &self.extrapolate {
Extrapolate::Enable => {}
Extrapolate::Fill(value) => return Ok(*value),
Extrapolate::Clamp => {
let clamped_point: Vec<_> = point
.iter()
.enumerate()
.map(|(dim, pt)| {
*clamp(
pt,
self.data.grid[dim].first().unwrap(),
self.data.grid[dim].last().unwrap(),
)
})
.collect();
return self.strategy.interpolate(&self.data, &clamped_point);
}
Extrapolate::Wrap => {
let wrapped_point: Vec<_> = point
.iter()
.enumerate()
.map(|(dim, pt)| {
wrap(
*pt,
*self.data.grid[dim].first().unwrap(),
*self.data.grid[dim].last().unwrap(),
)
})
.collect();
return self.strategy.interpolate(&self.data, &wrapped_point);
}
Extrapolate::Error => {
errors.push(format!(
"\n point[{dim}] = {:?} is out of bounds for grid[{dim}] = {:?}",
point[dim], self.data.grid[dim],
));
}
};
}
}
if !errors.is_empty() {
return Err(InterpolateError::ExtrapolateError(errors.join("")));
}
self.strategy.interpolate(&self.data, point)
}
fn set_extrapolate(&mut self, extrapolate: Extrapolate<D::Elem>) -> Result<(), ValidateError> {
self.check_extrapolate(&extrapolate)?;
self.extrapolate = extrapolate;
Ok(())
}
}
impl<D> InterpND<D, Box<dyn StrategyND<D>>>
where
D: Data + RawDataClone + Clone,
D::Elem: PartialEq + Debug,
{
pub fn set_strategy(&mut self, strategy: Box<dyn StrategyND<D>>) -> Result<(), ValidateError> {
self.strategy = strategy;
self.check_extrapolate(&self.extrapolate)
}
}
impl<D> InterpND<D, strategy::enums::StrategyNDEnum>
where
D: Data + RawDataClone + Clone,
D::Elem: Num + PartialOrd + Copy + Debug,
{
pub fn set_strategy(
&mut self,
strategy: impl Into<strategy::enums::StrategyNDEnum>,
) -> Result<(), ValidateError> {
self.strategy = strategy.into();
self.check_extrapolate(&self.extrapolate)
}
}