1use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::{Field, FieldRef};
21use arrow::{
22 compute::can_cast_types,
23 datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27 ListCoercion, base_type, coerced_fixed_size_list_to_list,
28};
29use datafusion_common::{
30 Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36 type_coercion::binary::comparison_coercion_numeric,
37 type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42pub trait UDFCoercionExt {
45 fn name(&self) -> &str;
47 fn signature(&self) -> &Signature;
50 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
53}
54
55impl UDFCoercionExt for ScalarUDF {
56 fn name(&self) -> &str {
57 self.name()
58 }
59
60 fn signature(&self) -> &Signature {
61 self.signature()
62 }
63
64 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
65 self.coerce_types(arg_types)
66 }
67}
68
69impl UDFCoercionExt for AggregateUDF {
70 fn name(&self) -> &str {
71 self.name()
72 }
73
74 fn signature(&self) -> &Signature {
75 self.signature()
76 }
77
78 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79 self.coerce_types(arg_types)
80 }
81}
82
83impl UDFCoercionExt for WindowUDF {
84 fn name(&self) -> &str {
85 self.name()
86 }
87
88 fn signature(&self) -> &Signature {
89 self.signature()
90 }
91
92 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
93 self.coerce_types(arg_types)
94 }
95}
96
97pub fn fields_with_udf<F: UDFCoercionExt>(
105 current_fields: &[FieldRef],
106 func: &F,
107) -> Result<Vec<FieldRef>> {
108 let signature = func.signature();
109 let type_signature = &signature.type_signature;
110
111 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
112 if type_signature.supports_zero_argument() {
113 return Ok(vec![]);
114 } else if type_signature.used_to_support_zero_arguments() {
115 return plan_err!(
117 "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
118 func.name()
119 );
120 } else {
121 return plan_err!("'{}' does not support zero arguments", func.name());
122 }
123 }
124 let current_types = current_fields
125 .iter()
126 .map(|f| f.data_type())
127 .cloned()
128 .collect::<Vec<_>>();
129
130 let valid_types = get_valid_types_with_udf(type_signature, ¤t_types, func)?;
131 if valid_types
132 .iter()
133 .any(|data_type| data_type == ¤t_types)
134 {
135 return Ok(current_fields.to_vec());
136 }
137
138 let updated_types =
139 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
140
141 Ok(current_fields
142 .iter()
143 .zip(updated_types)
144 .map(|(current_field, new_type)| {
145 current_field.as_ref().clone().with_data_type(new_type)
146 })
147 .map(Arc::new)
148 .collect())
149}
150
151#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
159pub fn data_types_with_scalar_udf(
160 current_types: &[DataType],
161 func: &ScalarUDF,
162) -> Result<Vec<DataType>> {
163 let current_fields = current_types
164 .iter()
165 .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
166 .collect::<Vec<_>>();
167 Ok(fields_with_udf(¤t_fields, func)?
168 .iter()
169 .map(|f| f.data_type().clone())
170 .collect())
171}
172
173#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
181pub fn fields_with_aggregate_udf(
182 current_fields: &[FieldRef],
183 func: &AggregateUDF,
184) -> Result<Vec<FieldRef>> {
185 fields_with_udf(current_fields, func)
186}
187
188#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
196pub fn fields_with_window_udf(
197 current_fields: &[FieldRef],
198 func: &WindowUDF,
199) -> Result<Vec<FieldRef>> {
200 fields_with_udf(current_fields, func)
201}
202
203#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
211pub fn data_types(
212 function_name: impl AsRef<str>,
213 current_types: &[DataType],
214 signature: &Signature,
215) -> Result<Vec<DataType>> {
216 let type_signature = &signature.type_signature;
217
218 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
219 if type_signature.supports_zero_argument() {
220 return Ok(vec![]);
221 } else if type_signature.used_to_support_zero_arguments() {
222 return plan_err!(
224 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
225 function_name.as_ref()
226 );
227 } else {
228 return plan_err!(
229 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
230 function_name.as_ref()
231 );
232 }
233 }
234
235 let valid_types =
236 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
237 if valid_types
238 .iter()
239 .any(|data_type| data_type == current_types)
240 {
241 return Ok(current_types.to_vec());
242 }
243
244 try_coerce_types(
245 function_name.as_ref(),
246 valid_types,
247 current_types,
248 type_signature,
249 )
250}
251
252fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
253 match type_signature {
254 TypeSignature::OneOf(type_signatures) => {
255 type_signatures.iter().all(is_well_supported_signature)
256 }
257 TypeSignature::UserDefined
258 | TypeSignature::Numeric(_)
259 | TypeSignature::String(_)
260 | TypeSignature::Coercible(_)
261 | TypeSignature::Any(_)
262 | TypeSignature::Nullary
263 | TypeSignature::Comparable(_) => true,
264 TypeSignature::Variadic(_)
265 | TypeSignature::VariadicAny
266 | TypeSignature::Uniform(_, _)
267 | TypeSignature::Exact(_)
268 | TypeSignature::ArraySignature(_) => false,
269 }
270}
271
272fn try_coerce_types(
273 function_name: &str,
274 valid_types: Vec<Vec<DataType>>,
275 current_types: &[DataType],
276 type_signature: &TypeSignature,
277) -> Result<Vec<DataType>> {
278 let mut valid_types = valid_types;
279
280 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
282 if !type_signature.is_one_of() {
285 assert_eq!(valid_types.len(), 1);
286 }
287
288 let valid_types = valid_types.swap_remove(0);
289 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
290 return Ok(t);
291 }
292 } else {
293 for valid_types in valid_types {
297 if let Some(types) = maybe_data_types(&valid_types, current_types) {
298 return Ok(types);
299 }
300 }
301 }
302
303 plan_err!(
305 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
306 current_types.iter().join(", ")
307 )
308}
309
310fn get_valid_types_with_udf<F: UDFCoercionExt>(
311 signature: &TypeSignature,
312 current_types: &[DataType],
313 func: &F,
314) -> Result<Vec<Vec<DataType>>> {
315 let valid_types = match signature {
316 TypeSignature::UserDefined => match func.coerce_types(current_types) {
317 Ok(coerced_types) => vec![coerced_types],
318 Err(e) => {
319 return exec_err!(
320 "Function '{}' user-defined coercion failed with {:?}",
321 func.name(),
322 e.strip_backtrace()
323 );
324 }
325 },
326 TypeSignature::OneOf(signatures) => {
327 let mut res = vec![];
328 let mut errors = vec![];
329 for sig in signatures {
330 match get_valid_types_with_udf(sig, current_types, func) {
331 Ok(valid_types) => {
332 res.extend(valid_types);
333 }
334 Err(e) => {
335 errors.push(e.to_string());
336 }
337 }
338 }
339
340 if res.is_empty() {
342 return internal_err!(
343 "Function '{}' failed to match any signature, errors: {}",
344 func.name(),
345 errors.join(",")
346 );
347 } else {
348 res
349 }
350 }
351 _ => get_valid_types(func.name(), signature, current_types)?,
352 };
353
354 Ok(valid_types)
355}
356
357fn get_valid_types(
359 function_name: &str,
360 signature: &TypeSignature,
361 current_types: &[DataType],
362) -> Result<Vec<Vec<DataType>>> {
363 fn array_valid_types(
364 function_name: &str,
365 current_types: &[DataType],
366 arguments: &[ArrayFunctionArgument],
367 array_coercion: Option<&ListCoercion>,
368 ) -> Result<Vec<Vec<DataType>>> {
369 if current_types.len() != arguments.len() {
370 return Ok(vec![vec![]]);
371 }
372
373 let mut large_list = false;
374 let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
375 let mut list_sizes = Vec::with_capacity(arguments.len());
376 let mut element_types = Vec::with_capacity(arguments.len());
377 let mut nested_item_nullability = Vec::with_capacity(arguments.len());
378 for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
379 match argument {
380 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
381 nested_item_nullability.push(None);
382 }
383 ArrayFunctionArgument::Element => {
384 element_types.push(current_type.clone());
385 nested_item_nullability.push(None);
386 }
387 ArrayFunctionArgument::Array => match current_type {
388 DataType::Null => {
389 element_types.push(DataType::Null);
390 nested_item_nullability.push(None);
391 }
392 DataType::List(field) => {
393 element_types.push(field.data_type().clone());
394 nested_item_nullability.push(Some(field.is_nullable()));
395 fixed_size = false;
396 }
397 DataType::LargeList(field) => {
398 element_types.push(field.data_type().clone());
399 nested_item_nullability.push(Some(field.is_nullable()));
400 large_list = true;
401 fixed_size = false;
402 }
403 DataType::FixedSizeList(field, size) => {
404 element_types.push(field.data_type().clone());
405 nested_item_nullability.push(Some(field.is_nullable()));
406 list_sizes.push(*size)
407 }
408 arg_type => {
409 plan_err!("{function_name} does not support type {arg_type}")?
410 }
411 },
412 }
413 }
414
415 debug_assert_eq!(nested_item_nullability.len(), arguments.len());
416
417 let Some(element_type) = type_union_resolution(&element_types) else {
418 return Ok(vec![vec![]]);
419 };
420
421 if !fixed_size {
422 list_sizes.clear()
423 };
424
425 let mut list_sizes = list_sizes.into_iter();
426 let valid_types = arguments
427 .iter()
428 .zip(current_types.iter())
429 .zip(nested_item_nullability)
430 .map(|((argument_type, current_type), is_nested_item_nullable)| {
431 match argument_type {
432 ArrayFunctionArgument::Index => DataType::Int64,
433 ArrayFunctionArgument::String => DataType::Utf8,
434 ArrayFunctionArgument::Element => element_type.clone(),
435 ArrayFunctionArgument::Array => {
436 if current_type.is_null() {
437 DataType::Null
438 } else if large_list {
439 DataType::new_large_list(
440 element_type.clone(),
441 is_nested_item_nullable.unwrap_or(true),
442 )
443 } else if let Some(size) = list_sizes.next() {
444 DataType::new_fixed_size_list(
445 element_type.clone(),
446 size,
447 is_nested_item_nullable.unwrap_or(true),
448 )
449 } else {
450 DataType::new_list(
451 element_type.clone(),
452 is_nested_item_nullable.unwrap_or(true),
453 )
454 }
455 }
456 }
457 });
458
459 Ok(vec![valid_types.collect()])
460 }
461
462 fn recursive_array(array_type: &DataType) -> Option<DataType> {
463 match array_type {
464 DataType::List(_)
465 | DataType::LargeList(_)
466 | DataType::FixedSizeList(_, _) => {
467 let array_type = coerced_fixed_size_list_to_list(array_type);
468 Some(array_type)
469 }
470 _ => None,
471 }
472 }
473
474 fn function_length_check(
475 function_name: &str,
476 length: usize,
477 expected_length: usize,
478 ) -> Result<()> {
479 if length != expected_length {
480 return plan_err!(
481 "Function '{function_name}' expects {expected_length} arguments but received {length}"
482 );
483 }
484 Ok(())
485 }
486
487 let valid_types = match signature {
488 TypeSignature::Variadic(valid_types) => valid_types
489 .iter()
490 .map(|valid_type| vec![valid_type.clone(); current_types.len()])
491 .collect(),
492 TypeSignature::String(number) => {
493 function_length_check(function_name, current_types.len(), *number)?;
494
495 let mut new_types = Vec::with_capacity(current_types.len());
496 for data_type in current_types.iter() {
497 let logical_data_type: NativeType = data_type.into();
498 if logical_data_type == NativeType::String {
499 new_types.push(data_type.to_owned());
500 } else if logical_data_type == NativeType::Null {
501 new_types.push(DataType::Utf8);
503 } else {
504 return plan_err!(
505 "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
506 );
507 }
508 }
509
510 fn find_common_type(
512 function_name: &str,
513 lhs_type: &DataType,
514 rhs_type: &DataType,
515 ) -> Result<DataType> {
516 match (lhs_type, rhs_type) {
517 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
518 find_common_type(function_name, lhs, rhs)
519 }
520 (DataType::Dictionary(_, v), other)
521 | (other, DataType::Dictionary(_, v)) => {
522 find_common_type(function_name, v, other)
523 }
524 _ => {
525 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
526 Ok(coerced_type)
527 } else {
528 plan_err!(
529 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
530 )
531 }
532 }
533 }
534 }
535
536 let mut coerced_type = new_types.first().unwrap().to_owned();
538 for t in new_types.iter().skip(1) {
539 coerced_type = find_common_type(function_name, &coerced_type, t)?;
540 }
541
542 fn base_type_or_default_type(data_type: &DataType) -> DataType {
543 if let DataType::Dictionary(_, v) = data_type {
544 base_type_or_default_type(v)
545 } else {
546 data_type.to_owned()
547 }
548 }
549
550 vec![vec![base_type_or_default_type(&coerced_type); *number]]
551 }
552 TypeSignature::Numeric(number) => {
553 function_length_check(function_name, current_types.len(), *number)?;
554
555 let mut valid_type = current_types.first().unwrap().to_owned();
557 for t in current_types.iter().skip(1) {
558 let logical_data_type: NativeType = t.into();
559 if logical_data_type == NativeType::Null {
560 continue;
561 }
562
563 if !logical_data_type.is_numeric() {
564 return plan_err!(
565 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
566 );
567 }
568
569 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
570 valid_type = coerced_type;
571 } else {
572 return plan_err!(
573 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
574 );
575 }
576 }
577
578 let logical_data_type: NativeType = valid_type.clone().into();
579 if logical_data_type == NativeType::Null {
583 valid_type = DataType::Float64;
584 } else if !logical_data_type.is_numeric() {
585 return plan_err!(
586 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
587 );
588 }
589
590 vec![vec![valid_type; *number]]
591 }
592 TypeSignature::Comparable(num) => {
593 function_length_check(function_name, current_types.len(), *num)?;
594 let mut target_type = current_types[0].to_owned();
595 for data_type in current_types.iter().skip(1) {
596 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
597 target_type = dt;
598 } else {
599 return plan_err!(
600 "For function '{function_name}' {target_type} and {data_type} is not comparable"
601 );
602 }
603 }
604 if target_type.is_null() {
606 vec![vec![DataType::Utf8View; *num]]
607 } else {
608 vec![vec![target_type; *num]]
609 }
610 }
611 TypeSignature::Coercible(param_types) => {
612 function_length_check(function_name, current_types.len(), param_types.len())?;
613
614 let mut new_types = Vec::with_capacity(current_types.len());
615 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
616 let current_native_type: NativeType = current_type.into();
617
618 if param
619 .desired_type()
620 .matches_native_type(¤t_native_type)
621 {
622 let casted_type = param
623 .desired_type()
624 .default_casted_type(¤t_native_type, current_type)?;
625
626 new_types.push(casted_type);
627 } else if param
628 .allowed_source_types()
629 .iter()
630 .any(|t| t.matches_native_type(¤t_native_type))
631 {
632 let default_casted_type = param.default_casted_type().unwrap();
634 let casted_type =
635 default_casted_type.default_cast_for(current_type)?;
636 new_types.push(casted_type);
637 } else {
638 let hint = if matches!(current_native_type, NativeType::Binary) {
639 "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String."
640 } else {
641 ""
642 };
643 return plan_err!(
644 "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}",
645 param.desired_type(),
646 current_native_type,
647 current_type
648 );
649 }
650 }
651
652 vec![new_types]
653 }
654 TypeSignature::Uniform(number, valid_types) => {
655 if *number == 0 {
656 return plan_err!(
657 "The function '{function_name}' expected at least one argument"
658 );
659 }
660
661 valid_types
662 .iter()
663 .map(|valid_type| vec![valid_type.clone(); *number])
664 .collect()
665 }
666 TypeSignature::UserDefined => {
667 return internal_err!(
668 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
669 );
670 }
671 TypeSignature::VariadicAny => {
672 if current_types.is_empty() {
673 return plan_err!(
674 "Function '{function_name}' expected at least one argument but received 0"
675 );
676 }
677 vec![current_types.to_vec()]
678 }
679 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
680 TypeSignature::ArraySignature(function_signature) => match function_signature {
681 ArrayFunctionSignature::Array {
682 arguments,
683 array_coercion,
684 } => array_valid_types(
685 function_name,
686 current_types,
687 arguments,
688 array_coercion.as_ref(),
689 )?,
690 ArrayFunctionSignature::RecursiveArray => {
691 if current_types.len() != 1 {
692 return Ok(vec![vec![]]);
693 }
694 recursive_array(¤t_types[0])
695 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
696 }
697 ArrayFunctionSignature::MapArray => {
698 if current_types.len() != 1 {
699 return Ok(vec![vec![]]);
700 }
701
702 match ¤t_types[0] {
703 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
704 _ => vec![vec![]],
705 }
706 }
707 },
708 TypeSignature::Nullary => {
709 if !current_types.is_empty() {
710 return plan_err!(
711 "The function '{function_name}' expected zero argument but received {}",
712 current_types.len()
713 );
714 }
715 vec![vec![]]
716 }
717 TypeSignature::Any(number) => {
718 if current_types.is_empty() {
719 return plan_err!(
720 "The function '{function_name}' expected at least one argument but received 0"
721 );
722 }
723
724 if current_types.len() != *number {
725 return plan_err!(
726 "The function '{function_name}' expected {number} arguments but received {}",
727 current_types.len()
728 );
729 }
730 vec![current_types.to_vec()]
731 }
732 TypeSignature::OneOf(types) => types
733 .iter()
734 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
735 .flatten()
736 .collect::<Vec<_>>(),
737 };
738
739 Ok(valid_types)
740}
741
742fn maybe_data_types(
749 valid_types: &[DataType],
750 current_types: &[DataType],
751) -> Option<Vec<DataType>> {
752 if valid_types.len() != current_types.len() {
753 return None;
754 }
755
756 let mut new_type = Vec::with_capacity(valid_types.len());
757 for (i, valid_type) in valid_types.iter().enumerate() {
758 let current_type = ¤t_types[i];
759
760 if current_type == valid_type {
761 new_type.push(current_type.clone())
762 } else {
763 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
767 new_type.push(coerced_type)
768 } else {
769 return None;
771 }
772 }
773 }
774 Some(new_type)
775}
776
777fn maybe_data_types_without_coercion(
781 valid_types: &[DataType],
782 current_types: &[DataType],
783) -> Option<Vec<DataType>> {
784 if valid_types.len() != current_types.len() {
785 return None;
786 }
787
788 let mut new_type = Vec::with_capacity(valid_types.len());
789 for (i, valid_type) in valid_types.iter().enumerate() {
790 let current_type = ¤t_types[i];
791
792 if current_type == valid_type {
793 new_type.push(current_type.clone())
794 } else if can_cast_types(current_type, valid_type) {
795 new_type.push(valid_type.clone())
797 } else {
798 return None;
799 }
800 }
801 Some(new_type)
802}
803
804#[deprecated(since = "53.0.0", note = "Unused internal function")]
809pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
810 if type_into == type_from {
811 return true;
812 }
813 if let Some(coerced) = coerced_from(type_into, type_from) {
814 return coerced == *type_into;
815 }
816 false
817}
818
819fn coerced_from<'a>(
826 type_into: &'a DataType,
827 type_from: &'a DataType,
828) -> Option<DataType> {
829 use self::DataType::*;
830
831 match (type_into, type_from) {
833 (_, Dictionary(_, value_type))
835 if coerced_from(type_into, value_type).is_some() =>
836 {
837 Some(type_into.clone())
838 }
839 (Dictionary(_, value_type), _)
840 if coerced_from(value_type, type_from).is_some() =>
841 {
842 Some(type_into.clone())
843 }
844 (Int8, Null | Int8) => Some(type_into.clone()),
846 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
847 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
848 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
849 Some(type_into.clone())
850 }
851 (UInt8, Null | UInt8) => Some(type_into.clone()),
852 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
853 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
854 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
855 (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => {
856 Some(type_into.clone())
857 }
858 (
859 Float32,
860 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
861 | Float16 | Float32,
862 ) => Some(type_into.clone()),
863 (
864 Float64,
865 Null
866 | Int8
867 | Int16
868 | Int32
869 | Int64
870 | UInt8
871 | UInt16
872 | UInt32
873 | UInt64
874 | Float16
875 | Float32
876 | Float64
877 | Decimal32(_, _)
878 | Decimal64(_, _)
879 | Decimal128(_, _)
880 | Decimal256(_, _),
881 ) => Some(type_into.clone()),
882 (
883 Timestamp(TimeUnit::Nanosecond, None),
884 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
885 ) => Some(type_into.clone()),
886 (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()),
887 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
889 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
891 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
892
893 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
894
895 (List(_) | LargeList(_), _)
898 if base_type(type_from).is_null()
899 || list_ndims(type_from) == list_ndims(type_into) =>
900 {
901 Some(type_into.clone())
902 }
903 (
905 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
906 FixedSizeList(f_from, size_from),
907 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
908 Some(data_type) if &data_type != f_into.data_type() => {
909 let new_field =
910 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
911 Some(FixedSizeList(new_field, *size_from))
912 }
913 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
914 _ => None,
915 },
916 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
917 match type_from {
918 Timestamp(_, Some(from_tz)) => {
919 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
920 }
921 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
922 Some(Timestamp(*unit, Some("+00".into())))
924 }
925 _ => None,
926 }
927 }
928 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
929 Some(type_into.clone())
930 }
931 _ => None,
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use crate::Volatility;
938
939 use super::*;
940 use arrow::datatypes::Field;
941 use datafusion_common::{
942 assert_contains,
943 types::{logical_binary, logical_int64},
944 };
945 use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
946
947 #[test]
948 fn test_string_conversion() {
949 let cases = vec![
950 (DataType::Utf8View, DataType::Utf8),
951 (DataType::Utf8View, DataType::LargeUtf8),
952 ];
953
954 for case in cases {
955 assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
956 }
957 }
958
959 #[test]
960 fn test_maybe_data_types() {
961 let cases = vec![
963 (
965 vec![DataType::UInt8, DataType::UInt16],
966 vec![DataType::UInt8, DataType::UInt16],
967 Some(vec![DataType::UInt8, DataType::UInt16]),
968 ),
969 (
971 vec![DataType::UInt16, DataType::UInt16],
972 vec![DataType::UInt8, DataType::UInt16],
973 Some(vec![DataType::UInt16, DataType::UInt16]),
974 ),
975 (vec![], vec![], Some(vec![])),
977 (
979 vec![DataType::Boolean, DataType::UInt16],
980 vec![DataType::UInt8, DataType::UInt16],
981 None,
982 ),
983 (
985 vec![DataType::Boolean, DataType::UInt32],
986 vec![DataType::Boolean, DataType::UInt16],
987 Some(vec![DataType::Boolean, DataType::UInt32]),
988 ),
989 (
991 vec![
992 DataType::Timestamp(TimeUnit::Nanosecond, None),
993 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
994 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
995 ],
996 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
997 Some(vec![
998 DataType::Timestamp(TimeUnit::Nanosecond, None),
999 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1000 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1001 ]),
1002 ),
1003 ];
1004
1005 for case in cases {
1006 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1007 }
1008 }
1009
1010 #[test]
1011 fn test_get_valid_types_numeric() -> Result<()> {
1012 let get_valid_types_flatten =
1013 |function_name: &str,
1014 signature: &TypeSignature,
1015 current_types: &[DataType]| {
1016 get_valid_types(function_name, signature, current_types)
1017 .unwrap()
1018 .into_iter()
1019 .flatten()
1020 .collect::<Vec<_>>()
1021 };
1022
1023 let got = get_valid_types_flatten(
1025 "test",
1026 &TypeSignature::Numeric(1),
1027 &[DataType::Int32],
1028 );
1029 assert_eq!(got, [DataType::Int32]);
1030
1031 let got = get_valid_types_flatten(
1033 "test",
1034 &TypeSignature::Numeric(2),
1035 &[DataType::Int32, DataType::Int64],
1036 );
1037 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1038
1039 let got = get_valid_types_flatten(
1041 "test",
1042 &TypeSignature::Numeric(3),
1043 &[DataType::Int32, DataType::Int64, DataType::Float64],
1044 );
1045 assert_eq!(
1046 got,
1047 [DataType::Float64, DataType::Float64, DataType::Float64]
1048 );
1049
1050 let got = get_valid_types(
1052 "test",
1053 &TypeSignature::Numeric(2),
1054 &[DataType::Int32, DataType::Utf8],
1055 )
1056 .unwrap_err();
1057 assert_contains!(
1058 got.to_string(),
1059 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1060 );
1061
1062 let got = get_valid_types_flatten(
1064 "test",
1065 &TypeSignature::Numeric(1),
1066 &[DataType::Null],
1067 );
1068 assert_eq!(got, [DataType::Float64]);
1069
1070 let got = get_valid_types(
1072 "test",
1073 &TypeSignature::Numeric(1),
1074 &[DataType::Timestamp(TimeUnit::Second, None)],
1075 )
1076 .unwrap_err();
1077 assert_contains!(
1078 got.to_string(),
1079 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(s)"
1080 );
1081
1082 Ok(())
1083 }
1084
1085 #[test]
1086 fn test_get_valid_types_one_of() -> Result<()> {
1087 let signature =
1088 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1089
1090 let invalid_types = get_valid_types(
1091 "test",
1092 &signature,
1093 &[DataType::Int32, DataType::Int32, DataType::Int32],
1094 )?;
1095 assert_eq!(invalid_types.len(), 0);
1096
1097 let args = vec![DataType::Int32, DataType::Int32];
1098 let valid_types = get_valid_types("test", &signature, &args)?;
1099 assert_eq!(valid_types.len(), 1);
1100 assert_eq!(valid_types[0], args);
1101
1102 let args = vec![DataType::Int32];
1103 let valid_types = get_valid_types("test", &signature, &args)?;
1104 assert_eq!(valid_types.len(), 1);
1105 assert_eq!(valid_types[0], args);
1106
1107 Ok(())
1108 }
1109
1110 #[test]
1111 fn test_get_valid_types_length_check() -> Result<()> {
1112 let signature = TypeSignature::Numeric(1);
1113
1114 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1115 assert_contains!(
1116 err.to_string(),
1117 "Function 'test' expects 1 arguments but received 0"
1118 );
1119
1120 let err = get_valid_types(
1121 "test",
1122 &signature,
1123 &[DataType::Int32, DataType::Int32, DataType::Int32],
1124 )
1125 .unwrap_err();
1126 assert_contains!(
1127 err.to_string(),
1128 "Function 'test' expects 1 arguments but received 3"
1129 );
1130
1131 Ok(())
1132 }
1133
1134 struct MockUdf(Signature);
1135
1136 impl UDFCoercionExt for MockUdf {
1137 fn name(&self) -> &str {
1138 "test"
1139 }
1140 fn signature(&self) -> &Signature {
1141 &self.0
1142 }
1143 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1144 unimplemented!()
1145 }
1146 }
1147
1148 #[test]
1149 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1150 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1151 let current_fields = vec![Arc::new(Field::new(
1153 "t",
1154 DataType::FixedSizeList(Arc::clone(&inner), 2),
1155 true,
1156 ))];
1157
1158 let signature = Signature::exact(
1159 vec![DataType::FixedSizeList(
1160 Arc::clone(&inner),
1161 FIXED_SIZE_LIST_WILDCARD,
1162 )],
1163 Volatility::Stable,
1164 );
1165
1166 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature))?;
1167 assert_eq!(coerced_fields, current_fields);
1168
1169 let signature = Signature::exact(
1171 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1172 Volatility::Stable,
1173 );
1174 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature));
1175 assert!(coerced_fields.is_err());
1176
1177 let signature = Signature::exact(
1179 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1180 Volatility::Stable,
1181 );
1182 let coerced_fields =
1183 fields_with_udf(¤t_fields, &MockUdf(signature)).unwrap();
1184 assert_eq!(coerced_fields, current_fields);
1185
1186 Ok(())
1187 }
1188
1189 #[test]
1190 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1191 let type_into = DataType::FixedSizeList(
1192 Arc::new(Field::new_list_field(
1193 DataType::FixedSizeList(
1194 Arc::new(Field::new_list_field(DataType::Int32, false)),
1195 FIXED_SIZE_LIST_WILDCARD,
1196 ),
1197 false,
1198 )),
1199 FIXED_SIZE_LIST_WILDCARD,
1200 );
1201
1202 let type_from = DataType::FixedSizeList(
1203 Arc::new(Field::new_list_field(
1204 DataType::FixedSizeList(
1205 Arc::new(Field::new_list_field(DataType::Int8, false)),
1206 4,
1207 ),
1208 false,
1209 )),
1210 3,
1211 );
1212
1213 assert_eq!(
1214 coerced_from(&type_into, &type_from),
1215 Some(DataType::FixedSizeList(
1216 Arc::new(Field::new_list_field(
1217 DataType::FixedSizeList(
1218 Arc::new(Field::new_list_field(DataType::Int32, false)),
1219 4,
1220 ),
1221 false,
1222 )),
1223 3,
1224 ))
1225 );
1226
1227 Ok(())
1228 }
1229
1230 #[test]
1231 fn test_coerced_from_dictionary() {
1232 let type_into =
1233 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1234 let type_from = DataType::Int64;
1235 assert_eq!(coerced_from(&type_into, &type_from), None);
1236
1237 let type_from =
1238 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1239 let type_into = DataType::Int64;
1240 assert_eq!(
1241 coerced_from(&type_into, &type_from),
1242 Some(type_into.clone())
1243 );
1244 }
1245
1246 #[test]
1247 fn test_get_valid_types_array_and_array() -> Result<()> {
1248 let function = "array_and_array";
1249 let signature = Signature::arrays(
1250 2,
1251 Some(ListCoercion::FixedSizedListToList),
1252 Volatility::Immutable,
1253 );
1254
1255 let data_types = vec![
1256 DataType::new_list(DataType::Int32, true),
1257 DataType::new_large_list(DataType::Float64, true),
1258 ];
1259 assert_eq!(
1260 get_valid_types(function, &signature.type_signature, &data_types)?,
1261 vec![vec![
1262 DataType::new_large_list(DataType::Float64, true),
1263 DataType::new_large_list(DataType::Float64, true),
1264 ]]
1265 );
1266
1267 let data_types = vec![
1268 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1269 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1270 ];
1271 assert_eq!(
1272 get_valid_types(function, &signature.type_signature, &data_types)?,
1273 vec![vec![
1274 DataType::new_list(DataType::Int64, true),
1275 DataType::new_list(DataType::Int64, true),
1276 ]]
1277 );
1278
1279 let data_types = vec![
1280 DataType::new_fixed_size_list(DataType::Null, 3, true),
1281 DataType::new_large_list(DataType::Utf8, true),
1282 ];
1283 assert_eq!(
1284 get_valid_types(function, &signature.type_signature, &data_types)?,
1285 vec![vec![
1286 DataType::new_large_list(DataType::Utf8, true),
1287 DataType::new_large_list(DataType::Utf8, true),
1288 ]]
1289 );
1290
1291 Ok(())
1292 }
1293
1294 #[test]
1295 fn test_get_valid_types_array_and_element() -> Result<()> {
1296 let function = "array_and_element";
1297 let signature = Signature::array_and_element(Volatility::Immutable);
1298
1299 let data_types =
1300 vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1301 assert_eq!(
1302 get_valid_types(function, &signature.type_signature, &data_types)?,
1303 vec![vec![
1304 DataType::new_list(DataType::Float64, true),
1305 DataType::Float64,
1306 ]]
1307 );
1308
1309 let data_types = vec![
1310 DataType::new_large_list(DataType::Int32, true),
1311 DataType::Null,
1312 ];
1313 assert_eq!(
1314 get_valid_types(function, &signature.type_signature, &data_types)?,
1315 vec![vec![
1316 DataType::new_large_list(DataType::Int32, true),
1317 DataType::Int32,
1318 ]]
1319 );
1320
1321 let data_types = vec![
1322 DataType::new_fixed_size_list(DataType::Null, 3, true),
1323 DataType::Utf8,
1324 ];
1325 assert_eq!(
1326 get_valid_types(function, &signature.type_signature, &data_types)?,
1327 vec![vec![
1328 DataType::new_list(DataType::Utf8, true),
1329 DataType::Utf8,
1330 ]]
1331 );
1332
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn test_get_valid_types_element_and_array() -> Result<()> {
1338 let function = "element_and_array";
1339 let signature = Signature::element_and_array(Volatility::Immutable);
1340
1341 let data_types = vec![
1342 DataType::new_large_list(DataType::Null, false),
1343 DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1344 ];
1345 assert_eq!(
1346 get_valid_types(function, &signature.type_signature, &data_types)?,
1347 vec![vec![
1348 DataType::new_large_list(DataType::Int64, true),
1349 DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1350 ]]
1351 );
1352
1353 Ok(())
1354 }
1355
1356 #[test]
1357 fn test_coercible_nulls() -> Result<()> {
1358 fn null_input(coercion: Coercion) -> Result<Vec<DataType>> {
1359 fields_with_udf(
1360 &[Field::new("field", DataType::Null, true).into()],
1361 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1362 )
1363 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1364 }
1365
1366 let output = null_input(Coercion::new_exact(TypeSignatureClass::Native(
1368 logical_int64(),
1369 )))?;
1370 assert_eq!(vec![DataType::Int64], output);
1371
1372 let output = null_input(Coercion::new_implicit(
1373 TypeSignatureClass::Native(logical_int64()),
1374 vec![],
1375 NativeType::Int64,
1376 ))?;
1377 assert_eq!(vec![DataType::Int64], output);
1378
1379 let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1381 assert_eq!(vec![DataType::Null], output);
1382
1383 let output = null_input(Coercion::new_implicit(
1384 TypeSignatureClass::Integer,
1385 vec![],
1386 NativeType::Int64,
1387 ))?;
1388 assert_eq!(vec![DataType::Null], output);
1389
1390 Ok(())
1391 }
1392
1393 #[test]
1394 fn test_coercible_dictionary() -> Result<()> {
1395 let dictionary =
1396 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64));
1397 fn dictionary_input(coercion: Coercion) -> Result<Vec<DataType>> {
1398 fields_with_udf(
1399 &[Field::new(
1400 "field",
1401 DataType::Dictionary(
1402 Box::new(DataType::Int8),
1403 Box::new(DataType::Int64),
1404 ),
1405 true,
1406 )
1407 .into()],
1408 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1409 )
1410 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1411 }
1412
1413 let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native(
1415 logical_int64(),
1416 )))?;
1417 assert_eq!(vec![DataType::Int64], output);
1418
1419 let output = dictionary_input(Coercion::new_implicit(
1420 TypeSignatureClass::Native(logical_int64()),
1421 vec![],
1422 NativeType::Int64,
1423 ))?;
1424 assert_eq!(vec![DataType::Int64], output);
1425
1426 let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1428 assert_eq!(vec![dictionary.clone()], output);
1429
1430 let output = dictionary_input(Coercion::new_implicit(
1431 TypeSignatureClass::Integer,
1432 vec![],
1433 NativeType::Int64,
1434 ))?;
1435 assert_eq!(vec![dictionary.clone()], output);
1436
1437 Ok(())
1438 }
1439
1440 #[test]
1441 fn test_coercible_run_end_encoded() -> Result<()> {
1442 let run_end_encoded = DataType::RunEndEncoded(
1443 Field::new("run_ends", DataType::Int16, false).into(),
1444 Field::new("values", DataType::Int64, true).into(),
1445 );
1446 fn run_end_encoded_input(coercion: Coercion) -> Result<Vec<DataType>> {
1447 fields_with_udf(
1448 &[Field::new(
1449 "field",
1450 DataType::RunEndEncoded(
1451 Field::new("run_ends", DataType::Int16, false).into(),
1452 Field::new("values", DataType::Int64, true).into(),
1453 ),
1454 true,
1455 )
1456 .into()],
1457 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1458 )
1459 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1460 }
1461
1462 let output = run_end_encoded_input(Coercion::new_exact(
1464 TypeSignatureClass::Native(logical_int64()),
1465 ))?;
1466 assert_eq!(vec![DataType::Int64], output);
1467
1468 let output = run_end_encoded_input(Coercion::new_implicit(
1469 TypeSignatureClass::Native(logical_int64()),
1470 vec![],
1471 NativeType::Int64,
1472 ))?;
1473 assert_eq!(vec![DataType::Int64], output);
1474
1475 let output =
1477 run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1478 assert_eq!(vec![run_end_encoded.clone()], output);
1479
1480 let output = run_end_encoded_input(Coercion::new_implicit(
1481 TypeSignatureClass::Integer,
1482 vec![],
1483 NativeType::Int64,
1484 ))?;
1485 assert_eq!(vec![run_end_encoded.clone()], output);
1486
1487 Ok(())
1488 }
1489
1490 #[test]
1491 fn test_get_valid_types_coercible_binary() -> Result<()> {
1492 let signature = Signature::coercible(
1493 vec![Coercion::new_exact(TypeSignatureClass::Native(
1494 logical_binary(),
1495 ))],
1496 Volatility::Immutable,
1497 );
1498
1499 for t in [
1501 DataType::Binary,
1502 DataType::BinaryView,
1503 DataType::LargeBinary,
1504 ] {
1505 assert_eq!(
1506 get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1507 vec![vec![t]]
1508 );
1509 }
1510
1511 Ok(())
1512 }
1513
1514 #[test]
1515 fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1516 let function = "fixed_size_arrays";
1517 let signature = Signature::arrays(2, None, Volatility::Immutable);
1518
1519 let data_types = vec![
1520 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1521 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1522 ];
1523 assert_eq!(
1524 get_valid_types(function, &signature.type_signature, &data_types)?,
1525 vec![vec![
1526 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1527 DataType::new_fixed_size_list(DataType::Int64, 5, true),
1528 ]]
1529 );
1530
1531 let data_types = vec![
1532 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1533 DataType::new_list(DataType::Int32, true),
1534 ];
1535 assert_eq!(
1536 get_valid_types(function, &signature.type_signature, &data_types)?,
1537 vec![vec![
1538 DataType::new_list(DataType::Int64, true),
1539 DataType::new_list(DataType::Int64, true),
1540 ]]
1541 );
1542
1543 let data_types = vec![
1544 DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1545 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1546 ];
1547 assert_eq!(
1548 get_valid_types(function, &signature.type_signature, &data_types)?,
1549 vec![vec![]]
1550 );
1551
1552 let data_types = vec![
1553 DataType::new_fixed_size_list(DataType::Int64, 3, false),
1554 DataType::new_list(DataType::Int32, false),
1555 ];
1556 assert_eq!(
1557 get_valid_types(function, &signature.type_signature, &data_types)?,
1558 vec![vec![
1559 DataType::new_list(DataType::Int64, false),
1560 DataType::new_list(DataType::Int64, false),
1561 ]]
1562 );
1563
1564 Ok(())
1565 }
1566}