datafusion_functions_nested/
set_ops.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, GenericListArray, LargeListArray, ListArray, OffsetSizeTrait,
23    new_null_array,
24};
25use arrow::buffer::{NullBuffer, OffsetBuffer};
26use arrow::compute;
27use arrow::datatypes::DataType::{LargeList, List, Null};
28use arrow::datatypes::{DataType, Field, FieldRef};
29use arrow::row::{RowConverter, SortField};
30use datafusion_common::cast::{as_large_list_array, as_list_array};
31use datafusion_common::utils::ListCoercion;
32use datafusion_common::{
33    Result, assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args,
34};
35use datafusion_expr::{
36    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
37};
38use datafusion_macros::user_doc;
39use itertools::Itertools;
40use std::any::Any;
41use std::collections::HashSet;
42use std::fmt::{Display, Formatter};
43use std::sync::Arc;
44
45// Create static instances of ScalarUDFs for each function
46make_udf_expr_and_func!(
47    ArrayUnion,
48    array_union,
49    array1 array2,
50    "returns an array of the elements in the union of array1 and array2 without duplicates.",
51    array_union_udf
52);
53
54make_udf_expr_and_func!(
55    ArrayIntersect,
56    array_intersect,
57    first_array second_array,
58    "returns an array of the elements in the intersection of array1 and array2.",
59    array_intersect_udf
60);
61
62make_udf_expr_and_func!(
63    ArrayDistinct,
64    array_distinct,
65    array,
66    "returns distinct values from the array after removing duplicates.",
67    array_distinct_udf
68);
69
70#[user_doc(
71    doc_section(label = "Array Functions"),
72    description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.",
73    syntax_example = "array_union(array1, array2)",
74    sql_example = r#"```sql
75> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
76+----------------------------------------------------+
77| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
78+----------------------------------------------------+
79| [1, 2, 3, 4, 5, 6]                                 |
80+----------------------------------------------------+
81> select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
82+----------------------------------------------------+
83| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
84+----------------------------------------------------+
85| [1, 2, 3, 4, 5, 6, 7, 8]                           |
86+----------------------------------------------------+
87```"#,
88    argument(
89        name = "array1",
90        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
91    ),
92    argument(
93        name = "array2",
94        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
95    )
96)]
97#[derive(Debug, PartialEq, Eq, Hash)]
98pub struct ArrayUnion {
99    signature: Signature,
100    aliases: Vec<String>,
101}
102
103impl Default for ArrayUnion {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl ArrayUnion {
110    pub fn new() -> Self {
111        Self {
112            signature: Signature::arrays(
113                2,
114                Some(ListCoercion::FixedSizedListToList),
115                Volatility::Immutable,
116            ),
117            aliases: vec![String::from("list_union")],
118        }
119    }
120}
121
122impl ScalarUDFImpl for ArrayUnion {
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126
127    fn name(&self) -> &str {
128        "array_union"
129    }
130
131    fn signature(&self) -> &Signature {
132        &self.signature
133    }
134
135    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136        let [array1, array2] = take_function_args(self.name(), arg_types)?;
137        match (array1, array2) {
138            (Null, Null) => Ok(DataType::new_list(Null, true)),
139            (Null, dt) => Ok(dt.clone()),
140            (dt, Null) => Ok(dt.clone()),
141            (dt, _) => Ok(dt.clone()),
142        }
143    }
144
145    fn invoke_with_args(
146        &self,
147        args: datafusion_expr::ScalarFunctionArgs,
148    ) -> Result<ColumnarValue> {
149        make_scalar_function(array_union_inner)(&args.args)
150    }
151
152    fn aliases(&self) -> &[String] {
153        &self.aliases
154    }
155
156    fn documentation(&self) -> Option<&Documentation> {
157        self.doc()
158    }
159}
160
161#[user_doc(
162    doc_section(label = "Array Functions"),
163    description = "Returns an array of elements in the intersection of array1 and array2.",
164    syntax_example = "array_intersect(array1, array2)",
165    sql_example = r#"```sql
166> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);
167+----------------------------------------------------+
168| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);       |
169+----------------------------------------------------+
170| [3, 4]                                             |
171+----------------------------------------------------+
172> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);
173+----------------------------------------------------+
174| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);       |
175+----------------------------------------------------+
176| []                                                 |
177+----------------------------------------------------+
178```"#,
179    argument(
180        name = "array1",
181        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
182    ),
183    argument(
184        name = "array2",
185        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
186    )
187)]
188#[derive(Debug, PartialEq, Eq, Hash)]
189pub(super) struct ArrayIntersect {
190    signature: Signature,
191    aliases: Vec<String>,
192}
193
194impl ArrayIntersect {
195    pub fn new() -> Self {
196        Self {
197            signature: Signature::arrays(
198                2,
199                Some(ListCoercion::FixedSizedListToList),
200                Volatility::Immutable,
201            ),
202            aliases: vec![String::from("list_intersect")],
203        }
204    }
205}
206
207impl ScalarUDFImpl for ArrayIntersect {
208    fn as_any(&self) -> &dyn Any {
209        self
210    }
211
212    fn name(&self) -> &str {
213        "array_intersect"
214    }
215
216    fn signature(&self) -> &Signature {
217        &self.signature
218    }
219
220    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
221        let [array1, array2] = take_function_args(self.name(), arg_types)?;
222        match (array1, array2) {
223            (Null, Null) => Ok(DataType::new_list(Null, true)),
224            (Null, dt) => Ok(dt.clone()),
225            (dt, Null) => Ok(dt.clone()),
226            (dt, _) => Ok(dt.clone()),
227        }
228    }
229
230    fn invoke_with_args(
231        &self,
232        args: datafusion_expr::ScalarFunctionArgs,
233    ) -> Result<ColumnarValue> {
234        make_scalar_function(array_intersect_inner)(&args.args)
235    }
236
237    fn aliases(&self) -> &[String] {
238        &self.aliases
239    }
240
241    fn documentation(&self) -> Option<&Documentation> {
242        self.doc()
243    }
244}
245
246#[user_doc(
247    doc_section(label = "Array Functions"),
248    description = "Returns distinct values from the array after removing duplicates.",
249    syntax_example = "array_distinct(array)",
250    sql_example = r#"```sql
251> select array_distinct([1, 3, 2, 3, 1, 2, 4]);
252+---------------------------------+
253| array_distinct(List([1,2,3,4])) |
254+---------------------------------+
255| [1, 2, 3, 4]                    |
256+---------------------------------+
257```"#,
258    argument(
259        name = "array",
260        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
261    )
262)]
263#[derive(Debug, PartialEq, Eq, Hash)]
264pub(super) struct ArrayDistinct {
265    signature: Signature,
266    aliases: Vec<String>,
267}
268
269impl ArrayDistinct {
270    pub fn new() -> Self {
271        Self {
272            signature: Signature::array(Volatility::Immutable),
273            aliases: vec!["list_distinct".to_string()],
274        }
275    }
276}
277
278impl ScalarUDFImpl for ArrayDistinct {
279    fn as_any(&self) -> &dyn Any {
280        self
281    }
282
283    fn name(&self) -> &str {
284        "array_distinct"
285    }
286
287    fn signature(&self) -> &Signature {
288        &self.signature
289    }
290
291    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
292        Ok(arg_types[0].clone())
293    }
294
295    fn invoke_with_args(
296        &self,
297        args: datafusion_expr::ScalarFunctionArgs,
298    ) -> Result<ColumnarValue> {
299        make_scalar_function(array_distinct_inner)(&args.args)
300    }
301
302    fn aliases(&self) -> &[String] {
303        &self.aliases
304    }
305
306    fn documentation(&self) -> Option<&Documentation> {
307        self.doc()
308    }
309}
310
311/// array_distinct SQL function
312/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
313fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
314    let [array] = take_function_args("array_distinct", args)?;
315    match array.data_type() {
316        Null => Ok(Arc::clone(array)),
317        List(field) => {
318            let array = as_list_array(&array)?;
319            general_array_distinct(array, field)
320        }
321        LargeList(field) => {
322            let array = as_large_list_array(&array)?;
323            general_array_distinct(array, field)
324        }
325        arg_type => exec_err!("array_distinct does not support type {arg_type}"),
326    }
327}
328
329#[derive(Debug, PartialEq, Copy, Clone)]
330enum SetOp {
331    Union,
332    Intersect,
333}
334
335impl Display for SetOp {
336    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
337        match self {
338            SetOp::Union => write!(f, "array_union"),
339            SetOp::Intersect => write!(f, "array_intersect"),
340        }
341    }
342}
343
344fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
345    l: &GenericListArray<OffsetSize>,
346    r: &GenericListArray<OffsetSize>,
347    field: Arc<Field>,
348    set_op: SetOp,
349) -> Result<ArrayRef> {
350    if l.is_empty() || l.value_type().is_null() {
351        let field = Arc::new(Field::new_list_field(r.value_type(), true));
352        return general_array_distinct::<OffsetSize>(r, &field);
353    } else if r.is_empty() || r.value_type().is_null() {
354        let field = Arc::new(Field::new_list_field(l.value_type(), true));
355        return general_array_distinct::<OffsetSize>(l, &field);
356    }
357
358    assert_eq_or_internal_err!(
359        l.value_type(),
360        r.value_type(),
361        "{set_op:?} is not implemented for '{l:?}' and '{r:?}'"
362    );
363
364    let mut offsets = vec![OffsetSize::usize_as(0)];
365    let mut new_arrays = vec![];
366    let mut new_null_buf = vec![];
367    let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
368    for (first_arr, second_arr) in l.iter().zip(r.iter()) {
369        let mut ele_should_be_null = false;
370
371        let l_values = if let Some(first_arr) = first_arr {
372            converter.convert_columns(&[first_arr])?
373        } else {
374            ele_should_be_null = true;
375            converter.empty_rows(0, 0)
376        };
377
378        let r_values = if let Some(second_arr) = second_arr {
379            converter.convert_columns(&[second_arr])?
380        } else {
381            ele_should_be_null = true;
382            converter.empty_rows(0, 0)
383        };
384
385        let l_iter = l_values.iter().sorted().dedup();
386        let values_set: HashSet<_> = l_iter.clone().collect();
387        let mut rows = if set_op == SetOp::Union {
388            l_iter.collect()
389        } else {
390            vec![]
391        };
392
393        for r_val in r_values.iter().sorted().dedup() {
394            match set_op {
395                SetOp::Union => {
396                    if !values_set.contains(&r_val) {
397                        rows.push(r_val);
398                    }
399                }
400                SetOp::Intersect => {
401                    if values_set.contains(&r_val) {
402                        rows.push(r_val);
403                    }
404                }
405            }
406        }
407
408        let last_offset = match offsets.last() {
409            Some(offset) => *offset,
410            None => return internal_err!("offsets should not be empty"),
411        };
412
413        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
414        let arrays = converter.convert_rows(rows)?;
415        let array = match arrays.first() {
416            Some(array) => Arc::clone(array),
417            None => {
418                return internal_err!("{set_op}: failed to get array from rows");
419            }
420        };
421
422        new_null_buf.push(!ele_should_be_null);
423        new_arrays.push(array);
424    }
425
426    let offsets = OffsetBuffer::new(offsets.into());
427    let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect();
428    let values = compute::concat(&new_arrays_ref)?;
429    let arr = GenericListArray::<OffsetSize>::try_new(
430        field,
431        offsets,
432        values,
433        Some(NullBuffer::new(new_null_buf.into())),
434    )?;
435    Ok(Arc::new(arr))
436}
437
438fn general_set_op(
439    array1: &ArrayRef,
440    array2: &ArrayRef,
441    set_op: SetOp,
442) -> Result<ArrayRef> {
443    fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result<ArrayRef> {
444        let field = Arc::new(Field::new_list_field(data_type.clone(), true));
445        let values = new_null_array(data_type, len);
446        if large {
447            Ok(Arc::new(LargeListArray::try_new(
448                field,
449                OffsetBuffer::new_zeroed(len),
450                values,
451                None,
452            )?))
453        } else {
454            Ok(Arc::new(ListArray::try_new(
455                field,
456                OffsetBuffer::new_zeroed(len),
457                values,
458                None,
459            )?))
460        }
461    }
462
463    match (array1.data_type(), array2.data_type()) {
464        (Null, Null) => Ok(Arc::new(ListArray::new_null(
465            Arc::new(Field::new_list_field(Null, true)),
466            array1.len(),
467        ))),
468        (Null, List(field)) => {
469            if set_op == SetOp::Intersect {
470                return empty_array(field.data_type(), array1.len(), false);
471            }
472            let array = as_list_array(&array2)?;
473            general_array_distinct::<i32>(array, field)
474        }
475        (List(field), Null) => {
476            if set_op == SetOp::Intersect {
477                return empty_array(field.data_type(), array1.len(), false);
478            }
479            let array = as_list_array(&array1)?;
480            general_array_distinct::<i32>(array, field)
481        }
482        (Null, LargeList(field)) => {
483            if set_op == SetOp::Intersect {
484                return empty_array(field.data_type(), array1.len(), true);
485            }
486            let array = as_large_list_array(&array2)?;
487            general_array_distinct::<i64>(array, field)
488        }
489        (LargeList(field), Null) => {
490            if set_op == SetOp::Intersect {
491                return empty_array(field.data_type(), array1.len(), true);
492            }
493            let array = as_large_list_array(&array1)?;
494            general_array_distinct::<i64>(array, field)
495        }
496        (List(field), List(_)) => {
497            let array1 = as_list_array(&array1)?;
498            let array2 = as_list_array(&array2)?;
499            generic_set_lists::<i32>(array1, array2, Arc::clone(field), set_op)
500        }
501        (LargeList(field), LargeList(_)) => {
502            let array1 = as_large_list_array(&array1)?;
503            let array2 = as_large_list_array(&array2)?;
504            generic_set_lists::<i64>(array1, array2, Arc::clone(field), set_op)
505        }
506        (data_type1, data_type2) => {
507            internal_err!(
508                "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
509            )
510        }
511    }
512}
513
514fn array_union_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
515    let [array1, array2] = take_function_args("array_union", args)?;
516    general_set_op(array1, array2, SetOp::Union)
517}
518
519fn array_intersect_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
520    let [array1, array2] = take_function_args("array_intersect", args)?;
521    general_set_op(array1, array2, SetOp::Intersect)
522}
523
524fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
525    array: &GenericListArray<OffsetSize>,
526    field: &FieldRef,
527) -> Result<ArrayRef> {
528    if array.is_empty() {
529        return Ok(Arc::new(array.clone()) as ArrayRef);
530    }
531    let dt = array.value_type();
532    let mut offsets = Vec::with_capacity(array.len());
533    offsets.push(OffsetSize::usize_as(0));
534    let mut new_arrays = Vec::with_capacity(array.len());
535    let converter = RowConverter::new(vec![SortField::new(dt)])?;
536    // distinct for each list in ListArray
537    for arr in array.iter() {
538        let last_offset: OffsetSize = offsets.last().copied().unwrap();
539        let Some(arr) = arr else {
540            // Add same offset for null
541            offsets.push(last_offset);
542            continue;
543        };
544        let values = converter.convert_columns(&[arr])?;
545        // sort elements in list and remove duplicates
546        let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
547        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
548        let arrays = converter.convert_rows(rows)?;
549        let array = match arrays.first() {
550            Some(array) => Arc::clone(array),
551            None => {
552                return internal_err!("array_distinct: failed to get array from rows");
553            }
554        };
555        new_arrays.push(array);
556    }
557    if new_arrays.is_empty() {
558        return Ok(Arc::new(array.clone()) as ArrayRef);
559    }
560    let offsets = OffsetBuffer::new(offsets.into());
561    let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
562    let values = compute::concat(&new_arrays_ref)?;
563    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
564        Arc::clone(field),
565        offsets,
566        values,
567        // Keep the list nulls
568        array.nulls().cloned(),
569    )?))
570}
571
572#[cfg(test)]
573mod tests {
574    use std::sync::Arc;
575
576    use arrow::{
577        array::{Int32Array, ListArray},
578        buffer::OffsetBuffer,
579        datatypes::{DataType, Field},
580    };
581    use datafusion_common::{DataFusionError, config::ConfigOptions};
582    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
583
584    use crate::set_ops::array_distinct_udf;
585
586    #[test]
587    fn test_array_distinct_inner_nullability_result_type_match_return_type()
588    -> Result<(), DataFusionError> {
589        let udf = array_distinct_udf();
590
591        for inner_nullable in [true, false] {
592            let inner_field = Field::new_list_field(DataType::Int32, inner_nullable);
593            let input_field =
594                Field::new_list("input", Arc::new(inner_field.clone()), true);
595
596            // [[1, 1, 2]]
597            let input_array = ListArray::new(
598                inner_field.into(),
599                OffsetBuffer::new(vec![0, 3].into()),
600                Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)),
601                None,
602            );
603
604            let input_array = ColumnarValue::Array(Arc::new(input_array));
605
606            let result = udf.invoke_with_args(ScalarFunctionArgs {
607                args: vec![input_array],
608                arg_fields: vec![input_field.clone().into()],
609                number_rows: 1,
610                return_field: input_field.clone().into(),
611                config_options: Arc::new(ConfigOptions::default()),
612            })?;
613
614            assert_eq!(
615                result.data_type(),
616                udf.return_type(&[input_field.data_type().clone()])?
617            );
618        }
619        Ok(())
620    }
621}