use datafusion::functions_aggregate::expr_fn as agg_fn;
use datafusion::logical_expr::expr::Case as DFCase;
use datafusion::logical_expr::{lit, Expr as DFExpr};
use hamelin_lib::func::defs::{
AggAll, AggAny, AggMax, AggMin, AnyValue, ApproxDistinct, ApproxPercentile, ArrayAgg, Avg,
CountAny, CountDistinct, CountIf, CountStar, MapAgg, MultimapAgg, Stddev, Sum,
};
use super::DataFusionTranslationRegistry;
pub fn register(registry: &mut DataFusionTranslationRegistry) {
registry.register::<CountStar>(|_params| Ok(agg_fn::count(datafusion::logical_expr::lit(1))));
registry.register::<CountAny>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::count(x))
});
registry.register::<CountDistinct>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::count_distinct(x))
});
registry.register::<Sum>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::sum(x))
});
registry.register::<Avg>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::avg(x))
});
registry.register::<AggMin>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::min(x))
});
registry.register::<AggMax>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::max(x))
});
registry.register::<ArrayAgg>(|mut params| {
let x = params.take()?.expr;
Ok(crate::udf::sliding_array_agg_udaf().call(vec![x]))
});
registry.register::<Stddev>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::stddev(x))
});
registry.register::<MapAgg>(|mut params| {
let key = params.take()?.expr;
let value = params.take()?.expr;
Ok(crate::udf::map_agg_udaf().call(vec![key, value]))
});
registry.register::<MultimapAgg>(|mut params| {
let key = params.take()?.expr;
let value = params.take()?.expr;
Ok(crate::udf::multimap_agg_udaf().call(vec![key, value]))
});
registry.register::<ApproxDistinct>(|mut params| {
let x = params.take()?.expr;
let variant_x = crate::udf::cast_to_variant_udf().call(vec![x]);
let json_x = crate::udf::variant_to_json_udf().call(vec![variant_x]);
Ok(agg_fn::approx_distinct(json_x))
});
registry.register::<AnyValue>(|mut params| {
let x = params.take()?.expr;
Ok(crate::udf::any_value_udaf().call(vec![x]))
});
registry.register::<AggAll>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::bool_and(x))
});
registry.register::<AggAny>(|mut params| {
let x = params.take()?.expr;
Ok(agg_fn::bool_or(x))
});
registry.register::<CountIf>(|mut params| {
let condition = params.take()?.expr;
let case_expr = DFExpr::Case(DFCase {
expr: None,
when_then_expr: vec![(Box::new(condition), Box::new(lit(1i64)))],
else_expr: Some(Box::new(lit(0i64))),
});
Ok(agg_fn::sum(case_expr))
});
registry.register::<ApproxPercentile>(|mut params| {
use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf;
let x = params.take()?.expr;
let percentile = params.take()?.expr;
Ok(approx_percentile_cont_udaf().call(vec![x, percentile]))
});
}