ndarray_interp/interp2d/strategies/
bilinear.rs1use std::{fmt::Debug, ops::Sub};
2
3use ndarray::{Data, Dimension, RemoveAxis, Zip};
4use num_traits::{Num, NumCast};
5
6use crate::{interp1d::Linear, InterpolateError};
7
8use super::{Interp2DStrategy, Interp2DStrategyBuilder};
9
10#[derive(Debug)]
11pub struct Bilinear {
12 extrapolate: bool,
13}
14
15impl Bilinear {
16 pub fn new() -> Self {
17 Bilinear { extrapolate: false }
18 }
19
20 pub fn extrapolate(mut self, yes: bool) -> Self {
21 self.extrapolate = yes;
22 self
23 }
24}
25
26impl Default for Bilinear {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl<Sd, Sx, Sy, D> Interp2DStrategyBuilder<Sd, Sx, Sy, D> for Bilinear
33where
34 Sd: Data,
35 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
36 Sx: Data<Elem = Sd::Elem>,
37 Sy: Data<Elem = Sd::Elem>,
38 D: Dimension + RemoveAxis,
39 D::Smaller: RemoveAxis,
40{
41 const MINIMUM_DATA_LENGHT: usize = 2;
42
43 type FinishedStrat = Self;
44
45 fn build(
46 self,
47 _x: &ndarray::ArrayBase<Sx, ndarray::Ix1>,
48 _y: &ndarray::ArrayBase<Sy, ndarray::Ix1>,
49 _data: &ndarray::ArrayBase<Sd, D>,
50 ) -> Result<Self::FinishedStrat, crate::BuilderError> {
51 Ok(self)
52 }
53}
54
55impl<Sd, Sx, Sy, D> Interp2DStrategy<Sd, Sx, Sy, D> for Bilinear
56where
57 Sd: Data,
58 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
59 Sx: Data<Elem = Sd::Elem>,
60 Sy: Data<Elem = Sd::Elem>,
61 D: Dimension + RemoveAxis,
62 D::Smaller: RemoveAxis,
63{
64 fn interp_into(
65 &self,
66 interpolator: &crate::interp2d::Interp2D<Sd, Sx, Sy, D, Self>,
67 target: ndarray::ArrayViewMut<'_, <Sd>::Elem, <D::Smaller as Dimension>::Smaller>,
68 x: <Sx>::Elem,
69 y: <Sy>::Elem,
70 ) -> Result<(), crate::InterpolateError> {
71 if !self.extrapolate && !interpolator.is_in_x_range(x) {
72 return Err(InterpolateError::OutOfBounds(format!(
73 "x = {x:?} is not in range"
74 )));
75 }
76 if !self.extrapolate && !interpolator.is_in_y_range(y) {
77 return Err(InterpolateError::OutOfBounds(format!(
78 "y = {y:?} is not in range"
79 )));
80 }
81
82 let (x_idx, y_idx) = interpolator.get_index_left_of(x, y);
83 let (x1, y1, z11) = interpolator.index_point(x_idx, y_idx);
84 let (_, _, z12) = interpolator.index_point(x_idx, y_idx + 1);
85 let (_, _, z21) = interpolator.index_point(x_idx + 1, y_idx);
86 let (x2, y2, z22) = interpolator.index_point(x_idx + 1, y_idx + 1);
87
88 Zip::from(z11)
89 .and(z12)
90 .and(z21)
91 .and(z22)
92 .and(target)
93 .for_each(|&z11, &z12, &z21, &z22, z| {
94 let z1 = Linear::calc_frac((x1, z11), (x2, z21), x);
95 let z2 = Linear::calc_frac((x1, z12), (x2, z22), x);
96 *z = Linear::calc_frac((y1, z1), (y2, z2), y)
97 });
98 Ok(())
99 }
100}