Skip to main content

anofox_ml_preprocessing/
simple_imputer.rs

1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3use std::collections::HashMap;
4
5/// Strategy used to compute the fill value for missing (NaN) entries.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
7pub enum ImputeStrategy {
8    /// Replace NaN with the column mean (computed from non-NaN values).
9    Mean,
10    /// Replace NaN with the column median (computed from non-NaN values).
11    Median,
12    /// Replace NaN with the most frequent value in the column.
13    MostFrequent,
14    /// Replace NaN with a fixed `fill_value`.
15    Constant,
16}
17
18/// Parameters for `SimpleImputer` (unfitted state).
19///
20/// Fills missing values (`NaN`) in each column with a statistic computed from
21/// the non-missing values during [`FitUnsupervised::fit`].
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
24pub struct SimpleImputer<F: Float> {
25    strategy: ImputeStrategy,
26    fill_value: Option<F>,
27}
28
29impl<F: Float> SimpleImputer<F> {
30    /// Create a new `SimpleImputer` with the default [`ImputeStrategy::Mean`].
31    pub fn new() -> Self {
32        Self {
33            strategy: ImputeStrategy::Mean,
34            fill_value: None,
35        }
36    }
37
38    /// Set the imputation strategy.
39    pub fn with_strategy(mut self, strategy: ImputeStrategy) -> Self {
40        self.strategy = strategy;
41        self
42    }
43
44    /// Set the fill value used when strategy is [`ImputeStrategy::Constant`].
45    pub fn with_fill_value(mut self, value: F) -> Self {
46        self.fill_value = Some(value);
47        self
48    }
49}
50
51impl<F: Float> Default for SimpleImputer<F> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57/// Fitted `SimpleImputer` — holds one fill value per column.
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
60pub struct FittedSimpleImputer<F: Float> {
61    fill_values: Array1<F>,
62}
63
64impl<F: Float> FittedSimpleImputer<F> {
65    /// Return the per-column fill values learned during fit.
66    pub fn fill_values(&self) -> &Array1<F> {
67        &self.fill_values
68    }
69}
70
71/// Compute the mean of non-NaN values in `values`. Returns `None` if all are NaN.
72fn column_mean<F: Float>(values: &[F]) -> Option<F> {
73    let mut sum = F::zero();
74    let mut count = 0usize;
75    for &v in values {
76        if !v.is_nan() {
77            sum = sum + v;
78            count += 1;
79        }
80    }
81    if count == 0 {
82        None
83    } else {
84        Some(sum / F::from_usize(count).unwrap())
85    }
86}
87
88/// Compute the median of non-NaN values in `values`. Returns `None` if all are NaN.
89fn column_median<F: Float>(values: &[F]) -> Option<F> {
90    let mut valid: Vec<F> = values.iter().copied().filter(|v| !v.is_nan()).collect();
91    if valid.is_empty() {
92        return None;
93    }
94    valid.sort_by(|a, b| a.partial_cmp(b).unwrap());
95    let n = valid.len();
96    if n % 2 == 1 {
97        Some(valid[n / 2])
98    } else {
99        Some((valid[n / 2 - 1] + valid[n / 2]) / F::from_f64(2.0).unwrap())
100    }
101}
102
103/// Compute the most frequent non-NaN value in `values`. Returns `None` if all are NaN.
104/// Ties are broken by taking the smallest value.
105fn column_most_frequent<F: Float>(values: &[F]) -> Option<F> {
106    let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
107    for &v in values {
108        if v.is_nan() {
109            continue;
110        }
111        // Use bit representation as hash key for exact equality.
112        let bits = v.to_f64().unwrap().to_bits();
113        counts
114            .entry(bits)
115            .and_modify(|e| e.1 += 1)
116            .or_insert((v, 1));
117    }
118    if counts.is_empty() {
119        return None;
120    }
121    // Pick highest count, break ties by smallest value.
122    counts
123        .values()
124        .max_by(|a, b| a.1.cmp(&b.1).then_with(|| b.0.partial_cmp(&a.0).unwrap()))
125        .map(|&(v, _)| v)
126}
127
128impl<F: Float> FitUnsupervised<F> for SimpleImputer<F> {
129    type Fitted = FittedSimpleImputer<F>;
130
131    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
132        if x.is_empty() {
133            return Err(RustMlError::EmptyInput("input array is empty".into()));
134        }
135
136        if self.strategy == ImputeStrategy::Constant {
137            let fill = self.fill_value.unwrap_or_else(F::zero);
138            let fill_values = Array1::from_elem(x.ncols(), fill);
139            return Ok(FittedSimpleImputer { fill_values });
140        }
141
142        let ncols = x.ncols();
143        let mut fill_values = Array1::<F>::zeros(ncols);
144
145        for j in 0..ncols {
146            let col: Vec<F> = x.column(j).to_vec();
147            let computed = match self.strategy {
148                ImputeStrategy::Mean => column_mean(&col),
149                ImputeStrategy::Median => column_median(&col),
150                ImputeStrategy::MostFrequent => column_most_frequent(&col),
151                ImputeStrategy::Constant => unreachable!(),
152            };
153            match computed {
154                Some(v) => fill_values[j] = v,
155                None => {
156                    return Err(RustMlError::InvalidParameter(format!(
157                        "column {} contains only NaN values",
158                        j
159                    )));
160                }
161            }
162        }
163
164        Ok(FittedSimpleImputer { fill_values })
165    }
166}
167
168impl<F: Float> Transform<F> for FittedSimpleImputer<F> {
169    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
170        if x.ncols() != self.fill_values.len() {
171            return Err(RustMlError::ShapeMismatch(format!(
172                "expected {} features, got {}",
173                self.fill_values.len(),
174                x.ncols()
175            )));
176        }
177
178        let mut result = x.to_owned();
179        for mut row in result.rows_mut() {
180            for (j, val) in row.iter_mut().enumerate() {
181                if val.is_nan() {
182                    *val = self.fill_values[j];
183                }
184            }
185        }
186        Ok(result)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use approx::assert_abs_diff_eq;
194    use ndarray::array;
195
196    #[test]
197    fn test_mean_strategy_basic() {
198        let x = array![[1.0, f64::NAN], [2.0, 4.0], [3.0, 6.0],];
199        let imputer = SimpleImputer::<f64>::new();
200        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
201        let result = fitted.transform(&x).unwrap();
202
203        // Column 0: no NaN, unchanged
204        assert_abs_diff_eq!(result[[0, 0]], 1.0);
205        assert_abs_diff_eq!(result[[1, 0]], 2.0);
206        assert_abs_diff_eq!(result[[2, 0]], 3.0);
207        // Column 1: NaN replaced with mean of 4.0 and 6.0 = 5.0
208        assert_abs_diff_eq!(result[[0, 1]], 5.0);
209        assert_abs_diff_eq!(result[[1, 1]], 4.0);
210        assert_abs_diff_eq!(result[[2, 1]], 6.0);
211    }
212
213    #[test]
214    fn test_median_strategy() {
215        let x = array![[f64::NAN, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0],];
216        let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
217        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
218        let result = fitted.transform(&x).unwrap();
219
220        // Column 0: valid = [2, 4, 6], median = 4.0
221        assert_abs_diff_eq!(result[[0, 0]], 4.0);
222        // Column 1: no NaN
223        assert_abs_diff_eq!(result[[0, 1]], 1.0);
224    }
225
226    #[test]
227    fn test_most_frequent_strategy() {
228        let x = array![[1.0, f64::NAN], [2.0, 3.0], [2.0, 3.0], [3.0, 5.0],];
229        let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::MostFrequent);
230        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
231        let result = fitted.transform(&x).unwrap();
232
233        // Column 0: most frequent = 2.0 (appears twice)
234        // Column 1: most frequent among [3,3,5] = 3.0
235        assert_abs_diff_eq!(result[[0, 0]], 1.0); // not NaN, unchanged
236        assert_abs_diff_eq!(result[[0, 1]], 3.0); // NaN replaced with 3.0
237    }
238
239    #[test]
240    fn test_constant_strategy() {
241        let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
242        let imputer = SimpleImputer::<f64>::new()
243            .with_strategy(ImputeStrategy::Constant)
244            .with_fill_value(-999.0);
245        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
246        let result = fitted.transform(&x).unwrap();
247
248        assert_abs_diff_eq!(result[[0, 0]], -999.0);
249        assert_abs_diff_eq!(result[[0, 1]], 1.0);
250        assert_abs_diff_eq!(result[[1, 0]], 2.0);
251        assert_abs_diff_eq!(result[[1, 1]], -999.0);
252    }
253
254    #[test]
255    fn test_mixed_nan_positions() {
256        let x = array![
257            [f64::NAN, 2.0, f64::NAN],
258            [1.0, f64::NAN, 6.0],
259            [3.0, 4.0, f64::NAN],
260            [5.0, 6.0, 8.0],
261        ];
262        let imputer = SimpleImputer::<f64>::new();
263        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
264        let result = fitted.transform(&x).unwrap();
265
266        // Column 0: valid = [1,3,5], mean = 3.0
267        assert_abs_diff_eq!(result[[0, 0]], 3.0);
268        // Column 1: valid = [2,4,6], mean = 4.0
269        assert_abs_diff_eq!(result[[1, 1]], 4.0);
270        // Column 2: valid = [6,8], mean = 7.0
271        assert_abs_diff_eq!(result[[0, 2]], 7.0);
272        assert_abs_diff_eq!(result[[2, 2]], 7.0);
273        // Non-NaN values unchanged
274        assert_abs_diff_eq!(result[[3, 0]], 5.0);
275        assert_abs_diff_eq!(result[[3, 1]], 6.0);
276        assert_abs_diff_eq!(result[[3, 2]], 8.0);
277    }
278
279    #[test]
280    fn test_all_nan_column_error() {
281        let x = array![[1.0, f64::NAN], [2.0, f64::NAN], [3.0, f64::NAN],];
282        let imputer = SimpleImputer::<f64>::new();
283        let result = FitUnsupervised::<f64>::fit(&imputer, &x);
284        assert!(result.is_err());
285        let err = result.unwrap_err();
286        let msg = format!("{}", err);
287        assert!(
288            msg.contains("column 1"),
289            "error should mention column index: {}",
290            msg
291        );
292        assert!(msg.contains("NaN"), "error should mention NaN: {}", msg);
293    }
294
295    #[test]
296    fn test_no_nan_passthrough() {
297        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
298        let imputer = SimpleImputer::<f64>::new();
299        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
300        let result = fitted.transform(&x).unwrap();
301
302        // Should be identical to input
303        for (a, b) in x.iter().zip(result.iter()) {
304            assert_abs_diff_eq!(a, b, epsilon = 1e-15);
305        }
306    }
307
308    #[test]
309    fn test_shape_mismatch_on_transform() {
310        let x_fit = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0],];
311        let x_transform = array![[1.0, 2.0], [3.0, 4.0],];
312        let imputer = SimpleImputer::<f64>::new();
313        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x_fit).unwrap();
314        let result = fitted.transform(&x_transform);
315        assert!(result.is_err());
316        let msg = format!("{}", result.unwrap_err());
317        assert!(
318            msg.contains("3") && msg.contains("2"),
319            "error should mention expected and actual: {}",
320            msg
321        );
322    }
323
324    #[test]
325    fn test_f32_support() {
326        let x = array![[1.0f32, f32::NAN], [3.0f32, 4.0f32], [5.0f32, 6.0f32],];
327        let imputer = SimpleImputer::<f32>::new();
328        let fitted = FitUnsupervised::<f32>::fit(&imputer, &x).unwrap();
329        let result = fitted.transform(&x).unwrap();
330
331        // Column 1: mean of 4.0 and 6.0 = 5.0
332        assert_abs_diff_eq!(result[[0, 1]], 5.0f32, epsilon = 1e-6);
333        // Non-NaN unchanged
334        assert_abs_diff_eq!(result[[0, 0]], 1.0f32, epsilon = 1e-6);
335    }
336
337    #[test]
338    fn test_constant_strategy_default_fill_value() {
339        // When Constant strategy is used without specifying fill_value, default to 0.
340        let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
341        let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Constant);
342        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
343        let result = fitted.transform(&x).unwrap();
344
345        assert_abs_diff_eq!(result[[0, 0]], 0.0);
346        assert_abs_diff_eq!(result[[1, 1]], 0.0);
347    }
348
349    #[test]
350    fn test_median_even_count() {
351        // Even number of non-NaN values: median is average of two middle values.
352        let x = array![[1.0], [3.0], [5.0], [7.0],];
353        let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
354        let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
355        // Median of [1,3,5,7] = (3+5)/2 = 4.0
356        assert_abs_diff_eq!(fitted.fill_values()[0], 4.0);
357    }
358}