Skip to main content

augurs_forecaster/transforms/
interpolate.rs

1/*!
2Contains an interpolation iterator adapter.
3
4The adapter can be used to fill in missing values in a time series
5using interpolation, similar to the `interpolate` method on a
6`Series` in the `pandas` or `polars` libraries.
7*/
8
9use std::{
10    collections::VecDeque,
11    iter::repeat_with,
12    ops::{Add, Div, Mul, Sub},
13};
14
15use super::{Error, Transformer};
16
17/// A type that can be used to interpolate between values.
18pub trait Interpolater {
19    /// Interpolate between two values.
20    ///
21    /// The `low` and `high` values are the start and end of the range to interpolate,
22    /// and `n` is the number of values to interpolate.
23    ///
24    /// The return value is an iterator that yields `n` values between `low` and `high`.
25    /// It should return a half-open range, i.e. it should include `low` but not `high`.
26    /// It should return exactly `n` values, so if `n` is `1` it should return an iterator
27    /// that yields only `low`, and if `n` is `0` it should return an empty iterator.
28    fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T>;
29}
30
31/// A linear interpolater.
32///
33/// This interpolater uses linear interpolation to fill in missing values in a time series.
34///
35/// # Example
36///
37/// ```
38/// use augurs_forecaster::transforms::interpolate::*;
39/// let got = LinearInterpolator::default().interpolate(1.0, 2.0, 4).collect::<Vec<_>>();
40/// assert_eq!(got, vec![1.0, 1.25, 1.5, 1.75]);
41/// ```
42#[derive(Debug, Clone, Copy, Default)]
43pub struct LinearInterpolator {
44    _priv: (),
45}
46
47impl LinearInterpolator {
48    /// Create a new `LinearInterpolator`.
49    pub fn new() -> Self {
50        Self::default()
51    }
52}
53
54impl Interpolater for LinearInterpolator {
55    fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T> {
56        let diff = high - low;
57        let step = diff / (T::from_usize(n));
58        (0..n).map(move |i| low + T::from_usize(i) * step)
59    }
60}
61
62impl Transformer for LinearInterpolator {
63    fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
64        Ok(())
65    }
66
67    fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
68        let interpolated: Vec<_> = data.iter().copied().interpolate(*self).collect();
69        data.copy_from_slice(&interpolated);
70        Ok(())
71    }
72
73    fn inverse_transform(&self, _data: &mut [f64]) -> Result<(), Error> {
74        Ok(())
75    }
76}
77
78/// An iterator that interpolates between NaN values in the input.
79///
80/// This iterator is used to fill in missing values in a time series by
81/// linearly interpolating between the nearest defined values.
82/// The iterator will yield the same number of values as the input, but
83/// with any NaN values replaced by interpolated values.
84///
85/// If the first or last value in the input is NaN, the iterator will
86/// yield NaN values at the start or end of the output, respectively.
87///
88/// # Example
89/// ```
90/// use augurs_forecaster::transforms::interpolate::*;
91/// let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
92/// let interp: Vec<_> = x.into_iter().interpolate(LinearInterpolator::default()).collect();
93/// assert_eq!(interp, vec![1.0, 1.25, 1.5, 1.75, 2.0]);
94/// ```
95#[derive(Debug, Clone)]
96pub struct Interpolate<T: Iterator, I> {
97    inner: T,
98    low: T::Item,
99    high: Option<T::Item>,
100    buf: VecDeque<T::Item>,
101    interpolator: I,
102}
103
104impl<T, I> Iterator for Interpolate<T, I>
105where
106    T: Iterator,
107    T::Item: Interpolatable,
108    I: Interpolater,
109{
110    type Item = T::Item;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        // If we have values in the buffer, use them first.
114        if !self.buf.is_empty() {
115            return self.buf.pop_front();
116        }
117
118        // If we have a high value from the previous iteration, use it and
119        // reset the high value to None, so that we don't use it again.
120        if let Some(high) = self.high.take() {
121            self.low = high;
122            return Some(high);
123        }
124
125        let next = self.inner.next();
126        match next {
127            Some(x) if x.is_nan() => {
128                // Count the number of NaNs we see, starting with this one (`x`).
129                let mut n: usize = 1;
130                for h in self.inner.by_ref() {
131                    if h.is_nan() {
132                        n += 1;
133                        continue;
134                    }
135                    // h is not NaN.
136                    self.high = Some(h);
137                    break;
138                }
139
140                if self.low.is_nan() {
141                    // We've seen NaNs at the start.
142                    self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
143                    return Some(self.low);
144                }
145
146                if let Some(high) = self.high {
147                    // Here we've seen NaNs in between some defined values, so we
148                    // can interpolate.
149                    let mut iter = self
150                        .interpolator
151                        // We need to interpolate `n + 1` values, because `n` doesn't
152                        // include the last non-NaN value which `interpolate` expects
153                        // to be included.
154                        .interpolate(self.low, high, n + 1)
155                        // Limit the number of values we yield to `n` since we know we need
156                        // that many NaNs but can't ensure that downstream implementors of
157                        // `Interpolater` respect that.
158                        .take(n + 1)
159                        // Skip the first value, which is the low value we've already seen.
160                        .skip(1);
161                    let first = iter.next();
162                    self.buf = iter.collect();
163                    first
164                } else {
165                    // We've seen NaNs at the end. Fill the buffer with NaNs to be
166                    // used by any subsequent calls to `next`.
167                    self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
168                    Some(T::Item::nan())
169                }
170            }
171            Some(x) => {
172                // We've seen a defined value, so we can store it as the low value
173                // for the next iteration and yield it.
174                self.low = x;
175                Some(x)
176            }
177            // We've reached the end of the input.
178            None => None,
179        }
180    }
181}
182
183/// An extension trait for iterators that adds the `interpolation` method.
184pub trait InterpolateExt: Iterator {
185    /// Interpolate between NaN values in the input.
186    ///
187    /// Returns an iterator that yields the same number of values as the input,
188    /// but with any NaN values replaced by linearly interpolated values.
189    ///
190    /// If the first or last value in the input is NaN, the iterator will
191    /// yield NaN values at the start or end of the output, respectively.
192    ///
193    /// # Example
194    /// ```
195    /// use augurs_forecaster::transforms::interpolate::*;
196    /// let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
197    /// let interp: Vec<_> = x.into_iter().interpolate(LinearInterpolator::default()).collect();
198    /// assert_eq!(interp, vec![1.0, 1.25, 1.5, 1.75, 2.0]);
199    /// ```
200    fn interpolate<I>(self, method: I) -> Interpolate<Self, I>
201    where
202        Self: Sized,
203        Self::Item: Interpolatable + Sized,
204        I: Interpolater,
205    {
206        Interpolate {
207            inner: self,
208            low: Self::Item::nan(),
209            high: None,
210            buf: VecDeque::new(),
211            interpolator: method,
212        }
213    }
214}
215
216impl<T> InterpolateExt for T where T: Iterator {}
217
218/// A trait for types that can be interpolated.
219///
220/// This is used to abstract over various types that can be interpolated.
221/// It is implemented for `f32` and `f64`, but can be implemented for more
222/// types if necessary.
223pub trait Interpolatable:
224    Add<Self, Output = Self>
225    + Div<Self, Output = Self>
226    + Mul<Self, Output = Self>
227    + Sub<Self, Output = Self>
228    + Copy
229    + Default
230    + Sized
231{
232    /// Return a NaN value of the type.
233    fn nan() -> Self;
234
235    /// Check if the value is NaN.
236    fn is_nan(&self) -> bool;
237
238    /// Convert a `usize` to the type.
239    fn from_usize(x: usize) -> Self;
240}
241
242impl Interpolatable for f32 {
243    fn nan() -> Self {
244        f32::NAN
245    }
246    fn is_nan(&self) -> bool {
247        f32::is_nan(*self)
248    }
249    fn from_usize(x: usize) -> Self {
250        x as f32
251    }
252}
253
254impl Interpolatable for f64 {
255    fn nan() -> Self {
256        f64::NAN
257    }
258    fn is_nan(&self) -> bool {
259        f64::is_nan(*self)
260    }
261    fn from_usize(x: usize) -> Self {
262        x as f64
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use super::*;
269
270    fn assert_approx_eq(a: f32, b: f32) -> bool {
271        if a.is_nan() && b.is_nan() {
272            return true;
273        }
274        (a - b).abs() < f32::EPSILON
275    }
276
277    fn assert_all_approx_eq(a: &[f32], b: &[f32]) {
278        if a.len() != b.len() {
279            assert_eq!(a, b);
280        }
281        for (ai, bi) in a.iter().zip(b) {
282            if !assert_approx_eq(*ai, *bi) {
283                assert_eq!(a, b);
284            }
285        }
286    }
287
288    #[test]
289    fn linear_interpreter() {
290        let got = LinearInterpolator::default()
291            .interpolate(1.0, 2.0, 4)
292            .collect::<Vec<_>>();
293        assert_eq!(got, vec![1.0, 1.25, 1.5, 1.75]);
294    }
295
296    #[test]
297    fn all_nan() {
298        let x = vec![f32::NAN, f32::NAN, f32::NAN];
299        let interp: Vec<_> = x
300            .clone()
301            .into_iter()
302            .interpolate(LinearInterpolator::default())
303            .collect();
304        assert_all_approx_eq(&interp, &x);
305    }
306
307    #[test]
308    fn empty() {
309        let x: Vec<f32> = vec![];
310        let interp: Vec<_> = x
311            .clone()
312            .into_iter()
313            .interpolate(LinearInterpolator::default())
314            .collect();
315        assert_all_approx_eq(&interp, &x);
316    }
317
318    #[test]
319    fn all_defined() {
320        let x = vec![1.0, 2.0, 3.0];
321        let interp: Vec<_> = x
322            .clone()
323            .into_iter()
324            .interpolate(LinearInterpolator::default())
325            .collect();
326        assert_all_approx_eq(&interp, &x);
327    }
328
329    #[test]
330    fn nans_in_middle() {
331        let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
332        let interp: Vec<_> = x
333            .clone()
334            .into_iter()
335            .interpolate(LinearInterpolator::default())
336            .collect();
337        assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0]);
338    }
339
340    #[test]
341    fn nans_at_start() {
342        let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
343        let interp: Vec<_> = x
344            .clone()
345            .into_iter()
346            .interpolate(LinearInterpolator::default())
347            .collect();
348        assert_all_approx_eq(&interp, &[f32::NAN, f32::NAN, 1.0, 1.25, 1.5, 1.75, 2.0]);
349    }
350
351    #[test]
352    fn nans_at_end() {
353        let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0, f32::NAN, f32::NAN];
354        let interp: Vec<_> = x
355            .clone()
356            .into_iter()
357            .interpolate(LinearInterpolator::default())
358            .collect();
359        assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0, f32::NAN, f32::NAN]);
360    }
361
362    #[test]
363    fn one_nan() {
364        let x = vec![0.0, 1.0, f32::NAN, 2.0, 3.0];
365        let interp: Vec<_> = x
366            .clone()
367            .into_iter()
368            .interpolate(LinearInterpolator::default())
369            .collect();
370        assert_all_approx_eq(&interp, &[0.0, 1.0, 1.5, 2.0, 3.0]);
371    }
372
373    #[test]
374    fn one_value() {
375        let x = vec![1.0];
376        let interp: Vec<_> = x
377            .clone()
378            .into_iter()
379            .interpolate(LinearInterpolator::default())
380            .collect();
381        assert_all_approx_eq(&interp, &x);
382    }
383
384    #[test]
385    fn one_value_amongst_nans() {
386        let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN];
387        let interp: Vec<_> = x
388            .clone()
389            .into_iter()
390            .interpolate(LinearInterpolator::default())
391            .collect();
392        assert_all_approx_eq(&interp, &x);
393    }
394
395    #[test]
396    fn one_value_before_nans() {
397        let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, f32::NAN];
398        let interp: Vec<_> = x
399            .clone()
400            .into_iter()
401            .interpolate(LinearInterpolator::default())
402            .collect();
403        assert_all_approx_eq(&interp, &x);
404    }
405
406    #[test]
407    fn one_value_after_nans() {
408        let x = vec![f32::NAN, f32::NAN, f32::NAN, f32::NAN, 1.0];
409        let interp: Vec<_> = x
410            .clone()
411            .into_iter()
412            .interpolate(LinearInterpolator::default())
413            .collect();
414        assert_all_approx_eq(&interp, &x);
415    }
416
417    #[test]
418    fn everything() {
419        let x = vec![
420            f32::NAN,
421            f32::NAN,
422            1.0,
423            f32::NAN,
424            f32::NAN,
425            f32::NAN,
426            2.0,
427            f32::NAN,
428            f32::NAN,
429        ];
430        let interp: Vec<_> = x
431            .clone()
432            .into_iter()
433            .interpolate(LinearInterpolator::default())
434            .collect();
435        assert_all_approx_eq(
436            &interp,
437            &[
438                f32::NAN,
439                f32::NAN,
440                1.0,
441                1.25,
442                1.5,
443                1.75,
444                2.0,
445                f32::NAN,
446                f32::NAN,
447            ],
448        );
449    }
450}