Skip to main content

scry_learn/preprocess/
imputer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Missing-value imputation.
3//!
4//! [`SimpleImputer`] replaces `NaN` values in a [`Dataset`] with a
5//! statistic computed from the non-missing entries of each feature column.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use scry_learn::preprocess::{SimpleImputer, Strategy, Transformer};
11//!
12//! let mut imputer = SimpleImputer::new().strategy(Strategy::Mean);
13//! imputer.fit_transform(&mut dataset)?;
14//! ```
15
16use crate::dataset::Dataset;
17use crate::error::{Result, ScryLearnError};
18use crate::preprocess::Transformer;
19
20/// Strategy for computing the replacement value per feature.
21#[derive(Clone, Debug, Default)]
22#[non_exhaustive]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub enum Strategy {
25    /// Replace with the arithmetic mean of non-`NaN` values.
26    #[default]
27    Mean,
28    /// Replace with the median of non-`NaN` values.
29    Median,
30    /// Replace with the most frequent non-`NaN` value (mode).
31    MostFrequent,
32    /// Replace with a user-specified constant.
33    Constant(f64),
34}
35
36/// Imputes missing (`NaN`) values in each feature column.
37///
38/// During [`fit`](Transformer::fit), the imputer computes one fill value
39/// per feature using the chosen [`Strategy`]. During
40/// [`transform`](Transformer::transform), every `NaN` in a column is
41/// replaced with that value.
42///
43/// # Example
44///
45/// ```ignore
46/// let mut imp = SimpleImputer::new().strategy(Strategy::Median);
47/// imp.fit_transform(&mut ds)?;
48/// ```
49#[derive(Clone, Debug)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51#[non_exhaustive]
52pub struct SimpleImputer {
53    strategy: Strategy,
54    fill_values: Vec<f64>,
55    fitted: bool,
56    #[cfg_attr(feature = "serde", serde(default))]
57    _schema_version: u32,
58}
59
60impl SimpleImputer {
61    /// Create a new unfitted imputer (defaults to [`Strategy::Mean`]).
62    pub fn new() -> Self {
63        Self {
64            strategy: Strategy::default(),
65            fill_values: Vec::new(),
66            fitted: false,
67            _schema_version: crate::version::SCHEMA_VERSION,
68        }
69    }
70
71    /// Set the imputation strategy.
72    pub fn strategy(mut self, strategy: Strategy) -> Self {
73        self.strategy = strategy;
74        self
75    }
76
77    /// Return the per-feature fill values computed during `fit`.
78    ///
79    /// # Panics
80    ///
81    /// Panics if the imputer has not been fitted.
82    pub fn fill_values(&self) -> &[f64] {
83        &self.fill_values
84    }
85}
86
87impl Default for SimpleImputer {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93// ── helpers ──────────────────────────────────────────────────────
94
95/// Compute the mean of values that are not NaN.
96fn mean_ignore_nan(col: &[f64]) -> f64 {
97    let (sum, count) = col
98        .iter()
99        .filter(|x| !x.is_nan())
100        .fold((0.0, 0usize), |(s, c), &v| (s + v, c + 1));
101    if count == 0 {
102        0.0
103    } else {
104        sum / count as f64
105    }
106}
107
108/// Compute the median of values that are not NaN.
109fn median_ignore_nan(col: &[f64]) -> f64 {
110    let mut valid: Vec<f64> = col.iter().copied().filter(|x| !x.is_nan()).collect();
111    if valid.is_empty() {
112        return 0.0;
113    }
114    valid.sort_unstable_by(|a, b| a.total_cmp(b));
115    let mid = valid.len() / 2;
116    if valid.len() % 2 == 0 {
117        f64::midpoint(valid[mid - 1], valid[mid])
118    } else {
119        valid[mid]
120    }
121}
122
123/// Compute the most frequent value (mode) ignoring NaN.
124/// Ties are broken by choosing the smallest value.
125fn mode_ignore_nan(col: &[f64]) -> f64 {
126    use std::collections::HashMap;
127
128    let mut counts: HashMap<u64, (f64, usize)> = HashMap::new();
129    for &v in col {
130        if v.is_nan() {
131            continue;
132        }
133        let key = v.to_bits();
134        counts
135            .entry(key)
136            .and_modify(|(_, c)| *c += 1)
137            .or_insert((v, 1));
138    }
139    if counts.is_empty() {
140        return 0.0;
141    }
142    counts
143        .into_values()
144        .max_by(|(v1, c1), (v2, c2)| c1.cmp(c2).then_with(|| v2.total_cmp(v1)))
145        .map_or(0.0, |(v, _)| v)
146}
147
148impl Transformer for SimpleImputer {
149    fn fit(&mut self, data: &Dataset) -> Result<()> {
150        data.validate_no_inf()?;
151        if data.n_samples() == 0 {
152            return Err(ScryLearnError::EmptyDataset);
153        }
154
155        self.fill_values = Vec::with_capacity(data.n_features());
156
157        for col in &data.features {
158            let fill = match &self.strategy {
159                Strategy::Mean => mean_ignore_nan(col),
160                Strategy::Median => median_ignore_nan(col),
161                Strategy::MostFrequent => mode_ignore_nan(col),
162                Strategy::Constant(v) => *v,
163            };
164            self.fill_values.push(fill);
165        }
166        self.fitted = true;
167        Ok(())
168    }
169
170    fn transform(&self, data: &mut Dataset) -> Result<()> {
171        crate::version::check_schema_version(self._schema_version)?;
172        if !self.fitted {
173            return Err(ScryLearnError::NotFitted);
174        }
175        for (j, col) in data.features.iter_mut().enumerate() {
176            let fill = self.fill_values[j];
177            for x in col.iter_mut() {
178                if x.is_nan() {
179                    *x = fill;
180                }
181            }
182        }
183        data.sync_matrix();
184        Ok(())
185    }
186
187    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
188        Err(ScryLearnError::InvalidParameter(
189            "SimpleImputer is not invertible".into(),
190        ))
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    fn ds_with_nan() -> Dataset {
199        Dataset::new(
200            vec![
201                vec![1.0, f64::NAN, 3.0, 4.0],
202                vec![10.0, 20.0, f64::NAN, 40.0],
203            ],
204            vec![0.0; 4],
205            vec!["a".into(), "b".into()],
206            "y",
207        )
208    }
209
210    #[test]
211    fn test_imputer_mean() {
212        let mut ds = ds_with_nan();
213        let mut imp = SimpleImputer::new().strategy(Strategy::Mean);
214        imp.fit_transform(&mut ds).unwrap();
215
216        // col a: mean(1,3,4) = 8/3 ≈ 2.6667
217        assert!(!ds.features[0][1].is_nan());
218        assert!((ds.features[0][1] - 8.0 / 3.0).abs() < 1e-10);
219
220        // col b: mean(10,20,40) = 70/3 ≈ 23.3333
221        assert!(!ds.features[1][2].is_nan());
222        assert!((ds.features[1][2] - 70.0 / 3.0).abs() < 1e-10);
223    }
224
225    #[test]
226    fn test_imputer_median() {
227        let mut ds = ds_with_nan();
228        let mut imp = SimpleImputer::new().strategy(Strategy::Median);
229        imp.fit_transform(&mut ds).unwrap();
230
231        // col a: sorted valid = [1,3,4], median = 3
232        assert!((ds.features[0][1] - 3.0).abs() < 1e-10);
233        // col b: sorted valid = [10,20,40], median = 20
234        assert!((ds.features[1][2] - 20.0).abs() < 1e-10);
235    }
236
237    #[test]
238    fn test_imputer_most_frequent() {
239        let mut ds = Dataset::new(
240            vec![vec![1.0, 1.0, f64::NAN, 3.0, 1.0]],
241            vec![0.0; 5],
242            vec!["a".into()],
243            "y",
244        );
245        let mut imp = SimpleImputer::new().strategy(Strategy::MostFrequent);
246        imp.fit_transform(&mut ds).unwrap();
247
248        // mode of [1,1,3,1] = 1
249        assert!((ds.features[0][2] - 1.0).abs() < 1e-10);
250    }
251
252    #[test]
253    fn test_imputer_constant() {
254        let mut ds = ds_with_nan();
255        let mut imp = SimpleImputer::new().strategy(Strategy::Constant(-999.0));
256        imp.fit_transform(&mut ds).unwrap();
257
258        assert!((ds.features[0][1] - (-999.0)).abs() < 1e-10);
259        assert!((ds.features[1][2] - (-999.0)).abs() < 1e-10);
260    }
261
262    #[test]
263    fn test_imputer_not_fitted() {
264        let imp = SimpleImputer::new();
265        let mut ds = ds_with_nan();
266        assert!(imp.transform(&mut ds).is_err());
267    }
268
269    #[test]
270    fn test_imputer_inverse_transform_err() {
271        let mut ds = ds_with_nan();
272        let mut imp = SimpleImputer::new();
273        imp.fit(&ds).unwrap();
274        assert!(imp.inverse_transform(&mut ds).is_err());
275    }
276}