use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
use ndarray::{Array1, Array2};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ImputeStrategy {
Mean,
Median,
MostFrequent,
Constant,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub struct SimpleImputer<F: Float> {
strategy: ImputeStrategy,
fill_value: Option<F>,
}
impl<F: Float> SimpleImputer<F> {
pub fn new() -> Self {
Self {
strategy: ImputeStrategy::Mean,
fill_value: None,
}
}
pub fn with_strategy(mut self, strategy: ImputeStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_fill_value(mut self, value: F) -> Self {
self.fill_value = Some(value);
self
}
}
impl<F: Float> Default for SimpleImputer<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
pub struct FittedSimpleImputer<F: Float> {
fill_values: Array1<F>,
}
impl<F: Float> FittedSimpleImputer<F> {
pub fn fill_values(&self) -> &Array1<F> {
&self.fill_values
}
}
fn column_mean<F: Float>(values: &[F]) -> Option<F> {
let mut sum = F::zero();
let mut count = 0usize;
for &v in values {
if !v.is_nan() {
sum = sum + v;
count += 1;
}
}
if count == 0 {
None
} else {
Some(sum / F::from_usize(count).unwrap())
}
}
fn column_median<F: Float>(values: &[F]) -> Option<F> {
let mut valid: Vec<F> = values.iter().copied().filter(|v| !v.is_nan()).collect();
if valid.is_empty() {
return None;
}
valid.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = valid.len();
if n % 2 == 1 {
Some(valid[n / 2])
} else {
Some((valid[n / 2 - 1] + valid[n / 2]) / F::from_f64(2.0).unwrap())
}
}
fn column_most_frequent<F: Float>(values: &[F]) -> Option<F> {
let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
for &v in values {
if v.is_nan() {
continue;
}
let bits = v.to_f64().unwrap().to_bits();
counts
.entry(bits)
.and_modify(|e| e.1 += 1)
.or_insert((v, 1));
}
if counts.is_empty() {
return None;
}
counts
.values()
.max_by(|a, b| a.1.cmp(&b.1).then_with(|| b.0.partial_cmp(&a.0).unwrap()))
.map(|&(v, _)| v)
}
impl<F: Float> FitUnsupervised<F> for SimpleImputer<F> {
type Fitted = FittedSimpleImputer<F>;
fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
if x.is_empty() {
return Err(RustMlError::EmptyInput("input array is empty".into()));
}
if self.strategy == ImputeStrategy::Constant {
let fill = self.fill_value.unwrap_or_else(F::zero);
let fill_values = Array1::from_elem(x.ncols(), fill);
return Ok(FittedSimpleImputer { fill_values });
}
let ncols = x.ncols();
let mut fill_values = Array1::<F>::zeros(ncols);
for j in 0..ncols {
let col: Vec<F> = x.column(j).to_vec();
let computed = match self.strategy {
ImputeStrategy::Mean => column_mean(&col),
ImputeStrategy::Median => column_median(&col),
ImputeStrategy::MostFrequent => column_most_frequent(&col),
ImputeStrategy::Constant => unreachable!(),
};
match computed {
Some(v) => fill_values[j] = v,
None => {
return Err(RustMlError::InvalidParameter(format!(
"column {} contains only NaN values",
j
)));
}
}
}
Ok(FittedSimpleImputer { fill_values })
}
}
impl<F: Float> Transform<F> for FittedSimpleImputer<F> {
fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
if x.ncols() != self.fill_values.len() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.fill_values.len(),
x.ncols()
)));
}
let mut result = x.to_owned();
for mut row in result.rows_mut() {
for (j, val) in row.iter_mut().enumerate() {
if val.is_nan() {
*val = self.fill_values[j];
}
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn test_mean_strategy_basic() {
let x = array![[1.0, f64::NAN], [2.0, 4.0], [3.0, 6.0],];
let imputer = SimpleImputer::<f64>::new();
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 1.0);
assert_abs_diff_eq!(result[[1, 0]], 2.0);
assert_abs_diff_eq!(result[[2, 0]], 3.0);
assert_abs_diff_eq!(result[[0, 1]], 5.0);
assert_abs_diff_eq!(result[[1, 1]], 4.0);
assert_abs_diff_eq!(result[[2, 1]], 6.0);
}
#[test]
fn test_median_strategy() {
let x = array![[f64::NAN, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0],];
let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 4.0);
assert_abs_diff_eq!(result[[0, 1]], 1.0);
}
#[test]
fn test_most_frequent_strategy() {
let x = array![[1.0, f64::NAN], [2.0, 3.0], [2.0, 3.0], [3.0, 5.0],];
let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::MostFrequent);
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 1.0); assert_abs_diff_eq!(result[[0, 1]], 3.0); }
#[test]
fn test_constant_strategy() {
let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
let imputer = SimpleImputer::<f64>::new()
.with_strategy(ImputeStrategy::Constant)
.with_fill_value(-999.0);
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], -999.0);
assert_abs_diff_eq!(result[[0, 1]], 1.0);
assert_abs_diff_eq!(result[[1, 0]], 2.0);
assert_abs_diff_eq!(result[[1, 1]], -999.0);
}
#[test]
fn test_mixed_nan_positions() {
let x = array![
[f64::NAN, 2.0, f64::NAN],
[1.0, f64::NAN, 6.0],
[3.0, 4.0, f64::NAN],
[5.0, 6.0, 8.0],
];
let imputer = SimpleImputer::<f64>::new();
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 3.0);
assert_abs_diff_eq!(result[[1, 1]], 4.0);
assert_abs_diff_eq!(result[[0, 2]], 7.0);
assert_abs_diff_eq!(result[[2, 2]], 7.0);
assert_abs_diff_eq!(result[[3, 0]], 5.0);
assert_abs_diff_eq!(result[[3, 1]], 6.0);
assert_abs_diff_eq!(result[[3, 2]], 8.0);
}
#[test]
fn test_all_nan_column_error() {
let x = array![[1.0, f64::NAN], [2.0, f64::NAN], [3.0, f64::NAN],];
let imputer = SimpleImputer::<f64>::new();
let result = FitUnsupervised::<f64>::fit(&imputer, &x);
assert!(result.is_err());
let err = result.unwrap_err();
let msg = format!("{}", err);
assert!(
msg.contains("column 1"),
"error should mention column index: {}",
msg
);
assert!(msg.contains("NaN"), "error should mention NaN: {}", msg);
}
#[test]
fn test_no_nan_passthrough() {
let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
let imputer = SimpleImputer::<f64>::new();
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
for (a, b) in x.iter().zip(result.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-15);
}
}
#[test]
fn test_shape_mismatch_on_transform() {
let x_fit = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0],];
let x_transform = array![[1.0, 2.0], [3.0, 4.0],];
let imputer = SimpleImputer::<f64>::new();
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x_fit).unwrap();
let result = fitted.transform(&x_transform);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("3") && msg.contains("2"),
"error should mention expected and actual: {}",
msg
);
}
#[test]
fn test_f32_support() {
let x = array![[1.0f32, f32::NAN], [3.0f32, 4.0f32], [5.0f32, 6.0f32],];
let imputer = SimpleImputer::<f32>::new();
let fitted = FitUnsupervised::<f32>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 1]], 5.0f32, epsilon = 1e-6);
assert_abs_diff_eq!(result[[0, 0]], 1.0f32, epsilon = 1e-6);
}
#[test]
fn test_constant_strategy_default_fill_value() {
let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Constant);
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
let result = fitted.transform(&x).unwrap();
assert_abs_diff_eq!(result[[0, 0]], 0.0);
assert_abs_diff_eq!(result[[1, 1]], 0.0);
}
#[test]
fn test_median_even_count() {
let x = array![[1.0], [3.0], [5.0], [7.0],];
let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
assert_abs_diff_eq!(fitted.fill_values()[0], 4.0);
}
}