datafusion_spark/function/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#[cfg(test)]
19pub mod test {
20    /// $FUNC ScalarUDFImpl to test
21    /// $ARGS arguments (vec) to pass to function
22    /// $EXPECTED a Result<ColumnarValue>
23    /// $EXPECTED_TYPE is the expected value type
24    /// $EXPECTED_DATA_TYPE is the expected result type
25    /// $ARRAY_TYPE is the column type after function applied
26    macro_rules! test_scalar_function {
27        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
28            let expected: datafusion_common::Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
29            let func = $FUNC;
30
31            let arg_fields: Vec<arrow::datatypes::FieldRef> = $ARGS
32                .iter()
33                .enumerate()
34                .map(|(idx, arg)| {
35
36                let nullable = match arg {
37                    datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
38                    datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
39                };
40
41                std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable))
42            })
43                .collect::<Vec<_>>();
44
45            let cardinality = $ARGS
46                .iter()
47                .fold(Option::<usize>::None, |acc, arg| match arg {
48                    datafusion_expr::ColumnarValue::Scalar(_) => acc,
49                    datafusion_expr::ColumnarValue::Array(a) => Some(a.len()),
50                })
51                .unwrap_or(1);
52
53            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
54                datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
55                datafusion_expr::ColumnarValue::Array(_) => None,
56            }).collect::<Vec<_>>();
57            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
58
59
60            let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
61                arg_fields: &arg_fields,
62                scalar_arguments: &scalar_arguments_refs
63            });
64
65            match expected {
66                Ok(expected) => {
67                    let return_field = return_field.unwrap();
68                    assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE);
69
70                    let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
71                        args: $ARGS,
72                        number_rows: cardinality,
73                        return_field,
74                        arg_fields: arg_fields.clone(),
75                    });
76                    assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
77
78                    let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
79                    let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
80                    assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
81
82                    // value is correct
83                    match expected {
84                        Some(v) => assert_eq!(result.value(0), v),
85                        None => assert!(result.is_null(0)),
86                    };
87                }
88                Err(expected_error) => {
89                    if return_field.is_err() {
90                        match return_field {
91                            Ok(_) => assert!(false, "expected error"),
92                            Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
93                        }
94                    }
95                    else {
96                        let return_field = return_field.unwrap();
97
98                        // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
99                        match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
100                            args: $ARGS,
101                            number_rows: cardinality,
102                            return_field,
103                            arg_fields,
104                        }) {
105                            Ok(_) => assert!(false, "expected error"),
106                            Err(error) => {
107                                assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
108                            }
109                        }
110                    }
111                }
112            };
113        };
114    }
115
116    pub(crate) use test_scalar_function;
117}