datafusion_functions/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::DataType;
21use arrow::error::ArrowError;
22use datafusion_common::{DataFusionError, Result, ScalarValue};
23use datafusion_expr::function::Hint;
24use datafusion_expr::ColumnarValue;
25use std::sync::Arc;
26
27/// Creates a function to identify the optimal return type of a string function given
28/// the type of its first argument.
29///
30/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
31/// `$largeUtf8Type`,
32///
33/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
34///
35/// If the input type is `Utf8View` the return type is $utf8Type,
36macro_rules! get_optimal_return_type {
37    ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
38        pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
39            Ok(match arg_type {
40                // LargeBinary inputs are automatically coerced to Utf8
41                DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
42                // Binary inputs are automatically coerced to Utf8
43                DataType::Utf8 | DataType::Binary => $utf8Type,
44                // Utf8View max offset size is u32::MAX, the same as UTF8
45                DataType::Utf8View | DataType::BinaryView => $utf8Type,
46                DataType::Null => DataType::Null,
47                DataType::Dictionary(_, value_type) => match **value_type {
48                    DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
49                    DataType::Utf8 | DataType::Binary => $utf8Type,
50                    DataType::Null => DataType::Null,
51                    _ => {
52                        return datafusion_common::exec_err!(
53                            "The {} function can only accept strings, but got {:?}.",
54                            name.to_uppercase(),
55                            **value_type
56                        );
57                    }
58                },
59                data_type => {
60                    return datafusion_common::exec_err!(
61                        "The {} function can only accept strings, but got {:?}.",
62                        name.to_uppercase(),
63                        data_type
64                    );
65                }
66            })
67        }
68    };
69}
70
71// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
72get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
73
74// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
75get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
76
77/// Creates a scalar function implementation for the given function.
78/// * `inner` - the function to be executed
79/// * `hints` - hints to be used when expanding scalars to arrays
80pub fn make_scalar_function<F>(
81    inner: F,
82    hints: Vec<Hint>,
83) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
84where
85    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
86{
87    move |args: &[ColumnarValue]| {
88        // first, identify if any of the arguments is an Array. If yes, store its `len`,
89        // as any scalar will need to be converted to an array of len `len`.
90        let len = args
91            .iter()
92            .fold(Option::<usize>::None, |acc, arg| match arg {
93                ColumnarValue::Scalar(_) => acc,
94                ColumnarValue::Array(a) => Some(a.len()),
95            });
96
97        let is_scalar = len.is_none();
98
99        let inferred_length = len.unwrap_or(1);
100        let args = args
101            .iter()
102            .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
103            .map(|(arg, hint)| {
104                // Decide on the length to expand this scalar to depending
105                // on the given hints.
106                let expansion_len = match hint {
107                    Hint::AcceptsSingular => 1,
108                    Hint::Pad => inferred_length,
109                };
110                arg.to_array(expansion_len)
111            })
112            .collect::<Result<Vec<_>>>()?;
113
114        let result = (inner)(&args);
115        if is_scalar {
116            // If all inputs are scalar, keeps output as scalar
117            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
118            result.map(ColumnarValue::Scalar)
119        } else {
120            result.map(ColumnarValue::Array)
121        }
122    }
123}
124
125/// Computes a binary math function for input arrays using a specified function.
126/// Generic types:
127/// - `L`: Left array primitive type
128/// - `R`: Right array primitive type
129/// - `O`: Output array primitive type
130/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>`
131pub fn calculate_binary_math<L, R, O, F>(
132    left: &dyn Array,
133    right: &ColumnarValue,
134    fun: F,
135) -> Result<Arc<PrimitiveArray<O>>>
136where
137    R: ArrowPrimitiveType,
138    L: ArrowPrimitiveType,
139    O: ArrowPrimitiveType,
140    F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
141    R::Native: TryFrom<ScalarValue>,
142{
143    let left = left.as_primitive::<L>();
144    let right = right.cast_to(&R::DATA_TYPE, None)?;
145    let result = match right {
146        ColumnarValue::Scalar(scalar) => {
147            let right = R::Native::try_from(scalar.clone()).map_err(|_| {
148                DataFusionError::NotImplemented(format!(
149                    "Cannot convert scalar value {} to {}",
150                    &scalar,
151                    R::DATA_TYPE
152                ))
153            })?;
154            left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))?
155        }
156        ColumnarValue::Array(right) => {
157            let right = right.as_primitive::<R>();
158            try_binary::<_, _, _, O>(left, right, &fun)?
159        }
160    };
161    Ok(Arc::new(result) as _)
162}
163
164/// Converts Decimal128 components (value and scale) to an unscaled i128
165pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
166    if scale < 0 {
167        Err(ArrowError::ComputeError(
168            "Negative scale is not supported".into(),
169        ))
170    } else if scale == 0 {
171        Ok(value)
172    } else {
173        match i128::from(10).checked_pow(scale as u32) {
174            Some(divisor) => Ok(value / divisor),
175            None => Err(ArrowError::ComputeError(format!(
176                "Cannot get a power of {scale}"
177            ))),
178        }
179    }
180}
181
182#[cfg(test)]
183pub mod test {
184    /// $FUNC ScalarUDFImpl to test
185    /// $ARGS arguments (vec) to pass to function
186    /// $EXPECTED a Result<ColumnarValue>
187    /// $EXPECTED_TYPE is the expected value type
188    /// $EXPECTED_DATA_TYPE is the expected result type
189    /// $ARRAY_TYPE is the column type after function applied
190    /// $CONFIG_OPTIONS config options to pass to function
191    macro_rules! test_function {
192    ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
193        let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
194        let func = $FUNC;
195
196        let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
197        let cardinality = $ARGS
198            .iter()
199            .fold(Option::<usize>::None, |acc, arg| match arg {
200                ColumnarValue::Scalar(_) => acc,
201                ColumnarValue::Array(a) => Some(a.len()),
202            })
203            .unwrap_or(1);
204
205            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
206                ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
207                ColumnarValue::Array(_) => None,
208            }).collect::<Vec<_>>();
209            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
210
211            let nullables = $ARGS.iter().map(|arg| match arg {
212                ColumnarValue::Scalar(scalar) => scalar.is_null(),
213                ColumnarValue::Array(a) => a.null_count() > 0,
214            }).collect::<Vec<_>>();
215
216            let field_array = data_array.into_iter().zip(nullables).enumerate()
217                .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
218            .map(std::sync::Arc::new)
219            .collect::<Vec<_>>();
220
221        let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
222            arg_fields: &field_array,
223            scalar_arguments: &scalar_arguments_refs,
224        });
225            let arg_fields = $ARGS.iter()
226            .enumerate()
227                .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
228            .collect::<Vec<_>>();
229
230        match expected {
231            Ok(expected) => {
232                assert_eq!(return_field.is_ok(), true);
233                let return_field = return_field.unwrap();
234                let return_type = return_field.data_type();
235                assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
236
237                    let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
238                    args: $ARGS,
239                    arg_fields,
240                    number_rows: cardinality,
241                    return_field,
242                        config_options: $CONFIG_OPTIONS
243                });
244                    assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
245
246                    let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
247                    let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
248                assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
249
250                // value is correct
251                match expected {
252                    Some(v) => assert_eq!(result.value(0), v),
253                    None => assert!(result.is_null(0)),
254                };
255            }
256            Err(expected_error) => {
257                if let Ok(return_field) = return_field {
258                    // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
259                    match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
260                        args: $ARGS,
261                        arg_fields,
262                        number_rows: cardinality,
263                        return_field,
264                        config_options: $CONFIG_OPTIONS,
265                    }) {
266                        Ok(_) => assert!(false, "expected error"),
267                        Err(error) => {
268                            assert!(expected_error
269                                .strip_backtrace()
270                                .starts_with(&error.strip_backtrace()));
271                        }
272                    }
273                } else if let Err(error) = return_field {
274                    datafusion_common::assert_contains!(
275                        expected_error.strip_backtrace(),
276                        error.strip_backtrace()
277                    );
278                }
279            }
280        };
281    };
282
283        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
284            test_function!(
285                $FUNC,
286                $ARGS,
287                $EXPECTED,
288                $EXPECTED_TYPE,
289                $EXPECTED_DATA_TYPE,
290                $ARRAY_TYPE,
291                std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
292            )
293        };
294    }
295
296    use arrow::datatypes::DataType;
297    #[allow(unused_imports)]
298    pub(crate) use test_function;
299
300    use super::*;
301
302    #[test]
303    fn string_to_int_type() {
304        let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
305        assert_eq!(v, DataType::Int32);
306
307        let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
308        assert_eq!(v, DataType::Int32);
309
310        let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
311        assert_eq!(v, DataType::Int64);
312    }
313
314    #[test]
315    fn test_decimal128_to_i128() {
316        let cases = [
317            (123, 0, Some(123)),
318            (1230, 1, Some(123)),
319            (123000, 3, Some(123)),
320            (1, 0, Some(1)),
321            (123, -3, None),
322            (123, i8::MAX, None),
323            (i128::MAX, 0, Some(i128::MAX)),
324            (i128::MAX, 3, Some(i128::MAX / 1000)),
325        ];
326
327        for (value, scale, expected) in cases {
328            match decimal128_to_i128(value, scale) {
329                Ok(actual) => {
330                    assert_eq!(
331                        actual,
332                        expected.expect("Got value but expected none"),
333                        "{value} and {scale} vs {expected:?}"
334                    );
335                }
336                Err(_) => assert!(expected.is_none()),
337            }
338        }
339    }
340}