use std::sync::Arc;
use crate::strings::{StringViewArrayBuilder, append_view};
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait,
StringArrayType, StringViewArray, make_view,
};
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_int64_array;
use datafusion_common::types::{
NativeType, logical_int32, logical_int64, logical_string,
};
use datafusion_common::{Result, exec_err};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
#[user_doc(
doc_section(label = "String Functions"),
description = "Extracts a substring of a specified number of characters from a specific starting position in a string.",
syntax_example = "substr(str, start_pos[, length])",
alternative_syntax = "substring(str from start_pos for length)",
sql_example = r#"```sql
> select substr('datafusion', 5, 3);
+----------------------------------------------+
| substr(Utf8("datafusion"),Int64(5),Int64(3)) |
+----------------------------------------------+
| fus |
+----------------------------------------------+
```"#,
standard_argument(name = "str", prefix = "String"),
argument(
name = "start_pos",
description = "Character position to start the substring at. The first character in the string has a position of 1. If the start position is less than 1, it is treated as if it is before the start of the string and the (absolute) number of characters before position 1 is subtracted from `length` (if given). For example, `substr('abc', -3, 6)` returns `'ab'`."
),
argument(
name = "length",
description = "Number of characters to extract. If not specified, returns the rest of the string after the start position."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SubstrFunc {
signature: Signature,
aliases: Vec<String>,
}
impl Default for SubstrFunc {
fn default() -> Self {
Self::new()
}
}
impl SubstrFunc {
pub fn new() -> Self {
let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
let int64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Native(logical_int32())],
NativeType::Int64,
);
Self {
signature: Signature::one_of(
vec![
TypeSignature::Coercible(vec![string.clone(), int64.clone()]),
TypeSignature::Coercible(vec![
string.clone(),
int64.clone(),
int64.clone(),
]),
],
Volatility::Immutable,
)
.with_parameter_names(vec![
"str".to_string(),
"start_pos".to_string(),
"length".to_string(),
])
.expect("valid parameter names"),
aliases: vec![String::from("substring")],
}
}
}
impl ScalarUDFImpl for SubstrFunc {
fn name(&self) -> &str {
"substr"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8View)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(substr, vec![])(&args.args)
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
generic_string_substr(string_array, &args[1..])
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
generic_string_substr(string_array, &args[1..])
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
string_view_substr(string_array, &args[1..])
}
other => exec_err!(
"Unsupported data type {other:?} for function substr,\
expected Utf8View, Utf8 or LargeUtf8."
),
}
}
pub fn get_true_start_end(
input: &str,
start: i64,
count: Option<i64>,
is_input_ascii_only: bool,
) -> Result<(usize, usize)> {
if let Some(count) = count
&& count < 0
{
return exec_err!("negative count not allowed: {count}");
}
let Some(start) = start.checked_sub(1) else {
return exec_err!("start position overflow: {start}");
};
let end = match count {
Some(count) => start.saturating_add(count),
None => input.len() as i64,
};
let start = start.clamp(0, input.len() as i64) as usize;
let end = end.clamp(0, input.len() as i64) as usize;
if is_input_ascii_only {
return Ok((start, end));
}
let mut byte_start = input.len();
let mut byte_end = input.len();
for (char_idx, (byte_idx, _)) in input.char_indices().enumerate() {
if char_idx == start {
byte_start = byte_idx;
if count.is_none() {
break;
}
}
if char_idx == end {
byte_end = byte_idx;
break;
}
}
Ok((byte_start, byte_end))
}
pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
string_array: &V,
start: &Int64Array,
count: Option<&Int64Array>,
) -> bool {
let is_short_prefix = match count {
Some(count) => {
let short_prefix_threshold = 32.0;
let n_sample = 10;
let total_prefix_len = start
.iter()
.zip(count.iter())
.take(n_sample)
.map(|(start, count)| {
let start = start.unwrap_or(0);
let count = count.unwrap_or(0);
start.saturating_add(count)
})
.fold(0i64, |acc, val| acc.saturating_add(val));
(total_prefix_len as f64 / n_sample as f64) <= short_prefix_threshold
}
None => false,
};
if is_short_prefix {
false
} else {
string_array.is_ascii()
}
}
fn string_view_substr(
string_view_array: &StringViewArray,
args: &[ArrayRef],
) -> Result<ArrayRef> {
let start_array = as_int64_array(&args[0])?;
let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?;
let is_ascii =
enable_ascii_fast_path(&string_view_array, start_array, count_array_opt);
let nulls = NullBuffer::union_many([
string_view_array.nulls(),
start_array.nulls(),
count_array_opt.and_then(|a| a.nulls()),
]);
let mut views_buf = Vec::with_capacity(string_view_array.len());
for (i, raw_view) in string_view_array.views().iter().enumerate() {
if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
views_buf.push(0);
continue;
}
let string = string_view_array.value(i);
let start = start_array.value(i);
let count = count_array_opt.map(|a| a.value(i));
let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?;
let substr = &string[byte_start..byte_end];
append_view(&mut views_buf, raw_view, substr, byte_start as u32);
}
let views_buf = ScalarBuffer::from(views_buf);
unsafe {
let array = StringViewArray::new_unchecked(
views_buf,
string_view_array.data_buffers().to_vec(),
nulls,
);
Ok(Arc::new(array) as ArrayRef)
}
}
fn values_fit_in_i32<T: OffsetSizeTrait>(string_array: &GenericStringArray<T>) -> bool {
string_array
.offsets()
.last()
.map(|offset| offset.as_usize() <= i32::MAX as usize)
.unwrap_or(true)
}
#[inline]
fn append_view_from_buffer(
views_buf: &mut Vec<u128>,
substr: &str,
byte_offset: usize,
) -> bool {
let byte_offset =
u32::try_from(byte_offset).expect("validated string buffer offset fits in i32");
let view = make_view(substr.as_bytes(), 0, byte_offset);
views_buf.push(view);
substr.len() > 12
}
#[expect(clippy::needless_range_loop)]
fn generic_string_substr<T: OffsetSizeTrait>(
string_array: &GenericStringArray<T>,
args: &[ArrayRef],
) -> Result<ArrayRef> {
if !values_fit_in_i32(string_array) {
return generic_string_substr_copy(string_array, args);
}
let start_array = as_int64_array(&args[0])?;
let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?;
let is_ascii = enable_ascii_fast_path(&string_array, start_array, count_array_opt);
let offsets = string_array.value_offsets();
let mut views_buf = Vec::with_capacity(string_array.len());
let mut has_out_of_line = false;
let nulls = NullBuffer::union_many([
string_array.nulls(),
start_array.nulls(),
count_array_opt.and_then(|a| a.nulls()),
]);
for i in 0..string_array.len() {
if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
views_buf.push(0);
continue;
}
let string = string_array.value(i);
let source_offset = offsets[i].as_usize();
let start = start_array.value(i);
let count = count_array_opt.map(|a| a.value(i));
let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?;
has_out_of_line |= append_view_from_buffer(
&mut views_buf,
&string[byte_start..byte_end],
source_offset + byte_start,
);
}
let views_buf = ScalarBuffer::from(views_buf);
let data_buffers = if has_out_of_line {
vec![string_array.values().clone()]
} else {
vec![]
};
unsafe {
let array = StringViewArray::new_unchecked(views_buf, data_buffers, nulls);
Ok(Arc::new(array) as ArrayRef)
}
}
fn generic_string_substr_copy<T: OffsetSizeTrait>(
string_array: &GenericStringArray<T>,
args: &[ArrayRef],
) -> Result<ArrayRef> {
let start_array = as_int64_array(&args[0])?;
let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?;
let is_ascii = enable_ascii_fast_path(&string_array, start_array, count_array_opt);
let nulls = NullBuffer::union_many([
string_array.nulls(),
start_array.nulls(),
count_array_opt.and_then(|a| a.nulls()),
]);
let len = string_array.len();
let mut result_builder = StringViewArrayBuilder::with_capacity(len);
for i in 0..len {
if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
result_builder.append_placeholder();
continue;
}
let string = string_array.value(i);
let start = start_array.value(i);
let count = count_array_opt.map(|a| a.value(i));
let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?;
result_builder.append_value(&string[byte_start..byte_end]);
}
Ok(Arc::new(result_builder.finish(nulls)?) as ArrayRef)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::{
Array, ArrayRef, AsArray, Int64Array, StringArray, StringViewArray,
};
use arrow::datatypes::DataType::Utf8View;
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use crate::unicode::substr::SubstrFunc;
use crate::utils::test::test_function;
#[test]
fn test_functions() -> Result<()> {
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"alphabet"
)))),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
],
Ok(Some("alphabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"this és longer than 12B"
)))),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some(" é")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"this is longer than 12B"
)))),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
],
Ok(Some(" is longer than 12B")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"joséésoj"
)))),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
],
Ok(Some("ésoj")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"alphabet"
)))),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("ph")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
"alphabet"
)))),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
ColumnarValue::Scalar(ScalarValue::from(20i64)),
],
Ok(Some("phabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
],
Ok(Some("alphabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
],
Ok(Some("ésoj")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
ColumnarValue::Scalar(ScalarValue::from(-5i64)),
],
Ok(Some("joséésoj")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("alphabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("lphabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
],
Ok(Some("phabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(-3i64)),
],
Ok(Some("alphabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(30i64)),
],
Ok(Some("")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("ph")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
ColumnarValue::Scalar(ScalarValue::from(20i64)),
],
Ok(Some("phabet")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
],
Ok(Some("alph")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(-5i64)),
ColumnarValue::Scalar(ScalarValue::from(10i64)),
],
Ok(Some("alph")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(-5i64)),
ColumnarValue::Scalar(ScalarValue::from(4i64)),
],
Ok(Some("")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(-5i64)),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
],
Ok(Some("")),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
ColumnarValue::Scalar(ScalarValue::from(20i64)),
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(3i64)),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
ColumnarValue::Scalar(ScalarValue::from(-1i64)),
],
exec_err!("negative count not allowed: -1"),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("joséésoj")),
ColumnarValue::Scalar(ScalarValue::from(5i64)),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("és")),
&str,
Utf8View,
StringViewArray
);
#[cfg(not(feature = "unicode_expressions"))]
test_function!(
SubstrFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
],
internal_err!(
"function substr requires compilation with feature flag: unicode_expressions."
),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("abc")),
ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
],
exec_err!("start position overflow: -9223372036854775808"),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("overflow")),
ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
exec_err!("start position overflow: -9223372036854775808"),
&str,
Utf8View,
StringViewArray
);
test_function!(
SubstrFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::from("large count")),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
ColumnarValue::Scalar(ScalarValue::from(i64::MAX)),
],
Ok(Some("arge count")),
&str,
Utf8View,
StringViewArray
);
Ok(())
}
#[test]
fn test_sliced_string_array_array_args() -> Result<()> {
let string_array = Arc::new(StringArray::from(vec![
"skipped_prefix_value",
"alphabet_long_string",
"joséésojanother_long",
])) as ArrayRef;
let string_array = string_array.slice(1, 2);
let start_array = Arc::new(Int64Array::from(vec![3, 5])) as ArrayRef;
let count_array = Arc::new(Int64Array::from(vec![15, 14])) as ArrayRef;
let result = super::substr(&[string_array, start_array, count_array])?;
let result = result.as_string_view();
assert_eq!(result.value(0), "phabet_long_str");
assert_eq!(result.value(1), "ésojanother_lo");
Ok(())
}
}