datafusion_functions_nested/
set_ops.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray,
23    OffsetSizeTrait,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::compute;
27use arrow::datatypes::DataType::{LargeList, List, Null};
28use arrow::datatypes::{DataType, Field, FieldRef};
29use arrow::row::{RowConverter, SortField};
30use datafusion_common::cast::{as_large_list_array, as_list_array};
31use datafusion_common::utils::ListCoercion;
32use datafusion_common::{
33    exec_err, internal_err, plan_err, utils::take_function_args, Result,
34};
35use datafusion_expr::{
36    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
37};
38use datafusion_macros::user_doc;
39use itertools::Itertools;
40use std::any::Any;
41use std::collections::HashSet;
42use std::fmt::{Display, Formatter};
43use std::sync::Arc;
44
45// Create static instances of ScalarUDFs for each function
46make_udf_expr_and_func!(
47    ArrayUnion,
48    array_union,
49    array1 array2,
50    "returns an array of the elements in the union of array1 and array2 without duplicates.",
51    array_union_udf
52);
53
54make_udf_expr_and_func!(
55    ArrayIntersect,
56    array_intersect,
57    first_array second_array,
58    "returns an array of the elements in the intersection of array1 and array2.",
59    array_intersect_udf
60);
61
62make_udf_expr_and_func!(
63    ArrayDistinct,
64    array_distinct,
65    array,
66    "returns distinct values from the array after removing duplicates.",
67    array_distinct_udf
68);
69
70#[user_doc(
71    doc_section(label = "Array Functions"),
72    description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.",
73    syntax_example = "array_union(array1, array2)",
74    sql_example = r#"```sql
75> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
76+----------------------------------------------------+
77| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
78+----------------------------------------------------+
79| [1, 2, 3, 4, 5, 6]                                 |
80+----------------------------------------------------+
81> select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
82+----------------------------------------------------+
83| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
84+----------------------------------------------------+
85| [1, 2, 3, 4, 5, 6, 7, 8]                           |
86+----------------------------------------------------+
87```"#,
88    argument(
89        name = "array1",
90        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
91    ),
92    argument(
93        name = "array2",
94        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
95    )
96)]
97#[derive(Debug)]
98pub struct ArrayUnion {
99    signature: Signature,
100    aliases: Vec<String>,
101}
102
103impl Default for ArrayUnion {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl ArrayUnion {
110    pub fn new() -> Self {
111        Self {
112            signature: Signature::arrays(
113                2,
114                Some(ListCoercion::FixedSizedListToList),
115                Volatility::Immutable,
116            ),
117            aliases: vec![String::from("list_union")],
118        }
119    }
120}
121
122impl ScalarUDFImpl for ArrayUnion {
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126
127    fn name(&self) -> &str {
128        "array_union"
129    }
130
131    fn signature(&self) -> &Signature {
132        &self.signature
133    }
134
135    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136        let [array1, array2] = take_function_args(self.name(), arg_types)?;
137        match (array1, array2) {
138            (Null, Null) => Ok(DataType::new_list(Null, true)),
139            (Null, dt) => Ok(dt.clone()),
140            (dt, Null) => Ok(dt.clone()),
141            (dt, _) => Ok(dt.clone()),
142        }
143    }
144
145    fn invoke_with_args(
146        &self,
147        args: datafusion_expr::ScalarFunctionArgs,
148    ) -> Result<ColumnarValue> {
149        make_scalar_function(array_union_inner)(&args.args)
150    }
151
152    fn aliases(&self) -> &[String] {
153        &self.aliases
154    }
155
156    fn documentation(&self) -> Option<&Documentation> {
157        self.doc()
158    }
159}
160
161#[user_doc(
162    doc_section(label = "Array Functions"),
163    description = "Returns an array of elements in the intersection of array1 and array2.",
164    syntax_example = "array_intersect(array1, array2)",
165    sql_example = r#"```sql
166> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);
167+----------------------------------------------------+
168| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);       |
169+----------------------------------------------------+
170| [3, 4]                                             |
171+----------------------------------------------------+
172> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);
173+----------------------------------------------------+
174| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);       |
175+----------------------------------------------------+
176| []                                                 |
177+----------------------------------------------------+
178```"#,
179    argument(
180        name = "array1",
181        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
182    ),
183    argument(
184        name = "array2",
185        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
186    )
187)]
188#[derive(Debug)]
189pub(super) struct ArrayIntersect {
190    signature: Signature,
191    aliases: Vec<String>,
192}
193
194impl ArrayIntersect {
195    pub fn new() -> Self {
196        Self {
197            signature: Signature::arrays(
198                2,
199                Some(ListCoercion::FixedSizedListToList),
200                Volatility::Immutable,
201            ),
202            aliases: vec![String::from("list_intersect")],
203        }
204    }
205}
206
207impl ScalarUDFImpl for ArrayIntersect {
208    fn as_any(&self) -> &dyn Any {
209        self
210    }
211
212    fn name(&self) -> &str {
213        "array_intersect"
214    }
215
216    fn signature(&self) -> &Signature {
217        &self.signature
218    }
219
220    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
221        let [array1, array2] = take_function_args(self.name(), arg_types)?;
222        match (array1, array2) {
223            (Null, Null) => Ok(DataType::new_list(Null, true)),
224            (Null, dt) => Ok(dt.clone()),
225            (dt, Null) => Ok(dt.clone()),
226            (dt, _) => Ok(dt.clone()),
227        }
228    }
229
230    fn invoke_with_args(
231        &self,
232        args: datafusion_expr::ScalarFunctionArgs,
233    ) -> Result<ColumnarValue> {
234        make_scalar_function(array_intersect_inner)(&args.args)
235    }
236
237    fn aliases(&self) -> &[String] {
238        &self.aliases
239    }
240
241    fn documentation(&self) -> Option<&Documentation> {
242        self.doc()
243    }
244}
245
246#[user_doc(
247    doc_section(label = "Array Functions"),
248    description = "Returns distinct values from the array after removing duplicates.",
249    syntax_example = "array_distinct(array)",
250    sql_example = r#"```sql
251> select array_distinct([1, 3, 2, 3, 1, 2, 4]);
252+---------------------------------+
253| array_distinct(List([1,2,3,4])) |
254+---------------------------------+
255| [1, 2, 3, 4]                    |
256+---------------------------------+
257```"#,
258    argument(
259        name = "array",
260        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
261    )
262)]
263#[derive(Debug)]
264pub(super) struct ArrayDistinct {
265    signature: Signature,
266    aliases: Vec<String>,
267}
268
269impl ArrayDistinct {
270    pub fn new() -> Self {
271        Self {
272            signature: Signature::array(Volatility::Immutable),
273            aliases: vec!["list_distinct".to_string()],
274        }
275    }
276}
277
278impl ScalarUDFImpl for ArrayDistinct {
279    fn as_any(&self) -> &dyn Any {
280        self
281    }
282
283    fn name(&self) -> &str {
284        "array_distinct"
285    }
286
287    fn signature(&self) -> &Signature {
288        &self.signature
289    }
290
291    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
292        match &arg_types[0] {
293            List(field) => Ok(DataType::new_list(field.data_type().clone(), true)),
294            LargeList(field) => {
295                Ok(DataType::new_large_list(field.data_type().clone(), true))
296            }
297            arg_type => plan_err!("{} does not support type {arg_type}", self.name()),
298        }
299    }
300
301    fn invoke_with_args(
302        &self,
303        args: datafusion_expr::ScalarFunctionArgs,
304    ) -> Result<ColumnarValue> {
305        make_scalar_function(array_distinct_inner)(&args.args)
306    }
307
308    fn aliases(&self) -> &[String] {
309        &self.aliases
310    }
311
312    fn documentation(&self) -> Option<&Documentation> {
313        self.doc()
314    }
315}
316
317/// array_distinct SQL function
318/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
319fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
320    let [array] = take_function_args("array_distinct", args)?;
321    match array.data_type() {
322        Null => Ok(Arc::clone(array)),
323        List(field) => {
324            let array = as_list_array(&array)?;
325            general_array_distinct(array, field)
326        }
327        LargeList(field) => {
328            let array = as_large_list_array(&array)?;
329            general_array_distinct(array, field)
330        }
331        arg_type => exec_err!("array_distinct does not support type {arg_type}"),
332    }
333}
334
335#[derive(Debug, PartialEq)]
336enum SetOp {
337    Union,
338    Intersect,
339}
340
341impl Display for SetOp {
342    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
343        match self {
344            SetOp::Union => write!(f, "array_union"),
345            SetOp::Intersect => write!(f, "array_intersect"),
346        }
347    }
348}
349
350fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
351    l: &GenericListArray<OffsetSize>,
352    r: &GenericListArray<OffsetSize>,
353    field: Arc<Field>,
354    set_op: SetOp,
355) -> Result<ArrayRef> {
356    if l.is_empty() || l.value_type().is_null() {
357        let field = Arc::new(Field::new_list_field(r.value_type(), true));
358        return general_array_distinct::<OffsetSize>(r, &field);
359    } else if r.is_empty() || r.value_type().is_null() {
360        let field = Arc::new(Field::new_list_field(l.value_type(), true));
361        return general_array_distinct::<OffsetSize>(l, &field);
362    }
363
364    if l.value_type() != r.value_type() {
365        return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
366    }
367
368    let mut offsets = vec![OffsetSize::usize_as(0)];
369    let mut new_arrays = vec![];
370    let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
371    for (first_arr, second_arr) in l.iter().zip(r.iter()) {
372        let l_values = if let Some(first_arr) = first_arr {
373            converter.convert_columns(&[first_arr])?
374        } else {
375            converter.convert_columns(&[])?
376        };
377
378        let r_values = if let Some(second_arr) = second_arr {
379            converter.convert_columns(&[second_arr])?
380        } else {
381            converter.convert_columns(&[])?
382        };
383
384        let l_iter = l_values.iter().sorted().dedup();
385        let values_set: HashSet<_> = l_iter.clone().collect();
386        let mut rows = if set_op == SetOp::Union {
387            l_iter.collect()
388        } else {
389            vec![]
390        };
391
392        for r_val in r_values.iter().sorted().dedup() {
393            match set_op {
394                SetOp::Union => {
395                    if !values_set.contains(&r_val) {
396                        rows.push(r_val);
397                    }
398                }
399                SetOp::Intersect => {
400                    if values_set.contains(&r_val) {
401                        rows.push(r_val);
402                    }
403                }
404            }
405        }
406
407        let last_offset = match offsets.last() {
408            Some(offset) => *offset,
409            None => return internal_err!("offsets should not be empty"),
410        };
411
412        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
413        let arrays = converter.convert_rows(rows)?;
414        let array = match arrays.first() {
415            Some(array) => Arc::clone(array),
416            None => {
417                return internal_err!("{set_op}: failed to get array from rows");
418            }
419        };
420
421        new_arrays.push(array);
422    }
423
424    let offsets = OffsetBuffer::new(offsets.into());
425    let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect();
426    let values = compute::concat(&new_arrays_ref)?;
427    let arr = GenericListArray::<OffsetSize>::try_new(field, offsets, values, None)?;
428    Ok(Arc::new(arr))
429}
430
431fn general_set_op(
432    array1: &ArrayRef,
433    array2: &ArrayRef,
434    set_op: SetOp,
435) -> Result<ArrayRef> {
436    fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result<ArrayRef> {
437        let field = Arc::new(Field::new_list_field(data_type.clone(), true));
438        let values = new_null_array(data_type, len);
439        if large {
440            Ok(Arc::new(LargeListArray::try_new(
441                field,
442                OffsetBuffer::new_zeroed(len),
443                values,
444                None,
445            )?))
446        } else {
447            Ok(Arc::new(ListArray::try_new(
448                field,
449                OffsetBuffer::new_zeroed(len),
450                values,
451                None,
452            )?))
453        }
454    }
455
456    match (array1.data_type(), array2.data_type()) {
457        (Null, Null) => Ok(Arc::new(ListArray::new_null(
458            Arc::new(Field::new_list_field(Null, true)),
459            array1.len(),
460        ))),
461        (Null, List(field)) => {
462            if set_op == SetOp::Intersect {
463                return empty_array(field.data_type(), array1.len(), false);
464            }
465            let array = as_list_array(&array2)?;
466            general_array_distinct::<i32>(array, field)
467        }
468        (List(field), Null) => {
469            if set_op == SetOp::Intersect {
470                return empty_array(field.data_type(), array1.len(), false);
471            }
472            let array = as_list_array(&array1)?;
473            general_array_distinct::<i32>(array, field)
474        }
475        (Null, LargeList(field)) => {
476            if set_op == SetOp::Intersect {
477                return empty_array(field.data_type(), array1.len(), true);
478            }
479            let array = as_large_list_array(&array2)?;
480            general_array_distinct::<i64>(array, field)
481        }
482        (LargeList(field), Null) => {
483            if set_op == SetOp::Intersect {
484                return empty_array(field.data_type(), array1.len(), true);
485            }
486            let array = as_large_list_array(&array1)?;
487            general_array_distinct::<i64>(array, field)
488        }
489        (List(field), List(_)) => {
490            let array1 = as_list_array(&array1)?;
491            let array2 = as_list_array(&array2)?;
492            generic_set_lists::<i32>(array1, array2, Arc::clone(field), set_op)
493        }
494        (LargeList(field), LargeList(_)) => {
495            let array1 = as_large_list_array(&array1)?;
496            let array2 = as_large_list_array(&array2)?;
497            generic_set_lists::<i64>(array1, array2, Arc::clone(field), set_op)
498        }
499        (data_type1, data_type2) => {
500            internal_err!(
501                "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
502            )
503        }
504    }
505}
506
507/// Array_union SQL function
508fn array_union_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
509    let [array1, array2] = take_function_args("array_union", args)?;
510    general_set_op(array1, array2, SetOp::Union)
511}
512
513/// array_intersect SQL function
514fn array_intersect_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
515    let [array1, array2] = take_function_args("array_intersect", args)?;
516    general_set_op(array1, array2, SetOp::Intersect)
517}
518
519fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
520    array: &GenericListArray<OffsetSize>,
521    field: &FieldRef,
522) -> Result<ArrayRef> {
523    if array.is_empty() {
524        return Ok(Arc::new(array.clone()) as ArrayRef);
525    }
526    let dt = array.value_type();
527    let mut offsets = Vec::with_capacity(array.len());
528    offsets.push(OffsetSize::usize_as(0));
529    let mut new_arrays = Vec::with_capacity(array.len());
530    let converter = RowConverter::new(vec![SortField::new(dt)])?;
531    // distinct for each list in ListArray
532    for arr in array.iter() {
533        let last_offset: OffsetSize = offsets.last().copied().unwrap();
534        let Some(arr) = arr else {
535            // Add same offset for null
536            offsets.push(last_offset);
537            continue;
538        };
539        let values = converter.convert_columns(&[arr])?;
540        // sort elements in list and remove duplicates
541        let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
542        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
543        let arrays = converter.convert_rows(rows)?;
544        let array = match arrays.first() {
545            Some(array) => Arc::clone(array),
546            None => {
547                return internal_err!("array_distinct: failed to get array from rows")
548            }
549        };
550        new_arrays.push(array);
551    }
552    if new_arrays.is_empty() {
553        return Ok(Arc::new(array.clone()) as ArrayRef);
554    }
555    let offsets = OffsetBuffer::new(offsets.into());
556    let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
557    let values = compute::concat(&new_arrays_ref)?;
558    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
559        Arc::clone(field),
560        offsets,
561        values,
562        // Keep the list nulls
563        array.nulls().cloned(),
564    )?))
565}