ndarray_interp/interp2d/strategies/
bilinear.rs

1use std::{fmt::Debug, ops::Sub};
2
3use ndarray::{Data, Dimension, RemoveAxis, Zip};
4use num_traits::{Num, NumCast};
5
6use crate::{interp1d::Linear, InterpolateError};
7
8use super::{Interp2DStrategy, Interp2DStrategyBuilder};
9
10#[derive(Debug)]
11pub struct Bilinear {
12    extrapolate: bool,
13}
14
15impl Bilinear {
16    pub fn new() -> Self {
17        Bilinear { extrapolate: false }
18    }
19
20    pub fn extrapolate(mut self, yes: bool) -> Self {
21        self.extrapolate = yes;
22        self
23    }
24}
25
26impl Default for Bilinear {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl<Sd, Sx, Sy, D> Interp2DStrategyBuilder<Sd, Sx, Sy, D> for Bilinear
33where
34    Sd: Data,
35    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
36    Sx: Data<Elem = Sd::Elem>,
37    Sy: Data<Elem = Sd::Elem>,
38    D: Dimension + RemoveAxis,
39    D::Smaller: RemoveAxis,
40{
41    const MINIMUM_DATA_LENGHT: usize = 2;
42
43    type FinishedStrat = Self;
44
45    fn build(
46        self,
47        _x: &ndarray::ArrayBase<Sx, ndarray::Ix1>,
48        _y: &ndarray::ArrayBase<Sy, ndarray::Ix1>,
49        _data: &ndarray::ArrayBase<Sd, D>,
50    ) -> Result<Self::FinishedStrat, crate::BuilderError> {
51        Ok(self)
52    }
53}
54
55impl<Sd, Sx, Sy, D> Interp2DStrategy<Sd, Sx, Sy, D> for Bilinear
56where
57    Sd: Data,
58    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
59    Sx: Data<Elem = Sd::Elem>,
60    Sy: Data<Elem = Sd::Elem>,
61    D: Dimension + RemoveAxis,
62    D::Smaller: RemoveAxis,
63{
64    fn interp_into(
65        &self,
66        interpolator: &crate::interp2d::Interp2D<Sd, Sx, Sy, D, Self>,
67        target: ndarray::ArrayViewMut<'_, <Sd>::Elem, <D::Smaller as Dimension>::Smaller>,
68        x: <Sx>::Elem,
69        y: <Sy>::Elem,
70    ) -> Result<(), crate::InterpolateError> {
71        if !self.extrapolate && !interpolator.is_in_x_range(x) {
72            return Err(InterpolateError::OutOfBounds(format!(
73                "x = {x:?} is not in range"
74            )));
75        }
76        if !self.extrapolate && !interpolator.is_in_y_range(y) {
77            return Err(InterpolateError::OutOfBounds(format!(
78                "y = {y:?} is not in range"
79            )));
80        }
81
82        let (x_idx, y_idx) = interpolator.get_index_left_of(x, y);
83        let (x1, y1, z11) = interpolator.index_point(x_idx, y_idx);
84        let (_, _, z12) = interpolator.index_point(x_idx, y_idx + 1);
85        let (_, _, z21) = interpolator.index_point(x_idx + 1, y_idx);
86        let (x2, y2, z22) = interpolator.index_point(x_idx + 1, y_idx + 1);
87
88        Zip::from(z11)
89            .and(z12)
90            .and(z21)
91            .and(z22)
92            .and(target)
93            .for_each(|&z11, &z12, &z21, &z22, z| {
94                let z1 = Linear::calc_frac((x1, z11), (x2, z21), x);
95                let z2 = Linear::calc_frac((x1, z12), (x2, z22), x);
96                *z = Linear::calc_frac((y1, z1), (y2, z2), y)
97            });
98        Ok(())
99    }
100}