datafusion_functions/
utils.rs1use arrow::array::ArrayRef;
19use arrow::datatypes::DataType;
20
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::Hint;
23use datafusion_expr::ColumnarValue;
24
25macro_rules! get_optimal_return_type {
35 ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
36 pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
37 Ok(match arg_type {
38 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
40 DataType::Utf8 | DataType::Binary => $utf8Type,
42 DataType::Utf8View | DataType::BinaryView => $utf8Type,
44 DataType::Null => DataType::Null,
45 DataType::Dictionary(_, value_type) => match **value_type {
46 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
47 DataType::Utf8 | DataType::Binary => $utf8Type,
48 DataType::Null => DataType::Null,
49 _ => {
50 return datafusion_common::exec_err!(
51 "The {} function can only accept strings, but got {:?}.",
52 name.to_uppercase(),
53 **value_type
54 );
55 }
56 },
57 data_type => {
58 return datafusion_common::exec_err!(
59 "The {} function can only accept strings, but got {:?}.",
60 name.to_uppercase(),
61 data_type
62 );
63 }
64 })
65 }
66 };
67}
68
69get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
71
72get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
74
75pub fn make_scalar_function<F>(
79 inner: F,
80 hints: Vec<Hint>,
81) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
82where
83 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
84{
85 move |args: &[ColumnarValue]| {
86 let len = args
89 .iter()
90 .fold(Option::<usize>::None, |acc, arg| match arg {
91 ColumnarValue::Scalar(_) => acc,
92 ColumnarValue::Array(a) => Some(a.len()),
93 });
94
95 let is_scalar = len.is_none();
96
97 let inferred_length = len.unwrap_or(1);
98 let args = args
99 .iter()
100 .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
101 .map(|(arg, hint)| {
102 let expansion_len = match hint {
105 Hint::AcceptsSingular => 1,
106 Hint::Pad => inferred_length,
107 };
108 arg.to_array(expansion_len)
109 })
110 .collect::<Result<Vec<_>>>()?;
111
112 let result = (inner)(&args);
113 if is_scalar {
114 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
116 result.map(ColumnarValue::Scalar)
117 } else {
118 result.map(ColumnarValue::Array)
119 }
120 }
121}
122
123#[cfg(test)]
124pub mod test {
125 macro_rules! test_function {
132 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
133 let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
134 let func = $FUNC;
135
136 let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
137 let cardinality = $ARGS
138 .iter()
139 .fold(Option::<usize>::None, |acc, arg| match arg {
140 ColumnarValue::Scalar(_) => acc,
141 ColumnarValue::Array(a) => Some(a.len()),
142 })
143 .unwrap_or(1);
144
145 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
146 ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
147 ColumnarValue::Array(_) => None,
148 }).collect::<Vec<_>>();
149 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
150
151 let nullables = $ARGS.iter().map(|arg| match arg {
152 ColumnarValue::Scalar(scalar) => scalar.is_null(),
153 ColumnarValue::Array(a) => a.null_count() > 0,
154 }).collect::<Vec<_>>();
155
156 let field_array = data_array.into_iter().zip(nullables).enumerate()
157 .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
158 .map(std::sync::Arc::new)
159 .collect::<Vec<_>>();
160
161 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
162 arg_fields: &field_array,
163 scalar_arguments: &scalar_arguments_refs,
164 });
165 let arg_fields = $ARGS.iter()
166 .enumerate()
167 .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
168 .collect::<Vec<_>>();
169
170 match expected {
171 Ok(expected) => {
172 assert_eq!(return_field.is_ok(), true);
173 let return_field = return_field.unwrap();
174 let return_type = return_field.data_type();
175 assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
176
177 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field});
178 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
179
180 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
181 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
182 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
183
184 match expected {
186 Some(v) => assert_eq!(result.value(0), v),
187 None => assert!(result.is_null(0)),
188 };
189 }
190 Err(expected_error) => {
191 if return_field.is_err() {
192 match return_field {
193 Ok(_) => assert!(false, "expected error"),
194 Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
195 }
196 }
197 else {
198 let return_field = return_field.unwrap();
199
200 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}) {
202 Ok(_) => assert!(false, "expected error"),
203 Err(error) => {
204 assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
205 }
206 }
207 }
208 }
209 };
210 };
211 }
212
213 use arrow::datatypes::DataType;
214 #[allow(unused_imports)]
215 pub(crate) use test_function;
216
217 use super::*;
218
219 #[test]
220 fn string_to_int_type() {
221 let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
222 assert_eq!(v, DataType::Int32);
223
224 let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
225 assert_eq!(v, DataType::Int32);
226
227 let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
228 assert_eq!(v, DataType::Int64);
229 }
230}