1use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::{Field, FieldRef};
21use arrow::{
22 compute::can_cast_types,
23 datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27 ListCoercion, base_type, coerced_fixed_size_list_to_list,
28};
29use datafusion_common::{
30 Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36 type_coercion::binary::comparison_coercion_numeric,
37 type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42pub trait UDFCoercionExt {
45 fn name(&self) -> &str;
47 fn signature(&self) -> &Signature;
50 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
53}
54
55impl UDFCoercionExt for ScalarUDF {
56 fn name(&self) -> &str {
57 self.name()
58 }
59
60 fn signature(&self) -> &Signature {
61 self.signature()
62 }
63
64 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
65 self.coerce_types(arg_types)
66 }
67}
68
69impl UDFCoercionExt for AggregateUDF {
70 fn name(&self) -> &str {
71 self.name()
72 }
73
74 fn signature(&self) -> &Signature {
75 self.signature()
76 }
77
78 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79 self.coerce_types(arg_types)
80 }
81}
82
83impl UDFCoercionExt for WindowUDF {
84 fn name(&self) -> &str {
85 self.name()
86 }
87
88 fn signature(&self) -> &Signature {
89 self.signature()
90 }
91
92 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
93 self.coerce_types(arg_types)
94 }
95}
96
97#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
105pub fn data_types_with_scalar_udf(
106 current_types: &[DataType],
107 func: &ScalarUDF,
108) -> Result<Vec<DataType>> {
109 let current_fields = current_types
110 .iter()
111 .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
112 .collect::<Vec<_>>();
113 Ok(fields_with_udf(¤t_fields, func)?
114 .iter()
115 .map(|f| f.data_type().clone())
116 .collect())
117}
118
119#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
127pub fn fields_with_aggregate_udf(
128 current_fields: &[FieldRef],
129 func: &AggregateUDF,
130) -> Result<Vec<FieldRef>> {
131 fields_with_udf(current_fields, func)
132}
133
134#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
142pub fn fields_with_window_udf(
143 current_fields: &[FieldRef],
144 func: &WindowUDF,
145) -> Result<Vec<FieldRef>> {
146 fields_with_udf(current_fields, func)
147}
148
149pub fn fields_with_udf<F: UDFCoercionExt>(
157 current_fields: &[FieldRef],
158 func: &F,
159) -> Result<Vec<FieldRef>> {
160 let signature = func.signature();
161 let type_signature = &signature.type_signature;
162
163 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
164 if type_signature.supports_zero_argument() {
165 return Ok(vec![]);
166 } else if type_signature.used_to_support_zero_arguments() {
167 return plan_err!(
169 "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
170 func.name()
171 );
172 } else {
173 return plan_err!("'{}' does not support zero arguments", func.name());
174 }
175 }
176 let current_types = current_fields
177 .iter()
178 .map(|f| f.data_type())
179 .cloned()
180 .collect::<Vec<_>>();
181
182 let valid_types = get_valid_types_with_udf(type_signature, ¤t_types, func)?;
183 if valid_types
184 .iter()
185 .any(|data_type| data_type == ¤t_types)
186 {
187 return Ok(current_fields.to_vec());
188 }
189
190 let updated_types =
191 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
192
193 Ok(current_fields
194 .iter()
195 .zip(updated_types)
196 .map(|(current_field, new_type)| {
197 current_field.as_ref().clone().with_data_type(new_type)
198 })
199 .map(Arc::new)
200 .collect())
201}
202
203#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
211pub fn data_types(
212 function_name: impl AsRef<str>,
213 current_types: &[DataType],
214 signature: &Signature,
215) -> Result<Vec<DataType>> {
216 let type_signature = &signature.type_signature;
217
218 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
219 if type_signature.supports_zero_argument() {
220 return Ok(vec![]);
221 } else if type_signature.used_to_support_zero_arguments() {
222 return plan_err!(
224 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
225 function_name.as_ref()
226 );
227 } else {
228 return plan_err!(
229 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
230 function_name.as_ref()
231 );
232 }
233 }
234
235 let valid_types =
236 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
237 if valid_types
238 .iter()
239 .any(|data_type| data_type == current_types)
240 {
241 return Ok(current_types.to_vec());
242 }
243
244 try_coerce_types(
245 function_name.as_ref(),
246 valid_types,
247 current_types,
248 type_signature,
249 )
250}
251
252fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
253 match type_signature {
254 TypeSignature::OneOf(type_signatures) => {
255 type_signatures.iter().all(is_well_supported_signature)
256 }
257 TypeSignature::UserDefined
258 | TypeSignature::Numeric(_)
259 | TypeSignature::String(_)
260 | TypeSignature::Coercible(_)
261 | TypeSignature::Any(_)
262 | TypeSignature::Nullary
263 | TypeSignature::Comparable(_) => true,
264 TypeSignature::Variadic(_)
265 | TypeSignature::VariadicAny
266 | TypeSignature::Uniform(_, _)
267 | TypeSignature::Exact(_)
268 | TypeSignature::ArraySignature(_) => false,
269 }
270}
271
272fn try_coerce_types(
273 function_name: &str,
274 valid_types: Vec<Vec<DataType>>,
275 current_types: &[DataType],
276 type_signature: &TypeSignature,
277) -> Result<Vec<DataType>> {
278 let mut valid_types = valid_types;
279
280 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
282 if !type_signature.is_one_of() {
285 assert_eq!(valid_types.len(), 1);
286 }
287
288 let valid_types = valid_types.swap_remove(0);
289 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
290 return Ok(t);
291 }
292 } else {
293 for valid_types in valid_types {
297 if let Some(types) = maybe_data_types(&valid_types, current_types) {
298 return Ok(types);
299 }
300 }
301 }
302
303 plan_err!(
305 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
306 current_types.iter().join(", ")
307 )
308}
309
310fn get_valid_types_with_udf<F: UDFCoercionExt>(
311 signature: &TypeSignature,
312 current_types: &[DataType],
313 func: &F,
314) -> Result<Vec<Vec<DataType>>> {
315 let valid_types = match signature {
316 TypeSignature::UserDefined => match func.coerce_types(current_types) {
317 Ok(coerced_types) => vec![coerced_types],
318 Err(e) => {
319 return exec_err!(
320 "Function '{}' user-defined coercion failed with {:?}",
321 func.name(),
322 e.strip_backtrace()
323 );
324 }
325 },
326 TypeSignature::OneOf(signatures) => {
327 let mut res = vec![];
328 let mut errors = vec![];
329 for sig in signatures {
330 match get_valid_types_with_udf(sig, current_types, func) {
331 Ok(valid_types) => {
332 res.extend(valid_types);
333 }
334 Err(e) => {
335 errors.push(e.to_string());
336 }
337 }
338 }
339
340 if res.is_empty() {
342 return internal_err!(
343 "Function '{}' failed to match any signature, errors: {}",
344 func.name(),
345 errors.join(",")
346 );
347 } else {
348 res
349 }
350 }
351 _ => get_valid_types(func.name(), signature, current_types)?,
352 };
353
354 Ok(valid_types)
355}
356
357fn get_valid_types(
359 function_name: &str,
360 signature: &TypeSignature,
361 current_types: &[DataType],
362) -> Result<Vec<Vec<DataType>>> {
363 fn array_valid_types(
364 function_name: &str,
365 current_types: &[DataType],
366 arguments: &[ArrayFunctionArgument],
367 array_coercion: Option<&ListCoercion>,
368 ) -> Result<Vec<Vec<DataType>>> {
369 if current_types.len() != arguments.len() {
370 return Ok(vec![vec![]]);
371 }
372
373 let mut large_list = false;
374 let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
375 let mut list_sizes = Vec::with_capacity(arguments.len());
376 let mut element_types = Vec::with_capacity(arguments.len());
377 let mut nested_item_nullability = Vec::with_capacity(arguments.len());
378 for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
379 match argument {
380 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
381 nested_item_nullability.push(None);
382 }
383 ArrayFunctionArgument::Element => {
384 element_types.push(current_type.clone());
385 nested_item_nullability.push(None);
386 }
387 ArrayFunctionArgument::Array => match current_type {
388 DataType::Null => {
389 element_types.push(DataType::Null);
390 nested_item_nullability.push(None);
391 }
392 DataType::List(field) => {
393 element_types.push(field.data_type().clone());
394 nested_item_nullability.push(Some(field.is_nullable()));
395 fixed_size = false;
396 }
397 DataType::LargeList(field) => {
398 element_types.push(field.data_type().clone());
399 nested_item_nullability.push(Some(field.is_nullable()));
400 large_list = true;
401 fixed_size = false;
402 }
403 DataType::FixedSizeList(field, size) => {
404 element_types.push(field.data_type().clone());
405 nested_item_nullability.push(Some(field.is_nullable()));
406 list_sizes.push(*size)
407 }
408 arg_type => {
409 plan_err!("{function_name} does not support type {arg_type}")?
410 }
411 },
412 }
413 }
414
415 debug_assert_eq!(nested_item_nullability.len(), arguments.len());
416
417 let Some(element_type) = type_union_resolution(&element_types) else {
418 return Ok(vec![vec![]]);
419 };
420
421 if !fixed_size {
422 list_sizes.clear()
423 };
424
425 let mut list_sizes = list_sizes.into_iter();
426 let valid_types = arguments
427 .iter()
428 .zip(current_types.iter())
429 .zip(nested_item_nullability)
430 .map(|((argument_type, current_type), is_nested_item_nullable)| {
431 match argument_type {
432 ArrayFunctionArgument::Index => DataType::Int64,
433 ArrayFunctionArgument::String => DataType::Utf8,
434 ArrayFunctionArgument::Element => element_type.clone(),
435 ArrayFunctionArgument::Array => {
436 if current_type.is_null() {
437 DataType::Null
438 } else if large_list {
439 DataType::new_large_list(
440 element_type.clone(),
441 is_nested_item_nullable.unwrap_or(true),
442 )
443 } else if let Some(size) = list_sizes.next() {
444 DataType::new_fixed_size_list(
445 element_type.clone(),
446 size,
447 is_nested_item_nullable.unwrap_or(true),
448 )
449 } else {
450 DataType::new_list(
451 element_type.clone(),
452 is_nested_item_nullable.unwrap_or(true),
453 )
454 }
455 }
456 }
457 });
458
459 Ok(vec![valid_types.collect()])
460 }
461
462 fn recursive_array(array_type: &DataType) -> Option<DataType> {
463 match array_type {
464 DataType::List(_)
465 | DataType::LargeList(_)
466 | DataType::FixedSizeList(_, _) => {
467 let array_type = coerced_fixed_size_list_to_list(array_type);
468 Some(array_type)
469 }
470 _ => None,
471 }
472 }
473
474 fn function_length_check(
475 function_name: &str,
476 length: usize,
477 expected_length: usize,
478 ) -> Result<()> {
479 if length != expected_length {
480 return plan_err!(
481 "Function '{function_name}' expects {expected_length} arguments but received {length}"
482 );
483 }
484 Ok(())
485 }
486
487 let valid_types = match signature {
488 TypeSignature::Variadic(valid_types) => valid_types
489 .iter()
490 .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
491 .collect(),
492 TypeSignature::String(number) => {
493 function_length_check(function_name, current_types.len(), *number)?;
494
495 let mut new_types = Vec::with_capacity(current_types.len());
496 for data_type in current_types.iter() {
497 let logical_data_type: NativeType = data_type.into();
498 if logical_data_type == NativeType::String {
499 new_types.push(data_type.to_owned());
500 } else if logical_data_type == NativeType::Null {
501 new_types.push(DataType::Utf8);
503 } else {
504 return plan_err!(
505 "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
506 );
507 }
508 }
509
510 fn find_common_type(
512 function_name: &str,
513 lhs_type: &DataType,
514 rhs_type: &DataType,
515 ) -> Result<DataType> {
516 match (lhs_type, rhs_type) {
517 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
518 find_common_type(function_name, lhs, rhs)
519 }
520 (DataType::Dictionary(_, v), other)
521 | (other, DataType::Dictionary(_, v)) => {
522 find_common_type(function_name, v, other)
523 }
524 _ => {
525 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
526 Ok(coerced_type)
527 } else {
528 plan_err!(
529 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
530 )
531 }
532 }
533 }
534 }
535
536 let mut coerced_type = new_types.first().unwrap().to_owned();
538 for t in new_types.iter().skip(1) {
539 coerced_type = find_common_type(function_name, &coerced_type, t)?;
540 }
541
542 fn base_type_or_default_type(data_type: &DataType) -> DataType {
543 if let DataType::Dictionary(_, v) = data_type {
544 base_type_or_default_type(v)
545 } else {
546 data_type.to_owned()
547 }
548 }
549
550 vec![vec![base_type_or_default_type(&coerced_type); *number]]
551 }
552 TypeSignature::Numeric(number) => {
553 function_length_check(function_name, current_types.len(), *number)?;
554
555 let mut valid_type = current_types.first().unwrap().to_owned();
557 for t in current_types.iter().skip(1) {
558 let logical_data_type: NativeType = t.into();
559 if logical_data_type == NativeType::Null {
560 continue;
561 }
562
563 if !logical_data_type.is_numeric() {
564 return plan_err!(
565 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
566 );
567 }
568
569 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
570 valid_type = coerced_type;
571 } else {
572 return plan_err!(
573 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
574 );
575 }
576 }
577
578 let logical_data_type: NativeType = valid_type.clone().into();
579 if logical_data_type == NativeType::Null {
583 valid_type = DataType::Float64;
584 } else if !logical_data_type.is_numeric() {
585 return plan_err!(
586 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
587 );
588 }
589
590 vec![vec![valid_type; *number]]
591 }
592 TypeSignature::Comparable(num) => {
593 function_length_check(function_name, current_types.len(), *num)?;
594 let mut target_type = current_types[0].to_owned();
595 for data_type in current_types.iter().skip(1) {
596 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
597 target_type = dt;
598 } else {
599 return plan_err!(
600 "For function '{function_name}' {target_type} and {data_type} is not comparable"
601 );
602 }
603 }
604 if target_type.is_null() {
606 vec![vec![DataType::Utf8View; *num]]
607 } else {
608 vec![vec![target_type; *num]]
609 }
610 }
611 TypeSignature::Coercible(param_types) => {
612 function_length_check(function_name, current_types.len(), param_types.len())?;
613
614 let mut new_types = Vec::with_capacity(current_types.len());
615 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
616 let current_native_type: NativeType = current_type.into();
617
618 if param
619 .desired_type()
620 .matches_native_type(¤t_native_type)
621 {
622 let casted_type = param
623 .desired_type()
624 .default_casted_type(¤t_native_type, current_type)?;
625
626 new_types.push(casted_type);
627 } else if param
628 .allowed_source_types()
629 .iter()
630 .any(|t| t.matches_native_type(¤t_native_type))
631 {
632 let default_casted_type = param.default_casted_type().unwrap();
634 let casted_type =
635 default_casted_type.default_cast_for(current_type)?;
636 new_types.push(casted_type);
637 } else {
638 return internal_err!(
639 "Expect {} but received NativeType::{}, DataType: {}",
640 param.desired_type(),
641 current_native_type,
642 current_type
643 );
644 }
645 }
646
647 vec![new_types]
648 }
649 TypeSignature::Uniform(number, valid_types) => {
650 if *number == 0 {
651 return plan_err!(
652 "The function '{function_name}' expected at least one argument"
653 );
654 }
655
656 valid_types
657 .iter()
658 .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
659 .collect()
660 }
661 TypeSignature::UserDefined => {
662 return internal_err!(
663 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
664 );
665 }
666 TypeSignature::VariadicAny => {
667 if current_types.is_empty() {
668 return plan_err!(
669 "Function '{function_name}' expected at least one argument but received 0"
670 );
671 }
672 vec![current_types.to_vec()]
673 }
674 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
675 TypeSignature::ArraySignature(function_signature) => match function_signature {
676 ArrayFunctionSignature::Array {
677 arguments,
678 array_coercion,
679 } => array_valid_types(
680 function_name,
681 current_types,
682 arguments,
683 array_coercion.as_ref(),
684 )?,
685 ArrayFunctionSignature::RecursiveArray => {
686 if current_types.len() != 1 {
687 return Ok(vec![vec![]]);
688 }
689 recursive_array(¤t_types[0])
690 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
691 }
692 ArrayFunctionSignature::MapArray => {
693 if current_types.len() != 1 {
694 return Ok(vec![vec![]]);
695 }
696
697 match ¤t_types[0] {
698 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
699 _ => vec![vec![]],
700 }
701 }
702 },
703 TypeSignature::Nullary => {
704 if !current_types.is_empty() {
705 return plan_err!(
706 "The function '{function_name}' expected zero argument but received {}",
707 current_types.len()
708 );
709 }
710 vec![vec![]]
711 }
712 TypeSignature::Any(number) => {
713 if current_types.is_empty() {
714 return plan_err!(
715 "The function '{function_name}' expected at least one argument but received 0"
716 );
717 }
718
719 if current_types.len() != *number {
720 return plan_err!(
721 "The function '{function_name}' expected {number} arguments but received {}",
722 current_types.len()
723 );
724 }
725 vec![(0..*number).map(|i| current_types[i].clone()).collect()]
726 }
727 TypeSignature::OneOf(types) => types
728 .iter()
729 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
730 .flatten()
731 .collect::<Vec<_>>(),
732 };
733
734 Ok(valid_types)
735}
736
737fn maybe_data_types(
744 valid_types: &[DataType],
745 current_types: &[DataType],
746) -> Option<Vec<DataType>> {
747 if valid_types.len() != current_types.len() {
748 return None;
749 }
750
751 let mut new_type = Vec::with_capacity(valid_types.len());
752 for (i, valid_type) in valid_types.iter().enumerate() {
753 let current_type = ¤t_types[i];
754
755 if current_type == valid_type {
756 new_type.push(current_type.clone())
757 } else {
758 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
762 new_type.push(coerced_type)
763 } else {
764 return None;
766 }
767 }
768 }
769 Some(new_type)
770}
771
772fn maybe_data_types_without_coercion(
776 valid_types: &[DataType],
777 current_types: &[DataType],
778) -> Option<Vec<DataType>> {
779 if valid_types.len() != current_types.len() {
780 return None;
781 }
782
783 let mut new_type = Vec::with_capacity(valid_types.len());
784 for (i, valid_type) in valid_types.iter().enumerate() {
785 let current_type = ¤t_types[i];
786
787 if current_type == valid_type {
788 new_type.push(current_type.clone())
789 } else if can_cast_types(current_type, valid_type) {
790 new_type.push(valid_type.clone())
792 } else {
793 return None;
794 }
795 }
796 Some(new_type)
797}
798
799pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
804 if type_into == type_from {
805 return true;
806 }
807 if let Some(coerced) = coerced_from(type_into, type_from) {
808 return coerced == *type_into;
809 }
810 false
811}
812
813fn coerced_from<'a>(
820 type_into: &'a DataType,
821 type_from: &'a DataType,
822) -> Option<DataType> {
823 use self::DataType::*;
824
825 match (type_into, type_from) {
827 (_, Dictionary(_, value_type))
829 if coerced_from(type_into, value_type).is_some() =>
830 {
831 Some(type_into.clone())
832 }
833 (Dictionary(_, value_type), _)
834 if coerced_from(value_type, type_from).is_some() =>
835 {
836 Some(type_into.clone())
837 }
838 (Int8, Null | Int8) => Some(type_into.clone()),
840 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
841 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
842 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
843 Some(type_into.clone())
844 }
845 (UInt8, Null | UInt8) => Some(type_into.clone()),
846 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
847 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
848 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
849 (
850 Float32,
851 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
852 | Float32,
853 ) => Some(type_into.clone()),
854 (
855 Float64,
856 Null
857 | Int8
858 | Int16
859 | Int32
860 | Int64
861 | UInt8
862 | UInt16
863 | UInt32
864 | UInt64
865 | Float32
866 | Float64
867 | Decimal32(_, _)
868 | Decimal64(_, _)
869 | Decimal128(_, _)
870 | Decimal256(_, _),
871 ) => Some(type_into.clone()),
872 (
873 Timestamp(TimeUnit::Nanosecond, None),
874 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
875 ) => Some(type_into.clone()),
876 (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
877 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
879 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
881 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
882
883 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
884
885 (List(_) | LargeList(_), _)
888 if base_type(type_from).is_null()
889 || list_ndims(type_from) == list_ndims(type_into) =>
890 {
891 Some(type_into.clone())
892 }
893 (
895 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
896 FixedSizeList(f_from, size_from),
897 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
898 Some(data_type) if &data_type != f_into.data_type() => {
899 let new_field =
900 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
901 Some(FixedSizeList(new_field, *size_from))
902 }
903 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
904 _ => None,
905 },
906 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
907 match type_from {
908 Timestamp(_, Some(from_tz)) => {
909 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
910 }
911 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
912 Some(Timestamp(*unit, Some("+00".into())))
914 }
915 _ => None,
916 }
917 }
918 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
919 Some(type_into.clone())
920 }
921 _ => None,
922 }
923}
924
925#[cfg(test)]
926mod tests {
927 use crate::Volatility;
928
929 use super::*;
930 use arrow::datatypes::Field;
931 use datafusion_common::{assert_contains, types::logical_binary};
932 use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
933
934 #[test]
935 fn test_string_conversion() {
936 let cases = vec![
937 (DataType::Utf8View, DataType::Utf8, true),
938 (DataType::Utf8View, DataType::LargeUtf8, true),
939 ];
940
941 for case in cases {
942 assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
943 }
944 }
945
946 #[test]
947 fn test_maybe_data_types() {
948 let cases = vec![
950 (
952 vec![DataType::UInt8, DataType::UInt16],
953 vec![DataType::UInt8, DataType::UInt16],
954 Some(vec![DataType::UInt8, DataType::UInt16]),
955 ),
956 (
958 vec![DataType::UInt16, DataType::UInt16],
959 vec![DataType::UInt8, DataType::UInt16],
960 Some(vec![DataType::UInt16, DataType::UInt16]),
961 ),
962 (vec![], vec![], Some(vec![])),
964 (
966 vec![DataType::Boolean, DataType::UInt16],
967 vec![DataType::UInt8, DataType::UInt16],
968 None,
969 ),
970 (
972 vec![DataType::Boolean, DataType::UInt32],
973 vec![DataType::Boolean, DataType::UInt16],
974 Some(vec![DataType::Boolean, DataType::UInt32]),
975 ),
976 (
978 vec![
979 DataType::Timestamp(TimeUnit::Nanosecond, None),
980 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
981 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
982 ],
983 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
984 Some(vec![
985 DataType::Timestamp(TimeUnit::Nanosecond, None),
986 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
987 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
988 ]),
989 ),
990 ];
991
992 for case in cases {
993 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
994 }
995 }
996
997 #[test]
998 fn test_get_valid_types_numeric() -> Result<()> {
999 let get_valid_types_flatten =
1000 |function_name: &str,
1001 signature: &TypeSignature,
1002 current_types: &[DataType]| {
1003 get_valid_types(function_name, signature, current_types)
1004 .unwrap()
1005 .into_iter()
1006 .flatten()
1007 .collect::<Vec<_>>()
1008 };
1009
1010 let got = get_valid_types_flatten(
1012 "test",
1013 &TypeSignature::Numeric(1),
1014 &[DataType::Int32],
1015 );
1016 assert_eq!(got, [DataType::Int32]);
1017
1018 let got = get_valid_types_flatten(
1020 "test",
1021 &TypeSignature::Numeric(2),
1022 &[DataType::Int32, DataType::Int64],
1023 );
1024 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1025
1026 let got = get_valid_types_flatten(
1028 "test",
1029 &TypeSignature::Numeric(3),
1030 &[DataType::Int32, DataType::Int64, DataType::Float64],
1031 );
1032 assert_eq!(
1033 got,
1034 [DataType::Float64, DataType::Float64, DataType::Float64]
1035 );
1036
1037 let got = get_valid_types(
1039 "test",
1040 &TypeSignature::Numeric(2),
1041 &[DataType::Int32, DataType::Utf8],
1042 )
1043 .unwrap_err();
1044 assert_contains!(
1045 got.to_string(),
1046 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1047 );
1048
1049 let got = get_valid_types_flatten(
1051 "test",
1052 &TypeSignature::Numeric(1),
1053 &[DataType::Null],
1054 );
1055 assert_eq!(got, [DataType::Float64]);
1056
1057 let got = get_valid_types(
1059 "test",
1060 &TypeSignature::Numeric(1),
1061 &[DataType::Timestamp(TimeUnit::Second, None)],
1062 )
1063 .unwrap_err();
1064 assert_contains!(
1065 got.to_string(),
1066 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1067 );
1068
1069 Ok(())
1070 }
1071
1072 #[test]
1073 fn test_get_valid_types_one_of() -> Result<()> {
1074 let signature =
1075 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1076
1077 let invalid_types = get_valid_types(
1078 "test",
1079 &signature,
1080 &[DataType::Int32, DataType::Int32, DataType::Int32],
1081 )?;
1082 assert_eq!(invalid_types.len(), 0);
1083
1084 let args = vec![DataType::Int32, DataType::Int32];
1085 let valid_types = get_valid_types("test", &signature, &args)?;
1086 assert_eq!(valid_types.len(), 1);
1087 assert_eq!(valid_types[0], args);
1088
1089 let args = vec![DataType::Int32];
1090 let valid_types = get_valid_types("test", &signature, &args)?;
1091 assert_eq!(valid_types.len(), 1);
1092 assert_eq!(valid_types[0], args);
1093
1094 Ok(())
1095 }
1096
1097 #[test]
1098 fn test_get_valid_types_length_check() -> Result<()> {
1099 let signature = TypeSignature::Numeric(1);
1100
1101 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1102 assert_contains!(
1103 err.to_string(),
1104 "Function 'test' expects 1 arguments but received 0"
1105 );
1106
1107 let err = get_valid_types(
1108 "test",
1109 &signature,
1110 &[DataType::Int32, DataType::Int32, DataType::Int32],
1111 )
1112 .unwrap_err();
1113 assert_contains!(
1114 err.to_string(),
1115 "Function 'test' expects 1 arguments but received 3"
1116 );
1117
1118 Ok(())
1119 }
1120
1121 #[test]
1122 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1123 struct MockUdf(Signature);
1124
1125 impl UDFCoercionExt for MockUdf {
1126 fn name(&self) -> &str {
1127 "test"
1128 }
1129 fn signature(&self) -> &Signature {
1130 &self.0
1131 }
1132 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1133 unimplemented!()
1134 }
1135 }
1136
1137 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1138 let current_fields = vec![Arc::new(Field::new(
1140 "t",
1141 DataType::FixedSizeList(Arc::clone(&inner), 2),
1142 true,
1143 ))];
1144
1145 let signature = Signature::exact(
1146 vec![DataType::FixedSizeList(
1147 Arc::clone(&inner),
1148 FIXED_SIZE_LIST_WILDCARD,
1149 )],
1150 Volatility::Stable,
1151 );
1152
1153 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature))?;
1154 assert_eq!(coerced_fields, current_fields);
1155
1156 let signature = Signature::exact(
1158 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1159 Volatility::Stable,
1160 );
1161 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature));
1162 assert!(coerced_fields.is_err());
1163
1164 let signature = Signature::exact(
1166 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1167 Volatility::Stable,
1168 );
1169 let coerced_fields =
1170 fields_with_udf(¤t_fields, &MockUdf(signature)).unwrap();
1171 assert_eq!(coerced_fields, current_fields);
1172
1173 Ok(())
1174 }
1175
1176 #[test]
1177 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1178 let type_into = DataType::FixedSizeList(
1179 Arc::new(Field::new_list_field(
1180 DataType::FixedSizeList(
1181 Arc::new(Field::new_list_field(DataType::Int32, false)),
1182 FIXED_SIZE_LIST_WILDCARD,
1183 ),
1184 false,
1185 )),
1186 FIXED_SIZE_LIST_WILDCARD,
1187 );
1188
1189 let type_from = DataType::FixedSizeList(
1190 Arc::new(Field::new_list_field(
1191 DataType::FixedSizeList(
1192 Arc::new(Field::new_list_field(DataType::Int8, false)),
1193 4,
1194 ),
1195 false,
1196 )),
1197 3,
1198 );
1199
1200 assert_eq!(
1201 coerced_from(&type_into, &type_from),
1202 Some(DataType::FixedSizeList(
1203 Arc::new(Field::new_list_field(
1204 DataType::FixedSizeList(
1205 Arc::new(Field::new_list_field(DataType::Int32, false)),
1206 4,
1207 ),
1208 false,
1209 )),
1210 3,
1211 ))
1212 );
1213
1214 Ok(())
1215 }
1216
1217 #[test]
1218 fn test_coerced_from_dictionary() {
1219 let type_into =
1220 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1221 let type_from = DataType::Int64;
1222 assert_eq!(coerced_from(&type_into, &type_from), None);
1223
1224 let type_from =
1225 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1226 let type_into = DataType::Int64;
1227 assert_eq!(
1228 coerced_from(&type_into, &type_from),
1229 Some(type_into.clone())
1230 );
1231 }
1232
1233 #[test]
1234 fn test_get_valid_types_array_and_array() -> Result<()> {
1235 let function = "array_and_array";
1236 let signature = Signature::arrays(
1237 2,
1238 Some(ListCoercion::FixedSizedListToList),
1239 Volatility::Immutable,
1240 );
1241
1242 let data_types = vec![
1243 DataType::new_list(DataType::Int32, true),
1244 DataType::new_large_list(DataType::Float64, true),
1245 ];
1246 assert_eq!(
1247 get_valid_types(function, &signature.type_signature, &data_types)?,
1248 vec![vec![
1249 DataType::new_large_list(DataType::Float64, true),
1250 DataType::new_large_list(DataType::Float64, true),
1251 ]]
1252 );
1253
1254 let data_types = vec![
1255 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1256 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1257 ];
1258 assert_eq!(
1259 get_valid_types(function, &signature.type_signature, &data_types)?,
1260 vec![vec![
1261 DataType::new_list(DataType::Int64, true),
1262 DataType::new_list(DataType::Int64, true),
1263 ]]
1264 );
1265
1266 let data_types = vec![
1267 DataType::new_fixed_size_list(DataType::Null, 3, true),
1268 DataType::new_large_list(DataType::Utf8, true),
1269 ];
1270 assert_eq!(
1271 get_valid_types(function, &signature.type_signature, &data_types)?,
1272 vec![vec![
1273 DataType::new_large_list(DataType::Utf8, true),
1274 DataType::new_large_list(DataType::Utf8, true),
1275 ]]
1276 );
1277
1278 Ok(())
1279 }
1280
1281 #[test]
1282 fn test_get_valid_types_array_and_element() -> Result<()> {
1283 let function = "array_and_element";
1284 let signature = Signature::array_and_element(Volatility::Immutable);
1285
1286 let data_types =
1287 vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1288 assert_eq!(
1289 get_valid_types(function, &signature.type_signature, &data_types)?,
1290 vec![vec![
1291 DataType::new_list(DataType::Float64, true),
1292 DataType::Float64,
1293 ]]
1294 );
1295
1296 let data_types = vec![
1297 DataType::new_large_list(DataType::Int32, true),
1298 DataType::Null,
1299 ];
1300 assert_eq!(
1301 get_valid_types(function, &signature.type_signature, &data_types)?,
1302 vec![vec![
1303 DataType::new_large_list(DataType::Int32, true),
1304 DataType::Int32,
1305 ]]
1306 );
1307
1308 let data_types = vec![
1309 DataType::new_fixed_size_list(DataType::Null, 3, true),
1310 DataType::Utf8,
1311 ];
1312 assert_eq!(
1313 get_valid_types(function, &signature.type_signature, &data_types)?,
1314 vec![vec![
1315 DataType::new_list(DataType::Utf8, true),
1316 DataType::Utf8,
1317 ]]
1318 );
1319
1320 Ok(())
1321 }
1322
1323 #[test]
1324 fn test_get_valid_types_element_and_array() -> Result<()> {
1325 let function = "element_and_array";
1326 let signature = Signature::element_and_array(Volatility::Immutable);
1327
1328 let data_types = vec![
1329 DataType::new_large_list(DataType::Null, false),
1330 DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1331 ];
1332 assert_eq!(
1333 get_valid_types(function, &signature.type_signature, &data_types)?,
1334 vec![vec![
1335 DataType::new_large_list(DataType::Int64, true),
1336 DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1337 ]]
1338 );
1339
1340 Ok(())
1341 }
1342
1343 #[test]
1344 fn test_get_valid_types_coercible_binary() -> Result<()> {
1345 let signature = Signature::coercible(
1346 vec![Coercion::new_exact(TypeSignatureClass::Native(
1347 logical_binary(),
1348 ))],
1349 Volatility::Immutable,
1350 );
1351
1352 for t in [
1354 DataType::Binary,
1355 DataType::BinaryView,
1356 DataType::LargeBinary,
1357 ] {
1358 assert_eq!(
1359 get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1360 vec![vec![t]]
1361 );
1362 }
1363
1364 Ok(())
1365 }
1366
1367 #[test]
1368 fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1369 let function = "fixed_size_arrays";
1370 let signature = Signature::arrays(2, None, Volatility::Immutable);
1371
1372 let data_types = vec![
1373 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1374 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1375 ];
1376 assert_eq!(
1377 get_valid_types(function, &signature.type_signature, &data_types)?,
1378 vec![vec![
1379 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1380 DataType::new_fixed_size_list(DataType::Int64, 5, true),
1381 ]]
1382 );
1383
1384 let data_types = vec![
1385 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1386 DataType::new_list(DataType::Int32, true),
1387 ];
1388 assert_eq!(
1389 get_valid_types(function, &signature.type_signature, &data_types)?,
1390 vec![vec![
1391 DataType::new_list(DataType::Int64, true),
1392 DataType::new_list(DataType::Int64, true),
1393 ]]
1394 );
1395
1396 let data_types = vec![
1397 DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1398 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1399 ];
1400 assert_eq!(
1401 get_valid_types(function, &signature.type_signature, &data_types)?,
1402 vec![vec![]]
1403 );
1404
1405 let data_types = vec![
1406 DataType::new_fixed_size_list(DataType::Int64, 3, false),
1407 DataType::new_list(DataType::Int32, false),
1408 ];
1409 assert_eq!(
1410 get_valid_types(function, &signature.type_signature, &data_types)?,
1411 vec![vec![
1412 DataType::new_list(DataType::Int64, false),
1413 DataType::new_list(DataType::Int64, false),
1414 ]]
1415 );
1416
1417 Ok(())
1418 }
1419}