datafusion_spark/function/utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#[cfg(test)]
19pub mod test {
20 /// $FUNC ScalarUDFImpl to test
21 /// $ARGS arguments (vec) to pass to function
22 /// $EXPECTED a Result<ColumnarValue>
23 /// $EXPECTED_TYPE is the expected value type
24 /// $EXPECTED_DATA_TYPE is the expected result type
25 /// $ARRAY_TYPE is the column type after function applied
26 /// $CONFIG_OPTIONS config options to pass to function
27 macro_rules! test_scalar_function {
28 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
29 let expected: datafusion_common::Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
30 let func = $FUNC;
31
32 let arg_fields: Vec<arrow::datatypes::FieldRef> = $ARGS
33 .iter()
34 .enumerate()
35 .map(|(idx, arg)| {
36
37 let nullable = match arg {
38 datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
39 datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
40 };
41
42 std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable))
43 })
44 .collect::<Vec<_>>();
45
46 let cardinality = $ARGS
47 .iter()
48 .fold(Option::<usize>::None, |acc, arg| match arg {
49 datafusion_expr::ColumnarValue::Scalar(_) => acc,
50 datafusion_expr::ColumnarValue::Array(a) => Some(a.len()),
51 })
52 .unwrap_or(1);
53
54 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
55 datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
56 datafusion_expr::ColumnarValue::Array(_) => None,
57 }).collect::<Vec<_>>();
58 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
59
60
61 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
62 arg_fields: &arg_fields,
63 scalar_arguments: &scalar_arguments_refs
64 });
65
66 match expected {
67 Ok(expected) => {
68 if let Ok(return_field) = return_field {
69 assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE);
70
71 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
72 args: $ARGS,
73 number_rows: cardinality,
74 return_field,
75 arg_fields: arg_fields.clone(),
76 config_options: $CONFIG_OPTIONS,
77 }) {
78 Ok(col_value) => {
79 match col_value.to_array(cardinality) {
80 Ok(array) => {
81 let result = array
82 .as_any()
83 .downcast_ref::<$ARRAY_TYPE>()
84 .expect("Failed to convert to type");
85 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
86
87 // value is correct
88 match expected {
89 Some(v) => assert_eq!(result.value(0), v),
90 None => assert!(result.is_null(0)),
91 };
92 }
93 Err(err) => {
94 panic!("Failed to convert to array: {err}");
95 }
96 }
97 }
98 Err(err) => {
99 panic!("function returned an error: {err}");
100 }
101 }
102 } else {
103 panic!("Expected return_field to be Ok but got Err");
104 }
105 }
106 Err(expected_error) => {
107 if let Err(error) = &return_field {
108 datafusion_common::assert_contains!(
109 expected_error.strip_backtrace(),
110 error.strip_backtrace()
111 );
112 } else if let Ok(value) = return_field {
113 // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
114 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
115 args: $ARGS,
116 number_rows: cardinality,
117 return_field: value,
118 arg_fields,
119 config_options: $CONFIG_OPTIONS,
120 }) {
121 Ok(_) => assert!(false, "expected error"),
122 Err(error) => {
123 assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
124 }
125 }
126 }
127 }
128 };
129 };
130
131 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
132 test_scalar_function!(
133 $FUNC,
134 $ARGS,
135 $EXPECTED,
136 $EXPECTED_TYPE,
137 $EXPECTED_DATA_TYPE,
138 $ARRAY_TYPE,
139 std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
140 )
141 };
142 }
143
144 pub(crate) use test_scalar_function;
145}