Skip to main content

datafusion_functions/unicode/
find_in_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    ArrayAccessor, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray,
22};
23use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
24use arrow_buffer::NullBuffer;
25
26use crate::utils::utf8_to_int_type;
27use datafusion_common::{
28    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
29};
30use datafusion_expr::TypeSignature::Exact;
31use datafusion_expr::{
32    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33    Volatility,
34};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38    doc_section(label = "String Functions"),
39    description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
40    syntax_example = "find_in_set(str, strlist)",
41    sql_example = r#"```sql
42> select find_in_set('b', 'a,b,c,d');
43+----------------------------------------+
44| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
45+----------------------------------------+
46| 2                                      |
47+----------------------------------------+
48```"#,
49    argument(name = "str", description = "String expression to find in strlist."),
50    argument(
51        name = "strlist",
52        description = "A string list is a string composed of substrings separated by , characters."
53    )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct FindInSetFunc {
57    signature: Signature,
58}
59
60impl Default for FindInSetFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl FindInSetFunc {
67    pub fn new() -> Self {
68        use DataType::*;
69        Self {
70            signature: Signature::one_of(
71                vec![
72                    Exact(vec![Utf8View, Utf8View]),
73                    Exact(vec![Utf8, Utf8]),
74                    Exact(vec![LargeUtf8, LargeUtf8]),
75                ],
76                Volatility::Immutable,
77            ),
78        }
79    }
80}
81
82impl ScalarUDFImpl for FindInSetFunc {
83    fn name(&self) -> &str {
84        "find_in_set"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92        utf8_to_int_type(&arg_types[0], "find_in_set")
93    }
94
95    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96        let return_field = args.return_field;
97        let [string, str_list] = take_function_args(self.name(), args.args)?;
98
99        match (string, str_list) {
100            // both inputs are scalars
101            (
102                ColumnarValue::Scalar(
103                    ScalarValue::Utf8View(string)
104                    | ScalarValue::Utf8(string)
105                    | ScalarValue::LargeUtf8(string),
106                ),
107                ColumnarValue::Scalar(
108                    ScalarValue::Utf8View(str_list)
109                    | ScalarValue::Utf8(str_list)
110                    | ScalarValue::LargeUtf8(str_list),
111                ),
112            ) => {
113                let res = match (string, str_list) {
114                    (Some(string), Some(str_list)) => {
115                        let position = str_list
116                            .split(',')
117                            .position(|s| s == string)
118                            .map_or(0, |idx| idx + 1);
119
120                        Some(position as i32)
121                    }
122                    _ => None,
123                };
124                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
125            }
126
127            // `string` is an array, `str_list` is scalar
128            (
129                ColumnarValue::Array(str_array),
130                ColumnarValue::Scalar(
131                    ScalarValue::Utf8View(str_list_literal)
132                    | ScalarValue::Utf8(str_list_literal)
133                    | ScalarValue::LargeUtf8(str_list_literal),
134                ),
135            ) => {
136                match str_list_literal {
137                    // find_in_set(column_a, null) = null
138                    None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
139                        return_field.data_type(),
140                    )?)),
141                    Some(str_list_literal) => {
142                        let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
143                        let result = match str_array.data_type() {
144                            DataType::Utf8 => {
145                                let string_array = str_array.as_string::<i32>();
146                                find_in_set_right_literal::<Int32Type, _>(
147                                    string_array,
148                                    &str_list,
149                                )
150                            }
151                            DataType::LargeUtf8 => {
152                                let string_array = str_array.as_string::<i64>();
153                                find_in_set_right_literal::<Int64Type, _>(
154                                    string_array,
155                                    &str_list,
156                                )
157                            }
158                            DataType::Utf8View => {
159                                let string_array = str_array.as_string_view();
160                                find_in_set_right_literal::<Int32Type, _>(
161                                    string_array,
162                                    &str_list,
163                                )
164                            }
165                            other => {
166                                exec_err!(
167                                    "Unsupported data type {other:?} for function find_in_set"
168                                )
169                            }
170                        };
171                        Ok(ColumnarValue::Array(Arc::new(result?)))
172                    }
173                }
174            }
175
176            // `string` is scalar, `str_list` is an array
177            (
178                ColumnarValue::Scalar(
179                    ScalarValue::Utf8View(string_literal)
180                    | ScalarValue::Utf8(string_literal)
181                    | ScalarValue::LargeUtf8(string_literal),
182                ),
183                ColumnarValue::Array(str_list_array),
184            ) => {
185                match string_literal {
186                    // find_in_set(null, column_b) = null
187                    None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
188                        return_field.data_type(),
189                    )?)),
190                    Some(string) => {
191                        let result = match str_list_array.data_type() {
192                            DataType::Utf8 => {
193                                let str_list = str_list_array.as_string::<i32>();
194                                find_in_set_left_literal::<Int32Type, _>(
195                                    &string, str_list,
196                                )
197                            }
198                            DataType::LargeUtf8 => {
199                                let str_list = str_list_array.as_string::<i64>();
200                                find_in_set_left_literal::<Int64Type, _>(
201                                    &string, str_list,
202                                )
203                            }
204                            DataType::Utf8View => {
205                                let str_list = str_list_array.as_string_view();
206                                find_in_set_left_literal::<Int32Type, _>(
207                                    &string, str_list,
208                                )
209                            }
210                            other => {
211                                exec_err!(
212                                    "Unsupported data type {other:?} for function find_in_set"
213                                )
214                            }
215                        };
216                        Ok(ColumnarValue::Array(Arc::new(result?)))
217                    }
218                }
219            }
220
221            // both inputs are arrays
222            (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
223                let res = find_in_set(&base_array, &exp_array)?;
224
225                Ok(ColumnarValue::Array(res))
226            }
227            _ => {
228                internal_err!("Invalid argument types for `find_in_set` function")
229            }
230        }
231    }
232
233    fn documentation(&self) -> Option<&Documentation> {
234        self.doc()
235    }
236}
237
238/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist`
239/// consisting of N substrings. A string list is a string composed of substrings separated by `,`
240/// characters.
241fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
242    match str.data_type() {
243        DataType::Utf8 => {
244            let string_array = str.as_string::<i32>();
245            let str_list_array = str_list.as_string::<i32>();
246            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
247        }
248        DataType::LargeUtf8 => {
249            let string_array = str.as_string::<i64>();
250            let str_list_array = str_list.as_string::<i64>();
251            find_in_set_general::<Int64Type, _>(string_array, str_list_array)
252        }
253        DataType::Utf8View => {
254            let string_array = str.as_string_view();
255            let str_list_array = str_list.as_string_view();
256            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
257        }
258        other => {
259            exec_err!("Unsupported data type {other:?} for function find_in_set")
260        }
261    }
262}
263
264fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
265where
266    T: ArrowPrimitiveType,
267    T::Native: OffsetSizeTrait,
268    V: ArrayAccessor<Item = &'a str> + Copy,
269{
270    let len = string_array.len();
271    let nulls = NullBuffer::union(string_array.nulls(), str_list_array.nulls());
272    let zero = T::Native::from_usize(0).unwrap();
273
274    let values: Vec<T::Native> = (0..len)
275        .map(|i| {
276            if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
277                return zero;
278            }
279            let string = string_array.value(i);
280            let str_list = str_list_array.value(i);
281            let position = str_list
282                .split(',')
283                .position(|s| s == string)
284                .map_or(0, |idx| idx + 1);
285            T::Native::from_usize(position).unwrap()
286        })
287        .collect();
288
289    Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
293where
294    T: ArrowPrimitiveType,
295    T::Native: OffsetSizeTrait,
296    V: ArrayAccessor<Item = &'a str> + Copy,
297{
298    let len = str_list_array.len();
299    let nulls = str_list_array.nulls().cloned();
300    let zero = T::Native::from_usize(0).unwrap();
301
302    let values: Vec<T::Native> = (0..len)
303        .map(|i| {
304            if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
305                return zero;
306            }
307            let str_list = str_list_array.value(i);
308            let position = str_list
309                .split(',')
310                .position(|s| s == string)
311                .map_or(0, |idx| idx + 1);
312            T::Native::from_usize(position).unwrap()
313        })
314        .collect();
315
316    Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320    string_array: V,
321    str_list: &[&str],
322) -> Result<ArrayRef>
323where
324    T: ArrowPrimitiveType,
325    T::Native: OffsetSizeTrait,
326    V: ArrayAccessor<Item = &'a str> + Copy,
327{
328    let len = string_array.len();
329    let nulls = string_array.nulls().cloned();
330    let zero = T::Native::from_usize(0).unwrap();
331
332    let values: Vec<T::Native> = (0..len)
333        .map(|i| {
334            if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
335                return zero;
336            }
337            let string = string_array.value(i);
338            let position = str_list
339                .iter()
340                .position(|s| *s == string)
341                .map_or(0, |idx| idx + 1);
342            T::Native::from_usize(position).unwrap()
343        })
344        .collect();
345
346    Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
347}
348
349#[cfg(test)]
350mod tests {
351    use crate::unicode::find_in_set::FindInSetFunc;
352    use crate::utils::test::test_function;
353    use arrow::array::{Array, Int32Array, StringArray};
354    use arrow::datatypes::{DataType::Int32, Field};
355    use datafusion_common::config::ConfigOptions;
356    use datafusion_common::{Result, ScalarValue};
357    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
358    use std::sync::Arc;
359
360    #[test]
361    fn test_functions() -> Result<()> {
362        test_function!(
363            FindInSetFunc::new(),
364            vec![
365                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
366                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
367            ],
368            Ok(Some(1)),
369            i32,
370            Int32,
371            Int32Array
372        );
373        test_function!(
374            FindInSetFunc::new(),
375            vec![
376                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
377                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
378                    "a,Д,🔥"
379                )))),
380            ],
381            Ok(Some(3)),
382            i32,
383            Int32,
384            Int32Array
385        );
386        test_function!(
387            FindInSetFunc::new(),
388            vec![
389                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
390                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
391            ],
392            Ok(Some(0)),
393            i32,
394            Int32,
395            Int32Array
396        );
397        test_function!(
398            FindInSetFunc::new(),
399            vec![
400                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401                    "Apache Software Foundation"
402                )))),
403                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
404                    "Github,Apache Software Foundation,DataFusion"
405                )))),
406            ],
407            Ok(Some(2)),
408            i32,
409            Int32,
410            Int32Array
411        );
412        test_function!(
413            FindInSetFunc::new(),
414            vec![
415                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
416                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
417            ],
418            Ok(Some(0)),
419            i32,
420            Int32,
421            Int32Array
422        );
423        test_function!(
424            FindInSetFunc::new(),
425            vec![
426                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
427                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
428            ],
429            Ok(Some(0)),
430            i32,
431            Int32,
432            Int32Array
433        );
434        test_function!(
435            FindInSetFunc::new(),
436            vec![
437                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
438                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
439            ],
440            Ok(None),
441            i32,
442            Int32,
443            Int32Array
444        );
445        test_function!(
446            FindInSetFunc::new(),
447            vec![
448                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
449                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
450            ],
451            Ok(None),
452            i32,
453            Int32,
454            Int32Array
455        );
456
457        Ok(())
458    }
459
460    macro_rules! test_find_in_set {
461        ($test_name:ident, $args:expr, $expected:expr) => {
462            #[test]
463            fn $test_name() -> Result<()> {
464                let fis = crate::unicode::find_in_set();
465
466                let args = $args;
467                let expected = $expected;
468
469                let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
470                let cardinality = args
471                    .iter()
472                    .fold(Option::<usize>::None, |acc, arg| match arg {
473                        ColumnarValue::Scalar(_) => acc,
474                        ColumnarValue::Array(a) => Some(a.len()),
475                    })
476                    .unwrap_or(1);
477                let return_type = fis.return_type(&type_array)?;
478                let arg_fields = args
479                    .iter()
480                    .enumerate()
481                    .map(|(idx, a)| {
482                        Field::new(format!("arg_{idx}"), a.data_type(), true).into()
483                    })
484                    .collect::<Vec<_>>();
485                let result = fis.invoke_with_args(ScalarFunctionArgs {
486                    args,
487                    arg_fields,
488                    number_rows: cardinality,
489                    return_field: Field::new("f", return_type, true).into(),
490                    config_options: Arc::new(ConfigOptions::default()),
491                });
492                assert!(result.is_ok());
493
494                let result = result?
495                    .to_array(cardinality)
496                    .expect("Failed to convert to array");
497                let result = result
498                    .as_any()
499                    .downcast_ref::<Int32Array>()
500                    .expect("Failed to convert to type");
501                assert_eq!(*result, expected);
502
503                Ok(())
504            }
505        };
506    }
507
508    test_find_in_set!(
509        test_find_in_set_with_scalar_args,
510        vec![
511            ColumnarValue::Array(Arc::new(StringArray::from(vec![
512                "", "a", "b", "c", "d"
513            ]))),
514            ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
515        ],
516        Int32Array::from(vec![0, 0, 1, 2, 3])
517    );
518    test_find_in_set!(
519        test_find_in_set_with_scalar_args_2,
520        vec![
521            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
522                "ApacheSoftware".to_string()
523            ))),
524            ColumnarValue::Array(Arc::new(StringArray::from(vec![
525                "a,b,c",
526                "ApacheSoftware,Github,DataFusion",
527                ""
528            ]))),
529        ],
530        Int32Array::from(vec![0, 1, 0])
531    );
532    test_find_in_set!(
533        test_find_in_set_with_scalar_args_3,
534        vec![
535            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
536            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
537        ],
538        Int32Array::from(vec![None::<i32>; 3])
539    );
540    test_find_in_set!(
541        test_find_in_set_with_scalar_args_4,
542        vec![
543            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
544            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
545        ],
546        Int32Array::from(vec![None::<i32>; 3])
547    );
548}