use crate::core::traits::Transformer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterpolationStrategy {
Linear,
Nearest,
Previous,
Next,
}
#[derive(Debug, Clone)]
pub struct InterpolationImputerConfig {
pub strategy: InterpolationStrategy,
}
impl InterpolationImputerConfig {
pub fn new() -> Self {
Self {
strategy: InterpolationStrategy::Linear,
}
}
}
impl Default for InterpolationImputerConfig {
fn default() -> Self {
Self::new()
}
}
pub struct InterpolationImputer;
impl Transformer for InterpolationImputer {
type Config = InterpolationImputerConfig;
fn transform(config: &Self::Config, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
assert!(!x.is_empty(), "Input must have at least one sample");
x.iter()
.map(|sample| impute_single(sample, config.strategy))
.collect()
}
}
fn impute_single(x: &[f64], strategy: InterpolationStrategy) -> Vec<f64> {
let n = x.len();
let known: Vec<(usize, f64)> = x
.iter()
.enumerate()
.filter(|(_, &v)| !v.is_nan())
.map(|(i, &v)| (i, v))
.collect();
if known.len() == n {
return x.to_vec();
}
assert!(
known.len() >= 2,
"At least 2 non-missing values required for interpolation"
);
let indices: Vec<f64> = known.iter().map(|&(i, _)| i as f64).collect();
let values: Vec<f64> = known.iter().map(|&(_, v)| v).collect();
(0..n)
.map(|i| {
if !x[i].is_nan() {
x[i]
} else {
match strategy {
InterpolationStrategy::Linear => {
linear_interp_extrapolate(i as f64, &indices, &values)
}
InterpolationStrategy::Nearest => nearest_interp(i as f64, &indices, &values),
InterpolationStrategy::Previous => previous_interp(i, &known),
InterpolationStrategy::Next => next_interp(i, &known),
}
}
})
.collect()
}
fn linear_interp_extrapolate(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
if x <= xs[0] {
if xs.len() == 1 {
return ys[0];
}
let slope = (ys[1] - ys[0]) / (xs[1] - xs[0]);
return ys[0] + slope * (x - xs[0]);
}
if x >= xs[xs.len() - 1] {
let n = xs.len();
if n == 1 {
return ys[0];
}
let slope = (ys[n - 1] - ys[n - 2]) / (xs[n - 1] - xs[n - 2]);
return ys[n - 1] + slope * (x - xs[n - 1]);
}
let pos = xs.partition_point(|&xi| xi < x);
let lo = pos - 1;
let hi = pos;
let frac = (x - xs[lo]) / (xs[hi] - xs[lo]);
ys[lo] + frac * (ys[hi] - ys[lo])
}
fn nearest_interp(x: f64, xs: &[f64], ys: &[f64]) -> f64 {
let mut best_idx = 0;
let mut best_dist = f64::INFINITY;
for (i, &xi) in xs.iter().enumerate() {
let dist = (x - xi).abs();
if dist < best_dist {
best_dist = dist;
best_idx = i;
}
}
ys[best_idx]
}
fn previous_interp(i: usize, known: &[(usize, f64)]) -> f64 {
for &(ki, kv) in known.iter().rev() {
if ki <= i {
return kv;
}
}
known[0].1
}
fn next_interp(i: usize, known: &[(usize, f64)]) -> f64 {
for &(ki, kv) in known.iter() {
if ki >= i {
return kv;
}
}
known[known.len() - 1].1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_nan() {
let config = InterpolationImputerConfig::new();
let x = vec![vec![1.0, 2.0, 3.0]];
let result = InterpolationImputer::transform(&config, &x);
assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_linear_middle() {
let config = InterpolationImputerConfig::new();
let x = vec![vec![1.0, f64::NAN, 3.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_linear_extrapolate_left() {
let config = InterpolationImputerConfig::new();
let x = vec![vec![f64::NAN, 2.0, 4.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][0] - 0.0).abs() < 1e-10);
}
#[test]
fn test_linear_extrapolate_right() {
let config = InterpolationImputerConfig::new();
let x = vec![vec![1.0, 3.0, f64::NAN]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][2] - 5.0).abs() < 1e-10);
}
#[test]
fn test_nearest() {
let config = InterpolationImputerConfig {
strategy: InterpolationStrategy::Nearest,
};
let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][1] - 1.0).abs() < 1e-10); assert!((result[0][2] - 10.0).abs() < 1e-10); }
#[test]
fn test_previous() {
let config = InterpolationImputerConfig {
strategy: InterpolationStrategy::Previous,
};
let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][1] - 1.0).abs() < 1e-10);
assert!((result[0][2] - 1.0).abs() < 1e-10);
}
#[test]
fn test_next() {
let config = InterpolationImputerConfig {
strategy: InterpolationStrategy::Next,
};
let x = vec![vec![1.0, f64::NAN, f64::NAN, 10.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][1] - 10.0).abs() < 1e-10);
assert!((result[0][2] - 10.0).abs() < 1e-10);
}
#[test]
fn test_multiple_samples() {
let config = InterpolationImputerConfig::new();
let x = vec![vec![1.0, f64::NAN, 3.0], vec![10.0, f64::NAN, 30.0]];
let result = InterpolationImputer::transform(&config, &x);
assert!((result[0][1] - 2.0).abs() < 1e-10);
assert!((result[1][1] - 20.0).abs() < 1e-10);
}
}