1use std::{any::TypeId, fmt::Debug, ops::Sub};
15
16use ndarray::{
17 Array, Array1, ArrayBase, ArrayView, ArrayViewMut, ArrayViewMut1, Axis, AxisDescription, Data,
18 DimAdd, Dimension, IntoDimension, Ix1, Ix2, OwnedRepr, RemoveAxis, Slice, Zip,
19};
20use num_traits::{cast, Num, NumCast};
21
22use crate::{
23 cast_unchecked,
24 dim_extensions::DimExtension,
25 vector_extensions::{Monotonic, VectorExtensions},
26 BuilderError, InterpolateError,
27};
28
29mod aliases;
30mod strategies;
31pub use aliases::*;
32pub use strategies::{Bilinear, Interp2DStrategy, Interp2DStrategyBuilder};
33
34#[derive(Debug)]
36pub struct Interp2D<Sd, Sx, Sy, D, Strat>
37where
38 Sd: Data,
39 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
40 Sx: Data<Elem = Sd::Elem>,
41 Sy: Data<Elem = Sd::Elem>,
42 D: Dimension,
43{
44 x: ArrayBase<Sx, Ix1>,
45 y: ArrayBase<Sy, Ix1>,
46 data: ArrayBase<Sd, D>,
47 strategy: Strat,
48}
49
50#[derive(Debug)]
52pub struct Interp2DBuilder<Sd, Sx, Sy, D, Strat>
53where
54 Sd: Data,
55 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
56 Sx: Data<Elem = Sd::Elem>,
57 Sy: Data<Elem = Sd::Elem>,
58 D: Dimension,
59{
60 x: ArrayBase<Sx, Ix1>,
61 y: ArrayBase<Sy, Ix1>,
62 data: ArrayBase<Sd, D>,
63 strategy: Strat,
64}
65
66impl<Sd, D> Interp2D<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear>
67where
68 Sd: Data,
69 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
70 D: Dimension,
71{
72 pub fn builder(
74 data: ArrayBase<Sd, D>,
75 ) -> Interp2DBuilder<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear> {
76 Interp2DBuilder::new(data)
77 }
78}
79
80impl<Sd, Sx, Sy, Strat> Interp2D<Sd, Sx, Sy, Ix2, Strat>
81where
82 Sd: Data,
83 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
84 Sx: Data<Elem = Sd::Elem>,
85 Sy: Data<Elem = Sd::Elem>,
86 Strat: Interp2DStrategy<Sd, Sx, Sy, Ix2>,
87{
88 pub fn interp_scalar(&self, x: Sx::Elem, y: Sy::Elem) -> Result<Sd::Elem, InterpolateError> {
108 let mut buffer = [cast(0.0).unwrap_or_else(|| unimplemented!())];
109 let buf_view = ArrayViewMut1::from(buffer.as_mut_slice()).remove_axis(Axis(0));
110 self.strategy
111 .interp_into(self, buf_view, x, y)
112 .map(|_| buffer[0])
113 }
114}
115
116impl<Sd, Sx, Sy, D, Strat> Interp2D<Sd, Sx, Sy, D, Strat>
117where
118 Sd: Data,
119 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
120 Sx: Data<Elem = Sd::Elem>,
121 Sy: Data<Elem = Sd::Elem>,
122 D: Dimension + RemoveAxis,
123 D::Smaller: RemoveAxis,
124 Strat: Interp2DStrategy<Sd, Sx, Sy, D>,
125{
126 pub fn interp(
133 &self,
134 x: Sx::Elem,
135 y: Sy::Elem,
136 ) -> Result<Array<Sd::Elem, <D::Smaller as Dimension>::Smaller>, InterpolateError> {
137 let dim = self
138 .data
139 .raw_dim()
140 .remove_axis(Axis(0))
141 .remove_axis(Axis(0));
142 let mut target = Array::zeros(dim);
143 self.strategy
144 .interp_into(self, target.view_mut(), x, y)
145 .map(|_| target)
146 }
147
148 #[inline]
160 pub fn interp_into(
161 &self,
162 x: Sx::Elem,
163 y: Sy::Elem,
164 buffer: ArrayViewMut<'_, Sd::Elem, <D::Smaller as Dimension>::Smaller>,
165 ) -> Result<(), InterpolateError> {
166 self.strategy.interp_into(self, buffer, x, y)
167 }
168
169 pub fn interp_array<Sqx, Sqy, Dq>(
176 &self,
177 xs: &ArrayBase<Sqx, Dq>,
178 ys: &ArrayBase<Sqy, Dq>,
179 ) -> Result<
180 Array<Sd::Elem, <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output>,
181 InterpolateError,
182 >
183 where
184 Sqx: Data<Elem = Sd::Elem>,
185 Sqy: Data<Elem = Sy::Elem>,
186 Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller> + 'static,
187 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
188 {
189 assert!(
190 xs.shape() == ys.shape(),
191 "`xs.shape()` and `ys.shape()` do not match"
192 );
193 let dim = self.get_buffer_shape(xs.raw_dim());
194 let mut zs = Array::zeros(dim);
195 self.interp_array_into(xs, ys, zs.view_mut()).map(|_| zs)
196 }
197
198 pub fn interp_array_into<Sqx, Sqy, Dq>(
216 &self,
217 xs: &ArrayBase<Sqx, Dq>,
218 ys: &ArrayBase<Sqy, Dq>,
219 mut buffer: ArrayViewMut<
220 Sd::Elem,
221 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output,
222 >,
223 ) -> Result<(), InterpolateError>
224 where
225 Sqx: Data<Elem = Sd::Elem>,
226 Sqy: Data<Elem = Sy::Elem>,
227 Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller> + 'static,
228 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
229 {
230 assert!(
231 xs.shape() == ys.shape(),
232 "`xs.shape()` and `ys.shape()` do not match"
233 );
234 if TypeId::of::<Dq>() == TypeId::of::<Ix1>() {
235 let xs_1d = unsafe { cast_unchecked::<&ArrayBase<Sqx, Dq>, &ArrayBase<Sqx, Ix1>>(xs) };
238 let ys_1d = unsafe { cast_unchecked::<&ArrayBase<Sqy, Dq>, &ArrayBase<Sqy, Ix1>>(ys) };
239 let buffer_d = unsafe {
244 cast_unchecked::<
245 ArrayViewMut<
246 Sd::Elem,
247 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output,
248 >,
249 ArrayViewMut<Sd::Elem, D::Smaller>,
250 >(buffer)
251 };
252 return self.interp_array_into_1d(xs_1d, ys_1d, buffer_d);
253 }
254
255 for (index, &x) in xs.indexed_iter() {
256 let current_dim = index.clone().into_dimension();
257 let y = *ys
258 .get(current_dim.clone())
259 .unwrap_or_else(|| unreachable!());
260 let subview =
261 buffer.slice_each_axis_mut(|AxisDescription { axis: Axis(nr), .. }| {
262 match current_dim.as_array_view().get(nr) {
263 Some(idx) => Slice::from(*idx..*idx + 1),
264 None => Slice::from(..),
265 }
266 });
267
268 let subview = match subview.into_shape_with_order(
269 self.data
270 .raw_dim()
271 .remove_axis(Axis(0))
272 .remove_axis(Axis(0)),
273 ) {
274 Ok(view) => view,
275 Err(err) => {
276 let expect = self.get_buffer_shape(xs.raw_dim()).into_pattern();
277 let got = buffer.dim();
278 panic!("{err} expected: {expect:?}, got: {got:?}")
279 }
280 };
281
282 self.strategy.interp_into(self, subview, x, y)?;
283 }
284 Ok(())
285 }
286
287 fn interp_array_into_1d<Sqx, Sqy>(
288 &self,
289 xs: &ArrayBase<Sqx, Ix1>,
290 ys: &ArrayBase<Sqy, Ix1>,
291 mut buffer: ArrayViewMut<'_, Sd::Elem, D::Smaller>,
292 ) -> Result<(), InterpolateError>
293 where
294 Sqx: Data<Elem = Sd::Elem>,
295 Sqy: Data<Elem = Sd::Elem>,
296 {
297 Zip::from(xs)
298 .and(ys)
299 .and(buffer.axis_iter_mut(Axis(0)))
300 .fold_while(Ok(()), |_, &x, &y, buf| {
301 match self.strategy.interp_into(self, buf, x, y) {
302 Ok(_) => ndarray::FoldWhile::Continue(Ok(())),
303 Err(e) => ndarray::FoldWhile::Done(Err(e)),
304 }
305 })
306 .into_inner()
307 }
308
309 fn get_buffer_shape<Dq>(
311 &self,
312 dq: Dq,
313 ) -> <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output
314 where
315 Dq: Dimension + DimAdd<<D::Smaller as Dimension>::Smaller>,
316 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output: DimExtension,
317 {
318 let binding = dq.as_array_view();
319 let lenghts = binding.iter().chain(self.data.shape()[2..].iter()).copied();
320 <Dq as DimAdd<<D::Smaller as Dimension>::Smaller>>::Output::new(lenghts)
321 }
322
323 pub fn new_unchecked(
331 x: ArrayBase<Sx, Ix1>,
332 y: ArrayBase<Sy, Ix1>,
333 data: ArrayBase<Sd, D>,
334 strategy: Strat,
335 ) -> Self {
336 Interp2D {
337 x,
338 y,
339 data,
340 strategy,
341 }
342 }
343
344 pub fn index_point(
349 &self,
350 x_idx: usize,
351 y_idx: usize,
352 ) -> (
353 Sx::Elem,
354 Sx::Elem,
355 ArrayView<Sd::Elem, <D::Smaller as Dimension>::Smaller>,
356 ) {
357 (
358 self.x[x_idx],
359 self.y[y_idx],
360 self.data
361 .index_axis(Axis(0), x_idx)
362 .index_axis_move(Axis(0), y_idx),
363 )
364 }
365
366 pub fn get_index_left_of(&self, x: Sx::Elem, y: Sy::Elem) -> (usize, usize) {
371 (self.x.get_lower_index(x), self.y.get_lower_index(y))
372 }
373
374 pub fn is_in_x_range(&self, x: Sx::Elem) -> bool {
375 self.x[0] <= x && x <= self.x[self.x.len() - 1]
376 }
377 pub fn is_in_y_range(&self, y: Sy::Elem) -> bool {
378 self.y[0] <= y && y <= self.y[self.y.len() - 1]
379 }
380}
381
382impl<Sd, D> Interp2DBuilder<Sd, OwnedRepr<Sd::Elem>, OwnedRepr<Sd::Elem>, D, Bilinear>
383where
384 Sd: Data,
385 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub,
386 D: Dimension,
387{
388 pub fn new(data: ArrayBase<Sd, D>) -> Self {
389 let x = Array1::from_iter((0..data.shape()[0]).map(|i| {
390 cast(i).unwrap_or_else(|| {
391 unimplemented!("casting from usize to a number should always work")
392 })
393 }));
394 let y = Array1::from_iter((0..data.shape()[1]).map(|i| {
395 cast(i).unwrap_or_else(|| {
396 unimplemented!("casting from usize to a number should always work")
397 })
398 }));
399 Interp2DBuilder {
400 x,
401 y,
402 data,
403 strategy: Bilinear::new(),
404 }
405 }
406}
407
408impl<Sd, Sx, Sy, D, Strat> Interp2DBuilder<Sd, Sx, Sy, D, Strat>
409where
410 Sd: Data,
411 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
412 Sx: Data<Elem = Sd::Elem>,
413 Sy: Data<Elem = Sd::Elem>,
414 D: Dimension + RemoveAxis,
415 D::Smaller: RemoveAxis,
416 Strat: Interp2DStrategyBuilder<Sd, Sx, Sy, D>,
417{
418 pub fn strategy<NewStrat: Interp2DStrategyBuilder<Sd, Sx, Sy, D>>(
421 self,
422 strategy: NewStrat,
423 ) -> Interp2DBuilder<Sd, Sx, Sy, D, NewStrat> {
424 let Interp2DBuilder { x, y, data, .. } = self;
425 Interp2DBuilder {
426 x,
427 y,
428 data,
429 strategy,
430 }
431 }
432
433 pub fn x<NewSx: Data<Elem = Sd::Elem>>(
436 self,
437 x: ArrayBase<NewSx, Ix1>,
438 ) -> Interp2DBuilder<Sd, NewSx, Sy, D, Strat> {
439 let Interp2DBuilder {
440 y, data, strategy, ..
441 } = self;
442 Interp2DBuilder {
443 x,
444 y,
445 data,
446 strategy,
447 }
448 }
449
450 pub fn y<NewSy: Data<Elem = Sd::Elem>>(
453 self,
454 y: ArrayBase<NewSy, Ix1>,
455 ) -> Interp2DBuilder<Sd, Sx, NewSy, D, Strat> {
456 let Interp2DBuilder {
457 x, data, strategy, ..
458 } = self;
459 Interp2DBuilder {
460 x,
461 y,
462 data,
463 strategy,
464 }
465 }
466
467 pub fn build(self) -> Result<Interp2D<Sd, Sx, Sy, D, Strat::FinishedStrat>, BuilderError> {
469 use self::Monotonic::*;
470 use BuilderError::*;
471 let Interp2DBuilder {
472 x,
473 y,
474 data,
475 strategy: stratgy_builder,
476 } = self;
477 if data.ndim() < 2 {
478 return Err(ShapeError("data dimension needs to be at least 2".into()));
479 }
480 if data.shape()[0] < Strat::MINIMUM_DATA_LENGHT {
481 return Err(NotEnoughData(format!("The 0-dimension has not enough data for the chosen interpolation strategy. Provided: {}, Reqired: {}", data.shape()[0], Strat::MINIMUM_DATA_LENGHT)));
482 }
483 if data.shape()[1] < Strat::MINIMUM_DATA_LENGHT {
484 return Err(NotEnoughData(format!("The 1-dimension has not enough data for the chosen interpolation strategy. Provided: {}, Reqired: {}", data.shape()[1], Strat::MINIMUM_DATA_LENGHT)));
485 }
486 if x.len() != data.shape()[0] {
487 return Err(ShapeError(format!(
488 "Lenghts of x-axis and data-0-axis need to match. Got x: {}, data-0: {}",
489 x.len(),
490 data.shape()[0]
491 )));
492 }
493 if y.len() != data.shape()[1] {
494 return Err(ShapeError(format!(
495 "Lenghts of y-axis and data-1-axis need to match. Got y: {}, data-1: {}",
496 y.len(),
497 data.shape()[1]
498 )));
499 }
500 if !matches!(x.monotonic_prop(), Rising { strict: true }) {
501 return Err(Monotonic(
502 "The x-axis needs to be strictly monotonic rising".into(),
503 ));
504 }
505 if !matches!(y.monotonic_prop(), Rising { strict: true }) {
506 return Err(Monotonic(
507 "The y-axis needs to be strictly monotonic rising".into(),
508 ));
509 }
510
511 let strategy = stratgy_builder.build(&x, &y, &data)?;
512 Ok(Interp2D {
513 x,
514 y,
515 data,
516 strategy,
517 })
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use approx::assert_abs_diff_eq;
524 use ndarray::{array, Array, Array1, IxDyn};
525 use rand::{
526 distr::{uniform::SampleUniform, Uniform},
527 rngs::StdRng,
528 Rng, SeedableRng,
529 };
530
531 use super::Interp2D;
532
533 fn rand_arr<T: SampleUniform>(size: usize, range: (T, T), seed: u64) -> Array1<T> {
534 Array::from_iter(
535 StdRng::seed_from_u64(seed)
536 .sample_iter(Uniform::new_inclusive(range.0, range.1).unwrap())
537 .take(size),
538 )
539 }
540
541 macro_rules! test_dim {
542 ($name:ident, $dim:expr, $shape:expr) => {
543 #[test]
544 fn $name() {
545 let arr = rand_arr(4usize.pow($dim), (0.0, 1.0), 64)
546 .into_shape_with_order($shape)
547 .unwrap();
548 let interp = Interp2D::builder(arr).build().unwrap();
549 let res = interp.interp(2.2, 2.2).unwrap();
550 assert_eq!(res.ndim(), $dim - 2);
551
552 let mut buf = Array::zeros(res.dim());
553 interp.interp_into(2.2, 2.2, buf.view_mut()).unwrap();
554 assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
555
556 let x_query = array![[0.5, 1.0], [1.5, 2.0]];
557 let y_query = array![[1.5, 2.0], [2.5, 3.0]];
558 let res = interp.interp_array(&x_query, &y_query).unwrap();
559 assert_eq!(res.ndim(), $dim - 2 + x_query.ndim());
560
561 let mut buf = Array::zeros(res.dim());
562 interp
563 .interp_array_into(&x_query, &y_query, buf.view_mut())
564 .unwrap();
565 assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
566 }
567 };
568 }
569
570 test_dim!(interp2d_2d, 2, (4, 4));
571 test_dim!(interp2d_3d, 3, (4, 4, 4));
572 test_dim!(interp2d_4d, 4, (4, 4, 4, 4));
573 test_dim!(interp2d_5d, 5, (4, 4, 4, 4, 4));
574 test_dim!(interp2d_6d, 6, (4, 4, 4, 4, 4, 4));
575 test_dim!(interp2d_7d, 7, IxDyn(&[4, 4, 4, 4, 4, 4, 4]));
576 test_dim!(interp2d_8d, 8, IxDyn(&[4, 4, 4, 4, 4, 4, 4, 4]));
577
578 #[test]
579 fn interp2d_2d_scalar() {
580 let arr = rand_arr(4usize.pow(2), (0.0, 1.0), 64)
581 .into_shape_with_order((4, 4))
582 .unwrap();
583 let _res: f64 = Interp2D::builder(arr) .build()
585 .unwrap()
586 .interp_scalar(2.2, 2.2)
587 .unwrap();
588 }
589}