datafusion_spark/function/array/
spark_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, sync::Arc};
19
20use arrow::array::{
21    make_array, new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray,
22    MutableArrayData, NullArray, OffsetSizeTrait,
23};
24use arrow::buffer::OffsetBuffer;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::utils::SingleRowListArrayBuilder;
27use datafusion_common::{plan_datafusion_err, plan_err, Result};
28use datafusion_expr::type_coercion::binary::comparison_coercion;
29use datafusion_expr::{
30    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignature, Volatility,
32};
33
34use crate::function::functions_nested_utils::make_scalar_function;
35
36const ARRAY_FIELD_DEFAULT_NAME: &str = "element";
37
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct SparkArray {
40    signature: Signature,
41    aliases: Vec<String>,
42}
43
44impl Default for SparkArray {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl SparkArray {
51    pub fn new() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55                Volatility::Immutable,
56            ),
57            aliases: vec![String::from("spark_make_array")],
58        }
59    }
60}
61
62impl ScalarUDFImpl for SparkArray {
63    fn as_any(&self) -> &dyn Any {
64        self
65    }
66
67    fn name(&self) -> &str {
68        "array"
69    }
70
71    fn signature(&self) -> &Signature {
72        &self.signature
73    }
74
75    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
76        match arg_types.len() {
77            0 => Ok(empty_array_type()),
78            _ => {
79                let mut expr_type = DataType::Null;
80                for arg_type in arg_types {
81                    if !arg_type.equals_datatype(&DataType::Null) {
82                        expr_type = arg_type.clone();
83                        break;
84                    }
85                }
86
87                if expr_type.is_null() {
88                    expr_type = DataType::Int32;
89                }
90
91                Ok(DataType::List(Arc::new(Field::new_list_field(
92                    expr_type, true,
93                ))))
94            }
95        }
96    }
97
98    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
99        let data_types = args
100            .arg_fields
101            .iter()
102            .map(|f| f.data_type())
103            .cloned()
104            .collect::<Vec<_>>();
105        let return_type = self.return_type(&data_types)?;
106        Ok(Arc::new(Field::new(
107            ARRAY_FIELD_DEFAULT_NAME,
108            return_type,
109            false,
110        )))
111    }
112
113    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
114        let ScalarFunctionArgs { args, .. } = args;
115        make_scalar_function(make_array_inner)(args.as_slice())
116    }
117
118    fn aliases(&self) -> &[String] {
119        &self.aliases
120    }
121
122    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
123        let first_type = arg_types.first().ok_or_else(|| {
124            plan_datafusion_err!("Spark array function requires at least one argument")
125        })?;
126        let new_type =
127            arg_types
128                .iter()
129                .skip(1)
130                .try_fold(first_type.clone(), |acc, x| {
131                    // The coerced types found by `comparison_coercion` are not guaranteed to be
132                    // coercible for the arguments. `comparison_coercion` returns more loose
133                    // types that can be coerced to both `acc` and `x` for comparison purpose.
134                    // See `maybe_data_types` for the actual coercion.
135                    let coerced_type = comparison_coercion(&acc, x);
136                    if let Some(coerced_type) = coerced_type {
137                        Ok(coerced_type)
138                    } else {
139                        plan_err!("Coercion from {acc:?} to {x:?} failed.")
140                    }
141                })?;
142        Ok(vec![new_type; arg_types.len()])
143    }
144}
145
146// Empty array is a special case that is useful for many other array functions
147pub(super) fn empty_array_type() -> DataType {
148    DataType::List(Arc::new(Field::new(
149        ARRAY_FIELD_DEFAULT_NAME,
150        DataType::Int32,
151        true,
152    )))
153}
154
155/// `make_array_inner` is the implementation of the `make_array` function.
156/// Constructs an array using the input `data` as `ArrayRef`.
157/// Returns a reference-counted `Array` instance result.
158pub fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
159    let mut data_type = DataType::Null;
160    for arg in arrays {
161        let arg_data_type = arg.data_type();
162        if !arg_data_type.equals_datatype(&DataType::Null) {
163            data_type = arg_data_type.clone();
164            break;
165        }
166    }
167
168    match data_type {
169        // Either an empty array or all nulls:
170        DataType::Null => {
171            let length = arrays.iter().map(|a| a.len()).sum();
172            // By default Int32
173            let array = new_null_array(&DataType::Int32, length);
174            Ok(Arc::new(
175                SingleRowListArrayBuilder::new(array)
176                    .with_nullable(true)
177                    .build_list_array(),
178            ))
179        }
180        DataType::LargeList(..) => array_array::<i64>(arrays, data_type),
181        _ => array_array::<i32>(arrays, data_type),
182    }
183}
184
185/// Convert one or more [`ArrayRef`] of the same type into a
186/// `ListArray` or 'LargeListArray' depending on the offset size.
187///
188/// # Example (non nested)
189///
190/// Calling `array(col1, col2)` where col1 and col2 are non nested
191/// would return a single new `ListArray`, where each row was a list
192/// of 2 elements:
193///
194/// ```text
195/// ┌─────────┐   ┌─────────┐           ┌──────────────┐
196/// │ ┌─────┐ │   │ ┌─────┐ │           │ ┌──────────┐ │
197/// │ │  A  │ │   │ │  X  │ │           │ │  [A, X]  │ │
198/// │ ├─────┤ │   │ ├─────┤ │           │ ├──────────┤ │
199/// │ │NULL │ │   │ │  Y  │ │──────────▶│ │[NULL, Y] │ │
200/// │ ├─────┤ │   │ ├─────┤ │           │ ├──────────┤ │
201/// │ │  C  │ │   │ │  Z  │ │           │ │  [C, Z]  │ │
202/// │ └─────┘ │   │ └─────┘ │           │ └──────────┘ │
203/// └─────────┘   └─────────┘           └──────────────┘
204///   col1           col2                    output
205/// ```
206///
207/// # Example (nested)
208///
209/// Calling `array(col1, col2)` where col1 and col2 are lists
210/// would return a single new `ListArray`, where each row was a list
211/// of the corresponding elements of col1 and col2.
212///
213/// ``` text
214/// ┌──────────────┐   ┌──────────────┐        ┌─────────────────────────────┐
215/// │ ┌──────────┐ │   │ ┌──────────┐ │        │ ┌────────────────────────┐  │
216/// │ │  [A, X]  │ │   │ │    []    │ │        │ │    [[A, X], []]        │  │
217/// │ ├──────────┤ │   │ ├──────────┤ │        │ ├────────────────────────┤  │
218/// │ │[NULL, Y] │ │   │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │  │
219/// │ ├──────────┤ │   │ ├──────────┤ │        │ ├────────────────────────│  │
220/// │ │  [C, Z]  │ │   │ │   NULL   │ │        │ │    [[C, Z], NULL]      │  │
221/// │ └──────────┘ │   │ └──────────┘ │        │ └────────────────────────┘  │
222/// └──────────────┘   └──────────────┘        └─────────────────────────────┘
223///      col1               col2                         output
224/// ```
225fn array_array<O: OffsetSizeTrait>(
226    args: &[ArrayRef],
227    data_type: DataType,
228) -> Result<ArrayRef> {
229    // do not accept 0 arguments.
230    if args.is_empty() {
231        return plan_err!("Array requires at least one argument");
232    }
233
234    let mut data = vec![];
235    let mut total_len = 0;
236    for arg in args {
237        let arg_data = if arg.as_any().is::<NullArray>() {
238            ArrayData::new_empty(&data_type)
239        } else {
240            arg.to_data()
241        };
242        total_len += arg_data.len();
243        data.push(arg_data);
244    }
245
246    let mut offsets: Vec<O> = Vec::with_capacity(total_len);
247    offsets.push(O::usize_as(0));
248
249    let capacity = Capacities::Array(total_len);
250    let data_ref = data.iter().collect::<Vec<_>>();
251    let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity);
252
253    let num_rows = args[0].len();
254    for row_idx in 0..num_rows {
255        for (arr_idx, arg) in args.iter().enumerate() {
256            if !arg.as_any().is::<NullArray>()
257                && !arg.is_null(row_idx)
258                && arg.is_valid(row_idx)
259            {
260                mutable.extend(arr_idx, row_idx, row_idx + 1);
261            } else {
262                mutable.extend_nulls(1);
263            }
264        }
265        offsets.push(O::usize_as(mutable.len()));
266    }
267    let data = mutable.freeze();
268
269    Ok(Arc::new(GenericListArray::<O>::try_new(
270        Arc::new(Field::new(ARRAY_FIELD_DEFAULT_NAME, data_type, true)),
271        OffsetBuffer::new(offsets.into()),
272        make_array(data),
273        None,
274    )?))
275}