datafusion_functions/unicode/
find_in_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23    OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29    exec_err, internal_err, utils::take_function_args, Result, ScalarValue,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41    syntax_example = "find_in_set(str, strlist)",
42    sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2                                      |
48+----------------------------------------+
49```"#,
50    argument(name = "str", description = "String expression to find in strlist."),
51    argument(
52        name = "strlist",
53        description = "A string list is a string composed of substrings separated by , characters."
54    )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58    signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl FindInSetFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    Exact(vec![Utf8View, Utf8View]),
74                    Exact(vec![Utf8, Utf8]),
75                    Exact(vec![LargeUtf8, LargeUtf8]),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn name(&self) -> &str {
89        "find_in_set"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        utf8_to_int_type(&arg_types[0], "find_in_set")
98    }
99
100    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101        let ScalarFunctionArgs { args, .. } = args;
102
103        let [string, str_list] = take_function_args(self.name(), args)?;
104
105        match (string, str_list) {
106            // both inputs are scalars
107            (
108                ColumnarValue::Scalar(
109                    ScalarValue::Utf8View(string)
110                    | ScalarValue::Utf8(string)
111                    | ScalarValue::LargeUtf8(string),
112                ),
113                ColumnarValue::Scalar(
114                    ScalarValue::Utf8View(str_list)
115                    | ScalarValue::Utf8(str_list)
116                    | ScalarValue::LargeUtf8(str_list),
117                ),
118            ) => {
119                let res = match (string, str_list) {
120                    (Some(string), Some(str_list)) => {
121                        let position = str_list
122                            .split(',')
123                            .position(|s| s == string)
124                            .map_or(0, |idx| idx + 1);
125
126                        Some(position as i32)
127                    }
128                    _ => None,
129                };
130                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131            }
132
133            // `string` is an array, `str_list` is scalar
134            (
135                ColumnarValue::Array(str_array),
136                ColumnarValue::Scalar(
137                    ScalarValue::Utf8View(str_list_literal)
138                    | ScalarValue::Utf8(str_list_literal)
139                    | ScalarValue::LargeUtf8(str_list_literal),
140                ),
141            ) => {
142                let result_array = match str_list_literal {
143                    // find_in_set(column_a, null) = null
144                    None => new_null_array(str_array.data_type(), str_array.len()),
145                    Some(str_list_literal) => {
146                        let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147                        let result = match str_array.data_type() {
148                            DataType::Utf8 => {
149                                let string_array = str_array.as_string::<i32>();
150                                find_in_set_right_literal::<Int32Type, _>(
151                                    string_array,
152                                    str_list,
153                                )
154                            }
155                            DataType::LargeUtf8 => {
156                                let string_array = str_array.as_string::<i64>();
157                                find_in_set_right_literal::<Int64Type, _>(
158                                    string_array,
159                                    str_list,
160                                )
161                            }
162                            DataType::Utf8View => {
163                                let string_array = str_array.as_string_view();
164                                find_in_set_right_literal::<Int32Type, _>(
165                                    string_array,
166                                    str_list,
167                                )
168                            }
169                            other => {
170                                exec_err!("Unsupported data type {other:?} for function find_in_set")
171                            }
172                        };
173                        Arc::new(result?)
174                    }
175                };
176                Ok(ColumnarValue::Array(result_array))
177            }
178
179            // `string` is scalar, `str_list` is an array
180            (
181                ColumnarValue::Scalar(
182                    ScalarValue::Utf8View(string_literal)
183                    | ScalarValue::Utf8(string_literal)
184                    | ScalarValue::LargeUtf8(string_literal),
185                ),
186                ColumnarValue::Array(str_list_array),
187            ) => {
188                let res = match string_literal {
189                    // find_in_set(null, column_b) = null
190                    None => {
191                        new_null_array(str_list_array.data_type(), str_list_array.len())
192                    }
193                    Some(string) => {
194                        let result = match str_list_array.data_type() {
195                            DataType::Utf8 => {
196                                let str_list = str_list_array.as_string::<i32>();
197                                find_in_set_left_literal::<Int32Type, _>(string, str_list)
198                            }
199                            DataType::LargeUtf8 => {
200                                let str_list = str_list_array.as_string::<i64>();
201                                find_in_set_left_literal::<Int64Type, _>(string, str_list)
202                            }
203                            DataType::Utf8View => {
204                                let str_list = str_list_array.as_string_view();
205                                find_in_set_left_literal::<Int32Type, _>(string, str_list)
206                            }
207                            other => {
208                                exec_err!("Unsupported data type {other:?} for function find_in_set")
209                            }
210                        };
211                        Arc::new(result?)
212                    }
213                };
214                Ok(ColumnarValue::Array(res))
215            }
216
217            // both inputs are arrays
218            (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
219                let res = find_in_set(base_array, exp_array)?;
220
221                Ok(ColumnarValue::Array(res))
222            }
223            _ => {
224                internal_err!("Invalid argument types for `find_in_set` function")
225            }
226        }
227    }
228
229    fn documentation(&self) -> Option<&Documentation> {
230        self.doc()
231    }
232}
233
234/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist`
235/// consisting of N substrings. A string list is a string composed of substrings separated by `,`
236/// characters.
237fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
238    match str.data_type() {
239        DataType::Utf8 => {
240            let string_array = str.as_string::<i32>();
241            let str_list_array = str_list.as_string::<i32>();
242            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
243        }
244        DataType::LargeUtf8 => {
245            let string_array = str.as_string::<i64>();
246            let str_list_array = str_list.as_string::<i64>();
247            find_in_set_general::<Int64Type, _>(string_array, str_list_array)
248        }
249        DataType::Utf8View => {
250            let string_array = str.as_string_view();
251            let str_list_array = str_list.as_string_view();
252            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253        }
254        other => {
255            exec_err!("Unsupported data type {other:?} for function find_in_set")
256        }
257    }
258}
259
260pub fn find_in_set_general<'a, T, V>(
261    string_array: V,
262    str_list_array: V,
263) -> Result<ArrayRef>
264where
265    T: ArrowPrimitiveType,
266    T::Native: OffsetSizeTrait,
267    V: ArrayAccessor<Item = &'a str>,
268{
269    let string_iter = ArrayIter::new(string_array);
270    let str_list_iter = ArrayIter::new(str_list_array);
271
272    let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
273
274    string_iter
275        .zip(str_list_iter)
276        .for_each(
277            |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
278                (Some(string), Some(str_list)) => {
279                    let position = str_list
280                        .split(',')
281                        .position(|s| s == string)
282                        .map_or(0, |idx| idx + 1);
283                    builder.append_value(T::Native::from_usize(position).unwrap());
284                }
285                _ => builder.append_null(),
286            },
287        );
288
289    Ok(Arc::new(builder.finish()) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(
293    string: String,
294    str_list_array: V,
295) -> Result<ArrayRef>
296where
297    T: ArrowPrimitiveType,
298    T::Native: OffsetSizeTrait,
299    V: ArrayAccessor<Item = &'a str>,
300{
301    let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
302
303    let str_list_iter = ArrayIter::new(str_list_array);
304
305    str_list_iter.for_each(|str_list_opt| match str_list_opt {
306        Some(str_list) => {
307            let position = str_list
308                .split(',')
309                .position(|s| s == string)
310                .map_or(0, |idx| idx + 1);
311            builder.append_value(T::Native::from_usize(position).unwrap());
312        }
313        None => builder.append_null(),
314    });
315
316    Ok(Arc::new(builder.finish()) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320    string_array: V,
321    str_list: Vec<&str>,
322) -> Result<ArrayRef>
323where
324    T: ArrowPrimitiveType,
325    T::Native: OffsetSizeTrait,
326    V: ArrayAccessor<Item = &'a str>,
327{
328    let mut builder = PrimitiveArray::<T>::builder(string_array.len());
329
330    let string_iter = ArrayIter::new(string_array);
331
332    string_iter.for_each(|string_opt| match string_opt {
333        Some(string) => {
334            let position = str_list
335                .iter()
336                .position(|s| *s == string)
337                .map_or(0, |idx| idx + 1);
338            builder.append_value(T::Native::from_usize(position).unwrap());
339        }
340        None => builder.append_null(),
341    });
342
343    Ok(Arc::new(builder.finish()) as ArrayRef)
344}
345
346#[cfg(test)]
347mod tests {
348    use crate::unicode::find_in_set::FindInSetFunc;
349    use crate::utils::test::test_function;
350    use arrow::array::{Array, Int32Array, StringArray};
351    use arrow::datatypes::{DataType::Int32, Field};
352    use datafusion_common::config::ConfigOptions;
353    use datafusion_common::{Result, ScalarValue};
354    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
355    use std::sync::Arc;
356
357    #[test]
358    fn test_functions() -> Result<()> {
359        test_function!(
360            FindInSetFunc::new(),
361            vec![
362                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
363                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
364            ],
365            Ok(Some(1)),
366            i32,
367            Int32,
368            Int32Array
369        );
370        test_function!(
371            FindInSetFunc::new(),
372            vec![
373                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
374                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
375                    "a,Д,🔥"
376                )))),
377            ],
378            Ok(Some(3)),
379            i32,
380            Int32,
381            Int32Array
382        );
383        test_function!(
384            FindInSetFunc::new(),
385            vec![
386                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
387                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
388            ],
389            Ok(Some(0)),
390            i32,
391            Int32,
392            Int32Array
393        );
394        test_function!(
395            FindInSetFunc::new(),
396            vec![
397                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
398                    "Apache Software Foundation"
399                )))),
400                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401                    "Github,Apache Software Foundation,DataFusion"
402                )))),
403            ],
404            Ok(Some(2)),
405            i32,
406            Int32,
407            Int32Array
408        );
409        test_function!(
410            FindInSetFunc::new(),
411            vec![
412                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
413                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
414            ],
415            Ok(Some(0)),
416            i32,
417            Int32,
418            Int32Array
419        );
420        test_function!(
421            FindInSetFunc::new(),
422            vec![
423                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
424                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
425            ],
426            Ok(Some(0)),
427            i32,
428            Int32,
429            Int32Array
430        );
431        test_function!(
432            FindInSetFunc::new(),
433            vec![
434                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
435                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
436            ],
437            Ok(None),
438            i32,
439            Int32,
440            Int32Array
441        );
442        test_function!(
443            FindInSetFunc::new(),
444            vec![
445                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
446                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
447            ],
448            Ok(None),
449            i32,
450            Int32,
451            Int32Array
452        );
453
454        Ok(())
455    }
456
457    macro_rules! test_find_in_set {
458        ($test_name:ident, $args:expr, $expected:expr) => {
459            #[test]
460            fn $test_name() -> Result<()> {
461                let fis = crate::unicode::find_in_set();
462
463                let args = $args;
464                let expected = $expected;
465
466                let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
467                let cardinality = args
468                    .iter()
469                    .fold(Option::<usize>::None, |acc, arg| match arg {
470                        ColumnarValue::Scalar(_) => acc,
471                        ColumnarValue::Array(a) => Some(a.len()),
472                    })
473                    .unwrap_or(1);
474                let return_type = fis.return_type(&type_array)?;
475                let arg_fields = args
476                    .iter()
477                    .enumerate()
478                    .map(|(idx, a)| {
479                        Field::new(format!("arg_{idx}"), a.data_type(), true).into()
480                    })
481                    .collect::<Vec<_>>();
482                let result = fis.invoke_with_args(ScalarFunctionArgs {
483                    args,
484                    arg_fields,
485                    number_rows: cardinality,
486                    return_field: Field::new("f", return_type, true).into(),
487                    config_options: Arc::new(ConfigOptions::default()),
488                });
489                assert!(result.is_ok());
490
491                let result = result?
492                    .to_array(cardinality)
493                    .expect("Failed to convert to array");
494                let result = result
495                    .as_any()
496                    .downcast_ref::<Int32Array>()
497                    .expect("Failed to convert to type");
498                assert_eq!(*result, expected);
499
500                Ok(())
501            }
502        };
503    }
504
505    test_find_in_set!(
506        test_find_in_set_with_scalar_args,
507        vec![
508            ColumnarValue::Array(Arc::new(StringArray::from(vec![
509                "", "a", "b", "c", "d"
510            ]))),
511            ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
512        ],
513        Int32Array::from(vec![0, 0, 1, 2, 3])
514    );
515    test_find_in_set!(
516        test_find_in_set_with_scalar_args_2,
517        vec![
518            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
519                "ApacheSoftware".to_string()
520            ))),
521            ColumnarValue::Array(Arc::new(StringArray::from(vec![
522                "a,b,c",
523                "ApacheSoftware,Github,DataFusion",
524                ""
525            ]))),
526        ],
527        Int32Array::from(vec![0, 1, 0])
528    );
529    test_find_in_set!(
530        test_find_in_set_with_scalar_args_3,
531        vec![
532            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
533            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
534        ],
535        Int32Array::from(vec![None::<i32>; 3])
536    );
537    test_find_in_set!(
538        test_find_in_set_with_scalar_args_4,
539        vec![
540            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
541            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
542        ],
543        Int32Array::from(vec![None::<i32>; 3])
544    );
545}