datafusion_functions_nested/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! array function utils
19
20use std::sync::Arc;
21
22use arrow::datatypes::{DataType, Field, Fields};
23
24use arrow::array::{
25    Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array,
26};
27use arrow::buffer::OffsetBuffer;
28use datafusion_common::cast::{
29    as_fixed_size_list_array, as_large_list_array, as_list_array,
30};
31use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
32
33use datafusion_expr::ColumnarValue;
34
35pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
36    let data_type = args[0].data_type();
37    if !args.iter().all(|arg| {
38        arg.data_type().equals_datatype(data_type)
39            || arg.data_type().equals_datatype(&DataType::Null)
40    }) {
41        let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
42        return plan_err!("{name} received incompatible types: '{types:?}'.");
43    }
44
45    Ok(())
46}
47
48/// array function wrapper that differentiates between scalar (length 1) and array.
49pub(crate) fn make_scalar_function<F>(
50    inner: F,
51) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
52where
53    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
54{
55    move |args: &[ColumnarValue]| {
56        // first, identify if any of the arguments is an Array. If yes, store its `len`,
57        // as any scalar will need to be converted to an array of len `len`.
58        let len = args
59            .iter()
60            .fold(Option::<usize>::None, |acc, arg| match arg {
61                ColumnarValue::Scalar(_) => acc,
62                ColumnarValue::Array(a) => Some(a.len()),
63            });
64
65        let is_scalar = len.is_none();
66
67        let args = ColumnarValue::values_to_arrays(args)?;
68
69        let result = (inner)(&args);
70
71        if is_scalar {
72            // If all inputs are scalar, keeps output as scalar
73            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
74            result.map(ColumnarValue::Scalar)
75        } else {
76            result.map(ColumnarValue::Array)
77        }
78    }
79}
80
81pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
82    args: Vec<ArrayRef>,
83) -> Result<Vec<ArrayRef>> {
84    let args_ndim = args
85        .iter()
86        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
87        .collect::<Vec<_>>();
88    let max_ndim = args_ndim.iter().max().unwrap_or(&0);
89
90    // Align the dimensions of the arrays
91    let aligned_args: Result<Vec<ArrayRef>> = args
92        .into_iter()
93        .zip(args_ndim.iter())
94        .map(|(array, ndim)| {
95            if ndim < max_ndim {
96                let mut aligned_array = Arc::clone(&array);
97                for _ in 0..(max_ndim - ndim) {
98                    let data_type = aligned_array.data_type().to_owned();
99                    let array_lengths = vec![1; aligned_array.len()];
100                    let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
101
102                    aligned_array = Arc::new(GenericListArray::<O>::try_new(
103                        Arc::new(Field::new_list_field(data_type, true)),
104                        offsets,
105                        aligned_array,
106                        None,
107                    )?)
108                }
109                Ok(aligned_array)
110            } else {
111                Ok(Arc::clone(&array))
112            }
113        })
114        .collect();
115
116    aligned_args
117}
118
119/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
120///
121/// # Arguments
122///
123/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
124///
125/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
126///
127/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
128///
129/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
130///
131/// # Returns
132///
133/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
134///
135/// # Example
136///
137/// ```text
138/// compare_element_to_list(
139///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
140///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
141///
142///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
143///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
144/// )
145/// ```
146pub(crate) fn compare_element_to_list(
147    list_array_row: &dyn Array,
148    element_array: &dyn Array,
149    row_index: usize,
150    eq: bool,
151) -> Result<BooleanArray> {
152    if list_array_row.data_type() != element_array.data_type() {
153        return exec_err!(
154            "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
155            list_array_row.data_type(),
156            element_array.data_type()
157        );
158    }
159
160    let indices = UInt32Array::from(vec![row_index as u32]);
161    let element_array_row = arrow::compute::take(element_array, &indices, None)?;
162
163    // Compute all positions in list_row_array (that is itself an
164    // array) that are equal to `from_array_row`
165    let res = match element_array_row.data_type() {
166        // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
167        DataType::List(_) => {
168            // compare each element of the from array
169            let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
170            let list_array_row_inner = as_list_array(list_array_row)?;
171
172            list_array_row_inner
173                .iter()
174                // compare element by element the current row of list_array
175                .map(|row| {
176                    row.map(|row| {
177                        if eq {
178                            row.eq(&element_array_row_inner)
179                        } else {
180                            row.ne(&element_array_row_inner)
181                        }
182                    })
183                })
184                .collect::<BooleanArray>()
185        }
186        DataType::LargeList(_) => {
187            // compare each element of the from array
188            let element_array_row_inner =
189                as_large_list_array(&element_array_row)?.value(0);
190            let list_array_row_inner = as_large_list_array(list_array_row)?;
191
192            list_array_row_inner
193                .iter()
194                // compare element by element the current row of list_array
195                .map(|row| {
196                    row.map(|row| {
197                        if eq {
198                            row.eq(&element_array_row_inner)
199                        } else {
200                            row.ne(&element_array_row_inner)
201                        }
202                    })
203                })
204                .collect::<BooleanArray>()
205        }
206        _ => {
207            let element_arr = Scalar::new(element_array_row);
208            // use not_distinct so we can compare NULL
209            if eq {
210                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
211            } else {
212                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
213            }
214        }
215    };
216
217    Ok(res)
218}
219
220/// Returns the length of each array dimension
221pub(crate) fn compute_array_dims(
222    arr: Option<ArrayRef>,
223) -> Result<Option<Vec<Option<u64>>>> {
224    let mut value = match arr {
225        Some(arr) => arr,
226        None => return Ok(None),
227    };
228    if value.is_empty() {
229        return Ok(None);
230    }
231    let mut res = vec![Some(value.len() as u64)];
232
233    loop {
234        match value.data_type() {
235            DataType::List(_) => {
236                value = as_list_array(&value)?.value(0);
237                res.push(Some(value.len() as u64));
238            }
239            DataType::LargeList(_) => {
240                value = as_large_list_array(&value)?.value(0);
241                res.push(Some(value.len() as u64));
242            }
243            DataType::FixedSizeList(..) => {
244                value = as_fixed_size_list_array(&value)?.value(0);
245                res.push(Some(value.len() as u64));
246            }
247            _ => return Ok(Some(res)),
248        }
249    }
250}
251
252pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
253    match data_type {
254        DataType::Map(field, _) => {
255            let field_data_type = field.data_type();
256            match field_data_type {
257                DataType::Struct(fields) => Ok(fields),
258                _ => {
259                    internal_err!("Expected a Struct type, got {:?}", field_data_type)
260                }
261            }
262        }
263        _ => internal_err!("Expected a Map type, got {:?}", data_type),
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use arrow::array::ListArray;
271    use arrow::datatypes::Int64Type;
272    use datafusion_common::utils::SingleRowListArrayBuilder;
273
274    /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
275    #[test]
276    fn test_align_array_dimensions() {
277        let array1d_1: ArrayRef =
278            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
279                Some(vec![Some(1), Some(2), Some(3)]),
280                Some(vec![Some(4), Some(5)]),
281            ]));
282        let array1d_2: ArrayRef =
283            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
284                Some(vec![Some(6), Some(7), Some(8)]),
285            ]));
286
287        let array2d_1: ArrayRef = Arc::new(
288            SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
289        );
290        let array2d_2 = Arc::new(
291            SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
292        );
293
294        let res = align_array_dimensions::<i32>(vec![
295            array1d_1.to_owned(),
296            array2d_2.to_owned(),
297        ])
298        .unwrap();
299
300        let expected = as_list_array(&array2d_1).unwrap();
301        let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
302        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
303        assert_eq!(
304            datafusion_common::utils::list_ndims(res[0].data_type()),
305            expected_dim
306        );
307
308        let array3d_1: ArrayRef =
309            Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
310        let array3d_2: ArrayRef =
311            Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
312        let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
313
314        let expected = as_list_array(&array3d_1).unwrap();
315        let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
316        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
317        assert_eq!(
318            datafusion_common::utils::list_ndims(res[0].data_type()),
319            expected_dim
320        );
321    }
322}