1use arrow::array::ArrayRef;
19use arrow::datatypes::DataType;
20
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::Hint;
23use datafusion_expr::ColumnarValue;
24
25macro_rules! get_optimal_return_type {
35 ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
36 pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
37 Ok(match arg_type {
38 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
40 DataType::Utf8 | DataType::Binary => $utf8Type,
42 DataType::Utf8View | DataType::BinaryView => $utf8Type,
44 DataType::Null => DataType::Null,
45 DataType::Dictionary(_, value_type) => match **value_type {
46 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
47 DataType::Utf8 | DataType::Binary => $utf8Type,
48 DataType::Null => DataType::Null,
49 _ => {
50 return datafusion_common::exec_err!(
51 "The {} function can only accept strings, but got {:?}.",
52 name.to_uppercase(),
53 **value_type
54 );
55 }
56 },
57 data_type => {
58 return datafusion_common::exec_err!(
59 "The {} function can only accept strings, but got {:?}.",
60 name.to_uppercase(),
61 data_type
62 );
63 }
64 })
65 }
66 };
67}
68
69get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
71
72get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
74
75pub fn make_scalar_function<F>(
79 inner: F,
80 hints: Vec<Hint>,
81) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
82where
83 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
84{
85 move |args: &[ColumnarValue]| {
86 let len = args
89 .iter()
90 .fold(Option::<usize>::None, |acc, arg| match arg {
91 ColumnarValue::Scalar(_) => acc,
92 ColumnarValue::Array(a) => Some(a.len()),
93 });
94
95 let is_scalar = len.is_none();
96
97 let inferred_length = len.unwrap_or(1);
98 let args = args
99 .iter()
100 .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
101 .map(|(arg, hint)| {
102 let expansion_len = match hint {
105 Hint::AcceptsSingular => 1,
106 Hint::Pad => inferred_length,
107 };
108 arg.to_array(expansion_len)
109 })
110 .collect::<Result<Vec<_>>>()?;
111
112 let result = (inner)(&args);
113 if is_scalar {
114 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
116 result.map(ColumnarValue::Scalar)
117 } else {
118 result.map(ColumnarValue::Array)
119 }
120 }
121}
122
123#[cfg(test)]
124pub mod test {
125 macro_rules! test_function {
133 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
134 let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
135 let func = $FUNC;
136
137 let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
138 let cardinality = $ARGS
139 .iter()
140 .fold(Option::<usize>::None, |acc, arg| match arg {
141 ColumnarValue::Scalar(_) => acc,
142 ColumnarValue::Array(a) => Some(a.len()),
143 })
144 .unwrap_or(1);
145
146 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
147 ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
148 ColumnarValue::Array(_) => None,
149 }).collect::<Vec<_>>();
150 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
151
152 let nullables = $ARGS.iter().map(|arg| match arg {
153 ColumnarValue::Scalar(scalar) => scalar.is_null(),
154 ColumnarValue::Array(a) => a.null_count() > 0,
155 }).collect::<Vec<_>>();
156
157 let field_array = data_array.into_iter().zip(nullables).enumerate()
158 .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
159 .map(std::sync::Arc::new)
160 .collect::<Vec<_>>();
161
162 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
163 arg_fields: &field_array,
164 scalar_arguments: &scalar_arguments_refs,
165 });
166 let arg_fields = $ARGS.iter()
167 .enumerate()
168 .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
169 .collect::<Vec<_>>();
170
171 match expected {
172 Ok(expected) => {
173 assert_eq!(return_field.is_ok(), true);
174 let return_field = return_field.unwrap();
175 let return_type = return_field.data_type();
176 assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
177
178 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
179 args: $ARGS,
180 arg_fields,
181 number_rows: cardinality,
182 return_field,
183 config_options: $CONFIG_OPTIONS
184 });
185 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
186
187 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
188 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
189 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
190
191 match expected {
193 Some(v) => assert_eq!(result.value(0), v),
194 None => assert!(result.is_null(0)),
195 };
196 }
197 Err(expected_error) => {
198 if let Ok(return_field) = return_field {
199 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
201 args: $ARGS,
202 arg_fields,
203 number_rows: cardinality,
204 return_field,
205 config_options: $CONFIG_OPTIONS,
206 }) {
207 Ok(_) => assert!(false, "expected error"),
208 Err(error) => {
209 assert!(expected_error
210 .strip_backtrace()
211 .starts_with(&error.strip_backtrace()));
212 }
213 }
214 } else if let Err(error) = return_field {
215 datafusion_common::assert_contains!(
216 expected_error.strip_backtrace(),
217 error.strip_backtrace()
218 );
219 }
220 }
221 };
222 };
223
224 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
225 test_function!(
226 $FUNC,
227 $ARGS,
228 $EXPECTED,
229 $EXPECTED_TYPE,
230 $EXPECTED_DATA_TYPE,
231 $ARRAY_TYPE,
232 std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
233 )
234 };
235 }
236
237 use arrow::datatypes::DataType;
238 #[allow(unused_imports)]
239 pub(crate) use test_function;
240
241 use super::*;
242
243 #[test]
244 fn string_to_int_type() {
245 let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
246 assert_eq!(v, DataType::Int32);
247
248 let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
249 assert_eq!(v, DataType::Int32);
250
251 let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
252 assert_eq!(v, DataType::Int64);
253 }
254}