use arrow::{array::PrimitiveArray, datatypes::ArrowPrimitiveType};
use num_traits::NumCast;
#[derive(Debug, PartialEq)]
pub struct PreLookup(usize, f64);
pub fn find_index_ref<T>(array: &PrimitiveArray<T>, value: T::Native, after_event: bool) -> usize
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let mut row = 0;
for i in 1..array.len() {
if array.value(i) >= value {
break;
}
row = i;
}
if after_event {
while row < array.len() - 2 {
if array.value(row + 1) > value {
break;
}
row += 1;
}
}
row
}
pub fn find_index<T>(array: &PrimitiveArray<T>, value: T::Native, after_event: bool) -> usize
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let row = array
.values()
.iter()
.position(|&x| x >= value)
.unwrap_or(array.len())
.saturating_sub(1);
let row = if after_event {
array.values()[row..]
.iter()
.position(|&x| x > value)
.map(|i| (i + row).saturating_sub(1))
.unwrap_or(row)
} else {
row
};
row
}
impl PreLookup {
pub fn new<T>(array: &PrimitiveArray<T>, value: T::Native, after_event: bool) -> Self
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let index = find_index_ref(array, value, after_event);
if index == array.len() - 1 {
return Self(index, 1.0);
}
let t0: f64 = NumCast::from(array.value(index)).unwrap();
let t1: f64 = NumCast::from(array.value(index + 1)).unwrap();
let value: f64 = NumCast::from(value).unwrap();
let fraction = (value - t0) / (t1 - t0);
Self(index, fraction)
}
#[allow(dead_code)]
pub fn interpolate<T, A>(&self, array: &PrimitiveArray<T>) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: NumCast,
A: Interpolate,
{
A::interpolate(self, array)
}
#[allow(dead_code)]
pub fn next_index(&self) -> usize {
if self.1 < 1.0 { self.0 } else { self.0 + 1 }
}
}
pub trait Interpolate {
fn interpolate<T>(pre: &PreLookup, array: &PrimitiveArray<T>) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: NumCast;
}
#[allow(dead_code)]
pub struct Nearest;
impl Interpolate for Nearest {
fn interpolate<T>(pre: &PreLookup, array: &PrimitiveArray<T>) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let (index, fraction) = (pre.0, pre.1);
let index = if fraction < 0.5 { index } else { index + 1 }.min(array.len() - 1);
array.value(index)
}
}
pub struct Linear;
impl Interpolate for Linear {
fn interpolate<T>(pre: &PreLookup, array: &PrimitiveArray<T>) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let (index, fraction) = (pre.0.min(array.len() - 2), pre.1);
let t0: f64 = NumCast::from(array.value(index)).unwrap();
let t1: f64 = NumCast::from(array.value(index + 1)).unwrap();
NumCast::from(t0 + fraction * (t1 - t0)).unwrap()
}
}
#[allow(dead_code)]
pub struct Cubic;
impl Interpolate for Cubic {
fn interpolate<T>(pre: &PreLookup, array: &PrimitiveArray<T>) -> T::Native
where
T: ArrowPrimitiveType,
T::Native: NumCast,
{
let (index, fraction) = (pre.0.min(array.len() - 4), pre.1);
let t0: f64 = NumCast::from(array.value(index)).unwrap();
let t1: f64 = NumCast::from(array.value(index + 1)).unwrap();
let t2: f64 = NumCast::from(array.value(index + 2)).unwrap();
let t3: f64 = NumCast::from(array.value(index + 3)).unwrap();
let a = -0.5 * t0 + 1.5 * t1 - 1.5 * t2 + 0.5 * t3;
let b = t0 - 2.5 * t1 + 2.0 * t2 - 0.5 * t3;
let c = -0.5 * t0 + 0.5 * t2;
let d = t1;
NumCast::from(a * fraction.powi(3) + b * fraction.powi(2) + c * fraction + d).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::{Interpolate, Linear, PreLookup};
use arrow::{array::PrimitiveArray, datatypes::Int32Type};
#[test]
fn test_pre_lookup() {
let array = PrimitiveArray::from(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(PreLookup::new(&array, -1.0, false), PreLookup(0, -1.0));
assert_eq!(PreLookup::new(&array, 0.0, false), PreLookup(0, 0.0));
assert_eq!(PreLookup::new(&array, 0.5, false), PreLookup(0, 0.5));
assert_eq!(PreLookup::new(&array, 1.0, false), PreLookup(0, 1.0));
assert_eq!(PreLookup::new(&array, 1.5, false), PreLookup(1, 0.5));
assert_eq!(PreLookup::new(&array, 2.0, false), PreLookup(1, 1.0));
assert_eq!(PreLookup::new(&array, 2.5, false), PreLookup(2, 0.5));
assert_eq!(PreLookup::new(&array, 3.0, false), PreLookup(2, 1.0));
assert_eq!(PreLookup::new(&array, 3.5, false), PreLookup(3, 0.5));
assert_eq!(PreLookup::new(&array, 4.0, false), PreLookup(3, 1.0));
assert_eq!(PreLookup::new(&array, 5.0, false), PreLookup(4, 1.0));
}
#[test]
fn test_interpolation() {
let time = PrimitiveArray::from(vec![0.0, 2.0, 3.0]);
let value1 = PrimitiveArray::from(vec![0.0, 2.0, 4.0]);
let value2 = PrimitiveArray::<Int32Type>::from(vec![1, 3, 5]);
let pl0 = PreLookup::new(&time, 0.0, false);
let pl1 = PreLookup::new(&time, 1.0, false);
let pl15 = PreLookup::new(&time, 1.5, false);
assert_eq!(Linear::interpolate(&pl0, &value1), 0.0);
assert_eq!(Linear::interpolate(&pl0, &value2), 1);
assert_eq!(Linear::interpolate(&pl1, &value1), 1.0);
assert_eq!(Linear::interpolate(&pl1, &value2), 2);
assert_eq!(Linear::interpolate(&pl15, &value1), 1.5);
assert_eq!(Linear::interpolate(&pl15, &value2), 2);
}
}