hamelin_datafusion 0.7.5

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for array functions.

use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::expr::Cast as DFCast;
use datafusion::logical_expr::{BinaryExpr, Expr as DFExpr, Operator as DFOperator};
use datafusion_functions_nested::expr_fn as array_fn;
use datafusion_functions_nested::string::string_to_array;

use hamelin_lib::func::defs::{
    ArrayAll, ArrayAny, ArrayAvg, ArrayConcat, ArrayDistinct, ArrayJoin2, ArrayJoin3, ArrayMax,
    ArrayMin, ArrayOrMapLen, ArraySum, FilterNull, Flatten, GetArray, Sequence2, Sequence3, Slice,
    Split,
};
use hamelin_lib::types::Type;

use super::DataFusionTranslationRegistry;

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // get(array, index) -> array_element(array, adjusted_index)
    // Hamelin uses 0-based indexing, DataFusion uses 1-based
    // For negative indices, both count from end the same way, so no adjustment needed
    // Formula: CASE WHEN index >= 0 THEN index + 1 ELSE index END
    registry.register::<GetArray>(|mut params| {
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let array = params.take_by_name("array")?.expr;
        let index = params.take_by_name("index")?.expr;

        // CASE WHEN index >= 0 THEN index + 1 ELSE index END
        let condition = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(index.clone()),
            DFOperator::GtEq,
            Box::new(lit(0i64)),
        ));
        let then_expr = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(index.clone()),
            DFOperator::Plus,
            Box::new(lit(1i64)),
        ));
        let index_adjusted = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(Box::new(condition), Box::new(then_expr))],
            else_expr: Some(Box::new(index)),
        });

        Ok(array_fn::array_element(array, index_adjusted))
    });

    // left + right -> array_concat(left, right)
    registry.register::<ArrayConcat>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(array_fn::array_concat(vec![left, right]))
    });

    // filter_null(x) -> array_remove_all(x, NULL)
    // Note: This isn't a perfect match - array_remove_all removes a specific value
    // A more accurate implementation would use transform + filter
    registry.register::<FilterNull>(|mut params| {
        let array = params.take()?.expr;
        // Use array_remove_all with NULL
        Ok(array_fn::array_remove_all(
            array,
            datafusion::logical_expr::lit(datafusion::common::ScalarValue::Null),
        ))
    });

    // len(array) -> array_length(x)   (top-level element count)
    // len(map)   -> cardinality(x)    (entry count)
    // cardinality counts ALL elements recursively for nested arrays, so we
    // must use array_length (dimension-1 length) for arrays.
    registry.register::<ArrayOrMapLen>(|mut params| {
        let x = params.take()?;
        let inner = match x.typ.as_ref() {
            Type::Array(_) => array_fn::array_length(x.expr),
            _ => array_fn::cardinality(x.expr),
        };
        // Cast to Int64 for arithmetic compatibility
        Ok(DFExpr::Cast(DFCast::new(Box::new(inner), DataType::Int64)))
    });

    // array_distinct(x) -> array_distinct(x)
    registry.register::<ArrayDistinct>(|mut params| {
        let x = params.take()?.expr;
        Ok(array_fn::array_distinct(x))
    });

    // slice(array, start, end) -> array_slice(array, adjusted_start, adjusted_end)
    // Hamelin: 0-based indexing, end exclusive, negative indices from end
    // DataFusion: 1-based indexing, end inclusive, negative indices from end
    //
    // For start:
    //   - Positive: add 1 (0-based → 1-based)
    //   - Negative: keep as-is (both count from end the same way)
    //   Formula: CASE WHEN start >= 0 THEN start + 1 ELSE start END
    //
    // For end:
    //   - Positive: keep as-is (exclusive→inclusive cancels 1-base adjustment)
    //   - Negative: subtract 1 (exclusive → inclusive)
    //   Formula: CASE WHEN end >= 0 THEN end ELSE end - 1 END
    //   Special case: end = 0 in Hamelin means "empty slice", but DataFusion end=0 is invalid
    //   Actually, Hamelin end is exclusive so end=0 means "take nothing before index 0" = empty
    registry.register::<Slice>(|mut params| {
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let array = params.take_by_name("array")?.expr;
        let start = params.take_by_name("start")?.expr;
        let end = params.take_by_name("end")?.expr;

        // start: CASE WHEN start >= 0 THEN start + 1 ELSE start END
        let start_condition = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(start.clone()),
            DFOperator::GtEq,
            Box::new(lit(0i64)),
        ));
        let start_then = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(start.clone()),
            DFOperator::Plus,
            Box::new(lit(1i64)),
        ));
        let start_adjusted = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(Box::new(start_condition), Box::new(start_then))],
            else_expr: Some(Box::new(start)),
        });

        // end: CASE WHEN end >= 0 THEN end ELSE end - 1 END
        // This converts Hamelin's exclusive end to DataFusion's inclusive end for negative indices
        // Example: Hamelin end=-1 (exclusive, last element) → DataFusion end=-2 (inclusive, 2nd from end)
        let end_condition = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(end.clone()),
            DFOperator::GtEq,
            Box::new(lit(0i64)),
        ));
        let end_else = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(end.clone()),
            DFOperator::Minus,
            Box::new(lit(1i64)),
        ));
        let end_adjusted = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(Box::new(end_condition), Box::new(end.clone()))],
            else_expr: Some(Box::new(end_else)),
        });

        Ok(array_fn::array_slice(
            array,
            start_adjusted,
            end_adjusted,
            None,
        ))
    });

    // split(string, delimiter) -> string_to_array(string, delimiter)
    registry.register::<Split>(|mut params| {
        let string = params.take()?.expr;
        let delimiter = params.take()?.expr;
        Ok(string_to_array(
            string,
            delimiter,
            datafusion::logical_expr::lit(datafusion::common::ScalarValue::Null),
        ))
    });

    // array_join(array, delimiter) -> array_to_string(array, delimiter)
    registry.register::<ArrayJoin2>(|mut params| {
        let array = params.take()?.expr;
        let delimiter = params.take()?.expr;
        Ok(array_fn::array_to_string(array, delimiter))
    });

    // array_join(array, delimiter, null_replacement) -> array_to_string(array, delimiter, null_replacement)
    // DataFusion's array_to_string supports a third argument for null string replacement
    registry.register::<ArrayJoin3>(|mut params| {
        let array = params.take()?.expr;
        let delimiter = params.take()?.expr;
        let null_replacement = params.take()?.expr;
        Ok(
            datafusion_functions_nested::string::array_to_string_udf().call(vec![
                array,
                delimiter,
                null_replacement,
            ]),
        )
    });

    // flatten(x) -> flatten(x)
    registry.register::<Flatten>(|mut params| {
        let x = params.take()?.expr;
        Ok(array_fn::flatten(x))
    });

    // any(array<boolean>) -> three-valued logic matching bool_or:
    //   true if any true, null if any null (and no true), false otherwise
    // CASE WHEN array_has(arr, true) THEN true
    //      WHEN cardinality(array_remove_all(array_remove_all(arr, true), false)) > 0 THEN NULL
    //      ELSE false END
    registry.register::<ArrayAny>(|mut params| {
        use datafusion::common::ScalarValue;
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let array = params.take()?.expr;

        let has_true = array_fn::array_has(array.clone(), lit(true));
        let nulls_remain = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(DFExpr::Cast(DFCast::new(
                Box::new(array_fn::cardinality(array_fn::array_remove_all(
                    array_fn::array_remove_all(array, lit(true)),
                    lit(false),
                ))),
                DataType::Int64,
            ))),
            DFOperator::Gt,
            Box::new(lit(0i64)),
        ));

        Ok(DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![
                (Box::new(has_true), Box::new(lit(true))),
                (
                    Box::new(nulls_remain),
                    Box::new(lit(ScalarValue::Boolean(None))),
                ),
            ],
            else_expr: Some(Box::new(lit(false))),
        }))
    });

    // all(array<boolean>) -> three-valued logic matching bool_and:
    //   false if any false, null if any null (and no false), true otherwise
    // CASE WHEN array_has(arr, false) THEN false
    //      WHEN cardinality(array_remove_all(array_remove_all(arr, true), false)) > 0 THEN NULL
    //      ELSE true END
    registry.register::<ArrayAll>(|mut params| {
        use datafusion::common::ScalarValue;
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let array = params.take()?.expr;

        let has_false = array_fn::array_has(array.clone(), lit(false));
        let nulls_remain = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(DFExpr::Cast(DFCast::new(
                Box::new(array_fn::cardinality(array_fn::array_remove_all(
                    array_fn::array_remove_all(array, lit(true)),
                    lit(false),
                ))),
                DataType::Int64,
            ))),
            DFOperator::Gt,
            Box::new(lit(0i64)),
        ));

        Ok(DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![
                (Box::new(has_false), Box::new(lit(false))),
                (
                    Box::new(nulls_remain),
                    Box::new(lit(ScalarValue::Boolean(None))),
                ),
            ],
            else_expr: Some(Box::new(lit(true))),
        }))
    });

    // max(array) -> array_max(array) (DataFusion doesn't have direct equivalent)
    registry.register::<ArrayMax>(|mut params| {
        let array = params.take()?.expr;
        Ok(array_fn::array_max(array))
    });

    // min(array) -> array_min(array)
    registry.register::<ArrayMin>(|mut params| {
        let array = params.take()?.expr;
        Ok(array_fn::array_min(array))
    });

    // sum(array) -> hamelin_array_sum(array)
    registry.register::<ArraySum>(|mut params| {
        let array = params.take()?.expr;
        Ok(crate::udf::array_sum_udf().call(vec![array]))
    });

    // avg(array) -> hamelin_array_avg(array)
    registry.register::<ArrayAvg>(|mut params| {
        let array = params.take()?.expr;
        Ok(crate::udf::array_avg_udf().call(vec![array]))
    });

    // sequence(start, stop) -> range(start, stop+1, 1)
    // DataFusion's range is exclusive on the end, Hamelin's sequence is inclusive
    registry.register::<Sequence2>(|mut params| {
        let start = params.take()?.expr;
        let stop = params.take()?.expr;

        // Add 1 to stop to make it inclusive (DataFusion range is exclusive)
        let stop_inclusive = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(stop),
            DFOperator::Plus,
            Box::new(datafusion::logical_expr::lit(1i64)),
        ));

        Ok(array_fn::range(
            start,
            stop_inclusive,
            datafusion::logical_expr::lit(1i64),
        ))
    });

    // sequence(start, stop, step) -> range(start, stop + sign(step), step)
    // DataFusion's range is exclusive on the end, Hamelin's sequence is inclusive.
    // We adjust by sign(step) so positive steps add 1 and negative steps subtract 1.
    registry.register::<Sequence3>(|mut params| {
        let start = params.take()?.expr;
        let stop = params.take()?.expr;
        let step = params.take()?.expr;

        let sign_step = DFExpr::Cast(DFCast::new(
            Box::new(datafusion_functions::math::expr_fn::signum(step.clone())),
            DataType::Int64,
        ));
        let stop_inclusive = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(stop),
            DFOperator::Plus,
            Box::new(sign_step),
        ));

        Ok(array_fn::range(start, stop_inclusive, step))
    });
}