use datafusion::common::ScalarValue;
use datafusion::logical_expr::expr::NullTreatment;
use datafusion::logical_expr::expr_fn::ExprFunctionExt;
use datafusion::logical_expr::Expr as DFExpr;
use datafusion_functions_window::expr_fn as window_fn;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::nth_value::nth_value_udwf;
use hamelin_lib::func::defs::{
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, PercentRank, Rank, RowNumber,
};
use super::DataFusionTranslationRegistry;
fn extract_bool_literal(expr: &DFExpr) -> Option<bool> {
match expr {
DFExpr::Literal(ScalarValue::Boolean(Some(v)), _) => Some(*v),
_ => None,
}
}
fn get_null_treatment(ignore_nulls: &DFExpr) -> Option<NullTreatment> {
extract_bool_literal(ignore_nulls).map(|ignore| {
if ignore {
NullTreatment::IgnoreNulls
} else {
NullTreatment::RespectNulls
}
})
}
pub fn register(registry: &mut DataFusionTranslationRegistry) {
registry.register::<RowNumber>(|_params| Ok(window_fn::row_number()));
registry.register::<Rank>(|_params| Ok(window_fn::rank()));
registry.register::<DenseRank>(|_params| Ok(window_fn::dense_rank()));
registry.register::<Lag>(|mut params| {
let expression = params.take()?.expr;
let offset = params.take()?.expr;
let ignore_nulls = params.take()?.expr;
let default = datafusion::logical_expr::lit(ScalarValue::Null);
let base = lag_udwf().call(vec![expression, offset, default]);
match get_null_treatment(&ignore_nulls) {
Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
None => Ok(base),
}
});
registry.register::<Lead>(|mut params| {
let expression = params.take()?.expr;
let offset = params.take()?.expr;
let ignore_nulls = params.take()?.expr;
let default = datafusion::logical_expr::lit(ScalarValue::Null);
let base = lead_udwf().call(vec![expression, offset, default]);
match get_null_treatment(&ignore_nulls) {
Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
None => Ok(base),
}
});
registry.register::<FirstValue>(|mut params| {
let expression = params.take()?.expr;
let ignore_nulls = params.take()?.expr;
let base = window_fn::first_value(expression);
match get_null_treatment(&ignore_nulls) {
Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
None => Ok(base),
}
});
registry.register::<LastValue>(|mut params| {
let expression = params.take()?.expr;
let ignore_nulls = params.take()?.expr;
let base = window_fn::last_value(expression);
match get_null_treatment(&ignore_nulls) {
Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
None => Ok(base),
}
});
registry.register::<NthValue>(|mut params| {
let expression = params.take()?.expr;
let n = params.take()?.expr;
let ignore_nulls = params.take()?.expr;
let base = nth_value_udwf().call(vec![expression, n]);
match get_null_treatment(&ignore_nulls) {
Some(nt) => base.null_treatment(Some(nt)).build().map_err(Into::into),
None => Ok(base),
}
});
registry.register::<CumeDist>(|_params| Ok(window_fn::cume_dist()));
registry.register::<PercentRank>(|_params| Ok(window_fn::percent_rank()));
}