Skip to main content

datafusion_functions/unicode/
find_in_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23    PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41    syntax_example = "find_in_set(str, strlist)",
42    sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2                                      |
48+----------------------------------------+
49```"#,
50    argument(name = "str", description = "String expression to find in strlist."),
51    argument(
52        name = "strlist",
53        description = "A string list is a string composed of substrings separated by , characters."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58    signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl FindInSetFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    Exact(vec![Utf8View, Utf8View]),
74                    Exact(vec![Utf8, Utf8]),
75                    Exact(vec![LargeUtf8, LargeUtf8]),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn name(&self) -> &str {
89        "find_in_set"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        utf8_to_int_type(&arg_types[0], "find_in_set")
98    }
99
100    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101        let return_field = args.return_field;
102        let [string, str_list] = take_function_args(self.name(), args.args)?;
103
104        match (string, str_list) {
105            // both inputs are scalars
106            (
107                ColumnarValue::Scalar(
108                    ScalarValue::Utf8View(string)
109                    | ScalarValue::Utf8(string)
110                    | ScalarValue::LargeUtf8(string),
111                ),
112                ColumnarValue::Scalar(
113                    ScalarValue::Utf8View(str_list)
114                    | ScalarValue::Utf8(str_list)
115                    | ScalarValue::LargeUtf8(str_list),
116                ),
117            ) => {
118                let res = match (string, str_list) {
119                    (Some(string), Some(str_list)) => {
120                        let position = str_list
121                            .split(',')
122                            .position(|s| s == string)
123                            .map_or(0, |idx| idx + 1);
124
125                        Some(position as i32)
126                    }
127                    _ => None,
128                };
129                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
130            }
131
132            // `string` is an array, `str_list` is scalar
133            (
134                ColumnarValue::Array(str_array),
135                ColumnarValue::Scalar(
136                    ScalarValue::Utf8View(str_list_literal)
137                    | ScalarValue::Utf8(str_list_literal)
138                    | ScalarValue::LargeUtf8(str_list_literal),
139                ),
140            ) => {
141                match str_list_literal {
142                    // find_in_set(column_a, null) = null
143                    None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
144                        return_field.data_type(),
145                    )?)),
146                    Some(str_list_literal) => {
147                        let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
148                        let result = match str_array.data_type() {
149                            DataType::Utf8 => {
150                                let string_array = str_array.as_string::<i32>();
151                                find_in_set_right_literal::<Int32Type, _>(
152                                    string_array,
153                                    &str_list,
154                                )
155                            }
156                            DataType::LargeUtf8 => {
157                                let string_array = str_array.as_string::<i64>();
158                                find_in_set_right_literal::<Int64Type, _>(
159                                    string_array,
160                                    &str_list,
161                                )
162                            }
163                            DataType::Utf8View => {
164                                let string_array = str_array.as_string_view();
165                                find_in_set_right_literal::<Int32Type, _>(
166                                    string_array,
167                                    &str_list,
168                                )
169                            }
170                            other => {
171                                exec_err!(
172                                    "Unsupported data type {other:?} for function find_in_set"
173                                )
174                            }
175                        };
176                        Ok(ColumnarValue::Array(Arc::new(result?)))
177                    }
178                }
179            }
180
181            // `string` is scalar, `str_list` is an array
182            (
183                ColumnarValue::Scalar(
184                    ScalarValue::Utf8View(string_literal)
185                    | ScalarValue::Utf8(string_literal)
186                    | ScalarValue::LargeUtf8(string_literal),
187                ),
188                ColumnarValue::Array(str_list_array),
189            ) => {
190                match string_literal {
191                    // find_in_set(null, column_b) = null
192                    None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
193                        return_field.data_type(),
194                    )?)),
195                    Some(string) => {
196                        let result = match str_list_array.data_type() {
197                            DataType::Utf8 => {
198                                let str_list = str_list_array.as_string::<i32>();
199                                find_in_set_left_literal::<Int32Type, _>(
200                                    &string, str_list,
201                                )
202                            }
203                            DataType::LargeUtf8 => {
204                                let str_list = str_list_array.as_string::<i64>();
205                                find_in_set_left_literal::<Int64Type, _>(
206                                    &string, str_list,
207                                )
208                            }
209                            DataType::Utf8View => {
210                                let str_list = str_list_array.as_string_view();
211                                find_in_set_left_literal::<Int32Type, _>(
212                                    &string, str_list,
213                                )
214                            }
215                            other => {
216                                exec_err!(
217                                    "Unsupported data type {other:?} for function find_in_set"
218                                )
219                            }
220                        };
221                        Ok(ColumnarValue::Array(Arc::new(result?)))
222                    }
223                }
224            }
225
226            // both inputs are arrays
227            (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
228                let res = find_in_set(&base_array, &exp_array)?;
229
230                Ok(ColumnarValue::Array(res))
231            }
232            _ => {
233                internal_err!("Invalid argument types for `find_in_set` function")
234            }
235        }
236    }
237
238    fn documentation(&self) -> Option<&Documentation> {
239        self.doc()
240    }
241}
242
243/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist`
244/// consisting of N substrings. A string list is a string composed of substrings separated by `,`
245/// characters.
246fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
247    match str.data_type() {
248        DataType::Utf8 => {
249            let string_array = str.as_string::<i32>();
250            let str_list_array = str_list.as_string::<i32>();
251            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
252        }
253        DataType::LargeUtf8 => {
254            let string_array = str.as_string::<i64>();
255            let str_list_array = str_list.as_string::<i64>();
256            find_in_set_general::<Int64Type, _>(string_array, str_list_array)
257        }
258        DataType::Utf8View => {
259            let string_array = str.as_string_view();
260            let str_list_array = str_list.as_string_view();
261            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
262        }
263        other => {
264            exec_err!("Unsupported data type {other:?} for function find_in_set")
265        }
266    }
267}
268
269fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
270where
271    T: ArrowPrimitiveType,
272    T::Native: OffsetSizeTrait,
273    V: ArrayAccessor<Item = &'a str>,
274{
275    let string_iter = ArrayIter::new(string_array);
276    let str_list_iter = ArrayIter::new(str_list_array);
277
278    let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
279
280    string_iter
281        .zip(str_list_iter)
282        .for_each(
283            |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
284                (Some(string), Some(str_list)) => {
285                    let position = str_list
286                        .split(',')
287                        .position(|s| s == string)
288                        .map_or(0, |idx| idx + 1);
289                    builder.append_value(T::Native::from_usize(position).unwrap());
290                }
291                _ => builder.append_null(),
292            },
293        );
294
295    Ok(Arc::new(builder.finish()) as ArrayRef)
296}
297
298fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
299where
300    T: ArrowPrimitiveType,
301    T::Native: OffsetSizeTrait,
302    V: ArrayAccessor<Item = &'a str>,
303{
304    let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
305
306    let str_list_iter = ArrayIter::new(str_list_array);
307
308    str_list_iter.for_each(|str_list_opt| match str_list_opt {
309        Some(str_list) => {
310            let position = str_list
311                .split(',')
312                .position(|s| s == string)
313                .map_or(0, |idx| idx + 1);
314            builder.append_value(T::Native::from_usize(position).unwrap());
315        }
316        None => builder.append_null(),
317    });
318
319    Ok(Arc::new(builder.finish()) as ArrayRef)
320}
321
322fn find_in_set_right_literal<'a, T, V>(
323    string_array: V,
324    str_list: &[&str],
325) -> Result<ArrayRef>
326where
327    T: ArrowPrimitiveType,
328    T::Native: OffsetSizeTrait,
329    V: ArrayAccessor<Item = &'a str>,
330{
331    let mut builder = PrimitiveArray::<T>::builder(string_array.len());
332
333    let string_iter = ArrayIter::new(string_array);
334
335    string_iter.for_each(|string_opt| match string_opt {
336        Some(string) => {
337            let position = str_list
338                .iter()
339                .position(|s| *s == string)
340                .map_or(0, |idx| idx + 1);
341            builder.append_value(T::Native::from_usize(position).unwrap());
342        }
343        None => builder.append_null(),
344    });
345
346    Ok(Arc::new(builder.finish()) as ArrayRef)
347}
348
349#[cfg(test)]
350mod tests {
351    use crate::unicode::find_in_set::FindInSetFunc;
352    use crate::utils::test::test_function;
353    use arrow::array::{Array, Int32Array, StringArray};
354    use arrow::datatypes::{DataType::Int32, Field};
355    use datafusion_common::config::ConfigOptions;
356    use datafusion_common::{Result, ScalarValue};
357    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
358    use std::sync::Arc;
359
360    #[test]
361    fn test_functions() -> Result<()> {
362        test_function!(
363            FindInSetFunc::new(),
364            vec![
365                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
366                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
367            ],
368            Ok(Some(1)),
369            i32,
370            Int32,
371            Int32Array
372        );
373        test_function!(
374            FindInSetFunc::new(),
375            vec![
376                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
377                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
378                    "a,Д,🔥"
379                )))),
380            ],
381            Ok(Some(3)),
382            i32,
383            Int32,
384            Int32Array
385        );
386        test_function!(
387            FindInSetFunc::new(),
388            vec![
389                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
390                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
391            ],
392            Ok(Some(0)),
393            i32,
394            Int32,
395            Int32Array
396        );
397        test_function!(
398            FindInSetFunc::new(),
399            vec![
400                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401                    "Apache Software Foundation"
402                )))),
403                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
404                    "Github,Apache Software Foundation,DataFusion"
405                )))),
406            ],
407            Ok(Some(2)),
408            i32,
409            Int32,
410            Int32Array
411        );
412        test_function!(
413            FindInSetFunc::new(),
414            vec![
415                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
416                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
417            ],
418            Ok(Some(0)),
419            i32,
420            Int32,
421            Int32Array
422        );
423        test_function!(
424            FindInSetFunc::new(),
425            vec![
426                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
427                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
428            ],
429            Ok(Some(0)),
430            i32,
431            Int32,
432            Int32Array
433        );
434        test_function!(
435            FindInSetFunc::new(),
436            vec![
437                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
438                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
439            ],
440            Ok(None),
441            i32,
442            Int32,
443            Int32Array
444        );
445        test_function!(
446            FindInSetFunc::new(),
447            vec![
448                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
449                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
450            ],
451            Ok(None),
452            i32,
453            Int32,
454            Int32Array
455        );
456
457        Ok(())
458    }
459
460    macro_rules! test_find_in_set {
461        ($test_name:ident, $args:expr, $expected:expr) => {
462            #[test]
463            fn $test_name() -> Result<()> {
464                let fis = crate::unicode::find_in_set();
465
466                let args = $args;
467                let expected = $expected;
468
469                let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
470                let cardinality = args
471                    .iter()
472                    .fold(Option::<usize>::None, |acc, arg| match arg {
473                        ColumnarValue::Scalar(_) => acc,
474                        ColumnarValue::Array(a) => Some(a.len()),
475                    })
476                    .unwrap_or(1);
477                let return_type = fis.return_type(&type_array)?;
478                let arg_fields = args
479                    .iter()
480                    .enumerate()
481                    .map(|(idx, a)| {
482                        Field::new(format!("arg_{idx}"), a.data_type(), true).into()
483                    })
484                    .collect::<Vec<_>>();
485                let result = fis.invoke_with_args(ScalarFunctionArgs {
486                    args,
487                    arg_fields,
488                    number_rows: cardinality,
489                    return_field: Field::new("f", return_type, true).into(),
490                    config_options: Arc::new(ConfigOptions::default()),
491                });
492                assert!(result.is_ok());
493
494                let result = result?
495                    .to_array(cardinality)
496                    .expect("Failed to convert to array");
497                let result = result
498                    .as_any()
499                    .downcast_ref::<Int32Array>()
500                    .expect("Failed to convert to type");
501                assert_eq!(*result, expected);
502
503                Ok(())
504            }
505        };
506    }
507
508    test_find_in_set!(
509        test_find_in_set_with_scalar_args,
510        vec![
511            ColumnarValue::Array(Arc::new(StringArray::from(vec![
512                "", "a", "b", "c", "d"
513            ]))),
514            ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
515        ],
516        Int32Array::from(vec![0, 0, 1, 2, 3])
517    );
518    test_find_in_set!(
519        test_find_in_set_with_scalar_args_2,
520        vec![
521            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
522                "ApacheSoftware".to_string()
523            ))),
524            ColumnarValue::Array(Arc::new(StringArray::from(vec![
525                "a,b,c",
526                "ApacheSoftware,Github,DataFusion",
527                ""
528            ]))),
529        ],
530        Int32Array::from(vec![0, 1, 0])
531    );
532    test_find_in_set!(
533        test_find_in_set_with_scalar_args_3,
534        vec![
535            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
536            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
537        ],
538        Int32Array::from(vec![None::<i32>; 3])
539    );
540    test_find_in_set!(
541        test_find_in_set_with_scalar_args_4,
542        vec![
543            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
544            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
545        ],
546        Int32Array::from(vec![None::<i32>; 3])
547    );
548}