datafusion_functions_nested/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! array function utils
19
20use std::sync::Arc;
21
22use arrow::datatypes::{DataType, Field, Fields};
23
24use arrow::array::{
25    Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar,
26    UInt32Array,
27};
28use arrow::buffer::OffsetBuffer;
29use datafusion_common::cast::{as_large_list_array, as_list_array};
30use datafusion_common::{
31    exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
32};
33
34use datafusion_expr::ColumnarValue;
35use datafusion_functions::{downcast_arg, downcast_named_arg};
36
37pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
38    let data_type = args[0].data_type();
39    if !args.iter().all(|arg| {
40        arg.data_type().equals_datatype(data_type)
41            || arg.data_type().equals_datatype(&DataType::Null)
42    }) {
43        let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
44        return plan_err!("{name} received incompatible types: '{types:?}'.");
45    }
46
47    Ok(())
48}
49
50/// array function wrapper that differentiates between scalar (length 1) and array.
51pub(crate) fn make_scalar_function<F>(
52    inner: F,
53) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
54where
55    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
56{
57    move |args: &[ColumnarValue]| {
58        // first, identify if any of the arguments is an Array. If yes, store its `len`,
59        // as any scalar will need to be converted to an array of len `len`.
60        let len = args
61            .iter()
62            .fold(Option::<usize>::None, |acc, arg| match arg {
63                ColumnarValue::Scalar(_) => acc,
64                ColumnarValue::Array(a) => Some(a.len()),
65            });
66
67        let is_scalar = len.is_none();
68
69        let args = ColumnarValue::values_to_arrays(args)?;
70
71        let result = (inner)(&args);
72
73        if is_scalar {
74            // If all inputs are scalar, keeps output as scalar
75            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
76            result.map(ColumnarValue::Scalar)
77        } else {
78            result.map(ColumnarValue::Array)
79        }
80    }
81}
82
83pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
84    args: Vec<ArrayRef>,
85) -> Result<Vec<ArrayRef>> {
86    let args_ndim = args
87        .iter()
88        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
89        .collect::<Vec<_>>();
90    let max_ndim = args_ndim.iter().max().unwrap_or(&0);
91
92    // Align the dimensions of the arrays
93    let aligned_args: Result<Vec<ArrayRef>> = args
94        .into_iter()
95        .zip(args_ndim.iter())
96        .map(|(array, ndim)| {
97            if ndim < max_ndim {
98                let mut aligned_array = Arc::clone(&array);
99                for _ in 0..(max_ndim - ndim) {
100                    let data_type = aligned_array.data_type().to_owned();
101                    let array_lengths = vec![1; aligned_array.len()];
102                    let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
103
104                    aligned_array = Arc::new(GenericListArray::<O>::try_new(
105                        Arc::new(Field::new_list_field(data_type, true)),
106                        offsets,
107                        aligned_array,
108                        None,
109                    )?)
110                }
111                Ok(aligned_array)
112            } else {
113                Ok(Arc::clone(&array))
114            }
115        })
116        .collect();
117
118    aligned_args
119}
120
121/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
122///
123/// # Arguments
124///
125/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
126///
127/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
128///
129/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
130///
131/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
132///
133/// # Returns
134///
135/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
136///
137/// # Example
138///
139/// ```text
140/// compare_element_to_list(
141///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
142///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
143///
144///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
145///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
146/// )
147/// ```
148pub(crate) fn compare_element_to_list(
149    list_array_row: &dyn Array,
150    element_array: &dyn Array,
151    row_index: usize,
152    eq: bool,
153) -> Result<BooleanArray> {
154    if list_array_row.data_type() != element_array.data_type() {
155        return exec_err!(
156            "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
157            list_array_row.data_type(),
158            element_array.data_type()
159        );
160    }
161
162    let indices = UInt32Array::from(vec![row_index as u32]);
163    let element_array_row = arrow::compute::take(element_array, &indices, None)?;
164
165    // Compute all positions in list_row_array (that is itself an
166    // array) that are equal to `from_array_row`
167    let res = match element_array_row.data_type() {
168        // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
169        DataType::List(_) => {
170            // compare each element of the from array
171            let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
172            let list_array_row_inner = as_list_array(list_array_row)?;
173
174            list_array_row_inner
175                .iter()
176                // compare element by element the current row of list_array
177                .map(|row| {
178                    row.map(|row| {
179                        if eq {
180                            row.eq(&element_array_row_inner)
181                        } else {
182                            row.ne(&element_array_row_inner)
183                        }
184                    })
185                })
186                .collect::<BooleanArray>()
187        }
188        DataType::LargeList(_) => {
189            // compare each element of the from array
190            let element_array_row_inner =
191                as_large_list_array(&element_array_row)?.value(0);
192            let list_array_row_inner = as_large_list_array(list_array_row)?;
193
194            list_array_row_inner
195                .iter()
196                // compare element by element the current row of list_array
197                .map(|row| {
198                    row.map(|row| {
199                        if eq {
200                            row.eq(&element_array_row_inner)
201                        } else {
202                            row.ne(&element_array_row_inner)
203                        }
204                    })
205                })
206                .collect::<BooleanArray>()
207        }
208        _ => {
209            let element_arr = Scalar::new(element_array_row);
210            // use not_distinct so we can compare NULL
211            if eq {
212                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
213            } else {
214                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
215            }
216        }
217    };
218
219    Ok(res)
220}
221
222/// Returns the length of each array dimension
223pub(crate) fn compute_array_dims(
224    arr: Option<ArrayRef>,
225) -> Result<Option<Vec<Option<u64>>>> {
226    let mut value = match arr {
227        Some(arr) => arr,
228        None => return Ok(None),
229    };
230    if value.is_empty() {
231        return Ok(None);
232    }
233    let mut res = vec![Some(value.len() as u64)];
234
235    loop {
236        match value.data_type() {
237            DataType::List(..) => {
238                value = downcast_arg!(value, ListArray).value(0);
239                res.push(Some(value.len() as u64));
240            }
241            _ => return Ok(Some(res)),
242        }
243    }
244}
245
246pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
247    match data_type {
248        DataType::Map(field, _) => {
249            let field_data_type = field.data_type();
250            match field_data_type {
251                DataType::Struct(fields) => Ok(fields),
252                _ => {
253                    internal_err!("Expected a Struct type, got {:?}", field_data_type)
254                }
255            }
256        }
257        _ => internal_err!("Expected a Map type, got {:?}", data_type),
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use arrow::datatypes::Int64Type;
265    use datafusion_common::utils::SingleRowListArrayBuilder;
266
267    /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
268    #[test]
269    fn test_align_array_dimensions() {
270        let array1d_1: ArrayRef =
271            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
272                Some(vec![Some(1), Some(2), Some(3)]),
273                Some(vec![Some(4), Some(5)]),
274            ]));
275        let array1d_2: ArrayRef =
276            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
277                Some(vec![Some(6), Some(7), Some(8)]),
278            ]));
279
280        let array2d_1: ArrayRef = Arc::new(
281            SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
282        );
283        let array2d_2 = Arc::new(
284            SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
285        );
286
287        let res = align_array_dimensions::<i32>(vec![
288            array1d_1.to_owned(),
289            array2d_2.to_owned(),
290        ])
291        .unwrap();
292
293        let expected = as_list_array(&array2d_1).unwrap();
294        let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
295        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
296        assert_eq!(
297            datafusion_common::utils::list_ndims(res[0].data_type()),
298            expected_dim
299        );
300
301        let array3d_1: ArrayRef =
302            Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
303        let array3d_2: ArrayRef =
304            Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
305        let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
306
307        let expected = as_list_array(&array3d_1).unwrap();
308        let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
309        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
310        assert_eq!(
311            datafusion_common::utils::list_ndims(res[0].data_type()),
312            expected_dim
313        );
314    }
315}