use crate::exceptions::{FeatureFactoryError, FeatureFactoryResult};
use crate::impl_transformer;
use datafusion::arrow::datatypes::DataType;
use datafusion::dataframe::DataFrame;
use datafusion_expr::{col, lit, Expr};
use datafusion_functions::datetime::{date_part, to_unixtime};
use std::ops::{Div, Sub};
fn validate_datetime_column(df: &DataFrame, col_name: &str) -> FeatureFactoryResult<()> {
let field = df.schema().field_with_name(None, col_name).map_err(|_| {
FeatureFactoryError::MissingColumn(format!("Column '{}' not found", col_name))
})?;
match field.data_type() {
DataType::Timestamp(_, _) | DataType::Date32 | DataType::Date64 => Ok(()),
dt => Err(FeatureFactoryError::InvalidParameter(format!(
"Column '{}' must be a datetime type (Timestamp, Date32, or Date64), but found {:?}",
col_name, dt
))),
}
}
pub struct DatetimeFeatures {
pub columns: Vec<String>,
}
impl DatetimeFeatures {
pub fn new(columns: Vec<String>) -> Self {
Self { columns }
}
pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
Ok(())
}
pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
for col_name in &self.columns {
validate_datetime_column(&df, col_name)?;
}
let mut exprs: Vec<Expr> = df.schema().fields().iter().map(|f| col(f.name())).collect();
for col_name in &self.columns {
let base = col(col_name);
let year_expr = date_part()
.call(vec![lit("year"), base.clone()])
.alias(format!("{}_year", col_name));
let month_expr = date_part()
.call(vec![lit("month"), base.clone()])
.alias(format!("{}_month", col_name));
let day_expr = date_part()
.call(vec![lit("day"), base.clone()])
.alias(format!("{}_day", col_name));
let hour_expr = date_part()
.call(vec![lit("hour"), base.clone()])
.alias(format!("{}_hour", col_name));
let minute_expr = date_part()
.call(vec![lit("minute"), base.clone()])
.alias(format!("{}_minute", col_name));
let second_expr = date_part()
.call(vec![lit("second"), base.clone()])
.alias(format!("{}_second", col_name));
let weekday_expr = date_part()
.call(vec![lit("dow"), base.clone()])
.alias(format!("{}_weekday", col_name));
exprs.push(year_expr);
exprs.push(month_expr);
exprs.push(day_expr);
exprs.push(hour_expr);
exprs.push(minute_expr);
exprs.push(second_expr);
exprs.push(weekday_expr);
}
df.select(exprs)
.map_err(FeatureFactoryError::DataFusionError)
}
fn inherent_is_stateful(&self) -> bool {
false
}
}
pub enum TimeUnit {
Second,
Minute,
Hour,
Day,
}
impl TimeUnit {
pub fn as_str(&self) -> &'static str {
match self {
TimeUnit::Second => "second",
TimeUnit::Minute => "minute",
TimeUnit::Hour => "hour",
TimeUnit::Day => "day",
}
}
}
fn timestamp_diff_expr(left: Expr, right: Expr, unit: &str) -> Expr {
let left_sec = to_unixtime().call(vec![left]);
let right_sec = to_unixtime().call(vec![right]);
let diff_in_seconds = left_sec.sub(right_sec);
match unit {
"second" => diff_in_seconds,
"minute" => diff_in_seconds.div(lit(60.0)),
"hour" => diff_in_seconds.div(lit(3600.0)),
"day" => diff_in_seconds.div(lit(86400.0)),
_ => diff_in_seconds,
}
}
pub struct DatetimeSubtraction {
pub new_features: Vec<(String, String, String, TimeUnit)>,
}
impl DatetimeSubtraction {
pub fn new(new_features: Vec<(String, String, String, TimeUnit)>) -> Self {
Self { new_features }
}
pub async fn fit(&mut self, _df: &DataFrame) -> FeatureFactoryResult<()> {
Ok(())
}
pub fn transform(&self, df: DataFrame) -> FeatureFactoryResult<DataFrame> {
for (_, left, right, _) in &self.new_features {
validate_datetime_column(&df, left)?;
validate_datetime_column(&df, right)?;
}
let mut exprs: Vec<Expr> = df.schema().fields().iter().map(|f| col(f.name())).collect();
for (new_name, left, right, unit) in &self.new_features {
let diff_expr =
timestamp_diff_expr(col(left), col(right), unit.as_str()).alias(new_name);
exprs.push(diff_expr);
}
df.select(exprs)
.map_err(FeatureFactoryError::DataFusionError)
}
fn inherent_is_stateful(&self) -> bool {
false
}
}
impl_transformer!(DatetimeFeatures);
impl_transformer!(DatetimeSubtraction);