Skip to main content

feature_factory/transformers/
imputation.rs

1//! ## Missing Value Imputation Transformers
2//!
3//! This module provides transformers for handling missing values in both numeric and categorical columns.
4//!
5//! ### Available Transformers
6//!
7//! - [`MeanMedianImputer`]: Fills missing values in numeric columns using the mean (median is not available yet).
8//! - [`ArbitraryNumberImputer`]: Replaces missing numeric values with a fixed arbitrary number.
9//! - [`EndTailImputer`]: Imputes numeric columns using a percentile value (e.g., tail imputation).
10//! - [`CategoricalImputer`]: Fills missing categorical values using the mode or a predefined default.
11//! - [`AddMissingIndicator`]: Creates Boolean indicator columns to flag missing values.
12//! - [`DropMissingData`]: Removes rows that contain missing values in the specified columns.
13//!
14//! Each transformer returns a new DataFrame with missing values handled accordingly.
15//! Errors are returned as [`FeatureFactoryError`], and results are wrapped in [`FeatureFactoryResult`].
16
17use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
18use crate::impl_transformer;
19use datafusion::dataframe::DataFrame;
20use datafusion::functions_aggregate::expr_fn::{approx_percentile_cont, avg, count};
21use datafusion::logical_expr::{col, lit, not, Case as DFCase, Expr};
22use datafusion::scalar::ScalarValue;
23use std::collections::HashMap;
24
25/// Validates that every column in `target_cols` exists in the DataFrame.
26/// Returns an error if any target column is missing.
27fn validate_columns(df: &DataFrame, target_cols: &[String]) -> FeatureFactoryResult<()> {
28    let schema = df.schema();
29    for col_name in target_cols {
30        if schema.field_with_name(None, col_name).is_err() {
31            return Err(FeatureFactoryError::MissingColumn(format!(
32                "Column '{}' not found in DataFrame",
33                col_name
34            )));
35        }
36    }
37    Ok(())
38}
39
40/// Constructs an expression equivalent to SQL COALESCE(col, fallback).
41/// This is implemented as a CASE expression: if `col` is not null then return it, otherwise return `fallback`.
42fn coalesce_expr_for(name: &str, fallback: Expr) -> Expr {
43    Expr::Case(DFCase {
44        expr: None,
45        when_then_expr: vec![(Box::new(not(col(name).is_null())), Box::new(col(name)))],
46        else_expr: Some(Box::new(fallback)),
47    })
48}
49
50/// Generic helper function to apply a mapping to a set of target columns.
51/// For each field in the DataFrame, if its name is in `target_cols` and a mapping is available via `get_fallback`,
52/// then the column is replaced by a CASE–WHEN expression; otherwise, the original column is retained.
53fn apply_imputation<F>(
54    df: DataFrame,
55    target_cols: &[String],
56    get_fallback: F,
57) -> FeatureFactoryResult<DataFrame>
58where
59    F: Fn(&str) -> Option<Expr>,
60{
61    let exprs: Vec<Expr> = df
62        .schema()
63        .fields()
64        .iter()
65        .map(|field| {
66            let name = field.name();
67            if target_cols.contains(name) {
68                if let Some(fallback_expr) = get_fallback(name) {
69                    coalesce_expr_for(name, fallback_expr).alias(name)
70                } else {
71                    col(name)
72                }
73            } else {
74                col(name)
75            }
76        })
77        .collect();
78    df.select(exprs).map_err(FeatureFactoryError::from)
79}
80
81/// Replaces missing values with the mean ~~(or median)~~ value for numeric columns.
82pub struct MeanMedianImputer {
83    pub columns: Vec<String>,
84    pub strategy: ImputeStrategy,
85    pub impute_values: HashMap<String, f64>,
86    fitted: bool,
87}
88
89#[derive(Debug, Clone, Copy)]
90pub enum ImputeStrategy {
91    Mean,
92    Median, // Not implemented in DF mode.
93}
94
95impl MeanMedianImputer {
96    pub fn new(columns: Vec<String>, strategy: ImputeStrategy) -> Self {
97        Self {
98            columns,
99            strategy,
100            impute_values: HashMap::new(),
101            fitted: false,
102        }
103    }
104
105    /// Fit computes imputation parameters without materializing the input.
106    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
107        validate_columns(df, &self.columns)?;
108        for col_name in &self.columns {
109            match self.strategy {
110                ImputeStrategy::Mean => {
111                    let agg_df = df
112                        .clone()
113                        .aggregate(vec![], vec![avg(col(col_name)).alias("avg")])
114                        .map_err(FeatureFactoryError::from)?;
115                    let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
116                    if let Some(batch) = batches.first() {
117                        if batch.num_rows() > 0 {
118                            let array = batch.column(0);
119                            let scalar = ScalarValue::try_from_array(array, 0)
120                                .map_err(FeatureFactoryError::from)?;
121                            if let ScalarValue::Float64(Some(avg_val)) = scalar {
122                                self.impute_values.insert(col_name.clone(), avg_val);
123                            } else {
124                                return Err(FeatureFactoryError::DataFusionError(
125                                    datafusion::error::DataFusionError::Plan(format!(
126                                        "Failed to compute average for column {}",
127                                        col_name
128                                    )),
129                                ));
130                            }
131                        }
132                    }
133                }
134                ImputeStrategy::Median => {
135                    return Err(FeatureFactoryError::NotImplemented(
136                        "Median imputation not implemented in DF mode".to_string(),
137                    ));
138                }
139            }
140        }
141        self.fitted = true;
142        Ok(())
143    }
144
145    /// Transform applies imputation and returns a modified DataFrame.
146    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
147        if !self.fitted {
148            return Err(FeatureFactoryError::FitNotCalled);
149        }
150        validate_columns(&df, &self.columns)?;
151        apply_imputation(df, &self.columns, |name| {
152            self.impute_values.get(name).map(|&v| lit(v))
153        })
154    }
155
156    // This transformer is stateful.
157    fn inherent_is_stateful(&self) -> bool {
158        true
159    }
160}
161
162/// Replaces missing values with the given number.
163pub struct ArbitraryNumberImputer {
164    pub columns: Vec<String>,
165    pub number: f64,
166}
167
168impl ArbitraryNumberImputer {
169    pub fn new(columns: Vec<String>, number: f64) -> Self {
170        Self { columns, number }
171    }
172
173    /// Stateless transformer: fit does nothing.
174    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
175        Ok(())
176    }
177
178    /// Transform validates inputs and applies imputation.
179    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
180        if !self.number.is_finite() {
181            return Err(FeatureFactoryError::InvalidParameter(format!(
182                "Fixed number {} must be finite",
183                self.number
184            )));
185        }
186        validate_columns(&df, &self.columns)?;
187        apply_imputation(df, &self.columns, |_| Some(lit(self.number)))
188    }
189
190    // This transformer is stateless.
191    fn inherent_is_stateful(&self) -> bool {
192        false
193    }
194}
195
196/// Replaces missing values with a percentile value computed from the data.
197pub struct EndTailImputer {
198    pub columns: Vec<String>,
199    pub percentile: f64,
200    pub impute_values: HashMap<String, f64>,
201    fitted: bool,
202}
203
204impl EndTailImputer {
205    pub fn new(columns: Vec<String>, percentile: f64) -> Self {
206        Self {
207            columns,
208            percentile,
209            impute_values: HashMap::new(),
210            fitted: false,
211        }
212    }
213
214    /// Fit computes the percentile for each column.
215    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
216        validate_columns(df, &self.columns)?;
217        if self.percentile < 0.0 || self.percentile > 1.0 {
218            return Err(FeatureFactoryError::InvalidParameter(format!(
219                "Percentile {} must be between 0 and 1",
220                self.percentile
221            )));
222        }
223        for col_name in &self.columns {
224            let agg_df = df
225                .clone()
226                .aggregate(
227                    vec![],
228                    vec![
229                        approx_percentile_cont(col(col_name), lit(self.percentile), None)
230                            .alias("perc"),
231                    ],
232                )
233                .map_err(FeatureFactoryError::from)?;
234            let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
235            if let Some(batch) = batches.first() {
236                let array = batch.column(0);
237                let scalar =
238                    ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
239                if let ScalarValue::Float64(Some(val)) = scalar {
240                    self.impute_values.insert(col_name.clone(), val);
241                } else {
242                    return Err(FeatureFactoryError::DataFusionError(
243                        datafusion::error::DataFusionError::Plan(format!(
244                            "Failed to compute percentile for column {}",
245                            col_name
246                        )),
247                    ));
248                }
249            }
250        }
251        self.fitted = true;
252        Ok(())
253    }
254
255    /// Transform applies the computed percentile imputation.
256    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
257        if !self.fitted {
258            return Err(FeatureFactoryError::FitNotCalled);
259        }
260        validate_columns(&df, &self.columns)?;
261        apply_imputation(df, &self.columns, |name| {
262            self.impute_values.get(name).map(|&v| lit(v))
263        })
264    }
265
266    // This transformer is stateful.
267    fn inherent_is_stateful(&self) -> bool {
268        true
269    }
270}
271
272/// Replaces missing values with the mode (or a provided default) for categorical columns.
273pub struct CategoricalImputer {
274    pub columns: Vec<String>,
275    pub default: Option<String>,
276    pub impute_values: HashMap<String, String>,
277    fitted: bool,
278}
279
280impl CategoricalImputer {
281    pub fn new(columns: Vec<String>, default: Option<String>) -> Self {
282        Self {
283            columns,
284            default,
285            impute_values: HashMap::new(),
286            fitted: false,
287        }
288    }
289
290    /// Fit computes the mode for each column when no default is provided.
291    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
292        validate_columns(df, &self.columns)?;
293        if self.default.is_some() {
294            self.fitted = true;
295            return Ok(());
296        }
297        for col_name in &self.columns {
298            let grouped = df
299                .clone()
300                .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
301                .map_err(FeatureFactoryError::from)?
302                .sort(vec![col("cnt").sort(false, false)])
303                .map_err(FeatureFactoryError::from)?
304                .limit(0, Some(1))
305                .map_err(FeatureFactoryError::from)?;
306            let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
307            if let Some(batch) = batches.first() {
308                let array = batch.column(0);
309                let scalar =
310                    ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
311                if let ScalarValue::Utf8(Some(mode_val)) = scalar {
312                    self.impute_values.insert(col_name.clone(), mode_val);
313                } else {
314                    return Err(FeatureFactoryError::DataFusionError(
315                        datafusion::error::DataFusionError::Plan(format!(
316                            "Failed to compute mode for column {}",
317                            col_name
318                        )),
319                    ));
320                }
321            }
322        }
323        self.fitted = true;
324        Ok(())
325    }
326
327    /// Transform applies the categorical imputation.
328    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
329        if !self.fitted {
330            return Err(FeatureFactoryError::FitNotCalled);
331        }
332        validate_columns(&df, &self.columns)?;
333        apply_imputation(df, &self.columns, |name| {
334            if let Some(default_val) = &self.default {
335                Some(lit(default_val.clone()))
336            } else {
337                self.impute_values
338                    .get(name)
339                    .map(|mode_val| lit(mode_val.clone()))
340            }
341        })
342    }
343
344    // This transformer is stateful.
345    fn inherent_is_stateful(&self) -> bool {
346        true
347    }
348}
349
350/// Adds additional Boolean indicator columns for missing values.
351pub struct AddMissingIndicator {
352    pub columns: Vec<String>,
353    pub suffix: String,
354}
355
356impl AddMissingIndicator {
357    pub fn new(columns: Vec<String>, suffix: Option<String>) -> Self {
358        Self {
359            columns,
360            suffix: suffix.unwrap_or_else(|| "_missing".to_string()),
361        }
362    }
363
364    /// Stateless transformer: fit does nothing.
365    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
366        Ok(())
367    }
368
369    /// Transform validates columns and returns the modified DataFrame.
370    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
371        validate_columns(&df, &self.columns)?;
372        let mut exprs = vec![];
373        for field in df.schema().fields() {
374            let name = field.name();
375            exprs.push(col(name));
376            if self.columns.contains(name) {
377                exprs.push(
378                    col(name)
379                        .is_null()
380                        .alias(format!("{}{}", name, self.suffix)),
381                );
382            }
383        }
384        df.select(exprs).map_err(FeatureFactoryError::from)
385    }
386
387    // This transformer is stateless.
388    fn inherent_is_stateful(&self) -> bool {
389        false
390    }
391}
392
393/// Removes rows that contain a missing value in the given columns.
394pub struct DropMissingData {
395    /// Optional list of column names to check for missing values.
396    /// If None, all columns in the DataFrame are checked.
397    pub columns: Option<Vec<String>>,
398}
399
400impl DropMissingData {
401    pub fn new() -> Self {
402        Self { columns: None }
403    }
404
405    pub fn with_columns(columns: Vec<String>) -> Self {
406        Self {
407            columns: Some(columns),
408        }
409    }
410
411    /// Stateless transformer: fit does nothing.
412    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
413        Ok(())
414    }
415
416    /// Transform applies filtering and returns the modified DataFrame.
417    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
418        let target_columns = if let Some(ref cols) = self.columns {
419            cols.clone()
420        } else {
421            df.schema()
422                .fields()
423                .iter()
424                .map(|f| f.name().to_string())
425                .collect()
426        };
427        let predicates: Vec<Expr> = target_columns
428            .iter()
429            .map(|col_name| col(col_name).is_not_null())
430            .collect();
431        let combined = predicates
432            .into_iter()
433            .reduce(|acc, expr| acc.and(expr))
434            .unwrap();
435        df.filter(combined)
436            .map_err(crate::exceptions::FeatureFactoryError::from)
437    }
438
439    // This transformer is stateless.
440    fn inherent_is_stateful(&self) -> bool {
441        false
442    }
443}
444
445impl Default for DropMissingData {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451// Implement the Transformer trait for the transformers in this module.
452impl_transformer!(MeanMedianImputer);
453impl_transformer!(ArbitraryNumberImputer);
454impl_transformer!(EndTailImputer);
455impl_transformer!(CategoricalImputer);
456impl_transformer!(AddMissingIndicator);
457impl_transformer!(DropMissingData);