use std::{
collections::VecDeque,
iter::repeat_with,
ops::{Add, Div, Mul, Sub},
};
use super::{Error, Transformer};
pub trait Interpolater {
fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct LinearInterpolator {
_priv: (),
}
impl LinearInterpolator {
pub fn new() -> Self {
Self::default()
}
}
impl Interpolater for LinearInterpolator {
fn interpolate<T: Interpolatable>(&self, low: T, high: T, n: usize) -> impl Iterator<Item = T> {
let diff = high - low;
let step = diff / (T::from_usize(n));
(0..n).map(move |i| low + T::from_usize(i) * step)
}
}
impl Transformer for LinearInterpolator {
fn fit(&mut self, _data: &[f64]) -> Result<(), Error> {
Ok(())
}
fn transform(&self, data: &mut [f64]) -> Result<(), Error> {
let interpolated: Vec<_> = data.iter().copied().interpolate(*self).collect();
data.copy_from_slice(&interpolated);
Ok(())
}
fn inverse_transform(&self, _data: &mut [f64]) -> Result<(), Error> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Interpolate<T: Iterator, I> {
inner: T,
low: T::Item,
high: Option<T::Item>,
buf: VecDeque<T::Item>,
interpolator: I,
}
impl<T, I> Iterator for Interpolate<T, I>
where
T: Iterator,
T::Item: Interpolatable,
I: Interpolater,
{
type Item = T::Item;
fn next(&mut self) -> Option<Self::Item> {
if !self.buf.is_empty() {
return self.buf.pop_front();
}
if let Some(high) = self.high.take() {
self.low = high;
return Some(high);
}
let next = self.inner.next();
match next {
Some(x) if x.is_nan() => {
let mut n: usize = 1;
for h in self.inner.by_ref() {
if h.is_nan() {
n += 1;
continue;
}
self.high = Some(h);
break;
}
if self.low.is_nan() {
self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
return Some(self.low);
}
if let Some(high) = self.high {
let mut iter = self
.interpolator
.interpolate(self.low, high, n + 1)
.take(n + 1)
.skip(1);
let first = iter.next();
self.buf = iter.collect();
first
} else {
self.buf = repeat_with(Self::Item::nan).take(n - 1).collect();
Some(T::Item::nan())
}
}
Some(x) => {
self.low = x;
Some(x)
}
None => None,
}
}
}
pub trait InterpolateExt: Iterator {
fn interpolate<I>(self, method: I) -> Interpolate<Self, I>
where
Self: Sized,
Self::Item: Interpolatable + Sized,
I: Interpolater,
{
Interpolate {
inner: self,
low: Self::Item::nan(),
high: None,
buf: VecDeque::new(),
interpolator: method,
}
}
}
impl<T> InterpolateExt for T where T: Iterator {}
pub trait Interpolatable:
Add<Self, Output = Self>
+ Div<Self, Output = Self>
+ Mul<Self, Output = Self>
+ Sub<Self, Output = Self>
+ Copy
+ Default
+ Sized
{
fn nan() -> Self;
fn is_nan(&self) -> bool;
fn from_usize(x: usize) -> Self;
}
impl Interpolatable for f32 {
fn nan() -> Self {
f32::NAN
}
fn is_nan(&self) -> bool {
f32::is_nan(*self)
}
fn from_usize(x: usize) -> Self {
x as f32
}
}
impl Interpolatable for f64 {
fn nan() -> Self {
f64::NAN
}
fn is_nan(&self) -> bool {
f64::is_nan(*self)
}
fn from_usize(x: usize) -> Self {
x as f64
}
}
#[cfg(test)]
mod test {
use super::*;
fn assert_approx_eq(a: f32, b: f32) -> bool {
if a.is_nan() && b.is_nan() {
return true;
}
(a - b).abs() < f32::EPSILON
}
fn assert_all_approx_eq(a: &[f32], b: &[f32]) {
if a.len() != b.len() {
assert_eq!(a, b);
}
for (ai, bi) in a.iter().zip(b) {
if !assert_approx_eq(*ai, *bi) {
assert_eq!(a, b);
}
}
}
#[test]
fn linear_interpreter() {
let got = LinearInterpolator::default()
.interpolate(1.0, 2.0, 4)
.collect::<Vec<_>>();
assert_eq!(got, vec![1.0, 1.25, 1.5, 1.75]);
}
#[test]
fn all_nan() {
let x = vec![f32::NAN, f32::NAN, f32::NAN];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn empty() {
let x: Vec<f32> = vec![];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn all_defined() {
let x = vec![1.0, 2.0, 3.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn nans_in_middle() {
let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0]);
}
#[test]
fn nans_at_start() {
let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN, f32::NAN, 2.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &[f32::NAN, f32::NAN, 1.0, 1.25, 1.5, 1.75, 2.0]);
}
#[test]
fn nans_at_end() {
let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, 2.0, f32::NAN, f32::NAN];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &[1.0, 1.25, 1.5, 1.75, 2.0, f32::NAN, f32::NAN]);
}
#[test]
fn one_nan() {
let x = vec![0.0, 1.0, f32::NAN, 2.0, 3.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &[0.0, 1.0, 1.5, 2.0, 3.0]);
}
#[test]
fn one_value() {
let x = vec![1.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn one_value_amongst_nans() {
let x = vec![f32::NAN, f32::NAN, 1.0, f32::NAN, f32::NAN];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn one_value_before_nans() {
let x = vec![1.0, f32::NAN, f32::NAN, f32::NAN, f32::NAN];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn one_value_after_nans() {
let x = vec![f32::NAN, f32::NAN, f32::NAN, f32::NAN, 1.0];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(&interp, &x);
}
#[test]
fn everything() {
let x = vec![
f32::NAN,
f32::NAN,
1.0,
f32::NAN,
f32::NAN,
f32::NAN,
2.0,
f32::NAN,
f32::NAN,
];
let interp: Vec<_> = x
.clone()
.into_iter()
.interpolate(LinearInterpolator::default())
.collect();
assert_all_approx_eq(
&interp,
&[
f32::NAN,
f32::NAN,
1.0,
1.25,
1.5,
1.75,
2.0,
f32::NAN,
f32::NAN,
],
);
}
}