1use std::sync::Arc;
19
20use arrow::array::{
21 ArrayAccessor, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray,
22};
23use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
24use arrow_buffer::NullBuffer;
25
26use crate::utils::utf8_to_int_type;
27use datafusion_common::{
28 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
29};
30use datafusion_expr::TypeSignature::Exact;
31use datafusion_expr::{
32 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33 Volatility,
34};
35use datafusion_macros::user_doc;
36
37#[user_doc(
38 doc_section(label = "String Functions"),
39 description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
40 syntax_example = "find_in_set(str, strlist)",
41 sql_example = r#"```sql
42> select find_in_set('b', 'a,b,c,d');
43+----------------------------------------+
44| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
45+----------------------------------------+
46| 2 |
47+----------------------------------------+
48```"#,
49 argument(name = "str", description = "String expression to find in strlist."),
50 argument(
51 name = "strlist",
52 description = "A string list is a string composed of substrings separated by , characters."
53 )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct FindInSetFunc {
57 signature: Signature,
58}
59
60impl Default for FindInSetFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl FindInSetFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::one_of(
71 vec![
72 Exact(vec![Utf8View, Utf8View]),
73 Exact(vec![Utf8, Utf8]),
74 Exact(vec![LargeUtf8, LargeUtf8]),
75 ],
76 Volatility::Immutable,
77 ),
78 }
79 }
80}
81
82impl ScalarUDFImpl for FindInSetFunc {
83 fn name(&self) -> &str {
84 "find_in_set"
85 }
86
87 fn signature(&self) -> &Signature {
88 &self.signature
89 }
90
91 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92 utf8_to_int_type(&arg_types[0], "find_in_set")
93 }
94
95 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96 let return_field = args.return_field;
97 let [string, str_list] = take_function_args(self.name(), args.args)?;
98
99 match (string, str_list) {
100 (
102 ColumnarValue::Scalar(
103 ScalarValue::Utf8View(string)
104 | ScalarValue::Utf8(string)
105 | ScalarValue::LargeUtf8(string),
106 ),
107 ColumnarValue::Scalar(
108 ScalarValue::Utf8View(str_list)
109 | ScalarValue::Utf8(str_list)
110 | ScalarValue::LargeUtf8(str_list),
111 ),
112 ) => {
113 let res = match (string, str_list) {
114 (Some(string), Some(str_list)) => {
115 let position = str_list
116 .split(',')
117 .position(|s| s == string)
118 .map_or(0, |idx| idx + 1);
119
120 Some(position as i32)
121 }
122 _ => None,
123 };
124 Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
125 }
126
127 (
129 ColumnarValue::Array(str_array),
130 ColumnarValue::Scalar(
131 ScalarValue::Utf8View(str_list_literal)
132 | ScalarValue::Utf8(str_list_literal)
133 | ScalarValue::LargeUtf8(str_list_literal),
134 ),
135 ) => {
136 match str_list_literal {
137 None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
139 return_field.data_type(),
140 )?)),
141 Some(str_list_literal) => {
142 let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
143 let result = match str_array.data_type() {
144 DataType::Utf8 => {
145 let string_array = str_array.as_string::<i32>();
146 find_in_set_right_literal::<Int32Type, _>(
147 string_array,
148 &str_list,
149 )
150 }
151 DataType::LargeUtf8 => {
152 let string_array = str_array.as_string::<i64>();
153 find_in_set_right_literal::<Int64Type, _>(
154 string_array,
155 &str_list,
156 )
157 }
158 DataType::Utf8View => {
159 let string_array = str_array.as_string_view();
160 find_in_set_right_literal::<Int32Type, _>(
161 string_array,
162 &str_list,
163 )
164 }
165 other => {
166 exec_err!(
167 "Unsupported data type {other:?} for function find_in_set"
168 )
169 }
170 };
171 Ok(ColumnarValue::Array(Arc::new(result?)))
172 }
173 }
174 }
175
176 (
178 ColumnarValue::Scalar(
179 ScalarValue::Utf8View(string_literal)
180 | ScalarValue::Utf8(string_literal)
181 | ScalarValue::LargeUtf8(string_literal),
182 ),
183 ColumnarValue::Array(str_list_array),
184 ) => {
185 match string_literal {
186 None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
188 return_field.data_type(),
189 )?)),
190 Some(string) => {
191 let result = match str_list_array.data_type() {
192 DataType::Utf8 => {
193 let str_list = str_list_array.as_string::<i32>();
194 find_in_set_left_literal::<Int32Type, _>(
195 &string, str_list,
196 )
197 }
198 DataType::LargeUtf8 => {
199 let str_list = str_list_array.as_string::<i64>();
200 find_in_set_left_literal::<Int64Type, _>(
201 &string, str_list,
202 )
203 }
204 DataType::Utf8View => {
205 let str_list = str_list_array.as_string_view();
206 find_in_set_left_literal::<Int32Type, _>(
207 &string, str_list,
208 )
209 }
210 other => {
211 exec_err!(
212 "Unsupported data type {other:?} for function find_in_set"
213 )
214 }
215 };
216 Ok(ColumnarValue::Array(Arc::new(result?)))
217 }
218 }
219 }
220
221 (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
223 let res = find_in_set(&base_array, &exp_array)?;
224
225 Ok(ColumnarValue::Array(res))
226 }
227 _ => {
228 internal_err!("Invalid argument types for `find_in_set` function")
229 }
230 }
231 }
232
233 fn documentation(&self) -> Option<&Documentation> {
234 self.doc()
235 }
236}
237
238fn find_in_set(str: &ArrayRef, str_list: &ArrayRef) -> Result<ArrayRef> {
242 match str.data_type() {
243 DataType::Utf8 => {
244 let string_array = str.as_string::<i32>();
245 let str_list_array = str_list.as_string::<i32>();
246 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
247 }
248 DataType::LargeUtf8 => {
249 let string_array = str.as_string::<i64>();
250 let str_list_array = str_list.as_string::<i64>();
251 find_in_set_general::<Int64Type, _>(string_array, str_list_array)
252 }
253 DataType::Utf8View => {
254 let string_array = str.as_string_view();
255 let str_list_array = str_list.as_string_view();
256 find_in_set_general::<Int32Type, _>(string_array, str_list_array)
257 }
258 other => {
259 exec_err!("Unsupported data type {other:?} for function find_in_set")
260 }
261 }
262}
263
264fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result<ArrayRef>
265where
266 T: ArrowPrimitiveType,
267 T::Native: OffsetSizeTrait,
268 V: ArrayAccessor<Item = &'a str> + Copy,
269{
270 let len = string_array.len();
271 let nulls = NullBuffer::union(string_array.nulls(), str_list_array.nulls());
272 let zero = T::Native::from_usize(0).unwrap();
273
274 let values: Vec<T::Native> = (0..len)
275 .map(|i| {
276 if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
277 return zero;
278 }
279 let string = string_array.value(i);
280 let str_list = str_list_array.value(i);
281 let position = str_list
282 .split(',')
283 .position(|s| s == string)
284 .map_or(0, |idx| idx + 1);
285 T::Native::from_usize(position).unwrap()
286 })
287 .collect();
288
289 Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result<ArrayRef>
293where
294 T: ArrowPrimitiveType,
295 T::Native: OffsetSizeTrait,
296 V: ArrayAccessor<Item = &'a str> + Copy,
297{
298 let len = str_list_array.len();
299 let nulls = str_list_array.nulls().cloned();
300 let zero = T::Native::from_usize(0).unwrap();
301
302 let values: Vec<T::Native> = (0..len)
303 .map(|i| {
304 if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
305 return zero;
306 }
307 let str_list = str_list_array.value(i);
308 let position = str_list
309 .split(',')
310 .position(|s| s == string)
311 .map_or(0, |idx| idx + 1);
312 T::Native::from_usize(position).unwrap()
313 })
314 .collect();
315
316 Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320 string_array: V,
321 str_list: &[&str],
322) -> Result<ArrayRef>
323where
324 T: ArrowPrimitiveType,
325 T::Native: OffsetSizeTrait,
326 V: ArrayAccessor<Item = &'a str> + Copy,
327{
328 let len = string_array.len();
329 let nulls = string_array.nulls().cloned();
330 let zero = T::Native::from_usize(0).unwrap();
331
332 let values: Vec<T::Native> = (0..len)
333 .map(|i| {
334 if nulls.as_ref().is_some_and(|n| n.is_null(i)) {
335 return zero;
336 }
337 let string = string_array.value(i);
338 let position = str_list
339 .iter()
340 .position(|s| *s == string)
341 .map_or(0, |idx| idx + 1);
342 T::Native::from_usize(position).unwrap()
343 })
344 .collect();
345
346 Ok(Arc::new(PrimitiveArray::<T>::new(values.into(), nulls)) as ArrayRef)
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::unicode::find_in_set::FindInSetFunc;
352 use crate::utils::test::test_function;
353 use arrow::array::{Array, Int32Array, StringArray};
354 use arrow::datatypes::{DataType::Int32, Field};
355 use datafusion_common::config::ConfigOptions;
356 use datafusion_common::{Result, ScalarValue};
357 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
358 use std::sync::Arc;
359
360 #[test]
361 fn test_functions() -> Result<()> {
362 test_function!(
363 FindInSetFunc::new(),
364 vec![
365 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
366 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
367 ],
368 Ok(Some(1)),
369 i32,
370 Int32,
371 Int32Array
372 );
373 test_function!(
374 FindInSetFunc::new(),
375 vec![
376 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
377 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
378 "a,Д,🔥"
379 )))),
380 ],
381 Ok(Some(3)),
382 i32,
383 Int32,
384 Int32Array
385 );
386 test_function!(
387 FindInSetFunc::new(),
388 vec![
389 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
390 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
391 ],
392 Ok(Some(0)),
393 i32,
394 Int32,
395 Int32Array
396 );
397 test_function!(
398 FindInSetFunc::new(),
399 vec![
400 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
401 "Apache Software Foundation"
402 )))),
403 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
404 "Github,Apache Software Foundation,DataFusion"
405 )))),
406 ],
407 Ok(Some(2)),
408 i32,
409 Int32,
410 Int32Array
411 );
412 test_function!(
413 FindInSetFunc::new(),
414 vec![
415 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
416 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
417 ],
418 Ok(Some(0)),
419 i32,
420 Int32,
421 Int32Array
422 );
423 test_function!(
424 FindInSetFunc::new(),
425 vec![
426 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
427 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
428 ],
429 Ok(Some(0)),
430 i32,
431 Int32,
432 Int32Array
433 );
434 test_function!(
435 FindInSetFunc::new(),
436 vec![
437 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
438 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
439 ],
440 Ok(None),
441 i32,
442 Int32,
443 Int32Array
444 );
445 test_function!(
446 FindInSetFunc::new(),
447 vec![
448 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
449 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
450 ],
451 Ok(None),
452 i32,
453 Int32,
454 Int32Array
455 );
456
457 Ok(())
458 }
459
460 macro_rules! test_find_in_set {
461 ($test_name:ident, $args:expr, $expected:expr) => {
462 #[test]
463 fn $test_name() -> Result<()> {
464 let fis = crate::unicode::find_in_set();
465
466 let args = $args;
467 let expected = $expected;
468
469 let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
470 let cardinality = args
471 .iter()
472 .fold(Option::<usize>::None, |acc, arg| match arg {
473 ColumnarValue::Scalar(_) => acc,
474 ColumnarValue::Array(a) => Some(a.len()),
475 })
476 .unwrap_or(1);
477 let return_type = fis.return_type(&type_array)?;
478 let arg_fields = args
479 .iter()
480 .enumerate()
481 .map(|(idx, a)| {
482 Field::new(format!("arg_{idx}"), a.data_type(), true).into()
483 })
484 .collect::<Vec<_>>();
485 let result = fis.invoke_with_args(ScalarFunctionArgs {
486 args,
487 arg_fields,
488 number_rows: cardinality,
489 return_field: Field::new("f", return_type, true).into(),
490 config_options: Arc::new(ConfigOptions::default()),
491 });
492 assert!(result.is_ok());
493
494 let result = result?
495 .to_array(cardinality)
496 .expect("Failed to convert to array");
497 let result = result
498 .as_any()
499 .downcast_ref::<Int32Array>()
500 .expect("Failed to convert to type");
501 assert_eq!(*result, expected);
502
503 Ok(())
504 }
505 };
506 }
507
508 test_find_in_set!(
509 test_find_in_set_with_scalar_args,
510 vec![
511 ColumnarValue::Array(Arc::new(StringArray::from(vec![
512 "", "a", "b", "c", "d"
513 ]))),
514 ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
515 ],
516 Int32Array::from(vec![0, 0, 1, 2, 3])
517 );
518 test_find_in_set!(
519 test_find_in_set_with_scalar_args_2,
520 vec![
521 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
522 "ApacheSoftware".to_string()
523 ))),
524 ColumnarValue::Array(Arc::new(StringArray::from(vec![
525 "a,b,c",
526 "ApacheSoftware,Github,DataFusion",
527 ""
528 ]))),
529 ],
530 Int32Array::from(vec![0, 1, 0])
531 );
532 test_find_in_set!(
533 test_find_in_set_with_scalar_args_3,
534 vec![
535 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
536 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
537 ],
538 Int32Array::from(vec![None::<i32>; 3])
539 );
540 test_find_in_set!(
541 test_find_in_set_with_scalar_args_4,
542 vec![
543 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
544 ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
545 ],
546 Int32Array::from(vec![None::<i32>; 3])
547 );
548}