Skip to main content

datafusion_functions_nested/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! array function utils
19
20use std::sync::Arc;
21
22use arrow::datatypes::{DataType, Field, Fields};
23
24use arrow::array::{
25    Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
26};
27use arrow::buffer::OffsetBuffer;
28use datafusion_common::cast::{
29    as_fixed_size_list_array, as_large_list_array, as_large_list_view_array,
30    as_list_array, as_list_view_array,
31};
32use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
33
34use datafusion_expr::ColumnarValue;
35use itertools::Itertools as _;
36
37pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
38    let data_type = args[0].data_type();
39    if !args.iter().all(|arg| {
40        arg.data_type().equals_datatype(data_type)
41            || arg.data_type().equals_datatype(&DataType::Null)
42    }) {
43        let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
44        return plan_err!(
45            "{name} received incompatible types: {}",
46            types.iter().join(", ")
47        );
48    }
49
50    Ok(())
51}
52
53/// array function wrapper that differentiates between scalar (length 1) and array.
54pub(crate) fn make_scalar_function<F>(
55    inner: F,
56) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
57where
58    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
59{
60    move |args: &[ColumnarValue]| {
61        // first, identify if any of the arguments is an Array. If yes, store its `len`,
62        // as any scalar will need to be converted to an array of len `len`.
63        let len = args
64            .iter()
65            .fold(Option::<usize>::None, |acc, arg| match arg {
66                ColumnarValue::Scalar(_) => acc,
67                ColumnarValue::Array(a) => Some(a.len()),
68            });
69
70        let is_scalar = len.is_none();
71
72        let args = ColumnarValue::values_to_arrays(args)?;
73
74        let result = (inner)(&args);
75
76        if is_scalar {
77            // If all inputs are scalar, keeps output as scalar
78            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
79            result.map(ColumnarValue::Scalar)
80        } else {
81            result.map(ColumnarValue::Array)
82        }
83    }
84}
85
86pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
87    args: Vec<ArrayRef>,
88) -> Result<Vec<ArrayRef>> {
89    let args_ndim = args
90        .iter()
91        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
92        .collect::<Vec<_>>();
93    let max_ndim = args_ndim.iter().max().unwrap_or(&0);
94
95    // Align the dimensions of the arrays
96    let aligned_args: Result<Vec<ArrayRef>> = args
97        .into_iter()
98        .zip(args_ndim.iter())
99        .map(|(array, ndim)| {
100            if ndim < max_ndim {
101                let mut aligned_array = Arc::clone(&array);
102                for _ in 0..(max_ndim - ndim) {
103                    let data_type = aligned_array.data_type().to_owned();
104                    let array_lengths = vec![1; aligned_array.len()];
105                    let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
106
107                    aligned_array = Arc::new(GenericListArray::<O>::try_new(
108                        Arc::new(Field::new_list_field(data_type, true)),
109                        offsets,
110                        aligned_array,
111                        None,
112                    )?)
113                }
114                Ok(aligned_array)
115            } else {
116                Ok(Arc::clone(&array))
117            }
118        })
119        .collect();
120
121    aligned_args
122}
123
124/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
125///
126/// # Arguments
127///
128/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
129///
130/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
131///
132/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
133///
134/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
135///
136/// # Returns
137///
138/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
139///
140/// # Example
141///
142/// ```text
143/// compare_element_to_list(
144///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
145///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
146///
147///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
148///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
149/// )
150/// ```
151pub(crate) fn compare_element_to_list(
152    list_array_row: &dyn Array,
153    element_array: &dyn Array,
154    row_index: usize,
155    eq: bool,
156) -> Result<BooleanArray> {
157    if list_array_row.data_type() != element_array.data_type() {
158        return exec_err!(
159            "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
160            list_array_row.data_type(),
161            element_array.data_type()
162        );
163    }
164
165    let element_array_row = element_array.slice(row_index, 1);
166
167    // Compute all positions in list_row_array (that is itself an
168    // array) that are equal to `from_array_row`
169    let res = match element_array_row.data_type() {
170        // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
171        DataType::List(_) => {
172            // compare each element of the from array
173            let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
174            let list_array_row_inner = as_list_array(list_array_row)?;
175
176            list_array_row_inner
177                .iter()
178                // compare element by element the current row of list_array
179                .map(|row| {
180                    row.map(|row| {
181                        if eq {
182                            row.eq(&element_array_row_inner)
183                        } else {
184                            row.ne(&element_array_row_inner)
185                        }
186                    })
187                })
188                .collect::<BooleanArray>()
189        }
190        DataType::LargeList(_) => {
191            // compare each element of the from array
192            let element_array_row_inner =
193                as_large_list_array(&element_array_row)?.value(0);
194            let list_array_row_inner = as_large_list_array(list_array_row)?;
195
196            list_array_row_inner
197                .iter()
198                // compare element by element the current row of list_array
199                .map(|row| {
200                    row.map(|row| {
201                        if eq {
202                            row.eq(&element_array_row_inner)
203                        } else {
204                            row.ne(&element_array_row_inner)
205                        }
206                    })
207                })
208                .collect::<BooleanArray>()
209        }
210        _ => {
211            let element_arr = Scalar::new(element_array_row);
212            // use not_distinct so we can compare NULL
213            if eq {
214                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
215            } else {
216                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
217            }
218        }
219    };
220
221    Ok(res)
222}
223
224/// Returns the length of each array dimension
225pub(crate) fn compute_array_dims(
226    arr: Option<ArrayRef>,
227) -> Result<Option<Vec<Option<u64>>>> {
228    let mut value = match arr {
229        Some(arr) => arr,
230        None => return Ok(None),
231    };
232    if value.is_empty() {
233        return Ok(None);
234    }
235    let mut res = vec![Some(value.len() as u64)];
236
237    loop {
238        match value.data_type() {
239            DataType::List(_) => {
240                value = as_list_array(&value)?.value(0);
241                res.push(Some(value.len() as u64));
242            }
243            DataType::LargeList(_) => {
244                value = as_large_list_array(&value)?.value(0);
245                res.push(Some(value.len() as u64));
246            }
247            DataType::ListView(_) => {
248                value = as_list_view_array(&value)?.value(0);
249                res.push(Some(value.len() as u64));
250            }
251            DataType::LargeListView(_) => {
252                value = as_large_list_view_array(&value)?.value(0);
253                res.push(Some(value.len() as u64));
254            }
255            DataType::FixedSizeList(..) => {
256                value = as_fixed_size_list_array(&value)?.value(0);
257                res.push(Some(value.len() as u64));
258            }
259            _ => return Ok(Some(res)),
260        }
261    }
262}
263
264pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
265    match data_type {
266        DataType::Map(field, _) => {
267            let field_data_type = field.data_type();
268            match field_data_type {
269                DataType::Struct(fields) => Ok(fields),
270                _ => {
271                    internal_err!("Expected a Struct type, got {}", field_data_type)
272                }
273            }
274        }
275        _ => internal_err!("Expected a Map type, got {data_type}"),
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use arrow::array::ListArray;
283    use arrow::datatypes::Int64Type;
284    use datafusion_common::utils::SingleRowListArrayBuilder;
285
286    /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
287    #[test]
288    fn test_align_array_dimensions() {
289        let array1d_1: ArrayRef =
290            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
291                Some(vec![Some(1), Some(2), Some(3)]),
292                Some(vec![Some(4), Some(5)]),
293            ]));
294        let array1d_2: ArrayRef =
295            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
296                Some(vec![Some(6), Some(7), Some(8)]),
297            ]));
298
299        let array2d_1: ArrayRef = Arc::new(
300            SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
301        );
302        let array2d_2 = Arc::new(
303            SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
304        );
305
306        let res = align_array_dimensions::<i32>(vec![
307            array1d_1.to_owned(),
308            array2d_2.to_owned(),
309        ])
310        .unwrap();
311
312        let expected = as_list_array(&array2d_1).unwrap();
313        let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
314        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
315        assert_eq!(
316            datafusion_common::utils::list_ndims(res[0].data_type()),
317            expected_dim
318        );
319
320        let array3d_1: ArrayRef =
321            Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
322        let array3d_2: ArrayRef =
323            Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
324        let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
325
326        let expected = as_list_array(&array3d_1).unwrap();
327        let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
328        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
329        assert_eq!(
330            datafusion_common::utils::list_ndims(res[0].data_type()),
331            expected_dim
332        );
333    }
334}