use crate::interpolation::{InterpolationIndex, InterpolationValue, Interpolator};
use RustQuant_error::RustQuantError;
pub struct LinearInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
pub xs: Vec<IndexType>,
pub ys: Vec<ValueType>,
pub fitted: bool,
}
impl<IndexType, ValueType> LinearInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
pub fn new(
xs: Vec<IndexType>,
ys: Vec<ValueType>,
) -> Result<LinearInterpolator<IndexType, ValueType>, RustQuantError> {
if xs.len() != ys.len() {
return Err(RustQuantError::UnequalLength);
}
let mut tmp: Vec<_> = xs.into_iter().zip(ys).collect();
tmp.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let (xs, ys): (Vec<IndexType>, Vec<ValueType>) = tmp.into_iter().unzip();
Ok(Self {
xs,
ys,
fitted: false,
})
}
}
impl<IndexType, ValueType> Interpolator<IndexType, ValueType>
for LinearInterpolator<IndexType, ValueType>
where
IndexType: InterpolationIndex<DeltaDiv = ValueType>,
ValueType: InterpolationValue,
{
fn fit(&mut self) -> Result<(), RustQuantError> {
self.fitted = true;
Ok(())
}
fn range(&self) -> (IndexType, IndexType) {
(*self.xs.first().unwrap(), *self.xs.last().unwrap())
}
fn add_point(&mut self, point: (IndexType, ValueType)) {
let idx = self.xs.partition_point(|&x| x < point.0);
self.xs.insert(idx, point.0);
self.ys.insert(idx, point.1);
}
fn interpolate(&self, point: IndexType) -> Result<ValueType, RustQuantError> {
let range = self.range();
if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less
|| point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater
{
return Err(RustQuantError::OutsideOfRange);
}
if let Ok(idx) = self
.xs
.binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values."))
{
return Ok(self.ys[idx]);
}
let idx_r = self.xs.partition_point(|&x| x < point);
let idx_l = idx_r - 1;
let x_l = self.xs[idx_l];
let x_r = self.xs[idx_r];
let y_l = self.ys[idx_l];
let y_r = self.ys[idx_r];
let term_1 = y_r - y_l;
let term_2 = (point - x_l) / (x_r - x_l);
let result = y_l + term_1 * term_2;
Ok(result)
}
}
#[cfg(test)]
mod tests_linear_interpolation {
use super::*;
use time::macros::date;
use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON};
#[test]
fn test_linear_interpolation() {
let xs = vec![1., 2., 3., 4., 5.];
let ys = vec![1., 2., 3., 4., 5.];
let mut interpolator = LinearInterpolator::new(xs, ys).unwrap();
let _ = interpolator.fit();
assert_approx_equal!(
2.5,
interpolator.interpolate(2.5).unwrap(),
RUSTQUANT_EPSILON
);
assert_approx_equal!(
3.5,
interpolator.interpolate(3.5).unwrap(),
RUSTQUANT_EPSILON
);
}
#[test]
fn test_linear_interpolation_out_of_range() {
let xs = vec![1., 2., 3., 4., 5.];
let ys = vec![1., 2., 3., 4., 5.];
let mut interpolator = LinearInterpolator::new(xs, ys).unwrap();
let _ = interpolator.fit();
assert!(interpolator.interpolate(6.).is_err());
}
#[test]
fn test_linear_interpolation_dates() {
let now = time::OffsetDateTime::now_utc();
let xs = vec![
now,
now + time::Duration::days(1),
now + time::Duration::days(2),
now + time::Duration::days(3),
now + time::Duration::days(4),
];
let ys = vec![1., 2., 3., 4., 5.];
let mut interpolator = LinearInterpolator::new(xs.clone(), ys).unwrap();
let _ = interpolator.fit();
assert_approx_equal!(
2.5,
interpolator
.interpolate(xs[1] + time::Duration::hours(12))
.unwrap(),
RUSTQUANT_EPSILON
);
}
#[test]
fn test_linear_interpolation_dates_textbook() {
let d_1m = date!(1990 - 06 - 16);
let d_2m = date!(1990 - 07 - 17);
let r_1m = 0.9870;
let r_2m = 0.9753;
let dates = vec![d_1m, d_2m];
let rates = vec![r_1m, r_2m];
let interpolator = LinearInterpolator::new(dates, rates).unwrap();
let d = date!(1990 - 06 - 20);
assert_approx_equal!(
interpolator.interpolate(d).unwrap(),
0.9854903225806452,
RUSTQUANT_EPSILON
);
}
}