hamelin_datafusion 0.6.12

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! regexp_split UDF for DataFusion.
//!
//! Splits a string by a regex pattern, returning an array of substrings.
//! This is similar to PostgreSQL's `regexp_split_to_array` function.

use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};

use datafusion::arrow::array::{Array, ArrayRef, AsArray, GenericStringBuilder, ListBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
    Volatility,
};
use regex::Regex;

use super::string_utils::{scalar_to_str, STRING_TYPES};

/// UDF that splits a string by a regex pattern.
///
/// `regexp_split(string, pattern)` -> `array<string>`
///
/// Returns an array of substrings split by matches of the pattern.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct RegexpSplitUdf {
    signature: Signature,
}

impl Default for RegexpSplitUdf {
    fn default() -> Self {
        Self::new()
    }
}

impl RegexpSplitUdf {
    pub fn new() -> Self {
        let mut sigs = Vec::new();
        for s1 in &STRING_TYPES {
            for s2 in &STRING_TYPES {
                sigs.push(TypeSignature::Exact(vec![s1.clone(), s2.clone()]));
            }
        }
        Self {
            signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
        }
    }
}

impl ScalarUDFImpl for RegexpSplitUdf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_regexp_split"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::List(Arc::new(Field::new_list_field(
            DataType::Utf8,
            true,
        ))))
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        let args = args.args;
        if args.len() != 2 {
            return exec_err!(
                "regexp_split expects exactly 2 arguments, got {}",
                args.len()
            );
        }

        match (&args[0], &args[1]) {
            // Both scalar
            (ColumnarValue::Scalar(string_val), ColumnarValue::Scalar(pattern_val)) => {
                let string = scalar_to_str(string_val)?;
                let pattern = scalar_to_str(pattern_val)?;

                match (string, pattern) {
                    (Some(s), Some(p)) => {
                        let re = compile_regex(p)?;
                        let parts: Vec<&str> = re.split(s).collect();
                        let scalars: Vec<ScalarValue> = parts
                            .into_iter()
                            .map(|s| ScalarValue::Utf8(Some(s.to_string())))
                            .collect();
                        Ok(ColumnarValue::Scalar(ScalarValue::List(
                            ScalarValue::new_list(&scalars, &DataType::Utf8, true),
                        )))
                    }
                    _ => Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
                        DataType::Utf8,
                        true,
                        1,
                    ))),
                }
            }

            // Array string, scalar pattern (common case)
            (ColumnarValue::Array(string_arr), ColumnarValue::Scalar(pattern_val)) => {
                let pattern = scalar_to_str(pattern_val)?;

                match pattern {
                    Some(p) => {
                        let re = compile_regex(p)?;
                        let result = match string_arr.data_type() {
                            DataType::Utf8 => {
                                split_array_with_regex(string_arr.as_string::<i32>(), &re)
                            }
                            DataType::LargeUtf8 => {
                                split_array_with_regex(string_arr.as_string::<i64>(), &re)
                            }
                            DataType::Utf8View => {
                                split_array_with_regex(string_arr.as_string_view(), &re)
                            }
                            other => {
                                return exec_err!(
                                    "regexp_split expects string array, got {}",
                                    other
                                )
                            }
                        };
                        Ok(ColumnarValue::Array(result))
                    }
                    None => {
                        let result = create_null_list_array(string_arr.len())?;
                        Ok(ColumnarValue::Array(result))
                    }
                }
            }

            // Both arrays
            (ColumnarValue::Array(string_arr), ColumnarValue::Array(pattern_arr)) => {
                let result = match string_arr.data_type() {
                    DataType::Utf8 => {
                        let strings = string_arr.as_string::<i32>();
                        dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
                            strings, patterns
                        ))
                    }
                    DataType::LargeUtf8 => {
                        let strings = string_arr.as_string::<i64>();
                        dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
                            strings, patterns
                        ))
                    }
                    DataType::Utf8View => {
                        let strings = string_arr.as_string_view();
                        dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
                            strings, patterns
                        ))
                    }
                    other => return exec_err!("regexp_split expects string array, got {}", other),
                }?;
                Ok(ColumnarValue::Array(result))
            }

            // Scalar string, array pattern (less common)
            (ColumnarValue::Scalar(string_val), ColumnarValue::Array(pattern_arr)) => {
                let string = scalar_to_str(string_val)?;
                let result = match pattern_arr.data_type() {
                    DataType::Utf8 => {
                        split_scalar_with_patterns(string, pattern_arr.as_string::<i32>())
                    }
                    DataType::LargeUtf8 => {
                        split_scalar_with_patterns(string, pattern_arr.as_string::<i64>())
                    }
                    DataType::Utf8View => {
                        split_scalar_with_patterns(string, pattern_arr.as_string_view())
                    }
                    other => {
                        return exec_err!(
                            "regexp_split expects string array for pattern, got {}",
                            other
                        )
                    }
                }?;
                Ok(ColumnarValue::Array(result))
            }
        }
    }
}

