datafusion_spark/function/functions_nested_utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::ColumnarValue;
21
22/// array function wrapper that differentiates between scalar (length 1) and array.
23pub(crate) fn make_scalar_function<F>(
24 inner: F,
25) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
26where
27 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
28{
29 move |args: &[ColumnarValue]| {
30 // first, identify if any of the arguments is an Array. If yes, store its `len`,
31 // as any scalar will need to be converted to an array of len `len`.
32 let len = args
33 .iter()
34 .fold(Option::<usize>::None, |acc, arg| match arg {
35 ColumnarValue::Scalar(_) => acc,
36 ColumnarValue::Array(a) => Some(a.len()),
37 });
38
39 let is_scalar = len.is_none();
40
41 let args = ColumnarValue::values_to_arrays(args)?;
42
43 let result = (inner)(&args);
44
45 if is_scalar {
46 // If all inputs are scalar, keeps output as scalar
47 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
48 result.map(ColumnarValue::Scalar)
49 } else {
50 result.map(ColumnarValue::Array)
51 }
52 }
53}