use std::fmt;
pub trait MissingDataHandler<T>: Send + Sync {
fn handle(&self, series: &[Option<T>], index: usize) -> Option<T>;
fn requires_context(&self) -> bool {
true
}
fn window_size(&self) -> Option<usize> {
None
}
}
#[derive(Debug, Clone)]
pub enum MissingDataStrategy {
LinearInterpolation,
ForwardFill,
BackwardFill,
NearestNeighbor,
MeanImputation {
window_size: usize,
},
MedianImputation {
window_size: usize,
},
ZeroFill,
Drop,
Fallback {
primary: Box<MissingDataStrategy>,
fallback: Box<MissingDataStrategy>,
},
}
impl MissingDataStrategy {
pub fn with_fallback(self, fallback: MissingDataStrategy) -> Self {
MissingDataStrategy::Fallback {
primary: Box::new(self),
fallback: Box::new(fallback),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ImputationError {
BoundaryImpossible {
index: usize,
},
NoValidValues {
index: usize,
window: usize,
},
AllStrategiesFailed {
index: usize,
},
}
impl fmt::Display for ImputationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ImputationError::BoundaryImpossible { index } => {
write!(f, "Cannot impute at boundary index {}", index)
}
ImputationError::NoValidValues { index, window } => {
write!(
f,
"No valid values found in window of size {} around index {}",
window, index
)
}
ImputationError::AllStrategiesFailed { index } => {
write!(f, "All imputation strategies failed at index {}", index)
}
}
}
}
impl std::error::Error for ImputationError {}
impl MissingDataStrategy {
fn find_prev_value<T: Copy>(series: &[Option<T>], index: usize) -> Option<T> {
series[..index]
.iter()
.rev()
.find_map(|&v| v)
}
fn find_next_value<T: Copy>(series: &[Option<T>], index: usize) -> Option<T> {
series.get(index + 1..)
.and_then(|slice| slice.iter().find_map(|&v| v))
}
fn handle_linear_interpolation<T>(series: &[Option<T>], index: usize) -> Option<T>
where
T: Copy + std::ops::Add<Output = T> + std::ops::Div<Output = T> + From<f64>,
{
if index == 0 || index >= series.len() - 1 {
return None;
}
let prev_val = Self::find_prev_value(series, index);
let next_val = Self::find_next_value(series, index);
match (prev_val, next_val) {
(Some(p), Some(n)) => Some((p + n) / T::from(2.0)),
_ => None,
}
}
fn handle_forward_fill<T: Copy>(series: &[Option<T>], index: usize) -> Option<T> {
Self::find_prev_value(series, index)
}
fn handle_backward_fill<T: Copy>(series: &[Option<T>], index: usize) -> Option<T> {
Self::find_next_value(series, index)
}
fn handle_nearest_neighbor<T: Copy>(series: &[Option<T>], index: usize) -> Option<T> {
let (prev_val, prev_dist) = series[..index]
.iter()
.enumerate()
.rev()
.find_map(|(i, &v)| v.map(|val| (val, index - i)))
.unwrap_or((None?, usize::MAX));
let (next_val, next_dist) = series.get(index + 1..)
.and_then(|slice| {
slice.iter()
.enumerate()
.find_map(|(offset, &v)| v.map(|val| (val, offset + 1)))
})
.unwrap_or((None?, usize::MAX));
match (prev_dist, next_dist) {
(usize::MAX, usize::MAX) => None,
(_, usize::MAX) => Some(prev_val),
(usize::MAX, _) => Some(next_val),
(pd, nd) if pd <= nd => Some(prev_val),
_ => Some(next_val),
}
}
fn handle_mean_imputation<T>(
series: &[Option<T>],
index: usize,
window_size: usize,
) -> Option<T>
where
T: Copy + std::ops::Add<Output = T> + std::ops::Div<Output = T> + From<f64>,
{
let start = index.saturating_sub(window_size / 2);
let end = (index + window_size / 2 + 1).min(series.len());
let (sum, count) = series[start..end]
.iter()
.enumerate()
.filter(|(i, _)| start + i != index)
.filter_map(|(_, &val)| val)
.fold((None, 0), |(acc, cnt), v| {
(Some(acc.map_or(v, |s| s + v)), cnt + 1)
});
if count > 0 {
sum.map(|s| s / T::from(count as f64))
} else {
None
}
}
fn handle_median_imputation<T>(
series: &[Option<T>],
index: usize,
window_size: usize,
) -> Option<T>
where
T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Div<Output = T> + From<f64>,
{
let start = index.saturating_sub(window_size / 2);
let end = (index + window_size / 2 + 1).min(series.len());
let mut values: Vec<T> = series[start..end]
.iter()
.enumerate()
.filter(|(i, _)| start + i != index)
.filter_map(|(_, &val)| val)
.collect();
if values.is_empty() {
return None;
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = values.len() / 2;
if values.len() % 2 == 0 {
Some((values[mid - 1] + values[mid]) / T::from(2.0))
} else {
Some(values[mid])
}
}
}
impl<T> MissingDataHandler<T> for MissingDataStrategy
where
T: Copy + PartialOrd + std::ops::Add<Output = T> + std::ops::Div<Output = T> + From<f64>,
{
fn handle(&self, series: &[Option<T>], index: usize) -> Option<T> {
match self {
MissingDataStrategy::LinearInterpolation => {
Self::handle_linear_interpolation(series, index)
}
MissingDataStrategy::ForwardFill => {
Self::handle_forward_fill(series, index)
}
MissingDataStrategy::BackwardFill => {
Self::handle_backward_fill(series, index)
}
MissingDataStrategy::NearestNeighbor => {
Self::handle_nearest_neighbor(series, index)
}
MissingDataStrategy::MeanImputation { window_size } => {
Self::handle_mean_imputation(series, index, *window_size)
}
MissingDataStrategy::MedianImputation { window_size } => {
Self::handle_median_imputation(series, index, *window_size)
}
MissingDataStrategy::ZeroFill => Some(T::from(0.0)),
MissingDataStrategy::Drop => None,
MissingDataStrategy::Fallback { primary, fallback } => {
primary.handle(series, index).or_else(|| fallback.handle(series, index))
}
}
}
}