datafusion_functions/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, DecimalType};
21use arrow::error::ArrowError;
22use datafusion_common::{DataFusionError, Result, ScalarValue};
23use datafusion_expr::ColumnarValue;
24use datafusion_expr::function::Hint;
25use std::sync::Arc;
26
27/// Creates a function to identify the optimal return type of a string function given
28/// the type of its first argument.
29///
30/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
31/// `$largeUtf8Type`,
32///
33/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
34///
35/// If the input type is `Utf8View` the return type is $utf8Type,
36macro_rules! get_optimal_return_type {
37    ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
38        pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
39            Ok(match arg_type {
40                // LargeBinary inputs are automatically coerced to Utf8
41                DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
42                // Binary inputs are automatically coerced to Utf8
43                DataType::Utf8 | DataType::Binary => $utf8Type,
44                // Utf8View max offset size is u32::MAX, the same as UTF8
45                DataType::Utf8View | DataType::BinaryView => $utf8Type,
46                DataType::Null => DataType::Null,
47                DataType::Dictionary(_, value_type) => match **value_type {
48                    DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
49                    DataType::Utf8 | DataType::Binary => $utf8Type,
50                    DataType::Null => DataType::Null,
51                    _ => {
52                        return datafusion_common::exec_err!(
53                            "The {} function can only accept strings, but got {:?}.",
54                            name.to_uppercase(),
55                            **value_type
56                        );
57                    }
58                },
59                data_type => {
60                    return datafusion_common::exec_err!(
61                        "The {} function can only accept strings, but got {:?}.",
62                        name.to_uppercase(),
63                        data_type
64                    );
65                }
66            })
67        }
68    };
69}
70
71// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
72get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
73
74// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
75get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
76
77/// Creates a scalar function implementation for the given function.
78/// * `inner` - the function to be executed
79/// * `hints` - hints to be used when expanding scalars to arrays
80pub fn make_scalar_function<F>(
81    inner: F,
82    hints: Vec<Hint>,
83) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
84where
85    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
86{
87    move |args: &[ColumnarValue]| {
88        // first, identify if any of the arguments is an Array. If yes, store its `len`,
89        // as any scalar will need to be converted to an array of len `len`.
90        let len = args
91            .iter()
92            .fold(Option::<usize>::None, |acc, arg| match arg {
93                ColumnarValue::Scalar(_) => acc,
94                ColumnarValue::Array(a) => Some(a.len()),
95            });
96
97        let is_scalar = len.is_none();
98
99        let inferred_length = len.unwrap_or(1);
100        let args = args
101            .iter()
102            .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
103            .map(|(arg, hint)| {
104                // Decide on the length to expand this scalar to depending
105                // on the given hints.
106                let expansion_len = match hint {
107                    Hint::AcceptsSingular => 1,
108                    Hint::Pad => inferred_length,
109                };
110                arg.to_array(expansion_len)
111            })
112            .collect::<Result<Vec<_>>>()?;
113
114        let result = (inner)(&args);
115        if is_scalar {
116            // If all inputs are scalar, keeps output as scalar
117            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
118            result.map(ColumnarValue::Scalar)
119        } else {
120            result.map(ColumnarValue::Array)
121        }
122    }
123}
124
125/// Computes a binary math function for input arrays using a specified function.
126/// Generic types:
127/// - `L`: Left array primitive type
128/// - `R`: Right array primitive type
129/// - `O`: Output array primitive type
130/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>`
131pub fn calculate_binary_math<L, R, O, F>(
132    left: &dyn Array,
133    right: &ColumnarValue,
134    fun: F,
135) -> Result<Arc<PrimitiveArray<O>>>
136where
137    L: ArrowPrimitiveType,
138    R: ArrowPrimitiveType,
139    O: ArrowPrimitiveType,
140    F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
141    R::Native: TryFrom<ScalarValue>,
142{
143    let left = left.as_primitive::<L>();
144    let right = right.cast_to(&R::DATA_TYPE, None)?;
145    let result = match right {
146        ColumnarValue::Scalar(scalar) => {
147            if scalar.is_null() {
148                // Null scalar is castable to any numeric, creating a non-null expression.
149                // Provide null array explicitly to make result null
150                PrimitiveArray::<O>::new_null(1)
151            } else {
152                let right = R::Native::try_from(scalar.clone()).map_err(|_| {
153                    DataFusionError::NotImplemented(format!(
154                        "Cannot convert scalar value {} to {}",
155                        &scalar,
156                        R::DATA_TYPE
157                    ))
158                })?;
159                left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))?
160            }
161        }
162        ColumnarValue::Array(right) => {
163            let right = right.as_primitive::<R>();
164            try_binary::<_, _, _, O>(left, right, &fun)?
165        }
166    };
167    Ok(Arc::new(result) as _)
168}
169
170/// Computes a binary math function for input arrays using a specified function
171/// and apply rescaling to given precision and scale.
172/// Generic types:
173/// - `L`: Left array decimal type
174/// - `R`: Right array primitive type
175/// - `O`: Output array decimal type
176/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>`
177pub fn calculate_binary_decimal_math<L, R, O, F>(
178    left: &dyn Array,
179    right: &ColumnarValue,
180    fun: F,
181    precision: u8,
182    scale: i8,
183) -> Result<Arc<PrimitiveArray<O>>>
184where
185    L: DecimalType,
186    R: ArrowPrimitiveType,
187    O: DecimalType,
188    F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
189    R::Native: TryFrom<ScalarValue>,
190{
191    let result_array = calculate_binary_math::<L, R, O, F>(left, right, fun)?;
192    Ok(Arc::new(
193        result_array
194            .as_ref()
195            .clone()
196            .with_precision_and_scale(precision, scale)?,
197    ))
198}
199
200/// Converts Decimal128 components (value and scale) to an unscaled i128
201pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
202    if scale < 0 {
203        Err(ArrowError::ComputeError(
204            "Negative scale is not supported".into(),
205        ))
206    } else if scale == 0 {
207        Ok(value)
208    } else {
209        match i128::from(10).checked_pow(scale as u32) {
210            Some(divisor) => Ok(value / divisor),
211            None => Err(ArrowError::ComputeError(format!(
212                "Cannot get a power of {scale}"
213            ))),
214        }
215    }
216}
217
218pub fn decimal32_to_i32(value: i32, scale: i8) -> Result<i32, ArrowError> {
219    if scale < 0 {
220        Err(ArrowError::ComputeError(
221            "Negative scale is not supported".into(),
222        ))
223    } else if scale == 0 {
224        Ok(value)
225    } else {
226        match 10_i32.checked_pow(scale as u32) {
227            Some(divisor) => Ok(value / divisor),
228            None => Err(ArrowError::ComputeError(format!(
229                "Cannot get a power of {scale}"
230            ))),
231        }
232    }
233}
234
235pub fn decimal64_to_i64(value: i64, scale: i8) -> Result<i64, ArrowError> {
236    if scale < 0 {
237        Err(ArrowError::ComputeError(
238            "Negative scale is not supported".into(),
239        ))
240    } else if scale == 0 {
241        Ok(value)
242    } else {
243        match i64::from(10).checked_pow(scale as u32) {
244            Some(divisor) => Ok(value / divisor),
245            None => Err(ArrowError::ComputeError(format!(
246                "Cannot get a power of {scale}"
247            ))),
248        }
249    }
250}
251
252#[cfg(test)]
253pub mod test {
254    /// $FUNC ScalarUDFImpl to test
255    /// $ARGS arguments (vec) to pass to function
256    /// $EXPECTED a Result<ColumnarValue>
257    /// $EXPECTED_TYPE is the expected value type
258    /// $EXPECTED_DATA_TYPE is the expected result type
259    /// $ARRAY_TYPE is the column type after function applied
260    /// $CONFIG_OPTIONS config options to pass to function
261    macro_rules! test_function {
262    ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
263        let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
264        let func = $FUNC;
265
266        let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
267        let cardinality = $ARGS
268            .iter()
269            .fold(Option::<usize>::None, |acc, arg| match arg {
270                ColumnarValue::Scalar(_) => acc,
271                ColumnarValue::Array(a) => Some(a.len()),
272            })
273            .unwrap_or(1);
274
275            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
276                ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
277                ColumnarValue::Array(_) => None,
278            }).collect::<Vec<_>>();
279            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
280
281            let nullables = $ARGS.iter().map(|arg| match arg {
282                ColumnarValue::Scalar(scalar) => scalar.is_null(),
283                ColumnarValue::Array(a) => a.null_count() > 0,
284            }).collect::<Vec<_>>();
285
286            let field_array = data_array.into_iter().zip(nullables).enumerate()
287                .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
288            .map(std::sync::Arc::new)
289            .collect::<Vec<_>>();
290
291        let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
292            arg_fields: &field_array,
293            scalar_arguments: &scalar_arguments_refs,
294        });
295            let arg_fields = $ARGS.iter()
296            .enumerate()
297                .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
298            .collect::<Vec<_>>();
299
300        match expected {
301            Ok(expected) => {
302                assert_eq!(return_field.is_ok(), true);
303                let return_field = return_field.unwrap();
304                let return_type = return_field.data_type();
305                assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
306
307                    let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
308                    args: $ARGS,
309                    arg_fields,
310                    number_rows: cardinality,
311                    return_field,
312                        config_options: $CONFIG_OPTIONS
313                });
314                    assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
315
316                    let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
317                    let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
318                assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
319
320                // value is correct
321                match expected {
322                    Some(v) => assert_eq!(result.value(0), v),
323                    None => assert!(result.is_null(0)),
324                };
325            }
326            Err(expected_error) => {
327                if let Ok(return_field) = return_field {
328                    // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
329                    match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
330                        args: $ARGS,
331                        arg_fields,
332                        number_rows: cardinality,
333                        return_field,
334                        config_options: $CONFIG_OPTIONS,
335                    }) {
336                        Ok(_) => assert!(false, "expected error"),
337                        Err(error) => {
338                            assert!(expected_error
339                                .strip_backtrace()
340                                .starts_with(&error.strip_backtrace()));
341                        }
342                    }
343                } else if let Err(error) = return_field {
344                    datafusion_common::assert_contains!(
345                        expected_error.strip_backtrace(),
346                        error.strip_backtrace()
347                    );
348                }
349            }
350        };
351    };
352
353        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
354            test_function!(
355                $FUNC,
356                $ARGS,
357                $EXPECTED,
358                $EXPECTED_TYPE,
359                $EXPECTED_DATA_TYPE,
360                $ARRAY_TYPE,
361                std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
362            )
363        };
364    }
365
366    use arrow::datatypes::DataType;
367    use itertools::Either;
368    pub(crate) use test_function;
369
370    use super::*;
371
372    #[test]
373    fn string_to_int_type() {
374        let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
375        assert_eq!(v, DataType::Int32);
376
377        let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
378        assert_eq!(v, DataType::Int32);
379
380        let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
381        assert_eq!(v, DataType::Int64);
382    }
383
384    #[test]
385    fn test_decimal128_to_i128() {
386        let cases = [
387            (123, 0, Some(123)),
388            (1230, 1, Some(123)),
389            (123000, 3, Some(123)),
390            (1, 0, Some(1)),
391            (123, -3, None),
392            (123, i8::MAX, None),
393            (i128::MAX, 0, Some(i128::MAX)),
394            (i128::MAX, 3, Some(i128::MAX / 1000)),
395        ];
396
397        for (value, scale, expected) in cases {
398            match decimal128_to_i128(value, scale) {
399                Ok(actual) => {
400                    assert_eq!(
401                        actual,
402                        expected.expect("Got value but expected none"),
403                        "{value} and {scale} vs {expected:?}"
404                    );
405                }
406                Err(_) => assert!(expected.is_none()),
407            }
408        }
409    }
410
411    #[test]
412    fn test_decimal32_to_i32() {
413        let cases: [(i32, i8, Either<i32, String>); _] = [
414            (123, 0, Either::Left(123)),
415            (1230, 1, Either::Left(123)),
416            (123000, 3, Either::Left(123)),
417            (1234567, 2, Either::Left(12345)),
418            (-1234567, 2, Either::Left(-12345)),
419            (1, 0, Either::Left(1)),
420            (
421                123,
422                -3,
423                Either::Right("Negative scale is not supported".into()),
424            ),
425            (
426                123,
427                i8::MAX,
428                Either::Right("Cannot get a power of 127".into()),
429            ),
430            (999999999, 0, Either::Left(999999999)),
431            (999999999, 3, Either::Left(999999)),
432        ];
433
434        for (value, scale, expected) in cases {
435            match decimal32_to_i32(value, scale) {
436                Ok(actual) => {
437                    let expected_value =
438                        expected.left().expect("Got value but expected none");
439                    assert_eq!(
440                        actual, expected_value,
441                        "{value} and {scale} vs {expected_value:?}"
442                    );
443                }
444                Err(ArrowError::ComputeError(msg)) => {
445                    assert_eq!(
446                        msg,
447                        expected.right().expect("Got error but expected value")
448                    );
449                }
450                Err(_) => {
451                    assert!(expected.is_right())
452                }
453            }
454        }
455    }
456
457    #[test]
458    fn test_decimal64_to_i64() {
459        let cases: [(i64, i8, Either<i64, String>); _] = [
460            (123, 0, Either::Left(123)),
461            (1234567890, 2, Either::Left(12345678)),
462            (-1234567890, 2, Either::Left(-12345678)),
463            (
464                123,
465                -3,
466                Either::Right("Negative scale is not supported".into()),
467            ),
468            (
469                123,
470                i8::MAX,
471                Either::Right("Cannot get a power of 127".into()),
472            ),
473            (
474                999999999999999999i64,
475                0,
476                Either::Left(999999999999999999i64),
477            ),
478            (
479                999999999999999999i64,
480                3,
481                Either::Left(999999999999999999i64 / 1000),
482            ),
483            (
484                -999999999999999999i64,
485                3,
486                Either::Left(-999999999999999999i64 / 1000),
487            ),
488        ];
489
490        for (value, scale, expected) in cases {
491            match decimal64_to_i64(value, scale) {
492                Ok(actual) => {
493                    let expected_value =
494                        expected.left().expect("Got value but expected none");
495                    assert_eq!(
496                        actual, expected_value,
497                        "{value} and {scale} vs {expected_value:?}"
498                    );
499                }
500                Err(ArrowError::ComputeError(msg)) => {
501                    assert_eq!(
502                        msg,
503                        expected.right().expect("Got error but expected value")
504                    );
505                }
506                Err(_) => {
507                    assert!(expected.is_right())
508                }
509            }
510        }
511    }
512}