rsl_interpolation/types/
bilinear.rs

1use std::marker::PhantomData;
2
3use crate::Accelerator;
4use crate::DomainError;
5use crate::Interp2dType;
6use crate::Interpolation2d;
7use crate::InterpolationError;
8use crate::interp2d::{acc_indices, partials, xy_grid_indices, z_grid_indices};
9use crate::types::utils::check_if_inbounds;
10use crate::types::utils::check2d_data;
11
12const MIN_SIZE: usize = 2;
13
14/// Bilinear Interpolation type.
15///
16/// The simplest type of 2d Interpolation.
17#[doc(alias = "gsl_interp2d_bilinear")]
18pub struct Bilinear;
19
20impl<T> Interp2dType<T> for Bilinear
21where
22    T: crate::Num,
23{
24    type Interpolation2d = BilinearInterp<T>;
25
26    /// Constructs a Bilinear Interpolator.
27    ///
28    /// # Example
29    ///
30    /// ```
31    /// # use rsl_interpolation::*;
32    /// #
33    /// # fn main() -> Result<(), InterpolationError>{
34    /// let xa = [0.0, 1.0, 2.0];
35    /// let ya = [0.0, 2.0, 4.0];
36    /// // z = x + y, in column-major order
37    /// let za = [
38    ///     0.0, 1.0, 2.0,
39    ///     2.0, 3.0, 4.0,
40    ///     4.0, 5.0, 6.0,
41    /// ];
42    ///
43    /// let interp = Bilinear.build(&xa, &ya, &za)?;
44    /// # Ok(())
45    /// # }
46    /// ```
47    fn build(&self, xa: &[T], ya: &[T], za: &[T]) -> Result<BilinearInterp<T>, InterpolationError> {
48        check2d_data(xa, ya, za, MIN_SIZE)?;
49
50        Ok(BilinearInterp {
51            _variable_type: PhantomData,
52        })
53    }
54
55    fn name(&self) -> &str {
56        "Bilinear"
57    }
58
59    fn min_size(&self) -> usize {
60        MIN_SIZE
61    }
62}
63
64// ===============================================================================================
65
66/// Bilinear Interpolator.
67///
68/// Provides all the evaluation methods.
69///
70/// Should be constructed through the [`Bilinear`] type.
71pub struct BilinearInterp<T> {
72    _variable_type: PhantomData<T>,
73}
74
75impl<T> Interpolation2d<T> for BilinearInterp<T>
76where
77    T: crate::Num,
78{
79    fn eval_extrap(
80        &self,
81        xa: &[T],
82        ya: &[T],
83        za: &[T],
84        x: T,
85        y: T,
86        xacc: &mut Accelerator,
87        yacc: &mut Accelerator,
88    ) -> Result<T, DomainError> {
89        let (xi, yi) = acc_indices(xa, ya, x, y, xacc, yacc);
90        let (xlo, xhi, ylo, yhi) = xy_grid_indices(xa, ya, xi, yi);
91        let (zlolo, zlohi, zhilo, zhihi) = z_grid_indices(za, xa.len(), ya.len(), xi, yi)?;
92        let (dx, dy) = partials(xlo, xhi, ylo, yhi);
93
94        debug_assert!(dx > T::zero());
95        debug_assert!(dy > T::zero());
96
97        let t = (x - xlo) / dx;
98        let u = (y - ylo) / dy;
99
100        let one = T::one();
101        let result = (one - t) * (one - u) * zlolo
102            + t * (one - u) * zhilo
103            + (one - t) * u * zlohi
104            + t * u * zhihi;
105        Ok(result)
106    }
107
108    fn eval_deriv_x(
109        &self,
110        xa: &[T],
111        ya: &[T],
112        za: &[T],
113        x: T,
114        y: T,
115        xacc: &mut Accelerator,
116        yacc: &mut Accelerator,
117    ) -> Result<T, DomainError> {
118        check_if_inbounds(xa, x)?;
119        check_if_inbounds(ya, y)?;
120
121        let (xi, yi) = acc_indices(xa, ya, x, y, xacc, yacc);
122        let (xlo, xhi, ylo, yhi) = xy_grid_indices(xa, ya, xi, yi);
123        let (zlolo, zlohi, zhilo, zhihi) = z_grid_indices(za, xa.len(), ya.len(), xi, yi)?;
124        let (dx, dy) = partials(xlo, xhi, ylo, yhi);
125
126        debug_assert!(dx > T::zero());
127        debug_assert!(dy > T::zero());
128
129        let one = T::one();
130        let dt = one / dx;
131        let u = (y - ylo) / dy;
132
133        let result = dt * (-(one - u) * zlolo + (one - u) * zhilo - u * zlohi + u * zhihi);
134        Ok(result)
135    }
136
137    fn eval_deriv_y(
138        &self,
139        xa: &[T],
140        ya: &[T],
141        za: &[T],
142        x: T,
143        y: T,
144        xacc: &mut Accelerator,
145        yacc: &mut Accelerator,
146    ) -> Result<T, DomainError> {
147        check_if_inbounds(xa, x)?;
148        check_if_inbounds(ya, y)?;
149
150        let (xi, yi) = acc_indices(xa, ya, x, y, xacc, yacc);
151        let (xlo, xhi, ylo, yhi) = xy_grid_indices(xa, ya, xi, yi);
152        let (zlolo, zlohi, zhilo, zhihi) = z_grid_indices(za, xa.len(), ya.len(), xi, yi)?;
153        let (dx, dy) = partials(xlo, xhi, ylo, yhi);
154
155        debug_assert!(dx > T::zero());
156        debug_assert!(dy > T::zero());
157
158        let one = T::one();
159        let t = (x - xlo) / dx;
160        let du = one / dy;
161
162        let result = du * (-(one - t) * zlolo - t * zhilo + (one - t) * zlohi + t * zhihi);
163        Ok(result)
164    }
165
166    #[allow(unused_variables)]
167    fn eval_deriv_xx(
168        &self,
169        xa: &[T],
170        ya: &[T],
171        za: &[T],
172        x: T,
173        y: T,
174        xacc: &mut Accelerator,
175        yacc: &mut Accelerator,
176    ) -> Result<T, DomainError> {
177        check_if_inbounds(xa, x)?;
178        check_if_inbounds(ya, y)?;
179
180        Ok(T::zero())
181    }
182
183    #[allow(unused_variables)]
184    fn eval_deriv_yy(
185        &self,
186        xa: &[T],
187        ya: &[T],
188        za: &[T],
189        x: T,
190        y: T,
191        xacc: &mut Accelerator,
192        yacc: &mut Accelerator,
193    ) -> Result<T, DomainError> {
194        check_if_inbounds(xa, x)?;
195        check_if_inbounds(ya, y)?;
196
197        Ok(T::zero())
198    }
199
200    fn eval_deriv_xy(
201        &self,
202        xa: &[T],
203        ya: &[T],
204        za: &[T],
205        x: T,
206        y: T,
207        xacc: &mut Accelerator,
208        yacc: &mut Accelerator,
209    ) -> Result<T, DomainError> {
210        check_if_inbounds(xa, x)?;
211        check_if_inbounds(ya, y)?;
212
213        let (xi, yi) = acc_indices(xa, ya, x, y, xacc, yacc);
214        let (xlo, xhi, ylo, yhi) = xy_grid_indices(xa, ya, xi, yi);
215        let (zlolo, zlohi, zhilo, zhihi) = z_grid_indices(za, xa.len(), ya.len(), xi, yi)?;
216        let (dx, dy) = partials(xlo, xhi, ylo, yhi);
217
218        let one = T::one();
219        let dt = one / dx;
220        let du = one / dy;
221
222        let result = dt * du * (zlolo - zhilo - zlohi + zhihi);
223        Ok(result)
224    }
225}