hamelin_datafusion 0.7.5

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! DataFusion translations for string functions.

use datafusion::logical_expr::{BinaryExpr, Expr as DFExpr, Operator as DFOperator};
use datafusion_functions::string::expr_fn as string_fn;
use datafusion_functions::unicode::expr_fn as unicode_fn;
use datafusion_functions::unicode::{strpos, substr, substring};

use hamelin_lib::func::defs::{
    CidrContains, Contains, EndsWith, IsIpv4, IsIpv6, Lower, Replace2, Replace3, StartsWith,
    StringConcat, StringLen, Substr2, Substr3, Upper, Uuid, Uuid5,
};

use super::DataFusionTranslationRegistry;
use crate::udf::uuid5_udf;

pub fn register(registry: &mut DataFusionTranslationRegistry) {
    // left + right -> concat(left, right)
    registry.register::<StringConcat>(|mut params| {
        let left = params.take()?.expr;
        let right = params.take()?.expr;
        Ok(DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(left),
            DFOperator::StringConcat,
            Box::new(right),
        )))
    });

    // replace(string, pattern) -> replace(string, pattern, '')
    registry.register::<Replace2>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let empty = datafusion::logical_expr::lit("");
        Ok(string_fn::replace(string, pattern, empty))
    });

    // replace(string, pattern, replacement) -> replace(string, pattern, replacement)
    registry.register::<Replace3>(|mut params| {
        let string = params.take()?.expr;
        let pattern = params.take()?.expr;
        let replacement = params.take()?.expr;
        Ok(string_fn::replace(string, pattern, replacement))
    });

    // substr(string, start) -> substr(string, adjusted_start)
    // Hamelin uses 0-based indexing, DataFusion uses 1-based
    // For negative indices: convert to positive using len(string) + start + 1
    // Formula: CASE WHEN start >= 0 THEN start + 1 ELSE character_length(string) + start + 1 END
    registry.register::<Substr2>(|mut params| {
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let string = params.take_by_name("string")?.expr;
        let start = params.take_by_name("start")?.expr;

        // CASE WHEN start >= 0 THEN start + 1 ELSE character_length(string) + start + 1 END
        let condition = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(start.clone()),
            DFOperator::GtEq,
            Box::new(lit(0i64)),
        ));
        // Positive case: start + 1
        let then_expr = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(start.clone()),
            DFOperator::Plus,
            Box::new(lit(1i64)),
        ));
        // Negative case: character_length(string) + start + 1
        let else_expr = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(DFExpr::BinaryExpr(BinaryExpr::new(
                Box::new(unicode_fn::character_length(string.clone())),
                DFOperator::Plus,
                Box::new(start),
            ))),
            DFOperator::Plus,
            Box::new(lit(1i64)),
        ));
        let start_adjusted = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(Box::new(condition), Box::new(then_expr))],
            else_expr: Some(Box::new(else_expr)),
        });

        Ok(substr().call(vec![string, start_adjusted]))
    });

    // substr(string, start, end) -> substring(string, adjusted_start, length)
    // Hamelin: 0-based indexing, end is exclusive, negative indices count from end
    // DataFusion: 1-based indexing, takes length (not end index)
    //
    // Both start and end must be resolved to absolute 0-based indices before
    // computing length, since they may be in different coordinate systems
    // (e.g. start=-2, end=3).
    registry.register::<Substr3>(|mut params| {
        use datafusion::logical_expr::expr::Case as DFCase;
        use datafusion::logical_expr::lit;

        let string = params.take_by_name("string")?.expr;
        let start = params.take_by_name("start")?.expr;
        let end = params.take_by_name("end")?.expr;

        let char_len = unicode_fn::character_length(string.clone());

        // Resolve start to absolute 0-based index:
        // CASE WHEN start >= 0 THEN start ELSE character_length(string) + start END
        let resolved_start = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(
                Box::new(DFExpr::BinaryExpr(BinaryExpr::new(
                    Box::new(start.clone()),
                    DFOperator::GtEq,
                    Box::new(lit(0i64)),
                ))),
                Box::new(start.clone()),
            )],
            else_expr: Some(Box::new(DFExpr::BinaryExpr(BinaryExpr::new(
                Box::new(char_len.clone()),
                DFOperator::Plus,
                Box::new(start),
            )))),
        });

        // Resolve end to absolute 0-based index:
        // CASE WHEN end >= 0 THEN end ELSE character_length(string) + end END
        let resolved_end = DFExpr::Case(DFCase {
            expr: None,
            when_then_expr: vec![(
                Box::new(DFExpr::BinaryExpr(BinaryExpr::new(
                    Box::new(end.clone()),
                    DFOperator::GtEq,
                    Box::new(lit(0i64)),
                ))),
                Box::new(end.clone()),
            )],
            else_expr: Some(Box::new(DFExpr::BinaryExpr(BinaryExpr::new(
                Box::new(char_len),
                DFOperator::Plus,
                Box::new(end),
            )))),
        });

        // Convert resolved_start to 1-based for DataFusion: resolved_start + 1
        let start_adjusted = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(resolved_start.clone()),
            DFOperator::Plus,
            Box::new(lit(1i64)),
        ));

        // length = resolved_end - resolved_start (both in absolute 0-based space)
        let length = DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(resolved_end),
            DFOperator::Minus,
            Box::new(resolved_start),
        ));

        Ok(substring().call(vec![string, start_adjusted, length]))
    });

    // starts_with(string, prefix) -> starts_with(string, prefix)
    registry.register::<StartsWith>(|mut params| {
        let string = params.take()?.expr;
        let prefix = params.take()?.expr;
        Ok(string_fn::starts_with(string, prefix))
    });

    // ends_with(string, suffix) -> ends_with(string, suffix)
    registry.register::<EndsWith>(|mut params| {
        let string = params.take()?.expr;
        let suffix = params.take()?.expr;
        Ok(string_fn::ends_with(string, suffix))
    });

    // contains(string, substring) -> strpos(string, substring) > 0
    registry.register::<Contains>(|mut params| {
        let string = params.take()?.expr;
        let substr = params.take()?.expr;
        // strpos returns 0 if not found, > 0 if found
        let strpos_result = strpos().call(vec![string, substr]);
        Ok(DFExpr::BinaryExpr(BinaryExpr::new(
            Box::new(strpos_result),
            DFOperator::Gt,
            Box::new(datafusion::logical_expr::lit(0i64)),
        )))
    });

    // cidr_contains(cidr, ip) -> hamelin_cidr_contains(cidr, ip)
    registry.register::<CidrContains>(|mut params| {
        let cidr = params.take()?.expr;
        let ip = params.take()?.expr;
        Ok(crate::udf::cidr_contains_udf().call(vec![cidr, ip]))
    });

    // is_ipv4(ip) -> hamelin_is_ipv4(ip)
    registry.register::<IsIpv4>(|mut params| {
        let ip = params.take()?.expr;
        Ok(crate::udf::is_ipv4_udf().call(vec![ip]))
    });

    // is_ipv6(ip) -> hamelin_is_ipv6(ip)
    registry.register::<IsIpv6>(|mut params| {
        let ip = params.take()?.expr;
        Ok(crate::udf::is_ipv6_udf().call(vec![ip]))
    });

    // lower(string) -> lower(string)
    registry.register::<Lower>(|mut params| {
        let string = params.take()?.expr;
        Ok(string_fn::lower(string))
    });

    // upper(string) -> upper(string)
    registry.register::<Upper>(|mut params| {
        let string = params.take()?.expr;
        Ok(string_fn::upper(string))
    });

    // len(string) -> character_length(string)
    registry.register::<StringLen>(|mut params| {
        let string = params.take()?.expr;
        Ok(unicode_fn::character_length(string))
    });

    // uuid() -> uuid()
    // DataFusion has a built-in uuid() function
    registry.register::<Uuid>(|_params| Ok(string_fn::uuid()));

    // uuid5(name) -> hamelin_uuid5(name)
    registry.register::<Uuid5>(|mut params| {
        let name = params.take()?.expr;
        Ok(uuid5_udf().call(vec![name]))
    });
}