datafusion_functions/unicode/
find_in_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23    PrimitiveArray, new_null_array,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41    syntax_example = "find_in_set(str, strlist)",
42    sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2                                      |
48+----------------------------------------+
49```"#,
50    argument(name = "str", description = "String expression to find in strlist."),
51    argument(
52        name = "strlist",
53        description = "A string list is a string composed of substrings separated by , characters."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58    signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl FindInSetFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    Exact(vec![Utf8View, Utf8View]),
74                    Exact(vec![Utf8, Utf8]),
75                    Exact(vec![LargeUtf8, LargeUtf8]),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn name(&self) -> &str {
89        "find_in_set"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        utf8_to_int_type(&arg_types[0], "find_in_set")
98    }
99
100    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101        let ScalarFunctionArgs { args, .. } = args;
102
103        let [string, str_list] = take_function_args(self.name(), args)?;
104
105        match (string, str_list) {
106            // both inputs are scalars
107            (
108                ColumnarValue::Scalar(
109                    ScalarValue::Utf8View(string)
110                    | ScalarValue::Utf8(string)
111                    | ScalarValue::LargeUtf8(string),
112                ),
113                ColumnarValue::Scalar(
114                    ScalarValue::Utf8View(str_list)
115                    | ScalarValue::Utf8(str_list)
116                    | ScalarValue::LargeUtf8(str_list),
117                ),
118            ) => {
119                let res = match (string, str_list) {
120                    (Some(string), Some(str_list)) => {
121                        let position = str_list
122                            .split(',')
123                            .position(|s| s == string)
124                            .map_or(0, |idx| idx + 1);
125
126                        Some(position as i32)
127                    }
128                    _ => None,
129                };
130                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131            }
132
133            // `string` is an array, `str_list` is scalar
134            (
135                ColumnarValue::Array(str_array),
136                ColumnarValue::Scalar(
137                    ScalarValue::Utf8View(str_list_literal)
138                    | ScalarValue::Utf8(str_list_literal)
139                    | ScalarValue::LargeUtf8(str_list_literal),
140                ),
141            ) => {
142                let result_array = match str_list_literal {
143                    // find_in_set(column_a, null) = null
144                    None => new_null_array(str_array.data_type(), str_array.len()),
145                    Some(str_list_literal) => {
146                        let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147                        let result = match str_array.data_type() {
148                            DataType::Utf8 => {
149                                let string_array = str_array.as_string::<i32>();
150                                find_in_set_right_literal::<Int32Type, _>(
151                                    string_array,
152                                    &str_list,
153                                )
154                            }
155                            DataType::LargeUtf8 => {
156                                let string_array = str_array.as_string::<i64>();
157                                find_in_set_right_literal::<Int64Type, _>(
158                                    string_array,
159                                    &str_list,
160                                )
161                            }
162                            DataType::Utf8View => {
163                                let string_array = str_array.as_string_view();
164                                find_in_set_right_literal::<Int32Type, _>(
165                                    string_array,
166                                    &str_list,
167                                )
168                            }
169                            other => {
170                                exec_err!(
171                                    "Unsupported data type {other:?} for function find_in_set"
172                                )
173                            }
174                        };
175                        Arc::new(result?)
176                    }
177                };
178                Ok(ColumnarValue::Array(result_array))
179            }
180
181            // `string` is scalar, `str_list` is an array
182            (
183                ColumnarValue::Scalar(
184                    ScalarValue::Utf8View(string_literal)
185                    | ScalarValue::Utf8(string_literal)
186                    | ScalarValue::LargeUtf8(string_literal),
187                ),
188                ColumnarValue::Array(str_list_array),
189            ) => {
190                let res = match string_literal {
191                    // find_in_set(null, column_b) = null
192                    None => {
193                        new_null_array(str_list_array.data_type(), str_list_array.len())
194                    }
195                    Some(string) => {
196                        let result = match str_list_array.data_type() {
197                            DataType::Utf8 => {
198                                let str_list = str_list_array.as_string::<i32>();
199                                find_in_set_left_literal::<Int32Type, _>(
200                                    &string, str_list,
201                                )
202                            }
203                            DataType::LargeUtf8 => {
204                                let str_list = str_list_array.as_string::<i64>();
205                                find_in_set_left_literal::<Int64Type, _>(
206                                    &string, str_list,
207                                )
208                            }
209                            DataType::Utf8View => {
210                                let str_list = str_list_array.as_string_view();
211                                find_in_set_left_literal::<Int32Type, _>(
212                                    &string, str_list,
213                                )
214                            }
215                            other => {
216                                exec_err!(
217                                    "Unsupported data type {other:?} for function find_in_set"
218                                )
219                            }
220                        };
221                        Arc::new(result?)
222                    }
223                };
224                Ok(ColumnarValue::Array(res))
225            }
226
227            // both inputs are arrays
228            (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
229                let res = find_in_set(&base_array, &exp_array)?;
230
231                Ok(ColumnarValue::Array(res))
232            }
233            _ => {
234                internal_err!("Invalid argument types for `find_in_set` function")
235            }
236        }
237    }
238
239    fn documentation(&self) -> Option<&Documentation> {
240        self.doc()
241    }
242}
243
244/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist`
245/// consisting of N substrings. A string list is a string composed of substrings separated by `,`
246/// characters.
247fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
248    match str.data_type() {
249        DataType::Utf8 => {
250            let string_array = str.as_string::<i32>();
251            let str_list_array = str_list.as_string::<i32>();
252            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253        }
254        DataType::LargeUtf8 => {
255            let string_array = str.as_string::<i64>();
256            let str_list_array = str_list.as_string::<i64>();
257            find_in_set_general::<Int64Type, _>(string_array, str_list_array)
258        }
259        DataType::Utf8View => {
260            let string_array = str.as_string_view();
261            let str_list_array = str_list.as_string_view();
262            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
263        }
264        other => {
265            exec_err!("Unsupported data type {other:?} for function find_in_set")
266        }
267    }
268}
269
270fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
271where
272    T: ArrowPrimitiveType,
273    T::Native: OffsetSizeTrait,
274    V: ArrayAccessor<Item = &'a str>,
275{
276    let string_iter = ArrayIter::new(string_array);
277    let str_list_iter = ArrayIter::new(str_list_array);
278
279    let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
280
281    string_iter
282        .zip(str_list_iter)
283        .for_each(
284            |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
285                (Some(string), Some(str_list)) => {
286                    let position = str_list
287                        .split(',')
288                        .position(|s| s == string)
289                        .map_or(0, |idx| idx + 1);
290                    builder.append_value(T::Native::from_usize(position).unwrap());
291                }
292                _ => builder.append_null(),
293            },
294        );
295
296    Ok(Arc::new(builder.finish()) as ArrayRef)
297}
298
299fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
300where
301    T: ArrowPrimitiveType,
302    T::Native: OffsetSizeTrait,
303    V: ArrayAccessor<Item = &'a str>,
304{
305    let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
306
307    let str_list_iter = ArrayIter::new(str_list_array);
308
309    str_list_iter.for_each(|str_list_opt| match str_list_opt {
310        Some(str_list) => {
311            let position = str_list
312                .split(',')
313                .position(|s| s == string)
314                .map_or(0, |idx| idx + 1);
315            builder.append_value(T::Native::from_usize(position).unwrap());
316        }
317        None => builder.append_null(),
318    });
319
320    Ok(Arc::new(builder.finish()) as ArrayRef)
321}
322
323fn find_in_set_right_literal<'a, T, V>(
324    string_array: V,
325    str_list: &[&str],
326) -> Result<ArrayRef>
327where
328    T: ArrowPrimitiveType,
329    T::Native: OffsetSizeTrait,
330    V: ArrayAccessor<Item = &'a str>,
331{
332    let mut builder = PrimitiveArray::<T>::builder(string_array.len());
333
334    let string_iter = ArrayIter::new(string_array);
335
336    string_iter.for_each(|string_opt| match string_opt {
337        Some(string) => {
338            let position = str_list
339                .iter()
340                .position(|s| *s == string)
341                .map_or(0, |idx| idx + 1);
342            builder.append_value(T::Native::from_usize(position).unwrap());
343        }
344        None => builder.append_null(),
345    });
346
347    Ok(Arc::new(builder.finish()) as ArrayRef)
348}
349
350#[cfg(test)]
351mod tests {
352    use crate::unicode::find_in_set::FindInSetFunc;
353    use crate::utils::test::test_function;
354    use arrow::array::{Array, Int32Array, StringArray};
355    use arrow::datatypes::{DataType::Int32, Field};
356    use datafusion_common::config::ConfigOptions;
357    use datafusion_common::{Result, ScalarValue};
358    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
359    use std::sync::Arc;
360
361    #[test]
362    fn test_functions() -> Result<()> {
363        test_function!(
364            FindInSetFunc::new(),
365            vec![
366                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
367                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
368            ],
369            Ok(Some(1)),
370            i32,
371            Int32,
372            Int32Array
373        );
374        test_function!(
375            FindInSetFunc::new(),
376            vec![
377                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
378                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
379                    "a,Д,🔥"
380                )))),
381            ],
382            Ok(Some(3)),
383            i32,
384            Int32,
385            Int32Array
386        );
387        test_function!(
388            FindInSetFunc::new(),
389            vec![
390                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
391                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
392            ],
393            Ok(Some(0)),
394            i32,
395            Int32,
396            Int32Array
397        );
398        test_function!(
399            FindInSetFunc::new(),
400            vec![
401                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
402                    "Apache Software Foundation"
403                )))),
404                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
405                    "Github,Apache Software Foundation,DataFusion"
406                )))),
407            ],
408            Ok(Some(2)),
409            i32,
410            Int32,
411            Int32Array
412        );
413        test_function!(
414            FindInSetFunc::new(),
415            vec![
416                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
417                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
418            ],
419            Ok(Some(0)),
420            i32,
421            Int32,
422            Int32Array
423        );
424        test_function!(
425            FindInSetFunc::new(),
426            vec![
427                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
428                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
429            ],
430            Ok(Some(0)),
431            i32,
432            Int32,
433            Int32Array
434        );
435        test_function!(
436            FindInSetFunc::new(),
437            vec![
438                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
439                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
440            ],
441            Ok(None),
442            i32,
443            Int32,
444            Int32Array
445        );
446        test_function!(
447            FindInSetFunc::new(),
448            vec![
449                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
450                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
451            ],
452            Ok(None),
453            i32,
454            Int32,
455            Int32Array
456        );
457
458        Ok(())
459    }
460
461    macro_rules! test_find_in_set {
462        ($test_name:ident, $args:expr, $expected:expr) => {
463            #[test]
464            fn $test_name() -> Result<()> {
465                let fis = crate::unicode::find_in_set();
466
467                let args = $args;
468                let expected = $expected;
469
470                let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
471                let cardinality = args
472                    .iter()
473                    .fold(Option::<usize>::None, |acc, arg| match arg {
474                        ColumnarValue::Scalar(_) => acc,
475                        ColumnarValue::Array(a) => Some(a.len()),
476                    })
477                    .unwrap_or(1);
478                let return_type = fis.return_type(&type_array)?;
479                let arg_fields = args
480                    .iter()
481                    .enumerate()
482                    .map(|(idx, a)| {
483                        Field::new(format!("arg_{idx}"), a.data_type(), true).into()
484                    })
485                    .collect::<Vec<_>>();
486                let result = fis.invoke_with_args(ScalarFunctionArgs {
487                    args,
488                    arg_fields,
489                    number_rows: cardinality,
490                    return_field: Field::new("f", return_type, true).into(),
491                    config_options: Arc::new(ConfigOptions::default()),
492                });
493                assert!(result.is_ok());
494
495                let result = result?
496                    .to_array(cardinality)
497                    .expect("Failed to convert to array");
498                let result = result
499                    .as_any()
500                    .downcast_ref::<Int32Array>()
501                    .expect("Failed to convert to type");
502                assert_eq!(*result, expected);
503
504                Ok(())
505            }
506        };
507    }
508
509    test_find_in_set!(
510        test_find_in_set_with_scalar_args,
511        vec![
512            ColumnarValue::Array(Arc::new(StringArray::from(vec![
513                "", "a", "b", "c", "d"
514            ]))),
515            ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
516        ],
517        Int32Array::from(vec![0, 0, 1, 2, 3])
518    );
519    test_find_in_set!(
520        test_find_in_set_with_scalar_args_2,
521        vec![
522            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
523                "ApacheSoftware".to_string()
524            ))),
525            ColumnarValue::Array(Arc::new(StringArray::from(vec![
526                "a,b,c",
527                "ApacheSoftware,Github,DataFusion",
528                ""
529            ]))),
530        ],
531        Int32Array::from(vec![0, 1, 0])
532    );
533    test_find_in_set!(
534        test_find_in_set_with_scalar_args_3,
535        vec![
536            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
537            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
538        ],
539        Int32Array::from(vec![None::<i32>; 3])
540    );
541    test_find_in_set!(
542        test_find_in_set_with_scalar_args_4,
543        vec![
544            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
545            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
546        ],
547        Int32Array::from(vec![None::<i32>; 3])
548    );
549}