ndarray_interp/interp1d/strategies/
linear.rs1use std::{fmt::Debug, ops::Sub};
2
3use ndarray::{ArrayBase, ArrayViewMut, Data, Dimension, Ix1, RemoveAxis, Zip};
4use num_traits::{Num, NumCast};
5
6use crate::{interp1d::Interp1D, BuilderError, InterpolateError};
7
8use super::{Interp1DStrategy, Interp1DStrategyBuilder};
9
10#[derive(Debug)]
12pub struct Linear {
13 extrapolate: bool,
14}
15
16impl Linear {
17 pub fn new() -> Self {
19 Self { extrapolate: false }
20 }
21
22 pub fn extrapolate(mut self, extrapolate: bool) -> Self {
24 self.extrapolate = extrapolate;
25 self
26 }
27
28 pub(crate) fn calc_frac<T>((x1, y1): (T, T), (x2, y2): (T, T), x: T) -> T
30 where
31 T: Num + Copy,
32 {
33 let b = y1;
34 let m = (y2 - y1) / (x2 - x1);
35 m * (x - x1) + b
36 }
37}
38
39impl Default for Linear {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<Sd, Sx, D> Interp1DStrategyBuilder<Sd, Sx, D> for Linear
46where
47 Sd: Data,
48 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
49 Sx: Data<Elem = Sd::Elem>,
50 D: Dimension + RemoveAxis,
51{
52 const MINIMUM_DATA_LENGHT: usize = 2;
53 type FinishedStrat = Linear;
54 fn build<Sx2>(
55 self,
56 _x: &ArrayBase<Sx2, Ix1>,
57 _data: &ArrayBase<Sd, D>,
58 ) -> Result<Self::FinishedStrat, BuilderError>
59 where
60 Sx2: Data<Elem = Sd::Elem>,
61 {
62 Ok(self)
63 }
64}
65
66impl<Sd, Sx, D> Interp1DStrategy<Sd, Sx, D> for Linear
67where
68 Sd: Data,
69 Sd::Elem: Num + PartialOrd + NumCast + Copy + Debug + Sub + Send,
70 Sx: Data<Elem = Sd::Elem>,
71 D: Dimension + RemoveAxis,
72{
73 fn interp_into(
74 &self,
75 interpolator: &Interp1D<Sd, Sx, D, Self>,
76 target: ArrayViewMut<'_, <Sd>::Elem, <D as Dimension>::Smaller>,
77 x: Sx::Elem,
78 ) -> Result<(), InterpolateError> {
79 let this = interpolator;
80 if !self.extrapolate && !this.is_in_range(x) {
81 return Err(InterpolateError::OutOfBounds(format!(
82 "x = {x:#?} is not in range",
83 )));
84 }
85
86 let idx = this.get_index_left_of(x);
88
89 let (x1, y1) = this.index_point(idx);
91 let (x2, y2) = this.index_point(idx + 1);
92
93 Zip::from(y1).and(y2).and(target).for_each(|&y1, &y2, t| {
95 *t = Self::calc_frac((x1, y1), (x2, y2), x);
96 });
97 Ok(())
98 }
99}