use crate::utils::utf8_to_str_type;
use arrow::array::{
Array, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array,
StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder,
make_view, new_null_array,
};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_common::cast::as_int64_array;
use datafusion_common::types::{NativeType, logical_int64, logical_string};
use datafusion_common::{Result, exec_datafusion_err, exec_err};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility,
};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_macros::user_doc;
use memchr::memmem;
use std::sync::Arc;
#[user_doc(
doc_section(label = "String Functions"),
description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
syntax_example = "split_part(str, delimiter, pos)",
sql_example = r#"```sql
> select split_part('1.2.3.4.5', '.', 3);
+--------------------------------------------------+
| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
+--------------------------------------------------+
| 3 |
+--------------------------------------------------+
```"#,
standard_argument(name = "str", prefix = "String"),
argument(name = "delimiter", description = "String or character to split on."),
argument(
name = "pos",
description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SplitPartFunc {
signature: Signature,
}
impl Default for SplitPartFunc {
fn default() -> Self {
Self::new()
}
}
impl SplitPartFunc {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Integer],
NativeType::Int64,
),
],
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for SplitPartFunc {
fn name(&self) -> &str {
"split_part"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types[0] == DataType::Utf8View {
Ok(DataType::Utf8View)
} else {
utf8_to_str_type(&arg_types[0], "split_part")
}
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
if let (
ColumnarValue::Array(string_array),
ColumnarValue::Scalar(delim_scalar),
ColumnarValue::Scalar(pos_scalar),
) = (&args[0], &args[1], &args[2])
{
return split_part_scalar(string_array, delim_scalar, pos_scalar);
}
let len = args.iter().find_map(|arg| match arg {
ColumnarValue::Array(a) => Some(a.len()),
_ => None,
});
let inferred_length = len.unwrap_or(1);
let is_scalar = len.is_none();
let args = args
.iter()
.map(|arg| match arg {
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
})
.collect::<Result<Vec<_>>>()?;
let n_array = as_int64_array(&args[2])?;
macro_rules! split_part_for_delimiter_type {
($str_arr:expr, $builder:expr) => {
match args[1].data_type() {
DataType::Utf8View => split_part_impl(
$str_arr,
&args[1].as_string_view(),
n_array,
$builder,
),
DataType::Utf8 => split_part_impl(
$str_arr,
&args[1].as_string::<i32>(),
n_array,
$builder,
),
DataType::LargeUtf8 => split_part_impl(
$str_arr,
&args[1].as_string::<i64>(),
n_array,
$builder,
),
other => {
exec_err!("Unsupported delimiter type {other:?} for split_part")
}
}
};
}
let result = match args[0].data_type() {
DataType::Utf8View => split_part_for_delimiter_type!(
&args[0].as_string_view(),
StringViewBuilder::with_capacity(inferred_length)
),
DataType::Utf8 => {
let str_arr = &args[0].as_string::<i32>();
split_part_for_delimiter_type!(
str_arr,
GenericStringBuilder::<i32>::with_capacity(
inferred_length,
inferred_length,
)
)
}
DataType::LargeUtf8 => {
let str_arr = &args[0].as_string::<i64>();
split_part_for_delimiter_type!(
str_arr,
GenericStringBuilder::<i64>::with_capacity(
inferred_length,
inferred_length,
)
)
}
other => exec_err!("Unsupported string type {other:?} for split_part"),
};
if is_scalar {
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
#[inline]
fn split_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> {
if delimiter.len() == 1 {
string.split(delimiter.as_bytes()[0] as char).nth(n)
} else {
string.split(delimiter).nth(n)
}
}
#[inline]
fn rsplit_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> {
if delimiter.len() == 1 {
string.rsplit(delimiter.as_bytes()[0] as char).nth(n)
} else {
string.rsplit(delimiter).nth(n)
}
}
fn split_part_scalar(
string_array: &ArrayRef,
delim_scalar: &ScalarValue,
pos_scalar: &ScalarValue,
) -> Result<ColumnarValue> {
if string_array.is_empty() {
return Ok(ColumnarValue::Array(new_null_array(
string_array.data_type(),
0,
)));
}
let delimiter = delim_scalar.try_as_str().ok_or_else(|| {
exec_datafusion_err!(
"Unsupported delimiter type {:?} for split_part",
delim_scalar.data_type()
)
})?;
let position = match pos_scalar {
ScalarValue::Int64(v) => *v,
other => {
return exec_err!(
"Unsupported position type {:?} for split_part",
other.data_type()
);
}
};
let (Some(delimiter), Some(position)) = (delimiter, position) else {
return Ok(ColumnarValue::Array(new_null_array(
string_array.data_type(),
string_array.len(),
)));
};
if position == 0 {
return exec_err!("field position must not be zero");
}
let result = match string_array.data_type() {
DataType::Utf8View => {
split_part_scalar_view(string_array.as_string_view(), delimiter, position)
}
DataType::Utf8 => {
let arr = string_array.as_string::<i32>();
split_part_scalar_impl(
arr,
delimiter,
position,
GenericStringBuilder::<i32>::with_capacity(arr.len(), arr.len()),
)
}
DataType::LargeUtf8 => {
let arr = string_array.as_string::<i64>();
split_part_scalar_impl(
arr,
delimiter,
position,
GenericStringBuilder::<i64>::with_capacity(arr.len(), arr.len()),
)
}
other => exec_err!("Unsupported string type {other:?} for split_part"),
}?;
Ok(ColumnarValue::Array(result))
}
fn split_part_scalar_impl<'a, S, B>(
string_array: S,
delimiter: &str,
position: i64,
builder: B,
) -> Result<ArrayRef>
where
S: StringArrayType<'a> + Copy,
B: StringLikeArrayBuilder,
{
if delimiter.is_empty() {
return if position == 1 || position == -1 {
map_strings(string_array, builder, Some)
} else {
map_strings(string_array, builder, |_| None)
};
}
let delim_bytes = delimiter.as_bytes();
let delim_len = delimiter.len();
if position > 0 {
let idx: usize = (position - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {position} exceeds maximum supported value"
)
})?;
let finder = memmem::Finder::new(delim_bytes);
map_strings(string_array, builder, |s| {
split_nth_finder(s, &finder, delim_len, idx)
})
} else {
let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {position} exceeds minimum supported value"
)
})?;
let finder_rev = memmem::FinderRev::new(delim_bytes);
map_strings(string_array, builder, |s| {
rsplit_nth_finder(s, &finder_rev, delim_len, idx)
})
}
}
#[inline]
fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result<ArrayRef>
where
S: StringArrayType<'a> + Copy,
B: StringLikeArrayBuilder,
F: Fn(&'a str) -> Option<&'a str>,
{
for string in string_array.iter() {
match string {
Some(s) => builder.append_value(f(s).unwrap_or("")),
None => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
#[inline]
fn split_nth_finder<'a>(
string: &'a str,
finder: &memmem::Finder,
delim_len: usize,
n: usize,
) -> Option<&'a str> {
let bytes = string.as_bytes();
let mut start = 0;
for _ in 0..n {
match finder.find(&bytes[start..]) {
Some(pos) => start += pos + delim_len,
None => return None,
}
}
match finder.find(&bytes[start..]) {
Some(pos) => Some(&string[start..start + pos]),
None => Some(&string[start..]),
}
}
#[inline]
fn rsplit_nth_finder<'a>(
string: &'a str,
finder: &memmem::FinderRev,
delim_len: usize,
n: usize,
) -> Option<&'a str> {
let bytes = string.as_bytes();
let mut end = bytes.len();
for _ in 0..n {
match finder.rfind(&bytes[..end]) {
Some(pos) => end = pos,
None => return None,
}
}
match finder.rfind(&bytes[..end]) {
Some(pos) => Some(&string[pos + delim_len..end]),
None => Some(&string[..end]),
}
}
fn split_part_scalar_view(
string_view_array: &StringViewArray,
delimiter: &str,
position: i64,
) -> Result<ArrayRef> {
let len = string_view_array.len();
let mut views_buf = Vec::with_capacity(len);
let views = string_view_array.views();
if delimiter.is_empty() {
let empty_view = make_view(b"", 0, 0);
let return_input = position == 1 || position == -1;
for i in 0..len {
if string_view_array.is_null(i) {
views_buf.push(0);
} else if return_input {
views_buf.push(views[i]);
} else {
views_buf.push(empty_view);
}
}
} else if position > 0 {
let idx: usize = (position - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {position} exceeds maximum supported value"
)
})?;
let finder = memmem::Finder::new(delimiter.as_bytes());
split_view_loop(string_view_array, views, &mut views_buf, |s| {
split_nth_finder(s, &finder, delimiter.len(), idx)
});
} else {
let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {position} exceeds minimum supported value"
)
})?;
let finder_rev = memmem::FinderRev::new(delimiter.as_bytes());
split_view_loop(string_view_array, views, &mut views_buf, |s| {
rsplit_nth_finder(s, &finder_rev, delimiter.len(), idx)
});
}
let views_buf = ScalarBuffer::from(views_buf);
let nulls = string_view_array.nulls().cloned();
unsafe {
Ok(Arc::new(StringViewArray::new_unchecked(
views_buf,
string_view_array.data_buffers().to_vec(),
nulls,
)) as ArrayRef)
}
}
#[inline]
fn substr_view(original_view: &u128, substr: &str, start_offset: u32) -> u128 {
if substr.len() > 12 {
let view = ByteView::from(*original_view);
make_view(
substr.as_bytes(),
view.buffer_index,
view.offset + start_offset,
)
} else {
make_view(substr.as_bytes(), 0, 0)
}
}
#[inline(always)]
fn split_view_loop<F>(
string_view_array: &StringViewArray,
views: &[u128],
views_buf: &mut Vec<u128>,
split_fn: F,
) where
F: Fn(&str) -> Option<&str>,
{
let empty_view = make_view(b"", 0, 0);
for (i, raw_view) in views.iter().enumerate() {
if string_view_array.is_null(i) {
views_buf.push(0);
continue;
}
let string = string_view_array.value(i);
match split_fn(string) {
Some(substr) => {
let start_offset = substr.as_ptr() as usize - string.as_ptr() as usize;
views_buf.push(substr_view(raw_view, substr, start_offset as u32));
}
None => views_buf.push(empty_view),
}
}
}
fn split_part_impl<'a, StringArrType, DelimiterArrType, B>(
string_array: &StringArrType,
delimiter_array: &DelimiterArrType,
n_array: &Int64Array,
mut builder: B,
) -> Result<ArrayRef>
where
StringArrType: StringArrayType<'a>,
DelimiterArrType: StringArrayType<'a>,
B: StringLikeArrayBuilder,
{
for ((string, delimiter), n) in string_array
.iter()
.zip(delimiter_array.iter())
.zip(n_array.iter())
{
match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let result = match n.cmp(&0) {
std::cmp::Ordering::Greater => {
let idx: usize = (n - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds maximum supported value"
)
})?;
if delimiter.is_empty() {
(n == 1).then_some(string)
} else {
split_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Less => {
let idx: usize =
(n.unsigned_abs() - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds minimum supported value"
)
})?;
if delimiter.is_empty() {
(n == -1).then_some(string)
} else {
rsplit_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Equal => {
return exec_err!("field position must not be zero");
}
};
builder.append_value(result.unwrap_or(""));
}
_ => builder.append_null(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
#[cfg(test)]
mod tests {
use arrow::array::{Array, AsArray, StringArray, StringViewArray};
use arrow::datatypes::DataType::Utf8;
use datafusion_common::ScalarValue;
use datafusion_common::{Result, exec_err};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use crate::string::split_part::SplitPartFunc;
use crate::utils::test::test_function;
#[test]
fn test_functions() -> Result<()> {
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
"abc~@~def~@~ghi"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
],
Ok(Some("def")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
"abc~@~def~@~ghi"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
"abc~@~def~@~ghi"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
],
Ok(Some("ghi")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
"abc~@~def~@~ghi"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
],
exec_err!("field position must not be zero"),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
"abc~@~def~@~ghi"
)))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
],
Ok(Some("a")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
],
Ok(Some("a,b")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
],
Ok(Some("a,b")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
],
Ok(Some("a,b")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
],
Ok(Some("a,b")),
&str,
Utf8,
StringArray
);
test_function!(
SplitPartFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
Ok(())
}
#[test]
fn test_split_part_stringview_sliced() -> Result<()> {
use super::split_part_scalar_view;
let strings: StringViewArray = vec![
Some("skip_this.value"),
Some("this_is_a_long_prefix.suffix"),
Some("short.val"),
Some("another_long_result.rest"),
None,
]
.into_iter()
.collect();
let sliced = strings.slice(1, 4);
let result = split_part_scalar_view(&sliced, ".", 1)?;
let result = result.as_string_view();
assert_eq!(result.len(), 4);
assert_eq!(result.value(0), "this_is_a_long_prefix");
assert_eq!(result.value(1), "short");
assert_eq!(result.value(2), "another_long_result");
assert!(result.is_null(3));
Ok(())
}
}