Skip to main content

feature_factory/transformers/
numerical.rs

1//! ## Numerical Transformation Transformers
2//!
3//! This module provides transformers for applying mathematical transformations to numerical features.
4//!
5//! ### Available Transformers
6//!
7//! - [`LogTransformer`]: Applies the natural logarithm transformation (requires positive values).
8//! - [`LogCpTransformer`]: Applies a logarithmic transformation with a constant (requires values + constant > 0).
9//! - [`ReciprocalTransformer`]: Applies the reciprocal transformation (requires non-zero values).
10//! - [`PowerTransformer`]: Applies a power transformation with a specified exponent.
11//! - [`BoxCoxTransformer`]: Applies the Box-Cox transformation (requires positive values).
12//! - [`YeoJohnsonTransformer`]: Applies the Yeo-Johnson transformation (supports all real numbers).
13//! - [`ArcsinTransformer`]: Applies the arcsine transformation (commonly used for proportions).
14//!
15//! Each transformer returns a new DataFrame with transformed features.
16//! Errors are returned as [`FeatureFactoryError`], and results are wrapped in [`FeatureFactoryResult`].
17
18use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
19use crate::impl_transformer;
20use datafusion::dataframe::DataFrame;
21use datafusion::functions_aggregate::approx_percentile_cont::approx_percentile_cont;
22use datafusion::scalar::ScalarValue;
23use datafusion_expr::{col, lit, Expr};
24use datafusion_functions::math;
25use std::ops::{Add, Div, Neg, Sub};
26
27/// Wrapper function wrapping math's natural logarithm UDF.
28fn ln_expr(e: Expr) -> Expr {
29    math::ln().call(vec![e])
30}
31
32/// Wrapper function wrapping math's power UDF.
33fn power_expr(e: Expr, p: f64) -> Expr {
34    math::power().call(vec![e, lit(p)])
35}
36
37/// Wrapper function wrapping math's square root UDF.
38fn sqrt_expr(e: Expr) -> Expr {
39    math::sqrt().call(vec![e])
40}
41
42/// Wrapper function wrapping math's arcsine UDF.
43fn asin_expr(e: Expr) -> Expr {
44    math::asin().call(vec![e])
45}
46
47/// Helper function to compute the minimum value in a numeric column using approximate percentiles (p=0).
48async fn compute_min(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<f64> {
49    let min_df = df
50        .clone()
51        .aggregate(
52            vec![],
53            vec![approx_percentile_cont(col(col_name), lit(0.0), None).alias("min")],
54        )
55        .map_err(FeatureFactoryError::from)?;
56    let batches = min_df.collect().await.map_err(FeatureFactoryError::from)?;
57    if let Some(batch) = batches.first() {
58        let array = batch.column(0);
59        let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
60        if let ScalarValue::Float64(Some(val)) = scalar {
61            Ok(val)
62        } else {
63            Err(FeatureFactoryError::DataFusionError(
64                datafusion::error::DataFusionError::Plan(format!(
65                    "Failed to compute min for column {}",
66                    col_name
67                )),
68            ))
69        }
70    } else {
71        Err(FeatureFactoryError::DataFusionError(
72            datafusion::error::DataFusionError::Plan("No data found".to_string()),
73        ))
74    }
75}
76
77/// Helper function to compute the maximum value in a numeric column using approximate percentiles (p=1).
78async fn compute_max(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<f64> {
79    let max_df = df
80        .clone()
81        .aggregate(
82            vec![],
83            vec![approx_percentile_cont(col(col_name), lit(1.0), None).alias("max")],
84        )
85        .map_err(FeatureFactoryError::from)?;
86    let batches = max_df.collect().await.map_err(FeatureFactoryError::from)?;
87    if let Some(batch) = batches.first() {
88        let array = batch.column(0);
89        let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
90        if let ScalarValue::Float64(Some(val)) = scalar {
91            Ok(val)
92        } else {
93            Err(FeatureFactoryError::DataFusionError(
94                datafusion::error::DataFusionError::Plan(format!(
95                    "Failed to compute max for column {}",
96                    col_name
97                )),
98            ))
99        }
100    } else {
101        Err(FeatureFactoryError::DataFusionError(
102            datafusion::error::DataFusionError::Plan("No data found".to_string()),
103        ))
104    }
105}
106
107/// Applies natural logarithm transformation to the values in the columns.
108/// Needs all values to be positive.
109pub struct LogTransformer {
110    pub columns: Vec<String>,
111}
112
113impl LogTransformer {
114    pub fn new(columns: Vec<String>) -> Self {
115        Self { columns }
116    }
117
118    /// Stateless transformer: fit does nothing.
119    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
120        Ok(())
121    }
122
123    /// Validates that each target column exists, is Float64, and that the minimum value is > 0.
124    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
125        for col_name in &self.columns {
126            let field = df.schema().field_with_name(None, col_name).map_err(|_| {
127                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
128            })?;
129            if field.data_type() != &datafusion::arrow::datatypes::DataType::Float64 {
130                return Err(FeatureFactoryError::InvalidParameter(format!(
131                    "LogTransformer requires column '{}' to be Float64",
132                    col_name
133                )));
134            }
135            // Compute min value.
136            let min_val = futures::executor::block_on(compute_min(df, col_name))?;
137            if min_val <= 0.0 {
138                return Err(FeatureFactoryError::InvalidParameter(format!(
139                    "LogTransformer requires all values in column '{}' to be positive, found min {}",
140                    col_name, min_val
141                )));
142            }
143        }
144        Ok(())
145    }
146
147    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
148        self.validate(&df)?;
149        let exprs: Vec<Expr> = df
150            .schema()
151            .fields()
152            .iter()
153            .map(|field| {
154                let name = field.name();
155                if self.columns.contains(name) {
156                    ln_expr(col(name)).alias(name)
157                } else {
158                    col(name)
159                }
160            })
161            .collect();
162        df.select(exprs).map_err(FeatureFactoryError::from)
163    }
164
165    fn inherent_is_stateful(&self) -> bool {
166        false
167    }
168}
169
170/// Applies logarithmic transformation with a constant to the values in the columns.
171/// Transformation: log(x + constant). Requires (min + constant) > 0.
172pub struct LogCpTransformer {
173    pub columns: Vec<String>,
174    pub constant: f64,
175}
176
177impl LogCpTransformer {
178    pub fn new(columns: Vec<String>, constant: f64) -> Self {
179        Self { columns, constant }
180    }
181
182    /// Stateless transformer: fit does nothing.
183    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
184        Ok(())
185    }
186
187    /// Validates that each target column exists, is Float64, and that (min + constant) > 0.
188    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
189        for col_name in &self.columns {
190            let field = df.schema().field_with_name(None, col_name).map_err(|_| {
191                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
192            })?;
193            if field.data_type() != &datafusion::arrow::datatypes::DataType::Float64 {
194                return Err(FeatureFactoryError::InvalidParameter(format!(
195                    "LogCpTransformer requires column '{}' to be Float64",
196                    col_name
197                )));
198            }
199            let min_val = futures::executor::block_on(compute_min(df, col_name))?;
200            if min_val + self.constant <= 0.0 {
201                return Err(FeatureFactoryError::InvalidParameter(format!(
202                    "LogCpTransformer requires (min + constant) > 0 for column '{}', but min {} + constant {} = {}",
203                    col_name, min_val, self.constant, min_val + self.constant
204                )));
205            }
206        }
207        Ok(())
208    }
209
210    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
211        self.validate(&df)?;
212        let exprs: Vec<Expr> = df
213            .schema()
214            .fields()
215            .iter()
216            .map(|field| {
217                let name = field.name();
218                if self.columns.contains(name) {
219                    ln_expr(col(name).add(lit(self.constant))).alias(name)
220                } else {
221                    col(name)
222                }
223            })
224            .collect();
225        df.select(exprs).map_err(FeatureFactoryError::from)
226    }
227
228    fn inherent_is_stateful(&self) -> bool {
229        false
230    }
231}
232
233/// Applies reciprocal transformation (1/x) to the values in the columns.
234/// Requires that no value is zero.
235pub struct ReciprocalTransformer {
236    pub columns: Vec<String>,
237}
238
239impl ReciprocalTransformer {
240    pub fn new(columns: Vec<String>) -> Self {
241        Self { columns }
242    }
243
244    /// Stateless transformer: fit does nothing.
245    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
246        Ok(())
247    }
248
249    /// Validates that each target column exists, is Float64, and that no value is zero.
250    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
251        for col_name in &self.columns {
252            let field = df.schema().field_with_name(None, col_name).map_err(|_| {
253                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
254            })?;
255            if field.data_type() != &datafusion::arrow::datatypes::DataType::Float64 {
256                return Err(FeatureFactoryError::InvalidParameter(format!(
257                    "ReciprocalTransformer requires column '{}' to be Float64",
258                    col_name
259                )));
260            }
261            let min_val = futures::executor::block_on(compute_min(df, col_name))?;
262            let max_val = futures::executor::block_on(compute_max(df, col_name))?;
263            if min_val <= 0.0 && max_val >= 0.0 {
264                return Err(FeatureFactoryError::InvalidParameter(format!(
265                    "ReciprocalTransformer requires column '{}' to have no zero values (found range [{}, {}])",
266                    col_name, min_val, max_val
267                )));
268            }
269        }
270        Ok(())
271    }
272
273    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
274        self.validate(&df)?;
275        let exprs: Vec<Expr> = df
276            .schema()
277            .fields()
278            .iter()
279            .map(|field| {
280                let name = field.name();
281                if self.columns.contains(name) {
282                    lit(1.0).div(col(name)).alias(name)
283                } else {
284                    col(name)
285                }
286            })
287            .collect();
288        df.select(exprs).map_err(FeatureFactoryError::from)
289    }
290
291    fn inherent_is_stateful(&self) -> bool {
292        false
293    }
294}
295
296/// Applies power transformation to the values in the columns (x^power).
297pub struct PowerTransformer {
298    pub columns: Vec<String>,
299    pub power: f64,
300}
301
302impl PowerTransformer {
303    pub fn new(columns: Vec<String>, power: f64) -> Self {
304        Self { columns, power }
305    }
306
307    /// Stateless transformer: fit does nothing.
308    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
309        Ok(())
310    }
311
312    /// Validates that each target column exists.
313    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
314        for col_name in &self.columns {
315            df.schema().field_with_name(None, col_name).map_err(|_| {
316                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
317            })?;
318        }
319        Ok(())
320    }
321
322    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
323        self.validate(&df)?;
324        let exprs: Vec<Expr> = df
325            .schema()
326            .fields()
327            .iter()
328            .map(|field| {
329                let name = field.name();
330                if self.columns.contains(name) {
331                    power_expr(col(name), self.power).alias(name)
332                } else {
333                    col(name)
334                }
335            })
336            .collect();
337        df.select(exprs).map_err(FeatureFactoryError::from)
338    }
339
340    fn inherent_is_stateful(&self) -> bool {
341        false
342    }
343}
344
345/// Applies Box–Cox transformation to the values in the columns.
346/// Transformation: (x^lambda - 1) / lambda for lambda != 0, else ln(x)
347/// Needs all values to be positive.
348pub struct BoxCoxTransformer {
349    pub columns: Vec<String>,
350    pub lambda: f64,
351}
352
353impl BoxCoxTransformer {
354    pub fn new(columns: Vec<String>, lambda: f64) -> Self {
355        Self { columns, lambda }
356    }
357
358    /// Stateless transformer: fit does nothing.
359    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
360        Ok(())
361    }
362
363    /// Validates that each target column exists, is Float64, and that all values are positive.
364    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
365        for col_name in &self.columns {
366            let field = df.schema().field_with_name(None, col_name).map_err(|_| {
367                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
368            })?;
369            if field.data_type() != &datafusion::arrow::datatypes::DataType::Float64 {
370                return Err(FeatureFactoryError::InvalidParameter(format!(
371                    "BoxCoxTransformer requires column '{}' to be Float64",
372                    col_name
373                )));
374            }
375            let min_val = futures::executor::block_on(compute_min(df, col_name))?;
376            if min_val <= 0.0 {
377                return Err(FeatureFactoryError::InvalidParameter(format!(
378                    "BoxCoxTransformer requires all values in column '{}' to be positive, found min {}",
379                    col_name, min_val
380                )));
381            }
382        }
383        Ok(())
384    }
385
386    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
387        self.validate(&df)?;
388        let exprs: Vec<Expr> = df
389            .schema()
390            .fields()
391            .iter()
392            .map(|field| {
393                let name = field.name();
394                if self.columns.contains(name) {
395                    let expr = if (self.lambda - 0.0).abs() > 1e-6 {
396                        power_expr(col(name), self.lambda)
397                            .sub(lit(1.0))
398                            .div(lit(self.lambda))
399                    } else {
400                        ln_expr(col(name))
401                    };
402                    expr.alias(name)
403                } else {
404                    col(name)
405                }
406            })
407            .collect();
408        df.select(exprs).map_err(FeatureFactoryError::from)
409    }
410
411    fn inherent_is_stateful(&self) -> bool {
412        false
413    }
414}
415
416/// Applies Yeo–Johnson transformation to the values in the columns.
417/// For x >= 0: ( (x + 1)^lambda - 1) / lambda for lambda != 0, else ln(x + 1)
418/// and for x < 0: -((1 - x)^(2 - lambda) - 1) / (2 - lambda) for lambda != 2, else -ln(1 - x)
419pub struct YeoJohnsonTransformer {
420    pub columns: Vec<String>,
421    pub lambda: f64,
422}
423
424impl YeoJohnsonTransformer {
425    pub fn new(columns: Vec<String>, lambda: f64) -> Self {
426        Self { columns, lambda }
427    }
428
429    /// Stateless transformer: fit does nothing.
430    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
431        Ok(())
432    }
433
434    /// Validates that each target column exists.
435    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
436        for col_name in &self.columns {
437            df.schema().field_with_name(None, col_name).map_err(|_| {
438                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
439            })?;
440        }
441        Ok(())
442    }
443
444    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
445        self.validate(&df)?;
446        let exprs: Vec<Expr> = df
447            .schema()
448            .fields()
449            .iter()
450            .map(|field| {
451                let name = field.name();
452                if self.columns.contains(name) {
453                    let pos_expr = if (self.lambda - 0.0).abs() > 1e-6 {
454                        power_expr(col(name).add(lit(1.0)), self.lambda)
455                            .sub(lit(1.0))
456                            .div(lit(self.lambda))
457                    } else {
458                        ln_expr(col(name).add(lit(1.0)))
459                    };
460                    let neg_expr = if (self.lambda - 2.0).abs() > 1e-6 {
461                        power_expr(lit(1.0).sub(col(name)), 2.0 - self.lambda)
462                            .sub(lit(1.0))
463                            .div(lit(2.0 - self.lambda))
464                            .neg()
465                    } else {
466                        ln_expr(lit(1.0).sub(col(name))).neg()
467                    };
468                    let case_expr = Expr::Case(datafusion_expr::expr::Case {
469                        expr: None,
470                        when_then_expr: vec![(
471                            Box::new(col(name).gt_eq(lit(0.0))),
472                            Box::new(pos_expr),
473                        )],
474                        else_expr: Some(Box::new(neg_expr)),
475                    });
476                    case_expr.alias(name)
477                } else {
478                    col(name)
479                }
480            })
481            .collect();
482        df.select(exprs).map_err(FeatureFactoryError::from)
483    }
484
485    fn inherent_is_stateful(&self) -> bool {
486        false
487    }
488}
489
490/// Applies an arcsine transformation defined as asin(sqrt(x)) to the values in the columns.
491/// Needs all values to be between 0 and 1.
492pub struct ArcsinTransformer {
493    pub columns: Vec<String>,
494}
495
496impl ArcsinTransformer {
497    pub fn new(columns: Vec<String>) -> Self {
498        Self { columns }
499    }
500
501    /// Stateless transformer: fit does nothing.
502    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
503        Ok(())
504    }
505
506    /// Validates that each target column exists, is Float64, and that all values are between 0 and 1.
507    fn validate(&self, df: &DataFrame) -> FeatureFactoryResult<()> {
508        for col_name in &self.columns {
509            let field = df.schema().field_with_name(None, col_name).map_err(|_| {
510                FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
511            })?;
512            if field.data_type() != &datafusion::arrow::datatypes::DataType::Float64 {
513                return Err(FeatureFactoryError::InvalidParameter(format!(
514                    "ArcsinTransformer requires column '{}' to be Float64",
515                    col_name
516                )));
517            }
518            let min_val = futures::executor::block_on(compute_min(df, col_name))?;
519            let max_val = futures::executor::block_on(compute_max(df, col_name))?;
520            if min_val < 0.0 || max_val > 1.0 {
521                return Err(FeatureFactoryError::InvalidParameter(format!(
522                    "ArcsinTransformer requires all values in column '{}' to be between 0 and 1, found range [{}, {}]",
523                    col_name, min_val, max_val
524                )));
525            }
526        }
527        Ok(())
528    }
529
530    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
531        self.validate(&df)?;
532        let exprs: Vec<Expr> = df
533            .schema()
534            .fields()
535            .iter()
536            .map(|field| {
537                let name = field.name();
538                if self.columns.contains(name) {
539                    asin_expr(sqrt_expr(col(name))).alias(name)
540                } else {
541                    col(name)
542                }
543            })
544            .collect();
545        df.select(exprs).map_err(FeatureFactoryError::from)
546    }
547
548    fn inherent_is_stateful(&self) -> bool {
549        false
550    }
551}
552
553// Implement the Transformer trait for the transformers in this module.
554impl_transformer!(LogTransformer);
555impl_transformer!(LogCpTransformer);
556impl_transformer!(ReciprocalTransformer);
557impl_transformer!(PowerTransformer);
558impl_transformer!(BoxCoxTransformer);
559impl_transformer!(YeoJohnsonTransformer);
560impl_transformer!(ArcsinTransformer);