feature_factory/transformers/
discretization.rs1use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
16use crate::impl_transformer;
17use datafusion::dataframe::DataFrame;
18use datafusion::functions_aggregate::expr_fn::approx_percentile_cont;
19use datafusion::logical_expr::{col, lit, Case as DFCase, Expr};
20use datafusion::scalar::ScalarValue;
21use std::collections::HashMap;
22
23fn validate_numeric_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
25 let field = df.schema().field_with_name(None, col_name).map_err(|_| {
26 FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
27 })?;
28 match field.data_type() {
29 datafusion::arrow::datatypes::DataType::Float64
30 | datafusion::arrow::datatypes::DataType::Int64 => Ok(()),
31 dt => Err(FeatureFactoryError::InvalidParameter(format!(
32 "Column '{}' must be numeric (Float64 or Int64), but found {:?}",
33 col_name, dt
34 ))),
35 }
36}
37
38fn build_interval_case_expr(col_name: &str, intervals: &[(f64, f64, String)]) -> Expr {
46 let n = intervals.len();
47 let when_then_expr = intervals
48 .iter()
49 .enumerate()
50 .map(|(i, (lower, upper, label))| {
51 let condition = if i == n - 1 {
52 col(col_name)
53 .gt_eq(lit(*lower))
54 .and(col(col_name).lt_eq(lit(*upper)))
55 } else {
56 col(col_name)
57 .gt_eq(lit(*lower))
58 .and(col(col_name).lt(lit(*upper)))
59 };
60 (Box::new(condition), Box::new(lit(label.clone())))
61 })
62 .collect::<Vec<_>>();
63 Expr::Case(DFCase {
64 expr: None,
65 when_then_expr,
66 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
67 })
68}
69
70fn apply_interval_mapping(
74 df: DataFrame,
75 target_cols: &[String],
76 mapping: &HashMap<String, Vec<(f64, f64, String)>>,
77) -> FeatureFactoryResult<DataFrame> {
78 let exprs: Vec<Expr> = df
79 .schema()
80 .fields()
81 .iter()
82 .map(|field| {
83 let name = field.name();
84 if target_cols.contains(name) {
85 if let Some(intervals) = mapping.get(name) {
86 build_interval_case_expr(name, intervals).alias(name)
87 } else {
88 col(name)
89 }
90 } else {
91 col(name)
92 }
93 })
94 .collect();
95 df.select(exprs).map_err(FeatureFactoryError::from)
96}
97
98async fn compute_min_max(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<(f64, f64)> {
101 validate_numeric_column(df, col_name)?;
102 let min_df = df
103 .clone()
104 .aggregate(
105 vec![],
106 vec![approx_percentile_cont(col(col_name), lit(0.0), None).alias("min")],
107 )
108 .map_err(FeatureFactoryError::from)?;
109 let min_batches = min_df.collect().await.map_err(FeatureFactoryError::from)?;
110 let min_val = if let Some(batch) = min_batches.first() {
111 let array = batch.column(0);
112 let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
113 if let ScalarValue::Float64(Some(val)) = scalar {
114 val
115 } else {
116 return Err(FeatureFactoryError::DataFusionError(
117 datafusion::error::DataFusionError::Plan(format!(
118 "Failed to compute min for column {}",
119 col_name
120 )),
121 ));
122 }
123 } else {
124 return Err(FeatureFactoryError::DataFusionError(
125 datafusion::error::DataFusionError::Plan("No data found".to_string()),
126 ));
127 };
128
129 let max_df = df
130 .clone()
131 .aggregate(
132 vec![],
133 vec![approx_percentile_cont(col(col_name), lit(1.0), None).alias("max")],
134 )
135 .map_err(FeatureFactoryError::from)?;
136 let max_batches = max_df.collect().await.map_err(FeatureFactoryError::from)?;
137 let max_val = if let Some(batch) = max_batches.first() {
138 let array = batch.column(0);
139 let scalar = ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
140 if let ScalarValue::Float64(Some(val)) = scalar {
141 val
142 } else {
143 return Err(FeatureFactoryError::DataFusionError(
144 datafusion::error::DataFusionError::Plan(format!(
145 "Failed to compute max for column {}",
146 col_name
147 )),
148 ));
149 }
150 } else {
151 return Err(FeatureFactoryError::DataFusionError(
152 datafusion::error::DataFusionError::Plan("No data found".to_string()),
153 ));
154 };
155
156 Ok((min_val, max_val))
157}
158
159pub struct ArbitraryDiscretizer {
161 pub columns: Vec<String>,
162 pub intervals: HashMap<String, Vec<(f64, f64, String)>>,
163}
164
165impl ArbitraryDiscretizer {
166 pub fn new(columns: Vec<String>, intervals: HashMap<String, Vec<(f64, f64, String)>>) -> Self {
167 Self { columns, intervals }
168 }
169
170 pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
172 Ok(())
173 }
174
175 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
178 for col_name in &self.columns {
179 validate_numeric_column(&df, col_name)?;
180 }
181 for (col, intervals) in &self.intervals {
182 for (lower, upper, _) in intervals {
183 if lower >= upper {
184 return Err(FeatureFactoryError::InvalidParameter(format!(
185 "For column '{}', lower bound {} is not less than upper bound {}",
186 col, lower, upper
187 )));
188 }
189 }
190 }
191 apply_interval_mapping(df, &self.columns, &self.intervals)
192 }
193
194 fn inherent_is_stateful(&self) -> bool {
196 false
197 }
198}
199
200pub struct EqualFrequencyDiscretizer {
202 pub columns: Vec<String>,
203 pub bins: usize,
204 pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
205 fitted: bool,
206}
207
208impl EqualFrequencyDiscretizer {
209 pub fn new(columns: Vec<String>, bins: usize) -> Self {
210 Self {
211 columns,
212 bins,
213 mapping: HashMap::new(),
214 fitted: false,
215 }
216 }
217
218 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
220 if self.bins < 1 {
221 return Err(FeatureFactoryError::InvalidParameter(
222 "Number of bins must be at least 1".to_string(),
223 ));
224 }
225 for col_name in &self.columns {
226 validate_numeric_column(df, col_name)?;
227 let mut boundaries = Vec::with_capacity(self.bins + 1);
228 for i in 0..=self.bins {
229 let p = i as f64 / self.bins as f64;
230 let agg_df = df
231 .clone()
232 .aggregate(
233 vec![],
234 vec![approx_percentile_cont(col(col_name), lit(p), None).alias("q")],
235 )
236 .map_err(FeatureFactoryError::from)?;
237 let batches = agg_df.collect().await.map_err(FeatureFactoryError::from)?;
238 if let Some(batch) = batches.first() {
239 let array = batch.column(0);
240 let scalar =
241 ScalarValue::try_from_array(array, 0).map_err(FeatureFactoryError::from)?;
242 if let ScalarValue::Float64(Some(val)) = scalar {
243 boundaries.push(val);
244 } else {
245 return Err(FeatureFactoryError::DataFusionError(
246 datafusion::error::DataFusionError::Plan(format!(
247 "Failed to compute percentile for column {}",
248 col_name
249 )),
250 ));
251 }
252 }
253 }
254 if let (Some(first), Some(last)) = (boundaries.first(), boundaries.last()) {
255 if (first - last).abs() < 1e-6 {
256 return Err(FeatureFactoryError::InvalidParameter(format!(
257 "Column {} appears to be constant; cannot discretize into equal-frequency bins",
258 col_name
259 )));
260 }
261 }
262 let intervals = boundaries
263 .windows(2)
264 .map(|pair| {
265 let lower = pair[0];
266 let upper = pair[1];
267 let label = format!("[{:.2}, {:.2})", lower, upper);
268 (lower, upper, label)
269 })
270 .collect::<Vec<_>>();
271 self.mapping.insert(col_name.clone(), intervals);
272 }
273 self.fitted = true;
274 Ok(())
275 }
276
277 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
279 if !self.fitted {
280 return Err(FeatureFactoryError::FitNotCalled);
281 }
282 apply_interval_mapping(df, &self.columns, &self.mapping)
283 }
284
285 fn inherent_is_stateful(&self) -> bool {
287 true
288 }
289}
290
291pub struct EqualWidthDiscretizer {
293 pub columns: Vec<String>,
294 pub bins: usize,
295 pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
296 fitted: bool,
297}
298
299impl EqualWidthDiscretizer {
300 pub fn new(columns: Vec<String>, bins: usize) -> Self {
301 Self {
302 columns,
303 bins,
304 mapping: HashMap::new(),
305 fitted: false,
306 }
307 }
308
309 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
311 if self.bins < 1 {
312 return Err(FeatureFactoryError::InvalidParameter(
313 "Number of bins must be at least 1".to_string(),
314 ));
315 }
316 for col_name in &self.columns {
317 validate_numeric_column(df, col_name)?;
318 let (min_val, max_val) = compute_min_max(df, col_name).await?;
319 if (max_val - min_val).abs() < 1e-6 {
320 return Err(FeatureFactoryError::InvalidParameter(format!(
321 "Column {} is constant (min == max), cannot discretize into equal-width bins",
322 col_name
323 )));
324 }
325 let width = (max_val - min_val) / self.bins as f64;
326 let intervals = (0..self.bins)
327 .map(|i| {
328 let lower = min_val + i as f64 * width;
329 let upper = if i == self.bins - 1 {
330 max_val
331 } else {
332 min_val + (i as f64 + 1.0) * width
333 };
334 let label = format!("[{:.2}, {:.2})", lower, upper);
335 (lower, upper, label)
336 })
337 .collect::<Vec<_>>();
338 self.mapping.insert(col_name.clone(), intervals);
339 }
340 self.fitted = true;
341 Ok(())
342 }
343
344 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
346 if !self.fitted {
347 return Err(FeatureFactoryError::FitNotCalled);
348 }
349 apply_interval_mapping(df, &self.columns, &self.mapping)
350 }
351
352 fn inherent_is_stateful(&self) -> bool {
354 true
355 }
356}
357
358pub struct GeometricWidthDiscretizer {
360 pub columns: Vec<String>,
361 pub bins: usize,
362 pub mapping: HashMap<String, Vec<(f64, f64, String)>>,
363 fitted: bool,
364}
365
366impl GeometricWidthDiscretizer {
367 pub fn new(columns: Vec<String>, bins: usize) -> Self {
368 Self {
369 columns,
370 bins,
371 mapping: HashMap::new(),
372 fitted: false,
373 }
374 }
375
376 pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
379 if self.bins < 1 {
380 return Err(FeatureFactoryError::InvalidParameter(
381 "Number of bins must be at least 1".to_string(),
382 ));
383 }
384 for col_name in &self.columns {
385 validate_numeric_column(df, col_name)?;
386 let (min_val, max_val) = compute_min_max(df, col_name).await?;
387 if min_val <= 0.0 {
388 return Err(FeatureFactoryError::DataFusionError(
389 datafusion::error::DataFusionError::Plan(format!(
390 "Column {} has non-positive values, cannot apply geometric discretization",
391 col_name
392 )),
393 ));
394 }
395 let ratio = (max_val / min_val).powf(1.0 / self.bins as f64);
396 let intervals = (0..self.bins)
397 .map(|i| {
398 let lower = min_val * ratio.powi(i as i32);
399 let upper = if i == self.bins - 1 {
400 max_val
401 } else {
402 min_val * ratio.powi((i + 1) as i32)
403 };
404 let label = format!("[{:.2}, {:.2})", lower, upper);
405 (lower, upper, label)
406 })
407 .collect::<Vec<_>>();
408 self.mapping.insert(col_name.clone(), intervals);
409 }
410 self.fitted = true;
411 Ok(())
412 }
413
414 pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
416 if !self.fitted {
417 return Err(FeatureFactoryError::FitNotCalled);
418 }
419 apply_interval_mapping(df, &self.columns, &self.mapping)
420 }
421
422 fn inherent_is_stateful(&self) -> bool {
424 true
425 }
426}
427
428impl_transformer!(ArbitraryDiscretizer);
430impl_transformer!(EqualFrequencyDiscretizer);
431impl_transformer!(EqualWidthDiscretizer);
432impl_transformer!(GeometricWidthDiscretizer);