use std::any::Any;
use std::sync::Arc;
use arrow::array::{
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
GenericStringBuilder, OffsetSizeTrait, PrimitiveArray,
};
use arrow::datatypes::{DataType, Int32Type, Int64Type};
use crate::utils::{make_scalar_function, utf8_to_str_type};
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
#[user_doc(
doc_section(label = "String Functions"),
description = r#"Returns the substring from str before count occurrences of the delimiter delim.
If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#,
syntax_example = "substr_index(str, delim, count)",
sql_example = r#"```sql
> select substr_index('www.apache.org', '.', 1);
+---------------------------------------------------------+
| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) |
+---------------------------------------------------------+
| www |
+---------------------------------------------------------+
> select substr_index('www.apache.org', '.', -1);
+----------------------------------------------------------+
| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) |
+----------------------------------------------------------+
| org |
+----------------------------------------------------------+
```"#,
standard_argument(name = "str", prefix = "String"),
argument(
name = "delim",
description = "The string to find in str to split str."
),
argument(
name = "count",
description = "The number of times to search for the delimiter. Can be either a positive or negative number."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SubstrIndexFunc {
signature: Signature,
aliases: Vec<String>,
}
impl Default for SubstrIndexFunc {
fn default() -> Self {
Self::new()
}
}
impl SubstrIndexFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View, Int64]),
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
Volatility::Immutable,
),
aliases: vec![String::from("substring_index")],
}
}
}
impl ScalarUDFImpl for SubstrIndexFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"substr_index"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "substr_index")
}
fn invoke_with_args(
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
make_scalar_function(substr_index, vec![])(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
let [str, delim, count] = take_function_args("substr_index", args)?;
match str.data_type() {
DataType::Utf8 => {
let string_array = str.as_string::<i32>();
let delimiter_array = delim.as_string::<i32>();
let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
substr_index_general::<Int32Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
DataType::LargeUtf8 => {
let string_array = str.as_string::<i64>();
let delimiter_array = delim.as_string::<i64>();
let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
substr_index_general::<Int64Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
DataType::Utf8View => {
let string_array = str.as_string_view();
let delimiter_array = delim.as_string_view();
let count_array: &PrimitiveArray<Int64Type> = count.as_primitive();
substr_index_general::<Int32Type, _, _>(
string_array,
delimiter_array,
count_array,
)
}
other => {
exec_err!("Unsupported data type {other:?} for function substr_index")
}
}
}
fn substr_index_general<
'a,
T: ArrowPrimitiveType,
V: ArrayAccessor<Item = &'a str>,
P: ArrayAccessor<Item = i64>,
>(
string_array: V,
delimiter_array: V,
count_array: P,
) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
let num_rows = string_array.len();
let mut builder = GenericStringBuilder::<T::Native>::with_capacity(num_rows, 0);
let string_iter = ArrayIter::new(string_array);
let delimiter_array_iter = ArrayIter::new(delimiter_array);
let count_array_iter = ArrayIter::new(count_array);
string_iter
.zip(delimiter_array_iter)
.zip(count_array_iter)
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
if n == 0 || string.is_empty() || delimiter.is_empty() {
builder.append_value("");
return;
}
let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
let result_idx = if delimiter.len() == 1 {
let d_byte = delimiter.as_bytes()[0];
let bytes = string.as_bytes();
if n > 0 {
bytes
.iter()
.enumerate()
.filter(|&(_, &b)| b == d_byte)
.nth(occurrences - 1)
.map(|(idx, _)| idx)
} else {
bytes
.iter()
.enumerate()
.rev()
.filter(|&(_, &b)| b == d_byte)
.nth(occurrences - 1)
.map(|(idx, _)| idx + 1)
}
} else if n > 0 {
string
.match_indices(delimiter)
.nth(occurrences - 1)
.map(|(idx, _)| idx)
} else {
string
.rmatch_indices(delimiter)
.nth(occurrences - 1)
.map(|(idx, _)| idx + delimiter.len())
};
match result_idx {
Some(idx) => {
if n > 0 {
builder.append_value(&string[..idx]);
} else {
builder.append_value(&string[idx..]);
}
}
None => builder.append_value(string),
}
}
_ => builder.append_null(),
});
Ok(Arc::new(builder.finish()) as ArrayRef)
}
#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use crate::unicode::substrindex::SubstrIndexFunc;
use crate::utils::test::test_function;
#[test]
fn test_functions() -> Result<()> {
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("www")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("www.apache")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(-2i64)),
],
Ok(Some("apache.org")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(-1i64)),
],
Ok(Some("org")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from("")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
Ok(())
}
}