1use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::DataType;
21use arrow::error::ArrowError;
22use datafusion_common::{DataFusionError, Result, ScalarValue};
23use datafusion_expr::function::Hint;
24use datafusion_expr::ColumnarValue;
25use std::sync::Arc;
26
27macro_rules! get_optimal_return_type {
37 ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
38 pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
39 Ok(match arg_type {
40 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
42 DataType::Utf8 | DataType::Binary => $utf8Type,
44 DataType::Utf8View | DataType::BinaryView => $utf8Type,
46 DataType::Null => DataType::Null,
47 DataType::Dictionary(_, value_type) => match **value_type {
48 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
49 DataType::Utf8 | DataType::Binary => $utf8Type,
50 DataType::Null => DataType::Null,
51 _ => {
52 return datafusion_common::exec_err!(
53 "The {} function can only accept strings, but got {:?}.",
54 name.to_uppercase(),
55 **value_type
56 );
57 }
58 },
59 data_type => {
60 return datafusion_common::exec_err!(
61 "The {} function can only accept strings, but got {:?}.",
62 name.to_uppercase(),
63 data_type
64 );
65 }
66 })
67 }
68 };
69}
70
71get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
73
74get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
76
77pub fn make_scalar_function<F>(
81 inner: F,
82 hints: Vec<Hint>,
83) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
84where
85 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
86{
87 move |args: &[ColumnarValue]| {
88 let len = args
91 .iter()
92 .fold(Option::<usize>::None, |acc, arg| match arg {
93 ColumnarValue::Scalar(_) => acc,
94 ColumnarValue::Array(a) => Some(a.len()),
95 });
96
97 let is_scalar = len.is_none();
98
99 let inferred_length = len.unwrap_or(1);
100 let args = args
101 .iter()
102 .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
103 .map(|(arg, hint)| {
104 let expansion_len = match hint {
107 Hint::AcceptsSingular => 1,
108 Hint::Pad => inferred_length,
109 };
110 arg.to_array(expansion_len)
111 })
112 .collect::<Result<Vec<_>>>()?;
113
114 let result = (inner)(&args);
115 if is_scalar {
116 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
118 result.map(ColumnarValue::Scalar)
119 } else {
120 result.map(ColumnarValue::Array)
121 }
122 }
123}
124
125pub fn calculate_binary_math<L, R, O, F>(
132 left: &dyn Array,
133 right: &ColumnarValue,
134 fun: F,
135) -> Result<Arc<PrimitiveArray<O>>>
136where
137 R: ArrowPrimitiveType,
138 L: ArrowPrimitiveType,
139 O: ArrowPrimitiveType,
140 F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
141 R::Native: TryFrom<ScalarValue>,
142{
143 let left = left.as_primitive::<L>();
144 let right = right.cast_to(&R::DATA_TYPE, None)?;
145 let result = match right {
146 ColumnarValue::Scalar(scalar) => {
147 let right = R::Native::try_from(scalar.clone()).map_err(|_| {
148 DataFusionError::NotImplemented(format!(
149 "Cannot convert scalar value {} to {}",
150 &scalar,
151 R::DATA_TYPE
152 ))
153 })?;
154 left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))?
155 }
156 ColumnarValue::Array(right) => {
157 let right = right.as_primitive::<R>();
158 try_binary::<_, _, _, O>(left, right, &fun)?
159 }
160 };
161 Ok(Arc::new(result) as _)
162}
163
164pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
166 if scale < 0 {
167 Err(ArrowError::ComputeError(
168 "Negative scale is not supported".into(),
169 ))
170 } else if scale == 0 {
171 Ok(value)
172 } else {
173 match i128::from(10).checked_pow(scale as u32) {
174 Some(divisor) => Ok(value / divisor),
175 None => Err(ArrowError::ComputeError(format!(
176 "Cannot get a power of {scale}"
177 ))),
178 }
179 }
180}
181
182#[cfg(test)]
183pub mod test {
184 macro_rules! test_function {
192 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
193 let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
194 let func = $FUNC;
195
196 let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
197 let cardinality = $ARGS
198 .iter()
199 .fold(Option::<usize>::None, |acc, arg| match arg {
200 ColumnarValue::Scalar(_) => acc,
201 ColumnarValue::Array(a) => Some(a.len()),
202 })
203 .unwrap_or(1);
204
205 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
206 ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
207 ColumnarValue::Array(_) => None,
208 }).collect::<Vec<_>>();
209 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
210
211 let nullables = $ARGS.iter().map(|arg| match arg {
212 ColumnarValue::Scalar(scalar) => scalar.is_null(),
213 ColumnarValue::Array(a) => a.null_count() > 0,
214 }).collect::<Vec<_>>();
215
216 let field_array = data_array.into_iter().zip(nullables).enumerate()
217 .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
218 .map(std::sync::Arc::new)
219 .collect::<Vec<_>>();
220
221 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
222 arg_fields: &field_array,
223 scalar_arguments: &scalar_arguments_refs,
224 });
225 let arg_fields = $ARGS.iter()
226 .enumerate()
227 .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
228 .collect::<Vec<_>>();
229
230 match expected {
231 Ok(expected) => {
232 assert_eq!(return_field.is_ok(), true);
233 let return_field = return_field.unwrap();
234 let return_type = return_field.data_type();
235 assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
236
237 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
238 args: $ARGS,
239 arg_fields,
240 number_rows: cardinality,
241 return_field,
242 config_options: $CONFIG_OPTIONS
243 });
244 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
245
246 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
247 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
248 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
249
250 match expected {
252 Some(v) => assert_eq!(result.value(0), v),
253 None => assert!(result.is_null(0)),
254 };
255 }
256 Err(expected_error) => {
257 if let Ok(return_field) = return_field {
258 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
260 args: $ARGS,
261 arg_fields,
262 number_rows: cardinality,
263 return_field,
264 config_options: $CONFIG_OPTIONS,
265 }) {
266 Ok(_) => assert!(false, "expected error"),
267 Err(error) => {
268 assert!(expected_error
269 .strip_backtrace()
270 .starts_with(&error.strip_backtrace()));
271 }
272 }
273 } else if let Err(error) = return_field {
274 datafusion_common::assert_contains!(
275 expected_error.strip_backtrace(),
276 error.strip_backtrace()
277 );
278 }
279 }
280 };
281 };
282
283 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
284 test_function!(
285 $FUNC,
286 $ARGS,
287 $EXPECTED,
288 $EXPECTED_TYPE,
289 $EXPECTED_DATA_TYPE,
290 $ARRAY_TYPE,
291 std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
292 )
293 };
294 }
295
296 use arrow::datatypes::DataType;
297 #[allow(unused_imports)]
298 pub(crate) use test_function;
299
300 use super::*;
301
302 #[test]
303 fn string_to_int_type() {
304 let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
305 assert_eq!(v, DataType::Int32);
306
307 let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
308 assert_eq!(v, DataType::Int32);
309
310 let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
311 assert_eq!(v, DataType::Int64);
312 }
313
314 #[test]
315 fn test_decimal128_to_i128() {
316 let cases = [
317 (123, 0, Some(123)),
318 (1230, 1, Some(123)),
319 (123000, 3, Some(123)),
320 (1, 0, Some(1)),
321 (123, -3, None),
322 (123, i8::MAX, None),
323 (i128::MAX, 0, Some(i128::MAX)),
324 (i128::MAX, 3, Some(i128::MAX / 1000)),
325 ];
326
327 for (value, scale, expected) in cases {
328 match decimal128_to_i128(value, scale) {
329 Ok(actual) => {
330 assert_eq!(
331 actual,
332 expected.expect("Got value but expected none"),
333 "{value} and {scale} vs {expected:?}"
334 );
335 }
336 Err(_) => assert!(expected.is_none()),
337 }
338 }
339 }
340}