use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
use crate::impl_transformer;
use datafusion::dataframe::DataFrame;
use datafusion::functions_aggregate::expr_fn::approx_percentile_cont;
use datafusion::logical_expr::{col, lit, Case as DFCase, Expr};
use datafusion::scalar::ScalarValue;
use std::collections::HashMap;
fn cap_expr_for(col_name: &str, lower: Option<f64>, upper: Option<f64>) -> Expr {
let base = col(col_name);
match (lower, upper) {
(Some(l), Some(u)) => Expr::Case(DFCase {
expr: None,
when_then_expr: vec![
(Box::new(base.clone().lt(lit(l))), Box::new(lit(l))),
(Box::new(base.clone().gt(lit(u))), Box::new(lit(u))),
],
else_expr: Some(Box::new(base)),
}),
(Some(l), None) => Expr::Case(DFCase {
expr: None,
when_then_expr: vec![(Box::new(base.clone().lt(lit(l))), Box::new(lit(l)))],
else_expr: Some(Box::new(base)),
}),
(None, Some(u)) => Expr::Case(DFCase {
expr: None,
when_then_expr: vec![(Box::new(base.clone().gt(lit(u))), Box::new(lit(u)))],
else_expr: Some(Box::new(base)),
}),
(None, None) => base,
}
}
async fn compute_percentiles_for_column(
df: &DataFrame,
col_name: &str,
lower_percentile: f64,
upper_percentile: f64,
) -> FeatureFactoryResult<(f64, f64)> {
if !(0.0..=1.0).contains(&lower_percentile) {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be between 0 and 1",
lower_percentile
)));
}
if !(0.0..=1.0).contains(&upper_percentile) {
return Err(FeatureFactoryError::InvalidParameter(format!(
"upper_percentile {} must be between 0 and 1",
upper_percentile
)));
}
if lower_percentile >= upper_percentile {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be less than upper_percentile {}",
lower_percentile, upper_percentile
)));
}
let lower_df = df
.clone()
.aggregate(
vec![],
vec![approx_percentile_cont(col(col_name), lit(lower_percentile), None).alias("lower")],
)
.map_err(FeatureFactoryError::DataFusionError)?;
let lower_batches = lower_df
.collect()
.await
.map_err(FeatureFactoryError::DataFusionError)?;
let lower = if let Some(batch) = lower_batches.first() {
let array = batch.column(0);
let scalar = ScalarValue::try_from_array(array, 0).map_err(|e| {
FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(format!(
"Error converting lower percentile for column {}: {}",
col_name, e
)))
})?;
if let ScalarValue::Float64(Some(val)) = scalar {
val
} else {
return Err(FeatureFactoryError::DataFusionError(
datafusion::error::DataFusionError::Plan(format!(
"Failed to compute lower percentile for column {}",
col_name
)),
));
}
} else {
return Err(FeatureFactoryError::DataFusionError(
datafusion::error::DataFusionError::Plan(format!(
"No data found when computing lower percentile for column {}",
col_name
)),
));
};
let upper_df = df
.clone()
.aggregate(
vec![],
vec![approx_percentile_cont(col(col_name), lit(upper_percentile), None).alias("upper")],
)
.map_err(FeatureFactoryError::DataFusionError)?;
let upper_batches = upper_df
.collect()
.await
.map_err(FeatureFactoryError::DataFusionError)?;
let upper = if let Some(batch) = upper_batches.first() {
let array = batch.column(0);
let scalar = ScalarValue::try_from_array(array, 0).map_err(|e| {
FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(format!(
"Error converting upper percentile for column {}: {}",
col_name, e
)))
})?;
if let ScalarValue::Float64(Some(val)) = scalar {
val
} else {
return Err(FeatureFactoryError::DataFusionError(
datafusion::error::DataFusionError::Plan(format!(
"Failed to compute upper percentile for column {}",
col_name
)),
));
}
} else {
return Err(FeatureFactoryError::DataFusionError(
datafusion::error::DataFusionError::Plan(format!(
"No data found when computing upper percentile for column {}",
col_name
)),
));
};
Ok((lower, upper))
}
pub struct ArbitraryOutlierCapper {
pub columns: Vec<String>,
pub lower_caps: HashMap<String, f64>,
pub upper_caps: HashMap<String, f64>,
}
impl ArbitraryOutlierCapper {
pub fn new(
columns: Vec<String>,
lower_caps: HashMap<String, f64>,
upper_caps: HashMap<String, f64>,
) -> Self {
Self {
columns,
lower_caps,
upper_caps,
}
}
pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
Ok(())
}
pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
let exprs: Vec<Expr> = df
.schema()
.fields()
.iter()
.map(|field| {
let name = field.name();
if self.columns.contains(&name.to_string()) {
let lower = self.lower_caps.get(name).cloned();
let upper = self.upper_caps.get(name).cloned();
cap_expr_for(name, lower, upper).alias(name)
} else {
col(name)
}
})
.collect();
df.select(exprs)
.map_err(FeatureFactoryError::DataFusionError)
}
fn inherent_is_stateful(&self) -> bool {
false
}
}
pub struct Winsorizer {
pub columns: Vec<String>,
pub lower_percentile: f64,
pub upper_percentile: f64,
pub thresholds: HashMap<String, (f64, f64)>,
fitted: bool,
}
impl Winsorizer {
pub fn new(columns: Vec<String>, lower_percentile: f64, upper_percentile: f64) -> Self {
Self {
columns,
lower_percentile,
upper_percentile,
thresholds: HashMap::new(),
fitted: false,
}
}
pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
if self.lower_percentile < 0.0 || self.lower_percentile > 1.0 {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be between 0 and 1",
self.lower_percentile
)));
}
if self.upper_percentile < 0.0 || self.upper_percentile > 1.0 {
return Err(FeatureFactoryError::InvalidParameter(format!(
"upper_percentile {} must be between 0 and 1",
self.upper_percentile
)));
}
if self.lower_percentile >= self.upper_percentile {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be less than upper_percentile {}",
self.lower_percentile, self.upper_percentile
)));
}
for col_name in &self.columns {
let (lower, upper) = compute_percentiles_for_column(
df,
col_name,
self.lower_percentile,
self.upper_percentile,
)
.await?;
self.thresholds.insert(col_name.clone(), (lower, upper));
}
self.fitted = true;
Ok(())
}
pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
if !self.fitted {
return Err(FeatureFactoryError::FitNotCalled);
}
let exprs: Vec<Expr> = df
.schema()
.fields()
.iter()
.map(|field| {
let name = field.name();
if self.columns.contains(&name.to_string()) {
if let Some(&(lower, upper)) = self.thresholds.get(name) {
cap_expr_for(name, Some(lower), Some(upper)).alias(name)
} else {
col(name)
}
} else {
col(name)
}
})
.collect();
df.select(exprs)
.map_err(FeatureFactoryError::DataFusionError)
}
fn inherent_is_stateful(&self) -> bool {
true
}
}
pub struct OutlierTrimmer {
pub columns: Vec<String>,
pub lower_percentile: f64,
pub upper_percentile: f64,
pub thresholds: HashMap<String, (f64, f64)>,
fitted: bool,
}
impl OutlierTrimmer {
pub fn new(columns: Vec<String>, lower_percentile: f64, upper_percentile: f64) -> Self {
Self {
columns,
lower_percentile,
upper_percentile,
thresholds: HashMap::new(),
fitted: false,
}
}
pub async fn fit(&mut self, df: &DataFrame) -> FeatureFactoryResult<()> {
if self.lower_percentile < 0.0 || self.lower_percentile > 1.0 {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be between 0 and 1",
self.lower_percentile
)));
}
if self.upper_percentile < 0.0 || self.upper_percentile > 1.0 {
return Err(FeatureFactoryError::InvalidParameter(format!(
"upper_percentile {} must be between 0 and 1",
self.upper_percentile
)));
}
if self.lower_percentile >= self.upper_percentile {
return Err(FeatureFactoryError::InvalidParameter(format!(
"lower_percentile {} must be less than upper_percentile {}",
self.lower_percentile, self.upper_percentile
)));
}
for col_name in &self.columns {
let (lower, upper) = compute_percentiles_for_column(
df,
col_name,
self.lower_percentile,
self.upper_percentile,
)
.await?;
self.thresholds.insert(col_name.clone(), (lower, upper));
}
self.fitted = true;
Ok(())
}
pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
if !self.fitted {
return Err(FeatureFactoryError::FitNotCalled);
}
let predicates: Vec<Expr> = df
.schema()
.fields()
.iter()
.filter_map(|field| {
let name = field.name();
if self.columns.contains(&name.to_string()) {
if let Some(&(lower, upper)) = self.thresholds.get(name) {
Some(col(name).gt_eq(lit(lower)).and(col(name).lt_eq(lit(upper))))
} else {
None
}
} else {
None
}
})
.collect();
if predicates.is_empty() {
return Ok(df);
}
let combined = predicates
.into_iter()
.reduce(|acc, expr| acc.and(expr))
.ok_or_else(|| {
FeatureFactoryError::DataFusionError(datafusion::error::DataFusionError::Plan(
"Failed to combine predicates".into(),
))
})?;
df.filter(combined)
.map_err(FeatureFactoryError::DataFusionError)
}
fn inherent_is_stateful(&self) -> bool {
true
}
}
impl_transformer!(ArbitraryOutlierCapper);
impl_transformer!(Winsorizer);
impl_transformer!(OutlierTrimmer);