1use super::binary::{binary_numeric_coercion, comparison_coercion};
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::{
21 compute::can_cast_types,
22 datatypes::{DataType, Field, TimeUnit},
23};
24use datafusion_common::types::LogicalType;
25use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
26use datafusion_common::{
27 exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28 utils::list_ndims, Result,
29};
30use datafusion_expr_common::signature::ArrayFunctionArgument;
31use datafusion_expr_common::{
32 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
33 type_coercion::binary::comparison_coercion_numeric,
34 type_coercion::binary::string_coercion,
35};
36use std::sync::Arc;
37
38pub fn data_types_with_scalar_udf(
46 current_types: &[DataType],
47 func: &ScalarUDF,
48) -> Result<Vec<DataType>> {
49 let signature = func.signature();
50 let type_signature = &signature.type_signature;
51
52 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
53 if type_signature.supports_zero_argument() {
54 return Ok(vec![]);
55 } else if type_signature.used_to_support_zero_arguments() {
56 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
58 } else {
59 return plan_err!("'{}' does not support zero arguments", func.name());
60 }
61 }
62
63 let valid_types =
64 get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
65
66 if valid_types
67 .iter()
68 .any(|data_type| data_type == current_types)
69 {
70 return Ok(current_types.to_vec());
71 }
72
73 try_coerce_types(func.name(), valid_types, current_types, type_signature)
74}
75
76pub fn data_types_with_aggregate_udf(
84 current_types: &[DataType],
85 func: &AggregateUDF,
86) -> Result<Vec<DataType>> {
87 let signature = func.signature();
88 let type_signature = &signature.type_signature;
89
90 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
91 if type_signature.supports_zero_argument() {
92 return Ok(vec![]);
93 } else if type_signature.used_to_support_zero_arguments() {
94 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
96 } else {
97 return plan_err!("'{}' does not support zero arguments", func.name());
98 }
99 }
100
101 let valid_types =
102 get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;
103 if valid_types
104 .iter()
105 .any(|data_type| data_type == current_types)
106 {
107 return Ok(current_types.to_vec());
108 }
109
110 try_coerce_types(func.name(), valid_types, current_types, type_signature)
111}
112
113pub fn data_types_with_window_udf(
121 current_types: &[DataType],
122 func: &WindowUDF,
123) -> Result<Vec<DataType>> {
124 let signature = func.signature();
125 let type_signature = &signature.type_signature;
126
127 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
128 if type_signature.supports_zero_argument() {
129 return Ok(vec![]);
130 } else if type_signature.used_to_support_zero_arguments() {
131 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
133 } else {
134 return plan_err!("'{}' does not support zero arguments", func.name());
135 }
136 }
137
138 let valid_types =
139 get_valid_types_with_window_udf(type_signature, current_types, func)?;
140 if valid_types
141 .iter()
142 .any(|data_type| data_type == current_types)
143 {
144 return Ok(current_types.to_vec());
145 }
146
147 try_coerce_types(func.name(), valid_types, current_types, type_signature)
148}
149
150pub fn data_types(
158 function_name: impl AsRef<str>,
159 current_types: &[DataType],
160 signature: &Signature,
161) -> Result<Vec<DataType>> {
162 let type_signature = &signature.type_signature;
163
164 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
165 if type_signature.supports_zero_argument() {
166 return Ok(vec![]);
167 } else if type_signature.used_to_support_zero_arguments() {
168 return plan_err!(
170 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
171 function_name.as_ref()
172 );
173 } else {
174 return plan_err!(
175 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
176 function_name.as_ref()
177 );
178 }
179 }
180
181 let valid_types =
182 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
183 if valid_types
184 .iter()
185 .any(|data_type| data_type == current_types)
186 {
187 return Ok(current_types.to_vec());
188 }
189
190 try_coerce_types(
191 function_name.as_ref(),
192 valid_types,
193 current_types,
194 type_signature,
195 )
196}
197
198fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
199 if let TypeSignature::OneOf(signatures) = type_signature {
200 return signatures.iter().all(is_well_supported_signature);
201 }
202
203 matches!(
204 type_signature,
205 TypeSignature::UserDefined
206 | TypeSignature::Numeric(_)
207 | TypeSignature::String(_)
208 | TypeSignature::Coercible(_)
209 | TypeSignature::Any(_)
210 | TypeSignature::Nullary
211 | TypeSignature::Comparable(_)
212 )
213}
214
215fn try_coerce_types(
216 function_name: &str,
217 valid_types: Vec<Vec<DataType>>,
218 current_types: &[DataType],
219 type_signature: &TypeSignature,
220) -> Result<Vec<DataType>> {
221 let mut valid_types = valid_types;
222
223 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
225 if !type_signature.is_one_of() {
228 assert_eq!(valid_types.len(), 1);
229 }
230
231 let valid_types = valid_types.swap_remove(0);
232 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
233 return Ok(t);
234 }
235 } else {
236 for valid_types in valid_types {
240 if let Some(types) = maybe_data_types(&valid_types, current_types) {
241 return Ok(types);
242 }
243 }
244 }
245
246 plan_err!(
248 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
249 )
250}
251
252fn get_valid_types_with_scalar_udf(
253 signature: &TypeSignature,
254 current_types: &[DataType],
255 func: &ScalarUDF,
256) -> Result<Vec<Vec<DataType>>> {
257 match signature {
258 TypeSignature::UserDefined => match func.coerce_types(current_types) {
259 Ok(coerced_types) => Ok(vec![coerced_types]),
260 Err(e) => exec_err!(
261 "Function '{}' user-defined coercion failed with {:?}",
262 func.name(),
263 e.strip_backtrace()
264 ),
265 },
266 TypeSignature::OneOf(signatures) => {
267 let mut res = vec![];
268 let mut errors = vec![];
269 for sig in signatures {
270 match get_valid_types_with_scalar_udf(sig, current_types, func) {
271 Ok(valid_types) => {
272 res.extend(valid_types);
273 }
274 Err(e) => {
275 errors.push(e.to_string());
276 }
277 }
278 }
279
280 if res.is_empty() {
282 internal_err!(
283 "Function '{}' failed to match any signature, errors: {}",
284 func.name(),
285 errors.join(",")
286 )
287 } else {
288 Ok(res)
289 }
290 }
291 _ => get_valid_types(func.name(), signature, current_types),
292 }
293}
294
295fn get_valid_types_with_aggregate_udf(
296 signature: &TypeSignature,
297 current_types: &[DataType],
298 func: &AggregateUDF,
299) -> Result<Vec<Vec<DataType>>> {
300 let valid_types = match signature {
301 TypeSignature::UserDefined => match func.coerce_types(current_types) {
302 Ok(coerced_types) => vec![coerced_types],
303 Err(e) => {
304 return exec_err!(
305 "Function '{}' user-defined coercion failed with {:?}",
306 func.name(),
307 e.strip_backtrace()
308 )
309 }
310 },
311 TypeSignature::OneOf(signatures) => signatures
312 .iter()
313 .filter_map(|t| {
314 get_valid_types_with_aggregate_udf(t, current_types, func).ok()
315 })
316 .flatten()
317 .collect::<Vec<_>>(),
318 _ => get_valid_types(func.name(), signature, current_types)?,
319 };
320
321 Ok(valid_types)
322}
323
324fn get_valid_types_with_window_udf(
325 signature: &TypeSignature,
326 current_types: &[DataType],
327 func: &WindowUDF,
328) -> Result<Vec<Vec<DataType>>> {
329 let valid_types = match signature {
330 TypeSignature::UserDefined => match func.coerce_types(current_types) {
331 Ok(coerced_types) => vec![coerced_types],
332 Err(e) => {
333 return exec_err!(
334 "Function '{}' user-defined coercion failed with {:?}",
335 func.name(),
336 e.strip_backtrace()
337 )
338 }
339 },
340 TypeSignature::OneOf(signatures) => signatures
341 .iter()
342 .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
343 .flatten()
344 .collect::<Vec<_>>(),
345 _ => get_valid_types(func.name(), signature, current_types)?,
346 };
347
348 Ok(valid_types)
349}
350
351fn get_valid_types(
353 function_name: &str,
354 signature: &TypeSignature,
355 current_types: &[DataType],
356) -> Result<Vec<Vec<DataType>>> {
357 fn array_valid_types(
358 function_name: &str,
359 current_types: &[DataType],
360 arguments: &[ArrayFunctionArgument],
361 array_coercion: Option<&ListCoercion>,
362 ) -> Result<Vec<Vec<DataType>>> {
363 if current_types.len() != arguments.len() {
364 return Ok(vec![vec![]]);
365 }
366
367 let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
368 if *arg == ArrayFunctionArgument::Array {
369 Some(idx)
370 } else {
371 None
372 }
373 });
374 let Some(array_idx) = array_idx else {
375 return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument"));
376 };
377 let Some(array_type) = array(¤t_types[array_idx]) else {
378 return Ok(vec![vec![]]);
379 };
380
381 let mut new_base_type = datafusion_common::utils::base_type(&array_type);
384 for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
385 match argument_type {
386 ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => {
387 new_base_type =
388 coerce_array_types(function_name, current_type, &new_base_type)?;
389 }
390 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {}
391 }
392 }
393 let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
394 &array_type,
395 &new_base_type,
396 array_coercion,
397 );
398
399 let new_elem_type = match new_array_type {
400 DataType::List(ref field)
401 | DataType::LargeList(ref field)
402 | DataType::FixedSizeList(ref field, _) => field.data_type(),
403 _ => return Ok(vec![vec![]]),
404 };
405
406 let mut valid_types = Vec::with_capacity(arguments.len());
407 for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
408 let valid_type = match argument_type {
409 ArrayFunctionArgument::Element => new_elem_type.clone(),
410 ArrayFunctionArgument::Index => DataType::Int64,
411 ArrayFunctionArgument::String => DataType::Utf8,
412 ArrayFunctionArgument::Array => {
413 let Some(current_type) = array(current_type) else {
414 return Ok(vec![vec![]]);
415 };
416 let new_type =
417 datafusion_common::utils::coerced_type_with_base_type_only(
418 ¤t_type,
419 &new_base_type,
420 array_coercion,
421 );
422 if new_type != new_array_type {
424 return Ok(vec![vec![]]);
425 }
426 new_type
427 }
428 };
429 valid_types.push(valid_type);
430 }
431
432 Ok(vec![valid_types])
433 }
434
435 fn array(array_type: &DataType) -> Option<DataType> {
436 match array_type {
437 DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
438 DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
439 DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field(
440 DataType::Int64,
441 true,
442 )))),
443 _ => None,
444 }
445 }
446
447 fn coerce_array_types(
448 function_name: &str,
449 current_type: &DataType,
450 base_type: &DataType,
451 ) -> Result<DataType> {
452 let current_base_type = datafusion_common::utils::base_type(current_type);
453 let new_base_type = comparison_coercion(base_type, ¤t_base_type);
454 new_base_type.ok_or_else(|| {
455 internal_datafusion_err!(
456 "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}"
457 )
458 })
459 }
460
461 fn recursive_array(array_type: &DataType) -> Option<DataType> {
462 match array_type {
463 DataType::List(_)
464 | DataType::LargeList(_)
465 | DataType::FixedSizeList(_, _) => {
466 let array_type = coerced_fixed_size_list_to_list(array_type);
467 Some(array_type)
468 }
469 _ => None,
470 }
471 }
472
473 fn function_length_check(
474 function_name: &str,
475 length: usize,
476 expected_length: usize,
477 ) -> Result<()> {
478 if length != expected_length {
479 return plan_err!(
480 "Function '{function_name}' expects {expected_length} arguments but received {length}"
481 );
482 }
483 Ok(())
484 }
485
486 let valid_types = match signature {
487 TypeSignature::Variadic(valid_types) => valid_types
488 .iter()
489 .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
490 .collect(),
491 TypeSignature::String(number) => {
492 function_length_check(function_name, current_types.len(), *number)?;
493
494 let mut new_types = Vec::with_capacity(current_types.len());
495 for data_type in current_types.iter() {
496 let logical_data_type: NativeType = data_type.into();
497 if logical_data_type == NativeType::String {
498 new_types.push(data_type.to_owned());
499 } else if logical_data_type == NativeType::Null {
500 new_types.push(DataType::Utf8);
502 } else {
503 return plan_err!(
504 "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
505 );
506 }
507 }
508
509 fn find_common_type(
511 function_name: &str,
512 lhs_type: &DataType,
513 rhs_type: &DataType,
514 ) -> Result<DataType> {
515 match (lhs_type, rhs_type) {
516 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
517 find_common_type(function_name, lhs, rhs)
518 }
519 (DataType::Dictionary(_, v), other)
520 | (other, DataType::Dictionary(_, v)) => {
521 find_common_type(function_name, v, other)
522 }
523 _ => {
524 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
525 Ok(coerced_type)
526 } else {
527 plan_err!(
528 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
529 )
530 }
531 }
532 }
533 }
534
535 let mut coerced_type = new_types.first().unwrap().to_owned();
537 for t in new_types.iter().skip(1) {
538 coerced_type = find_common_type(function_name, &coerced_type, t)?;
539 }
540
541 fn base_type_or_default_type(data_type: &DataType) -> DataType {
542 if let DataType::Dictionary(_, v) = data_type {
543 base_type_or_default_type(v)
544 } else {
545 data_type.to_owned()
546 }
547 }
548
549 vec![vec![base_type_or_default_type(&coerced_type); *number]]
550 }
551 TypeSignature::Numeric(number) => {
552 function_length_check(function_name, current_types.len(), *number)?;
553
554 let mut valid_type = current_types.first().unwrap().to_owned();
556 for t in current_types.iter().skip(1) {
557 let logical_data_type: NativeType = t.into();
558 if logical_data_type == NativeType::Null {
559 continue;
560 }
561
562 if !logical_data_type.is_numeric() {
563 return plan_err!(
564 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
565 );
566 }
567
568 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
569 valid_type = coerced_type;
570 } else {
571 return plan_err!(
572 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
573 );
574 }
575 }
576
577 let logical_data_type: NativeType = valid_type.clone().into();
578 if logical_data_type == NativeType::Null {
582 valid_type = DataType::Float64;
583 } else if !logical_data_type.is_numeric() {
584 return plan_err!(
585 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
586 );
587 }
588
589 vec![vec![valid_type; *number]]
590 }
591 TypeSignature::Comparable(num) => {
592 function_length_check(function_name, current_types.len(), *num)?;
593 let mut target_type = current_types[0].to_owned();
594 for data_type in current_types.iter().skip(1) {
595 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
596 target_type = dt;
597 } else {
598 return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
599 }
600 }
601 if target_type.is_null() {
603 vec![vec![DataType::Utf8View; *num]]
604 } else {
605 vec![vec![target_type; *num]]
606 }
607 }
608 TypeSignature::Coercible(param_types) => {
609 function_length_check(function_name, current_types.len(), param_types.len())?;
610
611 let mut new_types = Vec::with_capacity(current_types.len());
612 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
613 let current_native_type: NativeType = current_type.into();
614
615 if param.desired_type().matches_native_type(¤t_native_type) {
616 let casted_type = param.desired_type().default_casted_type(
617 ¤t_native_type,
618 current_type,
619 )?;
620
621 new_types.push(casted_type);
622 } else if param
623 .allowed_source_types()
624 .iter()
625 .any(|t| t.matches_native_type(¤t_native_type)) {
626 let default_casted_type = param.default_casted_type().unwrap();
628 let casted_type = default_casted_type.default_cast_for(current_type)?;
629 new_types.push(casted_type);
630 } else {
631 return internal_err!(
632 "Expect {} but received {}, DataType: {}",
633 param.desired_type(),
634 current_native_type,
635 current_type
636 );
637 }
638 }
639
640 vec![new_types]
641 }
642 TypeSignature::Uniform(number, valid_types) => {
643 if *number == 0 {
644 return plan_err!("The function '{function_name}' expected at least one argument");
645 }
646
647 valid_types
648 .iter()
649 .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
650 .collect()
651 }
652 TypeSignature::UserDefined => {
653 return internal_err!(
654 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
655 )
656 }
657 TypeSignature::VariadicAny => {
658 if current_types.is_empty() {
659 return plan_err!(
660 "Function '{function_name}' expected at least one argument but received 0"
661 );
662 }
663 vec![current_types.to_vec()]
664 }
665 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
666 TypeSignature::ArraySignature(ref function_signature) => match function_signature {
667 ArrayFunctionSignature::Array { arguments, array_coercion, } => {
668 array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
669 }
670 ArrayFunctionSignature::RecursiveArray => {
671 if current_types.len() != 1 {
672 return Ok(vec![vec![]]);
673 }
674 recursive_array(¤t_types[0])
675 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
676 }
677 ArrayFunctionSignature::MapArray => {
678 if current_types.len() != 1 {
679 return Ok(vec![vec![]]);
680 }
681
682 match ¤t_types[0] {
683 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
684 _ => vec![vec![]],
685 }
686 }
687 },
688 TypeSignature::Nullary => {
689 if !current_types.is_empty() {
690 return plan_err!(
691 "The function '{function_name}' expected zero argument but received {}",
692 current_types.len()
693 );
694 }
695 vec![vec![]]
696 }
697 TypeSignature::Any(number) => {
698 if current_types.is_empty() {
699 return plan_err!(
700 "The function '{function_name}' expected at least one argument but received 0"
701 );
702 }
703
704 if current_types.len() != *number {
705 return plan_err!(
706 "The function '{function_name}' expected {number} arguments but received {}",
707 current_types.len()
708 );
709 }
710 vec![(0..*number).map(|i| current_types[i].clone()).collect()]
711 }
712 TypeSignature::OneOf(types) => types
713 .iter()
714 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
715 .flatten()
716 .collect::<Vec<_>>(),
717 };
718
719 Ok(valid_types)
720}
721
722fn maybe_data_types(
729 valid_types: &[DataType],
730 current_types: &[DataType],
731) -> Option<Vec<DataType>> {
732 if valid_types.len() != current_types.len() {
733 return None;
734 }
735
736 let mut new_type = Vec::with_capacity(valid_types.len());
737 for (i, valid_type) in valid_types.iter().enumerate() {
738 let current_type = ¤t_types[i];
739
740 if current_type == valid_type {
741 new_type.push(current_type.clone())
742 } else {
743 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
747 new_type.push(coerced_type)
748 } else {
749 return None;
751 }
752 }
753 }
754 Some(new_type)
755}
756
757fn maybe_data_types_without_coercion(
761 valid_types: &[DataType],
762 current_types: &[DataType],
763) -> Option<Vec<DataType>> {
764 if valid_types.len() != current_types.len() {
765 return None;
766 }
767
768 let mut new_type = Vec::with_capacity(valid_types.len());
769 for (i, valid_type) in valid_types.iter().enumerate() {
770 let current_type = ¤t_types[i];
771
772 if current_type == valid_type {
773 new_type.push(current_type.clone())
774 } else if can_cast_types(current_type, valid_type) {
775 new_type.push(valid_type.clone())
777 } else {
778 return None;
779 }
780 }
781 Some(new_type)
782}
783
784pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
789 if type_into == type_from {
790 return true;
791 }
792 if let Some(coerced) = coerced_from(type_into, type_from) {
793 return coerced == *type_into;
794 }
795 false
796}
797
798fn coerced_from<'a>(
805 type_into: &'a DataType,
806 type_from: &'a DataType,
807) -> Option<DataType> {
808 use self::DataType::*;
809
810 match (type_into, type_from) {
812 (_, Dictionary(_, value_type))
814 if coerced_from(type_into, value_type).is_some() =>
815 {
816 Some(type_into.clone())
817 }
818 (Dictionary(_, value_type), _)
819 if coerced_from(value_type, type_from).is_some() =>
820 {
821 Some(type_into.clone())
822 }
823 (Int8, Null | Int8) => Some(type_into.clone()),
825 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
826 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
827 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
828 Some(type_into.clone())
829 }
830 (UInt8, Null | UInt8) => Some(type_into.clone()),
831 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
832 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
833 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
834 (
835 Float32,
836 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
837 | Float32,
838 ) => Some(type_into.clone()),
839 (
840 Float64,
841 Null
842 | Int8
843 | Int16
844 | Int32
845 | Int64
846 | UInt8
847 | UInt16
848 | UInt32
849 | UInt64
850 | Float32
851 | Float64
852 | Decimal128(_, _),
853 ) => Some(type_into.clone()),
854 (
855 Timestamp(TimeUnit::Nanosecond, None),
856 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
857 ) => Some(type_into.clone()),
858 (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
859 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
861 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
863 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
864
865 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
866
867 (List(_) | LargeList(_), _)
870 if datafusion_common::utils::base_type(type_from).eq(&Null)
871 || list_ndims(type_from) == list_ndims(type_into) =>
872 {
873 Some(type_into.clone())
874 }
875 (
877 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
878 FixedSizeList(f_from, size_from),
879 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
880 Some(data_type) if &data_type != f_into.data_type() => {
881 let new_field =
882 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
883 Some(FixedSizeList(new_field, *size_from))
884 }
885 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
886 _ => None,
887 },
888 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
889 match type_from {
890 Timestamp(_, Some(from_tz)) => {
891 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
892 }
893 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
894 Some(Timestamp(*unit, Some("+00".into())))
896 }
897 _ => None,
898 }
899 }
900 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
901 Some(type_into.clone())
902 }
903 _ => None,
904 }
905}
906
907#[cfg(test)]
908mod tests {
909
910 use crate::Volatility;
911
912 use super::*;
913 use arrow::datatypes::Field;
914 use datafusion_common::assert_contains;
915
916 #[test]
917 fn test_string_conversion() {
918 let cases = vec![
919 (DataType::Utf8View, DataType::Utf8, true),
920 (DataType::Utf8View, DataType::LargeUtf8, true),
921 ];
922
923 for case in cases {
924 assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
925 }
926 }
927
928 #[test]
929 fn test_maybe_data_types() {
930 let cases = vec![
932 (
934 vec![DataType::UInt8, DataType::UInt16],
935 vec![DataType::UInt8, DataType::UInt16],
936 Some(vec![DataType::UInt8, DataType::UInt16]),
937 ),
938 (
940 vec![DataType::UInt16, DataType::UInt16],
941 vec![DataType::UInt8, DataType::UInt16],
942 Some(vec![DataType::UInt16, DataType::UInt16]),
943 ),
944 (vec![], vec![], Some(vec![])),
946 (
948 vec![DataType::Boolean, DataType::UInt16],
949 vec![DataType::UInt8, DataType::UInt16],
950 None,
951 ),
952 (
954 vec![DataType::Boolean, DataType::UInt32],
955 vec![DataType::Boolean, DataType::UInt16],
956 Some(vec![DataType::Boolean, DataType::UInt32]),
957 ),
958 (
960 vec![
961 DataType::Timestamp(TimeUnit::Nanosecond, None),
962 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
963 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
964 ],
965 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
966 Some(vec![
967 DataType::Timestamp(TimeUnit::Nanosecond, None),
968 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
969 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
970 ]),
971 ),
972 ];
973
974 for case in cases {
975 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
976 }
977 }
978
979 #[test]
980 fn test_get_valid_types_numeric() -> Result<()> {
981 let get_valid_types_flatten =
982 |function_name: &str,
983 signature: &TypeSignature,
984 current_types: &[DataType]| {
985 get_valid_types(function_name, signature, current_types)
986 .unwrap()
987 .into_iter()
988 .flatten()
989 .collect::<Vec<_>>()
990 };
991
992 let got = get_valid_types_flatten(
994 "test",
995 &TypeSignature::Numeric(1),
996 &[DataType::Int32],
997 );
998 assert_eq!(got, [DataType::Int32]);
999
1000 let got = get_valid_types_flatten(
1002 "test",
1003 &TypeSignature::Numeric(2),
1004 &[DataType::Int32, DataType::Int64],
1005 );
1006 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1007
1008 let got = get_valid_types_flatten(
1010 "test",
1011 &TypeSignature::Numeric(3),
1012 &[DataType::Int32, DataType::Int64, DataType::Float64],
1013 );
1014 assert_eq!(
1015 got,
1016 [DataType::Float64, DataType::Float64, DataType::Float64]
1017 );
1018
1019 let got = get_valid_types(
1021 "test",
1022 &TypeSignature::Numeric(2),
1023 &[DataType::Int32, DataType::Utf8],
1024 )
1025 .unwrap_err();
1026 assert_contains!(
1027 got.to_string(),
1028 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1029 );
1030
1031 let got = get_valid_types_flatten(
1033 "test",
1034 &TypeSignature::Numeric(1),
1035 &[DataType::Null],
1036 );
1037 assert_eq!(got, [DataType::Float64]);
1038
1039 let got = get_valid_types(
1041 "test",
1042 &TypeSignature::Numeric(1),
1043 &[DataType::Timestamp(TimeUnit::Second, None)],
1044 )
1045 .unwrap_err();
1046 assert_contains!(
1047 got.to_string(),
1048 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1049 );
1050
1051 Ok(())
1052 }
1053
1054 #[test]
1055 fn test_get_valid_types_one_of() -> Result<()> {
1056 let signature =
1057 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1058
1059 let invalid_types = get_valid_types(
1060 "test",
1061 &signature,
1062 &[DataType::Int32, DataType::Int32, DataType::Int32],
1063 )?;
1064 assert_eq!(invalid_types.len(), 0);
1065
1066 let args = vec![DataType::Int32, DataType::Int32];
1067 let valid_types = get_valid_types("test", &signature, &args)?;
1068 assert_eq!(valid_types.len(), 1);
1069 assert_eq!(valid_types[0], args);
1070
1071 let args = vec![DataType::Int32];
1072 let valid_types = get_valid_types("test", &signature, &args)?;
1073 assert_eq!(valid_types.len(), 1);
1074 assert_eq!(valid_types[0], args);
1075
1076 Ok(())
1077 }
1078
1079 #[test]
1080 fn test_get_valid_types_length_check() -> Result<()> {
1081 let signature = TypeSignature::Numeric(1);
1082
1083 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1084 assert_contains!(
1085 err.to_string(),
1086 "Function 'test' expects 1 arguments but received 0"
1087 );
1088
1089 let err = get_valid_types(
1090 "test",
1091 &signature,
1092 &[DataType::Int32, DataType::Int32, DataType::Int32],
1093 )
1094 .unwrap_err();
1095 assert_contains!(
1096 err.to_string(),
1097 "Function 'test' expects 1 arguments but received 3"
1098 );
1099
1100 Ok(())
1101 }
1102
1103 #[test]
1104 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1105 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1106 let current_types = vec![
1107 DataType::FixedSizeList(Arc::clone(&inner), 2), ];
1109
1110 let signature = Signature::exact(
1111 vec![DataType::FixedSizeList(
1112 Arc::clone(&inner),
1113 FIXED_SIZE_LIST_WILDCARD,
1114 )],
1115 Volatility::Stable,
1116 );
1117
1118 let coerced_data_types = data_types("test", ¤t_types, &signature)?;
1119 assert_eq!(coerced_data_types, current_types);
1120
1121 let signature = Signature::exact(
1123 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1124 Volatility::Stable,
1125 );
1126 let coerced_data_types = data_types("test", ¤t_types, &signature);
1127 assert!(coerced_data_types.is_err());
1128
1129 let signature = Signature::exact(
1131 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1132 Volatility::Stable,
1133 );
1134 let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap();
1135 assert_eq!(coerced_data_types, current_types);
1136
1137 Ok(())
1138 }
1139
1140 #[test]
1141 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1142 let type_into = DataType::FixedSizeList(
1143 Arc::new(Field::new_list_field(
1144 DataType::FixedSizeList(
1145 Arc::new(Field::new_list_field(DataType::Int32, false)),
1146 FIXED_SIZE_LIST_WILDCARD,
1147 ),
1148 false,
1149 )),
1150 FIXED_SIZE_LIST_WILDCARD,
1151 );
1152
1153 let type_from = DataType::FixedSizeList(
1154 Arc::new(Field::new_list_field(
1155 DataType::FixedSizeList(
1156 Arc::new(Field::new_list_field(DataType::Int8, false)),
1157 4,
1158 ),
1159 false,
1160 )),
1161 3,
1162 );
1163
1164 assert_eq!(
1165 coerced_from(&type_into, &type_from),
1166 Some(DataType::FixedSizeList(
1167 Arc::new(Field::new_list_field(
1168 DataType::FixedSizeList(
1169 Arc::new(Field::new_list_field(DataType::Int32, false)),
1170 4,
1171 ),
1172 false,
1173 )),
1174 3,
1175 ))
1176 );
1177
1178 Ok(())
1179 }
1180
1181 #[test]
1182 fn test_coerced_from_dictionary() {
1183 let type_into =
1184 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1185 let type_from = DataType::Int64;
1186 assert_eq!(coerced_from(&type_into, &type_from), None);
1187
1188 let type_from =
1189 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1190 let type_into = DataType::Int64;
1191 assert_eq!(
1192 coerced_from(&type_into, &type_from),
1193 Some(type_into.clone())
1194 );
1195 }
1196}