use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use datafusion::arrow::array::{Array, ArrayRef, AsArray, GenericStringBuilder, ListBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use regex::Regex;
use super::string_utils::{scalar_to_str, STRING_TYPES};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct RegexpSplitUdf {
signature: Signature,
}
impl Default for RegexpSplitUdf {
fn default() -> Self {
Self::new()
}
}
impl RegexpSplitUdf {
pub fn new() -> Self {
let mut sigs = Vec::new();
for s1 in &STRING_TYPES {
for s2 in &STRING_TYPES {
sigs.push(TypeSignature::Exact(vec![s1.clone(), s2.clone()]));
}
}
Self {
signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for RegexpSplitUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_regexp_split"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new_list_field(
DataType::Utf8,
true,
))))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 2 {
return exec_err!(
"regexp_split expects exactly 2 arguments, got {}",
args.len()
);
}
match (&args[0], &args[1]) {
(ColumnarValue::Scalar(string_val), ColumnarValue::Scalar(pattern_val)) => {
let string = scalar_to_str(string_val)?;
let pattern = scalar_to_str(pattern_val)?;
match (string, pattern) {
(Some(s), Some(p)) => {
let re = compile_regex(p)?;
let parts: Vec<&str> = re.split(s).collect();
let scalars: Vec<ScalarValue> = parts
.into_iter()
.map(|s| ScalarValue::Utf8(Some(s.to_string())))
.collect();
Ok(ColumnarValue::Scalar(ScalarValue::List(
ScalarValue::new_list(&scalars, &DataType::Utf8, true),
)))
}
_ => Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
DataType::Utf8,
true,
1,
))),
}
}
(ColumnarValue::Array(string_arr), ColumnarValue::Scalar(pattern_val)) => {
let pattern = scalar_to_str(pattern_val)?;
match pattern {
Some(p) => {
let re = compile_regex(p)?;
let result = match string_arr.data_type() {
DataType::Utf8 => {
split_array_with_regex(string_arr.as_string::<i32>(), &re)
}
DataType::LargeUtf8 => {
split_array_with_regex(string_arr.as_string::<i64>(), &re)
}
DataType::Utf8View => {
split_array_with_regex(string_arr.as_string_view(), &re)
}
other => {
return exec_err!(
"regexp_split expects string array, got {}",
other
)
}
};
Ok(ColumnarValue::Array(result))
}
None => {
let result = create_null_list_array(string_arr.len())?;
Ok(ColumnarValue::Array(result))
}
}
}
(ColumnarValue::Array(string_arr), ColumnarValue::Array(pattern_arr)) => {
let result = match string_arr.data_type() {
DataType::Utf8 => {
let strings = string_arr.as_string::<i32>();
dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
strings, patterns
))
}
DataType::LargeUtf8 => {
let strings = string_arr.as_string::<i64>();
dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
strings, patterns
))
}
DataType::Utf8View => {
let strings = string_arr.as_string_view();
dispatch_pattern_array!(pattern_arr, |patterns| split_arrays(
strings, patterns
))
}
other => return exec_err!("regexp_split expects string array, got {}", other),
}?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(string_val), ColumnarValue::Array(pattern_arr)) => {
let string = scalar_to_str(string_val)?;
let result = match pattern_arr.data_type() {
DataType::Utf8 => {
split_scalar_with_patterns(string, pattern_arr.as_string::<i32>())
}
DataType::LargeUtf8 => {
split_scalar_with_patterns(string, pattern_arr.as_string::<i64>())
}
DataType::Utf8View => {
split_scalar_with_patterns(string, pattern_arr.as_string_view())
}
other => {
return exec_err!(
"regexp_split expects string array for pattern, got {}",
other
)
}
}?;
Ok(ColumnarValue::Array(result))
}
}
}
}
macro_rules! dispatch_pattern_array {
($array:expr, $func:expr) => {
match $array.data_type() {
DataType::Utf8 => $func($array.as_string::<i32>()),
DataType::LargeUtf8 => $func($array.as_string::<i64>()),
DataType::Utf8View => $func($array.as_string_view()),
other => {
return exec_err!(
"regexp_split expects string array for pattern, got {}",
other
)
}
}
};
}
use dispatch_pattern_array;
fn compile_regex(pattern: &str) -> Result<Regex> {
Regex::new(pattern).map_err(|e| {
datafusion::common::DataFusionError::Execution(format!(
"Invalid regex pattern '{}': {}",
pattern, e
))
})
}
fn split_array_with_regex<T>(strings: &T, re: &Regex) -> ArrayRef
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
for opt_s in strings {
match opt_s {
None => builder.append_null(),
Some(s) => {
let values_builder = builder.values();
for part in re.split(s) {
values_builder.append_value(part);
}
builder.append(true);
}
}
}
Arc::new(builder.finish())
}
fn split_arrays<S, P>(strings: &S, patterns: &P) -> Result<ArrayRef>
where
S: Array + 'static,
P: Array + 'static,
for<'a> &'a S: IntoIterator<Item = Option<&'a str>>,
for<'a> &'a P: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
for (opt_s, opt_p) in strings.into_iter().zip(patterns.into_iter()) {
match (opt_s, opt_p) {
(Some(s), Some(p)) => {
let re = match regex_cache.entry(p.to_string()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => {
let compiled = compile_regex(p)?;
e.insert(compiled)
}
};
let values_builder = builder.values();
for part in re.split(s) {
values_builder.append_value(part);
}
builder.append(true);
}
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}
fn split_scalar_with_patterns<T>(string: Option<&str>, patterns: &T) -> Result<ArrayRef>
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
for opt_p in patterns {
match (string, opt_p) {
(Some(s), Some(p)) => {
let re = match regex_cache.entry(p.to_string()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => {
let compiled = compile_regex(p)?;
e.insert(compiled)
}
};
let values_builder = builder.values();
for part in re.split(s) {
values_builder.append_value(part);
}
builder.append(true);
}
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}
fn create_null_list_array(len: usize) -> Result<ArrayRef> {
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
for _ in 0..len {
builder.append_null();
}
Ok(Arc::new(builder.finish()))
}
static REGEXP_SPLIT_UDF: OnceLock<ScalarUDF> = OnceLock::new();
pub fn regexp_split_udf() -> ScalarUDF {
REGEXP_SPLIT_UDF
.get_or_init(|| ScalarUDF::new_from_impl(RegexpSplitUdf::new()))
.clone()
}