1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23 PrimitiveArray, new_null_array,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41 syntax_example = "find_in_set(str, strlist)",
42 sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2 |
48+----------------------------------------+
49```"#,
50 argument(name = "str", description = "String expression to find in strlist."),
51 argument(
52 name = "strlist",
53 description = "A string list is a string composed of substrings separated by , characters."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58 signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl FindInSetFunc {
68 pub fn new() -> Self {
69 use DataType::*;
70 Self {
71 signature: Signature::one_of(
72 vec![
73 Exact(vec![Utf8View, Utf8View]),
74 Exact(vec![Utf8, Utf8]),
75 Exact(vec![LargeUtf8, LargeUtf8]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "find_in_set"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_int_type(&arg_types[0], "find_in_set")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let ScalarFunctionArgs { args, .. } = args;
102
103 let [string, str_list] = take_function_args(self.name(), args)?;
104
105 match (string, str_list) {
106 (
108 ColumnarValue::Scalar(
109 ScalarValue::Utf8View(string)
110 | ScalarValue::Utf8(string)
111 | ScalarValue::LargeUtf8(string),
112 ),
113 ColumnarValue::Scalar(
114 ScalarValue::Utf8View(str_list)
115 | ScalarValue::Utf8(str_list)
116 | ScalarValue::LargeUtf8(str_list),
117 ),
118 ) => {
119 let res = match (string, str_list) {
120 (Some(string), Some(str_list)) => {
121 let position = str_list
122 .split(',')
123 .position(|s| s == string)
124 .map_or(0, |idx| idx + 1);
125
126 Some(position as i32)
127 }
128 _ => None,
129 };
130 Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131 }
132
133 (
135 ColumnarValue::Array(str_array),
136 ColumnarValue::Scalar(
137 ScalarValue::Utf8View(str_list_literal)
138 | ScalarValue::Utf8(str_list_literal)
139 | ScalarValue::LargeUtf8(str_list_literal),
140 ),
141 ) => {
142 let result_array = match str_list_literal {
143 None => new_null_array(str_array.data_type(), str_array.len()),
145 Some(str_list_literal) => {
146 let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147 let result = match str_array.data_type() {
148 DataType::Utf8 => {
149 let string_array = str_array.as_string::<i32>();
150 find_in_set_right_literal::<Int32Type, _>(
151 string_array,
152 &str_list,
153 )
154 }
155 DataType::LargeUtf8 => {
156 let string_array = str_array.as_string::<i64>();
157 find_in_set_right_literal::<Int64Type, _>(
158 string_array,
159 &str_list,
160 )
161 }
162 DataType::Utf8View => {
163 let string_array = str_array.as_string_view();
164 find_in_set_right_literal::<Int32Type, _>(
165 string_array,
166 &str_list,
167 )
168 }
169 other => {
170 exec_err!(
171 "Unsupported data type {other:?} for function find_in_set"
172 )
173 }
174 };
175 Arc::new(result?)
176 }
177 };
178 Ok(ColumnarValue::Array(result_array))
179 }
180
181 (
183 ColumnarValue::Scalar(
184 ScalarValue::Utf8View(string_literal)
185 | ScalarValue::Utf8(string_literal)
186 | ScalarValue::LargeUtf8(string_literal),
187 ),
188 ColumnarValue::Array(str_list_array),
189 ) => {
190 let res = match string_literal {
191 None => {
193 new_null_array(str_list_array.data_type(), str_list_array.len())
194 }
195 Some(string) => {
196 let result = match str_list_array.data_type() {
197 DataType::Utf8 => {
198 let str_list = str_list_array.as_string::<i32>();
199 find_in_set_left_literal::<Int32Type, _>(
200 &string, str_list,
201 )
202 }
203 DataType::LargeUtf8 => {
204 let str_list = str_list_array.as_string::<i64>();
205 find_in_set_left_literal::<Int64Type, _>(
206 &string, str_list,
207 )
208 }
209 DataType::Utf8View => {
210 let str_list = str_list_array.as_string_view();
211 find_in_set_left_literal::<Int32Type, _>(
212 &string, str_list,
213 )
214 }
215 other => {
216 exec_err!(
217 "Unsupported data type {other:?} for function find_in_set"
218 )
219 }
220 };
221 Arc::new(result?)
222 }
223 };
224 Ok(ColumnarValue::Array(res))
225 }
226
227 (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
229 let res = find_in_set(&base_array, &exp_array)?;
230
231 Ok(ColumnarValue::Array(res))
232 }
233 _ => {
234 internal_err!("Invalid argument types for `find_in_set` function")
235 }
236 }
237 }
238
239 fn documentation(&self) -> Option<&Documentation> {
240 self.doc()
241 }
242}
243
244fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
248 match str.data_type() {
249 DataType::Utf8 => {
250 let string_array = str.as_string::<i32>();
251 let str_list_array = str_list.as_string::<i32>();
252 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253 }
254 DataType::LargeUtf8 => {
255 let string_array = str.as_string::<i64>();
256 let str_list_array = str_list.as_string::<i64>();
257 find_in_set_general::<Int64Type, _>(string_array, str_list_array)
258 }
259 DataType::Utf8View => {
260 let string_array = str.as_string_view();
261 let str_list_array = str_list.as_string_view();
262 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
263 }
264 other => {
265 exec_err!("Unsupported data type {other:?} for function find_in_set")
266 }
267 }
268}
269
270fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
271where
272 T: ArrowPrimitiveType,
273 T::Native: OffsetSizeTrait,
274 V: ArrayAccessor<Item = &'a str>,
275{
276 let string_iter = ArrayIter::new(string_array);
277 let str_list_iter = ArrayIter::new(str_list_array);
278
279 let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
280
281 string_iter
282 .zip(str_list_iter)
283 .for_each(
284 |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
285 (Some(string), Some(str_list)) => {
286 let position = str_list
287 .split(',')
288 .position(|s| s == string)
289 .map_or(0, |idx| idx + 1);
290 builder.append_value(T::Native::from_usize(position).unwrap());
291 }
292 _ => builder.append_null(),
293 },
294 );
295
296 Ok(Arc::new(builder.finish()) as ArrayRef)
297}
298
299fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
300where
301 T: ArrowPrimitiveType,
302 T::Native: OffsetSizeTrait,
303 V: ArrayAccessor<Item = &'a str>,
304{
305 let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
306
307 let str_list_iter = ArrayIter::new(str_list_array);
308
309 str_list_iter.for_each(|str_list_opt| match str_list_opt {
310 Some(str_list) => {
311 let position = str_list
312 .split(',')
313 .position(|s| s == string)
314 .map_or(0, |idx| idx + 1);
315 builder.append_value(T::Native::from_usize(position).unwrap());
316 }
317 None => builder.append_null(),
318 });
319
320 Ok(Arc::new(builder.finish()) as ArrayRef)
321}
322
323fn find_in_set_right_literal<'a, T, V>(
324 string_array: V,
325 str_list: &[&str],
326) -> Result<ArrayRef>
327where
328 T: ArrowPrimitiveType,
329 T::Native: OffsetSizeTrait,
330 V: ArrayAccessor<Item = &'a str>,
331{
332 let mut builder = PrimitiveArray::<T>::builder(string_array.len());
333
334 let string_iter = ArrayIter::new(string_array);
335
336 string_iter.for_each(|string_opt| match string_opt {
337 Some(string) => {
338 let position = str_list
339 .iter()
340 .position(|s| *s == string)
341 .map_or(0, |idx| idx + 1);
342 builder.append_value(T::Native::from_usize(position).unwrap());
343 }
344 None => builder.append_null(),
345 });
346
347 Ok(Arc::new(builder.finish()) as ArrayRef)
348}
349
350#[cfg(test)]
351mod tests {
352 use crate::unicode::find_in_set::FindInSetFunc;
353 use crate::utils::test::test_function;
354 use arrow::array::{Array, Int32Array, StringArray};
355 use arrow::datatypes::{DataType::Int32, Field};
356 use datafusion_common::config::ConfigOptions;
357 use datafusion_common::{Result, ScalarValue};
358 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
359 use std::sync::Arc;
360
361 #[test]
362 fn test_functions() -> Result<()> {
363 test_function!(
364 FindInSetFunc::new(),
365 vec![
366 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
367 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
368 ],
369 Ok(Some(1)),
370 i32,
371 Int32,
372 Int32Array
373 );
374 test_function!(
375 FindInSetFunc::new(),
376 vec![
377 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
378 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
379 "a,Д,🔥"
380 )))),
381 ],
382 Ok(Some(3)),
383 i32,
384 Int32,
385 Int32Array
386 );
387 test_function!(
388 FindInSetFunc::new(),
389 vec![
390 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
391 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
392 ],
393 Ok(Some(0)),
394 i32,
395 Int32,
396 Int32Array
397 );
398 test_function!(
399 FindInSetFunc::new(),
400 vec![
401 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
402 "Apache Software Foundation"
403 )))),
404 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
405 "Github,Apache Software Foundation,DataFusion"
406 )))),
407 ],
408 Ok(Some(2)),
409 i32,
410 Int32,
411 Int32Array
412 );
413 test_function!(
414 FindInSetFunc::new(),
415 vec![
416 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
417 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
418 ],
419 Ok(Some(0)),
420 i32,
421 Int32,
422 Int32Array
423 );
424 test_function!(
425 FindInSetFunc::new(),
426 vec![
427 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
428 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
429 ],
430 Ok(Some(0)),
431 i32,
432 Int32,
433 Int32Array
434 );
435 test_function!(
436 FindInSetFunc::new(),
437 vec![
438 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
439 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
440 ],
441 Ok(None),
442 i32,
443 Int32,
444 Int32Array
445 );
446 test_function!(
447 FindInSetFunc::new(),
448 vec![
449 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
450 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
451 ],
452 Ok(None),
453 i32,
454 Int32,
455 Int32Array
456 );
457
458 Ok(())
459 }
460
461 macro_rules! test_find_in_set {
462 ($test_name:ident, $args:expr, $expected:expr) => {
463 #[test]
464 fn $test_name() -> Result<()> {
465 let fis = crate::unicode::find_in_set();
466
467 let args = $args;
468 let expected = $expected;
469
470 let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
471 let cardinality = args
472 .iter()
473 .fold(Option::<usize>::None, |acc, arg| match arg {
474 ColumnarValue::Scalar(_) => acc,
475 ColumnarValue::Array(a) => Some(a.len()),
476 })
477 .unwrap_or(1);
478 let return_type = fis.return_type(&type_array)?;
479 let arg_fields = args
480 .iter()
481 .enumerate()
482 .map(|(idx, a)| {
483 Field::new(format!("arg_{idx}"), a.data_type(), true).into()
484 })
485 .collect::<Vec<_>>();
486 let result = fis.invoke_with_args(ScalarFunctionArgs {
487 args,
488 arg_fields,
489 number_rows: cardinality,
490 return_field: Field::new("f", return_type, true).into(),
491 config_options: Arc::new(ConfigOptions::default()),
492 });
493 assert!(result.is_ok());
494
495 let result = result?
496 .to_array(cardinality)
497 .expect("Failed to convert to array");
498 let result = result
499 .as_any()
500 .downcast_ref::<Int32Array>()
501 .expect("Failed to convert to type");
502 assert_eq!(*result, expected);
503
504 Ok(())
505 }
506 };
507 }
508
509 test_find_in_set!(
510 test_find_in_set_with_scalar_args,
511 vec![
512 ColumnarValue::Array(Arc::new(StringArray::from(vec![
513 "", "a", "b", "c", "d"
514 ]))),
515 ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
516 ],
517 Int32Array::from(vec![0, 0, 1, 2, 3])
518 );
519 test_find_in_set!(
520 test_find_in_set_with_scalar_args_2,
521 vec![
522 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
523 "ApacheSoftware".to_string()
524 ))),
525 ColumnarValue::Array(Arc::new(StringArray::from(vec![
526 "a,b,c",
527 "ApacheSoftware,Github,DataFusion",
528 ""
529 ]))),
530 ],
531 Int32Array::from(vec![0, 1, 0])
532 );
533 test_find_in_set!(
534 test_find_in_set_with_scalar_args_3,
535 vec![
536 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
537 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
538 ],
539 Int32Array::from(vec![None::<i32>; 3])
540 );
541 test_find_in_set!(
542 test_find_in_set_with_scalar_args_4,
543 vec![
544 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
545 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
546 ],
547 Int32Array::from(vec![None::<i32>; 3])
548 );
549}