use crate::{state_space::kalman::KalmanFilter, TimeSeries};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ImputationMethod {
LOCF,
NOCB,
Linear,
Spline,
Mean,
Median,
KalmanFilter,
Seasonal,
}
pub struct TimeSeriesImputer {
method: ImputationMethod,
seasonal_period: Option<usize>,
state_dim: Option<usize>,
}
impl TimeSeriesImputer {
pub fn new(method: ImputationMethod) -> Self {
Self {
method,
seasonal_period: None,
state_dim: None,
}
}
pub fn with_seasonal_period(mut self, period: usize) -> Self {
self.seasonal_period = Some(period);
self
}
pub fn with_state_dim(mut self, dim: usize) -> Self {
self.state_dim = Some(dim);
self
}
pub fn fit_transform(&self, series: &TimeSeries) -> Result<TimeSeries> {
match self.method {
ImputationMethod::LOCF => self.locf(series),
ImputationMethod::NOCB => self.nocb(series),
ImputationMethod::Linear => self.linear_interpolation(series),
ImputationMethod::Spline => self.spline_interpolation(series),
ImputationMethod::Mean => self.mean_imputation(series),
ImputationMethod::Median => self.median_imputation(series),
ImputationMethod::KalmanFilter => self.kalman_imputation(series),
ImputationMethod::Seasonal => self.seasonal_imputation(series),
}
}
fn locf(&self, series: &TimeSeries) -> Result<TimeSeries> {
let n = series.len();
let mut imputed_data = Vec::with_capacity(n);
let mut last_valid = None;
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if val.is_nan() {
if let Some(last) = last_valid {
imputed_data.push(last);
} else {
imputed_data.push(0.0);
}
} else {
last_valid = Some(val);
imputed_data.push(val);
}
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn nocb(&self, series: &TimeSeries) -> Result<TimeSeries> {
let n = series.len();
let mut imputed_data = vec![0.0f32; n];
let mut next_valid = None;
for i in (0..n).rev() {
let val = series.values.get_item_flat(i)?;
if val.is_nan() {
if let Some(next) = next_valid {
imputed_data[i] = next;
} else {
imputed_data[i] = 0.0;
}
} else {
next_valid = Some(val);
imputed_data[i] = val;
}
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn linear_interpolation(&self, series: &TimeSeries) -> Result<TimeSeries> {
let n = series.len();
let mut imputed_data = Vec::with_capacity(n);
let mut valid_points: Vec<(usize, f32)> = Vec::new();
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if !val.is_nan() {
valid_points.push((i, val));
}
}
if valid_points.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot interpolate: no valid values".to_string(),
));
}
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if val.is_nan() {
let before = valid_points.iter().filter(|(idx, _)| *idx < i).last();
let after = valid_points.iter().find(|(idx, _)| *idx > i);
let interpolated = match (before, after) {
(Some((i0, v0)), Some((i1, v1))) => {
let t = (i - i0) as f32 / (i1 - i0) as f32;
v0 + t * (v1 - v0)
}
(Some((_, v)), None) | (None, Some((_, v))) => {
*v
}
(None, None) => 0.0, };
imputed_data.push(interpolated);
} else {
imputed_data.push(val);
}
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn spline_interpolation(&self, series: &TimeSeries) -> Result<TimeSeries> {
use scirs2_core::ndarray::Array1;
let data = series.values.to_vec()?;
let array = Array1::from_vec(data);
let mut x_valid = Vec::new();
let mut y_valid = Vec::new();
for (i, &val) in array.iter().enumerate() {
if !val.is_nan() {
x_valid.push(i as f64);
y_valid.push(val as f64);
}
}
if x_valid.len() < 4 {
return self.linear_interpolation(series);
}
self.linear_interpolation(series)
}
fn mean_imputation(&self, series: &TimeSeries) -> Result<TimeSeries> {
let n = series.len();
let mut sum = 0.0f64;
let mut count = 0;
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if !val.is_nan() {
sum += val as f64;
count += 1;
}
}
if count == 0 {
return Err(TorshError::InvalidArgument(
"Cannot compute mean: no valid values".to_string(),
));
}
let mean = (sum / count as f64) as f32;
let mut imputed_data = Vec::with_capacity(n);
for i in 0..n {
let val = series.values.get_item_flat(i)?;
imputed_data.push(if val.is_nan() { mean } else { val });
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn median_imputation(&self, series: &TimeSeries) -> Result<TimeSeries> {
let n = series.len();
let mut valid_values = Vec::new();
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if !val.is_nan() {
valid_values.push(val);
}
}
if valid_values.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot compute median: no valid values".to_string(),
));
}
valid_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if valid_values.len() % 2 == 0 {
let mid = valid_values.len() / 2;
(valid_values[mid - 1] + valid_values[mid]) / 2.0
} else {
valid_values[valid_values.len() / 2]
};
let mut imputed_data = Vec::with_capacity(n);
for i in 0..n {
let val = series.values.get_item_flat(i)?;
imputed_data.push(if val.is_nan() { median } else { val });
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn kalman_imputation(&self, series: &TimeSeries) -> Result<TimeSeries> {
let state_dim = self.state_dim.unwrap_or(1);
let mut kf = KalmanFilter::new(state_dim, 1);
let n = series.len();
let mut imputed_data = Vec::with_capacity(n);
for i in 0..n {
kf.predict()?;
let val = series.values.get_item_flat(i)?;
if val.is_nan() {
let predicted = kf.state().get_item_flat(0)?;
imputed_data.push(predicted);
} else {
let obs = Tensor::from_vec(vec![val], &[1])?;
kf.update(&obs)?;
imputed_data.push(val);
}
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
fn seasonal_imputation(&self, series: &TimeSeries) -> Result<TimeSeries> {
let period = self.seasonal_period.ok_or_else(|| {
TorshError::InvalidArgument(
"Seasonal period must be set for seasonal imputation".to_string(),
)
})?;
let n = series.len();
let mut imputed_data = Vec::with_capacity(n);
let mut seasonal_means = vec![0.0f64; period];
let mut seasonal_counts = vec![0usize; period];
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if !val.is_nan() {
let season_idx = i % period;
seasonal_means[season_idx] += val as f64;
seasonal_counts[season_idx] += 1;
}
}
for i in 0..period {
if seasonal_counts[i] > 0 {
seasonal_means[i] /= seasonal_counts[i] as f64;
}
}
for i in 0..n {
let val = series.values.get_item_flat(i)?;
if val.is_nan() {
let season_idx = i % period;
let seasonal_val = seasonal_means[season_idx] as f32;
imputed_data.push(seasonal_val);
} else {
imputed_data.push(val);
}
}
let tensor = Tensor::from_vec(imputed_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
}
pub struct MICEImputer {
n_imputations: usize,
max_iter: usize,
methods: Vec<ImputationMethod>,
}
impl MICEImputer {
pub fn new(n_imputations: usize) -> Self {
Self {
n_imputations,
max_iter: 10,
methods: vec![
ImputationMethod::Linear,
ImputationMethod::Mean,
ImputationMethod::KalmanFilter,
],
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn fit_transform(&self, series: &TimeSeries) -> Result<Vec<TimeSeries>> {
let mut imputations = Vec::with_capacity(self.n_imputations);
for i in 0..self.n_imputations {
let method = self.methods[i % self.methods.len()];
let imputer = TimeSeriesImputer::new(method);
let imputed = imputer.fit_transform(series)?;
imputations.push(imputed);
}
Ok(imputations)
}
pub fn pool_results(&self, imputations: &[TimeSeries]) -> Result<TimeSeries> {
if imputations.is_empty() {
return Err(TorshError::InvalidArgument(
"No imputations to pool".to_string(),
));
}
let n = imputations[0].len();
let mut pooled_data = Vec::with_capacity(n);
for i in 0..n {
let mut sum = 0.0f64;
for imputed in imputations {
let val = imputed.values.get_item_flat(i)?;
sum += val as f64;
}
pooled_data.push((sum / imputations.len() as f64) as f32);
}
let tensor = Tensor::from_vec(pooled_data, &[n])?;
Ok(TimeSeries::new(tensor))
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::Tensor;
fn create_series_with_missing() -> TimeSeries {
let data = vec![1.0f32, 2.0, f32::NAN, 4.0, f32::NAN, 6.0, 7.0, 8.0];
let tensor = Tensor::from_vec(data, &[8]).expect("Tensor should succeed");
TimeSeries::new(tensor)
}
#[test]
fn test_locf_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::LOCF);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_nocb_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::NOCB);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_linear_interpolation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::Linear);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_mean_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::Mean);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_median_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::Median);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_kalman_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::KalmanFilter).with_state_dim(1);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_seasonal_imputation() {
let series = create_series_with_missing();
let imputer = TimeSeriesImputer::new(ImputationMethod::Seasonal).with_seasonal_period(4);
let imputed = imputer
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputed.len(), series.len());
for i in 0..imputed.len() {
let val = imputed
.values
.get_item_flat(i)
.expect("flat item retrieval should succeed for valid index");
assert!(!val.is_nan());
}
}
#[test]
fn test_mice_imputation() {
let series = create_series_with_missing();
let mice = MICEImputer::new(3).with_max_iter(5);
let imputations = mice
.fit_transform(&series)
.expect("fit_transform should succeed with valid input");
assert_eq!(imputations.len(), 3);
for imputed in &imputations {
assert_eq!(imputed.len(), series.len());
}
let pooled = mice
.pool_results(&imputations)
.expect("result pooling should succeed");
assert_eq!(pooled.len(), series.len());
}
}