1use anofox_ml_core::{FitUnsupervised, Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
7pub enum ImputeStrategy {
8 Mean,
10 Median,
12 MostFrequent,
14 Constant,
16}
17
18#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
24pub struct SimpleImputer<F: Float> {
25 strategy: ImputeStrategy,
26 fill_value: Option<F>,
27}
28
29impl<F: Float> SimpleImputer<F> {
30 pub fn new() -> Self {
32 Self {
33 strategy: ImputeStrategy::Mean,
34 fill_value: None,
35 }
36 }
37
38 pub fn with_strategy(mut self, strategy: ImputeStrategy) -> Self {
40 self.strategy = strategy;
41 self
42 }
43
44 pub fn with_fill_value(mut self, value: F) -> Self {
46 self.fill_value = Some(value);
47 self
48 }
49}
50
51impl<F: Float> Default for SimpleImputer<F> {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
60pub struct FittedSimpleImputer<F: Float> {
61 fill_values: Array1<F>,
62}
63
64impl<F: Float> FittedSimpleImputer<F> {
65 pub fn fill_values(&self) -> &Array1<F> {
67 &self.fill_values
68 }
69}
70
71fn column_mean<F: Float>(values: &[F]) -> Option<F> {
73 let mut sum = F::zero();
74 let mut count = 0usize;
75 for &v in values {
76 if !v.is_nan() {
77 sum = sum + v;
78 count += 1;
79 }
80 }
81 if count == 0 {
82 None
83 } else {
84 Some(sum / F::from_usize(count).unwrap())
85 }
86}
87
88fn column_median<F: Float>(values: &[F]) -> Option<F> {
90 let mut valid: Vec<F> = values.iter().copied().filter(|v| !v.is_nan()).collect();
91 if valid.is_empty() {
92 return None;
93 }
94 valid.sort_by(|a, b| a.partial_cmp(b).unwrap());
95 let n = valid.len();
96 if n % 2 == 1 {
97 Some(valid[n / 2])
98 } else {
99 Some((valid[n / 2 - 1] + valid[n / 2]) / F::from_f64(2.0).unwrap())
100 }
101}
102
103fn column_most_frequent<F: Float>(values: &[F]) -> Option<F> {
106 let mut counts: HashMap<u64, (F, usize)> = HashMap::new();
107 for &v in values {
108 if v.is_nan() {
109 continue;
110 }
111 let bits = v.to_f64().unwrap().to_bits();
113 counts
114 .entry(bits)
115 .and_modify(|e| e.1 += 1)
116 .or_insert((v, 1));
117 }
118 if counts.is_empty() {
119 return None;
120 }
121 counts
123 .values()
124 .max_by(|a, b| a.1.cmp(&b.1).then_with(|| b.0.partial_cmp(&a.0).unwrap()))
125 .map(|&(v, _)| v)
126}
127
128impl<F: Float> FitUnsupervised<F> for SimpleImputer<F> {
129 type Fitted = FittedSimpleImputer<F>;
130
131 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
132 if x.is_empty() {
133 return Err(RustMlError::EmptyInput("input array is empty".into()));
134 }
135
136 if self.strategy == ImputeStrategy::Constant {
137 let fill = self.fill_value.unwrap_or_else(F::zero);
138 let fill_values = Array1::from_elem(x.ncols(), fill);
139 return Ok(FittedSimpleImputer { fill_values });
140 }
141
142 let ncols = x.ncols();
143 let mut fill_values = Array1::<F>::zeros(ncols);
144
145 for j in 0..ncols {
146 let col: Vec<F> = x.column(j).to_vec();
147 let computed = match self.strategy {
148 ImputeStrategy::Mean => column_mean(&col),
149 ImputeStrategy::Median => column_median(&col),
150 ImputeStrategy::MostFrequent => column_most_frequent(&col),
151 ImputeStrategy::Constant => unreachable!(),
152 };
153 match computed {
154 Some(v) => fill_values[j] = v,
155 None => {
156 return Err(RustMlError::InvalidParameter(format!(
157 "column {} contains only NaN values",
158 j
159 )));
160 }
161 }
162 }
163
164 Ok(FittedSimpleImputer { fill_values })
165 }
166}
167
168impl<F: Float> Transform<F> for FittedSimpleImputer<F> {
169 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
170 if x.ncols() != self.fill_values.len() {
171 return Err(RustMlError::ShapeMismatch(format!(
172 "expected {} features, got {}",
173 self.fill_values.len(),
174 x.ncols()
175 )));
176 }
177
178 let mut result = x.to_owned();
179 for mut row in result.rows_mut() {
180 for (j, val) in row.iter_mut().enumerate() {
181 if val.is_nan() {
182 *val = self.fill_values[j];
183 }
184 }
185 }
186 Ok(result)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use approx::assert_abs_diff_eq;
194 use ndarray::array;
195
196 #[test]
197 fn test_mean_strategy_basic() {
198 let x = array![[1.0, f64::NAN], [2.0, 4.0], [3.0, 6.0],];
199 let imputer = SimpleImputer::<f64>::new();
200 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
201 let result = fitted.transform(&x).unwrap();
202
203 assert_abs_diff_eq!(result[[0, 0]], 1.0);
205 assert_abs_diff_eq!(result[[1, 0]], 2.0);
206 assert_abs_diff_eq!(result[[2, 0]], 3.0);
207 assert_abs_diff_eq!(result[[0, 1]], 5.0);
209 assert_abs_diff_eq!(result[[1, 1]], 4.0);
210 assert_abs_diff_eq!(result[[2, 1]], 6.0);
211 }
212
213 #[test]
214 fn test_median_strategy() {
215 let x = array![[f64::NAN, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0],];
216 let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
217 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
218 let result = fitted.transform(&x).unwrap();
219
220 assert_abs_diff_eq!(result[[0, 0]], 4.0);
222 assert_abs_diff_eq!(result[[0, 1]], 1.0);
224 }
225
226 #[test]
227 fn test_most_frequent_strategy() {
228 let x = array![[1.0, f64::NAN], [2.0, 3.0], [2.0, 3.0], [3.0, 5.0],];
229 let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::MostFrequent);
230 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
231 let result = fitted.transform(&x).unwrap();
232
233 assert_abs_diff_eq!(result[[0, 0]], 1.0); assert_abs_diff_eq!(result[[0, 1]], 3.0); }
238
239 #[test]
240 fn test_constant_strategy() {
241 let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
242 let imputer = SimpleImputer::<f64>::new()
243 .with_strategy(ImputeStrategy::Constant)
244 .with_fill_value(-999.0);
245 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
246 let result = fitted.transform(&x).unwrap();
247
248 assert_abs_diff_eq!(result[[0, 0]], -999.0);
249 assert_abs_diff_eq!(result[[0, 1]], 1.0);
250 assert_abs_diff_eq!(result[[1, 0]], 2.0);
251 assert_abs_diff_eq!(result[[1, 1]], -999.0);
252 }
253
254 #[test]
255 fn test_mixed_nan_positions() {
256 let x = array![
257 [f64::NAN, 2.0, f64::NAN],
258 [1.0, f64::NAN, 6.0],
259 [3.0, 4.0, f64::NAN],
260 [5.0, 6.0, 8.0],
261 ];
262 let imputer = SimpleImputer::<f64>::new();
263 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
264 let result = fitted.transform(&x).unwrap();
265
266 assert_abs_diff_eq!(result[[0, 0]], 3.0);
268 assert_abs_diff_eq!(result[[1, 1]], 4.0);
270 assert_abs_diff_eq!(result[[0, 2]], 7.0);
272 assert_abs_diff_eq!(result[[2, 2]], 7.0);
273 assert_abs_diff_eq!(result[[3, 0]], 5.0);
275 assert_abs_diff_eq!(result[[3, 1]], 6.0);
276 assert_abs_diff_eq!(result[[3, 2]], 8.0);
277 }
278
279 #[test]
280 fn test_all_nan_column_error() {
281 let x = array![[1.0, f64::NAN], [2.0, f64::NAN], [3.0, f64::NAN],];
282 let imputer = SimpleImputer::<f64>::new();
283 let result = FitUnsupervised::<f64>::fit(&imputer, &x);
284 assert!(result.is_err());
285 let err = result.unwrap_err();
286 let msg = format!("{}", err);
287 assert!(
288 msg.contains("column 1"),
289 "error should mention column index: {}",
290 msg
291 );
292 assert!(msg.contains("NaN"), "error should mention NaN: {}", msg);
293 }
294
295 #[test]
296 fn test_no_nan_passthrough() {
297 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
298 let imputer = SimpleImputer::<f64>::new();
299 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
300 let result = fitted.transform(&x).unwrap();
301
302 for (a, b) in x.iter().zip(result.iter()) {
304 assert_abs_diff_eq!(a, b, epsilon = 1e-15);
305 }
306 }
307
308 #[test]
309 fn test_shape_mismatch_on_transform() {
310 let x_fit = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0],];
311 let x_transform = array![[1.0, 2.0], [3.0, 4.0],];
312 let imputer = SimpleImputer::<f64>::new();
313 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x_fit).unwrap();
314 let result = fitted.transform(&x_transform);
315 assert!(result.is_err());
316 let msg = format!("{}", result.unwrap_err());
317 assert!(
318 msg.contains("3") && msg.contains("2"),
319 "error should mention expected and actual: {}",
320 msg
321 );
322 }
323
324 #[test]
325 fn test_f32_support() {
326 let x = array![[1.0f32, f32::NAN], [3.0f32, 4.0f32], [5.0f32, 6.0f32],];
327 let imputer = SimpleImputer::<f32>::new();
328 let fitted = FitUnsupervised::<f32>::fit(&imputer, &x).unwrap();
329 let result = fitted.transform(&x).unwrap();
330
331 assert_abs_diff_eq!(result[[0, 1]], 5.0f32, epsilon = 1e-6);
333 assert_abs_diff_eq!(result[[0, 0]], 1.0f32, epsilon = 1e-6);
335 }
336
337 #[test]
338 fn test_constant_strategy_default_fill_value() {
339 let x = array![[f64::NAN, 1.0], [2.0, f64::NAN],];
341 let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Constant);
342 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
343 let result = fitted.transform(&x).unwrap();
344
345 assert_abs_diff_eq!(result[[0, 0]], 0.0);
346 assert_abs_diff_eq!(result[[1, 1]], 0.0);
347 }
348
349 #[test]
350 fn test_median_even_count() {
351 let x = array![[1.0], [3.0], [5.0], [7.0],];
353 let imputer = SimpleImputer::<f64>::new().with_strategy(ImputeStrategy::Median);
354 let fitted = FitUnsupervised::<f64>::fit(&imputer, &x).unwrap();
355 assert_abs_diff_eq!(fitted.fill_values()[0], 4.0);
357 }
358}