oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
use crate::core::traits::Transformer;

/// Interpolation strategy for imputing missing values.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterpolationStrategy {
    Linear,
    Nearest,
    Previous,
    Next,
}

#[derive(Debug, Clone)]
pub struct InterpolationImputerConfig {
    pub strategy: InterpolationStrategy,
}

impl InterpolationImputerConfig {
    pub fn new() -> Self {
        Self {
            strategy: InterpolationStrategy::Linear,
        }
    }
}

impl Default for InterpolationImputerConfig {
    fn default() -> Self {
        Self::new()
    }
}

pub struct InterpolationImputer;

impl Transformer for InterpolationImputer {
    type Config = InterpolationImputerConfig;

    fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
        assert!(!x.is_empty(), "Input must have at least one sample");

        x.iter()
            .map(|sample| impute_single(sample, config.strategy))
            .collect()
    }
}

fn impute_single(x: &[f64], strategy: InterpolationStrategy) -> Vec<f64> {
    let n = x.len();

    // Find non-NaN indices and values
    let known: Vec<(usize, f64)> = x
        .iter()
        .enumerate()
        .filter(|(_, &v)| !v.is_nan())
        .map(|(i, &v)| (i, v))
        .collect();

    // If no NaN values, return as-is
    if known.len() == n {
        return x.to_vec();
    }

    assert!(
        known.len() >= 2,
        "At least 2 non-missing values required for interpolation"
    );

    let indices: Vec<f64> = known.iter().map(|&(i, _)| i as f64).collect();
    let values: Vec<f64> = known.iter().map(|&(_, v)| v).collect();

    (0..n)
        .map(|i| {
            if !x[i].is_nan() {
                x[i]
            } else {
                match strategy {
                    InterpolationStrategy::Linear => {
                        linear_interp_extrapolate(i as f64, &indices, &values)
                    }
                    InterpolationStrategy::Nearest => nearest_interp(i as f64, &indices, &values),
                    InterpolationStrategy::Previous => previous_interp(i, &known),
                    InterpolationStrategy::Next => next_interp(i, &known),
                }
            }
        })
        .collect()
}

/// Linear interpolation with extrapolation for points outside the known range.
fn linear_interp_extrapolate(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
    if x <= xs[0] {
        // Extrapolate left using first two points
        if xs.len() == 1 {
            return ys[0];
        }
        let slope = (ys[1] - ys[0]) / (xs[1] - xs[0]);
        return ys[0] + slope * (x - xs[0]);
    }
    if x >= xs[xs.len() - 1] {
        // Extrapolate right using last two points
        let n = xs.len();
        if n == 1 {
            return ys[0];
        }
        let slope = (ys[n - 1] - ys[n - 2]) / (xs[n - 1] - xs[n - 2]);
        return ys[n - 1] + slope * (x - xs[n - 1]);
    }

    // Binary search for the interval
    let pos = xs.partition_point(|&xi| xi < x);
    let lo = pos - 1;
    let hi = pos;
    let frac = (x - xs[lo]) / (xs[hi] - xs[lo]);
    ys[lo] + frac * (ys[hi] - ys[lo])
}

/// Nearest neighbor interpolation.
fn nearest_interp(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
    let mut best_idx = 0;
    let mut best_dist = f64::INFINITY;
    for (i, &xi) in xs.iter().enumerate() {
        let dist = (x - xi).abs();
        if dist < best_dist {
            best_dist = dist;
            best_idx = i;
        }
    }
    ys[best_idx]
}

/// Previous value interpolation: use the last known value before position i.
fn previous_interp(i: usize, known: &[(usize, f64)]) -> f64 {
    for &(ki, kv) in known.iter().rev() {
        if ki <= i {
            return kv;
        }
    }
    // If no previous value, use the first known value (extrapolation)
    known[0].1
}

/// Next value interpolation: use the next known value after position i.
fn next_interp(i: usize, known: &[(usize, f64)]) -> f64 {
    for &(ki, kv) in known.iter() {
        if ki >= i {
            return kv;
        }
    }
    // If no next value, use the last known value (extrapolation)
    known[known.len() - 1].1
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_no_nan() {
        let config = InterpolationImputerConfig::new();
        let x = vec![vec![1.0, 2.0, 3.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
    }

    #[test]
    fn test_linear_middle() {
        let config = InterpolationImputerConfig::new();
        let x = vec![vec![1.0, f64::NAN, 3.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert!((result[0][1] - 2.0).abs() < 1e-10);
    }

    #[test]
    fn test_linear_extrapolate_left() {
        let config = InterpolationImputerConfig::new();
        let x = vec![vec![f64::NAN, 2.0, 4.0]];
        let result = InterpolationImputer::transform(&config, &x);
        // Extrapolate left: slope = 2.0, value at 0 = 2.0 - 2.0 = 0.0
        assert!((result[0][0] - 0.0).abs() < 1e-10);
    }

    #[test]
    fn test_linear_extrapolate_right() {
        let config = InterpolationImputerConfig::new();
        let x = vec![vec![1.0, 3.0, f64::NAN]];
        let result = InterpolationImputer::transform(&config, &x);
        // Extrapolate right: slope = 2.0, value at 2 = 3.0 + 2.0 = 5.0
        assert!((result[0][2] - 5.0).abs() < 1e-10);
    }

    #[test]
    fn test_nearest() {
        let config = InterpolationImputerConfig {
            strategy: InterpolationStrategy::Nearest,
        };
        let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert!((result[0][1] - 1.0).abs() < 1e-10); // closer to index 0
        assert!((result[0][2] - 10.0).abs() < 1e-10); // closer to index 3
    }

    #[test]
    fn test_previous() {
        let config = InterpolationImputerConfig {
            strategy: InterpolationStrategy::Previous,
        };
        let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert!((result[0][1] - 1.0).abs() < 1e-10);
        assert!((result[0][2] - 1.0).abs() < 1e-10);
    }

    #[test]
    fn test_next() {
        let config = InterpolationImputerConfig {
            strategy: InterpolationStrategy::Next,
        };
        let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert!((result[0][1] - 10.0).abs() < 1e-10);
        assert!((result[0][2] - 10.0).abs() < 1e-10);
    }

    #[test]
    fn test_multiple_samples() {
        let config = InterpolationImputerConfig::new();
        let x = vec![vec![1.0, f64::NAN, 3.0], vec![10.0, f64::NAN, 30.0]];
        let result = InterpolationImputer::transform(&config, &x);
        assert!((result[0][1] - 2.0).abs() < 1e-10);
        assert!((result[1][1] - 20.0).abs() < 1e-10);
    }
}