use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use datafusion::arrow::array::{Array, ArrayRef, AsArray, GenericStringBuilder, ListBuilder};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use regex::Regex;
use super::string_utils::{scalar_to_str, STRING_TYPES};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct RegexpExtractAllUdf {
signature: Signature,
}
impl Default for RegexpExtractAllUdf {
fn default() -> Self {
Self::new()
}
}
impl RegexpExtractAllUdf {
pub fn new() -> Self {
let mut sigs = Vec::new();
for s1 in &STRING_TYPES {
for s2 in &STRING_TYPES {
sigs.push(TypeSignature::Exact(vec![s1.clone(), s2.clone()]));
sigs.push(TypeSignature::Exact(vec![
s1.clone(),
s2.clone(),
DataType::Int64,
]));
}
}
Self {
signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for RegexpExtractAllUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_regexp_extract_all"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new_list_field(
DataType::Utf8,
true,
))))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() < 2 || args.len() > 3 {
return exec_err!(
"regexp_extract_all expects 2 or 3 arguments, got {}",
args.len()
);
}
let group_idx = if args.len() == 3 {
match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(g))) if *g >= 0 => *g as usize,
ColumnarValue::Scalar(ScalarValue::Int64(Some(_) | None)) => {
return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
DataType::Utf8,
true,
1,
)));
}
_ => return exec_err!("regexp_extract_all group must be an integer"),
}
} else {
0 };
match (&args[0], &args[1]) {
(ColumnarValue::Scalar(string_val), ColumnarValue::Scalar(pattern_val)) => {
let string = scalar_to_str(string_val)?;
let pattern = scalar_to_str(pattern_val)?;
match (string, pattern) {
(Some(s), Some(p)) => {
let re = compile_regex(p)?;
match extract_all_matches(s, &re, group_idx) {
Some(matches) => {
let scalars: Vec<ScalarValue> = matches
.into_iter()
.map(|m| ScalarValue::Utf8(Some(m)))
.collect();
Ok(ColumnarValue::Scalar(ScalarValue::List(
ScalarValue::new_list(&scalars, &DataType::Utf8, true),
)))
}
None => Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
DataType::Utf8,
true,
1,
))),
}
}
_ => Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
DataType::Utf8,
true,
1,
))),
}
}
(ColumnarValue::Array(string_arr), ColumnarValue::Scalar(pattern_val)) => {
let pattern = scalar_to_str(pattern_val)?;
match pattern {
Some(p) => {
let re = compile_regex(p)?;
let result = match string_arr.data_type() {
DataType::Utf8 => extract_array_with_regex(
string_arr.as_string::<i32>(),
&re,
group_idx,
),
DataType::LargeUtf8 => extract_array_with_regex(
string_arr.as_string::<i64>(),
&re,
group_idx,
),
DataType::Utf8View => extract_array_with_regex(
string_arr.as_string_view(),
&re,
group_idx,
),
other => {
return exec_err!(
"regexp_extract_all expects string array, got {}",
other
)
}
};
Ok(ColumnarValue::Array(result))
}
None => {
let result = create_null_list_array(string_arr.len())?;
Ok(ColumnarValue::Array(result))
}
}
}
(ColumnarValue::Array(string_arr), ColumnarValue::Array(pattern_arr)) => {
let result = match string_arr.data_type() {
DataType::Utf8 => {
let strings = string_arr.as_string::<i32>();
dispatch_pattern_array!(pattern_arr, |patterns| extract_arrays(
strings, patterns, group_idx
))
}
DataType::LargeUtf8 => {
let strings = string_arr.as_string::<i64>();
dispatch_pattern_array!(pattern_arr, |patterns| extract_arrays(
strings, patterns, group_idx
))
}
DataType::Utf8View => {
let strings = string_arr.as_string_view();
dispatch_pattern_array!(pattern_arr, |patterns| extract_arrays(
strings, patterns, group_idx
))
}
other => {
return exec_err!("regexp_extract_all expects string array, got {}", other)
}
}?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(string_val), ColumnarValue::Array(pattern_arr)) => {
let string = scalar_to_str(string_val)?;
let result = match pattern_arr.data_type() {
DataType::Utf8 => extract_scalar_with_patterns(
string,
pattern_arr.as_string::<i32>(),
group_idx,
),
DataType::LargeUtf8 => extract_scalar_with_patterns(
string,
pattern_arr.as_string::<i64>(),
group_idx,
),
DataType::Utf8View => extract_scalar_with_patterns(
string,
pattern_arr.as_string_view(),
group_idx,
),
other => {
return exec_err!(
"regexp_extract_all expects string array for pattern, got {}",
other
)
}
}?;
Ok(ColumnarValue::Array(result))
}
}
}
}
macro_rules! dispatch_pattern_array {
($array:expr, $func:expr) => {
match $array.data_type() {
DataType::Utf8 => $func($array.as_string::<i32>()),
DataType::LargeUtf8 => $func($array.as_string::<i64>()),
DataType::Utf8View => $func($array.as_string_view()),
other => {
return exec_err!(
"regexp_extract_all expects string array for pattern, got {}",
other
)
}
}
};
}
use dispatch_pattern_array;
fn compile_regex(pattern: &str) -> Result<Regex> {
Regex::new(pattern).map_err(|e| {
datafusion::common::DataFusionError::Execution(format!(
"Invalid regex pattern '{}': {}",
pattern, e
))
})
}
fn extract_all_matches(s: &str, re: &Regex, group_idx: usize) -> Option<Vec<String>> {
if group_idx == 0 {
Some(re.find_iter(s).map(|m| m.as_str().to_string()).collect())
} else if group_idx >= re.captures_len() {
None
} else {
Some(
re.captures_iter(s)
.filter_map(|caps| caps.get(group_idx).map(|m| m.as_str().to_string()))
.collect(),
)
}
}
fn extract_array_with_regex<T>(strings: &T, re: &Regex, group_idx: usize) -> ArrayRef
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
for opt_s in strings {
match opt_s {
None => builder.append_null(),
Some(s) => match extract_all_matches(s, re, group_idx) {
Some(matches) => {
let values_builder = builder.values();
for m in matches {
values_builder.append_value(m);
}
builder.append(true);
}
None => builder.append_null(),
},
}
}
Arc::new(builder.finish())
}
fn extract_arrays<S, P>(strings: &S, patterns: &P, group_idx: usize) -> Result<ArrayRef>
where
S: Array + 'static,
P: Array + 'static,
for<'a> &'a S: IntoIterator<Item = Option<&'a str>>,
for<'a> &'a P: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
for (opt_s, opt_p) in strings.into_iter().zip(patterns.into_iter()) {
match (opt_s, opt_p) {
(Some(s), Some(p)) => {
let re = match regex_cache.entry(p.to_string()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(compile_regex(p)?),
};
match extract_all_matches(s, re, group_idx) {
Some(matches) => {
let values_builder = builder.values();
for m in matches {
values_builder.append_value(m);
}
builder.append(true);
}
None => builder.append_null(),
}
}
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}
fn extract_scalar_with_patterns<T>(
string: Option<&str>,
patterns: &T,
group_idx: usize,
) -> Result<ArrayRef>
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
for opt_p in patterns {
match (string, opt_p) {
(Some(s), Some(p)) => {
let re = match regex_cache.entry(p.to_string()) {
Entry::Occupied(e) => e.into_mut(),
Entry::Vacant(e) => e.insert(compile_regex(p)?),
};
match extract_all_matches(s, re, group_idx) {
Some(matches) => {
let values_builder = builder.values();
for m in matches {
values_builder.append_value(m);
}
builder.append(true);
}
None => builder.append_null(),
}
}
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()))
}
fn create_null_list_array(len: usize) -> Result<ArrayRef> {
let mut builder = ListBuilder::new(GenericStringBuilder::<i32>::new());
for _ in 0..len {
builder.append_null();
}
Ok(Arc::new(builder.finish()))
}
static REGEXP_EXTRACT_ALL_UDF: OnceLock<ScalarUDF> = OnceLock::new();
pub fn regexp_extract_all_udf() -> ScalarUDF {
REGEXP_EXTRACT_ALL_UDF
.get_or_init(|| ScalarUDF::new_from_impl(RegexpExtractAllUdf::new()))
.clone()
}