augurs_forecaster/transforms/
interpolate.rs1use std::{
10 collections::VecDeque,
11 iter::repeat_with,
12 ops::{Add, Div, Mul, Sub},
13};
14
15use super::{Error, Transformer};
16
17pub trait Interpolater {
19 fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T>;
29}
30
31#[derive(Debug, Clone, Copy, Default)]
43pub struct LinearInterpolator {
44 _priv: (),
45}
46
47impl LinearInterpolator {
48 pub fn new() -> Self {
50 Self::default()
51 }
52}
53
54impl Interpolater for LinearInterpolator {
55 fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T> {
56 let diff = high - low;
57 let step = diff / (T::from_usize(n));
58 (0..n).map(move |i| low + T::from_usize(i) * step)
59 }
60}
61
62impl Transformer for LinearInterpolator {
63 fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
64 Ok(())
65 }
66
67 fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
68 let interpolated: Vec<_> = data.iter().copied().interpolate(*self).collect();
69 data.copy_from_slice(&interpolated);
70 Ok(())
71 }
72
73 fn inverse_transform(&self, _data: &mut [f64]) -> Result<(), Error> {
74 Ok(())
75 }
76}
77
78#[derive(Debug, Clone)]
96pub struct Interpolate<T: Iterator, I> {
97 inner: T,
98 low: T::Item,
99 high: Option<T::Item>,
100 buf: VecDeque<T::Item>,
101 interpolator: I,
102}
103
104impl<T, I> Iterator for Interpolate<T, I>
105where
106 T: Iterator,
107 T::Item: Interpolatable,
108 I: Interpolater,
109{
110 type Item = T::Item;
111
112 fn next(&mut self) -> Option<Self::Item> {
113 if !self.buf.is_empty() {
115 return self.buf.pop_front();
116 }
117
118 if let Some(high) = self.high.take() {
121 self.low = high;
122 return Some(high);
123 }
124
125 let next = self.inner.next();
126 match next {
127 Some(x) if x.is_nan() => {
128 let mut n: usize = 1;
130 for h in self.inner.by_ref() {
131 if h.is_nan() {
132 n += 1;
133 continue;
134 }
135 self.high = Some(h);
137 break;
138 }
139
140 if self.low.is_nan() {
141 self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
143 return Some(self.low);
144 }
145
146 if let Some(high) = self.high {
147 let mut iter = self
150 .interpolator
151 .interpolate(self.low, high, n + 1)
155 .take(n + 1)
159 .skip(1);
161 let first = iter.next();
162 self.buf = iter.collect();
163 first
164 } else {
165 self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
168 Some(T::Item::nan())
169 }
170 }
171 Some(x) => {
172 self.low = x;
175 Some(x)
176 }
177 None => None,
179 }
180 }
181}
182
183pub trait InterpolateExt: Iterator {
185 fn interpolate<I>(self, method: I) -> Interpolate<Self, I>
201 where
202 Self: Sized,
203 Self::Item: Interpolatable + Sized,
204 I: Interpolater,
205 {
206 Interpolate {
207 inner: self,
208 low: Self::Item::nan(),
209 high: None,
210 buf: VecDeque::new(),
211 interpolator: method,
212 }
213 }
214}
215
216impl<T> InterpolateExt for T where T: Iterator {}
217
218pub trait Interpolatable:
224 Add<Self, Output = Self>
225 + Div<Self, Output = Self>
226 + Mul<Self, Output = Self>
227 + Sub<Self, Output = Self>
228 + Copy
229 + Default
230 + Sized
231{
232 fn nan() -> Self;
234
235 fn is_nan(&self) -> bool;
237
238 fn from_usize(x: usize) -> Self;
240}
241
242impl Interpolatable for f32 {
243 fn nan() -> Self {
244 f32::NAN
245 }
246 fn is_nan(&self) -> bool {
247 f32::is_nan(*self)
248 }
249 fn from_usize(x: usize) -> Self {
250 x as f32
251 }
252}
253
254impl Interpolatable for f64 {
255 fn nan() -> Self {
256 f64::NAN
257 }
258 fn is_nan(&self) -> bool {
259 f64::is_nan(*self)
260 }
261 fn from_usize(x: usize) -> Self {
262 x as f64
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use super::*;
269
270 fn assert_approx_eq(a: f32, b: f32) -> bool {
271 if a.is_nan() && b.is_nan() {
272 return true;
273 }
274 (a - b).abs() < f32::EPSILON
275 }
276
277 fn assert_all_approx_eq(a: &[f32], b: &[f32]) {
278 if a.len() != b.len() {
279 assert_eq!(a, b);
280 }
281 for (ai, bi) in a.iter().zip(b) {
282 if !assert_approx_eq(*ai, *bi) {
283 assert_eq!(a, b);
284 }
285 }
286 }
287
288 #[test]
289 fn linear_interpreter() {
290 let got = LinearInterpolator::default()
291 .interpolate(1.0, 2.0, 4)
292 .collect::<Vec<_>>();
293 assert_eq!(got, vec![1.0, 1.25, 1.5, 1.75]);
294 }
295
296 #[test]
297 fn all_nan() {
298 let x = vec![f32::NAN, f32::NAN, f32::NAN];
299 let interp: Vec<_> = x
300 .clone()
301 .into_iter()
302 .interpolate(LinearInterpolator::default())
303 .collect();
304 assert_all_approx_eq(&interp, &x);
305 }
306
307 #[test]
308 fn empty() {
309 let x: Vec<f32> = vec![];
310 let interp: Vec<_> = x
311 .clone()
312 .into_iter()
313 .interpolate(LinearInterpolator::default())
314 .collect();
315 assert_all_approx_eq(&interp, &x);
316 }
317
318 #[test]
319 fn all_defined() {
320 let x = vec![1.0, 2.0, 3.0];
321 let interp: Vec<_> = x
322 .clone()
323 .into_iter()
324 .interpolate(LinearInterpolator::default())
325 .collect();
326 assert_all_approx_eq(&interp, &x);
327 }
328
329 #[test]
330 fn nans_in_middle() {
331 let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
332 let interp: Vec<_> = x
333 .clone()
334 .into_iter()
335 .interpolate(LinearInterpolator::default())
336 .collect();
337 assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0]);
338 }
339
340 #[test]
341 fn nans_at_start() {
342 let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
343 let interp: Vec<_> = x
344 .clone()
345 .into_iter()
346 .interpolate(LinearInterpolator::default())
347 .collect();
348 assert_all_approx_eq(&interp, &[f32::NAN, f32::NAN, 1.0, 1.25, 1.5, 1.75, 2.0]);
349 }
350
351 #[test]
352 fn nans_at_end() {
353 let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0, f32::NAN, f32::NAN];
354 let interp: Vec<_> = x
355 .clone()
356 .into_iter()
357 .interpolate(LinearInterpolator::default())
358 .collect();
359 assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0, f32::NAN, f32::NAN]);
360 }
361
362 #[test]
363 fn one_nan() {
364 let x = vec![0.0, 1.0, f32::NAN, 2.0, 3.0];
365 let interp: Vec<_> = x
366 .clone()
367 .into_iter()
368 .interpolate(LinearInterpolator::default())
369 .collect();
370 assert_all_approx_eq(&interp, &[0.0, 1.0, 1.5, 2.0, 3.0]);
371 }
372
373 #[test]
374 fn one_value() {
375 let x = vec![1.0];
376 let interp: Vec<_> = x
377 .clone()
378 .into_iter()
379 .interpolate(LinearInterpolator::default())
380 .collect();
381 assert_all_approx_eq(&interp, &x);
382 }
383
384 #[test]
385 fn one_value_amongst_nans() {
386 let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN];
387 let interp: Vec<_> = x
388 .clone()
389 .into_iter()
390 .interpolate(LinearInterpolator::default())
391 .collect();
392 assert_all_approx_eq(&interp, &x);
393 }
394
395 #[test]
396 fn one_value_before_nans() {
397 let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, f32::NAN];
398 let interp: Vec<_> = x
399 .clone()
400 .into_iter()
401 .interpolate(LinearInterpolator::default())
402 .collect();
403 assert_all_approx_eq(&interp, &x);
404 }
405
406 #[test]
407 fn one_value_after_nans() {
408 let x = vec![f32::NAN, f32::NAN, f32::NAN, f32::NAN, 1.0];
409 let interp: Vec<_> = x
410 .clone()
411 .into_iter()
412 .interpolate(LinearInterpolator::default())
413 .collect();
414 assert_all_approx_eq(&interp, &x);
415 }
416
417 #[test]
418 fn everything() {
419 let x = vec![
420 f32::NAN,
421 f32::NAN,
422 1.0,
423 f32::NAN,
424 f32::NAN,
425 f32::NAN,
426 2.0,
427 f32::NAN,
428 f32::NAN,
429 ];
430 let interp: Vec<_> = x
431 .clone()
432 .into_iter()
433 .interpolate(LinearInterpolator::default())
434 .collect();
435 assert_all_approx_eq(
436 &interp,
437 &[
438 f32::NAN,
439 f32::NAN,
440 1.0,
441 1.25,
442 1.5,
443 1.75,
444 2.0,
445 f32::NAN,
446 f32::NAN,
447 ],
448 );
449 }
450}