Skip to main content

datafusion_functions_nested/
set_ops.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, GenericListArray, OffsetSizeTrait, new_empty_array, new_null_array,
23};
24use arrow::buffer::{NullBuffer, OffsetBuffer};
25use arrow::datatypes::DataType::{LargeList, List, Null};
26use arrow::datatypes::{DataType, Field, FieldRef};
27use arrow::row::{RowConverter, SortField};
28use datafusion_common::cast::{as_large_list_array, as_list_array};
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{
31    Result, assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args,
32};
33use datafusion_expr::{
34    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
35};
36use datafusion_macros::user_doc;
37use hashbrown::HashSet;
38use std::any::Any;
39use std::fmt::{Display, Formatter};
40use std::sync::Arc;
41
42// Create static instances of ScalarUDFs for each function
43make_udf_expr_and_func!(
44    ArrayUnion,
45    array_union,
46    array1 array2,
47    "returns an array of the elements in the union of array1 and array2 without duplicates.",
48    array_union_udf
49);
50
51make_udf_expr_and_func!(
52    ArrayIntersect,
53    array_intersect,
54    first_array second_array,
55    "returns an array of the elements in the intersection of array1 and array2.",
56    array_intersect_udf
57);
58
59make_udf_expr_and_func!(
60    ArrayDistinct,
61    array_distinct,
62    array,
63    "returns distinct values from the array after removing duplicates.",
64    array_distinct_udf
65);
66
67#[user_doc(
68    doc_section(label = "Array Functions"),
69    description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.",
70    syntax_example = "array_union(array1, array2)",
71    sql_example = r#"```sql
72> select array_union([1, 2, 3, 4], [5, 6, 3, 4]);
73+----------------------------------------------------+
74| array_union([1, 2, 3, 4], [5, 6, 3, 4]);           |
75+----------------------------------------------------+
76| [1, 2, 3, 4, 5, 6]                                 |
77+----------------------------------------------------+
78> select array_union([1, 2, 3, 4], [5, 6, 7, 8]);
79+----------------------------------------------------+
80| array_union([1, 2, 3, 4], [5, 6, 7, 8]);           |
81+----------------------------------------------------+
82| [1, 2, 3, 4, 5, 6, 7, 8]                           |
83+----------------------------------------------------+
84```"#,
85    argument(
86        name = "array1",
87        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
88    ),
89    argument(
90        name = "array2",
91        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
92    )
93)]
94#[derive(Debug, PartialEq, Eq, Hash)]
95pub struct ArrayUnion {
96    signature: Signature,
97    aliases: Vec<String>,
98}
99
100impl Default for ArrayUnion {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl ArrayUnion {
107    pub fn new() -> Self {
108        Self {
109            signature: Signature::arrays(
110                2,
111                Some(ListCoercion::FixedSizedListToList),
112                Volatility::Immutable,
113            ),
114            aliases: vec![String::from("list_union")],
115        }
116    }
117}
118
119impl ScalarUDFImpl for ArrayUnion {
120    fn as_any(&self) -> &dyn Any {
121        self
122    }
123
124    fn name(&self) -> &str {
125        "array_union"
126    }
127
128    fn signature(&self) -> &Signature {
129        &self.signature
130    }
131
132    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
133        let [array1, array2] = take_function_args(self.name(), arg_types)?;
134        match (array1, array2) {
135            (Null, Null) => Ok(DataType::new_list(Null, true)),
136            (Null, dt) | (dt, Null) => Ok(dt.clone()),
137            (dt, _) => Ok(dt.clone()),
138        }
139    }
140
141    fn invoke_with_args(
142        &self,
143        args: datafusion_expr::ScalarFunctionArgs,
144    ) -> Result<ColumnarValue> {
145        make_scalar_function(array_union_inner)(&args.args)
146    }
147
148    fn aliases(&self) -> &[String] {
149        &self.aliases
150    }
151
152    fn documentation(&self) -> Option<&Documentation> {
153        self.doc()
154    }
155}
156
157#[user_doc(
158    doc_section(label = "Array Functions"),
159    description = "Returns an array of elements in the intersection of array1 and array2.",
160    syntax_example = "array_intersect(array1, array2)",
161    sql_example = r#"```sql
162> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);
163+----------------------------------------------------+
164| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]);       |
165+----------------------------------------------------+
166| [3, 4]                                             |
167+----------------------------------------------------+
168> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);
169+----------------------------------------------------+
170| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]);       |
171+----------------------------------------------------+
172| []                                                 |
173+----------------------------------------------------+
174```"#,
175    argument(
176        name = "array1",
177        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
178    ),
179    argument(
180        name = "array2",
181        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
182    )
183)]
184#[derive(Debug, PartialEq, Eq, Hash)]
185pub struct ArrayIntersect {
186    signature: Signature,
187    aliases: Vec<String>,
188}
189
190impl Default for ArrayIntersect {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl ArrayIntersect {
197    pub fn new() -> Self {
198        Self {
199            signature: Signature::arrays(
200                2,
201                Some(ListCoercion::FixedSizedListToList),
202                Volatility::Immutable,
203            ),
204            aliases: vec![String::from("list_intersect")],
205        }
206    }
207}
208
209impl ScalarUDFImpl for ArrayIntersect {
210    fn as_any(&self) -> &dyn Any {
211        self
212    }
213
214    fn name(&self) -> &str {
215        "array_intersect"
216    }
217
218    fn signature(&self) -> &Signature {
219        &self.signature
220    }
221
222    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
223        let [array1, array2] = take_function_args(self.name(), arg_types)?;
224        match (array1, array2) {
225            (Null, Null) => Ok(DataType::new_list(Null, true)),
226            (Null, dt) | (dt, Null) => Ok(dt.clone()),
227            (dt, _) => Ok(dt.clone()),
228        }
229    }
230
231    fn invoke_with_args(
232        &self,
233        args: datafusion_expr::ScalarFunctionArgs,
234    ) -> Result<ColumnarValue> {
235        make_scalar_function(array_intersect_inner)(&args.args)
236    }
237
238    fn aliases(&self) -> &[String] {
239        &self.aliases
240    }
241
242    fn documentation(&self) -> Option<&Documentation> {
243        self.doc()
244    }
245}
246
247#[user_doc(
248    doc_section(label = "Array Functions"),
249    description = "Returns distinct values from the array after removing duplicates.",
250    syntax_example = "array_distinct(array)",
251    sql_example = r#"```sql
252> select array_distinct([1, 3, 2, 3, 1, 2, 4]);
253+---------------------------------+
254| array_distinct(List([1,2,3,4])) |
255+---------------------------------+
256| [1, 2, 3, 4]                    |
257+---------------------------------+
258```"#,
259    argument(
260        name = "array",
261        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
262    )
263)]
264#[derive(Debug, PartialEq, Eq, Hash)]
265pub struct ArrayDistinct {
266    signature: Signature,
267    aliases: Vec<String>,
268}
269
270impl ArrayDistinct {
271    pub fn new() -> Self {
272        Self {
273            signature: Signature::array(Volatility::Immutable),
274            aliases: vec!["list_distinct".to_string()],
275        }
276    }
277}
278
279impl Default for ArrayDistinct {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285impl ScalarUDFImpl for ArrayDistinct {
286    fn as_any(&self) -> &dyn Any {
287        self
288    }
289
290    fn name(&self) -> &str {
291        "array_distinct"
292    }
293
294    fn signature(&self) -> &Signature {
295        &self.signature
296    }
297
298    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
299        Ok(arg_types[0].clone())
300    }
301
302    fn invoke_with_args(
303        &self,
304        args: datafusion_expr::ScalarFunctionArgs,
305    ) -> Result<ColumnarValue> {
306        make_scalar_function(array_distinct_inner)(&args.args)
307    }
308
309    fn aliases(&self) -> &[String] {
310        &self.aliases
311    }
312
313    fn documentation(&self) -> Option<&Documentation> {
314        self.doc()
315    }
316}
317
318/// array_distinct SQL function
319/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
320fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
321    let [array] = take_function_args("array_distinct", args)?;
322    match array.data_type() {
323        Null => Ok(Arc::clone(array)),
324        List(field) => {
325            let array = as_list_array(&array)?;
326            general_array_distinct(array, field)
327        }
328        LargeList(field) => {
329            let array = as_large_list_array(&array)?;
330            general_array_distinct(array, field)
331        }
332        arg_type => exec_err!("array_distinct does not support type {arg_type}"),
333    }
334}
335
336#[derive(Debug, PartialEq, Copy, Clone)]
337enum SetOp {
338    Union,
339    Intersect,
340}
341
342impl Display for SetOp {
343    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
344        match self {
345            SetOp::Union => write!(f, "array_union"),
346            SetOp::Intersect => write!(f, "array_intersect"),
347        }
348    }
349}
350
351fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
352    l: &GenericListArray<OffsetSize>,
353    r: &GenericListArray<OffsetSize>,
354    field: Arc<Field>,
355    set_op: SetOp,
356) -> Result<ArrayRef> {
357    if l.is_empty() || l.value_type().is_null() {
358        let field = Arc::new(Field::new_list_field(r.value_type(), true));
359        return general_array_distinct::<OffsetSize>(r, &field);
360    } else if r.is_empty() || r.value_type().is_null() {
361        let field = Arc::new(Field::new_list_field(l.value_type(), true));
362        return general_array_distinct::<OffsetSize>(l, &field);
363    }
364
365    assert_eq_or_internal_err!(
366        l.value_type(),
367        r.value_type(),
368        "{set_op:?} is not implemented for '{l:?}' and '{r:?}'"
369    );
370
371    // Convert all values to rows in batch for performance.
372    let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
373    let rows_l = converter.convert_columns(&[Arc::clone(l.values())])?;
374    let rows_r = converter.convert_columns(&[Arc::clone(r.values())])?;
375
376    match set_op {
377        SetOp::Union => generic_set_loop::<OffsetSize, true>(
378            l, r, &rows_l, &rows_r, field, &converter,
379        ),
380        SetOp::Intersect => generic_set_loop::<OffsetSize, false>(
381            l, r, &rows_l, &rows_r, field, &converter,
382        ),
383    }
384}
385
386/// Inner loop for set operations, parameterized by const generic to
387/// avoid branching inside the hot loop.
388fn generic_set_loop<OffsetSize: OffsetSizeTrait, const IS_UNION: bool>(
389    l: &GenericListArray<OffsetSize>,
390    r: &GenericListArray<OffsetSize>,
391    rows_l: &arrow::row::Rows,
392    rows_r: &arrow::row::Rows,
393    field: Arc<Field>,
394    converter: &RowConverter,
395) -> Result<ArrayRef> {
396    let l_offsets = l.value_offsets();
397    let r_offsets = r.value_offsets();
398
399    let mut result_offsets = Vec::with_capacity(l.len() + 1);
400    result_offsets.push(OffsetSize::usize_as(0));
401    let initial_capacity = if IS_UNION {
402        // Union can include all elements from both sides
403        rows_l.num_rows()
404    } else {
405        // Intersect result is bounded by the smaller side
406        rows_l.num_rows().min(rows_r.num_rows())
407    };
408
409    let mut final_rows = Vec::with_capacity(initial_capacity);
410
411    // Reuse hash sets across iterations
412    let mut seen = HashSet::new();
413    let mut lookup_set = HashSet::new();
414    for i in 0..l.len() {
415        let last_offset = *result_offsets.last().unwrap();
416
417        if l.is_null(i) || r.is_null(i) {
418            result_offsets.push(last_offset);
419            continue;
420        }
421
422        let l_start = l_offsets[i].as_usize();
423        let l_end = l_offsets[i + 1].as_usize();
424        let r_start = r_offsets[i].as_usize();
425        let r_end = r_offsets[i + 1].as_usize();
426
427        seen.clear();
428
429        if IS_UNION {
430            for idx in l_start..l_end {
431                let row = rows_l.row(idx);
432                if seen.insert(row) {
433                    final_rows.push(row);
434                }
435            }
436            for idx in r_start..r_end {
437                let row = rows_r.row(idx);
438                if seen.insert(row) {
439                    final_rows.push(row);
440                }
441            }
442        } else {
443            let l_len = l_end - l_start;
444            let r_len = r_end - r_start;
445
446            // Select shorter side for lookup, longer side for probing
447            let (lookup_rows, lookup_range, probe_rows, probe_range) = if l_len < r_len {
448                (rows_l, l_start..l_end, rows_r, r_start..r_end)
449            } else {
450                (rows_r, r_start..r_end, rows_l, l_start..l_end)
451            };
452            lookup_set.clear();
453            lookup_set.reserve(lookup_range.len());
454
455            // Build lookup table
456            for idx in lookup_range {
457                lookup_set.insert(lookup_rows.row(idx));
458            }
459
460            // Probe and emit distinct intersected rows
461            for idx in probe_range {
462                let row = probe_rows.row(idx);
463                if lookup_set.contains(&row) && seen.insert(row) {
464                    final_rows.push(row);
465                }
466            }
467        }
468        result_offsets.push(last_offset + OffsetSize::usize_as(seen.len()));
469    }
470
471    let final_values = if final_rows.is_empty() {
472        new_empty_array(&l.value_type())
473    } else {
474        let arrays = converter.convert_rows(final_rows)?;
475        Arc::clone(&arrays[0])
476    };
477
478    let arr = GenericListArray::<OffsetSize>::try_new(
479        field,
480        OffsetBuffer::new(result_offsets.into()),
481        final_values,
482        NullBuffer::union(l.nulls(), r.nulls()),
483    )?;
484    Ok(Arc::new(arr))
485}
486
487fn general_set_op(
488    array1: &ArrayRef,
489    array2: &ArrayRef,
490    set_op: SetOp,
491) -> Result<ArrayRef> {
492    let len = array1.len();
493    match (array1.data_type(), array2.data_type()) {
494        (Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)),
495        (Null, dt @ List(_))
496        | (Null, dt @ LargeList(_))
497        | (dt @ List(_), Null)
498        | (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)),
499        (List(field), List(_)) => {
500            let array1 = as_list_array(&array1)?;
501            let array2 = as_list_array(&array2)?;
502            generic_set_lists::<i32>(array1, array2, Arc::clone(field), set_op)
503        }
504        (LargeList(field), LargeList(_)) => {
505            let array1 = as_large_list_array(&array1)?;
506            let array2 = as_large_list_array(&array2)?;
507            generic_set_lists::<i64>(array1, array2, Arc::clone(field), set_op)
508        }
509        (data_type1, data_type2) => {
510            internal_err!(
511                "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
512            )
513        }
514    }
515}
516
517fn array_union_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
518    let [array1, array2] = take_function_args("array_union", args)?;
519    general_set_op(array1, array2, SetOp::Union)
520}
521
522fn array_intersect_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
523    let [array1, array2] = take_function_args("array_intersect", args)?;
524    general_set_op(array1, array2, SetOp::Intersect)
525}
526
527fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
528    array: &GenericListArray<OffsetSize>,
529    field: &FieldRef,
530) -> Result<ArrayRef> {
531    if array.is_empty() {
532        return Ok(Arc::new(array.clone()) as ArrayRef);
533    }
534    let value_offsets = array.value_offsets();
535    let dt = array.value_type();
536    let mut offsets = Vec::with_capacity(array.len() + 1);
537    offsets.push(OffsetSize::usize_as(0));
538
539    // Convert all values to row format in a single batch for performance
540    let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
541    let rows = converter.convert_columns(&[Arc::clone(array.values())])?;
542    let mut final_rows = Vec::with_capacity(rows.num_rows());
543    let mut seen = HashSet::new();
544    for i in 0..array.len() {
545        let last_offset = *offsets.last().unwrap();
546
547        // Null list entries produce no output; just carry forward the offset.
548        if array.is_null(i) {
549            offsets.push(last_offset);
550            continue;
551        }
552
553        let start = value_offsets[i].as_usize();
554        let end = value_offsets[i + 1].as_usize();
555        seen.clear();
556        seen.reserve(end - start);
557
558        // Walk the sub-array and keep only the first occurrence of each value.
559        for idx in start..end {
560            let row = rows.row(idx);
561            if seen.insert(row) {
562                final_rows.push(row);
563            }
564        }
565        offsets.push(last_offset + OffsetSize::usize_as(seen.len()));
566    }
567
568    // Convert all collected distinct rows back
569    let final_values = if final_rows.is_empty() {
570        new_empty_array(&dt)
571    } else {
572        let arrays = converter.convert_rows(final_rows)?;
573        Arc::clone(&arrays[0])
574    };
575
576    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
577        Arc::clone(field),
578        OffsetBuffer::new(offsets.into()),
579        final_values,
580        // Keep the list nulls
581        array.nulls().cloned(),
582    )?))
583}
584
585#[cfg(test)]
586mod tests {
587    use std::sync::Arc;
588
589    use arrow::{
590        array::{Int32Array, ListArray},
591        buffer::OffsetBuffer,
592        datatypes::{DataType, Field},
593    };
594    use datafusion_common::{DataFusionError, config::ConfigOptions};
595    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
596
597    use crate::set_ops::array_distinct_udf;
598
599    #[test]
600    fn test_array_distinct_inner_nullability_result_type_match_return_type()
601    -> Result<(), DataFusionError> {
602        let udf = array_distinct_udf();
603
604        for inner_nullable in [true, false] {
605            let inner_field = Field::new_list_field(DataType::Int32, inner_nullable);
606            let input_field =
607                Field::new_list("input", Arc::new(inner_field.clone()), true);
608
609            // [[1, 1, 2]]
610            let input_array = ListArray::new(
611                inner_field.into(),
612                OffsetBuffer::new(vec![0, 3].into()),
613                Arc::new(Int32Array::new(vec![1, 1, 2].into(), None)),
614                None,
615            );
616
617            let input_array = ColumnarValue::Array(Arc::new(input_array));
618
619            let result = udf.invoke_with_args(ScalarFunctionArgs {
620                args: vec![input_array],
621                arg_fields: vec![input_field.clone().into()],
622                number_rows: 1,
623                return_field: input_field.clone().into(),
624                config_options: Arc::new(ConfigOptions::default()),
625            })?;
626
627            assert_eq!(
628                result.data_type(),
629                udf.return_type(&[input_field.data_type().clone()])?
630            );
631        }
632        Ok(())
633    }
634}