datafusion_spark/function/
functions_nested_utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::ColumnarValue;
21
22/// array function wrapper that differentiates between scalar (length 1) and array.
23pub(crate) fn make_scalar_function<F>(
24    inner: F,
25) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
26where
27    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
28{
29    move |args: &[ColumnarValue]| {
30        // first, identify if any of the arguments is an Array. If yes, store its `len`,
31        // as any scalar will need to be converted to an array of len `len`.
32        let len = args
33            .iter()
34            .fold(Option::<usize>::None, |acc, arg| match arg {
35                ColumnarValue::Scalar(_) => acc,
36                ColumnarValue::Array(a) => Some(a.len()),
37            });
38
39        let is_scalar = len.is_none();
40
41        let args = ColumnarValue::values_to_arrays(args)?;
42
43        let result = (inner)(&args);
44
45        if is_scalar {
46            // If all inputs are scalar, keeps output as scalar
47            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
48            result.map(ColumnarValue::Scalar)
49        } else {
50            result.map(ColumnarValue::Array)
51        }
52    }
53}