scry_learn/preprocess/
imputer.rs1use crate::dataset::Dataset;
17use crate::error::{Result, ScryLearnError};
18use crate::preprocess::Transformer;
19
20#[derive(Clone, Debug, Default)]
22#[non_exhaustive]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub enum Strategy {
25 #[default]
27 Mean,
28 Median,
30 MostFrequent,
32 Constant(f64),
34}
35
36#[derive(Clone, Debug)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51#[non_exhaustive]
52pub struct SimpleImputer {
53 strategy: Strategy,
54 fill_values: Vec<f64>,
55 fitted: bool,
56 #[cfg_attr(feature = "serde", serde(default))]
57 _schema_version: u32,
58}
59
60impl SimpleImputer {
61 pub fn new() -> Self {
63 Self {
64 strategy: Strategy::default(),
65 fill_values: Vec::new(),
66 fitted: false,
67 _schema_version: crate::version::SCHEMA_VERSION,
68 }
69 }
70
71 pub fn strategy(mut self, strategy: Strategy) -> Self {
73 self.strategy = strategy;
74 self
75 }
76
77 pub fn fill_values(&self) -> &[f64] {
83 &self.fill_values
84 }
85}
86
87impl Default for SimpleImputer {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93fn mean_ignore_nan(col: &[f64]) -> f64 {
97 let (sum, count) = col
98 .iter()
99 .filter(|x| !x.is_nan())
100 .fold((0.0, 0usize), |(s, c), &v| (s + v, c + 1));
101 if count == 0 {
102 0.0
103 } else {
104 sum / count as f64
105 }
106}
107
108fn median_ignore_nan(col: &[f64]) -> f64 {
110 let mut valid: Vec<f64> = col.iter().copied().filter(|x| !x.is_nan()).collect();
111 if valid.is_empty() {
112 return 0.0;
113 }
114 valid.sort_unstable_by(|a, b| a.total_cmp(b));
115 let mid = valid.len() / 2;
116 if valid.len() % 2 == 0 {
117 f64::midpoint(valid[mid - 1], valid[mid])
118 } else {
119 valid[mid]
120 }
121}
122
123fn mode_ignore_nan(col: &[f64]) -> f64 {
126 use std::collections::HashMap;
127
128 let mut counts: HashMap<u64, (f64, usize)> = HashMap::new();
129 for &v in col {
130 if v.is_nan() {
131 continue;
132 }
133 let key = v.to_bits();
134 counts
135 .entry(key)
136 .and_modify(|(_, c)| *c += 1)
137 .or_insert((v, 1));
138 }
139 if counts.is_empty() {
140 return 0.0;
141 }
142 counts
143 .into_values()
144 .max_by(|(v1, c1), (v2, c2)| c1.cmp(c2).then_with(|| v2.total_cmp(v1)))
145 .map_or(0.0, |(v, _)| v)
146}
147
148impl Transformer for SimpleImputer {
149 fn fit(&mut self, data: &Dataset) -> Result<()> {
150 data.validate_no_inf()?;
151 if data.n_samples() == 0 {
152 return Err(ScryLearnError::EmptyDataset);
153 }
154
155 self.fill_values = Vec::with_capacity(data.n_features());
156
157 for col in &data.features {
158 let fill = match &self.strategy {
159 Strategy::Mean => mean_ignore_nan(col),
160 Strategy::Median => median_ignore_nan(col),
161 Strategy::MostFrequent => mode_ignore_nan(col),
162 Strategy::Constant(v) => *v,
163 };
164 self.fill_values.push(fill);
165 }
166 self.fitted = true;
167 Ok(())
168 }
169
170 fn transform(&self, data: &mut Dataset) -> Result<()> {
171 crate::version::check_schema_version(self._schema_version)?;
172 if !self.fitted {
173 return Err(ScryLearnError::NotFitted);
174 }
175 for (j, col) in data.features.iter_mut().enumerate() {
176 let fill = self.fill_values[j];
177 for x in col.iter_mut() {
178 if x.is_nan() {
179 *x = fill;
180 }
181 }
182 }
183 data.sync_matrix();
184 Ok(())
185 }
186
187 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
188 Err(ScryLearnError::InvalidParameter(
189 "SimpleImputer is not invertible".into(),
190 ))
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 fn ds_with_nan() -> Dataset {
199 Dataset::new(
200 vec![
201 vec![1.0, f64::NAN, 3.0, 4.0],
202 vec![10.0, 20.0, f64::NAN, 40.0],
203 ],
204 vec![0.0; 4],
205 vec!["a".into(), "b".into()],
206 "y",
207 )
208 }
209
210 #[test]
211 fn test_imputer_mean() {
212 let mut ds = ds_with_nan();
213 let mut imp = SimpleImputer::new().strategy(Strategy::Mean);
214 imp.fit_transform(&mut ds).unwrap();
215
216 assert!(!ds.features[0][1].is_nan());
218 assert!((ds.features[0][1] - 8.0 / 3.0).abs() < 1e-10);
219
220 assert!(!ds.features[1][2].is_nan());
222 assert!((ds.features[1][2] - 70.0 / 3.0).abs() < 1e-10);
223 }
224
225 #[test]
226 fn test_imputer_median() {
227 let mut ds = ds_with_nan();
228 let mut imp = SimpleImputer::new().strategy(Strategy::Median);
229 imp.fit_transform(&mut ds).unwrap();
230
231 assert!((ds.features[0][1] - 3.0).abs() < 1e-10);
233 assert!((ds.features[1][2] - 20.0).abs() < 1e-10);
235 }
236
237 #[test]
238 fn test_imputer_most_frequent() {
239 let mut ds = Dataset::new(
240 vec![vec![1.0, 1.0, f64::NAN, 3.0, 1.0]],
241 vec![0.0; 5],
242 vec!["a".into()],
243 "y",
244 );
245 let mut imp = SimpleImputer::new().strategy(Strategy::MostFrequent);
246 imp.fit_transform(&mut ds).unwrap();
247
248 assert!((ds.features[0][2] - 1.0).abs() < 1e-10);
250 }
251
252 #[test]
253 fn test_imputer_constant() {
254 let mut ds = ds_with_nan();
255 let mut imp = SimpleImputer::new().strategy(Strategy::Constant(-999.0));
256 imp.fit_transform(&mut ds).unwrap();
257
258 assert!((ds.features[0][1] - (-999.0)).abs() < 1e-10);
259 assert!((ds.features[1][2] - (-999.0)).abs() < 1e-10);
260 }
261
262 #[test]
263 fn test_imputer_not_fitted() {
264 let imp = SimpleImputer::new();
265 let mut ds = ds_with_nan();
266 assert!(imp.transform(&mut ds).is_err());
267 }
268
269 #[test]
270 fn test_imputer_inverse_transform_err() {
271 let mut ds = ds_with_nan();
272 let mut imp = SimpleImputer::new();
273 imp.fit(&ds).unwrap();
274 assert!(imp.inverse_transform(&mut ds).is_err());
275 }
276}