Skip to main content

feature_factory/transformers/
discretization.rs

1//! ## Continuous Variable Discretization Transformers
2//!
3//! This module provides transformers to convert continuous variables into categorical ones by binning them into discrete intervals.
4//!
5//! ### Available Transformers
6//!
7//! - [`ArbitraryDiscretizer`]: Discretizes based on user-defined intervals.
8//! - [`EqualFrequencyDiscretizer`]: Splits a column into bins containing approximately equal numbers of values.
9//! - [`EqualWidthDiscretizer`]: Divides a column into bins of equal width.
10//! - [`GeometricWidthDiscretizer`]: Uses a geometric progression to determine bin boundaries.
11//!
12//! Each transformer returns a new DataFrame with the transformed columns.
13//! Errors are returned as [`FeatureFactoryError`], and results are wrapped in [`FeatureFactoryResult`].
14
15use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
16use crate::impl_transformer;
17use datafusion::dataframe::DataFrame;
18use datafusion::functions_aggregate::expr_fn::approx_percentile_cont;
19use datafusion::logical_expr::{col, lit, Case as DFCase, Expr};
20use datafusion::scalar::ScalarValue;
21use std::collections::HashMap;
22
23/// Validates that a column exists and is numeric (Float64 or Int64).
24fn validate_numeric_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
25    let field = df.schema().field_with_name(None, col_name).map_err(|_| {
26        FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
27    })?;
28    match field.data_type() {
29        datafusion::arrow::datatypes::DataType::Float64
30        | datafusion::arrow::datatypes::DataType::Int64 => Ok(()),
31        dt => Err(FeatureFactoryError::InvalidParameter(format!(
32            "Column '{}' must be numeric (Float64 or Int64), but found {:?}",
33            col_name, dt
34        ))),
35    }
36}
37
38/// Helper function to build a CASE expression for intervals.
39/// For each interval in `intervals` (tuple: lower, upper, label), it generates a condition:
40/// For the first (n-1) intervals, the condition is:
41/// `WHEN col >= lower AND col < upper THEN label`
42/// For the last interval, the condition is:
43/// `WHEN col >= lower AND col <= upper THEN label`
44/// If none match, returns NULL.
45fn build_interval_case_expr(col_name: &str, intervals: &[(f64, f64, String)]) -> Expr {
46    let n = intervals.len();
47    let when_then_expr = intervals
48        .iter()
49        .enumerate()
50        .map(|(i, (lower, upper, label))| {
51            let condition = if i == n - 1 {
52                col(col_name)
53                    .gt_eq(lit(*lower))
54                    .and(col(col_name).lt_eq(lit(*upper)))
55            } else {
56                col(col_name)
57                    .gt_eq(lit(*lower))
58                    .and(col(col_name).lt(lit(*upper)))
59            };
60            (Box::new(condition), Box::new(lit(label.clone())))
61        })
62        .collect::<Vec<_>>();
63    Expr::Case(DFCase {
64        expr: None,
65        when_then_expr,
66        else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
67    })
68}
69
70/// Generic helper function that applies an interval mapping to each target column in a DataFrame.
71/// For each column in `target_cols`, if a mapping exists in `mapping` then a CASE expression
72/// is built; otherwise, the original column is retained.
73fn apply_interval_mapping(
74    df: DataFrame,
75    target_cols: &[String],
76    mapping: &HashMap<String, Vec<(f64, f64, String)>>,
77) -> FeatureFactoryResult<DataFrame> {
78    let exprs: Vec<Expr> = df
79        .schema()
80        .fields()
81        .iter()
82        .map(|field| {
83            let name = field.name();
84            if target_cols.contains(name) {
85                if let Some(intervals) = mapping.get(name) {
86                    build_interval_case_expr(name, intervals).alias(name)
87                } else {
88                    col(name)
89                }
90            } else {
91                col(name)
92            }
93        })
94        .collect();
95    df.select(exprs).map_err(FeatureFactoryError::from)
96}
97
98/// Helper function to compute the min and max of a column using approximate percentiles.
99/// It uses p=0 for min and p=1 for max.
100async fn compute_min_max(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<(f64, f64)> {
101    validate_numeric_column(df, col_name)?;
102    let min_df = df
103        .clone()
104        .aggregate(
105            vec![],
106            vec![approx_percentile_cont(col(col_name), lit(0.0), None).alias("min")],
107        )
108        .map_err(FeatureFactoryError::from)?;
109    let min_batches = min_df.collect().await.map_err(FeatureFactoryError::from)?;
110    let min_val = if let Some(batch) = min_batches.first() {
111        let array = batch.column(0);
112        let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
113        if let ScalarValue::Float64(Some(val)) = scalar {
114            val
115        } else {
116            return Err(FeatureFactoryError::DataFusionError(
117                datafusion::error::DataFusionError::Plan(format!(
118                    "Failed to compute min for column {}",
119                    col_name
120                )),
121            ));
122        }
123    } else {
124        return Err(FeatureFactoryError::DataFusionError(
125            datafusion::error::DataFusionError::Plan("No data found".to_string()),
126        ));
127    };
128
129    let max_df = df
130        .clone()
131        .aggregate(
132            vec![],
133            vec![approx_percentile_cont(col(col_name), lit(1.0), None).alias("max")],
134        )
135        .map_err(FeatureFactoryError::from)?;
136    let max_batches = max_df.collect().await.map_err(FeatureFactoryError::from)?;
137    let max_val = if let Some(batch) = max_batches.first() {
138        let array = batch.column(0);
139        let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
140        if let ScalarValue::Float64(Some(val)) = scalar {
141            val
142        } else {
143            return Err(FeatureFactoryError::DataFusionError(
144                datafusion::error::DataFusionError::Plan(format!(
145                    "Failed to compute max for column {}",
146                    col_name
147                )),
148            ));
149        }
150    } else {
151        return Err(FeatureFactoryError::DataFusionError(
152            datafusion::error::DataFusionError::Plan("No data found".to_string()),
153        ));
154    };
155
156    Ok((min_val, max_val))
157}
158
159/// Discretizes a column into arbitrary intervals defined by the user.
160pub struct ArbitraryDiscretizer {
161    pub columns: Vec<String>,
162    pub intervals: HashMap<String, Vec<(f64, f64, String)>>,
163}
164
165impl ArbitraryDiscretizer {
166    pub fn new(columns: Vec<String>, intervals: HashMap<String, Vec<(f64, f64, String)>>) -> Self {
167        Self { columns, intervals }
168    }
169
170    /// Stateless transformer: fit does nothing.
171    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
172        Ok(())
173    }
174
175    /// Transform validates that each target column is numeric and that the intervals are valid,
176    /// then applies the interval mapping.
177    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
178        for col_name in &self.columns {
179            validate_numeric_column(&df, col_name)?;
180        }
181        for (col, intervals) in &self.intervals {
182            for (lower, upper, _) in intervals {
183                if lower >= upper {
184                    return Err(FeatureFactoryError::InvalidParameter(format!(
185                        "For column '{}', lower bound {} is not less than upper bound {}",
186                        col, lower, upper
187                    )));
188                }
189            }
190        }
191        apply_interval_mapping(df, &self.columns, &self.intervals)
192    }
193
194    // This transformer is stateless.
195    fn inherent_is_stateful(&self) -> bool {
196        false
197    }
198}
199
200/// Splits a column into bins containing approximately equal numbers of values from the column.
201pub struct EqualFrequencyDiscretizer {
202    pub columns: Vec<String>,
203    pub bins: usize,
204    pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
205    fitted: bool,
206}
207
208impl EqualFrequencyDiscretizer {
209    pub fn new(columns: Vec<String>, bins: usize) -> Self {
210        Self {
211            columns,
212            bins,
213            mapping: HashMap::new(),
214            fitted: false,
215        }
216    }
217
218    /// Fit computes equal-frequency intervals and stores the mapping.
219    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
220        if self.bins < 1 {
221            return Err(FeatureFactoryError::InvalidParameter(
222                "Number of bins must be at least 1".to_string(),
223            ));
224        }
225        for col_name in &self.columns {
226            validate_numeric_column(df, col_name)?;
227            let mut boundaries = Vec::with_capacity(self.bins + 1);
228            for i in 0..=self.bins {
229                let p = i as f64 / self.bins as f64;
230                let agg_df = df
231                    .clone()
232                    .aggregate(
233                        vec![],
234                        vec![approx_percentile_cont(col(col_name), lit(p), None).alias("q")],
235                    )
236                    .map_err(FeatureFactoryError::from)?;
237                let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
238                if let Some(batch) = batches.first() {
239                    let array = batch.column(0);
240                    let scalar =
241                        ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
242                    if let ScalarValue::Float64(Some(val)) = scalar {
243                        boundaries.push(val);
244                    } else {
245                        return Err(FeatureFactoryError::DataFusionError(
246                            datafusion::error::DataFusionError::Plan(format!(
247                                "Failed to compute percentile for column {}",
248                                col_name
249                            )),
250                        ));
251                    }
252                }
253            }
254            if let (Some(first), Some(last)) = (boundaries.first(), boundaries.last()) {
255                if (first - last).abs() < 1e-6 {
256                    return Err(FeatureFactoryError::InvalidParameter(format!(
257                        "Column {} appears to be constant; cannot discretize into equal-frequency bins",
258                        col_name
259                    )));
260                }
261            }
262            let intervals = boundaries
263                .windows(2)
264                .map(|pair| {
265                    let lower = pair[0];
266                    let upper = pair[1];
267                    let label = format!("[{:.2}, {:.2})", lower, upper);
268                    (lower, upper, label)
269                })
270                .collect::<Vec<_>>();
271            self.mapping.insert(col_name.clone(), intervals);
272        }
273        self.fitted = true;
274        Ok(())
275    }
276
277    /// Transform applies the equal-frequency discretization.
278    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
279        if !self.fitted {
280            return Err(FeatureFactoryError::FitNotCalled);
281        }
282        apply_interval_mapping(df, &self.columns, &self.mapping)
283    }
284
285    // This transformer is stateful.
286    fn inherent_is_stateful(&self) -> bool {
287        true
288    }
289}
290
291/// Splits a column into bins of equal width.
292pub struct EqualWidthDiscretizer {
293    pub columns: Vec<String>,
294    pub bins: usize,
295    pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
296    fitted: bool,
297}
298
299impl EqualWidthDiscretizer {
300    pub fn new(columns: Vec<String>, bins: usize) -> Self {
301        Self {
302            columns,
303            bins,
304            mapping: HashMap::new(),
305            fitted: false,
306        }
307    }
308
309    /// Fit computes the min and max and then builds equal-width intervals.
310    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
311        if self.bins < 1 {
312            return Err(FeatureFactoryError::InvalidParameter(
313                "Number of bins must be at least 1".to_string(),
314            ));
315        }
316        for col_name in &self.columns {
317            validate_numeric_column(df, col_name)?;
318            let (min_val, max_val) = compute_min_max(df, col_name).await?;
319            if (max_val - min_val).abs() < 1e-6 {
320                return Err(FeatureFactoryError::InvalidParameter(format!(
321                    "Column {} is constant (min == max), cannot discretize into equal-width bins",
322                    col_name
323                )));
324            }
325            let width = (max_val - min_val) / self.bins as f64;
326            let intervals = (0..self.bins)
327                .map(|i| {
328                    let lower = min_val + i as f64 * width;
329                    let upper = if i == self.bins - 1 {
330                        max_val
331                    } else {
332                        min_val + (i as f64 + 1.0) * width
333                    };
334                    let label = format!("[{:.2}, {:.2})", lower, upper);
335                    (lower, upper, label)
336                })
337                .collect::<Vec<_>>();
338            self.mapping.insert(col_name.clone(), intervals);
339        }
340        self.fitted = true;
341        Ok(())
342    }
343
344    /// Transform applies the equal-width discretization.
345    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
346        if !self.fitted {
347            return Err(FeatureFactoryError::FitNotCalled);
348        }
349        apply_interval_mapping(df, &self.columns, &self.mapping)
350    }
351
352    // This transformer is stateful.
353    fn inherent_is_stateful(&self) -> bool {
354        true
355    }
356}
357
358/// Uses a geometric progression to determine bin boundaries.
359pub struct GeometricWidthDiscretizer {
360    pub columns: Vec<String>,
361    pub bins: usize,
362    pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
363    fitted: bool,
364}
365
366impl GeometricWidthDiscretizer {
367    pub fn new(columns: Vec<String>, bins: usize) -> Self {
368        Self {
369            columns,
370            bins,
371            mapping: HashMap::new(),
372            fitted: false,
373        }
374    }
375
376    /// Fit computes min and max and then generates geometric intervals.
377    /// Returns an error if any column has non-positive values.
378    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
379        if self.bins < 1 {
380            return Err(FeatureFactoryError::InvalidParameter(
381                "Number of bins must be at least 1".to_string(),
382            ));
383        }
384        for col_name in &self.columns {
385            validate_numeric_column(df, col_name)?;
386            let (min_val, max_val) = compute_min_max(df, col_name).await?;
387            if min_val <= 0.0 {
388                return Err(FeatureFactoryError::DataFusionError(
389                    datafusion::error::DataFusionError::Plan(format!(
390                        "Column {} has non-positive values, cannot apply geometric discretization",
391                        col_name
392                    )),
393                ));
394            }
395            let ratio = (max_val / min_val).powf(1.0 / self.bins as f64);
396            let intervals = (0..self.bins)
397                .map(|i| {
398                    let lower = min_val * ratio.powi(i as i32);
399                    let upper = if i == self.bins - 1 {
400                        max_val
401                    } else {
402                        min_val * ratio.powi((i + 1) as i32)
403                    };
404                    let label = format!("[{:.2}, {:.2})", lower, upper);
405                    (lower, upper, label)
406                })
407                .collect::<Vec<_>>();
408            self.mapping.insert(col_name.clone(), intervals);
409        }
410        self.fitted = true;
411        Ok(())
412    }
413
414    /// Transform applies the geometric-width discretization.
415    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
416        if !self.fitted {
417            return Err(FeatureFactoryError::FitNotCalled);
418        }
419        apply_interval_mapping(df, &self.columns, &self.mapping)
420    }
421
422    // This transformer is stateful.
423    fn inherent_is_stateful(&self) -> bool {
424        true
425    }
426}
427
428// Implement the Transformer trait for the transformers in this module.
429impl_transformer!(ArbitraryDiscretizer);
430impl_transformer!(EqualFrequencyDiscretizer);
431impl_transformer!(EqualWidthDiscretizer);
432impl_transformer!(GeometricWidthDiscretizer);