feature_factory/transformers/
imputation.rs1use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
18use crate::impl_transformer;
19use datafusion::dataframe::DataFrame;
20use datafusion::functions_aggregate::expr_fn::{approx_percentile_cont, avg, count};
21use datafusion::logical_expr::{col, lit, not, Case as DFCase, Expr};
22use datafusion::scalar::ScalarValue;
23use std::collections::HashMap;
24
25fn validate_columns(df: &DataFrame, target_cols: &[String]) -> FeatureFactoryResult<()> {
28 let schema = df.schema();
29 for col_name in target_cols {
30 if schema.field_with_name(None, col_name).is_err() {
31 return Err(FeatureFactoryError::MissingColumn(format!(
32 "Column '{}' not found in DataFrame",
33 col_name
34 )));
35 }
36 }
37 Ok(())
38}
39
40fn coalesce_expr_for(name: &str, fallback: Expr) -> Expr {
43 Expr::Case(DFCase {
44 expr: None,
45 when_then_expr: vec![(Box::new(not(col(name).is_null())), Box::new(col(name)))],
46 else_expr: Some(Box::new(fallback)),
47 })
48}
49
50fn apply_imputation<F>(
54 df: DataFrame,
55 target_cols: &[String],
56 get_fallback: F,
57) -> FeatureFactoryResult<DataFrame>
58where
59 F: Fn(&str) -> Option<Expr>,
60{
61 let exprs: Vec<Expr> = df
62 .schema()
63 .fields()
64 .iter()
65 .map(|field| {
66 let name = field.name();
67 if target_cols.contains(name) {
68 if let Some(fallback_expr) = get_fallback(name) {
69 coalesce_expr_for(name, fallback_expr).alias(name)
70 } else {
71 col(name)
72 }
73 } else {
74 col(name)
75 }
76 })
77 .collect();
78 df.select(exprs).map_err(FeatureFactoryError::from)
79}
80
81pub struct MeanMedianImputer {
83 pub columns: Vec<String>,
84 pub strategy: ImputeStrategy,
85 pub impute_values: HashMap<String, f64>,
86 fitted: bool,
87}
88
89#[derive(Debug, Clone, Copy)]
90pub enum ImputeStrategy {
91 Mean,
92 Median, }
94
95impl MeanMedianImputer {
96 pub fn new(columns: Vec<String>, strategy: ImputeStrategy) -> Self {
97 Self {
98 columns,
99 strategy,
100 impute_values: HashMap::new(),
101 fitted: false,
102 }
103 }
104
105 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
107 validate_columns(df, &self.columns)?;
108 for col_name in &self.columns {
109 match self.strategy {
110 ImputeStrategy::Mean => {
111 let agg_df = df
112 .clone()
113 .aggregate(vec![], vec![avg(col(col_name)).alias("avg")])
114 .map_err(FeatureFactoryError::from)?;
115 let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
116 if let Some(batch) = batches.first() {
117 if batch.num_rows() > 0 {
118 let array = batch.column(0);
119 let scalar = ScalarValue::try_from_array(array, 0)
120 .map_err(FeatureFactoryError::from)?;
121 if let ScalarValue::Float64(Some(avg_val)) = scalar {
122 self.impute_values.insert(col_name.clone(), avg_val);
123 } else {
124 return Err(FeatureFactoryError::DataFusionError(
125 datafusion::error::DataFusionError::Plan(format!(
126 "Failed to compute average for column {}",
127 col_name
128 )),
129 ));
130 }
131 }
132 }
133 }
134 ImputeStrategy::Median => {
135 return Err(FeatureFactoryError::NotImplemented(
136 "Median imputation not implemented in DF mode".to_string(),
137 ));
138 }
139 }
140 }
141 self.fitted = true;
142 Ok(())
143 }
144
145 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
147 if !self.fitted {
148 return Err(FeatureFactoryError::FitNotCalled);
149 }
150 validate_columns(&df, &self.columns)?;
151 apply_imputation(df, &self.columns, |name| {
152 self.impute_values.get(name).map(|&v| lit(v))
153 })
154 }
155
156 fn inherent_is_stateful(&self) -> bool {
158 true
159 }
160}
161
162pub struct ArbitraryNumberImputer {
164 pub columns: Vec<String>,
165 pub number: f64,
166}
167
168impl ArbitraryNumberImputer {
169 pub fn new(columns: Vec<String>, number: f64) -> Self {
170 Self { columns, number }
171 }
172
173 pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
175 Ok(())
176 }
177
178 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
180 if !self.number.is_finite() {
181 return Err(FeatureFactoryError::InvalidParameter(format!(
182 "Fixed number {} must be finite",
183 self.number
184 )));
185 }
186 validate_columns(&df, &self.columns)?;
187 apply_imputation(df, &self.columns, |_| Some(lit(self.number)))
188 }
189
190 fn inherent_is_stateful(&self) -> bool {
192 false
193 }
194}
195
196pub struct EndTailImputer {
198 pub columns: Vec<String>,
199 pub percentile: f64,
200 pub impute_values: HashMap<String, f64>,
201 fitted: bool,
202}
203
204impl EndTailImputer {
205 pub fn new(columns: Vec<String>, percentile: f64) -> Self {
206 Self {
207 columns,
208 percentile,
209 impute_values: HashMap::new(),
210 fitted: false,
211 }
212 }
213
214 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
216 validate_columns(df, &self.columns)?;
217 if self.percentile < 0.0 || self.percentile > 1.0 {
218 return Err(FeatureFactoryError::InvalidParameter(format!(
219 "Percentile {} must be between 0 and 1",
220 self.percentile
221 )));
222 }
223 for col_name in &self.columns {
224 let agg_df = df
225 .clone()
226 .aggregate(
227 vec![],
228 vec![
229 approx_percentile_cont(col(col_name), lit(self.percentile), None)
230 .alias("perc"),
231 ],
232 )
233 .map_err(FeatureFactoryError::from)?;
234 let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
235 if let Some(batch) = batches.first() {
236 let array = batch.column(0);
237 let scalar =
238 ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
239 if let ScalarValue::Float64(Some(val)) = scalar {
240 self.impute_values.insert(col_name.clone(), val);
241 } else {
242 return Err(FeatureFactoryError::DataFusionError(
243 datafusion::error::DataFusionError::Plan(format!(
244 "Failed to compute percentile for column {}",
245 col_name
246 )),
247 ));
248 }
249 }
250 }
251 self.fitted = true;
252 Ok(())
253 }
254
255 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
257 if !self.fitted {
258 return Err(FeatureFactoryError::FitNotCalled);
259 }
260 validate_columns(&df, &self.columns)?;
261 apply_imputation(df, &self.columns, |name| {
262 self.impute_values.get(name).map(|&v| lit(v))
263 })
264 }
265
266 fn inherent_is_stateful(&self) -> bool {
268 true
269 }
270}
271
272pub struct CategoricalImputer {
274 pub columns: Vec<String>,
275 pub default: Option<String>,
276 pub impute_values: HashMap<String, String>,
277 fitted: bool,
278}
279
280impl CategoricalImputer {
281 pub fn new(columns: Vec<String>, default: Option<String>) -> Self {
282 Self {
283 columns,
284 default,
285 impute_values: HashMap::new(),
286 fitted: false,
287 }
288 }
289
290 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
292 validate_columns(df, &self.columns)?;
293 if self.default.is_some() {
294 self.fitted = true;
295 return Ok(());
296 }
297 for col_name in &self.columns {
298 let grouped = df
299 .clone()
300 .aggregate(vec![col(col_name)], vec![count(col(col_name)).alias("cnt")])
301 .map_err(FeatureFactoryError::from)?
302 .sort(vec![col("cnt").sort(false, false)])
303 .map_err(FeatureFactoryError::from)?
304 .limit(0, Some(1))
305 .map_err(FeatureFactoryError::from)?;
306 let batches = grouped.collect().await.map_err(FeatureFactoryError::from)?;
307 if let Some(batch) = batches.first() {
308 let array = batch.column(0);
309 let scalar =
310 ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
311 if let ScalarValue::Utf8(Some(mode_val)) = scalar {
312 self.impute_values.insert(col_name.clone(), mode_val);
313 } else {
314 return Err(FeatureFactoryError::DataFusionError(
315 datafusion::error::DataFusionError::Plan(format!(
316 "Failed to compute mode for column {}",
317 col_name
318 )),
319 ));
320 }
321 }
322 }
323 self.fitted = true;
324 Ok(())
325 }
326
327 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
329 if !self.fitted {
330 return Err(FeatureFactoryError::FitNotCalled);
331 }
332 validate_columns(&df, &self.columns)?;
333 apply_imputation(df, &self.columns, |name| {
334 if let Some(default_val) = &self.default {
335 Some(lit(default_val.clone()))
336 } else {
337 self.impute_values
338 .get(name)
339 .map(|mode_val| lit(mode_val.clone()))
340 }
341 })
342 }
343
344 fn inherent_is_stateful(&self) -> bool {
346 true
347 }
348}
349
350pub struct AddMissingIndicator {
352 pub columns: Vec<String>,
353 pub suffix: String,
354}
355
356impl AddMissingIndicator {
357 pub fn new(columns: Vec<String>, suffix: Option<String>) -> Self {
358 Self {
359 columns,
360 suffix: suffix.unwrap_or_else(|| "_missing".to_string()),
361 }
362 }
363
364 pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
366 Ok(())
367 }
368
369 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
371 validate_columns(&df, &self.columns)?;
372 let mut exprs = vec![];
373 for field in df.schema().fields() {
374 let name = field.name();
375 exprs.push(col(name));
376 if self.columns.contains(name) {
377 exprs.push(
378 col(name)
379 .is_null()
380 .alias(format!("{}{}", name, self.suffix)),
381 );
382 }
383 }
384 df.select(exprs).map_err(FeatureFactoryError::from)
385 }
386
387 fn inherent_is_stateful(&self) -> bool {
389 false
390 }
391}
392
393pub struct DropMissingData {
395 pub columns: Option<Vec<String>>,
398}
399
400impl DropMissingData {
401 pub fn new() -> Self {
402 Self { columns: None }
403 }
404
405 pub fn with_columns(columns: Vec<String>) -> Self {
406 Self {
407 columns: Some(columns),
408 }
409 }
410
411 pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
413 Ok(())
414 }
415
416 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
418 let target_columns = if let Some(ref cols) = self.columns {
419 cols.clone()
420 } else {
421 df.schema()
422 .fields()
423 .iter()
424 .map(|f| f.name().to_string())
425 .collect()
426 };
427 let predicates: Vec<Expr> = target_columns
428 .iter()
429 .map(|col_name| col(col_name).is_not_null())
430 .collect();
431 let combined = predicates
432 .into_iter()
433 .reduce(|acc, expr| acc.and(expr))
434 .unwrap();
435 df.filter(combined)
436 .map_err(crate::exceptions::FeatureFactoryError::from)
437 }
438
439 fn inherent_is_stateful(&self) -> bool {
441 false
442 }
443}
444
445impl Default for DropMissingData {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451impl_transformer!(MeanMedianImputer);
453impl_transformer!(ArbitraryNumberImputer);
454impl_transformer!(EndTailImputer);
455impl_transformer!(CategoricalImputer);
456impl_transformer!(AddMissingIndicator);
457impl_transformer!(DropMissingData);