Skip to main content

datafusion_functions/string/
split_part.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::utf8_to_str_type;
19use arrow::array::{
20    ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType,
21    StringViewArray,
22};
23use arrow::array::{AsArray, GenericStringBuilder};
24use arrow::datatypes::DataType;
25use datafusion_common::ScalarValue;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::types::{NativeType, logical_int64, logical_string};
28use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility,
31};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37#[user_doc(
38    doc_section(label = "String Functions"),
39    description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
40    syntax_example = "split_part(str, delimiter, pos)",
41    sql_example = r#"```sql
42> select split_part('1.2.3.4.5', '.', 3);
43+--------------------------------------------------+
44| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
45+--------------------------------------------------+
46| 3                                                |
47+--------------------------------------------------+
48```"#,
49    standard_argument(name = "str", prefix = "String"),
50    argument(name = "delimiter", description = "String or character to split on."),
51    argument(
52        name = "pos",
53        description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct SplitPartFunc {
58    signature: Signature,
59}
60
61impl Default for SplitPartFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl SplitPartFunc {
68    pub fn new() -> Self {
69        Self {
70            signature: Signature::coercible(
71                vec![
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74                    Coercion::new_implicit(
75                        TypeSignatureClass::Native(logical_int64()),
76                        vec![TypeSignatureClass::Integer],
77                        NativeType::Int64,
78                    ),
79                ],
80                Volatility::Immutable,
81            ),
82        }
83    }
84}
85
86impl ScalarUDFImpl for SplitPartFunc {
87    fn as_any(&self) -> &dyn Any {
88        self
89    }
90
91    fn name(&self) -> &str {
92        "split_part"
93    }
94
95    fn signature(&self) -> &Signature {
96        &self.signature
97    }
98
99    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100        utf8_to_str_type(&arg_types[0], "split_part")
101    }
102
103    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104        let ScalarFunctionArgs { args, .. } = args;
105
106        // First, determine if any of the arguments is an Array
107        let len = args.iter().find_map(|arg| match arg {
108            ColumnarValue::Array(a) => Some(a.len()),
109            _ => None,
110        });
111
112        let inferred_length = len.unwrap_or(1);
113        let is_scalar = len.is_none();
114
115        // Convert all ColumnarValues to ArrayRefs
116        let args = args
117            .iter()
118            .map(|arg| match arg {
119                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
120                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
121            })
122            .collect::<Result<Vec<_>>>()?;
123
124        // Unpack the ArrayRefs from the arguments
125        let n_array = as_int64_array(&args[2])?;
126        let result = match (args[0].data_type(), args[1].data_type()) {
127            (DataType::Utf8View, DataType::Utf8View) => {
128                split_part_impl::<&StringViewArray, &StringViewArray, i32>(
129                    &args[0].as_string_view(),
130                    &args[1].as_string_view(),
131                    n_array,
132                )
133            }
134            (DataType::Utf8View, DataType::Utf8) => {
135                split_part_impl::<&StringViewArray, &GenericStringArray<i32>, i32>(
136                    &args[0].as_string_view(),
137                    &args[1].as_string::<i32>(),
138                    n_array,
139                )
140            }
141            (DataType::Utf8View, DataType::LargeUtf8) => {
142                split_part_impl::<&StringViewArray, &GenericStringArray<i64>, i32>(
143                    &args[0].as_string_view(),
144                    &args[1].as_string::<i64>(),
145                    n_array,
146                )
147            }
148            (DataType::Utf8, DataType::Utf8View) => {
149                split_part_impl::<&GenericStringArray<i32>, &StringViewArray, i32>(
150                    &args[0].as_string::<i32>(),
151                    &args[1].as_string_view(),
152                    n_array,
153                )
154            }
155            (DataType::LargeUtf8, DataType::Utf8View) => {
156                split_part_impl::<&GenericStringArray<i64>, &StringViewArray, i64>(
157                    &args[0].as_string::<i64>(),
158                    &args[1].as_string_view(),
159                    n_array,
160                )
161            }
162            (DataType::Utf8, DataType::Utf8) => {
163                split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i32>, i32>(
164                    &args[0].as_string::<i32>(),
165                    &args[1].as_string::<i32>(),
166                    n_array,
167                )
168            }
169            (DataType::LargeUtf8, DataType::LargeUtf8) => {
170                split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i64>, i64>(
171                    &args[0].as_string::<i64>(),
172                    &args[1].as_string::<i64>(),
173                    n_array,
174                )
175            }
176            (DataType::Utf8, DataType::LargeUtf8) => {
177                split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i64>, i32>(
178                    &args[0].as_string::<i32>(),
179                    &args[1].as_string::<i64>(),
180                    n_array,
181                )
182            }
183            (DataType::LargeUtf8, DataType::Utf8) => {
184                split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i32>, i64>(
185                    &args[0].as_string::<i64>(),
186                    &args[1].as_string::<i32>(),
187                    n_array,
188                )
189            }
190            _ => exec_err!("Unsupported combination of argument types for split_part"),
191        };
192        if is_scalar {
193            // If all inputs are scalar, keep the output as scalar
194            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
195            result.map(ColumnarValue::Scalar)
196        } else {
197            result.map(ColumnarValue::Array)
198        }
199    }
200
201    fn documentation(&self) -> Option<&Documentation> {
202        self.doc()
203    }
204}
205
206fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(
207    string_array: &StringArrType,
208    delimiter_array: &DelimiterArrType,
209    n_array: &Int64Array,
210) -> Result<ArrayRef>
211where
212    StringArrType: StringArrayType<'a>,
213    DelimiterArrType: StringArrayType<'a>,
214    StringArrayLen: OffsetSizeTrait,
215{
216    let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
217
218    string_array
219        .iter()
220        .zip(delimiter_array.iter())
221        .zip(n_array.iter())
222        .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> {
223            match (string, delimiter, n) {
224                (Some(string), Some(delimiter), Some(n)) => {
225                    let result = match n.cmp(&0) {
226                        std::cmp::Ordering::Greater => {
227                            // Positive index: use nth() to avoid collecting all parts
228                            // This stops iteration as soon as we find the nth element
229                            let idx: usize = (n - 1).try_into().map_err(|_| {
230                                exec_datafusion_err!(
231                                    "split_part index {n} exceeds maximum supported value"
232                                )
233                            })?;
234
235                            if delimiter.is_empty() {
236                                // Match PostgreSQL split_part behavior for empty delimiter:
237                                // treat the input as a single field ("ab" -> ["ab"]),
238                                // rather than Rust's split("") result (["", "a", "b", ""]).
239                                (n == 1).then_some(string)
240                            } else {
241                                string.split(delimiter).nth(idx)
242                            }
243                        }
244                        std::cmp::Ordering::Less => {
245                            // Negative index: use rsplit().nth() to efficiently get from the end
246                            // rsplit iterates in reverse, so -1 means first from rsplit (index 0)
247                            let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| {
248                                exec_datafusion_err!(
249                                    "split_part index {n} exceeds minimum supported value"
250                                )
251                            })?;
252                            if delimiter.is_empty() {
253                                // Match PostgreSQL split_part behavior for empty delimiter:
254                                // treat the input as a single field ("ab" -> ["ab"]),
255                                // rather than Rust's split("") result (["", "a", "b", ""]).
256                                (n == -1).then_some(string)
257                            } else {
258                                string.rsplit(delimiter).nth(idx)
259                            }
260                        }
261                        std::cmp::Ordering::Equal => {
262                            return exec_err!("field position must not be zero");
263                        }
264                    };
265                    builder.append_value(result.unwrap_or(""));
266                }
267                _ => builder.append_null(),
268            }
269            Ok(())
270        })?;
271
272    Ok(Arc::new(builder.finish()) as ArrayRef)
273}
274
275#[cfg(test)]
276mod tests {
277    use arrow::array::{Array, StringArray};
278    use arrow::datatypes::DataType::Utf8;
279
280    use datafusion_common::ScalarValue;
281    use datafusion_common::{Result, exec_err};
282    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
283
284    use crate::string::split_part::SplitPartFunc;
285    use crate::utils::test::test_function;
286
287    #[test]
288    fn test_functions() -> Result<()> {
289        test_function!(
290            SplitPartFunc::new(),
291            vec![
292                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
293                    "abc~@~def~@~ghi"
294                )))),
295                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
296                ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
297            ],
298            Ok(Some("def")),
299            &str,
300            Utf8,
301            StringArray
302        );
303        test_function!(
304            SplitPartFunc::new(),
305            vec![
306                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
307                    "abc~@~def~@~ghi"
308                )))),
309                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
310                ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
311            ],
312            Ok(Some("")),
313            &str,
314            Utf8,
315            StringArray
316        );
317        test_function!(
318            SplitPartFunc::new(),
319            vec![
320                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
321                    "abc~@~def~@~ghi"
322                )))),
323                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
324                ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
325            ],
326            Ok(Some("ghi")),
327            &str,
328            Utf8,
329            StringArray
330        );
331        test_function!(
332            SplitPartFunc::new(),
333            vec![
334                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
335                    "abc~@~def~@~ghi"
336                )))),
337                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
338                ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
339            ],
340            exec_err!("field position must not be zero"),
341            &str,
342            Utf8,
343            StringArray
344        );
345        test_function!(
346            SplitPartFunc::new(),
347            vec![
348                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
349                    "abc~@~def~@~ghi"
350                )))),
351                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
352                ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))),
353            ],
354            Ok(Some("")),
355            &str,
356            Utf8,
357            StringArray
358        );
359        // Edge cases with delimiters
360        test_function!(
361            SplitPartFunc::new(),
362            vec![
363                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
364                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
365                ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
366            ],
367            Ok(Some("a")),
368            &str,
369            Utf8,
370            StringArray
371        );
372        test_function!(
373            SplitPartFunc::new(),
374            vec![
375                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
376                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))),
377                ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
378            ],
379            Ok(Some("")),
380            &str,
381            Utf8,
382            StringArray
383        );
384        test_function!(
385            SplitPartFunc::new(),
386            vec![
387                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
388                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
389                ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
390            ],
391            Ok(Some("a,b")),
392            &str,
393            Utf8,
394            StringArray
395        );
396        test_function!(
397            SplitPartFunc::new(),
398            vec![
399                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
400                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
401                ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
402            ],
403            Ok(Some("")),
404            &str,
405            Utf8,
406            StringArray
407        );
408        test_function!(
409            SplitPartFunc::new(),
410            vec![
411                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
412                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
413                ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
414            ],
415            Ok(Some("a,b")),
416            &str,
417            Utf8,
418            StringArray
419        );
420        test_function!(
421            SplitPartFunc::new(),
422            vec![
423                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
424                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
425                ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
426            ],
427            Ok(Some("")),
428            &str,
429            Utf8,
430            StringArray
431        );
432
433        // Edge cases with delimiters with negative n
434        test_function!(
435            SplitPartFunc::new(),
436            vec![
437                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
438                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
439                ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
440            ],
441            Ok(Some("a,b")),
442            &str,
443            Utf8,
444            StringArray
445        );
446        test_function!(
447            SplitPartFunc::new(),
448            vec![
449                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
450                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))),
451                ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
452            ],
453            Ok(Some("a,b")),
454            &str,
455            Utf8,
456            StringArray
457        );
458        test_function!(
459            SplitPartFunc::new(),
460            vec![
461                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))),
462                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
463                ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
464            ],
465            Ok(Some("")),
466            &str,
467            Utf8,
468            StringArray
469        );
470
471        Ok(())
472    }
473}