1use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, DecimalType};
21use arrow::error::ArrowError;
22use datafusion_common::{DataFusionError, Result, ScalarValue};
23use datafusion_expr::ColumnarValue;
24use datafusion_expr::function::Hint;
25use std::sync::Arc;
26
27macro_rules! get_optimal_return_type {
37 ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
38 pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
39 Ok(match arg_type {
40 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
42 DataType::Utf8 | DataType::Binary => $utf8Type,
44 DataType::Utf8View | DataType::BinaryView => $utf8Type,
46 DataType::Null => DataType::Null,
47 DataType::Dictionary(_, value_type) => match **value_type {
48 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
49 DataType::Utf8 | DataType::Binary => $utf8Type,
50 DataType::Null => DataType::Null,
51 _ => {
52 return datafusion_common::exec_err!(
53 "The {} function can only accept strings, but got {:?}.",
54 name.to_uppercase(),
55 **value_type
56 );
57 }
58 },
59 data_type => {
60 return datafusion_common::exec_err!(
61 "The {} function can only accept strings, but got {:?}.",
62 name.to_uppercase(),
63 data_type
64 );
65 }
66 })
67 }
68 };
69}
70
71get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
73
74get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
76
77pub fn make_scalar_function<F>(
81 inner: F,
82 hints: Vec<Hint>,
83) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
84where
85 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
86{
87 move |args: &[ColumnarValue]| {
88 let len = args
91 .iter()
92 .fold(Option::<usize>::None, |acc, arg| match arg {
93 ColumnarValue::Scalar(_) => acc,
94 ColumnarValue::Array(a) => Some(a.len()),
95 });
96
97 let is_scalar = len.is_none();
98
99 let inferred_length = len.unwrap_or(1);
100 let args = args
101 .iter()
102 .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
103 .map(|(arg, hint)| {
104 let expansion_len = match hint {
107 Hint::AcceptsSingular => 1,
108 Hint::Pad => inferred_length,
109 };
110 arg.to_array(expansion_len)
111 })
112 .collect::<Result<Vec<_>>>()?;
113
114 let result = (inner)(&args);
115 if is_scalar {
116 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
118 result.map(ColumnarValue::Scalar)
119 } else {
120 result.map(ColumnarValue::Array)
121 }
122 }
123}
124
125pub fn calculate_binary_math<L, R, O, F>(
132 left: &dyn Array,
133 right: &ColumnarValue,
134 fun: F,
135) -> Result<Arc<PrimitiveArray<O>>>
136where
137 L: ArrowPrimitiveType,
138 R: ArrowPrimitiveType,
139 O: ArrowPrimitiveType,
140 F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
141 R::Native: TryFrom<ScalarValue>,
142{
143 let left = left.as_primitive::<L>();
144 let right = right.cast_to(&R::DATA_TYPE, None)?;
145 let result = match right {
146 ColumnarValue::Scalar(scalar) => {
147 if scalar.is_null() {
148 PrimitiveArray::<O>::new_null(1)
151 } else {
152 let right = R::Native::try_from(scalar.clone()).map_err(|_| {
153 DataFusionError::NotImplemented(format!(
154 "Cannot convert scalar value {} to {}",
155 &scalar,
156 R::DATA_TYPE
157 ))
158 })?;
159 left.try_unary::<_, O, _>(|lvalue| fun(lvalue, right))?
160 }
161 }
162 ColumnarValue::Array(right) => {
163 let right = right.as_primitive::<R>();
164 try_binary::<_, _, _, O>(left, right, &fun)?
165 }
166 };
167 Ok(Arc::new(result) as _)
168}
169
170pub fn calculate_binary_decimal_math<L, R, O, F>(
178 left: &dyn Array,
179 right: &ColumnarValue,
180 fun: F,
181 precision: u8,
182 scale: i8,
183) -> Result<Arc<PrimitiveArray<O>>>
184where
185 L: DecimalType,
186 R: ArrowPrimitiveType,
187 O: DecimalType,
188 F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
189 R::Native: TryFrom<ScalarValue>,
190{
191 let result_array = calculate_binary_math::<L, R, O, F>(left, right, fun)?;
192 Ok(Arc::new(
193 result_array
194 .as_ref()
195 .clone()
196 .with_precision_and_scale(precision, scale)?,
197 ))
198}
199
200pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
202 if scale < 0 {
203 Err(ArrowError::ComputeError(
204 "Negative scale is not supported".into(),
205 ))
206 } else if scale == 0 {
207 Ok(value)
208 } else {
209 match i128::from(10).checked_pow(scale as u32) {
210 Some(divisor) => Ok(value / divisor),
211 None => Err(ArrowError::ComputeError(format!(
212 "Cannot get a power of {scale}"
213 ))),
214 }
215 }
216}
217
218pub fn decimal32_to_i32(value: i32, scale: i8) -> Result<i32, ArrowError> {
219 if scale < 0 {
220 Err(ArrowError::ComputeError(
221 "Negative scale is not supported".into(),
222 ))
223 } else if scale == 0 {
224 Ok(value)
225 } else {
226 match 10_i32.checked_pow(scale as u32) {
227 Some(divisor) => Ok(value / divisor),
228 None => Err(ArrowError::ComputeError(format!(
229 "Cannot get a power of {scale}"
230 ))),
231 }
232 }
233}
234
235pub fn decimal64_to_i64(value: i64, scale: i8) -> Result<i64, ArrowError> {
236 if scale < 0 {
237 Err(ArrowError::ComputeError(
238 "Negative scale is not supported".into(),
239 ))
240 } else if scale == 0 {
241 Ok(value)
242 } else {
243 match i64::from(10).checked_pow(scale as u32) {
244 Some(divisor) => Ok(value / divisor),
245 None => Err(ArrowError::ComputeError(format!(
246 "Cannot get a power of {scale}"
247 ))),
248 }
249 }
250}
251
252#[cfg(test)]
253pub mod test {
254 macro_rules! test_function {
262 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
263 let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
264 let func = $FUNC;
265
266 let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
267 let cardinality = $ARGS
268 .iter()
269 .fold(Option::<usize>::None, |acc, arg| match arg {
270 ColumnarValue::Scalar(_) => acc,
271 ColumnarValue::Array(a) => Some(a.len()),
272 })
273 .unwrap_or(1);
274
275 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
276 ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
277 ColumnarValue::Array(_) => None,
278 }).collect::<Vec<_>>();
279 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
280
281 let nullables = $ARGS.iter().map(|arg| match arg {
282 ColumnarValue::Scalar(scalar) => scalar.is_null(),
283 ColumnarValue::Array(a) => a.null_count() > 0,
284 }).collect::<Vec<_>>();
285
286 let field_array = data_array.into_iter().zip(nullables).enumerate()
287 .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
288 .map(std::sync::Arc::new)
289 .collect::<Vec<_>>();
290
291 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
292 arg_fields: &field_array,
293 scalar_arguments: &scalar_arguments_refs,
294 });
295 let arg_fields = $ARGS.iter()
296 .enumerate()
297 .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
298 .collect::<Vec<_>>();
299
300 match expected {
301 Ok(expected) => {
302 assert_eq!(return_field.is_ok(), true);
303 let return_field = return_field.unwrap();
304 let return_type = return_field.data_type();
305 assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
306
307 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
308 args: $ARGS,
309 arg_fields,
310 number_rows: cardinality,
311 return_field,
312 config_options: $CONFIG_OPTIONS
313 });
314 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
315
316 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
317 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
318 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
319
320 match expected {
322 Some(v) => assert_eq!(result.value(0), v),
323 None => assert!(result.is_null(0)),
324 };
325 }
326 Err(expected_error) => {
327 if let Ok(return_field) = return_field {
328 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
330 args: $ARGS,
331 arg_fields,
332 number_rows: cardinality,
333 return_field,
334 config_options: $CONFIG_OPTIONS,
335 }) {
336 Ok(_) => assert!(false, "expected error"),
337 Err(error) => {
338 assert!(expected_error
339 .strip_backtrace()
340 .starts_with(&error.strip_backtrace()));
341 }
342 }
343 } else if let Err(error) = return_field {
344 datafusion_common::assert_contains!(
345 expected_error.strip_backtrace(),
346 error.strip_backtrace()
347 );
348 }
349 }
350 };
351 };
352
353 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
354 test_function!(
355 $FUNC,
356 $ARGS,
357 $EXPECTED,
358 $EXPECTED_TYPE,
359 $EXPECTED_DATA_TYPE,
360 $ARRAY_TYPE,
361 std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
362 )
363 };
364 }
365
366 use arrow::datatypes::DataType;
367 use itertools::Either;
368 pub(crate) use test_function;
369
370 use super::*;
371
372 #[test]
373 fn string_to_int_type() {
374 let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
375 assert_eq!(v, DataType::Int32);
376
377 let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
378 assert_eq!(v, DataType::Int32);
379
380 let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
381 assert_eq!(v, DataType::Int64);
382 }
383
384 #[test]
385 fn test_decimal128_to_i128() {
386 let cases = [
387 (123, 0, Some(123)),
388 (1230, 1, Some(123)),
389 (123000, 3, Some(123)),
390 (1, 0, Some(1)),
391 (123, -3, None),
392 (123, i8::MAX, None),
393 (i128::MAX, 0, Some(i128::MAX)),
394 (i128::MAX, 3, Some(i128::MAX / 1000)),
395 ];
396
397 for (value, scale, expected) in cases {
398 match decimal128_to_i128(value, scale) {
399 Ok(actual) => {
400 assert_eq!(
401 actual,
402 expected.expect("Got value but expected none"),
403 "{value} and {scale} vs {expected:?}"
404 );
405 }
406 Err(_) => assert!(expected.is_none()),
407 }
408 }
409 }
410
411 #[test]
412 fn test_decimal32_to_i32() {
413 let cases: [(i32, i8, Either<i32, String>); _] = [
414 (123, 0, Either::Left(123)),
415 (1230, 1, Either::Left(123)),
416 (123000, 3, Either::Left(123)),
417 (1234567, 2, Either::Left(12345)),
418 (-1234567, 2, Either::Left(-12345)),
419 (1, 0, Either::Left(1)),
420 (
421 123,
422 -3,
423 Either::Right("Negative scale is not supported".into()),
424 ),
425 (
426 123,
427 i8::MAX,
428 Either::Right("Cannot get a power of 127".into()),
429 ),
430 (999999999, 0, Either::Left(999999999)),
431 (999999999, 3, Either::Left(999999)),
432 ];
433
434 for (value, scale, expected) in cases {
435 match decimal32_to_i32(value, scale) {
436 Ok(actual) => {
437 let expected_value =
438 expected.left().expect("Got value but expected none");
439 assert_eq!(
440 actual, expected_value,
441 "{value} and {scale} vs {expected_value:?}"
442 );
443 }
444 Err(ArrowError::ComputeError(msg)) => {
445 assert_eq!(
446 msg,
447 expected.right().expect("Got error but expected value")
448 );
449 }
450 Err(_) => {
451 assert!(expected.is_right())
452 }
453 }
454 }
455 }
456
457 #[test]
458 fn test_decimal64_to_i64() {
459 let cases: [(i64, i8, Either<i64, String>); _] = [
460 (123, 0, Either::Left(123)),
461 (1234567890, 2, Either::Left(12345678)),
462 (-1234567890, 2, Either::Left(-12345678)),
463 (
464 123,
465 -3,
466 Either::Right("Negative scale is not supported".into()),
467 ),
468 (
469 123,
470 i8::MAX,
471 Either::Right("Cannot get a power of 127".into()),
472 ),
473 (
474 999999999999999999i64,
475 0,
476 Either::Left(999999999999999999i64),
477 ),
478 (
479 999999999999999999i64,
480 3,
481 Either::Left(999999999999999999i64 / 1000),
482 ),
483 (
484 -999999999999999999i64,
485 3,
486 Either::Left(-999999999999999999i64 / 1000),
487 ),
488 ];
489
490 for (value, scale, expected) in cases {
491 match decimal64_to_i64(value, scale) {
492 Ok(actual) => {
493 let expected_value =
494 expected.left().expect("Got value but expected none");
495 assert_eq!(
496 actual, expected_value,
497 "{value} and {scale} vs {expected_value:?}"
498 );
499 }
500 Err(ArrowError::ComputeError(msg)) => {
501 assert_eq!(
502 msg,
503 expected.right().expect("Got error but expected value")
504 );
505 }
506 Err(_) => {
507 assert!(expected.is_right())
508 }
509 }
510 }
511 }
512}