1use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
18use crate::impl_transformer;
19use arrow::array::Array;
20use arrow::datatypes::DataType;
21use datafusion::dataframe::DataFrame;
22use datafusion::functions_aggregate::expr_fn::{avg, count};
23use datafusion::logical_expr::{col, lit, Case as DFCase, Expr};
24use std::collections::HashMap;
25
26fn validate_string_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
28 let field = df.schema().field_with_name(None, col_name).map_err(|_| {
29 FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
30 })?;
31 if field.data_type() != &DataType::Utf8 {
32 return Err(FeatureFactoryError::InvalidParameter(format!(
33 "Column '{}' must be of type Utf8, but found {:?}",
34 col_name,
35 field.data_type()
36 )));
37 }
38 Ok(())
39}
40
41fn validate_string_columns(df: &DataFrame, cols: &[String]) -> FeatureFactoryResult<()> {
43 for col in cols {
44 validate_string_column(df, col)?;
45 }
46 Ok(())
47}
48
49fn validate_numeric_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
51 let field = df.schema().field_with_name(None, col_name).map_err(|_| {
52 FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
53 })?;
54 match field.data_type() {
55 DataType::Float64 | DataType::Int64 => Ok(()),
56 dt => Err(FeatureFactoryError::InvalidParameter(format!(
57 "Column '{}' must be numeric (Float64 or Int64), but found {:?}",
58 col_name, dt
59 ))),
60 }
61}
62
63fn sanitize_category(cat: &str) -> String {
66 cat.replace(|c: char| !c.is_alphanumeric(), "_")
67}
68
69fn build_case_expr<T: Clone + 'static + datafusion::logical_expr::Literal>(
74 col_name: &str,
75 mapping: &[(String, T)],
76 default: Option<Expr>,
77) -> Expr {
78 let when_then_expr = mapping
79 .iter()
80 .map(|(cat, val)| {
81 (
82 Box::new(col(col_name).eq(lit(cat.clone()))),
83 Box::new(lit(val.clone())),
84 )
85 })
86 .collect();
87 Expr::Case(DFCase {
88 expr: None,
89 when_then_expr,
90 else_expr: default.map(Box::new),
91 })
92}
93
94async fn extract_distinct_values(
96 df: &DataFrame,
97 col_name: &str,
98) -> FeatureFactoryResult<Vec<String>> {
99 validate_string_column(df, col_name)?;
101 let distinct_df = df.clone().select(vec![col(col_name)])?.distinct()?;
102 let batches = distinct_df
103 .collect()
104 .await
105 .map_err(FeatureFactoryError::from)?;
106 let mut values = Vec::new();
107 for batch in batches {
108 let array = batch
109 .column(0)
110 .as_any()
111 .downcast_ref::<datafusion::arrow::array::StringArray>()
112 .ok_or_else(|| {
113 FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
114 format!("Expected Utf8 array for column {}", col_name),
115 ))
116 })?;
117 for i in 0..array.len() {
118 if !array.is_null(i) {
119 values.push(array.value(i).to_string());
120 }
121 }
122 }
123 Ok(values)
124}
125
126async fn extract_count_mapping(
128 df: &DataFrame,
129 col_name: &str,
130) -> FeatureFactoryResult<HashMap<String, i64>> {
131 validate_string_column(df, col_name)?;
132 let grouped = df
133 .clone()
134 .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
135 .map_err(FeatureFactoryError::from)?;
136 let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
137 let mut map = HashMap::new();
138 for batch in batches {
139 let cat_array = batch
140 .column(0)
141 .as_any()
142 .downcast_ref::<datafusion::arrow::array::StringArray>()
143 .ok_or_else(|| {
144 FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
145 format!("Expected Utf8 array for column {}", col_name),
146 ))
147 })?;
148 let count_array = batch
149 .column(1)
150 .as_any()
151 .downcast_ref::<datafusion::arrow::array::Int64Array>()
152 .ok_or_else(|| {
153 FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
154 "Expected Int64 array".into(),
155 ))
156 })?;
157 for i in 0..batch.num_rows() {
158 if !cat_array.is_null(i) {
159 map.insert(cat_array.value(i).to_string(), count_array.value(i));
160 }
161 }
162 }
163 Ok(map)
164}
165
166fn apply_mapping<T: Clone + 'static + datafusion::logical_expr::Literal>(
171 df: DataFrame,
172 target_cols: &[String],
173 mapping_fn: impl Fn(&str) -> Option<Vec<(String, T)>>,
174 default_fn: impl Fn(&str) -> Option<Expr>,
175) -> FeatureFactoryResult<DataFrame> {
176 let exprs: Vec<Expr> = df
177 .schema()
178 .fields()
179 .iter()
180 .map(|field| {
181 let name = field.name();
182 if target_cols.contains(name) {
183 if let Some(map) = mapping_fn(name) {
184 build_case_expr(name, &map, default_fn(name)).alias(name)
185 } else {
186 col(name)
187 }
188 } else {
189 col(name)
190 }
191 })
192 .collect();
193 df.select(exprs).map_err(FeatureFactoryError::from)
194}
195
196pub struct OneHotEncoder {
198 pub columns: Vec<String>,
199 pub categories: HashMap<String, Vec<String>>,
201 fitted: bool,
202}
203
204impl OneHotEncoder {
205 pub fn new(columns: Vec<String>) -> Self {
207 Self {
208 columns,
209 categories: HashMap::new(),
210 fitted: false,
211 }
212 }
213
214 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
216 validate_string_columns(df, &self.columns)?;
217 for col_name in &self.columns {
218 let values = extract_distinct_values(df, col_name).await?;
219 self.categories.insert(col_name.clone(), values);
220 }
221 self.fitted = true;
222 Ok(())
223 }
224
225 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
227 if !self.fitted {
228 return Err(FeatureFactoryError::FitNotCalled);
229 }
230 let mut exprs = vec![];
231 for field in df.schema().fields() {
232 exprs.push(col(field.name()));
233 }
234 for col_name in &self.columns {
235 if let Some(cats) = self.categories.get(col_name) {
236 for cat in cats {
237 let safe_cat = sanitize_category(cat);
238 let new_col_name = format!("{}_{}", col_name, safe_cat);
239 let case_expr = Expr::Case(DFCase {
240 expr: None,
241 when_then_expr: vec![(
242 Box::new(col(col_name).eq(lit(cat.clone()))),
243 Box::new(lit(1_i32)),
244 )],
245 else_expr: Some(Box::new(lit(0_i32))),
246 })
247 .alias(new_col_name);
248 exprs.push(case_expr);
249 }
250 }
251 }
252 df.select(exprs).map_err(FeatureFactoryError::from)
253 }
254
255 fn inherent_is_stateful(&self) -> bool {
257 true
258 }
259}
260
261pub struct CountFrequencyEncoder {
263 pub columns: Vec<String>,
264 pub mapping: HashMap<String, HashMap<String, i64>>,
266 fitted: bool,
267}
268
269impl CountFrequencyEncoder {
270 pub fn new(columns: Vec<String>) -> Self {
272 Self {
273 columns,
274 mapping: HashMap::new(),
275 fitted: false,
276 }
277 }
278
279 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
281 validate_string_columns(df, &self.columns)?;
282 for col_name in &self.columns {
283 let map = extract_count_mapping(df, col_name).await?;
284 self.mapping.insert(col_name.clone(), map);
285 }
286 self.fitted = true;
287 Ok(())
288 }
289
290 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
292 if !self.fitted {
293 return Err(FeatureFactoryError::FitNotCalled);
294 }
295 apply_mapping(
296 df,
297 &self.columns,
298 |name| {
299 self.mapping.get(name).map(|m| {
300 m.iter()
301 .map(|(k, &v)| (k.clone(), v))
302 .collect::<Vec<(String, i64)>>()
303 })
304 },
305 |_| Some(lit(0_i64)),
306 )
307 }
308
309 fn inherent_is_stateful(&self) -> bool {
311 true
312 }
313}
314
315pub struct OrdinalEncoder {
318 pub columns: Vec<String>,
319 pub mapping: HashMap<String, HashMap<String, i64>>,
321 fitted: bool,
322}
323
324impl OrdinalEncoder {
325 pub fn new(columns: Vec<String>) -> Self {
327 Self {
328 columns,
329 mapping: HashMap::new(),
330 fitted: false,
331 }
332 }
333
334 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
336 validate_string_columns(df, &self.columns)?;
337 for col_name in &self.columns {
338 let mut values = extract_distinct_values(df, col_name).await?;
339 values.sort();
340 let mapping = values
341 .into_iter()
342 .enumerate()
343 .map(|(i, cat)| (cat, i as i64))
344 .collect();
345 self.mapping.insert(col_name.clone(), mapping);
346 }
347 self.fitted = true;
348 Ok(())
349 }
350
351 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
353 if !self.fitted {
354 return Err(FeatureFactoryError::FitNotCalled);
355 }
356 apply_mapping(
357 df,
358 &self.columns,
359 |name| {
360 self.mapping.get(name).map(|m| {
361 m.iter()
362 .map(|(k, &v)| (k.clone(), v))
363 .collect::<Vec<(String, i64)>>()
364 })
365 },
366 |_| Some(lit(0_i64)),
367 )
368 }
369
370 fn inherent_is_stateful(&self) -> bool {
372 true
373 }
374}
375
376pub struct MeanEncoder {
378 pub columns: Vec<String>,
379 pub target: String,
380 pub mapping: HashMap<String, HashMap<String, f64>>,
382 fitted: bool,
383}
384
385impl MeanEncoder {
386 pub fn new(columns: Vec<String>, target: String) -> Self {
388 Self {
389 columns,
390 target,
391 mapping: HashMap::new(),
392 fitted: false,
393 }
394 }
395
396 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
398 validate_string_columns(df, &self.columns)?;
399 validate_numeric_column(df, &self.target)?;
400 for col_name in &self.columns {
401 let agg_df = df
402 .clone()
403 .aggregate(
404 vec![col(col_name)],
405 vec![avg(col(&self.target)).alias("mean")],
406 )
407 .map_err(FeatureFactoryError::from)?;
408 let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
409 let mut map = HashMap::new();
410 for batch in batches {
411 let cat_array = batch
412 .column(0)
413 .as_any()
414 .downcast_ref::<datafusion::arrow::array::StringArray>()
415 .ok_or_else(|| {
416 FeatureFactoryError::DataFusionError(
417 datafusion::error::DataFusionError::Plan(format!(
418 "Expected Utf8 array for column {}",
419 col_name
420 )),
421 )
422 })?;
423 let mean_array = batch
424 .column(1)
425 .as_any()
426 .downcast_ref::<datafusion::arrow::array::Float64Array>()
427 .ok_or_else(|| {
428 FeatureFactoryError::DataFusionError(
429 datafusion::error::DataFusionError::Plan(
430 "Expected Float64 array".into(),
431 ),
432 )
433 })?;
434 for i in 0..batch.num_rows() {
435 if !cat_array.is_null(i) {
436 map.insert(cat_array.value(i).to_string(), mean_array.value(i));
437 }
438 }
439 }
440 self.mapping.insert(col_name.clone(), map);
441 }
442 self.fitted = true;
443 Ok(())
444 }
445
446 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
448 if !self.fitted {
449 return Err(FeatureFactoryError::FitNotCalled);
450 }
451 apply_mapping(
452 df,
453 &self.columns,
454 |name| {
455 self.mapping.get(name).map(|m| {
456 m.iter()
457 .map(|(k, &v)| (k.clone(), v))
458 .collect::<Vec<(String, f64)>>()
459 })
460 },
461 |_| Some(lit(0.0_f64)),
462 )
463 }
464
465 fn inherent_is_stateful(&self) -> bool {
467 true
468 }
469}
470
471pub struct WoEEncoder {
474 pub columns: Vec<String>,
475 pub target: String,
476 pub mapping: HashMap<String, HashMap<String, f64>>,
478 fitted: bool,
479}
480
481impl WoEEncoder {
482 pub fn new(columns: Vec<String>, target: String) -> Self {
484 Self {
485 columns,
486 target,
487 mapping: HashMap::new(),
488 fitted: false,
489 }
490 }
491
492 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
494 validate_string_columns(df, &self.columns)?;
495 validate_numeric_column(df, &self.target)?;
496 let overall_df = df
497 .clone()
498 .aggregate(vec![], vec![count(col(&self.target)).alias("total")])
499 .map_err(FeatureFactoryError::from)?;
500 let overall_batches = overall_df
501 .collect()
502 .await
503 .map_err(FeatureFactoryError::from)?;
504 let _total = if let Some(batch) = overall_batches.first() {
505 let total_array = batch
506 .column(0)
507 .as_any()
508 .downcast_ref::<datafusion::arrow::array::Int64Array>()
509 .ok_or_else(|| {
510 FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
511 "Expected Int64 array".into(),
512 ))
513 })?;
514 total_array.value(0) as f64
515 } else {
516 return Err(FeatureFactoryError::DataFusionError(
517 datafusion::error::DataFusionError::Plan("No data found".into()),
518 ));
519 };
520
521 for col_name in &self.columns {
522 let grouped = df
523 .clone()
524 .aggregate(
525 vec![col(col_name), col(&self.target)],
526 vec![count(lit(1)).alias("cnt")],
527 )
528 .map_err(FeatureFactoryError::from)?;
529 let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
530 let mut cat_counts: HashMap<String, (f64, f64)> = HashMap::new(); for batch in batches {
532 let cat_array = batch
533 .column(0)
534 .as_any()
535 .downcast_ref::<datafusion::arrow::array::StringArray>()
536 .ok_or_else(|| {
537 FeatureFactoryError::DataFusionError(
538 datafusion::error::DataFusionError::Plan(format!(
539 "Expected Utf8 array for column {}",
540 col_name
541 )),
542 )
543 })?;
544 let target_array = batch
545 .column(1)
546 .as_any()
547 .downcast_ref::<datafusion::arrow::array::Int64Array>()
548 .ok_or_else(|| {
549 FeatureFactoryError::DataFusionError(
550 datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
551 )
552 })?;
553 let count_array = batch
554 .column(2)
555 .as_any()
556 .downcast_ref::<datafusion::arrow::array::Int64Array>()
557 .ok_or_else(|| {
558 FeatureFactoryError::DataFusionError(
559 datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
560 )
561 })?;
562 for i in 0..batch.num_rows() {
563 if !cat_array.is_null(i) {
564 let cat = cat_array.value(i).to_string();
565 let target_val = target_array.value(i);
566 let cnt = count_array.value(i) as f64;
567 let entry = cat_counts.entry(cat).or_insert((0.0, 0.0));
568 if target_val == 1 {
569 entry.0 += cnt;
570 } else {
571 entry.1 += cnt;
572 }
573 }
574 }
575 }
576 let mut mapping = HashMap::new();
577 for (cat, (good, bad)) in cat_counts {
578 let woe = ((good + 1e-6) / (bad + 1e-6)).ln();
579 mapping.insert(cat, woe);
580 }
581 self.mapping.insert(col_name.clone(), mapping);
582 }
583 self.fitted = true;
584 Ok(())
585 }
586
587 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
589 if !self.fitted {
590 return Err(FeatureFactoryError::FitNotCalled);
591 }
592 apply_mapping(
593 df,
594 &self.columns,
595 |name| {
596 self.mapping.get(name).map(|m| {
597 m.iter()
598 .map(|(k, &v)| (k.clone(), v))
599 .collect::<Vec<(String, f64)>>()
600 })
601 },
602 |_| Some(lit(0.0_f64)),
603 )
604 }
605
606 fn inherent_is_stateful(&self) -> bool {
608 true
609 }
610}
611
612pub struct RareLabelEncoder {
614 pub columns: Vec<String>,
615 pub threshold: f64, pub mapping: HashMap<String, HashMap<String, String>>,
618 fitted: bool,
619}
620
621impl RareLabelEncoder {
622 pub fn new(columns: Vec<String>, threshold: f64) -> Self {
624 Self {
625 columns,
626 threshold,
627 mapping: HashMap::new(),
628 fitted: false,
629 }
630 }
631
632 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
634 if self.threshold < 0.0 || self.threshold > 1.0 {
635 return Err(FeatureFactoryError::InvalidParameter(format!(
636 "Threshold {} must be between 0 and 1",
637 self.threshold
638 )));
639 }
640 validate_string_columns(df, &self.columns)?;
641 let total_df = df
642 .clone()
643 .aggregate(vec![], vec![count(lit(1)).alias("total")])
644 .map_err(FeatureFactoryError::from)?;
645 let total_batches = total_df
646 .collect()
647 .await
648 .map_err(FeatureFactoryError::from)?;
649 let total = if let Some(batch) = total_batches.first() {
650 let total_array = batch
651 .column(0)
652 .as_any()
653 .downcast_ref::<datafusion::arrow::array::Int64Array>()
654 .ok_or_else(|| {
655 FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
656 "Expected Int64 array".into(),
657 ))
658 })?;
659 total_array.value(0) as f64
660 } else {
661 return Err(FeatureFactoryError::DataFusionError(
662 datafusion::error::DataFusionError::Plan("No data found".into()),
663 ));
664 };
665
666 for col_name in &self.columns {
667 let grouped = df
668 .clone()
669 .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
670 .map_err(FeatureFactoryError::from)?;
671 let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
672 let mut map = HashMap::new();
673 for batch in batches {
674 let cat_array = batch
675 .column(0)
676 .as_any()
677 .downcast_ref::<datafusion::arrow::array::StringArray>()
678 .ok_or_else(|| {
679 FeatureFactoryError::DataFusionError(
680 datafusion::error::DataFusionError::Plan(format!(
681 "Expected Utf8 array for column {}",
682 col_name
683 )),
684 )
685 })?;
686 let cnt_array = batch
687 .column(1)
688 .as_any()
689 .downcast_ref::<datafusion::arrow::array::Int64Array>()
690 .ok_or_else(|| {
691 FeatureFactoryError::DataFusionError(
692 datafusion::error::DataFusionError::Plan("Expected Int64 array".into()),
693 )
694 })?;
695 for i in 0..batch.num_rows() {
696 if !cat_array.is_null(i) {
697 let cat = cat_array.value(i).to_string();
698 let cnt = cnt_array.value(i) as f64;
699 let freq = cnt / total;
700 let encoded = if freq < self.threshold {
701 "rare".to_string()
702 } else {
703 cat.clone()
704 };
705 map.insert(cat, encoded);
706 }
707 }
708 }
709 self.mapping.insert(col_name.clone(), map);
710 }
711 self.fitted = true;
712 Ok(())
713 }
714
715 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
717 if !self.fitted {
718 return Err(FeatureFactoryError::FitNotCalled);
719 }
720 apply_mapping(
721 df,
722 &self.columns,
723 |name| {
724 self.mapping.get(name).map(|m| {
725 m.iter()
726 .map(|(k, v)| (k.clone(), v.clone()))
727 .collect::<Vec<(String, String)>>()
728 })
729 },
730 |name| Some(col(name)),
731 )
732 }
733
734 fn inherent_is_stateful(&self) -> bool {
736 true
737 }
738}
739
740impl_transformer!(OneHotEncoder);
742impl_transformer!(CountFrequencyEncoder);
743impl_transformer!(OrdinalEncoder);
744impl_transformer!(MeanEncoder);
745impl_transformer!(WoEEncoder);
746impl_transformer!(RareLabelEncoder);