hamelin_datafusion 0.7.5

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for membership operators (IN, NOT IN).

use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{BinaryExpr, Expr as DFExpr, Operator as DFOperator};
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_functions::core::expr_fn as core_fn;
use datafusion_functions_nested::expr_fn as array_fn;

use hamelin_lib::func::defs::{
    InArray, InMap, InRange, InTimestampInterval, InTimestampTimestamp, InTuple, NotInArray,
    NotInMap, NotInRange, NotInTimestampInterval, NotInTimestampTimestamp, NotInTuple,
    NumericRange, TimestampRange,
};

use super::DataFusionTranslationRegistry;

/// Try to extract tuple elements from a DataFusion expression.
///
/// When a Hamelin tuple is translated, it becomes `struct(e0, e1, e2, ...)` which is
/// represented as a ScalarFunction with the "struct" UDF. This function extracts
/// the original elements from that expression.
///
/// Returns None if the expression is not a struct() call, or if the elements cannot
/// be extracted.
fn try_extract_tuple_elements(expr: DFExpr) -> Option<Vec<DFExpr>> {
    match expr {
        DFExpr::ScalarFunction(scalar_func) => {
            if scalar_func.func.name() == "struct" {
                Some(scalar_func.args)
            } else {
                None
            }
        }
        _ => None,
    }
}

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // left IN array -> array_has(array, left)
    registry.register::<InArray>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(array_fn::array_has(right, left))
    });

    // left NOT IN array -> NOT array_has(array, left)
    registry.register::<NotInArray>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(DFExpr::Not(Box::new(array_fn::array_has(right, left))))
    });

    // left IN map -> map_contains_key(map, left)
    registry.register::<InMap>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        // DataFusion doesn't have map_contains_key directly
        // Use: array_has(map_keys(map), key)
        Ok(array_fn::array_has(array_fn::map_keys(right), left))
    });

    // left NOT IN map -> NOT map_contains_key(map, left)
    registry.register::<NotInMap>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        // Use: NOT array_has(map_keys(map), key)
        Ok(DFExpr::Not(Box::new(array_fn::array_has(
            array_fn::map_keys(right),
            left,
        ))))
    });

    // left..right -> named_struct("begin", left, "end", right)
    registry.register::<NumericRange>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(core_fn::named_struct(vec![
            datafusion::logical_expr::lit("begin"),
            left,
            datafusion::logical_expr::lit("end"),
            right,
        ]))
    });

    // timestamp/interval range
    registry.register::<TimestampRange>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(core_fn::named_struct(vec![
            datafusion::logical_expr::lit("begin"),
            left,
            datafusion::logical_expr::lit("end"),
            right,
        ]))
    });

    // left IN range -> (begin IS NULL OR left >= begin) AND (end IS NULL OR left < end)
    // Handle NULL bounds: NULL begin = negative infinity, NULL end = positive infinity
    registry.register::<InRange>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        // Extract begin and end from the range struct
        let begin = right.clone().field("begin");
        let end = right.field("end");

        // begin IS NULL OR left >= begin
        let begin_is_null = DFExpr::IsNull(Box::new(begin.clone()));
        let gte_begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left.clone()),
            DFOperator::GtEq,
            Box::new(begin),
        ));
        let begin_check = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(begin_is_null),
            DFOperator::Or,
            Box::new(gte_begin),
        ));

        // end IS NULL OR left < end
        let end_is_null = DFExpr::IsNull(Box::new(end.clone()));
        let lt_end = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left),
            DFOperator::Lt,
            Box::new(end),
        ));
        let end_check = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(end_is_null),
            DFOperator::Or,
            Box::new(lt_end),
        ));

        Ok(DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(begin_check),
            DFOperator::And,
            Box::new(end_check),
        )))
    });

    // left NOT IN range -> NOT((begin IS NULL OR left >= begin) AND (end IS NULL OR left < end))
    // Handle NULL bounds: NULL begin = negative infinity, NULL end = positive infinity
    registry.register::<NotInRange>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        let begin = right.clone().field("begin");
        let end = right.field("end");

        // begin IS NULL OR left >= begin
        let begin_is_null = DFExpr::IsNull(Box::new(begin.clone()));
        let gte_begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left.clone()),
            DFOperator::GtEq,
            Box::new(begin),
        ));
        let begin_check = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(begin_is_null),
            DFOperator::Or,
            Box::new(gte_begin),
        ));

        // end IS NULL OR left < end
        let end_is_null = DFExpr::IsNull(Box::new(end.clone()));
        let lt_end = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left),
            DFOperator::Lt,
            Box::new(end),
        ));
        let end_check = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(end_is_null),
            DFOperator::Or,
            Box::new(lt_end),
        ));

        let in_range = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(begin_check),
            DFOperator::And,
            Box::new(end_check),
        ));
        Ok(DFExpr::Not(Box::new(in_range)))
    });

    // timestamp IN interval -> timestamp >= now() + interval AND timestamp <= now()
    // The interval is already negative (e.g., -24h), so now() + (-24h) = now() - 24h.
    registry.register::<InTimestampInterval>(|mut params| {
        let timestamp = params.take()?.expr;
        let interval = params.take()?.expr;
        let now = datafusion_functions::datetime::expr_fn::now();
        let begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(now.clone()),
            DFOperator::Plus,
            Box::new(interval),
        ));
        let gte_begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(timestamp.clone()),
            DFOperator::GtEq,
            Box::new(begin),
        ));
        let lte_now = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(timestamp),
            DFOperator::LtEq,
            Box::new(now),
        ));
        Ok(DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(gte_begin),
            DFOperator::And,
            Box::new(lte_now),
        )))
    });

    // timestamp NOT IN interval
    // The interval is already negative (e.g., -24h), so now() + (-24h) = now() - 24h.
    registry.register::<NotInTimestampInterval>(|mut params| {
        let timestamp = params.take()?.expr;
        let interval = params.take()?.expr;
        let now = datafusion_functions::datetime::expr_fn::now();
        let begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(now.clone()),
            DFOperator::Plus,
            Box::new(interval),
        ));
        let gte_begin = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(timestamp.clone()),
            DFOperator::GtEq,
            Box::new(begin),
        ));
        let lte_now = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(timestamp),
            DFOperator::LtEq,
            Box::new(now),
        ));
        let in_interval = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(gte_begin),
            DFOperator::And,
            Box::new(lte_now),
        ));
        Ok(DFExpr::Not(Box::new(in_interval)))
    });

    // timestamp IN timestamp -> left >= right AND left <= now()
    registry.register::<InTimestampTimestamp>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        let now = datafusion_functions::datetime::expr_fn::now();
        let gte_right = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left.clone()),
            DFOperator::GtEq,
            Box::new(right),
        ));
        let lte_now = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left),
            DFOperator::LtEq,
            Box::new(now),
        ));
        Ok(DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(gte_right),
            DFOperator::And,
            Box::new(lte_now),
        )))
    });

    // timestamp NOT IN timestamp -> NOT (left >= right AND left <= now())
    registry.register::<NotInTimestampTimestamp>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        let now = datafusion_functions::datetime::expr_fn::now();
        let gte_right = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left.clone()),
            DFOperator::GtEq,
            Box::new(right),
        ));
        let lte_now = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left),
            DFOperator::LtEq,
            Box::new(now),
        ));
        let in_range = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(gte_right),
            DFOperator::And,
            Box::new(lte_now),
        ));
        Ok(DFExpr::Not(Box::new(in_range)))
    });

    // left IN (a, b, c) -> left IN [a, b, c]
    // The tuple is translated as struct(a, b, c), so we extract the elements
    registry.register::<InTuple>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;

        // Extract tuple elements from the struct() expression
        let list = try_extract_tuple_elements(right).ok_or_else(|| {
            anyhow::anyhow!(
                "IN tuple: expected struct() expression for tuple, got a different expression type"
            )
        })?;

        Ok(DFExpr::InList(InList::new(Box::new(left), list, false)))
    });

    // left NOT IN (a, b, c) -> left NOT IN [a, b, c]
    // The tuple is translated as struct(a, b, c), so we extract the elements
    registry.register::<NotInTuple>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;

        // Extract tuple elements from the struct() expression
        let list = try_extract_tuple_elements(right).ok_or_else(|| {
            anyhow::anyhow!(
                "NOT IN tuple: expected struct() expression for tuple, got a different expression type"
            )
        })?;

        Ok(DFExpr::InList(InList::new(Box::new(left), list, true)))
    });
}