1use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::FieldRef;
21use arrow::{
22 compute::can_cast_types,
23 datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27 base_type, coerced_fixed_size_list_to_list, ListCoercion,
28};
29use datafusion_common::{
30 exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36 type_coercion::binary::comparison_coercion_numeric,
37 type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42pub fn data_types_with_scalar_udf(
50 current_types: &[DataType],
51 func: &ScalarUDF,
52) -> Result<Vec<DataType>> {
53 let signature = func.signature();
54 let type_signature = &signature.type_signature;
55
56 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
57 if type_signature.supports_zero_argument() {
58 return Ok(vec![]);
59 } else if type_signature.used_to_support_zero_arguments() {
60 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
62 } else {
63 return plan_err!("'{}' does not support zero arguments", func.name());
64 }
65 }
66
67 let valid_types =
68 get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
69
70 if valid_types
71 .iter()
72 .any(|data_type| data_type == current_types)
73 {
74 return Ok(current_types.to_vec());
75 }
76
77 try_coerce_types(func.name(), valid_types, current_types, type_signature)
78}
79
80pub fn fields_with_aggregate_udf(
88 current_fields: &[FieldRef],
89 func: &AggregateUDF,
90) -> Result<Vec<FieldRef>> {
91 let signature = func.signature();
92 let type_signature = &signature.type_signature;
93
94 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
95 if type_signature.supports_zero_argument() {
96 return Ok(vec![]);
97 } else if type_signature.used_to_support_zero_arguments() {
98 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
100 } else {
101 return plan_err!("'{}' does not support zero arguments", func.name());
102 }
103 }
104 let current_types = current_fields
105 .iter()
106 .map(|f| f.data_type())
107 .cloned()
108 .collect::<Vec<_>>();
109
110 let valid_types =
111 get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?;
112 if valid_types
113 .iter()
114 .any(|data_type| data_type == ¤t_types)
115 {
116 return Ok(current_fields.to_vec());
117 }
118
119 let updated_types =
120 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
121
122 Ok(current_fields
123 .iter()
124 .zip(updated_types)
125 .map(|(current_field, new_type)| {
126 current_field.as_ref().clone().with_data_type(new_type)
127 })
128 .map(Arc::new)
129 .collect())
130}
131
132pub fn fields_with_window_udf(
140 current_fields: &[FieldRef],
141 func: &WindowUDF,
142) -> Result<Vec<FieldRef>> {
143 let signature = func.signature();
144 let type_signature = &signature.type_signature;
145
146 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
147 if type_signature.supports_zero_argument() {
148 return Ok(vec![]);
149 } else if type_signature.used_to_support_zero_arguments() {
150 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
152 } else {
153 return plan_err!("'{}' does not support zero arguments", func.name());
154 }
155 }
156
157 let current_types = current_fields
158 .iter()
159 .map(|f| f.data_type())
160 .cloned()
161 .collect::<Vec<_>>();
162 let valid_types =
163 get_valid_types_with_window_udf(type_signature, ¤t_types, func)?;
164 if valid_types
165 .iter()
166 .any(|data_type| data_type == ¤t_types)
167 {
168 return Ok(current_fields.to_vec());
169 }
170
171 let updated_types =
172 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
173
174 Ok(current_fields
175 .iter()
176 .zip(updated_types)
177 .map(|(current_field, new_type)| {
178 current_field.as_ref().clone().with_data_type(new_type)
179 })
180 .map(Arc::new)
181 .collect())
182}
183
184pub fn data_types(
192 function_name: impl AsRef<str>,
193 current_types: &[DataType],
194 signature: &Signature,
195) -> Result<Vec<DataType>> {
196 let type_signature = &signature.type_signature;
197
198 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
199 if type_signature.supports_zero_argument() {
200 return Ok(vec![]);
201 } else if type_signature.used_to_support_zero_arguments() {
202 return plan_err!(
204 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
205 function_name.as_ref()
206 );
207 } else {
208 return plan_err!(
209 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
210 function_name.as_ref()
211 );
212 }
213 }
214
215 let valid_types =
216 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
217 if valid_types
218 .iter()
219 .any(|data_type| data_type == current_types)
220 {
221 return Ok(current_types.to_vec());
222 }
223
224 try_coerce_types(
225 function_name.as_ref(),
226 valid_types,
227 current_types,
228 type_signature,
229 )
230}
231
232fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
233 if let TypeSignature::OneOf(signatures) = type_signature {
234 return signatures.iter().all(is_well_supported_signature);
235 }
236
237 matches!(
238 type_signature,
239 TypeSignature::UserDefined
240 | TypeSignature::Numeric(_)
241 | TypeSignature::String(_)
242 | TypeSignature::Coercible(_)
243 | TypeSignature::Any(_)
244 | TypeSignature::Nullary
245 | TypeSignature::Comparable(_)
246 )
247}
248
249fn try_coerce_types(
250 function_name: &str,
251 valid_types: Vec<Vec<DataType>>,
252 current_types: &[DataType],
253 type_signature: &TypeSignature,
254) -> Result<Vec<DataType>> {
255 let mut valid_types = valid_types;
256
257 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
259 if !type_signature.is_one_of() {
262 assert_eq!(valid_types.len(), 1);
263 }
264
265 let valid_types = valid_types.swap_remove(0);
266 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
267 return Ok(t);
268 }
269 } else {
270 for valid_types in valid_types {
274 if let Some(types) = maybe_data_types(&valid_types, current_types) {
275 return Ok(types);
276 }
277 }
278 }
279
280 plan_err!(
282 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
283 current_types.iter().join(", ")
284 )
285}
286
287fn get_valid_types_with_scalar_udf(
288 signature: &TypeSignature,
289 current_types: &[DataType],
290 func: &ScalarUDF,
291) -> Result<Vec<Vec<DataType>>> {
292 match signature {
293 TypeSignature::UserDefined => match func.coerce_types(current_types) {
294 Ok(coerced_types) => Ok(vec![coerced_types]),
295 Err(e) => exec_err!(
296 "Function '{}' user-defined coercion failed with {:?}",
297 func.name(),
298 e.strip_backtrace()
299 ),
300 },
301 TypeSignature::OneOf(signatures) => {
302 let mut res = vec![];
303 let mut errors = vec![];
304 for sig in signatures {
305 match get_valid_types_with_scalar_udf(sig, current_types, func) {
306 Ok(valid_types) => {
307 res.extend(valid_types);
308 }
309 Err(e) => {
310 errors.push(e.to_string());
311 }
312 }
313 }
314
315 if res.is_empty() {
317 internal_err!(
318 "Function '{}' failed to match any signature, errors: {}",
319 func.name(),
320 errors.join(",")
321 )
322 } else {
323 Ok(res)
324 }
325 }
326 _ => get_valid_types(func.name(), signature, current_types),
327 }
328}
329
330fn get_valid_types_with_aggregate_udf(
331 signature: &TypeSignature,
332 current_types: &[DataType],
333 func: &AggregateUDF,
334) -> Result<Vec<Vec<DataType>>> {
335 let valid_types = match signature {
336 TypeSignature::UserDefined => match func.coerce_types(current_types) {
337 Ok(coerced_types) => vec![coerced_types],
338 Err(e) => {
339 return exec_err!(
340 "Function '{}' user-defined coercion failed with {:?}",
341 func.name(),
342 e.strip_backtrace()
343 )
344 }
345 },
346 TypeSignature::OneOf(signatures) => signatures
347 .iter()
348 .filter_map(|t| {
349 get_valid_types_with_aggregate_udf(t, current_types, func).ok()
350 })
351 .flatten()
352 .collect::<Vec<_>>(),
353 _ => get_valid_types(func.name(), signature, current_types)?,
354 };
355
356 Ok(valid_types)
357}
358
359fn get_valid_types_with_window_udf(
360 signature: &TypeSignature,
361 current_types: &[DataType],
362 func: &WindowUDF,
363) -> Result<Vec<Vec<DataType>>> {
364 let valid_types = match signature {
365 TypeSignature::UserDefined => match func.coerce_types(current_types) {
366 Ok(coerced_types) => vec![coerced_types],
367 Err(e) => {
368 return exec_err!(
369 "Function '{}' user-defined coercion failed with {:?}",
370 func.name(),
371 e.strip_backtrace()
372 )
373 }
374 },
375 TypeSignature::OneOf(signatures) => signatures
376 .iter()
377 .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
378 .flatten()
379 .collect::<Vec<_>>(),
380 _ => get_valid_types(func.name(), signature, current_types)?,
381 };
382
383 Ok(valid_types)
384}
385
386fn get_valid_types(
388 function_name: &str,
389 signature: &TypeSignature,
390 current_types: &[DataType],
391) -> Result<Vec<Vec<DataType>>> {
392 fn array_valid_types(
393 function_name: &str,
394 current_types: &[DataType],
395 arguments: &[ArrayFunctionArgument],
396 array_coercion: Option<&ListCoercion>,
397 ) -> Result<Vec<Vec<DataType>>> {
398 if current_types.len() != arguments.len() {
399 return Ok(vec![vec![]]);
400 }
401
402 let mut large_list = false;
403 let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
404 let mut list_sizes = Vec::with_capacity(arguments.len());
405 let mut element_types = Vec::with_capacity(arguments.len());
406 let mut nested_item_nullability = Vec::with_capacity(arguments.len());
407 for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
408 match argument {
409 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
410 nested_item_nullability.push(None);
411 }
412 ArrayFunctionArgument::Element => {
413 element_types.push(current_type.clone());
414 nested_item_nullability.push(None);
415 }
416 ArrayFunctionArgument::Array => match current_type {
417 DataType::Null => {
418 element_types.push(DataType::Null);
419 nested_item_nullability.push(None);
420 }
421 DataType::List(field) => {
422 element_types.push(field.data_type().clone());
423 nested_item_nullability.push(Some(field.is_nullable()));
424 fixed_size = false;
425 }
426 DataType::LargeList(field) => {
427 element_types.push(field.data_type().clone());
428 nested_item_nullability.push(Some(field.is_nullable()));
429 large_list = true;
430 fixed_size = false;
431 }
432 DataType::FixedSizeList(field, size) => {
433 element_types.push(field.data_type().clone());
434 nested_item_nullability.push(Some(field.is_nullable()));
435 list_sizes.push(*size)
436 }
437 arg_type => {
438 plan_err!("{function_name} does not support type {arg_type}")?
439 }
440 },
441 }
442 }
443
444 debug_assert_eq!(nested_item_nullability.len(), arguments.len());
445
446 let Some(element_type) = type_union_resolution(&element_types) else {
447 return Ok(vec![vec![]]);
448 };
449
450 if !fixed_size {
451 list_sizes.clear()
452 };
453
454 let mut list_sizes = list_sizes.into_iter();
455 let valid_types = arguments
456 .iter()
457 .zip(current_types.iter())
458 .zip(nested_item_nullability)
459 .map(|((argument_type, current_type), is_nested_item_nullable)| {
460 match argument_type {
461 ArrayFunctionArgument::Index => DataType::Int64,
462 ArrayFunctionArgument::String => DataType::Utf8,
463 ArrayFunctionArgument::Element => element_type.clone(),
464 ArrayFunctionArgument::Array => {
465 if current_type.is_null() {
466 DataType::Null
467 } else if large_list {
468 DataType::new_large_list(
469 element_type.clone(),
470 is_nested_item_nullable.unwrap_or(true),
471 )
472 } else if let Some(size) = list_sizes.next() {
473 DataType::new_fixed_size_list(
474 element_type.clone(),
475 size,
476 is_nested_item_nullable.unwrap_or(true),
477 )
478 } else {
479 DataType::new_list(
480 element_type.clone(),
481 is_nested_item_nullable.unwrap_or(true),
482 )
483 }
484 }
485 }
486 });
487
488 Ok(vec![valid_types.collect()])
489 }
490
491 fn recursive_array(array_type: &DataType) -> Option<DataType> {
492 match array_type {
493 DataType::List(_)
494 | DataType::LargeList(_)
495 | DataType::FixedSizeList(_, _) => {
496 let array_type = coerced_fixed_size_list_to_list(array_type);
497 Some(array_type)
498 }
499 _ => None,
500 }
501 }
502
503 fn function_length_check(
504 function_name: &str,
505 length: usize,
506 expected_length: usize,
507 ) -> Result<()> {
508 if length != expected_length {
509 return plan_err!(
510 "Function '{function_name}' expects {expected_length} arguments but received {length}"
511 );
512 }
513 Ok(())
514 }
515
516 let valid_types = match signature {
517 TypeSignature::Variadic(valid_types) => valid_types
518 .iter()
519 .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
520 .collect(),
521 TypeSignature::String(number) => {
522 function_length_check(function_name, current_types.len(), *number)?;
523
524 let mut new_types = Vec::with_capacity(current_types.len());
525 for data_type in current_types.iter() {
526 let logical_data_type: NativeType = data_type.into();
527 if logical_data_type == NativeType::String {
528 new_types.push(data_type.to_owned());
529 } else if logical_data_type == NativeType::Null {
530 new_types.push(DataType::Utf8);
532 } else {
533 return plan_err!(
534 "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
535 );
536 }
537 }
538
539 fn find_common_type(
541 function_name: &str,
542 lhs_type: &DataType,
543 rhs_type: &DataType,
544 ) -> Result<DataType> {
545 match (lhs_type, rhs_type) {
546 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
547 find_common_type(function_name, lhs, rhs)
548 }
549 (DataType::Dictionary(_, v), other)
550 | (other, DataType::Dictionary(_, v)) => {
551 find_common_type(function_name, v, other)
552 }
553 _ => {
554 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
555 Ok(coerced_type)
556 } else {
557 plan_err!(
558 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
559 )
560 }
561 }
562 }
563 }
564
565 let mut coerced_type = new_types.first().unwrap().to_owned();
567 for t in new_types.iter().skip(1) {
568 coerced_type = find_common_type(function_name, &coerced_type, t)?;
569 }
570
571 fn base_type_or_default_type(data_type: &DataType) -> DataType {
572 if let DataType::Dictionary(_, v) = data_type {
573 base_type_or_default_type(v)
574 } else {
575 data_type.to_owned()
576 }
577 }
578
579 vec![vec![base_type_or_default_type(&coerced_type); *number]]
580 }
581 TypeSignature::Numeric(number) => {
582 function_length_check(function_name, current_types.len(), *number)?;
583
584 let mut valid_type = current_types.first().unwrap().to_owned();
586 for t in current_types.iter().skip(1) {
587 let logical_data_type: NativeType = t.into();
588 if logical_data_type == NativeType::Null {
589 continue;
590 }
591
592 if !logical_data_type.is_numeric() {
593 return plan_err!(
594 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
595 );
596 }
597
598 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
599 valid_type = coerced_type;
600 } else {
601 return plan_err!(
602 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
603 );
604 }
605 }
606
607 let logical_data_type: NativeType = valid_type.clone().into();
608 if logical_data_type == NativeType::Null {
612 valid_type = DataType::Float64;
613 } else if !logical_data_type.is_numeric() {
614 return plan_err!(
615 "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
616 );
617 }
618
619 vec![vec![valid_type; *number]]
620 }
621 TypeSignature::Comparable(num) => {
622 function_length_check(function_name, current_types.len(), *num)?;
623 let mut target_type = current_types[0].to_owned();
624 for data_type in current_types.iter().skip(1) {
625 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
626 target_type = dt;
627 } else {
628 return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
629 }
630 }
631 if target_type.is_null() {
633 vec![vec![DataType::Utf8View; *num]]
634 } else {
635 vec![vec![target_type; *num]]
636 }
637 }
638 TypeSignature::Coercible(param_types) => {
639 function_length_check(function_name, current_types.len(), param_types.len())?;
640
641 let mut new_types = Vec::with_capacity(current_types.len());
642 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
643 let current_native_type: NativeType = current_type.into();
644
645 if param.desired_type().matches_native_type(¤t_native_type) {
646 let casted_type = param.desired_type().default_casted_type(
647 ¤t_native_type,
648 current_type,
649 )?;
650
651 new_types.push(casted_type);
652 } else if param
653 .allowed_source_types()
654 .iter()
655 .any(|t| t.matches_native_type(¤t_native_type)) {
656 let default_casted_type = param.default_casted_type().unwrap();
658 let casted_type = default_casted_type.default_cast_for(current_type)?;
659 new_types.push(casted_type);
660 } else {
661 return internal_err!(
662 "Expect {} but received NativeType::{}, DataType: {}",
663 param.desired_type(),
664 current_native_type,
665 current_type
666 );
667 }
668 }
669
670 vec![new_types]
671 }
672 TypeSignature::Uniform(number, valid_types) => {
673 if *number == 0 {
674 return plan_err!("The function '{function_name}' expected at least one argument");
675 }
676
677 valid_types
678 .iter()
679 .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
680 .collect()
681 }
682 TypeSignature::UserDefined => {
683 return internal_err!(
684 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
685 )
686 }
687 TypeSignature::VariadicAny => {
688 if current_types.is_empty() {
689 return plan_err!(
690 "Function '{function_name}' expected at least one argument but received 0"
691 );
692 }
693 vec![current_types.to_vec()]
694 }
695 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
696 TypeSignature::ArraySignature(ref function_signature) => match function_signature {
697 ArrayFunctionSignature::Array { arguments, array_coercion, } => {
698 array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
699 }
700 ArrayFunctionSignature::RecursiveArray => {
701 if current_types.len() != 1 {
702 return Ok(vec![vec![]]);
703 }
704 recursive_array(¤t_types[0])
705 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
706 }
707 ArrayFunctionSignature::MapArray => {
708 if current_types.len() != 1 {
709 return Ok(vec![vec![]]);
710 }
711
712 match ¤t_types[0] {
713 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
714 _ => vec![vec![]],
715 }
716 }
717 },
718 TypeSignature::Nullary => {
719 if !current_types.is_empty() {
720 return plan_err!(
721 "The function '{function_name}' expected zero argument but received {}",
722 current_types.len()
723 );
724 }
725 vec![vec![]]
726 }
727 TypeSignature::Any(number) => {
728 if current_types.is_empty() {
729 return plan_err!(
730 "The function '{function_name}' expected at least one argument but received 0"
731 );
732 }
733
734 if current_types.len() != *number {
735 return plan_err!(
736 "The function '{function_name}' expected {number} arguments but received {}",
737 current_types.len()
738 );
739 }
740 vec![(0..*number).map(|i| current_types[i].clone()).collect()]
741 }
742 TypeSignature::OneOf(types) => types
743 .iter()
744 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
745 .flatten()
746 .collect::<Vec<_>>(),
747 };
748
749 Ok(valid_types)
750}
751
752fn maybe_data_types(
759 valid_types: &[DataType],
760 current_types: &[DataType],
761) -> Option<Vec<DataType>> {
762 if valid_types.len() != current_types.len() {
763 return None;
764 }
765
766 let mut new_type = Vec::with_capacity(valid_types.len());
767 for (i, valid_type) in valid_types.iter().enumerate() {
768 let current_type = ¤t_types[i];
769
770 if current_type == valid_type {
771 new_type.push(current_type.clone())
772 } else {
773 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
777 new_type.push(coerced_type)
778 } else {
779 return None;
781 }
782 }
783 }
784 Some(new_type)
785}
786
787fn maybe_data_types_without_coercion(
791 valid_types: &[DataType],
792 current_types: &[DataType],
793) -> Option<Vec<DataType>> {
794 if valid_types.len() != current_types.len() {
795 return None;
796 }
797
798 let mut new_type = Vec::with_capacity(valid_types.len());
799 for (i, valid_type) in valid_types.iter().enumerate() {
800 let current_type = ¤t_types[i];
801
802 if current_type == valid_type {
803 new_type.push(current_type.clone())
804 } else if can_cast_types(current_type, valid_type) {
805 new_type.push(valid_type.clone())
807 } else {
808 return None;
809 }
810 }
811 Some(new_type)
812}
813
814pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
819 if type_into == type_from {
820 return true;
821 }
822 if let Some(coerced) = coerced_from(type_into, type_from) {
823 return coerced == *type_into;
824 }
825 false
826}
827
828fn coerced_from<'a>(
835 type_into: &'a DataType,
836 type_from: &'a DataType,
837) -> Option<DataType> {
838 use self::DataType::*;
839
840 match (type_into, type_from) {
842 (_, Dictionary(_, value_type))
844 if coerced_from(type_into, value_type).is_some() =>
845 {
846 Some(type_into.clone())
847 }
848 (Dictionary(_, value_type), _)
849 if coerced_from(value_type, type_from).is_some() =>
850 {
851 Some(type_into.clone())
852 }
853 (Int8, Null | Int8) => Some(type_into.clone()),
855 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
856 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
857 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
858 Some(type_into.clone())
859 }
860 (UInt8, Null | UInt8) => Some(type_into.clone()),
861 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
862 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
863 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
864 (
865 Float32,
866 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
867 | Float32,
868 ) => Some(type_into.clone()),
869 (
870 Float64,
871 Null
872 | Int8
873 | Int16
874 | Int32
875 | Int64
876 | UInt8
877 | UInt16
878 | UInt32
879 | UInt64
880 | Float32
881 | Float64
882 | Decimal32(_, _)
883 | Decimal64(_, _)
884 | Decimal128(_, _)
885 | Decimal256(_, _),
886 ) => Some(type_into.clone()),
887 (
888 Timestamp(TimeUnit::Nanosecond, None),
889 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
890 ) => Some(type_into.clone()),
891 (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
892 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
894 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
896 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
897
898 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
899
900 (List(_) | LargeList(_), _)
903 if base_type(type_from).is_null()
904 || list_ndims(type_from) == list_ndims(type_into) =>
905 {
906 Some(type_into.clone())
907 }
908 (
910 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
911 FixedSizeList(f_from, size_from),
912 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
913 Some(data_type) if &data_type != f_into.data_type() => {
914 let new_field =
915 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
916 Some(FixedSizeList(new_field, *size_from))
917 }
918 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
919 _ => None,
920 },
921 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
922 match type_from {
923 Timestamp(_, Some(from_tz)) => {
924 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
925 }
926 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
927 Some(Timestamp(*unit, Some("+00".into())))
929 }
930 _ => None,
931 }
932 }
933 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
934 Some(type_into.clone())
935 }
936 _ => None,
937 }
938}
939
940#[cfg(test)]
941mod tests {
942 use crate::Volatility;
943
944 use super::*;
945 use arrow::datatypes::Field;
946 use datafusion_common::assert_contains;
947
948 #[test]
949 fn test_string_conversion() {
950 let cases = vec![
951 (DataType::Utf8View, DataType::Utf8, true),
952 (DataType::Utf8View, DataType::LargeUtf8, true),
953 ];
954
955 for case in cases {
956 assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
957 }
958 }
959
960 #[test]
961 fn test_maybe_data_types() {
962 let cases = vec![
964 (
966 vec![DataType::UInt8, DataType::UInt16],
967 vec![DataType::UInt8, DataType::UInt16],
968 Some(vec![DataType::UInt8, DataType::UInt16]),
969 ),
970 (
972 vec![DataType::UInt16, DataType::UInt16],
973 vec![DataType::UInt8, DataType::UInt16],
974 Some(vec![DataType::UInt16, DataType::UInt16]),
975 ),
976 (vec![], vec![], Some(vec![])),
978 (
980 vec![DataType::Boolean, DataType::UInt16],
981 vec![DataType::UInt8, DataType::UInt16],
982 None,
983 ),
984 (
986 vec![DataType::Boolean, DataType::UInt32],
987 vec![DataType::Boolean, DataType::UInt16],
988 Some(vec![DataType::Boolean, DataType::UInt32]),
989 ),
990 (
992 vec![
993 DataType::Timestamp(TimeUnit::Nanosecond, None),
994 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
995 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
996 ],
997 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
998 Some(vec![
999 DataType::Timestamp(TimeUnit::Nanosecond, None),
1000 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1001 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1002 ]),
1003 ),
1004 ];
1005
1006 for case in cases {
1007 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1008 }
1009 }
1010
1011 #[test]
1012 fn test_get_valid_types_numeric() -> Result<()> {
1013 let get_valid_types_flatten =
1014 |function_name: &str,
1015 signature: &TypeSignature,
1016 current_types: &[DataType]| {
1017 get_valid_types(function_name, signature, current_types)
1018 .unwrap()
1019 .into_iter()
1020 .flatten()
1021 .collect::<Vec<_>>()
1022 };
1023
1024 let got = get_valid_types_flatten(
1026 "test",
1027 &TypeSignature::Numeric(1),
1028 &[DataType::Int32],
1029 );
1030 assert_eq!(got, [DataType::Int32]);
1031
1032 let got = get_valid_types_flatten(
1034 "test",
1035 &TypeSignature::Numeric(2),
1036 &[DataType::Int32, DataType::Int64],
1037 );
1038 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1039
1040 let got = get_valid_types_flatten(
1042 "test",
1043 &TypeSignature::Numeric(3),
1044 &[DataType::Int32, DataType::Int64, DataType::Float64],
1045 );
1046 assert_eq!(
1047 got,
1048 [DataType::Float64, DataType::Float64, DataType::Float64]
1049 );
1050
1051 let got = get_valid_types(
1053 "test",
1054 &TypeSignature::Numeric(2),
1055 &[DataType::Int32, DataType::Utf8],
1056 )
1057 .unwrap_err();
1058 assert_contains!(
1059 got.to_string(),
1060 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1061 );
1062
1063 let got = get_valid_types_flatten(
1065 "test",
1066 &TypeSignature::Numeric(1),
1067 &[DataType::Null],
1068 );
1069 assert_eq!(got, [DataType::Float64]);
1070
1071 let got = get_valid_types(
1073 "test",
1074 &TypeSignature::Numeric(1),
1075 &[DataType::Timestamp(TimeUnit::Second, None)],
1076 )
1077 .unwrap_err();
1078 assert_contains!(
1079 got.to_string(),
1080 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1081 );
1082
1083 Ok(())
1084 }
1085
1086 #[test]
1087 fn test_get_valid_types_one_of() -> Result<()> {
1088 let signature =
1089 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1090
1091 let invalid_types = get_valid_types(
1092 "test",
1093 &signature,
1094 &[DataType::Int32, DataType::Int32, DataType::Int32],
1095 )?;
1096 assert_eq!(invalid_types.len(), 0);
1097
1098 let args = vec![DataType::Int32, DataType::Int32];
1099 let valid_types = get_valid_types("test", &signature, &args)?;
1100 assert_eq!(valid_types.len(), 1);
1101 assert_eq!(valid_types[0], args);
1102
1103 let args = vec![DataType::Int32];
1104 let valid_types = get_valid_types("test", &signature, &args)?;
1105 assert_eq!(valid_types.len(), 1);
1106 assert_eq!(valid_types[0], args);
1107
1108 Ok(())
1109 }
1110
1111 #[test]
1112 fn test_get_valid_types_length_check() -> Result<()> {
1113 let signature = TypeSignature::Numeric(1);
1114
1115 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1116 assert_contains!(
1117 err.to_string(),
1118 "Function 'test' expects 1 arguments but received 0"
1119 );
1120
1121 let err = get_valid_types(
1122 "test",
1123 &signature,
1124 &[DataType::Int32, DataType::Int32, DataType::Int32],
1125 )
1126 .unwrap_err();
1127 assert_contains!(
1128 err.to_string(),
1129 "Function 'test' expects 1 arguments but received 3"
1130 );
1131
1132 Ok(())
1133 }
1134
1135 #[test]
1136 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1137 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1138 let current_types = vec![
1139 DataType::FixedSizeList(Arc::clone(&inner), 2), ];
1141
1142 let signature = Signature::exact(
1143 vec![DataType::FixedSizeList(
1144 Arc::clone(&inner),
1145 FIXED_SIZE_LIST_WILDCARD,
1146 )],
1147 Volatility::Stable,
1148 );
1149
1150 let coerced_data_types = data_types("test", ¤t_types, &signature)?;
1151 assert_eq!(coerced_data_types, current_types);
1152
1153 let signature = Signature::exact(
1155 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1156 Volatility::Stable,
1157 );
1158 let coerced_data_types = data_types("test", ¤t_types, &signature);
1159 assert!(coerced_data_types.is_err());
1160
1161 let signature = Signature::exact(
1163 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1164 Volatility::Stable,
1165 );
1166 let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap();
1167 assert_eq!(coerced_data_types, current_types);
1168
1169 Ok(())
1170 }
1171
1172 #[test]
1173 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1174 let type_into = DataType::FixedSizeList(
1175 Arc::new(Field::new_list_field(
1176 DataType::FixedSizeList(
1177 Arc::new(Field::new_list_field(DataType::Int32, false)),
1178 FIXED_SIZE_LIST_WILDCARD,
1179 ),
1180 false,
1181 )),
1182 FIXED_SIZE_LIST_WILDCARD,
1183 );
1184
1185 let type_from = DataType::FixedSizeList(
1186 Arc::new(Field::new_list_field(
1187 DataType::FixedSizeList(
1188 Arc::new(Field::new_list_field(DataType::Int8, false)),
1189 4,
1190 ),
1191 false,
1192 )),
1193 3,
1194 );
1195
1196 assert_eq!(
1197 coerced_from(&type_into, &type_from),
1198 Some(DataType::FixedSizeList(
1199 Arc::new(Field::new_list_field(
1200 DataType::FixedSizeList(
1201 Arc::new(Field::new_list_field(DataType::Int32, false)),
1202 4,
1203 ),
1204 false,
1205 )),
1206 3,
1207 ))
1208 );
1209
1210 Ok(())
1211 }
1212
1213 #[test]
1214 fn test_coerced_from_dictionary() {
1215 let type_into =
1216 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1217 let type_from = DataType::Int64;
1218 assert_eq!(coerced_from(&type_into, &type_from), None);
1219
1220 let type_from =
1221 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1222 let type_into = DataType::Int64;
1223 assert_eq!(
1224 coerced_from(&type_into, &type_from),
1225 Some(type_into.clone())
1226 );
1227 }
1228
1229 #[test]
1230 fn test_get_valid_types_array_and_array() -> Result<()> {
1231 let function = "array_and_array";
1232 let signature = Signature::arrays(
1233 2,
1234 Some(ListCoercion::FixedSizedListToList),
1235 Volatility::Immutable,
1236 );
1237
1238 let data_types = vec![
1239 DataType::new_list(DataType::Int32, true),
1240 DataType::new_large_list(DataType::Float64, true),
1241 ];
1242 assert_eq!(
1243 get_valid_types(function, &signature.type_signature, &data_types)?,
1244 vec![vec![
1245 DataType::new_large_list(DataType::Float64, true),
1246 DataType::new_large_list(DataType::Float64, true),
1247 ]]
1248 );
1249
1250 let data_types = vec![
1251 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1252 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1253 ];
1254 assert_eq!(
1255 get_valid_types(function, &signature.type_signature, &data_types)?,
1256 vec![vec![
1257 DataType::new_list(DataType::Int64, true),
1258 DataType::new_list(DataType::Int64, true),
1259 ]]
1260 );
1261
1262 let data_types = vec![
1263 DataType::new_fixed_size_list(DataType::Null, 3, true),
1264 DataType::new_large_list(DataType::Utf8, true),
1265 ];
1266 assert_eq!(
1267 get_valid_types(function, &signature.type_signature, &data_types)?,
1268 vec![vec![
1269 DataType::new_large_list(DataType::Utf8, true),
1270 DataType::new_large_list(DataType::Utf8, true),
1271 ]]
1272 );
1273
1274 Ok(())
1275 }
1276
1277 #[test]
1278 fn test_get_valid_types_array_and_element() -> Result<()> {
1279 let function = "array_and_element";
1280 let signature = Signature::array_and_element(Volatility::Immutable);
1281
1282 let data_types =
1283 vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1284 assert_eq!(
1285 get_valid_types(function, &signature.type_signature, &data_types)?,
1286 vec![vec![
1287 DataType::new_list(DataType::Float64, true),
1288 DataType::Float64,
1289 ]]
1290 );
1291
1292 let data_types = vec![
1293 DataType::new_large_list(DataType::Int32, true),
1294 DataType::Null,
1295 ];
1296 assert_eq!(
1297 get_valid_types(function, &signature.type_signature, &data_types)?,
1298 vec![vec![
1299 DataType::new_large_list(DataType::Int32, true),
1300 DataType::Int32,
1301 ]]
1302 );
1303
1304 let data_types = vec![
1305 DataType::new_fixed_size_list(DataType::Null, 3, true),
1306 DataType::Utf8,
1307 ];
1308 assert_eq!(
1309 get_valid_types(function, &signature.type_signature, &data_types)?,
1310 vec![vec![
1311 DataType::new_list(DataType::Utf8, true),
1312 DataType::Utf8,
1313 ]]
1314 );
1315
1316 Ok(())
1317 }
1318
1319 #[test]
1320 fn test_get_valid_types_element_and_array() -> Result<()> {
1321 let function = "element_and_array";
1322 let signature = Signature::element_and_array(Volatility::Immutable);
1323
1324 let data_types = vec![
1325 DataType::new_large_list(DataType::Null, false),
1326 DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1327 ];
1328 assert_eq!(
1329 get_valid_types(function, &signature.type_signature, &data_types)?,
1330 vec![vec![
1331 DataType::new_large_list(DataType::Int64, true),
1332 DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1333 ]]
1334 );
1335
1336 Ok(())
1337 }
1338
1339 #[test]
1340 fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1341 let function = "fixed_size_arrays";
1342 let signature = Signature::arrays(2, None, Volatility::Immutable);
1343
1344 let data_types = vec![
1345 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1346 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1347 ];
1348 assert_eq!(
1349 get_valid_types(function, &signature.type_signature, &data_types)?,
1350 vec![vec![
1351 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1352 DataType::new_fixed_size_list(DataType::Int64, 5, true),
1353 ]]
1354 );
1355
1356 let data_types = vec![
1357 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1358 DataType::new_list(DataType::Int32, true),
1359 ];
1360 assert_eq!(
1361 get_valid_types(function, &signature.type_signature, &data_types)?,
1362 vec![vec![
1363 DataType::new_list(DataType::Int64, true),
1364 DataType::new_list(DataType::Int64, true),
1365 ]]
1366 );
1367
1368 let data_types = vec![
1369 DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1370 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1371 ];
1372 assert_eq!(
1373 get_valid_types(function, &signature.type_signature, &data_types)?,
1374 vec![vec![]]
1375 );
1376
1377 let data_types = vec![
1378 DataType::new_fixed_size_list(DataType::Int64, 3, false),
1379 DataType::new_list(DataType::Int32, false),
1380 ];
1381 assert_eq!(
1382 get_valid_types(function, &signature.type_signature, &data_types)?,
1383 vec![vec![
1384 DataType::new_list(DataType::Int64, false),
1385 DataType::new_list(DataType::Int64, false),
1386 ]]
1387 );
1388
1389 Ok(())
1390 }
1391}