hamelin_sql 0.7.1

SQL generation utilities for Hamelin query language
Documentation
//! Translation implementations for string operators and functions

use crate::utils::{
    direct_function_translation, hamelin_string_index_to_sql_with_negatives,
    string_negative_index_to_positive, to_sql_concat,
};
use crate::TranslationRegistry;
use hamelin_lib::func::defs::{
    CidrContains, Contains, EndsWith, IsIpv4, IsIpv6, Lower, Replace2, Replace3, StartsWith,
    StringConcat, StringLen, Substr2, Substr3, Upper, Uuid, Uuid5,
};
use hamelin_lib::sql::expression::apply::{BinaryOperatorApply, FunctionCallApply};
use hamelin_lib::sql::expression::literal::{IntegerLiteral, StringLiteral};
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::expression::{Cast, SQLExpression, TryCast};
use hamelin_lib::sql::types::{SQLBaseType, SQLType};

/// Register all string operator and function translations.
pub fn register(registry: &mut TranslationRegistry) {
    // StringConcat: string + string -> SQL CONCAT
    registry.register::<StringConcat>(to_sql_concat);

    // replace(string, pattern) - pass through
    registry.register::<Replace2>(direct_function_translation);

    // replace(string, pattern, replacement) - pass through
    registry.register::<Replace3>(direct_function_translation);

    // substr(string, start) - convert 0-based Hamelin index to 1-based SQL
    registry.register::<Substr2>(|_, mut bindings| {
        let string = bindings.take()?;
        let start = bindings.take()?;
        // Convert 0-based Hamelin index to 1-based SQL index
        let sql_start = hamelin_string_index_to_sql_with_negatives(start.sql, string.sql.clone());
        Ok(FunctionCallApply::with_two("substr", string.sql, sql_start).into())
    });

    // substr(string, start, end) - convert indices
    registry.register::<Substr3>(|_, mut bindings| {
        let string = bindings.take()?;
        let start = bindings.take()?;
        let end = bindings.take()?;

        // Convert negative indices to positive in Hamelin space
        let positive_end = string_negative_index_to_positive(end.sql.clone(), string.sql.clone());
        let positive_start =
            string_negative_index_to_positive(start.sql.clone(), string.sql.clone());

        // Calculate length as positive_end - positive_start (both in Hamelin zero-based space)
        let length = BinaryOperatorApply::new(Operator::Minus, positive_end, positive_start).into();

        // Convert start to SQL space: add 1 for 1-based indexing
        let sql_start = hamelin_string_index_to_sql_with_negatives(start.sql, string.sql.clone());

        Ok(FunctionCallApply::with_three("substr", string.sql, sql_start, length).into())
    });

    // starts_with(string, prefix) - pass through
    registry.register::<StartsWith>(direct_function_translation);

    // ends_with(string, suffix) - translate to LIKE '%suffix'
    registry.register::<EndsWith>(|_, mut bindings| {
        let string = bindings.take()?;
        let suffix = bindings.take()?;
        Ok(BinaryOperatorApply::new(
            Operator::Like,
            string.sql,
            BinaryOperatorApply::new(Operator::Concat, StringLiteral::new("%").into(), suffix.sql)
                .into(),
        )
        .into())
    });

    // contains(string, substring) - translate to LIKE '%substring%'
    registry.register::<Contains>(|_, mut bindings| {
        let string = bindings.take()?;
        let substring = bindings.take()?;
        Ok(BinaryOperatorApply::new(
            Operator::Like,
            string.sql,
            BinaryOperatorApply::new(
                Operator::Concat,
                StringLiteral::new("%").into(),
                BinaryOperatorApply::new(
                    Operator::Concat,
                    substring.sql,
                    StringLiteral::new("%").into(),
                )
                .into(),
            )
            .into(),
        )
        .into())
    });

    // cidr_contains(cidr, ip) - translate to try(contains(cidr, try_cast(ip as ipaddress)))
    registry.register::<CidrContains>(|_, mut bindings| {
        let cidr = bindings.take()?;
        let ip = bindings.take()?;

        Ok(FunctionCallApply::with_one(
            "try",
            FunctionCallApply::with_two(
                "contains",
                cidr.sql,
                TryCast::new(ip.sql, SQLType::SQLBaseType(SQLBaseType::IpAddress)).into(),
            )
            .into(),
        )
        .into())
    });

    // is_ipv4(ip) - cast to ipaddress then to varchar and check for '.'
    registry.register::<IsIpv4>(|_, mut bindings| {
        let ip = bindings.take()?;

        Ok(BinaryOperatorApply::new(
            Operator::Like,
            Cast::new(
                TryCast::new(ip.sql, SQLType::SQLBaseType(SQLBaseType::IpAddress)).into(),
                SQLType::SQLBaseType(SQLBaseType::VarChar),
            )
            .into(),
            StringLiteral::new("%.%").into(),
        )
        .into())
    });

    // is_ipv6(ip) - cast to ipaddress then to varchar and check for ':'
    registry.register::<IsIpv6>(|_, mut bindings| {
        let ip = bindings.take()?;

        Ok(BinaryOperatorApply::new(
            Operator::Like,
            Cast::new(
                TryCast::new(ip.sql, SQLType::SQLBaseType(SQLBaseType::IpAddress)).into(),
                SQLType::SQLBaseType(SQLBaseType::VarChar),
            )
            .into(),
            StringLiteral::new("%:%").into(),
        )
        .into())
    });

    // lower(string) - pass through
    registry.register::<Lower>(direct_function_translation);

    // upper(string) - pass through
    registry.register::<Upper>(direct_function_translation);

    // len(string) - translate to length()
    registry.register::<StringLen>(|_, mut bindings| {
        let x = bindings.take()?;
        Ok(FunctionCallApply::with_one("length", x.sql).into())
    });

    // uuid() - generate a random UUID string
    // Trino/Presto: uuid() returns a UUID type, cast to varchar for STRING
    registry.register::<Uuid>(|_, _bindings| {
        let uuid_call = FunctionCallApply::with_no_arguments("uuid");
        Ok(Cast::new(uuid_call.into(), SQLBaseType::VarChar.into()).into())
    });

    // uuid5(name) - RFC 4122 UUID v5 with default namespace (NAMESPACE_OID)
    registry.register::<Uuid5>(|_, mut bindings| {
        let name = bindings.take()?;
        // Default namespace: NAMESPACE_OID (6ba7b812-9dad-11d1-80b4-00c04fd430c8) as varbinary
        let namespace_bytes: SQLExpression = FunctionCallApply::with_one(
            "from_hex",
            StringLiteral::new("6ba7b8129dad11d180b400c04fd430c8").into(),
        )
        .into();
        let to_utf8_name = FunctionCallApply::with_one("to_utf8", name.sql).into();
        let input =
            FunctionCallApply::with_positional("concat", vec![namespace_bytes, to_utf8_name]);
        let sha1_hex: SQLExpression = FunctionCallApply::with_one(
            "to_hex",
            FunctionCallApply::with_one("sha1", input.into()).into(),
        )
        .into();

        // RFC 4122 §4.3 layout over sha1_hex (40-char lowercase hex, 1-indexed):
        // Bytes 0-5  (chars  1-12): sha1[0..5]  — unchanged
        // Byte  6    (chars 13-14): '5' || sha1_hex[14]  — version nibble forced to 5
        // Byte  7    (chars 15-16): sha1[7]     — unchanged
        // Byte  8    (chars 17-18): (sha1[8] & 0x3f) | 0x80  — variant bits
        // Bytes 9-15 (chars 19-32): sha1[9..15] — unchanged
        let part1 = FunctionCallApply::with_three(
            "substr",
            sha1_hex.clone(),
            IntegerLiteral::new("1").into(),
            IntegerLiteral::new("12").into(),
        )
        .into();
        let part2 = BinaryOperatorApply::new(
            Operator::Concat,
            StringLiteral::new("5").into(),
            FunctionCallApply::with_three(
                "substr",
                sha1_hex.clone(),
                IntegerLiteral::new("14").into(),
                IntegerLiteral::new("1").into(),
            )
            .into(),
        )
        .into();
        let part3 = FunctionCallApply::with_three(
            "substr",
            sha1_hex.clone(),
            IntegerLiteral::new("15").into(),
            IntegerLiteral::new("2").into(),
        )
        .into();

        // byte8 as int: from_big_endian_32(concat(from_hex('000000'), from_hex(substr(sha1_hex, 17, 2))))
        let sha1_hex_17_2 = FunctionCallApply::with_three(
            "substr",
            sha1_hex.clone(),
            IntegerLiteral::new("17").into(),
            IntegerLiteral::new("2").into(),
        )
        .into();
        let one_byte_8 = FunctionCallApply::with_one("from_hex", sha1_hex_17_2).into();
        let four_byte_8 = FunctionCallApply::with_positional(
            "concat",
            vec![
                FunctionCallApply::with_one("from_hex", StringLiteral::new("000000").into()).into(),
                one_byte_8,
            ],
        )
        .into();
        let byte8_int: SQLExpression =
            FunctionCallApply::with_one("from_big_endian_32", four_byte_8).into();
        let byte8_masked =
            FunctionCallApply::with_two("bitwise_and", byte8_int, IntegerLiteral::new("63").into())
                .into();
        let byte8_val = FunctionCallApply::with_two(
            "bitwise_or",
            byte8_masked,
            IntegerLiteral::new("128").into(),
        )
        .into();
        let byte8_hex = FunctionCallApply::with_three(
            "lpad",
            FunctionCallApply::with_two("to_base", byte8_val, IntegerLiteral::new("16").into())
                .into(),
            IntegerLiteral::new("2").into(),
            StringLiteral::new("0").into(),
        )
        .into();

        // Bytes 9-15: sha1[9..15] unchanged — substr(sha1_hex, 19, 14) (7 bytes = 14 hex chars)
        let bytes_9_15: SQLExpression = FunctionCallApply::with_three(
            "substr",
            sha1_hex.clone(),
            IntegerLiteral::new("19").into(),
            IntegerLiteral::new("14").into(),
        )
        .into();

        let full_hex_parts = vec![part1, part2, part3, byte8_hex, bytes_9_15];
        let full_hex = FunctionCallApply::with_positional("concat", full_hex_parts).into();
        let lower_hex = FunctionCallApply::with_one("lower", full_hex).into();
        let pattern = StringLiteral::new("^(.{8})(.{4})(.{4})(.{4})(.{12})$").into();
        let replacement = StringLiteral::new("$1-$2-$3-$4-$5").into();
        Ok(FunctionCallApply::with_three("regexp_replace", lower_hex, pattern, replacement).into())
    });
}