ndarray_interp/interp2d/
mod.rs

1//! A collection of structs and traits to interpolate data along the first two axis
2//!
3//! # Interpolator
4//!  - [`Interp2D`] The interpolator used with any strategy
5//!  - [`Interp2DBuilder`] Configure the interpolator
6//!
7//! # Traits
8//!  - [`Interp2DStrategy`] The trait used to specialize [`Interp2D`] with the correct strategy
9//!  - [`Interp2DStrategyBuilder`] The trait used to specialize [`Interp2DBuilder`] to initialize the correct strategy
10//!
11//! # Strategies
12//!  - [`Bilinear`] Linear interpolation strategy
13
14use std::{any::TypeId, fmt::Debug, ops::Sub};
15
16use ndarray::{
17    Array, Array1, ArrayBase, ArrayView, ArrayViewMut, ArrayViewMut1, Axis, AxisDescription, Data,
18    DimAdd, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr, RemoveAxis, Slice, Zip,
19};
20use num_traits::{cast, Num, NumCast};
21
22use crate::{
23    cast_unchecked,
24    dim_extensions::DimExtension,
25    vector_extensions::{Monotonic, VectorExtensions},
26    BuilderError, InterpolateError,
27};
28
29mod aliases;
30mod strategies;
31pub use aliases::*;
32pub use strategies::{Bilinear, Interp2DStrategy, Interp2DStrategyBuilder};
33
34/// Two dimensional interpolator
35#[derive(Debug)]
36pub struct Interp2D<Sd, Sx, Sy, D, Strat>
37where
38    Sd: Data,
39    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
40    Sx: Data<Elem = Sd::Elem>,
41    Sy: Data<Elem = Sd::Elem>,
42    D: Dimension,
43{
44    x: ArrayBase<Sx, Ix1>,
45    y: ArrayBase<Sy, Ix1>,
46    data: ArrayBase<Sd, D>,
47    strategy: Strat,
48}
49
50/// Create and configure a [Interp2D] interpolator.
51#[derive(Debug)]
52pub struct Interp2DBuilder<Sd, Sx, Sy, D, Strat>
53where
54    Sd: Data,
55    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
56    Sx: Data<Elem = Sd::Elem>,
57    Sy: Data<Elem = Sd::Elem>,
58    D: Dimension,
59{
60    x: ArrayBase<Sx, Ix1>,
61    y: ArrayBase<Sy, Ix1>,
62    data: ArrayBase<Sd, D>,
63    strategy: Strat,
64}
65
66impl<Sd, D> Interp2D<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear>
67where
68    Sd: Data,
69    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
70    D: Dimension,
71{
72    /// Get the [Interp2DBuilder]
73    pub fn builder(
74        data: ArrayBase<Sd, D>,
75    ) -> Interp2DBuilder<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear> {
76        Interp2DBuilder::new(data)
77    }
78}
79
80impl<Sd, Sx, Sy, Strat> Interp2D<Sd, Sx, Sy, Ix2, Strat>
81where
82    Sd: Data,
83    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
84    Sx: Data<Elem = Sd::Elem>,
85    Sy: Data<Elem = Sd::Elem>,
86    Strat: Interp2DStrategy<Sd, Sx, Sy, Ix2>,
87{
88    /// convinient interpolation function for interpolation at one point
89    /// when the data dimension is [`type@Ix2`]
90    ///
91    /// ```rust
92    /// # use ndarray_interp::*;
93    /// # use ndarray_interp::interp2d::*;
94    /// # use ndarray::*;
95    /// # use approx::*;
96    /// let data = array![
97    ///     [1.0, 2.0],
98    ///     [3.0, 4.0],
99    /// ];
100    /// let (qx, qy) = (0.0, 0.5);
101    /// let expected = 1.5;
102    ///
103    /// let interpolator = Interp2D::builder(data).build().unwrap();
104    /// let result = interpolator.interp_scalar(qx, qy).unwrap();
105    /// # assert_eq!(result, expected);
106    /// ```
107    pub fn interp_scalar(&self, x: Sx::Elem, y: Sy::Elem) -> Result<Sd::Elem, InterpolateError> {
108        let mut buffer = [cast(0.0).unwrap_or_else(|| unimplemented!())];
109        let buf_view = ArrayViewMut1::from(buffer.as_mut_slice()).remove_axis(Axis(0));
110        self.strategy
111            .interp_into(self, buf_view, x, y)
112            .map(|_| buffer[0])
113    }
114}
115
116impl<Sd, Sx, Sy, D, Strat> Interp2D<Sd, Sx, Sy, D, Strat>
117where
118    Sd: Data,
119    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
120    Sx: Data<Elem = Sd::Elem>,
121    Sy: Data<Elem = Sd::Elem>,
122    D: Dimension + RemoveAxis,
123    D::Smaller: RemoveAxis,
124    Strat: Interp2DStrategy<Sd, Sx, Sy, D>,
125{
126    /// Calculate the interpolated values at `(x, y)`.
127    /// Returns the interpolated data in an array two dimensions smaller than
128    /// the data dimension.
129    ///
130    /// Concider using [`interp_scalar(x, y)`](Interp2D::interp_scalar)
131    /// when the data dimension is [`type@Ix2`]
132    pub fn interp(
133        &self,
134        x: Sx::Elem,
135        y: Sy::Elem,
136    ) -> Result<Array<Sd::Elem, <D::Smaller as Dimension>::Smaller>, InterpolateError> {
137        let dim = self
138            .data
139            .raw_dim()
140            .remove_axis(Axis(0))
141            .remove_axis(Axis(0));
142        let mut target = Array::zeros(dim);
143        self.strategy
144            .interp_into(self, target.view_mut(), x, y)
145            .map(|_| target)
146    }
147
148    /// Calculate the interpolated values at `(x, y)`.
149    /// and stores the result into the provided buffer
150    ///
151    /// The provided buffer must have the same shape as the interpolation data
152    /// with the first two axes removed.
153    ///
154    /// This can improve performance compared to [`interp`](Interp2D::interp)
155    /// because it does not allocate any memory for the result
156    ///
157    /// # Panics
158    /// When the provided buffer is too small or has the wrong shape
159    #[inline]
160    pub fn interp_into(
161        &self,
162        x: Sx::Elem,
163        y: Sy::Elem,
164        buffer: ArrayViewMut<'_, Sd::Elem, <D::Smaller as Dimension>::Smaller>,
165    ) -> Result<(), InterpolateError> {
166        self.strategy.interp_into(self, buffer, x, y)
167    }
168
169    /// Calculate the interpolated values at all points in `(xs, ys)`
170    ///
171    /// See [`interp_array_into`](Interp2D::interp_array_into) for dimension information
172    ///
173    /// # panics
174    /// when `xs.shape() != ys.shape()`
175    pub fn interp_array<Sqx, Sqy, Dq>(
176        &self,
177        xs: &ArrayBase<Sqx, Dq>,
178        ys: &ArrayBase<Sqy, Dq>,
179    ) -> Result<
180        Array<Sd::Elem, <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output>,
181        InterpolateError,
182    >
183    where
184        Sqx: Data<Elem = Sd::Elem>,
185        Sqy: Data<Elem = Sy::Elem>,
186        Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller> + 'static,
187        <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
188    {
189        assert!(
190            xs.shape() == ys.shape(),
191            "`xs.shape()` and `ys.shape()` do not match"
192        );
193        let dim = self.get_buffer_shape(xs.raw_dim());
194        let mut zs = Array::zeros(dim);
195        self.interp_array_into(xs, ys, zs.view_mut()).map(|_| zs)
196    }
197
198    /// Calculate the interpolated values at all points in `(xs, ys)`
199    /// and stores the result into the provided buffer
200    ///
201    /// This can improve performance compared to [`interp_array`](Interp2D::interp_array)
202    /// because it does not allocate any memory for the result
203    ///
204    /// # Dimensions
205    /// given the data dimension `N` and the query dimension `M` the return array
206    /// will have the dimension `M + N - 2` where the fist `M` dimensions correspond
207    /// to the query dimenions of `xs` and `ys`
208    ///
209    /// Lets assume we hava a data dimension of `N = (2, 3, 4, 5)` and query this data
210    /// with an array of dimension `M = (10)`, the return dimension will be `(10, 4, 5)`
211    /// given a multi dimensional qurey of `M = (10, 20)` the return will be `(10, 20, 4, 5)`
212    ///
213    /// # panics
214    /// when `xs.shape() != ys.shape()` or when the provided buffer is too small or has the wrong shape
215    pub fn interp_array_into<Sqx, Sqy, Dq>(
216        &self,
217        xs: &ArrayBase<Sqx, Dq>,
218        ys: &ArrayBase<Sqy, Dq>,
219        mut buffer: ArrayViewMut<
220            Sd::Elem,
221            <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output,
222        >,
223    ) -> Result<(), InterpolateError>
224    where
225        Sqx: Data<Elem = Sd::Elem>,
226        Sqy: Data<Elem = Sy::Elem>,
227        Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller> + 'static,
228        <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
229    {
230        assert!(
231            xs.shape() == ys.shape(),
232            "`xs.shape()` and `ys.shape()` do not match"
233        );
234        if TypeId::of::<Dq>() == TypeId::of::<Ix1>() {
235            // Safety: We checked that `Dq` has type `Ix1`.
236            //    Therefor the `&ArrayBase<Sq, Dq>` and `&ArrayBase<Sq, Ix1>` must be the same type.
237            let xs_1d = unsafe { cast_unchecked::<&ArrayBase<Sqx, Dq>, &ArrayBase<Sqx, Ix1>>(xs) };
238            let ys_1d = unsafe { cast_unchecked::<&ArrayBase<Sqy, Dq>, &ArrayBase<Sqy, Ix1>>(ys) };
239            // Safety: `<Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output>` reducees the dimension of `D` by two,
240            //    and adds the dimension of `Dq`.
241            //    Given that `Dq` has type `Ix1` the resulting dimension will be `D::Smaller` again.
242            //    `D` might be of type `IxDyn` In that case `IxDyn::Smaller` => `IxDyn` and also `Ix1::DimAdd<IxDyn>::Output` => `IxDyn`
243            let buffer_d = unsafe {
244                cast_unchecked::<
245                    ArrayViewMut<
246                        Sd::Elem,
247                        <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output,
248                    >,
249                    ArrayViewMut<Sd::Elem, D::Smaller>,
250                >(buffer)
251            };
252            return self.interp_array_into_1d(xs_1d, ys_1d, buffer_d);
253        }
254
255        for (index, &x) in xs.indexed_iter() {
256            let current_dim = index.clone().into_dimension();
257            let y = *ys
258                .get(current_dim.clone())
259                .unwrap_or_else(|| unreachable!());
260            let subview =
261                buffer.slice_each_axis_mut(|AxisDescription { axis: Axis(nr), .. }| {
262                    match current_dim.as_array_view().get(nr) {
263                        Some(idx) => Slice::from(*idx..*idx + 1),
264                        None => Slice::from(..),
265                    }
266                });
267
268            let subview = match subview.into_shape_with_order(
269                self.data
270                    .raw_dim()
271                    .remove_axis(Axis(0))
272                    .remove_axis(Axis(0)),
273            ) {
274                Ok(view) => view,
275                Err(err) => {
276                    let expect = self.get_buffer_shape(xs.raw_dim()).into_pattern();
277                    let got = buffer.dim();
278                    panic!("{err} expected: {expect:?}, got: {got:?}")
279                }
280            };
281
282            self.strategy.interp_into(self, subview, x, y)?;
283        }
284        Ok(())
285    }
286
287    fn interp_array_into_1d<Sqx, Sqy>(
288        &self,
289        xs: &ArrayBase<Sqx, Ix1>,
290        ys: &ArrayBase<Sqy, Ix1>,
291        mut buffer: ArrayViewMut<'_, Sd::Elem, D::Smaller>,
292    ) -> Result<(), InterpolateError>
293    where
294        Sqx: Data<Elem = Sd::Elem>,
295        Sqy: Data<Elem = Sd::Elem>,
296    {
297        Zip::from(xs)
298            .and(ys)
299            .and(buffer.axis_iter_mut(Axis(0)))
300            .fold_while(Ok(()), |_, &x, &y, buf| {
301                match self.strategy.interp_into(self, buf, x, y) {
302                    Ok(_) => ndarray::FoldWhile::Continue(Ok(())),
303                    Err(e) => ndarray::FoldWhile::Done(Err(e)),
304                }
305            })
306            .into_inner()
307    }
308
309    /// the required shape of the buffer when calling [`interp_array_into`]
310    fn get_buffer_shape<Dq>(
311        &self,
312        dq: Dq,
313    ) -> <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output
314    where
315        Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller>,
316        <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
317    {
318        let binding = dq.as_array_view();
319        let lenghts = binding.iter().chain(self.data.shape()[2..].iter()).copied();
320        <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output::new(lenghts)
321    }
322
323    /// Create a interpolator without any data validation. This is fast and cheap.
324    ///
325    /// # Safety
326    /// The following data properties are assumed, but not checked:
327    ///  - `x` and `y` are stricktly monotonic rising
328    ///  - `data.shape()[0] == x.len()`, `data.shape()[1] == y.len()`
329    ///  - the `strategy` is porperly initialized with the data
330    pub fn new_unchecked(
331        x: ArrayBase<Sx, Ix1>,
332        y: ArrayBase<Sy, Ix1>,
333        data: ArrayBase<Sd, D>,
334        strategy: Strat,
335    ) -> Self {
336        Interp2D {
337            x,
338            y,
339            data,
340            strategy,
341        }
342    }
343
344    /// get `(x, y, data)` coordinate at the given index
345    ///
346    /// # panics
347    /// when index out of bounds
348    pub fn index_point(
349        &self,
350        x_idx: usize,
351        y_idx: usize,
352    ) -> (
353        Sx::Elem,
354        Sx::Elem,
355        ArrayView<Sd::Elem, <D::Smaller as Dimension>::Smaller>,
356    ) {
357        (
358            self.x[x_idx],
359            self.y[y_idx],
360            self.data
361                .index_axis(Axis(0), x_idx)
362                .index_axis_move(Axis(0), y_idx),
363        )
364    }
365
366    /// The index of a known value left of, or at x and y.
367    ///
368    /// This will never return the right most index,
369    /// so calling [`index_point(x_idx+1, y_idx+1)`](Interp2D::index_point) is always safe.
370    pub fn get_index_left_of(&self, x: Sx::Elem, y: Sy::Elem) -> (usize, usize) {
371        (self.x.get_lower_index(x), self.y.get_lower_index(y))
372    }
373
374    pub fn is_in_x_range(&self, x: Sx::Elem) -> bool {
375        self.x[0] <= x && x <= self.x[self.x.len() - 1]
376    }
377    pub fn is_in_y_range(&self, y: Sy::Elem) -> bool {
378        self.y[0] <= y && y <= self.y[self.y.len() - 1]
379    }
380}
381
382impl<Sd, D> Interp2DBuilder<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear>
383where
384    Sd: Data,
385    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
386    D: Dimension,
387{
388    pub fn new(data: ArrayBase<Sd, D>) -> Self {
389        let x = Array1::from_iter((0..data.shape()[0]).map(|i| {
390            cast(i).unwrap_or_else(|| {
391                unimplemented!("casting from usize to a number should always work")
392            })
393        }));
394        let y = Array1::from_iter((0..data.shape()[1]).map(|i| {
395            cast(i).unwrap_or_else(|| {
396                unimplemented!("casting from usize to a number should always work")
397            })
398        }));
399        Interp2DBuilder {
400            x,
401            y,
402            data,
403            strategy: Bilinear::new(),
404        }
405    }
406}
407
408impl<Sd, Sx, Sy, D, Strat> Interp2DBuilder<Sd, Sx, Sy, D, Strat>
409where
410    Sd: Data,
411    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
412    Sx: Data<Elem = Sd::Elem>,
413    Sy: Data<Elem = Sd::Elem>,
414    D: Dimension + RemoveAxis,
415    D::Smaller: RemoveAxis,
416    Strat: Interp2DStrategyBuilder<Sd, Sx, Sy, D>,
417{
418    /// Set the interpolation strategy by provideing a [`Interp2DStrategyBuilder`].
419    /// By default [`Bilinear`] is used.
420    pub fn strategy<NewStrat: Interp2DStrategyBuilder<Sd, Sx, Sy, D>>(
421        self,
422        strategy: NewStrat,
423    ) -> Interp2DBuilder<Sd, Sx, Sy, D, NewStrat> {
424        let Interp2DBuilder { x, y, data, .. } = self;
425        Interp2DBuilder {
426            x,
427            y,
428            data,
429            strategy,
430        }
431    }
432
433    /// Add an custom x axis for the data.
434    /// The axis must have the same lenght as the first axis of the data.
435    pub fn x<NewSx: Data<Elem = Sd::Elem>>(
436        self,
437        x: ArrayBase<NewSx, Ix1>,
438    ) -> Interp2DBuilder<Sd, NewSx, Sy, D, Strat> {
439        let Interp2DBuilder {
440            y, data, strategy, ..
441        } = self;
442        Interp2DBuilder {
443            x,
444            y,
445            data,
446            strategy,
447        }
448    }
449
450    /// Add an custom y axis for the data.
451    /// The axis must have the same lenght as the second axis of the data.
452    pub fn y<NewSy: Data<Elem = Sd::Elem>>(
453        self,
454        y: ArrayBase<NewSy, Ix1>,
455    ) -> Interp2DBuilder<Sd, Sx, NewSy, D, Strat> {
456        let Interp2DBuilder {
457            x, data, strategy, ..
458        } = self;
459        Interp2DBuilder {
460            x,
461            y,
462            data,
463            strategy,
464        }
465    }
466
467    /// Validate the input and create the configured [`Interp2D`]
468    pub fn build(self) -> Result<Interp2D<Sd, Sx, Sy, D, Strat::FinishedStrat>, BuilderError> {
469        use self::Monotonic::*;
470        use BuilderError::*;
471        let Interp2DBuilder {
472            x,
473            y,
474            data,
475            strategy: stratgy_builder,
476        } = self;
477        if data.ndim() < 2 {
478            return Err(ShapeError("data dimension needs to be at least 2".into()));
479        }
480        if data.shape()[0] < Strat::MINIMUM_DATA_LENGHT {
481            return Err(NotEnoughData(format!("The 0-dimension has not enough data for the chosen interpolation strategy. Provided: {}, Reqired: {}", data.shape()[0], Strat::MINIMUM_DATA_LENGHT)));
482        }
483        if data.shape()[1] < Strat::MINIMUM_DATA_LENGHT {
484            return Err(NotEnoughData(format!("The 1-dimension has not enough data for the chosen interpolation strategy. Provided: {}, Reqired: {}", data.shape()[1], Strat::MINIMUM_DATA_LENGHT)));
485        }
486        if x.len() != data.shape()[0] {
487            return Err(ShapeError(format!(
488                "Lenghts of x-axis and data-0-axis need to match. Got x: {}, data-0: {}",
489                x.len(),
490                data.shape()[0]
491            )));
492        }
493        if y.len() != data.shape()[1] {
494            return Err(ShapeError(format!(
495                "Lenghts of y-axis and data-1-axis need to match. Got y: {}, data-1: {}",
496                y.len(),
497                data.shape()[1]
498            )));
499        }
500        if !matches!(x.monotonic_prop(), Rising { strict: true }) {
501            return Err(Monotonic(
502                "The x-axis needs to be strictly monotonic rising".into(),
503            ));
504        }
505        if !matches!(y.monotonic_prop(), Rising { strict: true }) {
506            return Err(Monotonic(
507                "The y-axis needs to be strictly monotonic rising".into(),
508            ));
509        }
510
511        let strategy = stratgy_builder.build(&x, &y, &data)?;
512        Ok(Interp2D {
513            x,
514            y,
515            data,
516            strategy,
517        })
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use approx::assert_abs_diff_eq;
524    use ndarray::{array, Array, Array1, IxDyn};
525    use rand::{
526        distr::{uniform::SampleUniform, Uniform},
527        rngs::StdRng,
528        Rng, SeedableRng,
529    };
530
531    use super::Interp2D;
532
533    fn rand_arr<T: SampleUniform>(size: usize, range: (T, T), seed: u64) -> Array1<T> {
534        Array::from_iter(
535            StdRng::seed_from_u64(seed)
536                .sample_iter(Uniform::new_inclusive(range.0, range.1).unwrap())
537                .take(size),
538        )
539    }
540
541    macro_rules! test_dim {
542        ($name:ident, $dim:expr, $shape:expr) => {
543            #[test]
544            fn $name() {
545                let arr = rand_arr(4usize.pow($dim), (0.0, 1.0), 64)
546                    .into_shape_with_order($shape)
547                    .unwrap();
548                let interp = Interp2D::builder(arr).build().unwrap();
549                let res = interp.interp(2.2, 2.2).unwrap();
550                assert_eq!(res.ndim(), $dim - 2);
551
552                let mut buf = Array::zeros(res.dim());
553                interp.interp_into(2.2, 2.2, buf.view_mut()).unwrap();
554                assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
555
556                let x_query = array![[0.5, 1.0], [1.5, 2.0]];
557                let y_query = array![[1.5, 2.0], [2.5, 3.0]];
558                let res = interp.interp_array(&x_query, &y_query).unwrap();
559                assert_eq!(res.ndim(), $dim - 2 + x_query.ndim());
560
561                let mut buf = Array::zeros(res.dim());
562                interp
563                    .interp_array_into(&x_query, &y_query, buf.view_mut())
564                    .unwrap();
565                assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
566            }
567        };
568    }
569
570    test_dim!(interp2d_2d, 2, (4, 4));
571    test_dim!(interp2d_3d, 3, (4, 4, 4));
572    test_dim!(interp2d_4d, 4, (4, 4, 4, 4));
573    test_dim!(interp2d_5d, 5, (4, 4, 4, 4, 4));
574    test_dim!(interp2d_6d, 6, (4, 4, 4, 4, 4, 4));
575    test_dim!(interp2d_7d, 7, IxDyn(&[4, 4, 4, 4, 4, 4, 4]));
576    test_dim!(interp2d_8d, 8, IxDyn(&[4, 4, 4, 4, 4, 4, 4, 4]));
577
578    #[test]
579    fn interp2d_2d_scalar() {
580        let arr = rand_arr(4usize.pow(2), (0.0, 1.0), 64)
581            .into_shape_with_order((4, 4))
582            .unwrap();
583        let _res: f64 = Interp2D::builder(arr) // typecheck f64 as return type
584            .build()
585            .unwrap()
586            .interp_scalar(2.2, 2.2)
587            .unwrap();
588    }
589}