1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23 OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29 exec_err, internal_err, utils::take_function_args, Result, ScalarValue,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39 doc_section(label = "String Functions"),
40 description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41 syntax_example = "find_in_set(str, strlist)",
42 sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2 |
48+----------------------------------------+
49```"#,
50 argument(name = "str", description = "String expression to find in strlist."),
51 argument(
52 name = "strlist",
53 description = "A string list is a string composed of substrings separated by , characters."
54 )
55)]
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct FindInSetFunc {
58 signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl FindInSetFunc {
68 pub fn new() -> Self {
69 use DataType::*;
70 Self {
71 signature: Signature::one_of(
72 vec![
73 Exact(vec![Utf8View, Utf8View]),
74 Exact(vec![Utf8, Utf8]),
75 Exact(vec![LargeUtf8, LargeUtf8]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "find_in_set"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_int_type(&arg_types[0], "find_in_set")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let ScalarFunctionArgs { args, .. } = args;
102
103 let [string, str_list] = take_function_args(self.name(), args)?;
104
105 match (string, str_list) {
106 (
108 ColumnarValue::Scalar(
109 ScalarValue::Utf8View(string)
110 | ScalarValue::Utf8(string)
111 | ScalarValue::LargeUtf8(string),
112 ),
113 ColumnarValue::Scalar(
114 ScalarValue::Utf8View(str_list)
115 | ScalarValue::Utf8(str_list)
116 | ScalarValue::LargeUtf8(str_list),
117 ),
118 ) => {
119 let res = match (string, str_list) {
120 (Some(string), Some(str_list)) => {
121 let position = str_list
122 .split(',')
123 .position(|s| s == string)
124 .map_or(0, |idx| idx + 1);
125
126 Some(position as i32)
127 }
128 _ => None,
129 };
130 Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131 }
132
133 (
135 ColumnarValue::Array(str_array),
136 ColumnarValue::Scalar(
137 ScalarValue::Utf8View(str_list_literal)
138 | ScalarValue::Utf8(str_list_literal)
139 | ScalarValue::LargeUtf8(str_list_literal),
140 ),
141 ) => {
142 let result_array = match str_list_literal {
143 None => new_null_array(str_array.data_type(), str_array.len()),
145 Some(str_list_literal) => {
146 let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147 let result = match str_array.data_type() {
148 DataType::Utf8 => {
149 let string_array = str_array.as_string::<i32>();
150 find_in_set_right_literal::<Int32Type, _>(
151 string_array,
152 str_list,
153 )
154 }
155 DataType::LargeUtf8 => {
156 let string_array = str_array.as_string::<i64>();
157 find_in_set_right_literal::<Int64Type, _>(
158 string_array,
159 str_list,
160 )
161 }
162 DataType::Utf8View => {
163 let string_array = str_array.as_string_view();
164 find_in_set_right_literal::<Int32Type, _>(
165 string_array,
166 str_list,
167 )
168 }
169 other => {
170 exec_err!("Unsupported data type {other:?} for function find_in_set")
171 }
172 };
173 Arc::new(result?)
174 }
175 };
176 Ok(ColumnarValue::Array(result_array))
177 }
178
179 (
181 ColumnarValue::Scalar(
182 ScalarValue::Utf8View(string_literal)
183 | ScalarValue::Utf8(string_literal)
184 | ScalarValue::LargeUtf8(string_literal),
185 ),
186 ColumnarValue::Array(str_list_array),
187 ) => {
188 let res = match string_literal {
189 None => {
191 new_null_array(str_list_array.data_type(), str_list_array.len())
192 }
193 Some(string) => {
194 let result = match str_list_array.data_type() {
195 DataType::Utf8 => {
196 let str_list = str_list_array.as_string::<i32>();
197 find_in_set_left_literal::<Int32Type, _>(string, str_list)
198 }
199 DataType::LargeUtf8 => {
200 let str_list = str_list_array.as_string::<i64>();
201 find_in_set_left_literal::<Int64Type, _>(string, str_list)
202 }
203 DataType::Utf8View => {
204 let str_list = str_list_array.as_string_view();
205 find_in_set_left_literal::<Int32Type, _>(string, str_list)
206 }
207 other => {
208 exec_err!("Unsupported data type {other:?} for function find_in_set")
209 }
210 };
211 Arc::new(result?)
212 }
213 };
214 Ok(ColumnarValue::Array(res))
215 }
216
217 (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
219 let res = find_in_set(base_array, exp_array)?;
220
221 Ok(ColumnarValue::Array(res))
222 }
223 _ => {
224 internal_err!("Invalid argument types for `find_in_set` function")
225 }
226 }
227 }
228
229 fn documentation(&self) -> Option<&Documentation> {
230 self.doc()
231 }
232}
233
234fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
238 match str.data_type() {
239 DataType::Utf8 => {
240 let string_array = str.as_string::<i32>();
241 let str_list_array = str_list.as_string::<i32>();
242 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
243 }
244 DataType::LargeUtf8 => {
245 let string_array = str.as_string::<i64>();
246 let str_list_array = str_list.as_string::<i64>();
247 find_in_set_general::<Int64Type, _>(string_array, str_list_array)
248 }
249 DataType::Utf8View => {
250 let string_array = str.as_string_view();
251 let str_list_array = str_list.as_string_view();
252 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253 }
254 other => {
255 exec_err!("Unsupported data type {other:?} for function find_in_set")
256 }
257 }
258}
259
260pub fn find_in_set_general<'a, T, V>(
261 string_array: V,
262 str_list_array: V,
263) -> Result<ArrayRef>
264where
265 T: ArrowPrimitiveType,
266 T::Native: OffsetSizeTrait,
267 V: ArrayAccessor<Item = &'a str>,
268{
269 let string_iter = ArrayIter::new(string_array);
270 let str_list_iter = ArrayIter::new(str_list_array);
271
272 let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
273
274 string_iter
275 .zip(str_list_iter)
276 .for_each(
277 |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
278 (Some(string), Some(str_list)) => {
279 let position = str_list
280 .split(',')
281 .position(|s| s == string)
282 .map_or(0, |idx| idx + 1);
283 builder.append_value(T::Native::from_usize(position).unwrap());
284 }
285 _ => builder.append_null(),
286 },
287 );
288
289 Ok(Arc::new(builder.finish()) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(
293 string: String,
294 str_list_array: V,
295) -> Result<ArrayRef>
296where
297 T: ArrowPrimitiveType,
298 T::Native: OffsetSizeTrait,
299 V: ArrayAccessor<Item = &'a str>,
300{
301 let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
302
303 let str_list_iter = ArrayIter::new(str_list_array);
304
305 str_list_iter.for_each(|str_list_opt| match str_list_opt {
306 Some(str_list) => {
307 let position = str_list
308 .split(',')
309 .position(|s| s == string)
310 .map_or(0, |idx| idx + 1);
311 builder.append_value(T::Native::from_usize(position).unwrap());
312 }
313 None => builder.append_null(),
314 });
315
316 Ok(Arc::new(builder.finish()) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320 string_array: V,
321 str_list: Vec<&str>,
322) -> Result<ArrayRef>
323where
324 T: ArrowPrimitiveType,
325 T::Native: OffsetSizeTrait,
326 V: ArrayAccessor<Item = &'a str>,
327{
328 let mut builder = PrimitiveArray::<T>::builder(string_array.len());
329
330 let string_iter = ArrayIter::new(string_array);
331
332 string_iter.for_each(|string_opt| match string_opt {
333 Some(string) => {
334 let position = str_list
335 .iter()
336 .position(|s| *s == string)
337 .map_or(0, |idx| idx + 1);
338 builder.append_value(T::Native::from_usize(position).unwrap());
339 }
340 None => builder.append_null(),
341 });
342
343 Ok(Arc::new(builder.finish()) as ArrayRef)
344}
345
346#[cfg(test)]
347mod tests {
348 use crate::unicode::find_in_set::FindInSetFunc;
349 use crate::utils::test::test_function;
350 use arrow::array::{Array, Int32Array, StringArray};
351 use arrow::datatypes::{DataType::Int32, Field};
352 use datafusion_common::config::ConfigOptions;
353 use datafusion_common::{Result, ScalarValue};
354 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
355 use std::sync::Arc;
356
357 #[test]
358 fn test_functions() -> Result<()> {
359 test_function!(
360 FindInSetFunc::new(),
361 vec![
362 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
363 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
364 ],
365 Ok(Some(1)),
366 i32,
367 Int32,
368 Int32Array
369 );
370 test_function!(
371 FindInSetFunc::new(),
372 vec![
373 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
374 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
375 "a,Д,🔥"
376 )))),
377 ],
378 Ok(Some(3)),
379 i32,
380 Int32,
381 Int32Array
382 );
383 test_function!(
384 FindInSetFunc::new(),
385 vec![
386 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
387 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
388 ],
389 Ok(Some(0)),
390 i32,
391 Int32,
392 Int32Array
393 );
394 test_function!(
395 FindInSetFunc::new(),
396 vec![
397 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
398 "Apache Software Foundation"
399 )))),
400 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401 "Github,Apache Software Foundation,DataFusion"
402 )))),
403 ],
404 Ok(Some(2)),
405 i32,
406 Int32,
407 Int32Array
408 );
409 test_function!(
410 FindInSetFunc::new(),
411 vec![
412 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
413 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
414 ],
415 Ok(Some(0)),
416 i32,
417 Int32,
418 Int32Array
419 );
420 test_function!(
421 FindInSetFunc::new(),
422 vec![
423 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
424 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
425 ],
426 Ok(Some(0)),
427 i32,
428 Int32,
429 Int32Array
430 );
431 test_function!(
432 FindInSetFunc::new(),
433 vec![
434 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
435 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
436 ],
437 Ok(None),
438 i32,
439 Int32,
440 Int32Array
441 );
442 test_function!(
443 FindInSetFunc::new(),
444 vec![
445 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
446 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
447 ],
448 Ok(None),
449 i32,
450 Int32,
451 Int32Array
452 );
453
454 Ok(())
455 }
456
457 macro_rules! test_find_in_set {
458 ($test_name:ident, $args:expr, $expected:expr) => {
459 #[test]
460 fn $test_name() -> Result<()> {
461 let fis = crate::unicode::find_in_set();
462
463 let args = $args;
464 let expected = $expected;
465
466 let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
467 let cardinality = args
468 .iter()
469 .fold(Option::<usize>::None, |acc, arg| match arg {
470 ColumnarValue::Scalar(_) => acc,
471 ColumnarValue::Array(a) => Some(a.len()),
472 })
473 .unwrap_or(1);
474 let return_type = fis.return_type(&type_array)?;
475 let arg_fields = args
476 .iter()
477 .enumerate()
478 .map(|(idx, a)| {
479 Field::new(format!("arg_{idx}"), a.data_type(), true).into()
480 })
481 .collect::<Vec<_>>();
482 let result = fis.invoke_with_args(ScalarFunctionArgs {
483 args,
484 arg_fields,
485 number_rows: cardinality,
486 return_field: Field::new("f", return_type, true).into(),
487 config_options: Arc::new(ConfigOptions::default()),
488 });
489 assert!(result.is_ok());
490
491 let result = result?
492 .to_array(cardinality)
493 .expect("Failed to convert to array");
494 let result = result
495 .as_any()
496 .downcast_ref::<Int32Array>()
497 .expect("Failed to convert to type");
498 assert_eq!(*result, expected);
499
500 Ok(())
501 }
502 };
503 }
504
505 test_find_in_set!(
506 test_find_in_set_with_scalar_args,
507 vec![
508 ColumnarValue::Array(Arc::new(StringArray::from(vec![
509 "", "a", "b", "c", "d"
510 ]))),
511 ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
512 ],
513 Int32Array::from(vec![0, 0, 1, 2, 3])
514 );
515 test_find_in_set!(
516 test_find_in_set_with_scalar_args_2,
517 vec![
518 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
519 "ApacheSoftware".to_string()
520 ))),
521 ColumnarValue::Array(Arc::new(StringArray::from(vec![
522 "a,b,c",
523 "ApacheSoftware,Github,DataFusion",
524 ""
525 ]))),
526 ],
527 Int32Array::from(vec![0, 1, 0])
528 );
529 test_find_in_set!(
530 test_find_in_set_with_scalar_args_3,
531 vec![
532 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
533 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
534 ],
535 Int32Array::from(vec![None::<i32>; 3])
536 );
537 test_find_in_set!(
538 test_find_in_set_with_scalar_args_4,
539 vec![
540 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
541 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
542 ],
543 Int32Array::from(vec![None::<i32>; 3])
544 );
545}