/// Dispatch a pattern array to one of the three concrete string types.
macro_rules! dispatch_pattern_array {
    ($array:expr, $func:expr) => {
        match $array.data_type() {
            DataType::Utf8 => $func($array.as_string::<i32>()),
            DataType::LargeUtf8 => $func($array.as_string::<i64>()),
            DataType::Utf8View => $func($array.as_string_view()),
            other => {
                return exec_err!(
                    "regexp_split expects string array for pattern, got {}",
                    other
                )
            }
        }
    };
}
use dispatch_pattern_array;

/// Compile a regex pattern.
fn compile_regex(pattern: &str) -> Result<Regex> {
    Regex::new(pattern).map_err(|e| {
        datafusion::common::DataFusionError::Execution(format!(
            "Invalid regex pattern '{}': {}",
            pattern, e
        ))
    })
}

/// Split an array of strings using a single compiled regex.
fn split_array_with_regex<T>(strings: &T, re: &Regex) -> ArrayRef
where
    T: Array + 'static,
    for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
    let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());

    for opt_s in strings {
        match opt_s {
            None => builder.append_null(),
            Some(s) => {
                let values_builder = builder.values();
                for part in re.split(s) {
                    values_builder.append_value(part);
                }
                builder.append(true);
            }
        }
    }

    Arc::new(builder.finish())
}

/// Split arrays element-wise (different pattern per row).
fn split_arrays<S, P>(strings: &S, patterns: &P) -> Result<ArrayRef>
where
    S: Array + 'static,
    P: Array + 'static,
    for<'a> &'a S: IntoIterator<Item = Option<&'a str>>,
    for<'a> &'a P: IntoIterator<Item = Option<&'a str>>,
{
    let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
    let mut regex_cache: HashMap<String, Regex> = HashMap::new();

    for (opt_s, opt_p) in strings.into_iter().zip(patterns.into_iter()) {
        match (opt_s, opt_p) {
            (Some(s), Some(p)) => {
                let re = match regex_cache.entry(p.to_string()) {
                    Entry::Occupied(e) => e.into_mut(),
                    Entry::Vacant(e) => {
                        let compiled = compile_regex(p)?;
                        e.insert(compiled)
                    }
                };

                let values_builder = builder.values();
                for part in re.split(s) {
                    values_builder.append_value(part);
                }
                builder.append(true);
            }
            _ => builder.append_null(),
        }
    }

    Ok(Arc::new(builder.finish()))
}

/// Split a scalar string with an array of patterns.
fn split_scalar_with_patterns<T>(string: Option<&str>, patterns: &T) -> Result<ArrayRef>
where
    T: Array + 'static,
    for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
    let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
    let mut regex_cache: HashMap<String, Regex> = HashMap::new();

    for opt_p in patterns {
        match (string, opt_p) {
            (Some(s), Some(p)) => {
                let re = match regex_cache.entry(p.to_string()) {
                    Entry::Occupied(e) => e.into_mut(),
                    Entry::Vacant(e) => {
                        let compiled = compile_regex(p)?;
                        e.insert(compiled)
                    }
                };

                let values_builder = builder.values();
                for part in re.split(s) {
                    values_builder.append_value(part);
                }
                builder.append(true);
            }
            _ => builder.append_null(),
        }
    }

    Ok(Arc::new(builder.finish()))
}

/// Create a null list array of the given length.
fn create_null_list_array(len: usize) -> Result<ArrayRef> {
    let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
    for _ in 0..len {
        builder.append_null();
    }
    Ok(Arc::new(builder.finish()))
}

/// Get the regexp_split UDF.
static REGEXP_SPLIT_UDF: OnceLock<ScalarUDF> = OnceLock::new();

pub fn regexp_split_udf() -> ScalarUDF {
    REGEXP_SPLIT_UDF
        .get_or_init(|| ScalarUDF::new_from_impl(RegexpSplitUdf::new()))
        .clone()
}