sklears_impute/
simple.rs

1//! Simple imputation methods
2//!
3//! This module provides basic imputation strategies including mean, median,
4//! mode, constant, and time series imputation methods.
5
6use crate::core::{ImputationError, ImputationResult, Imputer};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Transform, Untrained},
12    types::Float,
13};
14use std::collections::HashMap;
15
16/// Simple Imputer
17///
18/// Imputation transformer for completing missing values using simple strategies.
19/// The imputer replaces missing values using the mean, median, most frequent value,
20/// a constant value, or time series imputation along each column.
21///
22/// # Parameters
23///
24/// * `missing_values` - The placeholder for missing values (NaN by default)
25/// * `strategy` - Imputation strategy ('mean', 'median', 'most_frequent', 'constant', 'forward_fill', 'backward_fill', 'random_sampling')
26/// * `fill_value` - Fill value to use when strategy is 'constant'
27/// * `copy` - Whether to make a copy of the input data
28///
29/// # Examples
30///
31/// ```
32/// use sklears_impute::SimpleImputer;
33/// use sklears_core::traits::{Transform, Fit};
34/// use scirs2_core::ndarray::array;
35///
36/// let X = array![[1.0, 2.0], [f64::NAN, 3.0], [7.0, 6.0]];
37///
38/// let imputer = SimpleImputer::new()
39///     .strategy("mean".to_string());
40/// let fitted = imputer.fit(&X.view(), &()).unwrap();
41/// let X_imputed = fitted.transform(&X.view()).unwrap();
42/// ```
43#[derive(Debug, Clone)]
44pub struct SimpleImputer<S = Untrained> {
45    state: S,
46    missing_values: f64,
47    strategy: String,
48    fill_value: Option<f64>,
49    copy: bool,
50}
51
52/// Trained state for SimpleImputer
53#[derive(Debug, Clone)]
54pub struct SimpleImputerTrained {
55    statistics: Array1<f64>,
56    valid_values: Vec<Vec<f64>>,
57}
58
59impl SimpleImputer<Untrained> {
60    /// Create a new SimpleImputer instance
61    pub fn new() -> Self {
62        Self {
63            state: Untrained,
64            missing_values: f64::NAN,
65            strategy: "mean".to_string(),
66            fill_value: None,
67            copy: true,
68        }
69    }
70
71    /// Set the missing values placeholder
72    pub fn missing_values(mut self, missing_values: f64) -> Self {
73        self.missing_values = missing_values;
74        self
75    }
76
77    /// Set the imputation strategy
78    pub fn strategy(mut self, strategy: String) -> Self {
79        self.strategy = strategy;
80        self
81    }
82
83    /// Set the fill value for constant strategy
84    pub fn fill_value(mut self, fill_value: Option<f64>) -> Self {
85        self.fill_value = fill_value;
86        self
87    }
88
89    /// Set whether to copy the input data
90    pub fn copy(mut self, copy: bool) -> Self {
91        self.copy = copy;
92        self
93    }
94
95    fn is_missing(&self, value: f64) -> bool {
96        if self.missing_values.is_nan() {
97            value.is_nan()
98        } else {
99            (value - self.missing_values).abs() < f64::EPSILON
100        }
101    }
102}
103
104impl Default for SimpleImputer<Untrained> {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl Estimator for SimpleImputer<Untrained> {
111    type Config = ();
112    type Error = SklearsError;
113    type Float = Float;
114
115    fn config(&self) -> &Self::Config {
116        &()
117    }
118}
119
120impl Fit<ArrayView2<'_, Float>, ()> for SimpleImputer<Untrained> {
121    type Fitted = SimpleImputer<SimpleImputerTrained>;
122
123    #[allow(non_snake_case)]
124    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
125        let X = X.mapv(|x| x);
126        let (_, n_features) = X.dim();
127        let mut statistics = Vec::new();
128        let mut all_valid_values = Vec::new();
129
130        for feature_idx in 0..n_features {
131            let column = X.column(feature_idx);
132            let valid_values: Vec<f64> = column
133                .iter()
134                .filter(|&&x| !self.is_missing(x))
135                .cloned()
136                .collect();
137
138            if valid_values.is_empty() {
139                return Err(SklearsError::InvalidInput(format!(
140                    "All values are missing in feature {feature_idx}"
141                )));
142            }
143
144            let statistic = match self.strategy.as_str() {
145                "mean" => {
146                    let sum: f64 = valid_values.iter().sum();
147                    sum / valid_values.len() as f64
148                }
149                "median" => {
150                    let mut sorted_values = valid_values.clone();
151                    sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
152                    let len = sorted_values.len();
153                    if len % 2 == 0 {
154                        (sorted_values[len / 2 - 1] + sorted_values[len / 2]) / 2.0
155                    } else {
156                        sorted_values[len / 2]
157                    }
158                }
159                "most_frequent" => {
160                    let mut counts = HashMap::new();
161                    for &value in &valid_values {
162                        *counts.entry(value.to_bits()).or_insert(0) += 1;
163                    }
164                    let most_frequent_bits = counts
165                        .into_iter()
166                        .max_by_key(|&(_, count)| count)
167                        .unwrap()
168                        .0;
169                    f64::from_bits(most_frequent_bits)
170                }
171                "constant" => self.fill_value.unwrap_or(0.0),
172                "forward_fill" | "backward_fill" => {
173                    // For time series strategies, we'll store the mean as fallback
174                    // The actual forward/backward fill will be done in transform
175                    let sum: f64 = valid_values.iter().sum();
176                    sum / valid_values.len() as f64
177                }
178                "random_sampling" => {
179                    // For random sampling, we'll store the mean as fallback
180                    // The actual random sampling will be done in transform
181                    let sum: f64 = valid_values.iter().sum();
182                    sum / valid_values.len() as f64
183                }
184                _ => {
185                    return Err(SklearsError::InvalidInput(format!(
186                        "Unknown strategy: {}",
187                        self.strategy
188                    )));
189                }
190            };
191
192            statistics.push(statistic);
193            all_valid_values.push(valid_values.clone());
194        }
195
196        Ok(SimpleImputer {
197            state: SimpleImputerTrained {
198                statistics: Array1::from(statistics),
199                valid_values: all_valid_values,
200            },
201            missing_values: self.missing_values,
202            strategy: self.strategy,
203            fill_value: self.fill_value,
204            copy: self.copy,
205        })
206    }
207}
208
209impl Transform<ArrayView2<'_, Float>, Array2<Float>> for SimpleImputer<SimpleImputerTrained> {
210    #[allow(non_snake_case)]
211    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
212        let X = X.mapv(|x| x);
213        let (n_samples, n_features) = X.dim();
214
215        if n_features != self.state.statistics.len() {
216            return Err(SklearsError::InvalidInput(format!(
217                "Number of features {} does not match training features {}",
218                n_features,
219                self.state.statistics.len()
220            )));
221        }
222
223        let mut X_imputed = if self.copy { X.clone() } else { X };
224
225        match self.strategy.as_str() {
226            "forward_fill" => {
227                for feature_idx in 0..n_features {
228                    let mut last_valid = None;
229                    for sample_idx in 0..n_samples {
230                        let value = X_imputed[[sample_idx, feature_idx]];
231                        if self.is_missing(value) {
232                            if let Some(fill_value) = last_valid {
233                                X_imputed[[sample_idx, feature_idx]] = fill_value;
234                            } else {
235                                // No previous valid value, use mean as fallback
236                                X_imputed[[sample_idx, feature_idx]] =
237                                    self.state.statistics[feature_idx];
238                            }
239                        } else {
240                            last_valid = Some(value);
241                        }
242                    }
243                }
244            }
245            "backward_fill" => {
246                for feature_idx in 0..n_features {
247                    let mut next_valid = None;
248                    for sample_idx in (0..n_samples).rev() {
249                        let value = X_imputed[[sample_idx, feature_idx]];
250                        if self.is_missing(value) {
251                            if let Some(fill_value) = next_valid {
252                                X_imputed[[sample_idx, feature_idx]] = fill_value;
253                            } else {
254                                // No next valid value, use mean as fallback
255                                X_imputed[[sample_idx, feature_idx]] =
256                                    self.state.statistics[feature_idx];
257                            }
258                        } else {
259                            next_valid = Some(value);
260                        }
261                    }
262                }
263            }
264            "random_sampling" => {
265                let mut rng = Random::default();
266                for feature_idx in 0..n_features {
267                    let valid_values = &self.state.valid_values[feature_idx];
268                    if !valid_values.is_empty() {
269                        for sample_idx in 0..n_samples {
270                            if self.is_missing(X_imputed[[sample_idx, feature_idx]]) {
271                                let random_idx = rng.gen_range(0..valid_values.len());
272                                let random_value = &valid_values[random_idx];
273                                X_imputed[[sample_idx, feature_idx]] = *random_value;
274                            }
275                        }
276                    }
277                }
278            }
279            _ => {
280                // Standard strategies: mean, median, most_frequent, constant
281                for feature_idx in 0..n_features {
282                    let fill_value = self.state.statistics[feature_idx];
283                    for sample_idx in 0..n_samples {
284                        if self.is_missing(X_imputed[[sample_idx, feature_idx]]) {
285                            X_imputed[[sample_idx, feature_idx]] = fill_value;
286                        }
287                    }
288                }
289            }
290        }
291
292        Ok(X_imputed.mapv(|x| x as Float))
293    }
294}
295
296impl SimpleImputer<SimpleImputerTrained> {
297    fn is_missing(&self, value: f64) -> bool {
298        if self.missing_values.is_nan() {
299            value.is_nan()
300        } else {
301            (value - self.missing_values).abs() < f64::EPSILON
302        }
303    }
304}
305
306/// Missing Indicator
307///
308/// Binary indicator for missing values.
309///
310/// # Parameters
311///
312/// * `missing_values` - The placeholder for missing values (NaN by default)
313/// * `features` - Which features to generate indicators for ('missing-only' or 'all')
314/// * `sparse` - Whether to return sparse indicators
315/// * `error_on_new` - Whether to raise an error when a new feature is completely missing during transform
316///
317/// # Examples
318///
319/// ```
320/// use sklears_impute::MissingIndicator;
321/// use sklears_core::traits::{Transform, Fit};
322/// use scirs2_core::ndarray::array;
323///
324/// let X = array![[1.0, 2.0], [f64::NAN, 3.0], [7.0, 6.0]];
325///
326/// let indicator = MissingIndicator::new();
327/// let fitted = indicator.fit(&X.view(), &()).unwrap();
328/// let indicators = fitted.transform(&X.view()).unwrap();
329/// ```
330#[derive(Debug, Clone)]
331pub struct MissingIndicator<S = Untrained> {
332    state: S,
333    missing_values: f64,
334    features: String,
335    sparse: bool,
336    error_on_new: bool,
337}
338
339/// Trained state for MissingIndicator
340#[derive(Debug, Clone)]
341pub struct MissingIndicatorTrained {
342    features_: Vec<usize>,
343    n_features_in_: usize,
344}
345
346impl MissingIndicator<Untrained> {
347    /// Create a new MissingIndicator instance
348    pub fn new() -> Self {
349        Self {
350            state: Untrained,
351            missing_values: f64::NAN,
352            features: "missing-only".to_string(),
353            sparse: false,
354            error_on_new: true,
355        }
356    }
357
358    /// Set the missing values placeholder
359    pub fn missing_values(mut self, missing_values: f64) -> Self {
360        self.missing_values = missing_values;
361        self
362    }
363
364    /// Set which features to generate indicators for
365    pub fn features(mut self, features: String) -> Self {
366        self.features = features;
367        self
368    }
369
370    /// Set whether to return sparse indicators
371    pub fn sparse(mut self, sparse: bool) -> Self {
372        self.sparse = sparse;
373        self
374    }
375
376    /// Set whether to raise an error on new missing features
377    pub fn error_on_new(mut self, error_on_new: bool) -> Self {
378        self.error_on_new = error_on_new;
379        self
380    }
381
382    fn is_missing(&self, value: f64) -> bool {
383        if self.missing_values.is_nan() {
384            value.is_nan()
385        } else {
386            (value - self.missing_values).abs() < f64::EPSILON
387        }
388    }
389}
390
391impl Default for MissingIndicator<Untrained> {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397impl Estimator for MissingIndicator<Untrained> {
398    type Config = ();
399    type Error = SklearsError;
400    type Float = Float;
401
402    fn config(&self) -> &Self::Config {
403        &()
404    }
405}
406
407impl Fit<ArrayView2<'_, Float>, ()> for MissingIndicator<Untrained> {
408    type Fitted = MissingIndicator<MissingIndicatorTrained>;
409
410    #[allow(non_snake_case)]
411    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
412        let X = X.mapv(|x| x);
413        let (_, n_features) = X.dim();
414
415        let features_ = match self.features.as_str() {
416            "missing-only" => {
417                // Only include features that have missing values
418                let mut selected_features = Vec::new();
419                for feature_idx in 0..n_features {
420                    let column = X.column(feature_idx);
421                    if column.iter().any(|&x| self.is_missing(x)) {
422                        selected_features.push(feature_idx);
423                    }
424                }
425                selected_features
426            }
427            "all" => (0..n_features).collect(),
428            _ => {
429                return Err(SklearsError::InvalidInput(format!(
430                    "Unknown features option: {}",
431                    self.features
432                )));
433            }
434        };
435
436        Ok(MissingIndicator {
437            state: MissingIndicatorTrained {
438                features_,
439                n_features_in_: n_features,
440            },
441            missing_values: self.missing_values,
442            features: self.features,
443            sparse: self.sparse,
444            error_on_new: self.error_on_new,
445        })
446    }
447}
448
449impl Transform<ArrayView2<'_, Float>, Array2<Float>> for MissingIndicator<MissingIndicatorTrained> {
450    #[allow(non_snake_case)]
451    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
452        let X = X.mapv(|x| x);
453        let (n_samples, n_features) = X.dim();
454
455        if n_features != self.state.n_features_in_ {
456            return Err(SklearsError::InvalidInput(format!(
457                "Number of features {} does not match training features {}",
458                n_features, self.state.n_features_in_
459            )));
460        }
461
462        if self.error_on_new {
463            // Check for new missing features
464            for feature_idx in 0..n_features {
465                if !self.state.features_.contains(&feature_idx) {
466                    let column = X.column(feature_idx);
467                    if column.iter().any(|&x| self.is_missing(x)) {
468                        return Err(SklearsError::InvalidInput(format!(
469                            "Feature {} has missing values but was not seen during fit",
470                            feature_idx
471                        )));
472                    }
473                }
474            }
475        }
476
477        let n_indicator_features = self.state.features_.len();
478        let mut indicators = Array2::<f64>::zeros((n_samples, n_indicator_features));
479
480        for (indicator_idx, &feature_idx) in self.state.features_.iter().enumerate() {
481            let column = X.column(feature_idx);
482            for (sample_idx, &value) in column.iter().enumerate() {
483                if self.is_missing(value) {
484                    indicators[[sample_idx, indicator_idx]] = 1.0;
485                }
486            }
487        }
488
489        Ok(indicators.mapv(|x| x as Float))
490    }
491}
492
493impl MissingIndicator<MissingIndicatorTrained> {
494    fn is_missing(&self, value: f64) -> bool {
495        if self.missing_values.is_nan() {
496            value.is_nan()
497        } else {
498            (value - self.missing_values).abs() < f64::EPSILON
499        }
500    }
501}
502
503// Implement the Imputer trait for SimpleImputer
504impl Imputer for SimpleImputer<Untrained> {
505    #[allow(non_snake_case)]
506    fn fit_transform(
507        &self,
508        X: &scirs2_core::ndarray::ArrayView2<f64>,
509    ) -> ImputationResult<scirs2_core::ndarray::Array2<f64>> {
510        // Convert from f64 array to Float array for sklears-core compatibility
511        let X_float = X.mapv(|x| x as Float);
512        let X_view = X_float.view();
513
514        // Use the sklears-core fit and transform pattern
515        let fitted = self.clone().fit(&X_view, &()).map_err(|e| {
516            ImputationError::ProcessingError(format!("Failed to fit imputer: {}", e))
517        })?;
518
519        let result = fitted.transform(&X_view).map_err(|e| {
520            ImputationError::ProcessingError(format!("Failed to transform data: {}", e))
521        })?;
522
523        // Convert back to f64 for the imputation interface
524        Ok(result.mapv(|x| x))
525    }
526}