datafusion_spark/function/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#[cfg(test)]
19pub mod test {
20    /// $FUNC ScalarUDFImpl to test
21    /// $ARGS arguments (vec) to pass to function
22    /// $EXPECTED a Result<ColumnarValue>
23    /// $EXPECTED_TYPE is the expected value type
24    /// $EXPECTED_DATA_TYPE is the expected result type
25    /// $ARRAY_TYPE is the column type after function applied
26    /// $CONFIG_OPTIONS config options to pass to function
27    macro_rules! test_scalar_function {
28        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
29            let expected: datafusion_common::Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
30            let func = $FUNC;
31
32            let arg_fields: Vec<arrow::datatypes::FieldRef> = $ARGS
33                .iter()
34                .enumerate()
35                .map(|(idx, arg)| {
36
37                    let nullable = match arg {
38                        datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
39                        datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
40                    };
41
42                std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable))
43                })
44                .collect::<Vec<_>>();
45
46            let cardinality = $ARGS
47                .iter()
48                .fold(Option::<usize>::None, |acc, arg| match arg {
49                    datafusion_expr::ColumnarValue::Scalar(_) => acc,
50                    datafusion_expr::ColumnarValue::Array(a) => Some(a.len()),
51                })
52                .unwrap_or(1);
53
54            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
55                    datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
56                    datafusion_expr::ColumnarValue::Array(_) => None,
57            }).collect::<Vec<_>>();
58            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
59
60
61            let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
62                arg_fields: &arg_fields,
63                scalar_arguments: &scalar_arguments_refs
64            });
65
66            match expected {
67                Ok(expected) => {
68                    if let Ok(return_field) = return_field {
69                        assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE);
70
71                        match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
72                            args: $ARGS,
73                            number_rows: cardinality,
74                            return_field,
75                            arg_fields: arg_fields.clone(),
76                            config_options: $CONFIG_OPTIONS,
77                        }) {
78                            Ok(col_value) => {
79                                match col_value.to_array(cardinality) {
80                                    Ok(array) => {
81                                        let result = array
82                                            .as_any()
83                                            .downcast_ref::<$ARRAY_TYPE>()
84                                            .expect("Failed to convert to type");
85                                        assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
86
87                                       // value is correct
88                                        match expected {
89                                            Some(v) => assert_eq!(result.value(0), v),
90                                            None => assert!(result.is_null(0)),
91                                        };
92                                    }
93                                    Err(err) => {
94                                        panic!("Failed to convert to array: {err}");
95                                    }
96                                }
97                            }
98                            Err(err) => {
99                                panic!("function returned an error: {err}");
100                            }
101                        }
102                    } else {
103                        panic!("Expected return_field to be Ok but got Err");
104                    }
105                }
106                Err(expected_error) => {
107                    if let Err(error) = &return_field {
108                        datafusion_common::assert_contains!(
109                            expected_error.strip_backtrace(),
110                            error.strip_backtrace()
111                        );
112                    } else if let Ok(value) = return_field {
113                        // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
114                        match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
115                            args: $ARGS,
116                            number_rows: cardinality,
117                            return_field: value,
118                            arg_fields,
119                            config_options: $CONFIG_OPTIONS,
120                        }) {
121                            Ok(_) => assert!(false, "expected error"),
122                            Err(error) => {
123                                assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
124                            }
125                        }
126                    }
127                }
128            };
129        };
130
131        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
132            test_scalar_function!(
133                $FUNC,
134                $ARGS,
135                $EXPECTED,
136                $EXPECTED_TYPE,
137                $EXPECTED_DATA_TYPE,
138                $ARRAY_TYPE,
139                std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
140            )
141        };
142    }
143
144    pub(crate) use test_scalar_function;
145}