hamelin_datafusion 0.7.8

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for regex functions.

use datafusion::functions::regex::expr_fn as regex_fn;

use hamelin_lib::func::defs::{
    RegexpCount, RegexpExtract2, RegexpExtract3, RegexpExtractAll2, RegexpExtractAll3, RegexpLike,
    RegexpPosition2, RegexpPosition3, RegexpPosition4, RegexpReplace2, RegexpReplace3, RegexpSplit,
};

use super::DataFusionTranslationRegistry;

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // regexp_count(string, pattern) -> regexp_count(string, pattern, start, flags)
    registry.register::<RegexpCount>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        Ok(regex_fn::regexp_count(string, pattern, None, None))
    });

    // regexp_like(string, pattern) -> regexp_like(string, pattern, None)
    registry.register::<RegexpLike>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        Ok(regex_fn::regexp_like(string, pattern, None))
    });

    // regexp_extract(string, pattern) -> regexp_match(string, '(' || pattern || ')')[1]
    //
    // Trino's regexp_extract(string, pattern) returns the ENTIRE matched substring when no
    // group is specified. DataFusion's regexp_match returns only captured groups, not the
    // full match. To emulate Trino behavior, we wrap the pattern in a capturing group.
    registry.register::<RegexpExtract2>(|mut params| {
        use datafusion::logical_expr::{lit, BinaryExpr, Expr, Operator};

        let string = params.take()?.expr;
        let pattern = params.take()?.expr;

        // Wrap pattern in capturing group: '(' || pattern || ')'
        let inner_concat = Expr::BinaryExpr(BinaryExpr {
            left: Box::new(lit("(")),
            op: Operator::StringConcat,
            right: Box::new(pattern),
        });
        let wrapped_pattern = Expr::BinaryExpr(BinaryExpr {
            left: Box::new(inner_concat),
            op: Operator::StringConcat,
            right: Box::new(lit(")")),
        });

        let matches = regex_fn::regexp_match(string, wrapped_pattern, None);
        // Get first captured group (index 1 in 1-based), which is now the full match
        Ok(datafusion_functions_nested::expr_fn::array_element(
            matches,
            lit(1i64),
        ))
    });

    // regexp_extract(string, pattern, group) -> regexp_match(string, pattern)[group]
    registry.register::<RegexpExtract3>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let group = params.take()?.expr;
        let matches = regex_fn::regexp_match(string, pattern, None);
        // Both Hamelin and DataFusion use 1-based group indices
        Ok(datafusion_functions_nested::expr_fn::array_element(
            matches, group,
        ))
    });

    // regexp_extract_all(string, pattern) -> hamelin_regexp_extract_all(string, pattern)
    registry.register::<RegexpExtractAll2>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        Ok(crate::udf::regexp_extract_all_udf().call(vec![string, pattern]))
    });

    // regexp_extract_all(string, pattern, group) -> hamelin_regexp_extract_all(string, pattern, group)
    registry.register::<RegexpExtractAll3>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let group = params.take()?.expr;
        Ok(crate::udf::regexp_extract_all_udf().call(vec![string, pattern, group]))
    });

    // regexp_replace(string, pattern, replacement) -> regexp_replace(string, pattern, replacement)
    registry.register::<RegexpReplace2>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        // Default replacement is empty string
        Ok(regex_fn::regexp_replace(
            string,
            pattern,
            datafusion::logical_expr::lit(""),
            Some(datafusion::logical_expr::lit("g")), // global replace
        ))
    });

    registry.register::<RegexpReplace3>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let replacement = params.take()?.expr;
        Ok(regex_fn::regexp_replace(
            string,
            pattern,
            replacement,
            Some(datafusion::logical_expr::lit("g")), // global replace
        ))
    });

    // regexp_split(string, pattern) -> hamelin_regexp_split(string, pattern)
    registry.register::<RegexpSplit>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        Ok(crate::udf::regexp_split_udf().call(vec![string, pattern]))
    });

    // regexp_position(string, pattern) -> regexp_instr(string, pattern)
    registry.register::<RegexpPosition2>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        Ok(regex_fn::regexp_instr(
            string, pattern, None, None, None, None, None,
        ))
    });

    // regexp_position(string, pattern, start) -> regexp_instr(string, pattern, start)
    registry.register::<RegexpPosition3>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let start = params.take()?.expr;
        Ok(regex_fn::regexp_instr(
            string,
            pattern,
            Some(start),
            None,
            None,
            None,
            None,
        ))
    });

    // regexp_position(string, pattern, start, occurrence) -> regexp_instr(string, pattern, start, occurrence)
    registry.register::<RegexpPosition4>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let start = params.take()?.expr;
        let occurrence = params.take()?.expr;
        Ok(regex_fn::regexp_instr(
            string,
            pattern,
            Some(start),
            Some(occurrence),
            None,
            None,
            None,
        ))
    });
}