ndarray_interp/interp1d/strategies/
linear.rs

1use std::{fmt::Debug, ops::Sub};
2
3use ndarray::{ArrayBase, ArrayViewMut, Data, Dimension, Ix1, RemoveAxis, Zip};
4use num_traits::{Num, NumCast};
5
6use crate::{interp1d::Interp1D, BuilderError, InterpolateError};
7
8use super::{Interp1DStrategy, Interp1DStrategyBuilder};
9
10/// Linear Interpolation Strategy
11#[derive(Debug)]
12pub struct Linear {
13    extrapolate: bool,
14}
15
16impl Linear {
17    /// create a linear interpolation stratgy
18    pub fn new() -> Self {
19        Self { extrapolate: false }
20    }
21
22    /// does the strategy extrapolate? Default is `false`
23    pub fn extrapolate(mut self, extrapolate: bool) -> Self {
24        self.extrapolate = extrapolate;
25        self
26    }
27
28    /// linearly interpolate/exrapolate between two points
29    pub(crate) fn calc_frac<T>((x1, y1): (T, T), (x2, y2): (T, T), x: T) -> T
30    where
31        T: Num + Copy,
32    {
33        let b = y1;
34        let m = (y2 - y1) / (x2 - x1);
35        m * (x - x1) + b
36    }
37}
38
39impl Default for Linear {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<Sd, Sx, D> Interp1DStrategyBuilder<Sd, Sx, D> for Linear
46where
47    Sd: Data,
48    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
49    Sx: Data<Elem = Sd::Elem>,
50    D: Dimension + RemoveAxis,
51{
52    const MINIMUM_DATA_LENGHT: usize = 2;
53    type FinishedStrat = Linear;
54    fn build<Sx2>(
55        self,
56        _x: &ArrayBase<Sx2, Ix1>,
57        _data: &ArrayBase<Sd, D>,
58    ) -> Result<Self::FinishedStrat, BuilderError>
59    where
60        Sx2: Data<Elem = Sd::Elem>,
61    {
62        Ok(self)
63    }
64}
65
66impl<Sd, Sx, D> Interp1DStrategy<Sd, Sx, D> for Linear
67where
68    Sd: Data,
69    Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
70    Sx: Data<Elem = Sd::Elem>,
71    D: Dimension + RemoveAxis,
72{
73    fn interp_into(
74        &self,
75        interpolator: &Interp1D<Sd, Sx, D, Self>,
76        target: ArrayViewMut<'_, <Sd>::Elem, <D as Dimension>::Smaller>,
77        x: Sx::Elem,
78    ) -> Result<(), InterpolateError> {
79        let this = interpolator;
80        if !self.extrapolate && !this.is_in_range(x) {
81            return Err(InterpolateError::OutOfBounds(format!(
82                "x = {x:#?} is not in range",
83            )));
84        }
85
86        // find the relevant index
87        let idx = this.get_index_left_of(x);
88
89        // lookup the data
90        let (x1, y1) = this.index_point(idx);
91        let (x2, y2) = this.index_point(idx + 1);
92
93        // do interpolation
94        Zip::from(y1).and(y2).and(target).for_each(|&y1, &y2, t| {
95            *t = Self::calc_frac((x1, y1), (x2, y2), x);
96        });
97        Ok(())
98    }
99}