datafusion_functions_nested/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! array function utils
19
20use std::sync::Arc;
21
22use arrow::datatypes::{DataType, Field, Fields};
23
24use arrow::array::{
25    Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array,
26};
27use arrow::buffer::OffsetBuffer;
28use datafusion_common::cast::{
29    as_fixed_size_list_array, as_large_list_array, as_list_array,
30};
31use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
32
33use datafusion_expr::ColumnarValue;
34use itertools::Itertools as _;
35
36pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
37    let data_type = args[0].data_type();
38    if !args.iter().all(|arg| {
39        arg.data_type().equals_datatype(data_type)
40            || arg.data_type().equals_datatype(&DataType::Null)
41    }) {
42        let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
43        return plan_err!(
44            "{name} received incompatible types: {}",
45            types.iter().join(", ")
46        );
47    }
48
49    Ok(())
50}
51
52/// array function wrapper that differentiates between scalar (length 1) and array.
53pub(crate) fn make_scalar_function<F>(
54    inner: F,
55) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
56where
57    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
58{
59    move |args: &[ColumnarValue]| {
60        // first, identify if any of the arguments is an Array. If yes, store its `len`,
61        // as any scalar will need to be converted to an array of len `len`.
62        let len = args
63            .iter()
64            .fold(Option::<usize>::None, |acc, arg| match arg {
65                ColumnarValue::Scalar(_) => acc,
66                ColumnarValue::Array(a) => Some(a.len()),
67            });
68
69        let is_scalar = len.is_none();
70
71        let args = ColumnarValue::values_to_arrays(args)?;
72
73        let result = (inner)(&args);
74
75        if is_scalar {
76            // If all inputs are scalar, keeps output as scalar
77            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
78            result.map(ColumnarValue::Scalar)
79        } else {
80            result.map(ColumnarValue::Array)
81        }
82    }
83}
84
85pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
86    args: Vec<ArrayRef>,
87) -> Result<Vec<ArrayRef>> {
88    let args_ndim = args
89        .iter()
90        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
91        .collect::<Vec<_>>();
92    let max_ndim = args_ndim.iter().max().unwrap_or(&0);
93
94    // Align the dimensions of the arrays
95    let aligned_args: Result<Vec<ArrayRef>> = args
96        .into_iter()
97        .zip(args_ndim.iter())
98        .map(|(array, ndim)| {
99            if ndim < max_ndim {
100                let mut aligned_array = Arc::clone(&array);
101                for _ in 0..(max_ndim - ndim) {
102                    let data_type = aligned_array.data_type().to_owned();
103                    let array_lengths = vec![1; aligned_array.len()];
104                    let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
105
106                    aligned_array = Arc::new(GenericListArray::<O>::try_new(
107                        Arc::new(Field::new_list_field(data_type, true)),
108                        offsets,
109                        aligned_array,
110                        None,
111                    )?)
112                }
113                Ok(aligned_array)
114            } else {
115                Ok(Arc::clone(&array))
116            }
117        })
118        .collect();
119
120    aligned_args
121}
122
123/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
124///
125/// # Arguments
126///
127/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
128///
129/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
130///
131/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
132///
133/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
134///
135/// # Returns
136///
137/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
138///
139/// # Example
140///
141/// ```text
142/// compare_element_to_list(
143///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
144///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
145///
146///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
147///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
148/// )
149/// ```
150pub(crate) fn compare_element_to_list(
151    list_array_row: &dyn Array,
152    element_array: &dyn Array,
153    row_index: usize,
154    eq: bool,
155) -> Result<BooleanArray> {
156    if list_array_row.data_type() != element_array.data_type() {
157        return exec_err!(
158            "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
159            list_array_row.data_type(),
160            element_array.data_type()
161        );
162    }
163
164    let indices = UInt32Array::from(vec![row_index as u32]);
165    let element_array_row = arrow::compute::take(element_array, &indices, None)?;
166
167    // Compute all positions in list_row_array (that is itself an
168    // array) that are equal to `from_array_row`
169    let res = match element_array_row.data_type() {
170        // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
171        DataType::List(_) => {
172            // compare each element of the from array
173            let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
174            let list_array_row_inner = as_list_array(list_array_row)?;
175
176            list_array_row_inner
177                .iter()
178                // compare element by element the current row of list_array
179                .map(|row| {
180                    row.map(|row| {
181                        if eq {
182                            row.eq(&element_array_row_inner)
183                        } else {
184                            row.ne(&element_array_row_inner)
185                        }
186                    })
187                })
188                .collect::<BooleanArray>()
189        }
190        DataType::LargeList(_) => {
191            // compare each element of the from array
192            let element_array_row_inner =
193                as_large_list_array(&element_array_row)?.value(0);
194            let list_array_row_inner = as_large_list_array(list_array_row)?;
195
196            list_array_row_inner
197                .iter()
198                // compare element by element the current row of list_array
199                .map(|row| {
200                    row.map(|row| {
201                        if eq {
202                            row.eq(&element_array_row_inner)
203                        } else {
204                            row.ne(&element_array_row_inner)
205                        }
206                    })
207                })
208                .collect::<BooleanArray>()
209        }
210        _ => {
211            let element_arr = Scalar::new(element_array_row);
212            // use not_distinct so we can compare NULL
213            if eq {
214                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
215            } else {
216                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
217            }
218        }
219    };
220
221    Ok(res)
222}
223
224/// Returns the length of each array dimension
225pub(crate) fn compute_array_dims(
226    arr: Option<ArrayRef>,
227) -> Result<Option<Vec<Option<u64>>>> {
228    let mut value = match arr {
229        Some(arr) => arr,
230        None => return Ok(None),
231    };
232    if value.is_empty() {
233        return Ok(None);
234    }
235    let mut res = vec![Some(value.len() as u64)];
236
237    loop {
238        match value.data_type() {
239            DataType::List(_) => {
240                value = as_list_array(&value)?.value(0);
241                res.push(Some(value.len() as u64));
242            }
243            DataType::LargeList(_) => {
244                value = as_large_list_array(&value)?.value(0);
245                res.push(Some(value.len() as u64));
246            }
247            DataType::FixedSizeList(..) => {
248                value = as_fixed_size_list_array(&value)?.value(0);
249                res.push(Some(value.len() as u64));
250            }
251            _ => return Ok(Some(res)),
252        }
253    }
254}
255
256pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
257    match data_type {
258        DataType::Map(field, _) => {
259            let field_data_type = field.data_type();
260            match field_data_type {
261                DataType::Struct(fields) => Ok(fields),
262                _ => {
263                    internal_err!("Expected a Struct type, got {:?}", field_data_type)
264                }
265            }
266        }
267        _ => internal_err!("Expected a Map type, got {data_type}"),
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use arrow::array::ListArray;
275    use arrow::datatypes::Int64Type;
276    use datafusion_common::utils::SingleRowListArrayBuilder;
277
278    /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
279    #[test]
280    fn test_align_array_dimensions() {
281        let array1d_1: ArrayRef =
282            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
283                Some(vec![Some(1), Some(2), Some(3)]),
284                Some(vec![Some(4), Some(5)]),
285            ]));
286        let array1d_2: ArrayRef =
287            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
288                Some(vec![Some(6), Some(7), Some(8)]),
289            ]));
290
291        let array2d_1: ArrayRef = Arc::new(
292            SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
293        );
294        let array2d_2 = Arc::new(
295            SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
296        );
297
298        let res = align_array_dimensions::<i32>(vec![
299            array1d_1.to_owned(),
300            array2d_2.to_owned(),
301        ])
302        .unwrap();
303
304        let expected = as_list_array(&array2d_1).unwrap();
305        let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
306        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
307        assert_eq!(
308            datafusion_common::utils::list_ndims(res[0].data_type()),
309            expected_dim
310        );
311
312        let array3d_1: ArrayRef =
313            Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
314        let array3d_2: ArrayRef =
315            Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
316        let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
317
318        let expected = as_list_array(&array3d_1).unwrap();
319        let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
320        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
321        assert_eq!(
322            datafusion_common::utils::list_ndims(res[0].data_type()),
323            expected_dim
324        );
325    }
326}