1use crate::utils::utf8_to_str_type;
19use arrow::array::{
20 ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType,
21 StringViewArray,
22};
23use arrow::array::{AsArray, GenericStringBuilder};
24use arrow::datatypes::DataType;
25use datafusion_common::ScalarValue;
26use datafusion_common::cast::as_int64_array;
27use datafusion_common::types::{NativeType, logical_int64, logical_string};
28use datafusion_common::{DataFusionError, Result, exec_err};
29use datafusion_expr::{
30 Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility,
31};
32use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
40 syntax_example = "split_part(str, delimiter, pos)",
41 sql_example = r#"```sql
42> select split_part('1.2.3.4.5', '.', 3);
43+--------------------------------------------------+
44| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
45+--------------------------------------------------+
46| 3 |
47+--------------------------------------------------+
48```"#,
49 standard_argument(name = "str", prefix = "String"),
50 argument(name = "delimiter", description = "String or character to split on."),
51 argument(name = "pos", description = "Position of the part to return.")
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct SplitPartFunc {
55 signature: Signature,
56}
57
58impl Default for SplitPartFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl SplitPartFunc {
65 pub fn new() -> Self {
66 Self {
67 signature: Signature::coercible(
68 vec![
69 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
70 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
71 Coercion::new_implicit(
72 TypeSignatureClass::Native(logical_int64()),
73 vec![TypeSignatureClass::Integer],
74 NativeType::Int64,
75 ),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for SplitPartFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "split_part"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_str_type(&arg_types[0], "split_part")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let ScalarFunctionArgs { args, .. } = args;
102
103 let len = args.iter().find_map(|arg| match arg {
105 ColumnarValue::Array(a) => Some(a.len()),
106 _ => None,
107 });
108
109 let inferred_length = len.unwrap_or(1);
110 let is_scalar = len.is_none();
111
112 let args = args
114 .iter()
115 .map(|arg| match arg {
116 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
117 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
118 })
119 .collect::<Result<Vec<_>>>()?;
120
121 let n_array = as_int64_array(&args[2])?;
123 let result = match (args[0].data_type(), args[1].data_type()) {
124 (DataType::Utf8View, DataType::Utf8View) => {
125 split_part_impl::<&StringViewArray, &StringViewArray, i32>(
126 &args[0].as_string_view(),
127 &args[1].as_string_view(),
128 n_array,
129 )
130 }
131 (DataType::Utf8View, DataType::Utf8) => {
132 split_part_impl::<&StringViewArray, &GenericStringArray<i32>, i32>(
133 &args[0].as_string_view(),
134 &args[1].as_string::<i32>(),
135 n_array,
136 )
137 }
138 (DataType::Utf8View, DataType::LargeUtf8) => {
139 split_part_impl::<&StringViewArray, &GenericStringArray<i64>, i32>(
140 &args[0].as_string_view(),
141 &args[1].as_string::<i64>(),
142 n_array,
143 )
144 }
145 (DataType::Utf8, DataType::Utf8View) => {
146 split_part_impl::<&GenericStringArray<i32>, &StringViewArray, i32>(
147 &args[0].as_string::<i32>(),
148 &args[1].as_string_view(),
149 n_array,
150 )
151 }
152 (DataType::LargeUtf8, DataType::Utf8View) => {
153 split_part_impl::<&GenericStringArray<i64>, &StringViewArray, i64>(
154 &args[0].as_string::<i64>(),
155 &args[1].as_string_view(),
156 n_array,
157 )
158 }
159 (DataType::Utf8, DataType::Utf8) => {
160 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i32>, i32>(
161 &args[0].as_string::<i32>(),
162 &args[1].as_string::<i32>(),
163 n_array,
164 )
165 }
166 (DataType::LargeUtf8, DataType::LargeUtf8) => {
167 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i64>, i64>(
168 &args[0].as_string::<i64>(),
169 &args[1].as_string::<i64>(),
170 n_array,
171 )
172 }
173 (DataType::Utf8, DataType::LargeUtf8) => {
174 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i64>, i32>(
175 &args[0].as_string::<i32>(),
176 &args[1].as_string::<i64>(),
177 n_array,
178 )
179 }
180 (DataType::LargeUtf8, DataType::Utf8) => {
181 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i32>, i64>(
182 &args[0].as_string::<i64>(),
183 &args[1].as_string::<i32>(),
184 n_array,
185 )
186 }
187 _ => exec_err!("Unsupported combination of argument types for split_part"),
188 };
189 if is_scalar {
190 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
192 result.map(ColumnarValue::Scalar)
193 } else {
194 result.map(ColumnarValue::Array)
195 }
196 }
197
198 fn documentation(&self) -> Option<&Documentation> {
199 self.doc()
200 }
201}
202
203fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(
204 string_array: &StringArrType,
205 delimiter_array: &DelimiterArrType,
206 n_array: &Int64Array,
207) -> Result<ArrayRef>
208where
209 StringArrType: StringArrayType<'a>,
210 DelimiterArrType: StringArrayType<'a>,
211 StringArrayLen: OffsetSizeTrait,
212{
213 let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
214
215 string_array
216 .iter()
217 .zip(delimiter_array.iter())
218 .zip(n_array.iter())
219 .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> {
220 match (string, delimiter, n) {
221 (Some(string), Some(delimiter), Some(n)) => {
222 let split_string: Vec<&str> = string.split(delimiter).collect();
223 let len = split_string.len();
224
225 let index = match n.cmp(&0) {
226 std::cmp::Ordering::Less => len as i64 + n,
227 std::cmp::Ordering::Equal => {
228 return exec_err!("field position must not be zero");
229 }
230 std::cmp::Ordering::Greater => n - 1,
231 } as usize;
232
233 if index < len {
234 builder.append_value(split_string[index]);
235 } else {
236 builder.append_value("");
237 }
238 }
239 _ => builder.append_null(),
240 }
241 Ok(())
242 })?;
243
244 Ok(Arc::new(builder.finish()) as ArrayRef)
245}
246
247#[cfg(test)]
248mod tests {
249 use arrow::array::{Array, StringArray};
250 use arrow::datatypes::DataType::Utf8;
251
252 use datafusion_common::ScalarValue;
253 use datafusion_common::{Result, exec_err};
254 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
255
256 use crate::string::split_part::SplitPartFunc;
257 use crate::utils::test::test_function;
258
259 #[test]
260 fn test_functions() -> Result<()> {
261 test_function!(
262 SplitPartFunc::new(),
263 vec![
264 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
265 "abc~@~def~@~ghi"
266 )))),
267 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
268 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
269 ],
270 Ok(Some("def")),
271 &str,
272 Utf8,
273 StringArray
274 );
275 test_function!(
276 SplitPartFunc::new(),
277 vec![
278 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
279 "abc~@~def~@~ghi"
280 )))),
281 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
282 ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
283 ],
284 Ok(Some("")),
285 &str,
286 Utf8,
287 StringArray
288 );
289 test_function!(
290 SplitPartFunc::new(),
291 vec![
292 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
293 "abc~@~def~@~ghi"
294 )))),
295 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
296 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
297 ],
298 Ok(Some("ghi")),
299 &str,
300 Utf8,
301 StringArray
302 );
303 test_function!(
304 SplitPartFunc::new(),
305 vec![
306 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
307 "abc~@~def~@~ghi"
308 )))),
309 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
310 ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
311 ],
312 exec_err!("field position must not be zero"),
313 &str,
314 Utf8,
315 StringArray
316 );
317
318 Ok(())
319 }
320}