1use std::{any::TypeId, fmt::Debug, ops::Sub};
16
17use ndarray::{
18 Array, ArrayBase, ArrayView, ArrayViewMut, ArrayViewMut1, Axis, AxisDescription, Data, DimAdd,
19 Dimension, IntoDimension, Ix1, OwnedRepr, RemoveAxis, Slice, Zip,
20};
21use num_traits::{cast, Num, NumCast};
22
23use crate::{
24 cast_unchecked,
25 dim_extensions::DimExtension,
26 vector_extensions::{Monotonic, VectorExtensions},
27 BuilderError, InterpolateError,
28};
29
30mod aliases;
31mod strategies;
32pub use aliases::*;
33pub use strategies::cubic_spline;
34pub use strategies::linear::Linear;
35pub use strategies::{Interp1DStrategy, Interp1DStrategyBuilder};
36
37#[derive(Debug)]
39pub struct Interp1D<Sd, Sx, D, Strat>
40where
41 Sd: Data,
42 Sd::Elem: Num + Debug + Send,
43 Sx: Data<Elem = Sd::Elem>,
44 D: Dimension,
45 Strat: Interp1DStrategy<Sd, Sx, D>,
46{
47 x: ArrayBase<Sx, Ix1>,
49 data: ArrayBase<Sd, D>,
50 strategy: Strat,
51}
52
53#[derive(Debug)]
60pub struct Interp1DBuilder<Sd, Sx, D, Strat>
61where
62 Sd: Data,
63 Sd::Elem: Num + Debug,
64 Sx: Data<Elem = Sd::Elem>,
65 D: Dimension,
66{
67 x: ArrayBase<Sx, Ix1>,
68 data: ArrayBase<Sd, D>,
69 strategy: Strat,
70}
71
72impl<Sd, D> Interp1D<Sd, OwnedRepr<Sd::Elem>, D, Linear>
73where
74 Sd: Data,
75 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Send,
76 D: Dimension + RemoveAxis,
77{
78 pub fn builder(data: ArrayBase<Sd, D>) -> Interp1DBuilder<Sd, OwnedRepr<Sd::Elem>, D, Linear> {
80 Interp1DBuilder::new(data)
81 }
82}
83
84impl<Sd, Sx, Strat> Interp1D<Sd, Sx, Ix1, Strat>
85where
86 Sd: Data,
87 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
88 Sx: Data<Elem = Sd::Elem>,
89 Strat: Interp1DStrategy<Sd, Sx, Ix1>,
90{
91 pub fn interp_scalar(&self, x: Sx::Elem) -> Result<Sd::Elem, InterpolateError> {
109 let mut buffer: [Sd::Elem; 1] = [cast(0.0).unwrap_or_else(|| unimplemented!())];
110 let buf_view = ArrayViewMut1::from(buffer.as_mut_slice()).remove_axis(Axis(0));
111 self.strategy
112 .interp_into(self, buf_view, x)
113 .map(|_| buffer[0])
114 }
115}
116
117impl<Sd, Sx, D, Strat> Interp1D<Sd, Sx, D, Strat>
118where
119 Sd: Data,
120 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
121 Sx: Data<Elem = Sd::Elem>,
122 D: Dimension + RemoveAxis,
123 Strat: Interp1DStrategy<Sd, Sx, D>,
124{
125 pub fn interp(&self, x: Sx::Elem) -> Result<Array<Sd::Elem, D::Smaller>, InterpolateError> {
151 let dim = self.data.raw_dim().remove_axis(Axis(0));
152 let mut target: Array<Sd::Elem, _> = Array::zeros(dim);
153 self.strategy
154 .interp_into(self, target.view_mut(), x)
155 .map(|_| target)
156 }
157
158 pub fn interp_into(
170 &self,
171 x: Sx::Elem,
172 buffer: ArrayViewMut<'_, Sd::Elem, D::Smaller>,
173 ) -> Result<(), InterpolateError> {
174 self.strategy.interp_into(self, buffer, x)
175 }
176
177 pub fn interp_array<Sq, Dq>(
198 &self,
199 xs: &ArrayBase<Sq, Dq>,
200 ) -> Result<Array<Sd::Elem, <Dq as DimAdd<D::Smaller>>::Output>, InterpolateError>
201 where
202 Sq: Data<Elem = Sd::Elem>,
203 Dq: Dimension + DimAdd<D::Smaller> + 'static,
204 <Dq as DimAdd<D::Smaller>>::Output: DimExtension,
205 {
206 let dim = self.get_buffer_shape(xs.raw_dim());
207 debug_assert_eq!(dim.ndim(), self.data.ndim() + xs.ndim() - 1);
208
209 let mut ys = Array::zeros(dim);
210 self.interp_array_into(xs, ys.view_mut()).map(|_| ys)
211 }
212
213 pub fn interp_array_into<Sq, Dq>(
273 &self,
274 xs: &ArrayBase<Sq, Dq>,
275 mut buffer: ArrayViewMut<Sd::Elem, <Dq as DimAdd<D::Smaller>>::Output>,
276 ) -> Result<(), InterpolateError>
277 where
278 Sq: Data<Elem = Sd::Elem>,
279 Dq: Dimension + DimAdd<D::Smaller> + 'static,
280 <Dq as DimAdd<D::Smaller>>::Output: DimExtension,
281 {
282 if TypeId::of::<Dq>() == TypeId::of::<Ix1>() {
284 let xs_1d = unsafe { cast_unchecked::<&ArrayBase<Sq, Dq>, &ArrayBase<Sq, Ix1>>(xs) };
287 let buffer_d = unsafe {
292 cast_unchecked::<
293 ArrayViewMut<Sd::Elem, <Dq as DimAdd<D::Smaller>>::Output>,
294 ArrayViewMut<Sd::Elem, D>,
295 >(buffer)
296 };
297 return self.interp_array_into_1d(xs_1d, buffer_d);
298 }
299
300 for (index, &x) in xs.indexed_iter() {
302 let current_dim = index.clone().into_dimension();
303 let subview =
304 buffer.slice_each_axis_mut(|AxisDescription { axis: Axis(nr), .. }| {
305 match current_dim.as_array_view().get(nr) {
306 Some(idx) => Slice::from(*idx..*idx + 1),
307 None => Slice::from(..),
308 }
309 });
310
311 let subview =
312 match subview.into_shape_with_order(self.data.raw_dim().remove_axis(Axis(0))) {
313 Ok(view) => view,
314 Err(err) => {
315 let expect = self.get_buffer_shape(xs.raw_dim()).into_pattern();
316 let got = buffer.dim();
317 panic!("{err} expected: {expect:?}, got: {got:?}")
318 }
319 };
320
321 self.strategy.interp_into(self, subview, x)?;
322 }
323 Ok(())
324 }
325
326 fn interp_array_into_1d<Sq>(
327 &self,
328 xs: &ArrayBase<Sq, Ix1>,
329 mut buffer: ArrayViewMut<'_, Sd::Elem, D>,
330 ) -> Result<(), InterpolateError>
331 where
332 Sq: Data<Elem = Sd::Elem>,
333 {
334 Zip::from(xs)
335 .and(buffer.axis_iter_mut(Axis(0)))
336 .fold_while(Ok(()), |_, &x, buf| {
337 match self.strategy.interp_into(self, buf, x) {
338 Ok(_) => ndarray::FoldWhile::Continue(Ok(())),
339 Err(e) => ndarray::FoldWhile::Done(Err(e)),
340 }
341 })
342 .into_inner()
343 }
344
345 fn get_buffer_shape<Dq>(&self, dq: Dq) -> <Dq as DimAdd<D::Smaller>>::Output
347 where
348 Dq: Dimension + DimAdd<D::Smaller>,
349 <Dq as DimAdd<D::Smaller>>::Output: DimExtension,
350 {
351 let binding = dq.as_array_view();
352 let lenghts = binding.iter().chain(self.data.shape()[1..].iter()).copied();
353 <Dq as DimAdd<D::Smaller>>::Output::new(lenghts)
354 }
355
356 pub fn new_unchecked(x: ArrayBase<Sx, Ix1>, data: ArrayBase<Sd, D>, strategy: Strat) -> Self {
364 Interp1D { x, data, strategy }
365 }
366
367 pub fn index_point(&self, index: usize) -> (Sx::Elem, ArrayView<Sd::Elem, D::Smaller>) {
372 let view = self.data.index_axis(Axis(0), index);
373 (self.x[index], view)
374 }
375
376 pub fn get_index_left_of(&self, x: Sx::Elem) -> usize {
381 self.x.get_lower_index(x)
382 }
383
384 pub fn is_in_range(&self, x: Sx::Elem) -> bool {
385 self.x[0] <= x && x <= self.x[self.x.len() - 1]
386 }
387}
388
389impl<Sd, D> Interp1DBuilder<Sd, OwnedRepr<Sd::Elem>, D, Linear>
390where
391 Sd: Data,
392 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug,
393 D: Dimension,
394{
395 pub fn new(data: ArrayBase<Sd, D>) -> Self {
400 let len = data.shape()[0];
401 Interp1DBuilder {
402 x: Array::from_iter((0..len).map(|n| {
403 cast(n).unwrap_or_else(|| {
404 unimplemented!("casting from usize to a number should always work")
405 })
406 })),
407 data,
408 strategy: Linear::new(),
409 }
410 }
411}
412
413impl<Sd, Sx, D, Strat> Interp1DBuilder<Sd, Sx, D, Strat>
414where
415 Sd: Data,
416 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Send,
417 Sx: Data<Elem = Sd::Elem>,
418 D: Dimension + RemoveAxis,
419 Strat: Interp1DStrategyBuilder<Sd, Sx, D>,
420{
421 pub fn x<NewSx>(self, x: ArrayBase<NewSx, Ix1>) -> Interp1DBuilder<Sd, NewSx, D, Strat>
425 where
426 NewSx: Data<Elem = Sd::Elem>,
427 {
428 let Interp1DBuilder { data, strategy, .. } = self;
429 Interp1DBuilder { x, data, strategy }
430 }
431
432 pub fn strategy<NewStrat>(self, strategy: NewStrat) -> Interp1DBuilder<Sd, Sx, D, NewStrat>
435 where
436 NewStrat: Interp1DStrategyBuilder<Sd, Sx, D>,
437 {
438 let Interp1DBuilder { x, data, .. } = self;
439 Interp1DBuilder { x, data, strategy }
440 }
441
442 pub fn build(self) -> Result<Interp1D<Sd, Sx, D, Strat::FinishedStrat>, BuilderError> {
444 use self::Monotonic::*;
445 use BuilderError::*;
446
447 let Interp1DBuilder { x, data, strategy } = self;
448
449 if data.ndim() < 1 {
450 return Err(ShapeError(
451 "data dimension is 0, needs to be at least 1".into(),
452 ));
453 }
454 if data.shape()[0] < Strat::MINIMUM_DATA_LENGHT {
455 return Err(NotEnoughData(format!(
456 "The chosen Interpolation strategy needs at least {} data points",
457 Strat::MINIMUM_DATA_LENGHT
458 )));
459 }
460 if !matches!(x.monotonic_prop(), Rising { strict: true }) {
461 return Err(Monotonic(
462 "Values in the x axis need to be strictly monotonic rising".into(),
463 ));
464 }
465 if x.len() != data.shape()[0] {
466 return Err(BuilderError::ShapeError(format!(
467 "Lengths of x and data axis need to match. Got x: {:}, data: {:}",
468 x.len(),
469 data.shape()[0],
470 )));
471 }
472
473 let strategy = strategy.build(&x, &data)?;
474
475 Ok(Interp1D { x, data, strategy })
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use approx::assert_abs_diff_eq;
482 use ndarray::{array, Array, Array1, IxDyn};
483 use rand::{
484 distr::{uniform::SampleUniform, Uniform},
485 rngs::StdRng,
486 Rng, SeedableRng,
487 };
488
489 use super::Interp1D;
490
491 fn rand_arr<T: SampleUniform>(size: usize, range: (T, T), seed: u64) -> Array1<T> {
492 Array::from_iter(
493 StdRng::seed_from_u64(seed)
494 .sample_iter(Uniform::new_inclusive(range.0, range.1).unwrap())
495 .take(size),
496 )
497 }
498
499 macro_rules! get_interp {
500 ($dim:expr, $shape:expr) => {{
501 let arr = rand_arr(4usize.pow($dim), (0.0, 1.0), 64)
502 .into_shape_with_order($shape)
503 .unwrap();
504 Interp1D::builder(arr).build().unwrap()
505 }};
506 }
507
508 macro_rules! test_dim {
509 ($name:ident, $dim:expr, $shape:expr) => {
510 #[test]
511 fn $name() {
512 let interp = get_interp!($dim, $shape);
513 let res = interp.interp(2.2).unwrap();
514 assert_eq!(res.ndim(), $dim - 1);
515
516 let mut buf = Array::zeros(res.dim());
517 interp.interp_into(2.2, buf.view_mut()).unwrap();
518 assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
519
520 let query = array![[0.5, 1.0], [1.5, 2.0]];
521 let res = interp.interp_array(&query).unwrap();
522 assert_eq!(res.ndim(), $dim - 1 + query.ndim());
523
524 let mut buf = Array::zeros(res.dim());
525 interp.interp_array_into(&query, buf.view_mut()).unwrap();
526 assert_abs_diff_eq!(buf, res, epsilon = f64::EPSILON);
527 }
528 };
529 }
530
531 test_dim!(interp1d_1d, 1, 4);
532 test_dim!(interp1d_2d, 2, (4, 4));
533 test_dim!(interp1d_3d, 3, (4, 4, 4));
534 test_dim!(interp1d_4d, 4, (4, 4, 4, 4));
535 test_dim!(interp1d_5d, 5, (4, 4, 4, 4, 4));
536 test_dim!(interp1d_6d, 6, (4, 4, 4, 4, 4, 4));
537 test_dim!(interp1d_7d, 7, IxDyn(&[4, 4, 4, 4, 4, 4, 4]));
538
539 #[test]
540 fn interp1d_1d_scalar() {
541 let arr = rand_arr(4, (0.0, 1.0), 64);
542 let _res: f64 = Interp1D::builder(arr) .build()
544 .unwrap()
545 .interp_scalar(2.2)
546 .unwrap();
547 }
548
549 #[test]
550 #[should_panic(expected = "expected: [4], got: [3]")]
551 fn interp1d_2d_into_too_small() {
552 let interp = get_interp!(2, (4, 4));
553 let mut buf = Array::zeros(3);
554 let _ = interp.interp_into(2.2, buf.view_mut());
555 }
556
557 #[test]
558 #[should_panic(expected = "expected: [4], got: [5]")]
559 fn interp1d_2d_into_too_big() {
560 let interp = get_interp!(2, (4, 4));
561 let mut buf = Array::zeros(5);
562 let _ = interp.interp_into(2.2, buf.view_mut());
563 }
564
565 #[test]
566 #[should_panic(expected = "expected: [2], got: [1]")] fn interp1d_2d_array_into_too_small1() {
568 let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
569 .into_shape_with_order((4, 4))
570 .unwrap();
571 let interp = Interp1D::builder(arr).build().unwrap();
572 let mut buf = Array::zeros((1, 4));
573 let _ = interp.interp_array_into(&array![2.2, 2.4], buf.view_mut());
574 }
575
576 #[test]
577 #[should_panic]
578 fn interp1d_2d_array_into_too_small2() {
579 let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
580 .into_shape_with_order((4, 4))
581 .unwrap();
582 let interp = Interp1D::builder(arr).build().unwrap();
583 let mut buf = Array::zeros((2, 3));
584 let _ = interp.interp_array_into(&array![2.2, 2.4], buf.view_mut());
585 }
586
587 #[test]
588 #[should_panic]
589 fn interp1d_2d_array_into_too_big1() {
590 let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
591 .into_shape_with_order((4, 4))
592 .unwrap();
593 let interp = Interp1D::builder(arr).build().unwrap();
594 let mut buf = Array::zeros((3, 4));
595 let _ = interp.interp_array_into(&array![2.2, 2.4], buf.view_mut());
596 }
597
598 #[test]
599 #[should_panic]
600 fn interp1d_2d_array_into_too_big2() {
601 let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
602 .into_shape_with_order((4, 4))
603 .unwrap();
604 let interp = Interp1D::builder(arr).build().unwrap();
605 let mut buf = Array::zeros((2, 5));
606 let _ = interp.interp_array_into(&array![2.2, 2.4], buf.view_mut());
607 }
608}