hamelin_datafusion 0.6.12

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for math functions.

use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::expr::{Case as DFCase, Cast};
use datafusion::logical_expr::{lit, Expr as DFExpr, Operator};
use datafusion_functions::math::expr_fn as math_fn;

use hamelin_lib::func::defs::{
    Abs, Cbrt, Ceil, Degrees, Euler, Exp, Floor, Ln, Log, Log10, Log2, Pi, Pow, Radians, Round1,
    Round2, Sign, Sqrt, Truncate, WidthBucket2, WidthBucket4,
};

use super::DataFusionTranslationRegistry;

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // abs(x) -> abs(x)
    registry.register::<Abs>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::abs(x))
    });

    // cbrt(x) -> cbrt(x)
    registry.register::<Cbrt>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::cbrt(x))
    });

    // ceil(x) -> ceil(x)
    registry.register::<Ceil>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::ceil(x))
    });

    // degrees(x) -> degrees(x)
    registry.register::<Degrees>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::degrees(x))
    });

    // e() -> lit(E)
    registry.register::<Euler>(|_params| Ok(datafusion::logical_expr::lit(std::f64::consts::E)));

    // exp(x) -> exp(x)
    registry.register::<Exp>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::exp(x))
    });

    // floor(x) -> floor(x)
    registry.register::<Floor>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::floor(x))
    });

    // ln(x) -> ln(x)
    registry.register::<Ln>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::ln(x))
    });

    // log(base, x) -> log(base, x)
    registry.register::<Log>(|mut params| {
        let base = params.take()?.expr;
        let x = params.take()?.expr;
        Ok(math_fn::log(base, x))
    });

    // log10(x) -> log10(x)
    registry.register::<Log10>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::log10(x))
    });

    // log2(x) -> log2(x)
    registry.register::<Log2>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::log2(x))
    });

    // pi() -> pi()
    registry.register::<Pi>(|_params| Ok(math_fn::pi()));

    // pow(x, p) -> power(x, p)
    registry.register::<Pow>(|mut params| {
        let x = params.take()?.expr;
        let p = params.take()?.expr;
        Ok(math_fn::power(x, p))
    });

    // radians(x) -> radians(x)
    registry.register::<Radians>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::radians(x))
    });

    // round(x) -> round(x)
    registry.register::<Round1>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::round(vec![x]))
    });

    // round(x, d) -> round(x, d)
    registry.register::<Round2>(|mut params| {
        let x = params.take()?.expr;
        let d = params.take()?.expr;
        Ok(math_fn::round(vec![x, d]))
    });

    // sign(x) -> signum(x)
    registry.register::<Sign>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::signum(x))
    });

    // sqrt(x) -> sqrt(x)
    registry.register::<Sqrt>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::sqrt(x))
    });

    // truncate(x) -> trunc(x)
    registry.register::<Truncate>(|mut params| {
        let x = params.take()?.expr;
        Ok(math_fn::trunc(vec![x]))
    });

    // width_bucket(x, bound1, bound2, n) ->
    //   CASE
    //     WHEN x < bound1 THEN 0
    //     WHEN x >= bound2 THEN n + 1
    //     ELSE FLOOR((x - bound1) / (bound2 - bound1) * n) + 1
    //   END
    registry.register::<WidthBucket4>(|mut params| {
        let x = params.take()?.expr;
        let bound1 = params.take()?.expr;
        let bound2 = params.take()?.expr;
        let n = params.take()?.expr;

        // x < bound1
        let cond_below = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(x.clone()),
            op: Operator::Lt,
            right: Box::new(bound1.clone()),
        });

        // x >= bound2
        let cond_above = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(x.clone()),
            op: Operator::GtEq,
            right: Box::new(bound2.clone()),
        });

        // n + 1
        let n_plus_one = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(n.clone()),
            op: Operator::Plus,
            right: Box::new(lit(1i64)),
        });

        // (x - bound1)
        let x_minus_bound1 = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(x),
            op: Operator::Minus,
            right: Box::new(bound1.clone()),
        });

        // (bound2 - bound1)
        let range = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(bound2),
            op: Operator::Minus,
            right: Box::new(bound1),
        });

        // Cast numerator to Float64 to avoid integer division truncation
        let x_minus_bound1_f64 =
            DFExpr::Cast(Cast::new(Box::new(x_minus_bound1), DataType::Float64));

        // (x - bound1) / (bound2 - bound1)
        let normalized = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(x_minus_bound1_f64),
            op: Operator::Divide,
            right: Box::new(range),
        });

        // (x - bound1) / (bound2 - bound1) * n
        let scaled = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(normalized),
            op: Operator::Multiply,
            right: Box::new(n),
        });

        // FLOOR(...) + 1
        let bucket = DFExpr::BinaryExpr(datafusion::logical_expr::BinaryExpr {
            left: Box::new(math_fn::floor(scaled)),
            op: Operator::Plus,
            right: Box::new(lit(1i64)),
        });

        Ok(DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![
                (Box::new(cond_below), Box::new(lit(0i64))),
                (Box::new(cond_above), Box::new(n_plus_one)),
            ],
            else_expr: Some(Box::new(bucket)),
        }))
    });

    // width_bucket(x, bins) -> hamelin_width_bucket(x, bins)
    registry.register::<WidthBucket2>(|mut params| {
        let x = params.take()?.expr;
        let bins = params.take()?.expr;
        Ok(crate::udf::width_bucket_array_udf().call(vec![x, bins]))
    });
}