1use crate::utils::utf8_to_str_type;
19use arrow::array::{
20 ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType,
21 StringViewArray,
22};
23use arrow::array::{AsArray, GenericStringBuilder};
24use arrow::datatypes::DataType;
25use datafusion_common::ScalarValue;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::types::{NativeType, logical_int64, logical_string};
28use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err};
29use datafusion_expr::{
30 Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility,
31};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
40 syntax_example = "split_part(str, delimiter, pos)",
41 sql_example = r#"```sql
42> select split_part('1.2.3.4.5', '.', 3);
43+--------------------------------------------------+
44| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
45+--------------------------------------------------+
46| 3 |
47+--------------------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "delimiter", description = "String or character to split on."),
51 argument(
52 name = "pos",
53 description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct SplitPartFunc {
58 signature: Signature,
59}
60
61impl Default for SplitPartFunc {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl SplitPartFunc {
68 pub fn new() -> Self {
69 Self {
70 signature: Signature::coercible(
71 vec![
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 Coercion::new_implicit(
75 TypeSignatureClass::Native(logical_int64()),
76 vec![TypeSignatureClass::Integer],
77 NativeType::Int64,
78 ),
79 ],
80 Volatility::Immutable,
81 ),
82 }
83 }
84}
85
86impl ScalarUDFImpl for SplitPartFunc {
87 fn as_any(&self) -> &dyn Any {
88 self
89 }
90
91 fn name(&self) -> &str {
92 "split_part"
93 }
94
95 fn signature(&self) -> &Signature {
96 &self.signature
97 }
98
99 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100 utf8_to_str_type(&arg_types[0], "split_part")
101 }
102
103 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104 let ScalarFunctionArgs { args, .. } = args;
105
106 let len = args.iter().find_map(|arg| match arg {
108 ColumnarValue::Array(a) => Some(a.len()),
109 _ => None,
110 });
111
112 let inferred_length = len.unwrap_or(1);
113 let is_scalar = len.is_none();
114
115 let args = args
117 .iter()
118 .map(|arg| match arg {
119 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
120 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
121 })
122 .collect::<Result<Vec<_>>>()?;
123
124 let n_array = as_int64_array(&args[2])?;
126 let result = match (args[0].data_type(), args[1].data_type()) {
127 (DataType::Utf8View, DataType::Utf8View) => {
128 split_part_impl::<&StringViewArray, &StringViewArray, i32>(
129 &args[0].as_string_view(),
130 &args[1].as_string_view(),
131 n_array,
132 )
133 }
134 (DataType::Utf8View, DataType::Utf8) => {
135 split_part_impl::<&StringViewArray, &GenericStringArray<i32>, i32>(
136 &args[0].as_string_view(),
137 &args[1].as_string::<i32>(),
138 n_array,
139 )
140 }
141 (DataType::Utf8View, DataType::LargeUtf8) => {
142 split_part_impl::<&StringViewArray, &GenericStringArray<i64>, i32>(
143 &args[0].as_string_view(),
144 &args[1].as_string::<i64>(),
145 n_array,
146 )
147 }
148 (DataType::Utf8, DataType::Utf8View) => {
149 split_part_impl::<&GenericStringArray<i32>, &StringViewArray, i32>(
150 &args[0].as_string::<i32>(),
151 &args[1].as_string_view(),
152 n_array,
153 )
154 }
155 (DataType::LargeUtf8, DataType::Utf8View) => {
156 split_part_impl::<&GenericStringArray<i64>, &StringViewArray, i64>(
157 &args[0].as_string::<i64>(),
158 &args[1].as_string_view(),
159 n_array,
160 )
161 }
162 (DataType::Utf8, DataType::Utf8) => {
163 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i32>, i32>(
164 &args[0].as_string::<i32>(),
165 &args[1].as_string::<i32>(),
166 n_array,
167 )
168 }
169 (DataType::LargeUtf8, DataType::LargeUtf8) => {
170 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i64>, i64>(
171 &args[0].as_string::<i64>(),
172 &args[1].as_string::<i64>(),
173 n_array,
174 )
175 }
176 (DataType::Utf8, DataType::LargeUtf8) => {
177 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i64>, i32>(
178 &args[0].as_string::<i32>(),
179 &args[1].as_string::<i64>(),
180 n_array,
181 )
182 }
183 (DataType::LargeUtf8, DataType::Utf8) => {
184 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i32>, i64>(
185 &args[0].as_string::<i64>(),
186 &args[1].as_string::<i32>(),
187 n_array,
188 )
189 }
190 _ => exec_err!("Unsupported combination of argument types for split_part"),
191 };
192 if is_scalar {
193 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
195 result.map(ColumnarValue::Scalar)
196 } else {
197 result.map(ColumnarValue::Array)
198 }
199 }
200
201 fn documentation(&self) -> Option<&Documentation> {
202 self.doc()
203 }
204}
205
206fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(
207 string_array: &StringArrType,
208 delimiter_array: &DelimiterArrType,
209 n_array: &Int64Array,
210) -> Result<ArrayRef>
211where
212 StringArrType: StringArrayType<'a>,
213 DelimiterArrType: StringArrayType<'a>,
214 StringArrayLen: OffsetSizeTrait,
215{
216 let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
217
218 string_array
219 .iter()
220 .zip(delimiter_array.iter())
221 .zip(n_array.iter())
222 .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> {
223 match (string, delimiter, n) {
224 (Some(string), Some(delimiter), Some(n)) => {
225 let result = match n.cmp(&0) {
226 std::cmp::Ordering::Greater => {
227 let idx: usize = (n - 1).try_into().map_err(|_| {
230 exec_datafusion_err!(
231 "split_part index {n} exceeds maximum supported value"
232 )
233 })?;
234
235 if delimiter.is_empty() {
236 (n == 1).then_some(string)
240 } else {
241 string.split(delimiter).nth(idx)
242 }
243 }
244 std::cmp::Ordering::Less => {
245 let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| {
248 exec_datafusion_err!(
249 "split_part index {n} exceeds minimum supported value"
250 )
251 })?;
252 if delimiter.is_empty() {
253 (n == -1).then_some(string)
257 } else {
258 string.rsplit(delimiter).nth(idx)
259 }
260 }
261 std::cmp::Ordering::Equal => {
262 return exec_err!("field position must not be zero");
263 }
264 };
265 builder.append_value(result.unwrap_or(""));
266 }
267 _ => builder.append_null(),
268 }
269 Ok(())
270 })?;
271
272 Ok(Arc::new(builder.finish()) as ArrayRef)
273}
274
275#[cfg(test)]
276mod tests {
277 use arrow::array::{Array, StringArray};
278 use arrow::datatypes::DataType::Utf8;
279
280 use datafusion_common::ScalarValue;
281 use datafusion_common::{Result, exec_err};
282 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
283
284 use crate::string::split_part::SplitPartFunc;
285 use crate::utils::test::test_function;
286
287 #[test]
288 fn test_functions() -> Result<()> {
289 test_function!(
290 SplitPartFunc::new(),
291 vec![
292 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
293 "abc~@~def~@~ghi"
294 )))),
295 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
296 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
297 ],
298 Ok(Some("def")),
299 &str,
300 Utf8,
301 StringArray
302 );
303 test_function!(
304 SplitPartFunc::new(),
305 vec![
306 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
307 "abc~@~def~@~ghi"
308 )))),
309 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
310 ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
311 ],
312 Ok(Some("")),
313 &str,
314 Utf8,
315 StringArray
316 );
317 test_function!(
318 SplitPartFunc::new(),
319 vec![
320 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
321 "abc~@~def~@~ghi"
322 )))),
323 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
324 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
325 ],
326 Ok(Some("ghi")),
327 &str,
328 Utf8,
329 StringArray
330 );
331 test_function!(
332 SplitPartFunc::new(),
333 vec![
334 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
335 "abc~@~def~@~ghi"
336 )))),
337 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
338 ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
339 ],
340 exec_err!("field position must not be zero"),
341 &str,
342 Utf8,
343 StringArray
344 );
345 test_function!(
346 SplitPartFunc::new(),
347 vec![
348 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
349 "abc~@~def~@~ghi"
350 )))),
351 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
352 ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))),
353 ],
354 Ok(Some("")),
355 &str,
356 Utf8,
357 StringArray
358 );
359 test_function!(
361 SplitPartFunc::new(),
362 vec![
363 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
364 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
365 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
366 ],
367 Ok(Some("a")),
368 &str,
369 Utf8,
370 StringArray
371 );
372 test_function!(
373 SplitPartFunc::new(),
374 vec![
375 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
376 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
377 ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
378 ],
379 Ok(Some("")),
380 &str,
381 Utf8,
382 StringArray
383 );
384 test_function!(
385 SplitPartFunc::new(),
386 vec![
387 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
388 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
389 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
390 ],
391 Ok(Some("a,b")),
392 &str,
393 Utf8,
394 StringArray
395 );
396 test_function!(
397 SplitPartFunc::new(),
398 vec![
399 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
400 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
401 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
402 ],
403 Ok(Some("")),
404 &str,
405 Utf8,
406 StringArray
407 );
408 test_function!(
409 SplitPartFunc::new(),
410 vec![
411 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
412 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
413 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
414 ],
415 Ok(Some("a,b")),
416 &str,
417 Utf8,
418 StringArray
419 );
420 test_function!(
421 SplitPartFunc::new(),
422 vec![
423 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
424 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
425 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
426 ],
427 Ok(Some("")),
428 &str,
429 Utf8,
430 StringArray
431 );
432
433 test_function!(
435 SplitPartFunc::new(),
436 vec![
437 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
438 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
439 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
440 ],
441 Ok(Some("a,b")),
442 &str,
443 Utf8,
444 StringArray
445 );
446 test_function!(
447 SplitPartFunc::new(),
448 vec![
449 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
450 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
451 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
452 ],
453 Ok(Some("a,b")),
454 &str,
455 Utf8,
456 StringArray
457 );
458 test_function!(
459 SplitPartFunc::new(),
460 vec![
461 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
462 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
463 ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
464 ],
465 Ok(Some("")),
466 &str,
467 Utf8,
468 StringArray
469 );
470
471 Ok(())
472 }
473}