hamelin_datafusion 0.7.3

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for aggregate functions.

use datafusion::functions_aggregate::expr_fn as agg_fn;

use datafusion::logical_expr::expr::Case as DFCase;
use datafusion::logical_expr::{lit, Expr as DFExpr};

use hamelin_lib::func::defs::{
    AggAll, AggAny, AggMax, AggMin, AnyValue, ApproxDistinct, ApproxPercentile, ArrayAgg, Avg,
    CountAny, CountDistinct, CountIf, CountStar, MapAgg, MultimapAgg, Stddev, Sum,
};

use super::DataFusionTranslationRegistry;

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // count(*) -> count()
    registry.register::<CountStar>(|_params| Ok(agg_fn::count(datafusion::logical_expr::lit(1))));

    // count(x) -> count(x)
    registry.register::<CountAny>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::count(x))
    });

    // count_distinct(x) -> count(distinct x)
    registry.register::<CountDistinct>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::count_distinct(x))
    });

    // sum(x) -> sum(x)
    registry.register::<Sum>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::sum(x))
    });

    // avg(x) -> avg(x)
    registry.register::<Avg>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::avg(x))
    });

    // min(x) -> min(x)
    registry.register::<AggMin>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::min(x))
    });

    // max(x) -> max(x)
    registry.register::<AggMax>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::max(x))
    });

    // array_agg(x) -> sliding_array_agg(x)
    // We use our custom sliding_array_agg which supports retract_batch for sliding windows.
    // DataFusion's built-in array_agg doesn't support sliding window frames.
    registry.register::<ArrayAgg>(|mut params| {
        let x = params.take()?.expr;
        Ok(crate::udf::sliding_array_agg_udaf().call(vec![x]))
    });

    // stdev(x) -> stddev_samp(x)
    registry.register::<Stddev>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::stddev(x))
    });

    // map_agg(key, value) -> hamelin_map_agg(key, value)
    registry.register::<MapAgg>(|mut params| {
        let key = params.take()?.expr;
        let value = params.take()?.expr;
        Ok(crate::udf::map_agg_udaf().call(vec![key, value]))
    });

    // multimap_agg(key, value) -> hamelin_multimap_agg(key, value)
    registry.register::<MultimapAgg>(|mut params| {
        let key = params.take()?.expr;
        let value = params.take()?.expr;
        Ok(crate::udf::multimap_agg_udaf().call(vec![key, value]))
    });

    // approx_distinct(x) -> approx_distinct(to_json_string(to_variant(x)))
    // DataFusion's approx_distinct doesn't support complex types (struct, array, map, variant),
    // so we convert to Variant then to JSON string to get a hashable representation.
    registry.register::<ApproxDistinct>(|mut params| {
        let x = params.take()?.expr;
        let variant_x = crate::udf::cast_to_variant_udf().call(vec![x]);
        let json_x = crate::udf::variant_to_json_udf().call(vec![variant_x]);
        Ok(agg_fn::approx_distinct(json_x))
    });

    // any_value(x) -> hamelin_any_value(x)
    // Custom UDF with native GroupsAccumulator to avoid per-row ScalarValue::compact overhead
    registry.register::<AnyValue>(|mut params| {
        let x = params.take()?.expr;
        Ok(crate::udf::any_value_udaf().call(vec![x]))
    });

    // all(x) -> bool_and(x)
    registry.register::<AggAll>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::bool_and(x))
    });

    // any(x) -> bool_or(x)
    registry.register::<AggAny>(|mut params| {
        let x = params.take()?.expr;
        Ok(agg_fn::bool_or(x))
    });

    // count_if(condition) -> sum(CASE WHEN condition THEN 1 ELSE 0 END)
    // DataFusion doesn't have count_if, so we use sum with a case expression
    registry.register::<CountIf>(|mut params| {
        let condition = params.take()?.expr;
        let case_expr = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(Box::new(condition), Box::new(lit(1i64)))],
            else_expr: Some(Box::new(lit(0i64))),
        });
        Ok(agg_fn::sum(case_expr))
    });

    // approx_percentile(x, percentile) -> approx_percentile_cont_udaf().call([x, percentile])
    registry.register::<ApproxPercentile>(|mut params| {
        use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf;
        let x = params.take()?.expr;
        let percentile = params.take()?.expr;
        Ok(approx_percentile_cont_udaf().call(vec![x, percentile]))
    });
}