1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23 PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41 syntax_example = "find_in_set(str, strlist)",
42 sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2 |
48+----------------------------------------+
49```"#,
50 argument(name = "str", description = "String expression to find in strlist."),
51 argument(
52 name = "strlist",
53 description = "A string list is a string composed of substrings separated by , characters."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58 signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl FindInSetFunc {
68 pub fn new() -> Self {
69 use DataType::*;
70 Self {
71 signature: Signature::one_of(
72 vec![
73 Exact(vec![Utf8View, Utf8View]),
74 Exact(vec![Utf8, Utf8]),
75 Exact(vec![LargeUtf8, LargeUtf8]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "find_in_set"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_int_type(&arg_types[0], "find_in_set")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let return_field = args.return_field;
102 let [string, str_list] = take_function_args(self.name(), args.args)?;
103
104 match (string, str_list) {
105 (
107 ColumnarValue::Scalar(
108 ScalarValue::Utf8View(string)
109 | ScalarValue::Utf8(string)
110 | ScalarValue::LargeUtf8(string),
111 ),
112 ColumnarValue::Scalar(
113 ScalarValue::Utf8View(str_list)
114 | ScalarValue::Utf8(str_list)
115 | ScalarValue::LargeUtf8(str_list),
116 ),
117 ) => {
118 let res = match (string, str_list) {
119 (Some(string), Some(str_list)) => {
120 let position = str_list
121 .split(',')
122 .position(|s| s == string)
123 .map_or(0, |idx| idx + 1);
124
125 Some(position as i32)
126 }
127 _ => None,
128 };
129 Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
130 }
131
132 (
134 ColumnarValue::Array(str_array),
135 ColumnarValue::Scalar(
136 ScalarValue::Utf8View(str_list_literal)
137 | ScalarValue::Utf8(str_list_literal)
138 | ScalarValue::LargeUtf8(str_list_literal),
139 ),
140 ) => {
141 match str_list_literal {
142 None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
144 return_field.data_type(),
145 )?)),
146 Some(str_list_literal) => {
147 let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
148 let result = match str_array.data_type() {
149 DataType::Utf8 => {
150 let string_array = str_array.as_string::<i32>();
151 find_in_set_right_literal::<Int32Type, _>(
152 string_array,
153 &str_list,
154 )
155 }
156 DataType::LargeUtf8 => {
157 let string_array = str_array.as_string::<i64>();
158 find_in_set_right_literal::<Int64Type, _>(
159 string_array,
160 &str_list,
161 )
162 }
163 DataType::Utf8View => {
164 let string_array = str_array.as_string_view();
165 find_in_set_right_literal::<Int32Type, _>(
166 string_array,
167 &str_list,
168 )
169 }
170 other => {
171 exec_err!(
172 "Unsupported data type {other:?} for function find_in_set"
173 )
174 }
175 };
176 Ok(ColumnarValue::Array(Arc::new(result?)))
177 }
178 }
179 }
180
181 (
183 ColumnarValue::Scalar(
184 ScalarValue::Utf8View(string_literal)
185 | ScalarValue::Utf8(string_literal)
186 | ScalarValue::LargeUtf8(string_literal),
187 ),
188 ColumnarValue::Array(str_list_array),
189 ) => {
190 match string_literal {
191 None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
193 return_field.data_type(),
194 )?)),
195 Some(string) => {
196 let result = match str_list_array.data_type() {
197 DataType::Utf8 => {
198 let str_list = str_list_array.as_string::<i32>();
199 find_in_set_left_literal::<Int32Type, _>(
200 &string, str_list,
201 )
202 }
203 DataType::LargeUtf8 => {
204 let str_list = str_list_array.as_string::<i64>();
205 find_in_set_left_literal::<Int64Type, _>(
206 &string, str_list,
207 )
208 }
209 DataType::Utf8View => {
210 let str_list = str_list_array.as_string_view();
211 find_in_set_left_literal::<Int32Type, _>(
212 &string, str_list,
213 )
214 }
215 other => {
216 exec_err!(
217 "Unsupported data type {other:?} for function find_in_set"
218 )
219 }
220 };
221 Ok(ColumnarValue::Array(Arc::new(result?)))
222 }
223 }
224 }
225
226 (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
228 let res = find_in_set(&base_array, &exp_array)?;
229
230 Ok(ColumnarValue::Array(res))
231 }
232 _ => {
233 internal_err!("Invalid argument types for `find_in_set` function")
234 }
235 }
236 }
237
238 fn documentation(&self) -> Option<&Documentation> {
239 self.doc()
240 }
241}
242
243fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
247 match str.data_type() {
248 DataType::Utf8 => {
249 let string_array = str.as_string::<i32>();
250 let str_list_array = str_list.as_string::<i32>();
251 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
252 }
253 DataType::LargeUtf8 => {
254 let string_array = str.as_string::<i64>();
255 let str_list_array = str_list.as_string::<i64>();
256 find_in_set_general::<Int64Type, _>(string_array, str_list_array)
257 }
258 DataType::Utf8View => {
259 let string_array = str.as_string_view();
260 let str_list_array = str_list.as_string_view();
261 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
262 }
263 other => {
264 exec_err!("Unsupported data type {other:?} for function find_in_set")
265 }
266 }
267}
268
269fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
270where
271 T: ArrowPrimitiveType,
272 T::Native: OffsetSizeTrait,
273 V: ArrayAccessor<Item = &'a str>,
274{
275 let string_iter = ArrayIter::new(string_array);
276 let str_list_iter = ArrayIter::new(str_list_array);
277
278 let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
279
280 string_iter
281 .zip(str_list_iter)
282 .for_each(
283 |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
284 (Some(string), Some(str_list)) => {
285 let position = str_list
286 .split(',')
287 .position(|s| s == string)
288 .map_or(0, |idx| idx + 1);
289 builder.append_value(T::Native::from_usize(position).unwrap());
290 }
291 _ => builder.append_null(),
292 },
293 );
294
295 Ok(Arc::new(builder.finish()) as ArrayRef)
296}
297
298fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
299where
300 T: ArrowPrimitiveType,
301 T::Native: OffsetSizeTrait,
302 V: ArrayAccessor<Item = &'a str>,
303{
304 let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
305
306 let str_list_iter = ArrayIter::new(str_list_array);
307
308 str_list_iter.for_each(|str_list_opt| match str_list_opt {
309 Some(str_list) => {
310 let position = str_list
311 .split(',')
312 .position(|s| s == string)
313 .map_or(0, |idx| idx + 1);
314 builder.append_value(T::Native::from_usize(position).unwrap());
315 }
316 None => builder.append_null(),
317 });
318
319 Ok(Arc::new(builder.finish()) as ArrayRef)
320}
321
322fn find_in_set_right_literal<'a, T, V>(
323 string_array: V,
324 str_list: &[&str],
325) -> Result<ArrayRef>
326where
327 T: ArrowPrimitiveType,
328 T::Native: OffsetSizeTrait,
329 V: ArrayAccessor<Item = &'a str>,
330{
331 let mut builder = PrimitiveArray::<T>::builder(string_array.len());
332
333 let string_iter = ArrayIter::new(string_array);
334
335 string_iter.for_each(|string_opt| match string_opt {
336 Some(string) => {
337 let position = str_list
338 .iter()
339 .position(|s| *s == string)
340 .map_or(0, |idx| idx + 1);
341 builder.append_value(T::Native::from_usize(position).unwrap());
342 }
343 None => builder.append_null(),
344 });
345
346 Ok(Arc::new(builder.finish()) as ArrayRef)
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::unicode::find_in_set::FindInSetFunc;
352 use crate::utils::test::test_function;
353 use arrow::array::{Array, Int32Array, StringArray};
354 use arrow::datatypes::{DataType::Int32, Field};
355 use datafusion_common::config::ConfigOptions;
356 use datafusion_common::{Result, ScalarValue};
357 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
358 use std::sync::Arc;
359
360 #[test]
361 fn test_functions() -> Result<()> {
362 test_function!(
363 FindInSetFunc::new(),
364 vec![
365 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
366 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
367 ],
368 Ok(Some(1)),
369 i32,
370 Int32,
371 Int32Array
372 );
373 test_function!(
374 FindInSetFunc::new(),
375 vec![
376 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
377 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
378 "a,Д,🔥"
379 )))),
380 ],
381 Ok(Some(3)),
382 i32,
383 Int32,
384 Int32Array
385 );
386 test_function!(
387 FindInSetFunc::new(),
388 vec![
389 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
390 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
391 ],
392 Ok(Some(0)),
393 i32,
394 Int32,
395 Int32Array
396 );
397 test_function!(
398 FindInSetFunc::new(),
399 vec![
400 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401 "Apache Software Foundation"
402 )))),
403 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
404 "Github,Apache Software Foundation,DataFusion"
405 )))),
406 ],
407 Ok(Some(2)),
408 i32,
409 Int32,
410 Int32Array
411 );
412 test_function!(
413 FindInSetFunc::new(),
414 vec![
415 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
416 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
417 ],
418 Ok(Some(0)),
419 i32,
420 Int32,
421 Int32Array
422 );
423 test_function!(
424 FindInSetFunc::new(),
425 vec![
426 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
427 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
428 ],
429 Ok(Some(0)),
430 i32,
431 Int32,
432 Int32Array
433 );
434 test_function!(
435 FindInSetFunc::new(),
436 vec![
437 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
438 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
439 ],
440 Ok(None),
441 i32,
442 Int32,
443 Int32Array
444 );
445 test_function!(
446 FindInSetFunc::new(),
447 vec![
448 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
449 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
450 ],
451 Ok(None),
452 i32,
453 Int32,
454 Int32Array
455 );
456
457 Ok(())
458 }
459
460 macro_rules! test_find_in_set {
461 ($test_name:ident, $args:expr, $expected:expr) => {
462 #[test]
463 fn $test_name() -> Result<()> {
464 let fis = crate::unicode::find_in_set();
465
466 let args = $args;
467 let expected = $expected;
468
469 let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
470 let cardinality = args
471 .iter()
472 .fold(Option::<usize>::None, |acc, arg| match arg {
473 ColumnarValue::Scalar(_) => acc,
474 ColumnarValue::Array(a) => Some(a.len()),
475 })
476 .unwrap_or(1);
477 let return_type = fis.return_type(&type_array)?;
478 let arg_fields = args
479 .iter()
480 .enumerate()
481 .map(|(idx, a)| {
482 Field::new(format!("arg_{idx}"), a.data_type(), true).into()
483 })
484 .collect::<Vec<_>>();
485 let result = fis.invoke_with_args(ScalarFunctionArgs {
486 args,
487 arg_fields,
488 number_rows: cardinality,
489 return_field: Field::new("f", return_type, true).into(),
490 config_options: Arc::new(ConfigOptions::default()),
491 });
492 assert!(result.is_ok());
493
494 let result = result?
495 .to_array(cardinality)
496 .expect("Failed to convert to array");
497 let result = result
498 .as_any()
499 .downcast_ref::<Int32Array>()
500 .expect("Failed to convert to type");
501 assert_eq!(*result, expected);
502
503 Ok(())
504 }
505 };
506 }
507
508 test_find_in_set!(
509 test_find_in_set_with_scalar_args,
510 vec![
511 ColumnarValue::Array(Arc::new(StringArray::from(vec![
512 "", "a", "b", "c", "d"
513 ]))),
514 ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
515 ],
516 Int32Array::from(vec![0, 0, 1, 2, 3])
517 );
518 test_find_in_set!(
519 test_find_in_set_with_scalar_args_2,
520 vec![
521 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
522 "ApacheSoftware".to_string()
523 ))),
524 ColumnarValue::Array(Arc::new(StringArray::from(vec![
525 "a,b,c",
526 "ApacheSoftware,Github,DataFusion",
527 ""
528 ]))),
529 ],
530 Int32Array::from(vec![0, 1, 0])
531 );
532 test_find_in_set!(
533 test_find_in_set_with_scalar_args_3,
534 vec![
535 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
536 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
537 ],
538 Int32Array::from(vec![None::<i32>; 3])
539 );
540 test_find_in_set!(
541 test_find_in_set_with_scalar_args_4,
542 vec![
543 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
544 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
545 ],
546 Int32Array::from(vec![None::<i32>; 3])
547 );
548}