Skip to main content

datafusion_functions_nested/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! array function utils
19
20use std::sync::Arc;
21
22use arrow::datatypes::{DataType, Field, Fields};
23
24use arrow::array::{
25    Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
26};
27use arrow::buffer::OffsetBuffer;
28use datafusion_common::cast::{
29    as_fixed_size_list_array, as_large_list_array, as_list_array,
30};
31use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
32
33use datafusion_expr::ColumnarValue;
34use itertools::Itertools as _;
35
36pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
37    let data_type = args[0].data_type();
38    if !args.iter().all(|arg| {
39        arg.data_type().equals_datatype(data_type)
40            || arg.data_type().equals_datatype(&DataType::Null)
41    }) {
42        let types = args.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
43        return plan_err!(
44            "{name} received incompatible types: {}",
45            types.iter().join(", ")
46        );
47    }
48
49    Ok(())
50}
51
52/// array function wrapper that differentiates between scalar (length 1) and array.
53pub(crate) fn make_scalar_function<F>(
54    inner: F,
55) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
56where
57    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
58{
59    move |args: &[ColumnarValue]| {
60        // first, identify if any of the arguments is an Array. If yes, store its `len`,
61        // as any scalar will need to be converted to an array of len `len`.
62        let len = args
63            .iter()
64            .fold(Option::<usize>::None, |acc, arg| match arg {
65                ColumnarValue::Scalar(_) => acc,
66                ColumnarValue::Array(a) => Some(a.len()),
67            });
68
69        let is_scalar = len.is_none();
70
71        let args = ColumnarValue::values_to_arrays(args)?;
72
73        let result = (inner)(&args);
74
75        if is_scalar {
76            // If all inputs are scalar, keeps output as scalar
77            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
78            result.map(ColumnarValue::Scalar)
79        } else {
80            result.map(ColumnarValue::Array)
81        }
82    }
83}
84
85pub(crate) fn align_array_dimensions<O: OffsetSizeTrait>(
86    args: Vec<ArrayRef>,
87) -> Result<Vec<ArrayRef>> {
88    let args_ndim = args
89        .iter()
90        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
91        .collect::<Vec<_>>();
92    let max_ndim = args_ndim.iter().max().unwrap_or(&0);
93
94    // Align the dimensions of the arrays
95    let aligned_args: Result<Vec<ArrayRef>> = args
96        .into_iter()
97        .zip(args_ndim.iter())
98        .map(|(array, ndim)| {
99            if ndim < max_ndim {
100                let mut aligned_array = Arc::clone(&array);
101                for _ in 0..(max_ndim - ndim) {
102                    let data_type = aligned_array.data_type().to_owned();
103                    let array_lengths = vec![1; aligned_array.len()];
104                    let offsets = OffsetBuffer::<O>::from_lengths(array_lengths);
105
106                    aligned_array = Arc::new(GenericListArray::<O>::try_new(
107                        Arc::new(Field::new_list_field(data_type, true)),
108                        offsets,
109                        aligned_array,
110                        None,
111                    )?)
112                }
113                Ok(aligned_array)
114            } else {
115                Ok(Arc::clone(&array))
116            }
117        })
118        .collect();
119
120    aligned_args
121}
122
123/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
124///
125/// # Arguments
126///
127/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
128///
129/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
130///
131/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
132///
133/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
134///
135/// # Returns
136///
137/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
138///
139/// # Example
140///
141/// ```text
142/// compare_element_to_list(
143///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
144///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
145///
146///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
147///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
148/// )
149/// ```
150pub(crate) fn compare_element_to_list(
151    list_array_row: &dyn Array,
152    element_array: &dyn Array,
153    row_index: usize,
154    eq: bool,
155) -> Result<BooleanArray> {
156    if list_array_row.data_type() != element_array.data_type() {
157        return exec_err!(
158            "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
159            list_array_row.data_type(),
160            element_array.data_type()
161        );
162    }
163
164    let element_array_row = element_array.slice(row_index, 1);
165
166    // Compute all positions in list_row_array (that is itself an
167    // array) that are equal to `from_array_row`
168    let res = match element_array_row.data_type() {
169        // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
170        DataType::List(_) => {
171            // compare each element of the from array
172            let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
173            let list_array_row_inner = as_list_array(list_array_row)?;
174
175            list_array_row_inner
176                .iter()
177                // compare element by element the current row of list_array
178                .map(|row| {
179                    row.map(|row| {
180                        if eq {
181                            row.eq(&element_array_row_inner)
182                        } else {
183                            row.ne(&element_array_row_inner)
184                        }
185                    })
186                })
187                .collect::<BooleanArray>()
188        }
189        DataType::LargeList(_) => {
190            // compare each element of the from array
191            let element_array_row_inner =
192                as_large_list_array(&element_array_row)?.value(0);
193            let list_array_row_inner = as_large_list_array(list_array_row)?;
194
195            list_array_row_inner
196                .iter()
197                // compare element by element the current row of list_array
198                .map(|row| {
199                    row.map(|row| {
200                        if eq {
201                            row.eq(&element_array_row_inner)
202                        } else {
203                            row.ne(&element_array_row_inner)
204                        }
205                    })
206                })
207                .collect::<BooleanArray>()
208        }
209        _ => {
210            let element_arr = Scalar::new(element_array_row);
211            // use not_distinct so we can compare NULL
212            if eq {
213                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
214            } else {
215                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
216            }
217        }
218    };
219
220    Ok(res)
221}
222
223/// Returns the length of each array dimension
224pub(crate) fn compute_array_dims(
225    arr: Option<ArrayRef>,
226) -> Result<Option<Vec<Option<u64>>>> {
227    let mut value = match arr {
228        Some(arr) => arr,
229        None => return Ok(None),
230    };
231    if value.is_empty() {
232        return Ok(None);
233    }
234    let mut res = vec![Some(value.len() as u64)];
235
236    loop {
237        match value.data_type() {
238            DataType::List(_) => {
239                value = as_list_array(&value)?.value(0);
240                res.push(Some(value.len() as u64));
241            }
242            DataType::LargeList(_) => {
243                value = as_large_list_array(&value)?.value(0);
244                res.push(Some(value.len() as u64));
245            }
246            DataType::FixedSizeList(..) => {
247                value = as_fixed_size_list_array(&value)?.value(0);
248                res.push(Some(value.len() as u64));
249            }
250            _ => return Ok(Some(res)),
251        }
252    }
253}
254
255pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
256    match data_type {
257        DataType::Map(field, _) => {
258            let field_data_type = field.data_type();
259            match field_data_type {
260                DataType::Struct(fields) => Ok(fields),
261                _ => {
262                    internal_err!("Expected a Struct type, got {}", field_data_type)
263                }
264            }
265        }
266        _ => internal_err!("Expected a Map type, got {data_type}"),
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use arrow::array::ListArray;
274    use arrow::datatypes::Int64Type;
275    use datafusion_common::utils::SingleRowListArrayBuilder;
276
277    /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt`
278    #[test]
279    fn test_align_array_dimensions() {
280        let array1d_1: ArrayRef =
281            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
282                Some(vec![Some(1), Some(2), Some(3)]),
283                Some(vec![Some(4), Some(5)]),
284            ]));
285        let array1d_2: ArrayRef =
286            Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
287                Some(vec![Some(6), Some(7), Some(8)]),
288            ]));
289
290        let array2d_1: ArrayRef = Arc::new(
291            SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(),
292        );
293        let array2d_2 = Arc::new(
294            SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(),
295        );
296
297        let res = align_array_dimensions::<i32>(vec![
298            array1d_1.to_owned(),
299            array2d_2.to_owned(),
300        ])
301        .unwrap();
302
303        let expected = as_list_array(&array2d_1).unwrap();
304        let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
305        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
306        assert_eq!(
307            datafusion_common::utils::list_ndims(res[0].data_type()),
308            expected_dim
309        );
310
311        let array3d_1: ArrayRef =
312            Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array());
313        let array3d_2: ArrayRef =
314            Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array());
315        let res = align_array_dimensions::<i32>(vec![array1d_1, array3d_2]).unwrap();
316
317        let expected = as_list_array(&array3d_1).unwrap();
318        let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
319        assert_ne!(as_list_array(&res[0]).unwrap(), expected);
320        assert_eq!(
321            datafusion_common::utils::list_ndims(res[0].data_type()),
322            expected_dim
323        );
324    }
325}