Skip to main content

feature_factory/transformers/
categorical.rs

1//! ## Categorical Encoding Transformers
2//!
3//! This module provides transformers (or encoders) that convert categorical features into numeric values.
4//!
5//! ### Available Transformers
6//!
7//! - [`OneHotEncoder`]: Expands each categorical column into multiple binary columns, one per distinct category.
8//! - [`CountFrequencyEncoder`]: Replaces each category with its count (or frequency).
9//! - [`OrdinalEncoder`]: Replaces each category with an ordinal (ordered integer) value.
10//! - [`MeanEncoder`]: Replaces each category with the mean of a target variable.
11//! - [`WoEEncoder`]: Replaces each category with its weight of evidence, calculated as the logarithm of the ratio of probabilities of “good” outcomes to “bad” outcomes.
12//! - [`RareLabelEncoder`]: Groups infrequent categories into a single “rare” label.
13//!
14//! Each transformer returns a new DataFrame with the applied encodings.
15//! Errors are returned as `FeatureFactoryError`, and results are wrapped in `FeatureFactoryResult`.
16
17use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
18use crate::impl_transformer;
19use arrow::array::Array;
20use arrow::datatypes::DataType;
21use datafusion::dataframe::DataFrame;
22use datafusion::functions_aggregate::expr_fn::{avg, count};
23use datafusion::logical_expr::{col, lit, Case as DFCase, Expr};
24use std::collections::HashMap;
25
26/// Validates that a column exists and is of Utf8 type.
27fn validate_string_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
28    let field = df.schema().field_with_name(None, col_name).map_err(|_| {
29        FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
30    })?;
31    if field.data_type() != &DataType::Utf8 {
32        return Err(FeatureFactoryError::InvalidParameter(format!(
33            "Column '{}' must be of type Utf8, but found {:?}",
34            col_name,
35            field.data_type()
36        )));
37    }
38    Ok(())
39}
40
41/// Validates that all columns in `cols` exist and are of Utf8 type.
42fn validate_string_columns(df: &DataFrame, cols: &[String]) -> FeatureFactoryResult<()> {
43    for col in cols {
44        validate_string_column(df, col)?;
45    }
46    Ok(())
47}
48
49/// Validates that a column exists and is numeric (Float64 or Int64).
50fn validate_numeric_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
51    let field = df.schema().field_with_name(None, col_name).map_err(|_| {
52        FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
53    })?;
54    match field.data_type() {
55        DataType::Float64 | DataType::Int64 => Ok(()),
56        dt => Err(FeatureFactoryError::InvalidParameter(format!(
57            "Column '{}' must be numeric (Float64 or Int64), but found {:?}",
58            col_name, dt
59        ))),
60    }
61}
62
63/// Sanitizes a category string so that it can be safely used as part of a column name.
64/// Non-alphanumeric characters are replaced with underscores.
65fn sanitize_category(cat: &str) -> String {
66    cat.replace(|c: char| !c.is_alphanumeric(), "_")
67}
68
69/// Helper function to build a CASE WHEN expression given a mapping from category strings to values.
70/// For each pair, the expression generated is:
71/// `WHEN <col> = lit(<category>) THEN lit(<encoded_value>)`
72/// If provided, `default` is used as the ELSE branch; otherwise, the original column is returned.
73fn build_case_expr<T: Clone + 'static + datafusion::logical_expr::Literal>(
74    col_name: &str,
75    mapping: &[(String, T)],
76    default: Option<Expr>,
77) -> Expr {
78    let when_then_expr = mapping
79        .iter()
80        .map(|(cat, val)| {
81            (
82                Box::new(col(col_name).eq(lit(cat.clone()))),
83                Box::new(lit(val.clone())),
84            )
85        })
86        .collect();
87    Expr::Case(DFCase {
88        expr: None,
89        when_then_expr,
90        else_expr: default.map(Box::new),
91    })
92}
93
94/// Extract distinct string values for a given column from a DataFrame.
95async fn extract_distinct_values(
96    df: &DataFrame,
97    col_name: &str,
98) -> FeatureFactoryResult<Vec<String>> {
99    // Validate that the column is of string type.
100    validate_string_column(df, col_name)?;
101    let distinct_df = df.clone().select(vec![col(col_name)])?.distinct()?;
102    let batches = distinct_df
103        .collect()
104        .await
105        .map_err(FeatureFactoryError::from)?;
106    let mut values = Vec::new();
107    for batch in batches {
108        let array = batch
109            .column(0)
110            .as_any()
111            .downcast_ref::<datafusion::arrow::array::StringArray>()
112            .ok_or_else(|| {
113                FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
114                    format!("Expected Utf8 array for column {}", col_name),
115                ))
116            })?;
117        for i in 0..array.len() {
118            if !array.is_null(i) {
119                values.push(array.value(i).to_string());
120            }
121        }
122    }
123    Ok(values)
124}
125
126/// Extract a mapping (category -> count) for a given column by aggregating counts.
127async fn extract_count_mapping(
128    df: &DataFrame,
129    col_name: &str,
130) -> FeatureFactoryResult<HashMap<String, i64>> {
131    validate_string_column(df, col_name)?;
132    let grouped = df
133        .clone()
134        .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
135        .map_err(FeatureFactoryError::from)?;
136    let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
137    let mut map = HashMap::new();
138    for batch in batches {
139        let cat_array = batch
140            .column(0)
141            .as_any()
142            .downcast_ref::<datafusion::arrow::array::StringArray>()
143            .ok_or_else(|| {
144                FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
145                    format!("Expected Utf8 array for column {}", col_name),
146                ))
147            })?;
148        let count_array = batch
149            .column(1)
150            .as_any()
151            .downcast_ref::<datafusion::arrow::array::Int64Array>()
152            .ok_or_else(|| {
153                FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
154                    "Expected Int64 array".into(),
155                ))
156            })?;
157        for i in 0..batch.num_rows() {
158            if !cat_array.is_null(i) {
159                map.insert(cat_array.value(i).to_string(), count_array.value(i));
160            }
161        }
162    }
163    Ok(map)
164}
165
166/// Generic helper function to apply a mapping to each target column in a DataFrame.
167/// For each field, if the column is in `target_cols` and a mapping is available via `mapping_fn`,
168/// then the function replaces the column with a CASE–WHEN expression; otherwise, the original
169/// column is retained. The `default_fn` closure produces a default expression for a given column name.
170fn apply_mapping<T: Clone + 'static + datafusion::logical_expr::Literal>(
171    df: DataFrame,
172    target_cols: &[String],
173    mapping_fn: impl Fn(&str) -> Option<Vec<(String, T)>>,
174    default_fn: impl Fn(&str) -> Option<Expr>,
175) -> FeatureFactoryResult<DataFrame> {
176    let exprs: Vec<Expr> = df
177        .schema()
178        .fields()
179        .iter()
180        .map(|field| {
181            let name = field.name();
182            if target_cols.contains(name) {
183                if let Some(map) = mapping_fn(name) {
184                    build_case_expr(name, &map, default_fn(name)).alias(name)
185                } else {
186                    col(name)
187                }
188            } else {
189                col(name)
190            }
191        })
192        .collect();
193    df.select(exprs).map_err(FeatureFactoryError::from)
194}
195
196/// Expands each categorical column into multiple binary columns, one per distinct category.
197pub struct OneHotEncoder {
198    pub columns: Vec<String>,
199    /// Mapping from column name to list of distinct category values.
200    pub categories: HashMap<String, Vec<String>>,
201    fitted: bool,
202}
203
204impl OneHotEncoder {
205    /// Create a new OneHotEncoder for the specified columns.
206    pub fn new(columns: Vec<String>) -> Self {
207        Self {
208            columns,
209            categories: HashMap::new(),
210            fitted: false,
211        }
212    }
213
214    /// Fit computes and stores distinct category values.
215    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
216        validate_string_columns(df, &self.columns)?;
217        for col_name in &self.columns {
218            let values = extract_distinct_values(df, col_name).await?;
219            self.categories.insert(col_name.clone(), values);
220        }
221        self.fitted = true;
222        Ok(())
223    }
224
225    /// Transform applies one-hot encoding and returns a new DataFrame.
226    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
227        if !self.fitted {
228            return Err(FeatureFactoryError::FitNotCalled);
229        }
230        let mut exprs = vec![];
231        for field in df.schema().fields() {
232            exprs.push(col(field.name()));
233        }
234        for col_name in &self.columns {
235            if let Some(cats) = self.categories.get(col_name) {
236                for cat in cats {
237                    let safe_cat = sanitize_category(cat);
238                    let new_col_name = format!("{}_{}", col_name, safe_cat);
239                    let case_expr = Expr::Case(DFCase {
240                        expr: None,
241                        when_then_expr: vec![(
242                            Box::new(col(col_name).eq(lit(cat.clone()))),
243                            Box::new(lit(1_i32)),
244                        )],
245                        else_expr: Some(Box::new(lit(0_i32))),
246                    })
247                    .alias(new_col_name);
248                    exprs.push(case_expr);
249                }
250            }
251        }
252        df.select(exprs).map_err(FeatureFactoryError::from)
253    }
254
255    // This transformer is stateful.
256    fn inherent_is_stateful(&self) -> bool {
257        true
258    }
259}
260
261/// Replaces each category in a column with its frequency.
262pub struct CountFrequencyEncoder {
263    pub columns: Vec<String>,
264    /// Mapping from column to (category -> count)
265    pub mapping: HashMap<String, HashMap<String, i64>>,
266    fitted: bool,
267}
268
269impl CountFrequencyEncoder {
270    /// Create a new CountFrequencyEncoder for the specified columns.
271    pub fn new(columns: Vec<String>) -> Self {
272        Self {
273            columns,
274            mapping: HashMap::new(),
275            fitted: false,
276        }
277    }
278
279    /// Fit computes counts for each category.
280    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
281        validate_string_columns(df, &self.columns)?;
282        for col_name in &self.columns {
283            let map = extract_count_mapping(df, col_name).await?;
284            self.mapping.insert(col_name.clone(), map);
285        }
286        self.fitted = true;
287        Ok(())
288    }
289
290    /// Transform replaces each category with its count.
291    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
292        if !self.fitted {
293            return Err(FeatureFactoryError::FitNotCalled);
294        }
295        apply_mapping(
296            df,
297            &self.columns,
298            |name| {
299                self.mapping.get(name).map(|m| {
300                    m.iter()
301                        .map(|(k, &v)| (k.clone(), v))
302                        .collect::<Vec<(String, i64)>>()
303                })
304            },
305            |_| Some(lit(0_i64)),
306        )
307    }
308
309    // This transformer is stateful.
310    fn inherent_is_stateful(&self) -> bool {
311        true
312    }
313}
314
315/// Replaces each category with an ordinal (ordered integer) value.
316/// Categories are sorted alphabetically and assigned increasing integers starting at 0.
317pub struct OrdinalEncoder {
318    pub columns: Vec<String>,
319    /// Mapping from column to (category -> ordinal index)
320    pub mapping: HashMap<String, HashMap<String, i64>>,
321    fitted: bool,
322}
323
324impl OrdinalEncoder {
325    /// Create a new OrdinalEncoder for the specified columns.
326    pub fn new(columns: Vec<String>) -> Self {
327        Self {
328            columns,
329            mapping: HashMap::new(),
330            fitted: false,
331        }
332    }
333
334    /// Fit computes the ordinal mapping.
335    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
336        validate_string_columns(df, &self.columns)?;
337        for col_name in &self.columns {
338            let mut values = extract_distinct_values(df, col_name).await?;
339            values.sort();
340            let mapping = values
341                .into_iter()
342                .enumerate()
343                .map(|(i, cat)| (cat, i as i64))
344                .collect();
345            self.mapping.insert(col_name.clone(), mapping);
346        }
347        self.fitted = true;
348        Ok(())
349    }
350
351    /// Transform replaces each category with its ordinal index.
352    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
353        if !self.fitted {
354            return Err(FeatureFactoryError::FitNotCalled);
355        }
356        apply_mapping(
357            df,
358            &self.columns,
359            |name| {
360                self.mapping.get(name).map(|m| {
361                    m.iter()
362                        .map(|(k, &v)| (k.clone(), v))
363                        .collect::<Vec<(String, i64)>>()
364                })
365            },
366            |_| Some(lit(0_i64)),
367        )
368    }
369
370    // This transformer is stateful.
371    fn inherent_is_stateful(&self) -> bool {
372        true
373    }
374}
375
376/// Replaces each category with the mean of a target variable.
377pub struct MeanEncoder {
378    pub columns: Vec<String>,
379    pub target: String,
380    /// Mapping from column to (category -> mean)
381    pub mapping: HashMap<String, HashMap<String, f64>>,
382    fitted: bool,
383}
384
385impl MeanEncoder {
386    /// Create a new MeanEncoder for the specified columns and target.
387    pub fn new(columns: Vec<String>, target: String) -> Self {
388        Self {
389            columns,
390            target,
391            mapping: HashMap::new(),
392            fitted: false,
393        }
394    }
395
396    /// Fit computes the mean for each category.
397    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
398        validate_string_columns(df, &self.columns)?;
399        validate_numeric_column(df, &self.target)?;
400        for col_name in &self.columns {
401            let agg_df = df
402                .clone()
403                .aggregate(
404                    vec![col(col_name)],
405                    vec![avg(col(&self.target)).alias("mean")],
406                )
407                .map_err(FeatureFactoryError::from)?;
408            let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
409            let mut map = HashMap::new();
410            for batch in batches {
411                let cat_array = batch
412                    .column(0)
413                    .as_any()
414                    .downcast_ref::<datafusion::arrow::array::StringArray>()
415                    .ok_or_else(|| {
416                        FeatureFactoryError::DataFusionError(
417                            datafusion::error::DataFusionError::Plan(format!(
418                                "Expected Utf8 array for column {}",
419                                col_name
420                            )),
421                        )
422                    })?;
423                let mean_array = batch
424                    .column(1)
425                    .as_any()
426                    .downcast_ref::<datafusion::arrow::array::Float64Array>()
427                    .ok_or_else(|| {
428                        FeatureFactoryError::DataFusionError(
429                            datafusion::error::DataFusionError::Plan(
430                                "Expected Float64 array".into(),
431                            ),
432                        )
433                    })?;
434                for i in 0..batch.num_rows() {
435                    if !cat_array.is_null(i) {
436                        map.insert(cat_array.value(i).to_string(), mean_array.value(i));
437                    }
438                }
439            }
440            self.mapping.insert(col_name.clone(), map);
441        }
442        self.fitted = true;
443        Ok(())
444    }
445
446    /// Transform replaces each category with its mean.
447    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
448        if !self.fitted {
449            return Err(FeatureFactoryError::FitNotCalled);
450        }
451        apply_mapping(
452            df,
453            &self.columns,
454            |name| {
455                self.mapping.get(name).map(|m| {
456                    m.iter()
457                        .map(|(k, &v)| (k.clone(), v))
458                        .collect::<Vec<(String, f64)>>()
459                })
460            },
461            |_| Some(lit(0.0_f64)),
462        )
463    }
464
465    // This transformer is stateful.
466    fn inherent_is_stateful(&self) -> bool {
467        true
468    }
469}
470
471/// Replaces each category with its weight of evidence (WoE).
472/// WoE is computed as ln((good_rate)/(bad_rate)), assuming a binary target.
473pub struct WoEEncoder {
474    pub columns: Vec<String>,
475    pub target: String,
476    /// Mapping from column to (category -> WoE)
477    pub mapping: HashMap<String, HashMap<String, f64>>,
478    fitted: bool,
479}
480
481impl WoEEncoder {
482    /// Create a new WoEEncoder for the specified columns and target.
483    pub fn new(columns: Vec<String>, target: String) -> Self {
484        Self {
485            columns,
486            target,
487            mapping: HashMap::new(),
488            fitted: false,
489        }
490    }
491
492    /// Fit computes the WoE for each category.
493    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
494        validate_string_columns(df, &self.columns)?;
495        validate_numeric_column(df, &self.target)?;
496        let overall_df = df
497            .clone()
498            .aggregate(vec![], vec![count(col(&self.target)).alias("total")])
499            .map_err(FeatureFactoryError::from)?;
500        let overall_batches = overall_df
501            .collect()
502            .await
503            .map_err(FeatureFactoryError::from)?;
504        let _total = if let Some(batch) = overall_batches.first() {
505            let total_array = batch
506                .column(0)
507                .as_any()
508                .downcast_ref::<datafusion::arrow::array::Int64Array>()
509                .ok_or_else(|| {
510                    FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
511                        "Expected Int64 array".into(),
512                    ))
513                })?;
514            total_array.value(0) as f64
515        } else {
516            return Err(FeatureFactoryError::DataFusionError(
517                datafusion::error::DataFusionError::Plan("No data found".into()),
518            ));
519        };
520
521        for col_name in &self.columns {
522            let grouped = df
523                .clone()
524                .aggregate(
525                    vec![col(col_name), col(&self.target)],
526                    vec![count(lit(1)).alias("cnt")],
527                )
528                .map_err(FeatureFactoryError::from)?;
529            let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
530            let mut cat_counts: HashMap<String, (f64, f64)> = HashMap::new(); // (good, bad)
531            for batch in batches {
532                let cat_array = batch
533                    .column(0)
534                    .as_any()
535                    .downcast_ref::<datafusion::arrow::array::StringArray>()
536                    .ok_or_else(|| {
537                        FeatureFactoryError::DataFusionError(
538                            datafusion::error::DataFusionError::Plan(format!(
539                                "Expected Utf8 array for column {}",
540                                col_name
541                            )),
542                        )
543                    })?;
544                let target_array = batch
545                    .column(1)
546                    .as_any()
547                    .downcast_ref::<datafusion::arrow::array::Int64Array>()
548                    .ok_or_else(|| {
549                        FeatureFactoryError::DataFusionError(
550                            datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
551                        )
552                    })?;
553                let count_array = batch
554                    .column(2)
555                    .as_any()
556                    .downcast_ref::<datafusion::arrow::array::Int64Array>()
557                    .ok_or_else(|| {
558                        FeatureFactoryError::DataFusionError(
559                            datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
560                        )
561                    })?;
562                for i in 0..batch.num_rows() {
563                    if !cat_array.is_null(i) {
564                        let cat = cat_array.value(i).to_string();
565                        let target_val = target_array.value(i);
566                        let cnt = count_array.value(i) as f64;
567                        let entry = cat_counts.entry(cat).or_insert((0.0, 0.0));
568                        if target_val == 1 {
569                            entry.0 += cnt;
570                        } else {
571                            entry.1 += cnt;
572                        }
573                    }
574                }
575            }
576            let mut mapping = HashMap::new();
577            for (cat, (good, bad)) in cat_counts {
578                let woe = ((good + 1e-6) / (bad + 1e-6)).ln();
579                mapping.insert(cat, woe);
580            }
581            self.mapping.insert(col_name.clone(), mapping);
582        }
583        self.fitted = true;
584        Ok(())
585    }
586
587    /// Transform replaces each category with its computed WoE.
588    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
589        if !self.fitted {
590            return Err(FeatureFactoryError::FitNotCalled);
591        }
592        apply_mapping(
593            df,
594            &self.columns,
595            |name| {
596                self.mapping.get(name).map(|m| {
597                    m.iter()
598                        .map(|(k, &v)| (k.clone(), v))
599                        .collect::<Vec<(String, f64)>>()
600                })
601            },
602            |_| Some(lit(0.0_f64)),
603        )
604    }
605
606    // This transformer is stateful.
607    fn inherent_is_stateful(&self) -> bool {
608        true
609    }
610}
611
612/// Groups infrequent categories into a single “rare” label.
613pub struct RareLabelEncoder {
614    pub columns: Vec<String>,
615    pub threshold: f64, // frequency threshold (between 0 and 1)
616    /// Mapping from column to (category -> encoded label)
617    pub mapping: HashMap<String, HashMap<String, String>>,
618    fitted: bool,
619}
620
621impl RareLabelEncoder {
622    /// Create a new RareLabelEncoder for the specified columns and threshold.
623    pub fn new(columns: Vec<String>, threshold: f64) -> Self {
624        Self {
625            columns,
626            threshold,
627            mapping: HashMap::new(),
628            fitted: false,
629        }
630    }
631
632    /// Fit computes frequencies and marks infrequent categories as "rare".
633    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
634        if self.threshold < 0.0 || self.threshold > 1.0 {
635            return Err(FeatureFactoryError::InvalidParameter(format!(
636                "Threshold {} must be between 0 and 1",
637                self.threshold
638            )));
639        }
640        validate_string_columns(df, &self.columns)?;
641        let total_df = df
642            .clone()
643            .aggregate(vec![], vec![count(lit(1)).alias("total")])
644            .map_err(FeatureFactoryError::from)?;
645        let total_batches = total_df
646            .collect()
647            .await
648            .map_err(FeatureFactoryError::from)?;
649        let total = if let Some(batch) = total_batches.first() {
650            let total_array = batch
651                .column(0)
652                .as_any()
653                .downcast_ref::<datafusion::arrow::array::Int64Array>()
654                .ok_or_else(|| {
655                    FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
656                        "Expected Int64 array".into(),
657                    ))
658                })?;
659            total_array.value(0) as f64
660        } else {
661            return Err(FeatureFactoryError::DataFusionError(
662                datafusion::error::DataFusionError::Plan("No data found".into()),
663            ));
664        };
665
666        for col_name in &self.columns {
667            let grouped = df
668                .clone()
669                .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
670                .map_err(FeatureFactoryError::from)?;
671            let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
672            let mut map = HashMap::new();
673            for batch in batches {
674                let cat_array = batch
675                    .column(0)
676                    .as_any()
677                    .downcast_ref::<datafusion::arrow::array::StringArray>()
678                    .ok_or_else(|| {
679                        FeatureFactoryError::DataFusionError(
680                            datafusion::error::DataFusionError::Plan(format!(
681                                "Expected Utf8 array for column {}",
682                                col_name
683                            )),
684                        )
685                    })?;
686                let cnt_array = batch
687                    .column(1)
688                    .as_any()
689                    .downcast_ref::<datafusion::arrow::array::Int64Array>()
690                    .ok_or_else(|| {
691                        FeatureFactoryError::DataFusionError(
692                            datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
693                        )
694                    })?;
695                for i in 0..batch.num_rows() {
696                    if !cat_array.is_null(i) {
697                        let cat = cat_array.value(i).to_string();
698                        let cnt = cnt_array.value(i) as f64;
699                        let freq = cnt / total;
700                        let encoded = if freq < self.threshold {
701                            "rare".to_string()
702                        } else {
703                            cat.clone()
704                        };
705                        map.insert(cat, encoded);
706                    }
707                }
708            }
709            self.mapping.insert(col_name.clone(), map);
710        }
711        self.fitted = true;
712        Ok(())
713    }
714
715    /// Transform replaces each category with its encoded label.
716    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
717        if !self.fitted {
718            return Err(FeatureFactoryError::FitNotCalled);
719        }
720        apply_mapping(
721            df,
722            &self.columns,
723            |name| {
724                self.mapping.get(name).map(|m| {
725                    m.iter()
726                        .map(|(k, v)| (k.clone(), v.clone()))
727                        .collect::<Vec<(String, String)>>()
728                })
729            },
730            |name| Some(col(name)),
731        )
732    }
733
734    // This transformer is stateful.
735    fn inherent_is_stateful(&self) -> bool {
736        true
737    }
738}
739
740// Implement the Transformer trait for the transformers in this module.
741impl_transformer!(OneHotEncoder);
742impl_transformer!(CountFrequencyEncoder);
743impl_transformer!(OrdinalEncoder);
744impl_transformer!(MeanEncoder);
745impl_transformer!(WoEEncoder);
746impl_transformer!(RareLabelEncoder);