datafusion_functions_nested/
set_ops.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
19
20use crate::make_array::{empty_array_type, make_array_inner};
21use crate::utils::make_scalar_function;
22use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait};
23use arrow::buffer::OffsetBuffer;
24use arrow::compute;
25use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
26use arrow::datatypes::{DataType, Field, FieldRef};
27use arrow::row::{RowConverter, SortField};
28use datafusion_common::cast::{as_large_list_array, as_list_array};
29use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result};
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34use itertools::Itertools;
35use std::any::Any;
36use std::collections::HashSet;
37use std::fmt::{Display, Formatter};
38use std::sync::Arc;
39
40// Create static instances of ScalarUDFs for each function
41make_udf_expr_and_func!(
42    ArrayUnion,
43    array_union,
44    array1 array2,
45    "returns an array of the elements in the union of array1 and array2 without duplicates.",
46    array_union_udf
47);
48
49make_udf_expr_and_func!(
50    ArrayIntersect,
51    array_intersect,
52    first_array second_array,
53    "returns an array of the elements in the intersection of array1 and array2.",
54    array_intersect_udf
55);
56
57make_udf_expr_and_func!(
58    ArrayDistinct,
59    array_distinct,
60    array,
61    "returns distinct values from the array after removing duplicates.",
62    array_distinct_udf
63);
64
65#[user_doc(
66    doc_section(label = "Array Functions"),
67    description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.",
68    syntax_example = "array_union(array1, array2)",
69    sql_example = r#"```sql
70> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
71+----------------------------------------------------+
72| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
73+----------------------------------------------------+
74| [1, 2, 3, 4, 5, 6]                                 |
75+----------------------------------------------------+
76> select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
77+----------------------------------------------------+
78| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
79+----------------------------------------------------+
80| [1, 2, 3, 4, 5, 6, 7, 8]                           |
81+----------------------------------------------------+
82```"#,
83    argument(
84        name = "array1",
85        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
86    ),
87    argument(
88        name = "array2",
89        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
90    )
91)]
92#[derive(Debug)]
93pub struct ArrayUnion {
94    signature: Signature,
95    aliases: Vec<String>,
96}
97
98impl Default for ArrayUnion {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl ArrayUnion {
105    pub fn new() -> Self {
106        Self {
107            signature: Signature::any(2, Volatility::Immutable),
108            aliases: vec![String::from("list_union")],
109        }
110    }
111}
112
113impl ScalarUDFImpl for ArrayUnion {
114    fn as_any(&self) -> &dyn Any {
115        self
116    }
117
118    fn name(&self) -> &str {
119        "array_union"
120    }
121
122    fn signature(&self) -> &Signature {
123        &self.signature
124    }
125
126    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
127        match (&arg_types[0], &arg_types[1]) {
128            (&Null, dt) => Ok(dt.clone()),
129            (dt, Null) => Ok(dt.clone()),
130            (dt, _) => Ok(dt.clone()),
131        }
132    }
133
134    fn invoke_with_args(
135        &self,
136        args: datafusion_expr::ScalarFunctionArgs,
137    ) -> Result<ColumnarValue> {
138        make_scalar_function(array_union_inner)(&args.args)
139    }
140
141    fn aliases(&self) -> &[String] {
142        &self.aliases
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150#[user_doc(
151    doc_section(label = "Array Functions"),
152    description = "Returns an array of elements in the intersection of array1 and array2.",
153    syntax_example = "array_intersect(array1, array2)",
154    sql_example = r#"```sql
155> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);
156+----------------------------------------------------+
157| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);       |
158+----------------------------------------------------+
159| [3, 4]                                             |
160+----------------------------------------------------+
161> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);
162+----------------------------------------------------+
163| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);       |
164+----------------------------------------------------+
165| []                                                 |
166+----------------------------------------------------+
167```"#,
168    argument(
169        name = "array1",
170        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
171    ),
172    argument(
173        name = "array2",
174        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
175    )
176)]
177#[derive(Debug)]
178pub(super) struct ArrayIntersect {
179    signature: Signature,
180    aliases: Vec<String>,
181}
182
183impl ArrayIntersect {
184    pub fn new() -> Self {
185        Self {
186            signature: Signature::any(2, Volatility::Immutable),
187            aliases: vec![String::from("list_intersect")],
188        }
189    }
190}
191
192impl ScalarUDFImpl for ArrayIntersect {
193    fn as_any(&self) -> &dyn Any {
194        self
195    }
196
197    fn name(&self) -> &str {
198        "array_intersect"
199    }
200
201    fn signature(&self) -> &Signature {
202        &self.signature
203    }
204
205    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
206        match (arg_types[0].clone(), arg_types[1].clone()) {
207            (Null, Null) | (Null, _) => Ok(Null),
208            (_, Null) => Ok(empty_array_type()),
209            (dt, _) => Ok(dt),
210        }
211    }
212
213    fn invoke_with_args(
214        &self,
215        args: datafusion_expr::ScalarFunctionArgs,
216    ) -> Result<ColumnarValue> {
217        make_scalar_function(array_intersect_inner)(&args.args)
218    }
219
220    fn aliases(&self) -> &[String] {
221        &self.aliases
222    }
223
224    fn documentation(&self) -> Option<&Documentation> {
225        self.doc()
226    }
227}
228
229#[user_doc(
230    doc_section(label = "Array Functions"),
231    description = "Returns distinct values from the array after removing duplicates.",
232    syntax_example = "array_distinct(array)",
233    sql_example = r#"```sql
234> select array_distinct([1, 3, 2, 3, 1, 2, 4]);
235+---------------------------------+
236| array_distinct(List([1,2,3,4])) |
237+---------------------------------+
238| [1, 2, 3, 4]                    |
239+---------------------------------+
240```"#,
241    argument(
242        name = "array",
243        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
244    )
245)]
246#[derive(Debug)]
247pub(super) struct ArrayDistinct {
248    signature: Signature,
249    aliases: Vec<String>,
250}
251
252impl ArrayDistinct {
253    pub fn new() -> Self {
254        Self {
255            signature: Signature::array(Volatility::Immutable),
256            aliases: vec!["list_distinct".to_string()],
257        }
258    }
259}
260
261impl ScalarUDFImpl for ArrayDistinct {
262    fn as_any(&self) -> &dyn Any {
263        self
264    }
265
266    fn name(&self) -> &str {
267        "array_distinct"
268    }
269
270    fn signature(&self) -> &Signature {
271        &self.signature
272    }
273
274    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
275        match &arg_types[0] {
276            List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(
277                Field::new_list_field(field.data_type().clone(), true),
278            ))),
279            LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field(
280                field.data_type().clone(),
281                true,
282            )))),
283            _ => exec_err!(
284                "Not reachable, data_type should be List, LargeList or FixedSizeList"
285            ),
286        }
287    }
288
289    fn invoke_with_args(
290        &self,
291        args: datafusion_expr::ScalarFunctionArgs,
292    ) -> Result<ColumnarValue> {
293        make_scalar_function(array_distinct_inner)(&args.args)
294    }
295
296    fn aliases(&self) -> &[String] {
297        &self.aliases
298    }
299
300    fn documentation(&self) -> Option<&Documentation> {
301        self.doc()
302    }
303}
304
305/// array_distinct SQL function
306/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
307fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
308    let [input_array] = take_function_args("array_distinct", args)?;
309
310    // handle null
311    if input_array.data_type() == &Null {
312        return Ok(Arc::clone(input_array));
313    }
314
315    // handle for list & largelist
316    match input_array.data_type() {
317        List(field) => {
318            let array = as_list_array(&input_array)?;
319            general_array_distinct(array, field)
320        }
321        LargeList(field) => {
322            let array = as_large_list_array(&input_array)?;
323            general_array_distinct(array, field)
324        }
325        array_type => exec_err!("array_distinct does not support type '{array_type:?}'"),
326    }
327}
328
329#[derive(Debug, PartialEq)]
330enum SetOp {
331    Union,
332    Intersect,
333}
334
335impl Display for SetOp {
336    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
337        match self {
338            SetOp::Union => write!(f, "array_union"),
339            SetOp::Intersect => write!(f, "array_intersect"),
340        }
341    }
342}
343
344fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
345    l: &GenericListArray<OffsetSize>,
346    r: &GenericListArray<OffsetSize>,
347    field: Arc<Field>,
348    set_op: SetOp,
349) -> Result<ArrayRef> {
350    if matches!(l.value_type(), Null) {
351        let field = Arc::new(Field::new_list_field(r.value_type(), true));
352        return general_array_distinct::<OffsetSize>(r, &field);
353    } else if matches!(r.value_type(), Null) {
354        let field = Arc::new(Field::new_list_field(l.value_type(), true));
355        return general_array_distinct::<OffsetSize>(l, &field);
356    }
357
358    // Handle empty array at rhs case
359    // array_union(arr, []) -> arr;
360    // array_intersect(arr, []) -> [];
361    if r.value_length(0).is_zero() {
362        if set_op == SetOp::Union {
363            return Ok(Arc::new(l.clone()) as ArrayRef);
364        } else {
365            return Ok(Arc::new(r.clone()) as ArrayRef);
366        }
367    }
368
369    if l.value_type() != r.value_type() {
370        return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
371    }
372
373    let dt = l.value_type();
374
375    let mut offsets = vec![OffsetSize::usize_as(0)];
376    let mut new_arrays = vec![];
377
378    let converter = RowConverter::new(vec![SortField::new(dt)])?;
379    for (first_arr, second_arr) in l.iter().zip(r.iter()) {
380        if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
381            let l_values = converter.convert_columns(&[first_arr])?;
382            let r_values = converter.convert_columns(&[second_arr])?;
383
384            let l_iter = l_values.iter().sorted().dedup();
385            let values_set: HashSet<_> = l_iter.clone().collect();
386            let mut rows = if set_op == SetOp::Union {
387                l_iter.collect::<Vec<_>>()
388            } else {
389                vec![]
390            };
391            for r_val in r_values.iter().sorted().dedup() {
392                match set_op {
393                    SetOp::Union => {
394                        if !values_set.contains(&r_val) {
395                            rows.push(r_val);
396                        }
397                    }
398                    SetOp::Intersect => {
399                        if values_set.contains(&r_val) {
400                            rows.push(r_val);
401                        }
402                    }
403                }
404            }
405
406            let last_offset = match offsets.last().copied() {
407                Some(offset) => offset,
408                None => return internal_err!("offsets should not be empty"),
409            };
410            offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
411            let arrays = converter.convert_rows(rows)?;
412            let array = match arrays.first() {
413                Some(array) => Arc::clone(array),
414                None => {
415                    return internal_err!("{set_op}: failed to get array from rows");
416                }
417            };
418            new_arrays.push(array);
419        }
420    }
421
422    let offsets = OffsetBuffer::new(offsets.into());
423    let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
424    let values = compute::concat(&new_arrays_ref)?;
425    let arr = GenericListArray::<OffsetSize>::try_new(field, offsets, values, None)?;
426    Ok(Arc::new(arr))
427}
428
429fn general_set_op(
430    array1: &ArrayRef,
431    array2: &ArrayRef,
432    set_op: SetOp,
433) -> Result<ArrayRef> {
434    match (array1.data_type(), array2.data_type()) {
435        (Null, List(field)) => {
436            if set_op == SetOp::Intersect {
437                return Ok(new_empty_array(&Null));
438            }
439            let array = as_list_array(&array2)?;
440            general_array_distinct::<i32>(array, field)
441        }
442
443        (List(field), Null) => {
444            if set_op == SetOp::Intersect {
445                return make_array_inner(&[]);
446            }
447            let array = as_list_array(&array1)?;
448            general_array_distinct::<i32>(array, field)
449        }
450        (Null, LargeList(field)) => {
451            if set_op == SetOp::Intersect {
452                return Ok(new_empty_array(&Null));
453            }
454            let array = as_large_list_array(&array2)?;
455            general_array_distinct::<i64>(array, field)
456        }
457        (LargeList(field), Null) => {
458            if set_op == SetOp::Intersect {
459                return make_array_inner(&[]);
460            }
461            let array = as_large_list_array(&array1)?;
462            general_array_distinct::<i64>(array, field)
463        }
464        (Null, Null) => Ok(new_empty_array(&Null)),
465
466        (List(field), List(_)) => {
467            let array1 = as_list_array(&array1)?;
468            let array2 = as_list_array(&array2)?;
469            generic_set_lists::<i32>(array1, array2, Arc::clone(field), set_op)
470        }
471        (LargeList(field), LargeList(_)) => {
472            let array1 = as_large_list_array(&array1)?;
473            let array2 = as_large_list_array(&array2)?;
474            generic_set_lists::<i64>(array1, array2, Arc::clone(field), set_op)
475        }
476        (data_type1, data_type2) => {
477            internal_err!(
478                "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
479            )
480        }
481    }
482}
483
484/// Array_union SQL function
485fn array_union_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
486    let [array1, array2] = take_function_args("array_union", args)?;
487    general_set_op(array1, array2, SetOp::Union)
488}
489
490/// array_intersect SQL function
491fn array_intersect_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
492    let [array1, array2] = take_function_args("array_intersect", args)?;
493    general_set_op(array1, array2, SetOp::Intersect)
494}
495
496fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
497    array: &GenericListArray<OffsetSize>,
498    field: &FieldRef,
499) -> Result<ArrayRef> {
500    if array.is_empty() {
501        return Ok(Arc::new(array.clone()) as ArrayRef);
502    }
503    let dt = array.value_type();
504    let mut offsets = Vec::with_capacity(array.len());
505    offsets.push(OffsetSize::usize_as(0));
506    let mut new_arrays = Vec::with_capacity(array.len());
507    let converter = RowConverter::new(vec![SortField::new(dt)])?;
508    // distinct for each list in ListArray
509    for arr in array.iter() {
510        let last_offset: OffsetSize = offsets.last().copied().unwrap();
511        let Some(arr) = arr else {
512            // Add same offset for null
513            offsets.push(last_offset);
514            continue;
515        };
516        let values = converter.convert_columns(&[arr])?;
517        // sort elements in list and remove duplicates
518        let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
519        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
520        let arrays = converter.convert_rows(rows)?;
521        let array = match arrays.first() {
522            Some(array) => Arc::clone(array),
523            None => {
524                return internal_err!("array_distinct: failed to get array from rows")
525            }
526        };
527        new_arrays.push(array);
528    }
529    if new_arrays.is_empty() {
530        return Ok(Arc::new(array.clone()) as ArrayRef);
531    }
532    let offsets = OffsetBuffer::new(offsets.into());
533    let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
534    let values = compute::concat(&new_arrays_ref)?;
535    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
536        Arc::clone(field),
537        offsets,
538        values,
539        // Keep the list nulls
540        array.nulls().cloned(),
541    )?))
542}