1use arrow::array::{
21 Array, ArrayRef, BooleanArray, Datum, GenericListArray, OffsetSizeTrait, Scalar,
22};
23use arrow::buffer::BooleanBuffer;
24use arrow::datatypes::DataType;
25use arrow::row::{RowConverter, Rows, SortField};
26use datafusion_common::cast::as_generic_list_array;
27use datafusion_common::utils::string_utils::string_array_to_vec;
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{exec_err, Result, ScalarValue};
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34use datafusion_physical_expr_common::datum::compare_with_eq;
35use itertools::Itertools;
36
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42make_udf_expr_and_func!(ArrayHas,
44 array_has,
45 haystack_array element, "returns true, if the element appears in the first array, otherwise false.", array_has_udf );
49make_udf_expr_and_func!(ArrayHasAll,
50 array_has_all,
51 haystack_array needle_array, "returns true if each element of the second array appears in the first array; otherwise, it returns false.", array_has_all_udf );
55make_udf_expr_and_func!(ArrayHasAny,
56 array_has_any,
57 haystack_array needle_array, "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", array_has_any_udf );
61
62#[user_doc(
63 doc_section(label = "Array Functions"),
64 description = "Returns true if the array contains the element.",
65 syntax_example = "array_has(array, element)",
66 sql_example = r#"```sql
67> select array_has([1, 2, 3], 2);
68+-----------------------------+
69| array_has(List([1,2,3]), 2) |
70+-----------------------------+
71| true |
72+-----------------------------+
73```"#,
74 argument(
75 name = "array",
76 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77 ),
78 argument(
79 name = "element",
80 description = "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators."
81 )
82)]
83#[derive(Debug)]
84pub struct ArrayHas {
85 signature: Signature,
86 aliases: Vec<String>,
87}
88
89impl Default for ArrayHas {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl ArrayHas {
96 pub fn new() -> Self {
97 Self {
98 signature: Signature::array_and_element(Volatility::Immutable),
99 aliases: vec![
100 String::from("list_has"),
101 String::from("array_contains"),
102 String::from("list_contains"),
103 ],
104 }
105 }
106}
107
108impl ScalarUDFImpl for ArrayHas {
109 fn as_any(&self) -> &dyn Any {
110 self
111 }
112 fn name(&self) -> &str {
113 "array_has"
114 }
115
116 fn signature(&self) -> &Signature {
117 &self.signature
118 }
119
120 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
121 Ok(DataType::Boolean)
122 }
123
124 fn invoke_with_args(
125 &self,
126 args: datafusion_expr::ScalarFunctionArgs,
127 ) -> Result<ColumnarValue> {
128 let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?;
129 match &second_arg {
130 ColumnarValue::Array(array_needle) => {
131 let haystack = first_arg.to_array(array_needle.len())?;
133 let array = array_has_inner_for_array(&haystack, array_needle)?;
134 Ok(ColumnarValue::Array(array))
135 }
136 ColumnarValue::Scalar(scalar_needle) => {
137 if scalar_needle.is_null() {
140 return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
141 }
142
143 let haystack = first_arg.to_array(1)?;
145 let needle = scalar_needle.to_array_of_size(1)?;
146 let needle = Scalar::new(needle);
147 let array = array_has_inner_for_scalar(&haystack, &needle)?;
148 if let ColumnarValue::Scalar(_) = &first_arg {
149 let scalar_value = ScalarValue::try_from_array(&array, 0)?;
151 Ok(ColumnarValue::Scalar(scalar_value))
152 } else {
153 Ok(ColumnarValue::Array(array))
154 }
155 }
156 }
157 }
158
159 fn aliases(&self) -> &[String] {
160 &self.aliases
161 }
162
163 fn documentation(&self) -> Option<&Documentation> {
164 self.doc()
165 }
166}
167
168fn array_has_inner_for_scalar(
169 haystack: &ArrayRef,
170 needle: &dyn Datum,
171) -> Result<ArrayRef> {
172 match haystack.data_type() {
173 DataType::List(_) => array_has_dispatch_for_scalar::<i32>(haystack, needle),
174 DataType::LargeList(_) => array_has_dispatch_for_scalar::<i64>(haystack, needle),
175 _ => exec_err!(
176 "array_has does not support type '{:?}'.",
177 haystack.data_type()
178 ),
179 }
180}
181
182fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result<ArrayRef> {
183 match haystack.data_type() {
184 DataType::List(_) => array_has_dispatch_for_array::<i32>(haystack, needle),
185 DataType::LargeList(_) => array_has_dispatch_for_array::<i64>(haystack, needle),
186 _ => exec_err!(
187 "array_has does not support type '{:?}'.",
188 haystack.data_type()
189 ),
190 }
191}
192
193fn array_has_dispatch_for_array<O: OffsetSizeTrait>(
194 haystack: &ArrayRef,
195 needle: &ArrayRef,
196) -> Result<ArrayRef> {
197 let haystack = as_generic_list_array::<O>(haystack)?;
198 let mut boolean_builder = BooleanArray::builder(haystack.len());
199
200 for (i, arr) in haystack.iter().enumerate() {
201 if arr.is_none() || needle.is_null(i) {
202 boolean_builder.append_null();
203 continue;
204 }
205 let arr = arr.unwrap();
206 let is_nested = arr.data_type().is_nested();
207 let needle_row = Scalar::new(needle.slice(i, 1));
208 let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?;
209 boolean_builder.append_value(eq_array.true_count() > 0);
210 }
211
212 Ok(Arc::new(boolean_builder.finish()))
213}
214
215fn array_has_dispatch_for_scalar<O: OffsetSizeTrait>(
216 haystack: &ArrayRef,
217 needle: &dyn Datum,
218) -> Result<ArrayRef> {
219 let haystack = as_generic_list_array::<O>(haystack)?;
220 let values = haystack.values();
221 let is_nested = values.data_type().is_nested();
222 let offsets = haystack.value_offsets();
223 if values.len() == 0 {
226 return Ok(Arc::new(BooleanArray::new(
227 BooleanBuffer::new_unset(haystack.len()),
228 None,
229 )));
230 }
231 let eq_array = compare_with_eq(values, needle, is_nested)?;
232 let mut final_contained = vec![None; haystack.len()];
233 for (i, offset) in offsets.windows(2).enumerate() {
234 let start = offset[0].to_usize().unwrap();
235 let end = offset[1].to_usize().unwrap();
236 let length = end - start;
237 if length == 0 {
239 continue;
240 }
241 let sliced_array = eq_array.slice(start, length);
242 final_contained[i] = Some(sliced_array.true_count() > 0);
243 }
244
245 Ok(Arc::new(BooleanArray::from(final_contained)))
246}
247
248fn array_has_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
249 match args[0].data_type() {
250 DataType::List(_) => {
251 array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::All)
252 }
253 DataType::LargeList(_) => {
254 array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::All)
255 }
256 _ => exec_err!(
257 "array_has does not support type '{:?}'.",
258 args[0].data_type()
259 ),
260 }
261}
262
263fn array_has_any_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
264 match args[0].data_type() {
265 DataType::List(_) => {
266 array_has_all_and_any_dispatch::<i32>(&args[0], &args[1], ComparisonType::Any)
267 }
268 DataType::LargeList(_) => {
269 array_has_all_and_any_dispatch::<i64>(&args[0], &args[1], ComparisonType::Any)
270 }
271 _ => exec_err!(
272 "array_has does not support type '{:?}'.",
273 args[0].data_type()
274 ),
275 }
276}
277
278#[user_doc(
279 doc_section(label = "Array Functions"),
280 description = "Returns true if all elements of sub-array exist in array.",
281 syntax_example = "array_has_all(array, sub-array)",
282 sql_example = r#"```sql
283> select array_has_all([1, 2, 3, 4], [2, 3]);
284+--------------------------------------------+
285| array_has_all(List([1,2,3,4]), List([2,3])) |
286+--------------------------------------------+
287| true |
288+--------------------------------------------+
289```"#,
290 argument(
291 name = "array",
292 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
293 ),
294 argument(
295 name = "sub-array",
296 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
297 )
298)]
299#[derive(Debug)]
300pub struct ArrayHasAll {
301 signature: Signature,
302 aliases: Vec<String>,
303}
304
305impl Default for ArrayHasAll {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl ArrayHasAll {
312 pub fn new() -> Self {
313 Self {
314 signature: Signature::any(2, Volatility::Immutable),
315 aliases: vec![String::from("list_has_all")],
316 }
317 }
318}
319
320impl ScalarUDFImpl for ArrayHasAll {
321 fn as_any(&self) -> &dyn Any {
322 self
323 }
324 fn name(&self) -> &str {
325 "array_has_all"
326 }
327
328 fn signature(&self) -> &Signature {
329 &self.signature
330 }
331
332 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
333 Ok(DataType::Boolean)
334 }
335
336 fn invoke_with_args(
337 &self,
338 args: datafusion_expr::ScalarFunctionArgs,
339 ) -> Result<ColumnarValue> {
340 make_scalar_function(array_has_all_inner)(&args.args)
341 }
342
343 fn aliases(&self) -> &[String] {
344 &self.aliases
345 }
346
347 fn documentation(&self) -> Option<&Documentation> {
348 self.doc()
349 }
350}
351
352#[user_doc(
353 doc_section(label = "Array Functions"),
354 description = "Returns true if any elements exist in both arrays.",
355 syntax_example = "array_has_any(array, sub-array)",
356 sql_example = r#"```sql
357> select array_has_any([1, 2, 3], [3, 4]);
358+------------------------------------------+
359| array_has_any(List([1,2,3]), List([3,4])) |
360+------------------------------------------+
361| true |
362+------------------------------------------+
363```"#,
364 argument(
365 name = "array",
366 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
367 ),
368 argument(
369 name = "sub-array",
370 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
371 )
372)]
373#[derive(Debug)]
374pub struct ArrayHasAny {
375 signature: Signature,
376 aliases: Vec<String>,
377}
378
379impl Default for ArrayHasAny {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385impl ArrayHasAny {
386 pub fn new() -> Self {
387 Self {
388 signature: Signature::any(2, Volatility::Immutable),
389 aliases: vec![String::from("list_has_any"), String::from("arrays_overlap")],
390 }
391 }
392}
393
394impl ScalarUDFImpl for ArrayHasAny {
395 fn as_any(&self) -> &dyn Any {
396 self
397 }
398 fn name(&self) -> &str {
399 "array_has_any"
400 }
401
402 fn signature(&self) -> &Signature {
403 &self.signature
404 }
405
406 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
407 Ok(DataType::Boolean)
408 }
409
410 fn invoke_with_args(
411 &self,
412 args: datafusion_expr::ScalarFunctionArgs,
413 ) -> Result<ColumnarValue> {
414 make_scalar_function(array_has_any_inner)(&args.args)
415 }
416
417 fn aliases(&self) -> &[String] {
418 &self.aliases
419 }
420
421 fn documentation(&self) -> Option<&Documentation> {
422 self.doc()
423 }
424}
425
426#[derive(Debug, PartialEq, Clone, Copy)]
428enum ComparisonType {
429 All,
431 Any,
433}
434
435fn array_has_all_and_any_dispatch<O: OffsetSizeTrait>(
436 haystack: &ArrayRef,
437 needle: &ArrayRef,
438 comparison_type: ComparisonType,
439) -> Result<ArrayRef> {
440 let haystack = as_generic_list_array::<O>(haystack)?;
441 let needle = as_generic_list_array::<O>(needle)?;
442 match needle.data_type() {
443 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
444 array_has_all_and_any_string_internal::<O>(haystack, needle, comparison_type)
445 }
446 _ => general_array_has_for_all_and_any::<O>(haystack, needle, comparison_type),
447 }
448}
449
450fn array_has_all_and_any_string_internal<O: OffsetSizeTrait>(
452 array: &GenericListArray<O>,
453 needle: &GenericListArray<O>,
454 comparison_type: ComparisonType,
455) -> Result<ArrayRef> {
456 let mut boolean_builder = BooleanArray::builder(array.len());
457 for (arr, sub_arr) in array.iter().zip(needle.iter()) {
458 match (arr, sub_arr) {
459 (Some(arr), Some(sub_arr)) => {
460 let haystack_array = string_array_to_vec(&arr);
461 let needle_array = string_array_to_vec(&sub_arr);
462 boolean_builder.append_value(array_has_string_kernel(
463 haystack_array,
464 needle_array,
465 comparison_type,
466 ));
467 }
468 (_, _) => {
469 boolean_builder.append_null();
470 }
471 }
472 }
473
474 Ok(Arc::new(boolean_builder.finish()))
475}
476
477fn array_has_string_kernel(
478 haystack: Vec<Option<&str>>,
479 needle: Vec<Option<&str>>,
480 comparison_type: ComparisonType,
481) -> bool {
482 match comparison_type {
483 ComparisonType::All => needle
484 .iter()
485 .dedup()
486 .all(|x| haystack.iter().dedup().any(|y| y == x)),
487 ComparisonType::Any => needle
488 .iter()
489 .dedup()
490 .any(|x| haystack.iter().dedup().any(|y| y == x)),
491 }
492}
493
494fn general_array_has_for_all_and_any<O: OffsetSizeTrait>(
496 haystack: &GenericListArray<O>,
497 needle: &GenericListArray<O>,
498 comparison_type: ComparisonType,
499) -> Result<ArrayRef> {
500 let mut boolean_builder = BooleanArray::builder(haystack.len());
501 let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?;
502
503 for (arr, sub_arr) in haystack.iter().zip(needle.iter()) {
504 if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
505 let arr_values = converter.convert_columns(&[arr])?;
506 let sub_arr_values = converter.convert_columns(&[sub_arr])?;
507 boolean_builder.append_value(general_array_has_all_and_any_kernel(
508 arr_values,
509 sub_arr_values,
510 comparison_type,
511 ));
512 } else {
513 boolean_builder.append_null();
514 }
515 }
516
517 Ok(Arc::new(boolean_builder.finish()))
518}
519
520fn general_array_has_all_and_any_kernel(
521 haystack_rows: Rows,
522 needle_rows: Rows,
523 comparison_type: ComparisonType,
524) -> bool {
525 match comparison_type {
526 ComparisonType::All => needle_rows.iter().all(|needle_row| {
527 haystack_rows
528 .iter()
529 .any(|haystack_row| haystack_row == needle_row)
530 }),
531 ComparisonType::Any => needle_rows.iter().any(|needle_row| {
532 haystack_rows
533 .iter()
534 .any(|haystack_row| haystack_row == needle_row)
535 }),
536 }
537}