datafusion_spark/function/
utils.rs1#[cfg(test)]
19pub mod test {
20 macro_rules! test_scalar_function {
27 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
28 let expected: datafusion_common::Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
29 let func = $FUNC;
30
31 let arg_fields: Vec<arrow::datatypes::FieldRef> = $ARGS
32 .iter()
33 .enumerate()
34 .map(|(idx, arg)| {
35
36 let nullable = match arg {
37 datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(),
38 datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0,
39 };
40
41 std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable))
42 })
43 .collect::<Vec<_>>();
44
45 let cardinality = $ARGS
46 .iter()
47 .fold(Option::<usize>::None, |acc, arg| match arg {
48 datafusion_expr::ColumnarValue::Scalar(_) => acc,
49 datafusion_expr::ColumnarValue::Array(a) => Some(a.len()),
50 })
51 .unwrap_or(1);
52
53 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
54 datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
55 datafusion_expr::ColumnarValue::Array(_) => None,
56 }).collect::<Vec<_>>();
57 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
58
59
60 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
61 arg_fields: &arg_fields,
62 scalar_arguments: &scalar_arguments_refs
63 });
64
65 match expected {
66 Ok(expected) => {
67 let return_field = return_field.unwrap();
68 assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE);
69
70 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
71 args: $ARGS,
72 number_rows: cardinality,
73 return_field,
74 arg_fields: arg_fields.clone(),
75 });
76 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
77
78 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
79 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
80 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
81
82 match expected {
84 Some(v) => assert_eq!(result.value(0), v),
85 None => assert!(result.is_null(0)),
86 };
87 }
88 Err(expected_error) => {
89 if return_field.is_err() {
90 match return_field {
91 Ok(_) => assert!(false, "expected error"),
92 Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
93 }
94 }
95 else {
96 let return_field = return_field.unwrap();
97
98 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
100 args: $ARGS,
101 number_rows: cardinality,
102 return_field,
103 arg_fields,
104 }) {
105 Ok(_) => assert!(false, "expected error"),
106 Err(error) => {
107 assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
108 }
109 }
110 }
111 }
112 };
113 };
114 }
115
116 pub(crate) use test_scalar_function;
117}