1use crate::utils::utf8_to_str_type;
19use arrow::array::{
20 Array, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array,
21 StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder,
22 make_view, new_null_array,
23};
24use arrow::buffer::ScalarBuffer;
25use arrow::datatypes::DataType;
26use datafusion_common::ScalarValue;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::types::{NativeType, logical_int64, logical_string};
29use datafusion_common::{Result, exec_datafusion_err, exec_err};
30use datafusion_expr::{
31 Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility,
32};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35use memchr::memmem;
36use std::sync::Arc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
41 syntax_example = "split_part(str, delimiter, pos)",
42 sql_example = r#"```sql
43> select split_part('1.2.3.4.5', '.', 3);
44+--------------------------------------------------+
45| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
46+--------------------------------------------------+
47| 3 |
48+--------------------------------------------------+
49```"#,
50 standard_argument(name = "str", prefix = "String"),
51 argument(name = "delimiter", description = "String or character to split on."),
52 argument(
53 name = "pos",
54 description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
55 )
56)]
57#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct SplitPartFunc {
59 signature: Signature,
60}
61
62impl Default for SplitPartFunc {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl SplitPartFunc {
69 pub fn new() -> Self {
70 Self {
71 signature: Signature::coercible(
72 vec![
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
75 Coercion::new_implicit(
76 TypeSignatureClass::Native(logical_int64()),
77 vec![TypeSignatureClass::Integer],
78 NativeType::Int64,
79 ),
80 ],
81 Volatility::Immutable,
82 ),
83 }
84 }
85}
86
87impl ScalarUDFImpl for SplitPartFunc {
88 fn name(&self) -> &str {
89 "split_part"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 if arg_types[0] == DataType::Utf8View {
98 Ok(DataType::Utf8View)
99 } else {
100 utf8_to_str_type(&arg_types[0], "split_part")
101 }
102 }
103
104 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105 let ScalarFunctionArgs { args, .. } = args;
106
107 if let (
109 ColumnarValue::Array(string_array),
110 ColumnarValue::Scalar(delim_scalar),
111 ColumnarValue::Scalar(pos_scalar),
112 ) = (&args[0], &args[1], &args[2])
113 {
114 return split_part_scalar(string_array, delim_scalar, pos_scalar);
115 }
116
117 let len = args.iter().find_map(|arg| match arg {
119 ColumnarValue::Array(a) => Some(a.len()),
120 _ => None,
121 });
122
123 let inferred_length = len.unwrap_or(1);
124 let is_scalar = len.is_none();
125
126 let args = args
128 .iter()
129 .map(|arg| match arg {
130 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
131 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
132 })
133 .collect::<Result<Vec<_>>>()?;
134
135 let n_array = as_int64_array(&args[2])?;
137
138 macro_rules! split_part_for_delimiter_type {
140 ($str_arr:expr, $builder:expr) => {
141 match args[1].data_type() {
142 DataType::Utf8View => split_part_impl(
143 $str_arr,
144 &args[1].as_string_view(),
145 n_array,
146 $builder,
147 ),
148 DataType::Utf8 => split_part_impl(
149 $str_arr,
150 &args[1].as_string::<i32>(),
151 n_array,
152 $builder,
153 ),
154 DataType::LargeUtf8 => split_part_impl(
155 $str_arr,
156 &args[1].as_string::<i64>(),
157 n_array,
158 $builder,
159 ),
160 other => {
161 exec_err!("Unsupported delimiter type {other:?} for split_part")
162 }
163 }
164 };
165 }
166
167 let result = match args[0].data_type() {
168 DataType::Utf8View => split_part_for_delimiter_type!(
169 &args[0].as_string_view(),
170 StringViewBuilder::with_capacity(inferred_length)
171 ),
172 DataType::Utf8 => {
173 let str_arr = &args[0].as_string::<i32>();
174 split_part_for_delimiter_type!(
178 str_arr,
179 GenericStringBuilder::<i32>::with_capacity(
180 inferred_length,
181 inferred_length,
182 )
183 )
184 }
185 DataType::LargeUtf8 => {
186 let str_arr = &args[0].as_string::<i64>();
187 split_part_for_delimiter_type!(
189 str_arr,
190 GenericStringBuilder::<i64>::with_capacity(
191 inferred_length,
192 inferred_length,
193 )
194 )
195 }
196 other => exec_err!("Unsupported string type {other:?} for split_part"),
197 };
198 if is_scalar {
199 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
201 result.map(ColumnarValue::Scalar)
202 } else {
203 result.map(ColumnarValue::Array)
204 }
205 }
206
207 fn documentation(&self) -> Option<&Documentation> {
208 self.doc()
209 }
210}
211
212#[inline]
214fn split_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> {
215 if delimiter.len() == 1 {
216 string.split(delimiter.as_bytes()[0] as char).nth(n)
221 } else {
222 string.split(delimiter).nth(n)
223 }
224}
225
226#[inline]
228fn rsplit_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> {
229 if delimiter.len() == 1 {
230 string.rsplit(delimiter.as_bytes()[0] as char).nth(n)
235 } else {
236 string.rsplit(delimiter).nth(n)
237 }
238}
239
240fn split_part_scalar(
242 string_array: &ArrayRef,
243 delim_scalar: &ScalarValue,
244 pos_scalar: &ScalarValue,
245) -> Result<ColumnarValue> {
246 if string_array.is_empty() {
248 return Ok(ColumnarValue::Array(new_null_array(
249 string_array.data_type(),
250 0,
251 )));
252 }
253
254 let delimiter = delim_scalar.try_as_str().ok_or_else(|| {
255 exec_datafusion_err!(
256 "Unsupported delimiter type {:?} for split_part",
257 delim_scalar.data_type()
258 )
259 })?;
260
261 let position = match pos_scalar {
262 ScalarValue::Int64(v) => *v,
263 other => {
264 return exec_err!(
265 "Unsupported position type {:?} for split_part",
266 other.data_type()
267 );
268 }
269 };
270
271 let (Some(delimiter), Some(position)) = (delimiter, position) else {
273 return Ok(ColumnarValue::Array(new_null_array(
274 string_array.data_type(),
275 string_array.len(),
276 )));
277 };
278
279 if position == 0 {
280 return exec_err!("field position must not be zero");
281 }
282
283 let result = match string_array.data_type() {
284 DataType::Utf8View => {
285 split_part_scalar_view(string_array.as_string_view(), delimiter, position)
286 }
287 DataType::Utf8 => {
288 let arr = string_array.as_string::<i32>();
289 split_part_scalar_impl(
293 arr,
294 delimiter,
295 position,
296 GenericStringBuilder::<i32>::with_capacity(arr.len(), arr.len()),
297 )
298 }
299 DataType::LargeUtf8 => {
300 let arr = string_array.as_string::<i64>();
301 split_part_scalar_impl(
303 arr,
304 delimiter,
305 position,
306 GenericStringBuilder::<i64>::with_capacity(arr.len(), arr.len()),
307 )
308 }
309 other => exec_err!("Unsupported string type {other:?} for split_part"),
310 }?;
311
312 Ok(ColumnarValue::Array(result))
313}
314
315fn split_part_scalar_impl<'a, S, B>(
319 string_array: S,
320 delimiter: &str,
321 position: i64,
322 builder: B,
323) -> Result<ArrayRef>
324where
325 S: StringArrayType<'a> + Copy,
326 B: StringLikeArrayBuilder,
327{
328 if delimiter.is_empty() {
329 return if position == 1 || position == -1 {
332 map_strings(string_array, builder, Some)
333 } else {
334 map_strings(string_array, builder, |_| None)
335 };
336 }
337
338 let delim_bytes = delimiter.as_bytes();
339 let delim_len = delimiter.len();
340
341 if position > 0 {
342 let idx: usize = (position - 1).try_into().map_err(|_| {
343 exec_datafusion_err!(
344 "split_part index {position} exceeds maximum supported value"
345 )
346 })?;
347 let finder = memmem::Finder::new(delim_bytes);
348 map_strings(string_array, builder, |s| {
349 split_nth_finder(s, &finder, delim_len, idx)
350 })
351 } else {
352 let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| {
353 exec_datafusion_err!(
354 "split_part index {position} exceeds minimum supported value"
355 )
356 })?;
357 let finder_rev = memmem::FinderRev::new(delim_bytes);
358 map_strings(string_array, builder, |s| {
359 rsplit_nth_finder(s, &finder_rev, delim_len, idx)
360 })
361 }
362}
363
364#[inline]
367fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result<ArrayRef>
368where
369 S: StringArrayType<'a> + Copy,
370 B: StringLikeArrayBuilder,
371 F: Fn(&'a str) -> Option<&'a str>,
372{
373 for string in string_array.iter() {
374 match string {
375 Some(s) => builder.append_value(f(s).unwrap_or("")),
376 None => builder.append_null(),
377 }
378 }
379 Ok(Arc::new(builder.finish()) as ArrayRef)
380}
381
382#[inline]
384fn split_nth_finder<'a>(
385 string: &'a str,
386 finder: &memmem::Finder,
387 delim_len: usize,
388 n: usize,
389) -> Option<&'a str> {
390 let bytes = string.as_bytes();
391 let mut start = 0;
392 for _ in 0..n {
393 match finder.find(&bytes[start..]) {
394 Some(pos) => start += pos + delim_len,
395 None => return None,
396 }
397 }
398 match finder.find(&bytes[start..]) {
399 Some(pos) => Some(&string[start..start + pos]),
400 None => Some(&string[start..]),
401 }
402}
403
404#[inline]
407fn rsplit_nth_finder<'a>(
408 string: &'a str,
409 finder: &memmem::FinderRev,
410 delim_len: usize,
411 n: usize,
412) -> Option<&'a str> {
413 let bytes = string.as_bytes();
414 let mut end = bytes.len();
415 for _ in 0..n {
416 match finder.rfind(&bytes[..end]) {
417 Some(pos) => end = pos,
418 None => return None,
419 }
420 }
421 match finder.rfind(&bytes[..end]) {
422 Some(pos) => Some(&string[pos + delim_len..end]),
423 None => Some(&string[..end]),
424 }
425}
426
427fn split_part_scalar_view(
433 string_view_array: &StringViewArray,
434 delimiter: &str,
435 position: i64,
436) -> Result<ArrayRef> {
437 let len = string_view_array.len();
438 let mut views_buf = Vec::with_capacity(len);
439 let views = string_view_array.views();
440
441 if delimiter.is_empty() {
442 let empty_view = make_view(b"", 0, 0);
444 let return_input = position == 1 || position == -1;
445 for i in 0..len {
446 if string_view_array.is_null(i) {
447 views_buf.push(0);
448 } else if return_input {
449 views_buf.push(views[i]);
450 } else {
451 views_buf.push(empty_view);
452 }
453 }
454 } else if position > 0 {
455 let idx: usize = (position - 1).try_into().map_err(|_| {
456 exec_datafusion_err!(
457 "split_part index {position} exceeds maximum supported value"
458 )
459 })?;
460 let finder = memmem::Finder::new(delimiter.as_bytes());
461 split_view_loop(string_view_array, views, &mut views_buf, |s| {
462 split_nth_finder(s, &finder, delimiter.len(), idx)
463 });
464 } else {
465 let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| {
466 exec_datafusion_err!(
467 "split_part index {position} exceeds minimum supported value"
468 )
469 })?;
470 let finder_rev = memmem::FinderRev::new(delimiter.as_bytes());
471 split_view_loop(string_view_array, views, &mut views_buf, |s| {
472 rsplit_nth_finder(s, &finder_rev, delimiter.len(), idx)
473 });
474 }
475
476 let views_buf = ScalarBuffer::from(views_buf);
477
478 let nulls = string_view_array.nulls().cloned();
480
481 unsafe {
485 Ok(Arc::new(StringViewArray::new_unchecked(
486 views_buf,
487 string_view_array.data_buffers().to_vec(),
488 nulls,
489 )) as ArrayRef)
490 }
491}
492
493#[inline]
496fn substr_view(original_view: &u128, substr: &str, start_offset: u32) -> u128 {
497 if substr.len() > 12 {
498 let view = ByteView::from(*original_view);
499 make_view(
500 substr.as_bytes(),
501 view.buffer_index,
502 view.offset + start_offset,
503 )
504 } else {
505 make_view(substr.as_bytes(), 0, 0)
506 }
507}
508
509#[inline(always)]
512fn split_view_loop<F>(
513 string_view_array: &StringViewArray,
514 views: &[u128],
515 views_buf: &mut Vec<u128>,
516 split_fn: F,
517) where
518 F: Fn(&str) -> Option<&str>,
519{
520 let empty_view = make_view(b"", 0, 0);
521 for (i, raw_view) in views.iter().enumerate() {
522 if string_view_array.is_null(i) {
523 views_buf.push(0);
524 continue;
525 }
526 let string = string_view_array.value(i);
527 match split_fn(string) {
528 Some(substr) => {
529 let start_offset = substr.as_ptr() as usize - string.as_ptr() as usize;
530 views_buf.push(substr_view(raw_view, substr, start_offset as u32));
531 }
532 None => views_buf.push(empty_view),
533 }
534 }
535}
536
537fn split_part_impl<'a, StringArrType, DelimiterArrType, B>(
538 string_array: &StringArrType,
539 delimiter_array: &DelimiterArrType,
540 n_array: &Int64Array,
541 mut builder: B,
542) -> Result<ArrayRef>
543where
544 StringArrType: StringArrayType<'a>,
545 DelimiterArrType: StringArrayType<'a>,
546 B: StringLikeArrayBuilder,
547{
548 for ((string, delimiter), n) in string_array
549 .iter()
550 .zip(delimiter_array.iter())
551 .zip(n_array.iter())
552 {
553 match (string, delimiter, n) {
554 (Some(string), Some(delimiter), Some(n)) => {
555 let result = match n.cmp(&0) {
556 std::cmp::Ordering::Greater => {
557 let idx: usize = (n - 1).try_into().map_err(|_| {
558 exec_datafusion_err!(
559 "split_part index {n} exceeds maximum supported value"
560 )
561 })?;
562 if delimiter.is_empty() {
563 (n == 1).then_some(string)
567 } else {
568 split_nth(string, delimiter, idx)
569 }
570 }
571 std::cmp::Ordering::Less => {
572 let idx: usize =
573 (n.unsigned_abs() - 1).try_into().map_err(|_| {
574 exec_datafusion_err!(
575 "split_part index {n} exceeds minimum supported value"
576 )
577 })?;
578 if delimiter.is_empty() {
579 (n == -1).then_some(string)
583 } else {
584 rsplit_nth(string, delimiter, idx)
585 }
586 }
587 std::cmp::Ordering::Equal => {
588 return exec_err!("field position must not be zero");
589 }
590 };
591 builder.append_value(result.unwrap_or(""));
592 }
593 _ => builder.append_null(),
594 }
595 }
596
597 Ok(Arc::new(builder.finish()) as ArrayRef)
598}
599
600#[cfg(test)]
601mod tests {
602 use arrow::array::{Array, AsArray, StringArray, StringViewArray};
603 use arrow::datatypes::DataType::Utf8;
604
605 use datafusion_common::ScalarValue;
606 use datafusion_common::{Result, exec_err};
607 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
608
609 use crate::string::split_part::SplitPartFunc;
610 use crate::utils::test::test_function;
611
612 #[test]
613 fn test_functions() -> Result<()> {
614 test_function!(
615 SplitPartFunc::new(),
616 vec![
617 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
618 "abc~@~def~@~ghi"
619 )))),
620 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
621 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
622 ],
623 Ok(Some("def")),
624 &str,
625 Utf8,
626 StringArray
627 );
628 test_function!(
629 SplitPartFunc::new(),
630 vec![
631 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
632 "abc~@~def~@~ghi"
633 )))),
634 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
635 ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
636 ],
637 Ok(Some("")),
638 &str,
639 Utf8,
640 StringArray
641 );
642 test_function!(
643 SplitPartFunc::new(),
644 vec![
645 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
646 "abc~@~def~@~ghi"
647 )))),
648 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
649 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
650 ],
651 Ok(Some("ghi")),
652 &str,
653 Utf8,
654 StringArray
655 );
656 test_function!(
657 SplitPartFunc::new(),
658 vec![
659 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
660 "abc~@~def~@~ghi"
661 )))),
662 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
663 ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
664 ],
665 exec_err!("field position must not be zero"),
666 &str,
667 Utf8,
668 StringArray
669 );
670 test_function!(
671 SplitPartFunc::new(),
672 vec![
673 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
674 "abc~@~def~@~ghi"
675 )))),
676 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
677 ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))),
678 ],
679 Ok(Some("")),
680 &str,
681 Utf8,
682 StringArray
683 );
684 test_function!(
686 SplitPartFunc::new(),
687 vec![
688 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
689 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
690 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
691 ],
692 Ok(Some("a")),
693 &str,
694 Utf8,
695 StringArray
696 );
697 test_function!(
698 SplitPartFunc::new(),
699 vec![
700 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
701 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
702 ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
703 ],
704 Ok(Some("")),
705 &str,
706 Utf8,
707 StringArray
708 );
709 test_function!(
710 SplitPartFunc::new(),
711 vec![
712 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
713 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
714 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
715 ],
716 Ok(Some("a,b")),
717 &str,
718 Utf8,
719 StringArray
720 );
721 test_function!(
722 SplitPartFunc::new(),
723 vec![
724 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
725 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
726 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
727 ],
728 Ok(Some("")),
729 &str,
730 Utf8,
731 StringArray
732 );
733 test_function!(
734 SplitPartFunc::new(),
735 vec![
736 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
737 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
738 ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
739 ],
740 Ok(Some("a,b")),
741 &str,
742 Utf8,
743 StringArray
744 );
745 test_function!(
746 SplitPartFunc::new(),
747 vec![
748 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
749 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
750 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
751 ],
752 Ok(Some("")),
753 &str,
754 Utf8,
755 StringArray
756 );
757
758 test_function!(
760 SplitPartFunc::new(),
761 vec![
762 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
763 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
764 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
765 ],
766 Ok(Some("a,b")),
767 &str,
768 Utf8,
769 StringArray
770 );
771 test_function!(
772 SplitPartFunc::new(),
773 vec![
774 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
775 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
776 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
777 ],
778 Ok(Some("a,b")),
779 &str,
780 Utf8,
781 StringArray
782 );
783 test_function!(
784 SplitPartFunc::new(),
785 vec![
786 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
787 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
788 ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
789 ],
790 Ok(Some("")),
791 &str,
792 Utf8,
793 StringArray
794 );
795
796 Ok(())
797 }
798
799 #[test]
800 fn test_split_part_stringview_sliced() -> Result<()> {
801 use super::split_part_scalar_view;
802
803 let strings: StringViewArray = vec![
804 Some("skip_this.value"),
805 Some("this_is_a_long_prefix.suffix"),
806 Some("short.val"),
807 Some("another_long_result.rest"),
808 None,
809 ]
810 .into_iter()
811 .collect();
812
813 let sliced = strings.slice(1, 4);
815 let result = split_part_scalar_view(&sliced, ".", 1)?;
816 let result = result.as_string_view();
817 assert_eq!(result.len(), 4);
818 assert_eq!(result.value(0), "this_is_a_long_prefix");
819 assert_eq!(result.value(1), "short");
820 assert_eq!(result.value(2), "another_long_result");
821 assert!(result.is_null(3));
822
823 Ok(())
824 }
825}