datafusion_spark/function/array/
spark_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, sync::Arc};
19
20use arrow::array::{
21    make_array, new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray,
22    MutableArrayData, NullArray, OffsetSizeTrait,
23};
24use arrow::buffer::OffsetBuffer;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::utils::SingleRowListArrayBuilder;
27use datafusion_common::{internal_err, plan_datafusion_err, plan_err, Result};
28use datafusion_expr::type_coercion::binary::comparison_coercion;
29use datafusion_expr::{
30    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31    TypeSignature, Volatility,
32};
33
34use crate::function::functions_nested_utils::make_scalar_function;
35
36const ARRAY_FIELD_DEFAULT_NAME: &str = "element";
37
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct SparkArray {
40    signature: Signature,
41    aliases: Vec<String>,
42}
43
44impl Default for SparkArray {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl SparkArray {
51    pub fn new() -> Self {
52        Self {
53            signature: Signature::one_of(
54                vec![TypeSignature::UserDefined, TypeSignature::Nullary],
55                Volatility::Immutable,
56            ),
57            aliases: vec![String::from("spark_make_array")],
58        }
59    }
60}
61
62impl ScalarUDFImpl for SparkArray {
63    fn as_any(&self) -> &dyn Any {
64        self
65    }
66
67    fn name(&self) -> &str {
68        "array"
69    }
70
71    fn signature(&self) -> &Signature {
72        &self.signature
73    }
74
75    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
76        internal_err!("return_field_from_args should be used instead")
77    }
78
79    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
80        let data_types = args
81            .arg_fields
82            .iter()
83            .map(|f| f.data_type())
84            .cloned()
85            .collect::<Vec<_>>();
86
87        let mut expr_type = DataType::Null;
88        for arg_type in &data_types {
89            if !arg_type.equals_datatype(&DataType::Null) {
90                expr_type = arg_type.clone();
91                break;
92            }
93        }
94
95        if expr_type.is_null() {
96            expr_type = DataType::Int32;
97        }
98
99        let return_type = DataType::List(Arc::new(Field::new(
100            ARRAY_FIELD_DEFAULT_NAME,
101            expr_type,
102            true,
103        )));
104
105        Ok(Arc::new(Field::new(
106            "this_field_name_is_irrelevant",
107            return_type,
108            false,
109        )))
110    }
111
112    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113        let ScalarFunctionArgs { args, .. } = args;
114        make_scalar_function(make_array_inner)(args.as_slice())
115    }
116
117    fn aliases(&self) -> &[String] {
118        &self.aliases
119    }
120
121    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
122        let first_type = arg_types.first().ok_or_else(|| {
123            plan_datafusion_err!("Spark array function requires at least one argument")
124        })?;
125        let new_type =
126            arg_types
127                .iter()
128                .skip(1)
129                .try_fold(first_type.clone(), |acc, x| {
130                    // The coerced types found by `comparison_coercion` are not guaranteed to be
131                    // coercible for the arguments. `comparison_coercion` returns more loose
132                    // types that can be coerced to both `acc` and `x` for comparison purpose.
133                    // See `maybe_data_types` for the actual coercion.
134                    let coerced_type = comparison_coercion(&acc, x);
135                    if let Some(coerced_type) = coerced_type {
136                        Ok(coerced_type)
137                    } else {
138                        plan_err!("Coercion from {acc} to {x} failed.")
139                    }
140                })?;
141        Ok(vec![new_type; arg_types.len()])
142    }
143}
144
145/// `make_array_inner` is the implementation of the `make_array` function.
146/// Constructs an array using the input `data` as `ArrayRef`.
147/// Returns a reference-counted `Array` instance result.
148pub fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
149    let mut data_type = DataType::Null;
150    for arg in arrays {
151        let arg_data_type = arg.data_type();
152        if !arg_data_type.equals_datatype(&DataType::Null) {
153            data_type = arg_data_type.clone();
154            break;
155        }
156    }
157
158    match data_type {
159        // Either an empty array or all nulls:
160        DataType::Null => {
161            let length = arrays.iter().map(|a| a.len()).sum();
162            // By default Int32
163            let array = new_null_array(&DataType::Int32, length);
164            Ok(Arc::new(
165                SingleRowListArrayBuilder::new(array)
166                    .with_nullable(true)
167                    .with_field_name(Some(ARRAY_FIELD_DEFAULT_NAME.to_string()))
168                    .build_list_array(),
169            ))
170        }
171        _ => array_array::<i32>(arrays, data_type),
172    }
173}
174
175/// Convert one or more [`ArrayRef`] of the same type into a
176/// `ListArray` or 'LargeListArray' depending on the offset size.
177///
178/// # Example (non nested)
179///
180/// Calling `array(col1, col2)` where col1 and col2 are non nested
181/// would return a single new `ListArray`, where each row was a list
182/// of 2 elements:
183///
184/// ```text
185/// ┌─────────┐   ┌─────────┐           ┌──────────────┐
186/// │ ┌─────┐ │   │ ┌─────┐ │           │ ┌──────────┐ │
187/// │ │  A  │ │   │ │  X  │ │           │ │  [A, X]  │ │
188/// │ ├─────┤ │   │ ├─────┤ │           │ ├──────────┤ │
189/// │ │NULL │ │   │ │  Y  │ │──────────▶│ │[NULL, Y] │ │
190/// │ ├─────┤ │   │ ├─────┤ │           │ ├──────────┤ │
191/// │ │  C  │ │   │ │  Z  │ │           │ │  [C, Z]  │ │
192/// │ └─────┘ │   │ └─────┘ │           │ └──────────┘ │
193/// └─────────┘   └─────────┘           └──────────────┘
194///   col1           col2                    output
195/// ```
196///
197/// # Example (nested)
198///
199/// Calling `array(col1, col2)` where col1 and col2 are lists
200/// would return a single new `ListArray`, where each row was a list
201/// of the corresponding elements of col1 and col2.
202///
203/// ``` text
204/// ┌──────────────┐   ┌──────────────┐        ┌─────────────────────────────┐
205/// │ ┌──────────┐ │   │ ┌──────────┐ │        │ ┌────────────────────────┐  │
206/// │ │  [A, X]  │ │   │ │    []    │ │        │ │    [[A, X], []]        │  │
207/// │ ├──────────┤ │   │ ├──────────┤ │        │ ├────────────────────────┤  │
208/// │ │[NULL, Y] │ │   │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │  │
209/// │ ├──────────┤ │   │ ├──────────┤ │        │ ├────────────────────────│  │
210/// │ │  [C, Z]  │ │   │ │   NULL   │ │        │ │    [[C, Z], NULL]      │  │
211/// │ └──────────┘ │   │ └──────────┘ │        │ └────────────────────────┘  │
212/// └──────────────┘   └──────────────┘        └─────────────────────────────┘
213///      col1               col2                         output
214/// ```
215fn array_array<O: OffsetSizeTrait>(
216    args: &[ArrayRef],
217    data_type: DataType,
218) -> Result<ArrayRef> {
219    // do not accept 0 arguments.
220    if args.is_empty() {
221        return plan_err!("Array requires at least one argument");
222    }
223
224    let mut data = vec![];
225    let mut total_len = 0;
226    for arg in args {
227        let arg_data = if arg.as_any().is::<NullArray>() {
228            ArrayData::new_empty(&data_type)
229        } else {
230            arg.to_data()
231        };
232        total_len += arg_data.len();
233        data.push(arg_data);
234    }
235
236    let mut offsets: Vec<O> = Vec::with_capacity(total_len);
237    offsets.push(O::usize_as(0));
238
239    let capacity = Capacities::Array(total_len);
240    let data_ref = data.iter().collect::<Vec<_>>();
241    let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity);
242
243    let num_rows = args[0].len();
244    for row_idx in 0..num_rows {
245        for (arr_idx, arg) in args.iter().enumerate() {
246            if !arg.as_any().is::<NullArray>()
247                && !arg.is_null(row_idx)
248                && arg.is_valid(row_idx)
249            {
250                mutable.extend(arr_idx, row_idx, row_idx + 1);
251            } else {
252                mutable.extend_nulls(1);
253            }
254        }
255        offsets.push(O::usize_as(mutable.len()));
256    }
257    let data = mutable.freeze();
258
259    Ok(Arc::new(GenericListArray::<O>::try_new(
260        Arc::new(Field::new(ARRAY_FIELD_DEFAULT_NAME, data_type, true)),
261        OffsetBuffer::new(offsets.into()),
262        make_array(data),
263        None,
264    )?))
265}