1use crate::{find_closest_neighbours_indices, is_monotonic};
28use ndarray::{ArrayBase, Axis, Data, Ix2, OwnedRepr};
29use std::ops::{Add, Div, Mul, Sub};
30
31#[derive(Debug)]
50pub struct Interp2D<C1, C2, Z, S>
51where
52 S: Data<Elem = Z>,
53{
54 x: Vec<C1>,
55 y: Vec<C2>,
56 z: ArrayBase<S, Ix2>,
57}
58
59impl<C1, C2, Z, S> Clone for Interp2D<C1, C2, Z, S>
60where
61 C1: Clone,
62 C2: Clone,
63 S: Data<Elem = Z>,
64 ArrayBase<S, Ix2>: Clone,
65{
66 fn clone(&self) -> Self {
67 Self {
68 x: self.x.clone(),
69 y: self.y.clone(),
70 z: self.z.clone(),
71 }
72 }
73}
74
75fn interpolate2d_bilinear<C, Z>(x00: Z, x10: Z, x01: Z, x11: Z, alpha: C, beta: C) -> Z
91where
92 C: Copy + Add<Output = C> + Sub<Output = C>,
93 Z: Copy + Mul<C, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
94{
95 x00 + (x10 - x00) * alpha + (x01 - x00) * beta + (x00 + x11 - x10 - x01) * alpha * beta
98}
99
100#[test]
101fn test_interpolate2d() {
102 let tol = 1e-6f64;
104
105 assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 0., 0.) - 1.).abs() < tol);
106 assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 1., 0.) - 2.).abs() < tol);
107 assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 0., 1.) - 3.).abs() < tol);
108 assert!((interpolate2d_bilinear(1.0f64, 2., 3., 4., 1., 1.) - 4.).abs() < tol);
109 assert!((interpolate2d_bilinear(0.0f64, 1., 1., 0., 0.5, 0.5) - 0.5).abs() < tol);
110}
111
112fn interp2d<C1, C2, Z>(
115 (x, y): (C1, C2),
116 (x0, x1): (C1, C1),
117 (y0, y1): (C2, C2),
118 (v00, v10, v01, v11): (Z, Z, Z, Z),
119) -> Z
120where
121 C1: Copy + Sub<Output = C1> + Div,
122 C2: Copy + Sub<Output = C2> + Div<Output = <C1 as Div>::Output>,
123 Z: Copy + Mul<<C1 as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
124 <C1 as Div>::Output:
125 Copy + Add<Output = <C1 as Div>::Output> + Sub<Output = <C1 as Div>::Output>,
126{
127 let dx = x1 - x0;
128 let dy = y1 - y0;
129
130 let alpha = (x - x0) / dx;
131 let beta = (y - y0) / dy;
132
133 interpolate2d_bilinear(v00, v10, v01, v11, alpha, beta)
134}
135
136impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
137where
138 S: Data<Elem = Z>,
139{
140 pub fn new(x: Vec<C1>, y: Vec<C2>, z: ArrayBase<S, Ix2>) -> Self
155 where
156 C1: PartialOrd,
157 C2: PartialOrd,
158 {
159 assert_eq!(z.len_of(Axis(0)), x.len(), "x-axis length mismatch.");
160 assert_eq!(z.len_of(Axis(1)), y.len(), "y-axis length mismatch.");
161 assert!(!x.is_empty());
162 assert!(!y.is_empty());
163
164 assert!(is_monotonic(&x), "x values must be monotonic.");
165 assert!(is_monotonic(&y), "x values must be monotonic.");
166
167 Self { x, y, z }
168 }
169 pub fn bounds(&self) -> ((C1, C1), (C2, C2))
172 where
173 C1: Copy,
174 C2: Copy,
175 {
176 (
177 (self.x[0], self.x[self.x.len() - 1]),
178 (self.y[0], self.y[self.y.len() - 1]),
179 )
180 }
181
182 pub fn is_within_bounds(&self, (x, y): (C1, C2)) -> bool
184 where
185 C1: PartialOrd + Copy,
186 C2: PartialOrd + Copy,
187 {
188 let ((x0, x1), (y0, y1)) = self.bounds();
189 x0 <= x && x <= x1 && y0 <= y && y <= y1
190 }
191
192 pub fn xs(&self) -> &Vec<C1> {
194 &self.x
195 }
196
197 pub fn ys(&self) -> &Vec<C2> {
199 &self.y
200 }
201
202 pub fn z(&self) -> &ArrayBase<S, Ix2> {
204 &self.z
205 }
206
207 pub fn map_values<Z2>(&self, f: impl Fn(&Z) -> Z2) -> Interp2D<C1, C2, Z2, OwnedRepr<Z2>>
209 where
210 C1: PartialOrd + Clone,
211 C2: PartialOrd + Clone,
212 {
213 Interp2D {
214 x: self.x.clone(),
215 y: self.y.clone(),
216 z: self.z.map(f),
217 }
218 }
219
220 pub fn map_x_axis<Xnew>(self, f: impl Fn(C1) -> Xnew) -> Interp2D<Xnew, C2, Z, S>
223 where
224 Xnew: PartialOrd,
225 {
226 let xnew = self.x.into_iter().map(f).collect();
227 assert!(is_monotonic(&xnew));
228 Interp2D {
229 x: xnew,
230 y: self.y,
231 z: self.z,
232 }
233 }
234 pub fn map_y_axis<Ynew>(self, f: impl Fn(C2) -> Ynew) -> Interp2D<C1, Ynew, Z, S>
237 where
238 Ynew: PartialOrd,
239 {
240 let ynew = self.y.into_iter().map(f).collect();
241 assert!(is_monotonic(&ynew));
242 Interp2D {
243 x: self.x,
244 y: ynew,
245 z: self.z,
246 }
247 }
248}
249
250impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
251where
252 C1: Copy + Sub<Output = C1> + Div + PartialOrd,
253 C2: Copy + Sub<Output = C2> + Div<Output = <C1 as Div>::Output> + PartialOrd,
254 Z: Copy + Mul<<C1 as Div>::Output, Output = Z> + Add<Output = Z> + Sub<Output = Z>,
255 <C1 as Div>::Output:
256 Copy + Add<Output = <C1 as Div>::Output> + Sub<Output = <C1 as Div>::Output>,
257 S: Data<Elem = Z> + Clone,
258{
259 pub fn eval(&self, (x, y): (C1, C2)) -> Z {
264 let (x0, x1) = find_closest_neighbours_indices(&self.x, x);
266 let (y0, y1) = find_closest_neighbours_indices(&self.y, y);
267
268 interp2d(
269 (x, y),
270 (self.x[x0], self.x[x1]),
271 (self.y[y0], self.y[y1]),
272 (
273 self.z[[x0, y0]],
274 self.z[[x1, y0]],
275 self.z[[x0, y1]],
276 self.z[[x1, y1]],
277 ),
278 )
279 }
280
281 pub fn eval_no_extrapolation(&self, xy: (C1, C2)) -> Option<Z> {
284 if self.is_within_bounds(xy) {
285 Some(self.eval(xy))
286 } else {
287 None
288 }
289 }
290}
291
292impl<C1, C2, Z, S> Interp2D<C1, C2, Z, S>
293where
294 Z: Clone,
295 S: Data<Elem = Z> + Clone,
296 OwnedRepr<Z>: Data<Elem = Z>,
297{
298 pub fn swap_variables(self) -> Interp2D<C2, C1, Z, OwnedRepr<Z>> {
300 let Self { x, y, z } = self;
301 Interp2D {
302 x: y,
303 y: x,
304 z: z.t().to_owned(),
305 }
306 }
307}
308
309#[test]
310fn test_interp2d_on_view() {
311 let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
312 let xs = vec![0.0, 1.0, 2.0];
313 let ys = vec![3.0, 4.0, 5.0];
314
315 let interp = Interp2D::new(xs, ys, grid.view());
316
317 let tol = 1e-6f64;
319
320 assert!((interp.eval_no_extrapolation((1.0, 4.0)).unwrap() - 1.0).abs() < tol);
321}
322
323#[test]
324fn test_interp2d() {
325 let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
326 let xs = vec![0.0, 1.0, 2.0];
327 let ys = vec![3.0, 4.0, 5.0];
328
329 let interp = Interp2D::new(xs, ys, grid);
330
331 let tol = 1e-6f64;
333
334 assert!((interp.eval_no_extrapolation((1.0, 4.0)).unwrap() - 1.0).abs() < tol);
335 assert!((interp.eval_no_extrapolation((0.0, 3.0)).unwrap() - 0.0).abs() < tol);
336 assert!((interp.eval_no_extrapolation((0.5, 3.5)).unwrap() - 0.25).abs() < tol);
337
338 }
345
346#[test]
347fn test_map_values() {
348 let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
349 let xs = vec![0.0, 1.0, 2.0];
350 let ys = vec![3.0, 4.0, 5.0];
351
352 let interp = Interp2D::new(xs, ys, grid);
353
354 let interp_doubled = interp.map_values(|&v| 2. * v);
355
356 let tol = 1e-6;
358
359 assert!((interp_doubled.eval_no_extrapolation((1.0, 4.0)).unwrap() - 2.0).abs() < tol);
360}
361
362#[test]
363fn test_map_x_axis() {
364 let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
365 let xs = vec![0.0, 1.0, 2.0];
366 let ys = vec![3.0, 4.0, 5.0];
367
368 let interp = Interp2D::new(xs, ys, grid);
369
370 let shift = 1.0;
371 let interp_shifted = interp.map_x_axis(|x| x + shift);
372
373 let tol = 1e-6;
375
376 assert!(
377 (interp_shifted
378 .eval_no_extrapolation((1.0 + shift, 4.0))
379 .unwrap()
380 - 1.0)
381 .abs()
382 < tol
383 );
384}
385
386#[test]
387fn test_map_y_axis() {
388 let grid = ndarray::array![[0.0f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
389 let xs = vec![0.0, 1.0, 2.0];
390 let ys = vec![3.0, 4.0, 5.0];
391
392 let interp = Interp2D::new(xs, ys, grid);
393
394 let shift = 1.0;
395 let interp_shifted = interp.map_y_axis(|y| y + shift);
396
397 let tol = 1e-6;
399
400 assert!(
401 (interp_shifted
402 .eval_no_extrapolation((1.0, 4.0 + shift))
403 .unwrap()
404 - 1.0)
405 .abs()
406 < tol
407 );
408}