Skip to main content

feature_factory/transformers/
feature_selection.rs

1//! ## Feature Selection Transformers
2//!
3//! This module provides transformers for selecting the most relevant features based on some criteria.
4//!
5//! ### Available Transformers
6//!
7//! - [`DropFeatures`]: Removes specific features from the dataset.
8//! - [`DropConstantFeatures`]: Eliminates constant and quasi-constant features.
9//! - [`DropDuplicateFeatures`]: Removes duplicate columns.
10//! - [`DropCorrelatedFeatures`]: Drops highly correlated features to reduce redundancy.
11//! - [`SmartCorrelatedSelection`]: Retains the best feature from correlated groups based on relevance.
12//! - [`DropHighPSIFeatures`]: Discards features with a high Population Stability Index (PSI).
13//! - [`SelectByInformationValue`]: Selects features based on Information Value (IV) for binary classification tasks.
14//! - [`SelectBySingleFeaturePerformance`]: Chooses features based on absolute correlation with a binary target.
15//! - [`SelectByTargetMeanPerformance`]: Selects features based on variations in target mean across bins.
16//! - [`MRMR`]: Uses Maximum Relevance Minimum Redundancy (MRMR) algorithm for feature selection.
17//!
18//! ### Assumptions
19//!
20//! - The DataFrame is fully materialized (`collect()`) for computing statistics.
21//! - Numeric columns are expected to be of Arrow’s `Float64` type.
22//! - Target-dependent methods assume a binary target column (values `0` and `1`).
23//!
24//! Each transformer returns a new DataFrame with the selected features.
25//! Errors are returned as [`FeatureFactoryError`], and results are wrapped in [`FeatureFactoryResult`].
26
27use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
28use crate::impl_transformer;
29use datafusion::arrow::array::{as_primitive_array, Array, StringArray};
30use datafusion::arrow::datatypes::{DataType, Float64Type};
31use datafusion::dataframe::DataFrame;
32use datafusion::logical_expr::{col, Expr};
33use rayon::prelude::*;
34use std::collections::{HashMap, HashSet};
35use std::sync::Arc;
36
37/// Helper function that checks if a DataFusion data type is numeric (only handling Float64 here).
38fn is_numeric(dt: &DataType) -> bool {
39    matches!(dt, DataType::Float64)
40}
41
42/// Removes the specified columns from the DataFrame.
43pub struct DropFeatures {
44    pub features: Vec<String>,
45}
46
47impl DropFeatures {
48    pub fn new(features: Vec<String>) -> Self {
49        Self { features }
50    }
51
52    pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
53        Ok(())
54    }
55
56    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
57        let available_exprs: Vec<Expr> = df
58            .schema()
59            .fields()
60            .iter()
61            .filter_map(|field| {
62                if !self.features.contains(field.name()) {
63                    Some(col(field.name()))
64                } else {
65                    None
66                }
67            })
68            .collect();
69
70        if available_exprs.is_empty() {
71            return Err(FeatureFactoryError::InvalidParameter(
72                "Dropping these features would result in an empty DataFrame.".to_string(),
73            ));
74        }
75        df.select(available_exprs)
76            .map_err(FeatureFactoryError::from)
77    }
78
79    fn inherent_is_stateful(&self) -> bool {
80        false
81    }
82}
83
84/// Removes features that are constant or nearly constant (where variance is below a threshold).
85pub struct DropConstantFeatures {
86    pub numeric_threshold: f64,
87    pub categorical_threshold: usize,
88    pub drop_columns: HashSet<String>,
89    fitted: bool,
90}
91
92impl DropConstantFeatures {
93    pub fn new(numeric_threshold: f64, categorical_threshold: usize) -> Self {
94        Self {
95            numeric_threshold,
96            categorical_threshold,
97            drop_columns: HashSet::new(),
98            fitted: false,
99        }
100    }
101
102    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
103        let schema = df.schema();
104        let batches = df.clone().collect().await?;
105        if batches.is_empty() {
106            return Err(FeatureFactoryError::InvalidParameter(
107                "DataFrame is empty.".to_string(),
108            ));
109        }
110        let batch = &batches[0];
111
112        for field in schema.fields() {
113            let name = field.name();
114            if is_numeric(field.data_type()) {
115                let array =
116                    as_primitive_array::<Float64Type>(batch.column_by_name(name).ok_or_else(
117                        || FeatureFactoryError::MissingColumn(format!("Column {} not found", name)),
118                    )?);
119                let n = array.len() as f64;
120                let sum: f64 = array.iter().flatten().par_bridge().sum();
121                let mean = sum / n;
122                let sum_sq: f64 = array.iter().flatten().par_bridge().map(|v| v * v).sum();
123                let variance = sum_sq / n - mean * mean;
124                if variance < self.numeric_threshold {
125                    self.drop_columns.insert(name.to_string());
126                }
127            } else {
128                let string_array = batch
129                    .column_by_name(name)
130                    .ok_or_else(|| {
131                        FeatureFactoryError::MissingColumn(format!("Column {} not found", name))
132                    })?
133                    .as_any()
134                    .downcast_ref::<StringArray>()
135                    .ok_or_else(|| {
136                        FeatureFactoryError::DataFusionError(
137                            datafusion::error::DataFusionError::Plan(format!(
138                                "Expected Utf8 array for column {}",
139                                name
140                            )),
141                        )
142                    })?;
143                let mut distinct = HashSet::new();
144                for i in 0..string_array.len() {
145                    if !string_array.is_null(i) {
146                        distinct.insert(string_array.value(i).to_string());
147                    }
148                }
149                if distinct.len() <= self.categorical_threshold {
150                    self.drop_columns.insert(name.to_string());
151                }
152            }
153        }
154        self.fitted = true;
155        Ok(())
156    }
157
158    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
159        if !self.fitted {
160            return Err(FeatureFactoryError::FitNotCalled);
161        }
162        let keep_exprs: Vec<Expr> = df
163            .schema()
164            .fields()
165            .iter()
166            .filter_map(|field| {
167                if !self.drop_columns.contains(field.name()) {
168                    Some(col(field.name()))
169                } else {
170                    None
171                }
172            })
173            .collect();
174
175        if keep_exprs.is_empty() {
176            return Err(FeatureFactoryError::InvalidParameter(
177                "All features were dropped by DropConstantFeatures.".to_string(),
178            ));
179        }
180        df.select(keep_exprs).map_err(FeatureFactoryError::from)
181    }
182
183    fn inherent_is_stateful(&self) -> bool {
184        true
185    }
186}
187
188/// Removes duplicate features by comparing values in each column.
189pub struct DropDuplicateFeatures {
190    pub drop_columns: HashSet<String>,
191    fitted: bool,
192}
193
194impl Default for DropDuplicateFeatures {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl DropDuplicateFeatures {
201    pub fn new() -> Self {
202        Self {
203            drop_columns: HashSet::new(),
204            fitted: false,
205        }
206    }
207
208    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
209        let batches = df.clone().collect().await?;
210        if batches.is_empty() {
211            return Err(FeatureFactoryError::InvalidParameter(
212                "Empty DataFrame".to_string(),
213            ));
214        }
215        let batch = &batches[0];
216        let schema = batch.schema();
217        let mut seen: Vec<(String, Arc<dyn Array>)> = Vec::new();
218        for field in schema.fields() {
219            let name = field.name().clone();
220            let array = batch.column_by_name(&name).unwrap();
221            let mut is_duplicate = false;
222            for (_seen_name, seen_array) in &seen {
223                if array == seen_array {
224                    self.drop_columns.insert(name.clone());
225                    is_duplicate = true;
226                    break;
227                }
228            }
229            if !is_duplicate {
230                seen.push((name, array.clone()));
231            }
232        }
233        self.fitted = true;
234        Ok(())
235    }
236
237    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
238        if !self.fitted {
239            return Err(FeatureFactoryError::FitNotCalled);
240        }
241        let keep_exprs: Vec<Expr> = df
242            .schema()
243            .fields()
244            .iter()
245            .filter_map(|field| {
246                if !self.drop_columns.contains(field.name()) {
247                    Some(col(field.name()))
248                } else {
249                    None
250                }
251            })
252            .collect();
253        if keep_exprs.is_empty() {
254            return Err(FeatureFactoryError::InvalidParameter(
255                "All features were dropped by DropDuplicateFeatures.".to_string(),
256            ));
257        }
258        df.select(keep_exprs).map_err(FeatureFactoryError::from)
259    }
260
261    fn inherent_is_stateful(&self) -> bool {
262        true
263    }
264}
265
266/// Removes one feature from each highly correlated pair (using Pearson correlation).
267pub struct DropCorrelatedFeatures {
268    pub threshold: f64,
269    pub drop_columns: HashSet<String>,
270    fitted: bool,
271}
272
273impl DropCorrelatedFeatures {
274    pub fn new(threshold: f64) -> Self {
275        Self {
276            threshold,
277            drop_columns: HashSet::new(),
278            fitted: false,
279        }
280    }
281
282    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
283        let batches = df.clone().collect().await?;
284        if batches.is_empty() {
285            return Err(FeatureFactoryError::InvalidParameter(
286                "Empty DataFrame".to_string(),
287            ));
288        }
289        let batch = &batches[0];
290        let schema = df.schema();
291        let numeric_fields: Vec<_> = schema
292            .fields()
293            .iter()
294            .filter(|f| is_numeric(f.data_type()))
295            .collect();
296        let mut data: HashMap<String, Vec<f64>> = HashMap::new();
297        for field in &numeric_fields {
298            let name = field.name();
299            let array = as_primitive_array::<Float64Type>(batch.column_by_name(name).unwrap());
300            let vec: Vec<f64> = array.iter().flatten().collect();
301            data.insert(name.to_string(), vec);
302        }
303        let mut to_drop = HashSet::new();
304        let names: Vec<_> = data.keys().cloned().collect();
305        for i in 0..names.len() {
306            for j in (i + 1)..names.len() {
307                let x = &data[&names[i]];
308                let y = &data[&names[j]];
309                if x.len() != y.len() || x.is_empty() {
310                    continue;
311                }
312                let n_f = x.len() as f64;
313                let mean_x = x.iter().sum::<f64>() / n_f;
314                let mean_y = y.iter().sum::<f64>() / n_f;
315                let cov: f64 = x
316                    .iter()
317                    .zip(y.iter())
318                    .map(|(a, b)| (a - mean_x) * (b - mean_y))
319                    .sum();
320                let var_x: f64 = x.iter().map(|a| (a - mean_x).powi(2)).sum();
321                let var_y: f64 = y.iter().map(|b| (b - mean_y).powi(2)).sum();
322                if var_x == 0.0 || var_y == 0.0 {
323                    continue;
324                }
325                let corr = cov / ((var_x).sqrt() * (var_y).sqrt());
326                if corr.abs() > self.threshold {
327                    if var_x < var_y {
328                        to_drop.insert(names[i].clone());
329                    } else {
330                        to_drop.insert(names[j].clone());
331                    }
332                }
333            }
334        }
335        self.drop_columns = to_drop;
336        self.fitted = true;
337        Ok(())
338    }
339
340    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
341        if !self.fitted {
342            return Err(FeatureFactoryError::FitNotCalled);
343        }
344        let keep_exprs: Vec<Expr> = df
345            .schema()
346            .fields()
347            .iter()
348            .filter_map(|f| {
349                if !self.drop_columns.contains(f.name()) {
350                    Some(col(f.name()))
351                } else {
352                    None
353                }
354            })
355            .collect();
356        if keep_exprs.is_empty() {
357            return Err(FeatureFactoryError::InvalidParameter(
358                "All features were dropped by DropCorrelatedFeatures.".to_string(),
359            ));
360        }
361        df.select(keep_exprs).map_err(FeatureFactoryError::from)
362    }
363
364    fn inherent_is_stateful(&self) -> bool {
365        true
366    }
367}
368
369/// Groups correlated features and keeps the one with the highest variance from each group.
370pub struct SmartCorrelatedSelection {
371    pub threshold: f64,
372    pub selected_features: HashSet<String>,
373    fitted: bool,
374}
375
376impl SmartCorrelatedSelection {
377    pub fn new(threshold: f64) -> Self {
378        Self {
379            threshold,
380            selected_features: HashSet::new(),
381            fitted: false,
382        }
383    }
384
385    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
386        let batches = df.clone().collect().await?;
387        if batches.is_empty() {
388            return Err(FeatureFactoryError::InvalidParameter(
389                "Empty DataFrame".to_string(),
390            ));
391        }
392        let batch = &batches[0];
393        let schema = df.schema();
394        let numeric_fields: Vec<_> = schema
395            .fields()
396            .iter()
397            .filter(|f| is_numeric(f.data_type()))
398            .collect();
399        let mut stats: Vec<(String, f64, Vec<f64>)> = Vec::new();
400        for field in &numeric_fields {
401            let name = field.name();
402            let array = as_primitive_array::<Float64Type>(batch.column_by_name(name).unwrap());
403            let vec: Vec<f64> = array.iter().flatten().collect();
404            let n = vec.len() as f64;
405            let mean = vec.iter().sum::<f64>() / n;
406            let var = vec.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
407            stats.push((name.to_string(), var, vec));
408        }
409        let mut candidates: HashSet<String> =
410            stats.iter().map(|(name, _, _)| name.clone()).collect();
411        let mut selected: Vec<String> = Vec::<String>::new();
412        for i in 0..stats.len() {
413            for j in (i + 1)..stats.len() {
414                let (ref name_i, var_i, ref x) = stats[i];
415                let (ref name_j, var_j, ref y) = stats[j];
416                if !candidates.contains(name_i) || !candidates.contains(name_j) {
417                    continue;
418                }
419                if x.len() != y.len() || x.is_empty() {
420                    continue;
421                }
422                let n_f = x.len() as f64;
423                let mean_i = x.iter().sum::<f64>() / n_f;
424                let mean_j = y.iter().sum::<f64>() / n_f;
425                let cov: f64 = x
426                    .iter()
427                    .zip(y.iter())
428                    .map(|(a, b)| (a - mean_i) * (b - mean_j))
429                    .sum();
430                let sxx: f64 = x.iter().map(|a| (a - mean_i).powi(2)).sum();
431                let syy: f64 = y.iter().map(|b| (b - mean_j).powi(2)).sum();
432                if sxx == 0.0 || syy == 0.0 {
433                    continue;
434                }
435                let corr = cov / (sxx.sqrt() * syy.sqrt());
436                if corr.abs() > self.threshold {
437                    if var_i < var_j {
438                        candidates.remove(name_i);
439                    } else {
440                        candidates.remove(name_j);
441                    }
442                }
443            }
444        }
445        selected.extend(candidates.into_iter());
446        self.selected_features = selected.into_iter().collect();
447        self.fitted = true;
448        Ok(())
449    }
450
451    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
452        if !self.fitted {
453            return Err(FeatureFactoryError::FitNotCalled);
454        }
455        let keep_exprs: Vec<Expr> = df
456            .schema()
457            .fields()
458            .iter()
459            .filter_map(|f| {
460                if is_numeric(f.data_type()) {
461                    if self.selected_features.contains(f.name()) {
462                        Some(col(f.name()))
463                    } else {
464                        None
465                    }
466                } else {
467                    Some(col(f.name()))
468                }
469            })
470            .collect();
471        if keep_exprs.is_empty() {
472            return Err(FeatureFactoryError::InvalidParameter(
473                "No features selected by SmartCorrelatedSelection.".to_string(),
474            ));
475        }
476        df.select(keep_exprs).map_err(FeatureFactoryError::from)
477    }
478
479    fn inherent_is_stateful(&self) -> bool {
480        true
481    }
482}
483
484/// Drops features that their Population Stability Index (PSI) is larger than a threshold.
485pub struct DropHighPSIFeatures {
486    pub reference: DataFrame,
487    pub psi_threshold: f64,
488    pub drop_columns: HashSet<String>,
489    fitted: bool,
490}
491
492impl DropHighPSIFeatures {
493    pub fn new(reference: DataFrame, psi_threshold: f64) -> Self {
494        Self {
495            reference,
496            psi_threshold,
497            drop_columns: HashSet::new(),
498            fitted: false,
499        }
500    }
501
502    fn compute_psi(ref_vals: &[f64], curr_vals: &[f64], bins: &[f64]) -> f64 {
503        let mut psi = 0.0;
504        let total_ref = ref_vals.len() as f64;
505        let total_curr = curr_vals.len() as f64;
506        for i in 0..bins.len() - 1 {
507            let lower = bins[i];
508            let upper = bins[i + 1];
509            let count_ref = ref_vals
510                .par_iter()
511                .filter(|v| **v >= lower && **v < upper)
512                .count() as f64;
513            let count_curr = curr_vals
514                .par_iter()
515                .filter(|v| **v >= lower && **v < upper)
516                .count() as f64;
517            let pct_ref = (count_ref / total_ref).max(0.0001);
518            let pct_curr = (count_curr / total_curr).max(0.0001);
519            psi += (pct_ref - pct_curr) * (pct_ref / pct_curr).ln();
520        }
521        psi
522    }
523
524    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
525        let ref_batches = self.reference.clone().collect().await?;
526        let curr_batches = df.clone().collect().await?;
527        if ref_batches.is_empty() || curr_batches.is_empty() {
528            return Err(FeatureFactoryError::InvalidParameter(
529                "Empty DataFrame".to_string(),
530            ));
531        }
532        let ref_batch = &ref_batches[0];
533        let curr_batch = &curr_batches[0];
534        let schema = df.schema();
535        for field in schema.fields() {
536            if is_numeric(field.data_type()) {
537                let name = field.name();
538                let ref_array =
539                    as_primitive_array::<Float64Type>(ref_batch.column_by_name(name).ok_or_else(
540                        || FeatureFactoryError::MissingColumn(format!("Column {} missing", name)),
541                    )?);
542                let curr_array =
543                    as_primitive_array::<Float64Type>(curr_batch.column_by_name(name).ok_or_else(
544                        || FeatureFactoryError::MissingColumn(format!("Column {} missing", name)),
545                    )?);
546                let ref_vals: Vec<f64> = ref_array.iter().flatten().par_bridge().collect();
547                let curr_vals: Vec<f64> = curr_array.iter().flatten().par_bridge().collect();
548                let mut sorted = ref_vals.clone();
549                sorted.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
550                let mut bins = Vec::new();
551                for i in 0..11 {
552                    let idx = ((sorted.len() - 1) as f64 * i as f64 / 10.0).round() as usize;
553                    bins.push(sorted[idx]);
554                }
555                let psi = Self::compute_psi(&ref_vals, &curr_vals, &bins);
556                if psi > self.psi_threshold {
557                    self.drop_columns.insert(name.to_string());
558                }
559            }
560        }
561        self.fitted = true;
562        Ok(())
563    }
564
565    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
566        if !self.fitted {
567            return Err(FeatureFactoryError::FitNotCalled);
568        }
569        let keep_exprs: Vec<Expr> = df
570            .schema()
571            .fields()
572            .iter()
573            .filter_map(|f| {
574                if !self.drop_columns.contains(f.name()) {
575                    Some(col(f.name()))
576                } else {
577                    None
578                }
579            })
580            .collect();
581        if keep_exprs.is_empty() {
582            return Err(FeatureFactoryError::InvalidParameter(
583                "All features dropped by DropHighPSIFeatures.".to_string(),
584            ));
585        }
586        df.select(keep_exprs).map_err(FeatureFactoryError::from)
587    }
588
589    fn inherent_is_stateful(&self) -> bool {
590        true
591    }
592}
593
594/// Computes Information Value (IV) for each feature relative to a binary target and selects the best.
595pub struct SelectByInformationValue {
596    pub target: String,
597    pub iv_threshold: f64,
598    pub selected_features: HashSet<String>,
599    fitted: bool,
600}
601
602impl SelectByInformationValue {
603    pub fn new(target: String, iv_threshold: f64) -> Self {
604        Self {
605            target,
606            iv_threshold,
607            selected_features: HashSet::new(),
608            fitted: false,
609        }
610    }
611
612    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
613        let batches = df.clone().collect().await?;
614        if batches.is_empty() {
615            return Err(FeatureFactoryError::InvalidParameter(
616                "Empty DataFrame".to_string(),
617            ));
618        }
619        let batch = &batches[0];
620        let schema = df.schema();
621        let target_array =
622            as_primitive_array::<Float64Type>(batch.column_by_name(&self.target).ok_or_else(
623                || FeatureFactoryError::MissingColumn(format!("Target {} missing", self.target)),
624            )?);
625        let target_vals: Vec<f64> = target_array.iter().flatten().par_bridge().collect();
626        let total_good = target_vals.iter().filter(|&&v| v == 1.0).count() as f64;
627        let total_bad = target_vals.iter().filter(|&&v| v == 0.0).count() as f64;
628        let mut selected = HashSet::new();
629        for field in schema.fields() {
630            let name = field.name();
631            if name == &self.target {
632                continue;
633            }
634            let col_array = batch.column_by_name(name).ok_or_else(|| {
635                FeatureFactoryError::MissingColumn(format!("Column {} missing", name))
636            })?;
637            let mut iv = 0.0;
638            if is_numeric(field.data_type()) {
639                let array = as_primitive_array::<Float64Type>(col_array);
640                let mut vals: Vec<f64> = array.iter().flatten().par_bridge().collect();
641                if vals.is_empty() {
642                    continue;
643                }
644                vals.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
645                let mut bins = Vec::new();
646                for i in 0..11 {
647                    let idx = ((vals.len() - 1) as f64 * i as f64 / 10.0).round() as usize;
648                    bins.push(vals[idx]);
649                }
650                for i in 0..bins.len() - 1 {
651                    let lower = bins[i];
652                    let upper = bins[i + 1];
653                    let mut good = 0.0;
654                    let mut bad = 0.0;
655                    for (j, v_opt) in array.iter().enumerate() {
656                        if let Some(v) = v_opt {
657                            if v >= lower && v < upper {
658                                if target_vals[j] == 1.0 {
659                                    good += 1.0;
660                                } else {
661                                    bad += 1.0;
662                                }
663                            }
664                        }
665                    }
666                    let pct_good = (good / total_good).max(0.0001);
667                    let pct_bad = (bad / total_bad).max(0.0001);
668                    iv += (pct_good - pct_bad) * (pct_good / pct_bad).ln();
669                }
670            } else {
671                let string_array = col_array
672                    .as_any()
673                    .downcast_ref::<StringArray>()
674                    .ok_or_else(|| {
675                        FeatureFactoryError::DataFusionError(
676                            datafusion::error::DataFusionError::Plan(format!(
677                                "Expected Utf8 array for column {}",
678                                name
679                            )),
680                        )
681                    })?;
682                let mut counts: HashMap<String, (f64, f64)> = HashMap::new();
683                for (j, v_opt) in string_array.iter().enumerate() {
684                    if let Some(v) = v_opt {
685                        let key = v.to_string();
686                        let entry = counts.entry(key).or_insert((0.0, 0.0));
687                        if target_vals[j] == 1.0 {
688                            entry.0 += 1.0;
689                        } else {
690                            entry.1 += 1.0;
691                        }
692                    }
693                }
694                for (_k, (good, bad)) in counts.iter() {
695                    let pct_good = (*good / total_good).max(0.0001);
696                    let pct_bad = (*bad / total_bad).max(0.0001);
697                    iv += (pct_good - pct_bad) * (pct_good / pct_bad).ln();
698                }
699            }
700            if iv >= self.iv_threshold {
701                selected.insert(name.to_string());
702            }
703        }
704        self.selected_features = selected;
705        self.fitted = true;
706        Ok(())
707    }
708
709    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
710        if !self.fitted {
711            return Err(FeatureFactoryError::FitNotCalled);
712        }
713        let keep_exprs: Vec<Expr> = df
714            .schema()
715            .fields()
716            .iter()
717            .filter_map(|f| {
718                if f.name() == &self.target || self.selected_features.contains(f.name()) {
719                    Some(col(f.name()))
720                } else {
721                    None
722                }
723            })
724            .collect();
725        if keep_exprs.is_empty() {
726            return Err(FeatureFactoryError::InvalidParameter(
727                "No features passed the IV threshold.".to_string(),
728            ));
729        }
730        df.select(keep_exprs).map_err(FeatureFactoryError::from)
731    }
732
733    fn inherent_is_stateful(&self) -> bool {
734        true
735    }
736}
737
738/// Selects numeric features based on absolute correlation with a binary target.
739/// Note: To preserve order between feature and target values, this transformer uses sequential iterators so it may be relatively slow.
740pub struct SelectBySingleFeaturePerformance {
741    pub target: String,
742    pub correlation_threshold: f64,
743    pub selected_features: HashSet<String>,
744    fitted: bool,
745}
746
747impl SelectBySingleFeaturePerformance {
748    pub fn new(target: String, correlation_threshold: f64) -> Self {
749        Self {
750            target,
751            correlation_threshold,
752            selected_features: HashSet::new(),
753            fitted: false,
754        }
755    }
756
757    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
758        let batches = df.clone().collect().await?;
759        if batches.is_empty() {
760            return Err(FeatureFactoryError::InvalidParameter(
761                "Empty DataFrame".to_string(),
762            ));
763        }
764        let batch = &batches[0];
765        let target_array =
766            as_primitive_array::<Float64Type>(batch.column_by_name(&self.target).ok_or_else(
767                || FeatureFactoryError::MissingColumn(format!("Target {} missing", self.target)),
768            )?);
769        // Use sequential iteration to preserve order.
770        let target_vals: Vec<f64> = target_array.iter().flatten().collect();
771        let mut selected = HashSet::new();
772        for field in df.schema().fields() {
773            let name = field.name();
774            if name == &self.target || !is_numeric(field.data_type()) {
775                continue;
776            }
777            let array = as_primitive_array::<Float64Type>(batch.column_by_name(name).unwrap());
778            let x: Vec<f64> = array.iter().flatten().collect();
779            if x.len() != target_vals.len() || x.is_empty() {
780                continue;
781            }
782            let n = x.len() as f64;
783            let mean_x = x.iter().sum::<f64>() / n;
784            let mean_y = target_vals.iter().sum::<f64>() / n;
785            let cov: f64 = x
786                .iter()
787                .zip(target_vals.iter())
788                .map(|(a, b)| (a - mean_x) * (b - mean_y))
789                .sum();
790            let var_x: f64 = x.iter().map(|a| (a - mean_x).powi(2)).sum();
791            let var_y: f64 = target_vals.iter().map(|b| (b - mean_y).powi(2)).sum();
792            if var_x == 0.0 || var_y == 0.0 {
793                continue;
794            }
795            let corr = cov / (var_x.sqrt() * var_y.sqrt());
796            if corr.abs() >= self.correlation_threshold {
797                selected.insert(name.to_string());
798            }
799        }
800        self.selected_features = selected;
801        self.fitted = true;
802        Ok(())
803    }
804
805    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
806        if !self.fitted {
807            return Err(FeatureFactoryError::FitNotCalled);
808        }
809        let mut keep_exprs: Vec<Expr> = vec![col(&self.target)];
810        for field in df.schema().fields() {
811            if self.selected_features.contains(field.name()) {
812                keep_exprs.push(col(field.name()));
813            }
814        }
815        if keep_exprs.is_empty() {
816            return Err(FeatureFactoryError::InvalidParameter(
817                "No features passed single feature performance selection.".to_string(),
818            ));
819        }
820        df.select(keep_exprs).map_err(FeatureFactoryError::from)
821    }
822
823    fn inherent_is_stateful(&self) -> bool {
824        true
825    }
826}
827
828/// Selects features based on the difference in target mean across different bins.
829pub struct SelectByTargetMeanPerformance {
830    pub target: String,
831    pub mean_diff_threshold: f64,
832    pub selected_features: HashSet<String>,
833    fitted: bool,
834}
835
836impl SelectByTargetMeanPerformance {
837    pub fn new(target: String, mean_diff_threshold: f64) -> Self {
838        Self {
839            target,
840            mean_diff_threshold,
841            selected_features: HashSet::new(),
842            fitted: false,
843        }
844    }
845
846    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
847        let batches = df.clone().collect().await?;
848        if batches.is_empty() {
849            return Err(FeatureFactoryError::InvalidParameter(
850                "Empty DataFrame".to_string(),
851            ));
852        }
853        let batch = &batches[0];
854        let target_array =
855            as_primitive_array::<Float64Type>(batch.column_by_name(&self.target).ok_or_else(
856                || FeatureFactoryError::MissingColumn(format!("Target {} missing", self.target)),
857            )?);
858        let target_vals: Vec<f64> = target_array.iter().flatten().collect();
859        let mut selected = HashSet::new();
860        for field in df.schema().fields() {
861            let name = field.name();
862            if name == &self.target || !is_numeric(field.data_type()) {
863                continue;
864            }
865            let array = as_primitive_array::<Float64Type>(batch.column_by_name(name).unwrap());
866            let mut vals: Vec<f64> = array.iter().flatten().collect();
867            if vals.is_empty() {
868                continue;
869            }
870            vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
871            let median = vals[vals.len() / 2];
872            let mut group1 = Vec::new();
873            let mut group2 = Vec::new();
874            for (j, v_opt) in array.iter().enumerate() {
875                if let Some(val) = v_opt {
876                    if val < median {
877                        group1.push(target_vals[j]);
878                    } else {
879                        group2.push(target_vals[j]);
880                    }
881                }
882            }
883            let mean1 = if !group1.is_empty() {
884                group1.iter().sum::<f64>() / group1.len() as f64
885            } else {
886                0.0
887            };
888            let mean2 = if !group2.is_empty() {
889                group2.iter().sum::<f64>() / group2.len() as f64
890            } else {
891                0.0
892            };
893            if (mean1 - mean2).abs() >= self.mean_diff_threshold {
894                selected.insert(name.to_string());
895            }
896        }
897        self.selected_features = selected;
898        self.fitted = true;
899        Ok(())
900    }
901
902    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
903        if !self.fitted {
904            return Err(FeatureFactoryError::FitNotCalled);
905        }
906        let mut keep_exprs: Vec<Expr> = vec![col(&self.target)];
907        for field in df.schema().fields() {
908            if self.selected_features.contains(field.name()) {
909                keep_exprs.push(col(field.name()));
910            }
911        }
912        if keep_exprs.is_empty() {
913            return Err(FeatureFactoryError::InvalidParameter(
914                "No features selected by target mean performance.".to_string(),
915            ));
916        }
917        df.select(keep_exprs).map_err(FeatureFactoryError::from)
918    }
919
920    fn inherent_is_stateful(&self) -> bool {
921        true
922    }
923}
924
925/// Selects features using MRMR algorithm based on feature-target relevance and redundancy.
926pub struct MRMR {
927    pub target: String,
928    pub relevance_threshold: f64,
929    pub redundancy_threshold: f64,
930    pub selected_features: HashSet<String>,
931    fitted: bool,
932}
933
934impl MRMR {
935    pub fn new(target: String, relevance_threshold: f64, redundancy_threshold: f64) -> Self {
936        Self {
937            target,
938            relevance_threshold,
939            redundancy_threshold,
940            selected_features: HashSet::new(),
941            fitted: false,
942        }
943    }
944
945    pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
946        let batches = df.clone().collect().await?;
947        if batches.is_empty() {
948            return Err(FeatureFactoryError::InvalidParameter(
949                "Empty DataFrame".to_string(),
950            ));
951        }
952        let batch = &batches[0];
953        let target_array =
954            as_primitive_array::<Float64Type>(batch.column_by_name(&self.target).ok_or_else(
955                || FeatureFactoryError::MissingColumn(format!("Target {} missing", self.target)),
956            )?);
957        let target_vals: Vec<f64> = target_array.iter().flatten().collect();
958        let schema = df.schema();
959        let mut candidates = Vec::new();
960        for field in schema.fields() {
961            let name = field.name();
962            if name == &self.target || !is_numeric(field.data_type()) {
963                continue;
964            }
965            let array = as_primitive_array::<Float64Type>(batch.column_by_name(name).unwrap());
966            let x: Vec<f64> = array.iter().flatten().collect();
967            if x.len() != target_vals.len() || x.is_empty() {
968                continue;
969            }
970            let n = x.len() as f64;
971            let mean_x = x.iter().sum::<f64>() / n;
972            let mean_y = target_vals.iter().sum::<f64>() / n;
973            let cov: f64 = x
974                .iter()
975                .zip(target_vals.iter())
976                .map(|(a, b)| (a - mean_x) * (b - mean_y))
977                .sum();
978            let var_x: f64 = x.iter().map(|a| (a - mean_x).powi(2)).sum();
979            let var_y: f64 = target_vals.iter().map(|b| (b - mean_y).powi(2)).sum();
980            if var_x == 0.0 || var_y == 0.0 {
981                continue;
982            }
983            let corr = cov / (var_x.sqrt() * var_y.sqrt());
984            if corr.abs() >= self.relevance_threshold {
985                candidates.push((name.to_string(), corr.abs()));
986            }
987        }
988        let mut selected = Vec::<String>::new();
989        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
990        for (feat, _) in candidates {
991            let mut redundant = false;
992            for sel in &selected {
993                let array_feat =
994                    as_primitive_array::<Float64Type>(batch.column_by_name(&feat).unwrap());
995                let array_sel =
996                    as_primitive_array::<Float64Type>(batch.column_by_name(sel).unwrap());
997                let x: Vec<f64> = array_feat.iter().flatten().collect();
998                let y: Vec<f64> = array_sel.iter().flatten().collect();
999                if x.len() != y.len() || x.is_empty() {
1000                    continue;
1001                }
1002                let n = x.len() as f64;
1003                let mean_x = x.iter().sum::<f64>() / n;
1004                let mean_y = y.iter().sum::<f64>() / n;
1005                let cov: f64 = x
1006                    .iter()
1007                    .zip(y.iter())
1008                    .map(|(a, b)| (a - mean_x) * (b - mean_y))
1009                    .sum();
1010                let var_x: f64 = x.iter().map(|a| (a - mean_x).powi(2)).sum();
1011                let var_y: f64 = y.iter().map(|b| (b - mean_y).powi(2)).sum();
1012                if var_x == 0.0 || var_y == 0.0 {
1013                    continue;
1014                }
1015                let corr = cov / (var_x.sqrt() * var_y.sqrt());
1016                if corr.abs() > self.redundancy_threshold {
1017                    redundant = true;
1018                    break;
1019                }
1020            }
1021            if !redundant {
1022                selected.push(feat);
1023            }
1024        }
1025        self.selected_features = selected.into_iter().collect();
1026        self.fitted = true;
1027        Ok(())
1028    }
1029
1030    pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
1031        if !self.fitted {
1032            return Err(FeatureFactoryError::FitNotCalled);
1033        }
1034        let mut keep_exprs: Vec<Expr> = vec![col(&self.target)];
1035        for field in df.schema().fields() {
1036            if self.selected_features.contains(field.name()) {
1037                keep_exprs.push(col(field.name()));
1038            }
1039        }
1040        if keep_exprs.is_empty() {
1041            return Err(FeatureFactoryError::InvalidParameter(
1042                "No features selected by MRMR.".to_string(),
1043            ));
1044        }
1045        df.select(keep_exprs).map_err(FeatureFactoryError::from)
1046    }
1047
1048    fn inherent_is_stateful(&self) -> bool {
1049        true
1050    }
1051}
1052
1053// Implement the Transformer trait for all the above feature selection transformers.
1054impl_transformer!(DropFeatures);
1055impl_transformer!(DropConstantFeatures);
1056impl_transformer!(DropDuplicateFeatures);
1057impl_transformer!(DropCorrelatedFeatures);
1058impl_transformer!(SmartCorrelatedSelection);
1059impl_transformer!(DropHighPSIFeatures);
1060impl_transformer!(SelectByInformationValue);
1061impl_transformer!(SelectBySingleFeaturePerformance);
1062impl_transformer!(SelectByTargetMeanPerformance);
1063impl_transformer!(MRMR);