datafusion_functions/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrayRef;
19use arrow::datatypes::DataType;
20
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::Hint;
23use datafusion_expr::ColumnarValue;
24
25/// Creates a function to identify the optimal return type of a string function given
26/// the type of its first argument.
27///
28/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
29/// `$largeUtf8Type`,
30///
31/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
32///
33/// If the input type is `Utf8View` the return type is $utf8Type,
34macro_rules! get_optimal_return_type {
35    ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
36        pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
37            Ok(match arg_type {
38                // LargeBinary inputs are automatically coerced to Utf8
39                DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
40                // Binary inputs are automatically coerced to Utf8
41                DataType::Utf8 | DataType::Binary => $utf8Type,
42                // Utf8View max offset size is u32::MAX, the same as UTF8
43                DataType::Utf8View | DataType::BinaryView => $utf8Type,
44                DataType::Null => DataType::Null,
45                DataType::Dictionary(_, value_type) => match **value_type {
46                    DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
47                    DataType::Utf8 | DataType::Binary => $utf8Type,
48                    DataType::Null => DataType::Null,
49                    _ => {
50                        return datafusion_common::exec_err!(
51                            "The {} function can only accept strings, but got {:?}.",
52                            name.to_uppercase(),
53                            **value_type
54                        );
55                    }
56                },
57                data_type => {
58                    return datafusion_common::exec_err!(
59                        "The {} function can only accept strings, but got {:?}.",
60                        name.to_uppercase(),
61                        data_type
62                    );
63                }
64            })
65        }
66    };
67}
68
69// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
70get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
71
72// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
73get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
74
75/// Creates a scalar function implementation for the given function.
76/// * `inner` - the function to be executed
77/// * `hints` - hints to be used when expanding scalars to arrays
78pub fn make_scalar_function<F>(
79    inner: F,
80    hints: Vec<Hint>,
81) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
82where
83    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
84{
85    move |args: &[ColumnarValue]| {
86        // first, identify if any of the arguments is an Array. If yes, store its `len`,
87        // as any scalar will need to be converted to an array of len `len`.
88        let len = args
89            .iter()
90            .fold(Option::<usize>::None, |acc, arg| match arg {
91                ColumnarValue::Scalar(_) => acc,
92                ColumnarValue::Array(a) => Some(a.len()),
93            });
94
95        let is_scalar = len.is_none();
96
97        let inferred_length = len.unwrap_or(1);
98        let args = args
99            .iter()
100            .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
101            .map(|(arg, hint)| {
102                // Decide on the length to expand this scalar to depending
103                // on the given hints.
104                let expansion_len = match hint {
105                    Hint::AcceptsSingular => 1,
106                    Hint::Pad => inferred_length,
107                };
108                arg.to_array(expansion_len)
109            })
110            .collect::<Result<Vec<_>>>()?;
111
112        let result = (inner)(&args);
113        if is_scalar {
114            // If all inputs are scalar, keeps output as scalar
115            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
116            result.map(ColumnarValue::Scalar)
117        } else {
118            result.map(ColumnarValue::Array)
119        }
120    }
121}
122
123#[cfg(test)]
124pub mod test {
125    /// $FUNC ScalarUDFImpl to test
126    /// $ARGS arguments (vec) to pass to function
127    /// $EXPECTED a Result<ColumnarValue>
128    /// $EXPECTED_TYPE is the expected value type
129    /// $EXPECTED_DATA_TYPE is the expected result type
130    /// $ARRAY_TYPE is the column type after function applied
131    macro_rules! test_function {
132        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
133            let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
134            let func = $FUNC;
135
136            let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
137            let cardinality = $ARGS
138                .iter()
139                .fold(Option::<usize>::None, |acc, arg| match arg {
140                    ColumnarValue::Scalar(_) => acc,
141                    ColumnarValue::Array(a) => Some(a.len()),
142                })
143                .unwrap_or(1);
144
145            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
146                ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
147                ColumnarValue::Array(_) => None,
148            }).collect::<Vec<_>>();
149            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
150
151            let nullables = $ARGS.iter().map(|arg| match arg {
152                ColumnarValue::Scalar(scalar) => scalar.is_null(),
153                ColumnarValue::Array(a) => a.null_count() > 0,
154            }).collect::<Vec<_>>();
155
156            let field_array = data_array.into_iter().zip(nullables).enumerate()
157                .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
158                .map(std::sync::Arc::new)
159                .collect::<Vec<_>>();
160
161            let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
162                arg_fields: &field_array,
163                scalar_arguments: &scalar_arguments_refs,
164            });
165            let arg_fields = $ARGS.iter()
166                .enumerate()
167                .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
168                .collect::<Vec<_>>();
169
170            match expected {
171                Ok(expected) => {
172                    assert_eq!(return_field.is_ok(), true);
173                    let return_field = return_field.unwrap();
174                    let return_type = return_field.data_type();
175                    assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
176
177                    let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field});
178                    assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
179
180                    let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
181                    let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
182                    assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
183
184                    // value is correct
185                    match expected {
186                        Some(v) => assert_eq!(result.value(0), v),
187                        None => assert!(result.is_null(0)),
188                    };
189                }
190                Err(expected_error) => {
191                    if return_field.is_err() {
192                        match return_field {
193                            Ok(_) => assert!(false, "expected error"),
194                            Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
195                        }
196                    }
197                    else {
198                        let return_field = return_field.unwrap();
199
200                        // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
201                        match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}) {
202                            Ok(_) => assert!(false, "expected error"),
203                            Err(error) => {
204                                assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
205                            }
206                        }
207                    }
208                }
209            };
210        };
211    }
212
213    use arrow::datatypes::DataType;
214    #[allow(unused_imports)]
215    pub(crate) use test_function;
216
217    use super::*;
218
219    #[test]
220    fn string_to_int_type() {
221        let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
222        assert_eq!(v, DataType::Int32);
223
224        let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
225        assert_eq!(v, DataType::Int32);
226
227        let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
228        assert_eq!(v, DataType::Int64);
229    }
230}