datafusion_functions_nested/
set_ops.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray,
23    OffsetSizeTrait,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::compute;
27use arrow::datatypes::DataType::{LargeList, List, Null};
28use arrow::datatypes::{DataType, Field, FieldRef};
29use arrow::row::{RowConverter, SortField};
30use datafusion_common::cast::{as_large_list_array, as_list_array};
31use datafusion_common::utils::ListCoercion;
32use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result};
33use datafusion_expr::{
34    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
35};
36use datafusion_macros::user_doc;
37use itertools::Itertools;
38use std::any::Any;
39use std::collections::HashSet;
40use std::fmt::{Display, Formatter};
41use std::sync::Arc;
42
43// Create static instances of ScalarUDFs for each function
44make_udf_expr_and_func!(
45    ArrayUnion,
46    array_union,
47    array1 array2,
48    "returns an array of the elements in the union of array1 and array2 without duplicates.",
49    array_union_udf
50);
51
52make_udf_expr_and_func!(
53    ArrayIntersect,
54    array_intersect,
55    first_array second_array,
56    "returns an array of the elements in the intersection of array1 and array2.",
57    array_intersect_udf
58);
59
60make_udf_expr_and_func!(
61    ArrayDistinct,
62    array_distinct,
63    array,
64    "returns distinct values from the array after removing duplicates.",
65    array_distinct_udf
66);
67
68#[user_doc(
69    doc_section(label = "Array Functions"),
70    description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.",
71    syntax_example = "array_union(array1, array2)",
72    sql_example = r#"```sql
73> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
74+----------------------------------------------------+
75| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
76+----------------------------------------------------+
77| [1, 2, 3, 4, 5, 6]                                 |
78+----------------------------------------------------+
79> select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
80+----------------------------------------------------+
81| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
82+----------------------------------------------------+
83| [1, 2, 3, 4, 5, 6, 7, 8]                           |
84+----------------------------------------------------+
85```"#,
86    argument(
87        name = "array1",
88        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
89    ),
90    argument(
91        name = "array2",
92        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
93    )
94)]
95#[derive(Debug, PartialEq, Eq, Hash)]
96pub struct ArrayUnion {
97    signature: Signature,
98    aliases: Vec<String>,
99}
100
101impl Default for ArrayUnion {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl ArrayUnion {
108    pub fn new() -> Self {
109        Self {
110            signature: Signature::arrays(
111                2,
112                Some(ListCoercion::FixedSizedListToList),
113                Volatility::Immutable,
114            ),
115            aliases: vec![String::from("list_union")],
116        }
117    }
118}
119
120impl ScalarUDFImpl for ArrayUnion {
121    fn as_any(&self) -> &dyn Any {
122        self
123    }
124
125    fn name(&self) -> &str {
126        "array_union"
127    }
128
129    fn signature(&self) -> &Signature {
130        &self.signature
131    }
132
133    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
134        let [array1, array2] = take_function_args(self.name(), arg_types)?;
135        match (array1, array2) {
136            (Null, Null) => Ok(DataType::new_list(Null, true)),
137            (Null, dt) => Ok(dt.clone()),
138            (dt, Null) => Ok(dt.clone()),
139            (dt, _) => Ok(dt.clone()),
140        }
141    }
142
143    fn invoke_with_args(
144        &self,
145        args: datafusion_expr::ScalarFunctionArgs,
146    ) -> Result<ColumnarValue> {
147        make_scalar_function(array_union_inner)(&args.args)
148    }
149
150    fn aliases(&self) -> &[String] {
151        &self.aliases
152    }
153
154    fn documentation(&self) -> Option<&Documentation> {
155        self.doc()
156    }
157}
158
159#[user_doc(
160    doc_section(label = "Array Functions"),
161    description = "Returns an array of elements in the intersection of array1 and array2.",
162    syntax_example = "array_intersect(array1, array2)",
163    sql_example = r#"```sql
164> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);
165+----------------------------------------------------+
166| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);       |
167+----------------------------------------------------+
168| [3, 4]                                             |
169+----------------------------------------------------+
170> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);
171+----------------------------------------------------+
172| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);       |
173+----------------------------------------------------+
174| []                                                 |
175+----------------------------------------------------+
176```"#,
177    argument(
178        name = "array1",
179        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
180    ),
181    argument(
182        name = "array2",
183        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
184    )
185)]
186#[derive(Debug, PartialEq, Eq, Hash)]
187pub(super) struct ArrayIntersect {
188    signature: Signature,
189    aliases: Vec<String>,
190}
191
192impl ArrayIntersect {
193    pub fn new() -> Self {
194        Self {
195            signature: Signature::arrays(
196                2,
197                Some(ListCoercion::FixedSizedListToList),
198                Volatility::Immutable,
199            ),
200            aliases: vec![String::from("list_intersect")],
201        }
202    }
203}
204
205impl ScalarUDFImpl for ArrayIntersect {
206    fn as_any(&self) -> &dyn Any {
207        self
208    }
209
210    fn name(&self) -> &str {
211        "array_intersect"
212    }
213
214    fn signature(&self) -> &Signature {
215        &self.signature
216    }
217
218    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
219        let [array1, array2] = take_function_args(self.name(), arg_types)?;
220        match (array1, array2) {
221            (Null, Null) => Ok(DataType::new_list(Null, true)),
222            (Null, dt) => Ok(dt.clone()),
223            (dt, Null) => Ok(dt.clone()),
224            (dt, _) => Ok(dt.clone()),
225        }
226    }
227
228    fn invoke_with_args(
229        &self,
230        args: datafusion_expr::ScalarFunctionArgs,
231    ) -> Result<ColumnarValue> {
232        make_scalar_function(array_intersect_inner)(&args.args)
233    }
234
235    fn aliases(&self) -> &[String] {
236        &self.aliases
237    }
238
239    fn documentation(&self) -> Option<&Documentation> {
240        self.doc()
241    }
242}
243
244#[user_doc(
245    doc_section(label = "Array Functions"),
246    description = "Returns distinct values from the array after removing duplicates.",
247    syntax_example = "array_distinct(array)",
248    sql_example = r#"```sql
249> select array_distinct([1, 3, 2, 3, 1, 2, 4]);
250+---------------------------------+
251| array_distinct(List([1,2,3,4])) |
252+---------------------------------+
253| [1, 2, 3, 4]                    |
254+---------------------------------+
255```"#,
256    argument(
257        name = "array",
258        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
259    )
260)]
261#[derive(Debug, PartialEq, Eq, Hash)]
262pub(super) struct ArrayDistinct {
263    signature: Signature,
264    aliases: Vec<String>,
265}
266
267impl ArrayDistinct {
268    pub fn new() -> Self {
269        Self {
270            signature: Signature::array(Volatility::Immutable),
271            aliases: vec!["list_distinct".to_string()],
272        }
273    }
274}
275
276impl ScalarUDFImpl for ArrayDistinct {
277    fn as_any(&self) -> &dyn Any {
278        self
279    }
280
281    fn name(&self) -> &str {
282        "array_distinct"
283    }
284
285    fn signature(&self) -> &Signature {
286        &self.signature
287    }
288
289    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
290        Ok(arg_types[0].clone())
291    }
292
293    fn invoke_with_args(
294        &self,
295        args: datafusion_expr::ScalarFunctionArgs,
296    ) -> Result<ColumnarValue> {
297        make_scalar_function(array_distinct_inner)(&args.args)
298    }
299
300    fn aliases(&self) -> &[String] {
301        &self.aliases
302    }
303
304    fn documentation(&self) -> Option<&Documentation> {
305        self.doc()
306    }
307}
308
309/// array_distinct SQL function
310/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
311fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
312    let [array] = take_function_args("array_distinct", args)?;
313    match array.data_type() {
314        Null => Ok(Arc::clone(array)),
315        List(field) => {
316            let array = as_list_array(&array)?;
317            general_array_distinct(array, field)
318        }
319        LargeList(field) => {
320            let array = as_large_list_array(&array)?;
321            general_array_distinct(array, field)
322        }
323        arg_type => exec_err!("array_distinct does not support type {arg_type}"),
324    }
325}
326
327#[derive(Debug, PartialEq)]
328enum SetOp {
329    Union,
330    Intersect,
331}
332
333impl Display for SetOp {
334    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
335        match self {
336            SetOp::Union => write!(f, "array_union"),
337            SetOp::Intersect => write!(f, "array_intersect"),
338        }
339    }
340}
341
342fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
343    l: &GenericListArray<OffsetSize>,
344    r: &GenericListArray<OffsetSize>,
345    field: Arc<Field>,
346    set_op: SetOp,
347) -> Result<ArrayRef> {
348    if l.is_empty() || l.value_type().is_null() {
349        let field = Arc::new(Field::new_list_field(r.value_type(), true));
350        return general_array_distinct::<OffsetSize>(r, &field);
351    } else if r.is_empty() || r.value_type().is_null() {
352        let field = Arc::new(Field::new_list_field(l.value_type(), true));
353        return general_array_distinct::<OffsetSize>(l, &field);
354    }
355
356    if l.value_type() != r.value_type() {
357        return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
358    }
359
360    let mut offsets = vec![OffsetSize::usize_as(0)];
361    let mut new_arrays = vec![];
362    let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
363    for (first_arr, second_arr) in l.iter().zip(r.iter()) {
364        let l_values = if let Some(first_arr) = first_arr {
365            converter.convert_columns(&[first_arr])?
366        } else {
367            converter.convert_columns(&[])?
368        };
369
370        let r_values = if let Some(second_arr) = second_arr {
371            converter.convert_columns(&[second_arr])?
372        } else {
373            converter.convert_columns(&[])?
374        };
375
376        let l_iter = l_values.iter().sorted().dedup();
377        let values_set: HashSet<_> = l_iter.clone().collect();
378        let mut rows = if set_op == SetOp::Union {
379            l_iter.collect()
380        } else {
381            vec![]
382        };
383
384        for r_val in r_values.iter().sorted().dedup() {
385            match set_op {
386                SetOp::Union => {
387                    if !values_set.contains(&r_val) {
388                        rows.push(r_val);
389                    }
390                }
391                SetOp::Intersect => {
392                    if values_set.contains(&r_val) {
393                        rows.push(r_val);
394                    }
395                }
396            }
397        }
398
399        let last_offset = match offsets.last() {
400            Some(offset) => *offset,
401            None => return internal_err!("offsets should not be empty"),
402        };
403
404        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
405        let arrays = converter.convert_rows(rows)?;
406        let array = match arrays.first() {
407            Some(array) => Arc::clone(array),
408            None => {
409                return internal_err!("{set_op}: failed to get array from rows");
410            }
411        };
412
413        new_arrays.push(array);
414    }
415
416    let offsets = OffsetBuffer::new(offsets.into());
417    let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect();
418    let values = compute::concat(&new_arrays_ref)?;
419    let arr = GenericListArray::<OffsetSize>::try_new(field, offsets, values, None)?;
420    Ok(Arc::new(arr))
421}
422
423fn general_set_op(
424    array1: &ArrayRef,
425    array2: &ArrayRef,
426    set_op: SetOp,
427) -> Result<ArrayRef> {
428    fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result<ArrayRef> {
429        let field = Arc::new(Field::new_list_field(data_type.clone(), true));
430        let values = new_null_array(data_type, len);
431        if large {
432            Ok(Arc::new(LargeListArray::try_new(
433                field,
434                OffsetBuffer::new_zeroed(len),
435                values,
436                None,
437            )?))
438        } else {
439            Ok(Arc::new(ListArray::try_new(
440                field,
441                OffsetBuffer::new_zeroed(len),
442                values,
443                None,
444            )?))
445        }
446    }
447
448    match (array1.data_type(), array2.data_type()) {
449        (Null, Null) => Ok(Arc::new(ListArray::new_null(
450            Arc::new(Field::new_list_field(Null, true)),
451            array1.len(),
452        ))),
453        (Null, List(field)) => {
454            if set_op == SetOp::Intersect {
455                return empty_array(field.data_type(), array1.len(), false);
456            }
457            let array = as_list_array(&array2)?;
458            general_array_distinct::<i32>(array, field)
459        }
460        (List(field), Null) => {
461            if set_op == SetOp::Intersect {
462                return empty_array(field.data_type(), array1.len(), false);
463            }
464            let array = as_list_array(&array1)?;
465            general_array_distinct::<i32>(array, field)
466        }
467        (Null, LargeList(field)) => {
468            if set_op == SetOp::Intersect {
469                return empty_array(field.data_type(), array1.len(), true);
470            }
471            let array = as_large_list_array(&array2)?;
472            general_array_distinct::<i64>(array, field)
473        }
474        (LargeList(field), Null) => {
475            if set_op == SetOp::Intersect {
476                return empty_array(field.data_type(), array1.len(), true);
477            }
478            let array = as_large_list_array(&array1)?;
479            general_array_distinct::<i64>(array, field)
480        }
481        (List(field), List(_)) => {
482            let array1 = as_list_array(&array1)?;
483            let array2 = as_list_array(&array2)?;
484            generic_set_lists::<i32>(array1, array2, Arc::clone(field), set_op)
485        }
486        (LargeList(field), LargeList(_)) => {
487            let array1 = as_large_list_array(&array1)?;
488            let array2 = as_large_list_array(&array2)?;
489            generic_set_lists::<i64>(array1, array2, Arc::clone(field), set_op)
490        }
491        (data_type1, data_type2) => {
492            internal_err!(
493                "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
494            )
495        }
496    }
497}
498
499/// Array_union SQL function
500fn array_union_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
501    let [array1, array2] = take_function_args("array_union", args)?;
502    general_set_op(array1, array2, SetOp::Union)
503}
504
505/// array_intersect SQL function
506fn array_intersect_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
507    let [array1, array2] = take_function_args("array_intersect", args)?;
508    general_set_op(array1, array2, SetOp::Intersect)
509}
510
511fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
512    array: &GenericListArray<OffsetSize>,
513    field: &FieldRef,
514) -> Result<ArrayRef> {
515    if array.is_empty() {
516        return Ok(Arc::new(array.clone()) as ArrayRef);
517    }
518    let dt = array.value_type();
519    let mut offsets = Vec::with_capacity(array.len());
520    offsets.push(OffsetSize::usize_as(0));
521    let mut new_arrays = Vec::with_capacity(array.len());
522    let converter = RowConverter::new(vec![SortField::new(dt)])?;
523    // distinct for each list in ListArray
524    for arr in array.iter() {
525        let last_offset: OffsetSize = offsets.last().copied().unwrap();
526        let Some(arr) = arr else {
527            // Add same offset for null
528            offsets.push(last_offset);
529            continue;
530        };
531        let values = converter.convert_columns(&[arr])?;
532        // sort elements in list and remove duplicates
533        let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
534        offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
535        let arrays = converter.convert_rows(rows)?;
536        let array = match arrays.first() {
537            Some(array) => Arc::clone(array),
538            None => {
539                return internal_err!("array_distinct: failed to get array from rows")
540            }
541        };
542        new_arrays.push(array);
543    }
544    if new_arrays.is_empty() {
545        return Ok(Arc::new(array.clone()) as ArrayRef);
546    }
547    let offsets = OffsetBuffer::new(offsets.into());
548    let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
549    let values = compute::concat(&new_arrays_ref)?;
550    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
551        Arc::clone(field),
552        offsets,
553        values,
554        // Keep the list nulls
555        array.nulls().cloned(),
556    )?))
557}
558
559#[cfg(test)]
560mod tests {
561    use std::sync::Arc;
562
563    use arrow::{
564        array::{Int32Array, ListArray},
565        buffer::OffsetBuffer,
566        datatypes::{DataType, Field},
567    };
568    use datafusion_common::{config::ConfigOptions, DataFusionError};
569    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
570
571    use crate::set_ops::array_distinct_udf;
572
573    #[test]
574    fn test_array_distinct_inner_nullability_result_type_match_return_type(
575    ) -> Result<(), DataFusionError> {
576        let udf = array_distinct_udf();
577
578        for inner_nullable in [true, false] {
579            let inner_field = Field::new_list_field(DataType::Int32, inner_nullable);
580            let input_field =
581                Field::new_list("input", Arc::new(inner_field.clone()), true);
582
583            // [[1, 1, 2]]
584            let input_array = ListArray::new(
585                inner_field.into(),
586                OffsetBuffer::new(vec![0, 3].into()),
587                Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)),
588                None,
589            );
590
591            let input_array = ColumnarValue::Array(Arc::new(input_array));
592
593            let result = udf.invoke_with_args(ScalarFunctionArgs {
594                args: vec![input_array],
595                arg_fields: vec![input_field.clone().into()],
596                number_rows: 1,
597                return_field: input_field.clone().into(),
598                config_options: Arc::new(ConfigOptions::default()),
599            })?;
600
601            assert_eq!(
602                result.data_type(),
603                udf.return_type(&[input_field.data_type().clone()])?
604            );
605        }
606        Ok(())
607    }
608}