1use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::FieldRef;
21use arrow::{
22 compute::can_cast_types,
23 datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27 base_type, coerced_fixed_size_list_to_list, ListCoercion,
28};
29use datafusion_common::{
30 exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36 type_coercion::binary::comparison_coercion_numeric,
37 type_coercion::binary::string_coercion,
38};
39use std::sync::Arc;
40
41pub fn data_types_with_scalar_udf(
49 current_types: &[DataType],
50 func: &ScalarUDF,
51) -> Result<Vec<DataType>> {
52 let signature = func.signature();
53 let type_signature = &signature.type_signature;
54
55 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
56 if type_signature.supports_zero_argument() {
57 return Ok(vec![]);
58 } else if type_signature.used_to_support_zero_arguments() {
59 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
61 } else {
62 return plan_err!("'{}' does not support zero arguments", func.name());
63 }
64 }
65
66 let valid_types =
67 get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
68
69 if valid_types
70 .iter()
71 .any(|data_type| data_type == current_types)
72 {
73 return Ok(current_types.to_vec());
74 }
75
76 try_coerce_types(func.name(), valid_types, current_types, type_signature)
77}
78
79pub fn fields_with_aggregate_udf(
87 current_fields: &[FieldRef],
88 func: &AggregateUDF,
89) -> Result<Vec<FieldRef>> {
90 let signature = func.signature();
91 let type_signature = &signature.type_signature;
92
93 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
94 if type_signature.supports_zero_argument() {
95 return Ok(vec![]);
96 } else if type_signature.used_to_support_zero_arguments() {
97 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
99 } else {
100 return plan_err!("'{}' does not support zero arguments", func.name());
101 }
102 }
103 let current_types = current_fields
104 .iter()
105 .map(|f| f.data_type())
106 .cloned()
107 .collect::<Vec<_>>();
108
109 let valid_types =
110 get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?;
111 if valid_types
112 .iter()
113 .any(|data_type| data_type == ¤t_types)
114 {
115 return Ok(current_fields.to_vec());
116 }
117
118 let updated_types =
119 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
120
121 Ok(current_fields
122 .iter()
123 .zip(updated_types)
124 .map(|(current_field, new_type)| {
125 current_field.as_ref().clone().with_data_type(new_type)
126 })
127 .map(Arc::new)
128 .collect())
129}
130
131pub fn fields_with_window_udf(
139 current_fields: &[FieldRef],
140 func: &WindowUDF,
141) -> Result<Vec<FieldRef>> {
142 let signature = func.signature();
143 let type_signature = &signature.type_signature;
144
145 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
146 if type_signature.supports_zero_argument() {
147 return Ok(vec![]);
148 } else if type_signature.used_to_support_zero_arguments() {
149 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
151 } else {
152 return plan_err!("'{}' does not support zero arguments", func.name());
153 }
154 }
155
156 let current_types = current_fields
157 .iter()
158 .map(|f| f.data_type())
159 .cloned()
160 .collect::<Vec<_>>();
161 let valid_types =
162 get_valid_types_with_window_udf(type_signature, ¤t_types, func)?;
163 if valid_types
164 .iter()
165 .any(|data_type| data_type == ¤t_types)
166 {
167 return Ok(current_fields.to_vec());
168 }
169
170 let updated_types =
171 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
172
173 Ok(current_fields
174 .iter()
175 .zip(updated_types)
176 .map(|(current_field, new_type)| {
177 current_field.as_ref().clone().with_data_type(new_type)
178 })
179 .map(Arc::new)
180 .collect())
181}
182
183pub fn data_types(
191 function_name: impl AsRef<str>,
192 current_types: &[DataType],
193 signature: &Signature,
194) -> Result<Vec<DataType>> {
195 let type_signature = &signature.type_signature;
196
197 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
198 if type_signature.supports_zero_argument() {
199 return Ok(vec![]);
200 } else if type_signature.used_to_support_zero_arguments() {
201 return plan_err!(
203 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
204 function_name.as_ref()
205 );
206 } else {
207 return plan_err!(
208 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
209 function_name.as_ref()
210 );
211 }
212 }
213
214 let valid_types =
215 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
216 if valid_types
217 .iter()
218 .any(|data_type| data_type == current_types)
219 {
220 return Ok(current_types.to_vec());
221 }
222
223 try_coerce_types(
224 function_name.as_ref(),
225 valid_types,
226 current_types,
227 type_signature,
228 )
229}
230
231fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
232 if let TypeSignature::OneOf(signatures) = type_signature {
233 return signatures.iter().all(is_well_supported_signature);
234 }
235
236 matches!(
237 type_signature,
238 TypeSignature::UserDefined
239 | TypeSignature::Numeric(_)
240 | TypeSignature::String(_)
241 | TypeSignature::Coercible(_)
242 | TypeSignature::Any(_)
243 | TypeSignature::Nullary
244 | TypeSignature::Comparable(_)
245 )
246}
247
248fn try_coerce_types(
249 function_name: &str,
250 valid_types: Vec<Vec<DataType>>,
251 current_types: &[DataType],
252 type_signature: &TypeSignature,
253) -> Result<Vec<DataType>> {
254 let mut valid_types = valid_types;
255
256 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
258 if !type_signature.is_one_of() {
261 assert_eq!(valid_types.len(), 1);
262 }
263
264 let valid_types = valid_types.swap_remove(0);
265 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
266 return Ok(t);
267 }
268 } else {
269 for valid_types in valid_types {
273 if let Some(types) = maybe_data_types(&valid_types, current_types) {
274 return Ok(types);
275 }
276 }
277 }
278
279 plan_err!(
281 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
282 )
283}
284
285fn get_valid_types_with_scalar_udf(
286 signature: &TypeSignature,
287 current_types: &[DataType],
288 func: &ScalarUDF,
289) -> Result<Vec<Vec<DataType>>> {
290 match signature {
291 TypeSignature::UserDefined => match func.coerce_types(current_types) {
292 Ok(coerced_types) => Ok(vec![coerced_types]),
293 Err(e) => exec_err!(
294 "Function '{}' user-defined coercion failed with {:?}",
295 func.name(),
296 e.strip_backtrace()
297 ),
298 },
299 TypeSignature::OneOf(signatures) => {
300 let mut res = vec![];
301 let mut errors = vec![];
302 for sig in signatures {
303 match get_valid_types_with_scalar_udf(sig, current_types, func) {
304 Ok(valid_types) => {
305 res.extend(valid_types);
306 }
307 Err(e) => {
308 errors.push(e.to_string());
309 }
310 }
311 }
312
313 if res.is_empty() {
315 internal_err!(
316 "Function '{}' failed to match any signature, errors: {}",
317 func.name(),
318 errors.join(",")
319 )
320 } else {
321 Ok(res)
322 }
323 }
324 _ => get_valid_types(func.name(), signature, current_types),
325 }
326}
327
328fn get_valid_types_with_aggregate_udf(
329 signature: &TypeSignature,
330 current_types: &[DataType],
331 func: &AggregateUDF,
332) -> Result<Vec<Vec<DataType>>> {
333 let valid_types = match signature {
334 TypeSignature::UserDefined => match func.coerce_types(current_types) {
335 Ok(coerced_types) => vec![coerced_types],
336 Err(e) => {
337 return exec_err!(
338 "Function '{}' user-defined coercion failed with {:?}",
339 func.name(),
340 e.strip_backtrace()
341 )
342 }
343 },
344 TypeSignature::OneOf(signatures) => signatures
345 .iter()
346 .filter_map(|t| {
347 get_valid_types_with_aggregate_udf(t, current_types, func).ok()
348 })
349 .flatten()
350 .collect::<Vec<_>>(),
351 _ => get_valid_types(func.name(), signature, current_types)?,
352 };
353
354 Ok(valid_types)
355}
356
357fn get_valid_types_with_window_udf(
358 signature: &TypeSignature,
359 current_types: &[DataType],
360 func: &WindowUDF,
361) -> Result<Vec<Vec<DataType>>> {
362 let valid_types = match signature {
363 TypeSignature::UserDefined => match func.coerce_types(current_types) {
364 Ok(coerced_types) => vec![coerced_types],
365 Err(e) => {
366 return exec_err!(
367 "Function '{}' user-defined coercion failed with {:?}",
368 func.name(),
369 e.strip_backtrace()
370 )
371 }
372 },
373 TypeSignature::OneOf(signatures) => signatures
374 .iter()
375 .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
376 .flatten()
377 .collect::<Vec<_>>(),
378 _ => get_valid_types(func.name(), signature, current_types)?,
379 };
380
381 Ok(valid_types)
382}
383
384fn get_valid_types(
386 function_name: &str,
387 signature: &TypeSignature,
388 current_types: &[DataType],
389) -> Result<Vec<Vec<DataType>>> {
390 fn array_valid_types(
391 function_name: &str,
392 current_types: &[DataType],
393 arguments: &[ArrayFunctionArgument],
394 array_coercion: Option<&ListCoercion>,
395 ) -> Result<Vec<Vec<DataType>>> {
396 if current_types.len() != arguments.len() {
397 return Ok(vec![vec![]]);
398 }
399
400 let mut large_list = false;
401 let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
402 let mut list_sizes = Vec::with_capacity(arguments.len());
403 let mut element_types = Vec::with_capacity(arguments.len());
404 for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
405 match argument {
406 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
407 ArrayFunctionArgument::Element => {
408 element_types.push(current_type.clone())
409 }
410 ArrayFunctionArgument::Array => match current_type {
411 DataType::Null => element_types.push(DataType::Null),
412 DataType::List(field) => {
413 element_types.push(field.data_type().clone());
414 fixed_size = false;
415 }
416 DataType::LargeList(field) => {
417 element_types.push(field.data_type().clone());
418 large_list = true;
419 fixed_size = false;
420 }
421 DataType::FixedSizeList(field, size) => {
422 element_types.push(field.data_type().clone());
423 list_sizes.push(*size)
424 }
425 arg_type => {
426 plan_err!("{function_name} does not support type {arg_type}")?
427 }
428 },
429 }
430 }
431
432 let Some(element_type) = type_union_resolution(&element_types) else {
433 return Ok(vec![vec![]]);
434 };
435
436 if !fixed_size {
437 list_sizes.clear()
438 }
439
440 let mut list_sizes = list_sizes.into_iter();
441 let valid_types = arguments.iter().zip(current_types.iter()).map(
442 |(argument_type, current_type)| match argument_type {
443 ArrayFunctionArgument::Index => DataType::Int64,
444 ArrayFunctionArgument::String => DataType::Utf8,
445 ArrayFunctionArgument::Element => element_type.clone(),
446 ArrayFunctionArgument::Array => {
447 if current_type.is_null() {
448 DataType::Null
449 } else if large_list {
450 DataType::new_large_list(element_type.clone(), true)
451 } else if let Some(size) = list_sizes.next() {
452 DataType::new_fixed_size_list(element_type.clone(), size, true)
453 } else {
454 DataType::new_list(element_type.clone(), true)
455 }
456 }
457 },
458 );
459
460 Ok(vec![valid_types.collect()])
461 }
462
463 fn recursive_array(array_type: &DataType) -> Option<DataType> {
464 match array_type {
465 DataType::List(_)
466 | DataType::LargeList(_)
467 | DataType::FixedSizeList(_, _) => {
468 let array_type = coerced_fixed_size_list_to_list(array_type);
469 Some(array_type)
470 }
471 _ => None,
472 }
473 }
474
475 fn function_length_check(
476 function_name: &str,
477 length: usize,
478 expected_length: usize,
479 ) -> Result<()> {
480 if length != expected_length {
481 return plan_err!(
482 "Function '{function_name}' expects {expected_length} arguments but received {length}"
483 );
484 }
485 Ok(())
486 }
487
488 let valid_types = match signature {
489 TypeSignature::Variadic(valid_types) => valid_types
490 .iter()
491 .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
492 .collect(),
493 TypeSignature::String(number) => {
494 function_length_check(function_name, current_types.len(), *number)?;
495
496 let mut new_types = Vec::with_capacity(current_types.len());
497 for data_type in current_types.iter() {
498 let logical_data_type: NativeType = data_type.into();
499 if logical_data_type == NativeType::String {
500 new_types.push(data_type.to_owned());
501 } else if logical_data_type == NativeType::Null {
502 new_types.push(DataType::Utf8);
504 } else {
505 return plan_err!(
506 "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
507 );
508 }
509 }
510
511 fn find_common_type(
513 function_name: &str,
514 lhs_type: &DataType,
515 rhs_type: &DataType,
516 ) -> Result<DataType> {
517 match (lhs_type, rhs_type) {
518 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
519 find_common_type(function_name, lhs, rhs)
520 }
521 (DataType::Dictionary(_, v), other)
522 | (other, DataType::Dictionary(_, v)) => {
523 find_common_type(function_name, v, other)
524 }
525 _ => {
526 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
527 Ok(coerced_type)
528 } else {
529 plan_err!(
530 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
531 )
532 }
533 }
534 }
535 }
536
537 let mut coerced_type = new_types.first().unwrap().to_owned();
539 for t in new_types.iter().skip(1) {
540 coerced_type = find_common_type(function_name, &coerced_type, t)?;
541 }
542
543 fn base_type_or_default_type(data_type: &DataType) -> DataType {
544 if let DataType::Dictionary(_, v) = data_type {
545 base_type_or_default_type(v)
546 } else {
547 data_type.to_owned()
548 }
549 }
550
551 vec![vec![base_type_or_default_type(&coerced_type); *number]]
552 }
553 TypeSignature::Numeric(number) => {
554 function_length_check(function_name, current_types.len(), *number)?;
555
556 let mut valid_type = current_types.first().unwrap().to_owned();
558 for t in current_types.iter().skip(1) {
559 let logical_data_type: NativeType = t.into();
560 if logical_data_type == NativeType::Null {
561 continue;
562 }
563
564 if !logical_data_type.is_numeric() {
565 return plan_err!(
566 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
567 );
568 }
569
570 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
571 valid_type = coerced_type;
572 } else {
573 return plan_err!(
574 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
575 );
576 }
577 }
578
579 let logical_data_type: NativeType = valid_type.clone().into();
580 if logical_data_type == NativeType::Null {
584 valid_type = DataType::Float64;
585 } else if !logical_data_type.is_numeric() {
586 return plan_err!(
587 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
588 );
589 }
590
591 vec![vec![valid_type; *number]]
592 }
593 TypeSignature::Comparable(num) => {
594 function_length_check(function_name, current_types.len(), *num)?;
595 let mut target_type = current_types[0].to_owned();
596 for data_type in current_types.iter().skip(1) {
597 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
598 target_type = dt;
599 } else {
600 return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
601 }
602 }
603 if target_type.is_null() {
605 vec![vec![DataType::Utf8View; *num]]
606 } else {
607 vec![vec![target_type; *num]]
608 }
609 }
610 TypeSignature::Coercible(param_types) => {
611 function_length_check(function_name, current_types.len(), param_types.len())?;
612
613 let mut new_types = Vec::with_capacity(current_types.len());
614 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
615 let current_native_type: NativeType = current_type.into();
616
617 if param.desired_type().matches_native_type(¤t_native_type) {
618 let casted_type = param.desired_type().default_casted_type(
619 ¤t_native_type,
620 current_type,
621 )?;
622
623 new_types.push(casted_type);
624 } else if param
625 .allowed_source_types()
626 .iter()
627 .any(|t| t.matches_native_type(¤t_native_type)) {
628 let default_casted_type = param.default_casted_type().unwrap();
630 let casted_type = default_casted_type.default_cast_for(current_type)?;
631 new_types.push(casted_type);
632 } else {
633 return internal_err!(
634 "Expect {} but received {}, DataType: {}",
635 param.desired_type(),
636 current_native_type,
637 current_type
638 );
639 }
640 }
641
642 vec![new_types]
643 }
644 TypeSignature::Uniform(number, valid_types) => {
645 if *number == 0 {
646 return plan_err!("The function '{function_name}' expected at least one argument");
647 }
648
649 valid_types
650 .iter()
651 .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
652 .collect()
653 }
654 TypeSignature::UserDefined => {
655 return internal_err!(
656 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
657 )
658 }
659 TypeSignature::VariadicAny => {
660 if current_types.is_empty() {
661 return plan_err!(
662 "Function '{function_name}' expected at least one argument but received 0"
663 );
664 }
665 vec![current_types.to_vec()]
666 }
667 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
668 TypeSignature::ArraySignature(ref function_signature) => match function_signature {
669 ArrayFunctionSignature::Array { arguments, array_coercion, } => {
670 array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
671 }
672 ArrayFunctionSignature::RecursiveArray => {
673 if current_types.len() != 1 {
674 return Ok(vec![vec![]]);
675 }
676 recursive_array(¤t_types[0])
677 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
678 }
679 ArrayFunctionSignature::MapArray => {
680 if current_types.len() != 1 {
681 return Ok(vec![vec![]]);
682 }
683
684 match ¤t_types[0] {
685 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
686 _ => vec![vec![]],
687 }
688 }
689 },
690 TypeSignature::Nullary => {
691 if !current_types.is_empty() {
692 return plan_err!(
693 "The function '{function_name}' expected zero argument but received {}",
694 current_types.len()
695 );
696 }
697 vec![vec![]]
698 }
699 TypeSignature::Any(number) => {
700 if current_types.is_empty() {
701 return plan_err!(
702 "The function '{function_name}' expected at least one argument but received 0"
703 );
704 }
705
706 if current_types.len() != *number {
707 return plan_err!(
708 "The function '{function_name}' expected {number} arguments but received {}",
709 current_types.len()
710 );
711 }
712 vec![(0..*number).map(|i| current_types[i].clone()).collect()]
713 }
714 TypeSignature::OneOf(types) => types
715 .iter()
716 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
717 .flatten()
718 .collect::<Vec<_>>(),
719 };
720
721 Ok(valid_types)
722}
723
724fn maybe_data_types(
731 valid_types: &[DataType],
732 current_types: &[DataType],
733) -> Option<Vec<DataType>> {
734 if valid_types.len() != current_types.len() {
735 return None;
736 }
737
738 let mut new_type = Vec::with_capacity(valid_types.len());
739 for (i, valid_type) in valid_types.iter().enumerate() {
740 let current_type = ¤t_types[i];
741
742 if current_type == valid_type {
743 new_type.push(current_type.clone())
744 } else {
745 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
749 new_type.push(coerced_type)
750 } else {
751 return None;
753 }
754 }
755 }
756 Some(new_type)
757}
758
759fn maybe_data_types_without_coercion(
763 valid_types: &[DataType],
764 current_types: &[DataType],
765) -> Option<Vec<DataType>> {
766 if valid_types.len() != current_types.len() {
767 return None;
768 }
769
770 let mut new_type = Vec::with_capacity(valid_types.len());
771 for (i, valid_type) in valid_types.iter().enumerate() {
772 let current_type = ¤t_types[i];
773
774 if current_type == valid_type {
775 new_type.push(current_type.clone())
776 } else if can_cast_types(current_type, valid_type) {
777 new_type.push(valid_type.clone())
779 } else {
780 return None;
781 }
782 }
783 Some(new_type)
784}
785
786pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
791 if type_into == type_from {
792 return true;
793 }
794 if let Some(coerced) = coerced_from(type_into, type_from) {
795 return coerced == *type_into;
796 }
797 false
798}
799
800fn coerced_from<'a>(
807 type_into: &'a DataType,
808 type_from: &'a DataType,
809) -> Option<DataType> {
810 use self::DataType::*;
811
812 match (type_into, type_from) {
814 (_, Dictionary(_, value_type))
816 if coerced_from(type_into, value_type).is_some() =>
817 {
818 Some(type_into.clone())
819 }
820 (Dictionary(_, value_type), _)
821 if coerced_from(value_type, type_from).is_some() =>
822 {
823 Some(type_into.clone())
824 }
825 (Int8, Null | Int8) => Some(type_into.clone()),
827 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
828 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
829 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
830 Some(type_into.clone())
831 }
832 (UInt8, Null | UInt8) => Some(type_into.clone()),
833 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
834 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
835 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
836 (
837 Float32,
838 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
839 | Float32,
840 ) => Some(type_into.clone()),
841 (
842 Float64,
843 Null
844 | Int8
845 | Int16
846 | Int32
847 | Int64
848 | UInt8
849 | UInt16
850 | UInt32
851 | UInt64
852 | Float32
853 | Float64
854 | Decimal128(_, _),
855 ) => Some(type_into.clone()),
856 (
857 Timestamp(TimeUnit::Nanosecond, None),
858 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
859 ) => Some(type_into.clone()),
860 (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
861 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
863 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
865 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
866
867 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
868
869 (List(_) | LargeList(_), _)
872 if base_type(type_from).is_null()
873 || list_ndims(type_from) == list_ndims(type_into) =>
874 {
875 Some(type_into.clone())
876 }
877 (
879 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
880 FixedSizeList(f_from, size_from),
881 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
882 Some(data_type) if &data_type != f_into.data_type() => {
883 let new_field =
884 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
885 Some(FixedSizeList(new_field, *size_from))
886 }
887 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
888 _ => None,
889 },
890 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
891 match type_from {
892 Timestamp(_, Some(from_tz)) => {
893 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
894 }
895 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
896 Some(Timestamp(*unit, Some("+00".into())))
898 }
899 _ => None,
900 }
901 }
902 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
903 Some(type_into.clone())
904 }
905 _ => None,
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use crate::Volatility;
912
913 use super::*;
914 use arrow::datatypes::Field;
915 use datafusion_common::assert_contains;
916
917 #[test]
918 fn test_string_conversion() {
919 let cases = vec![
920 (DataType::Utf8View, DataType::Utf8, true),
921 (DataType::Utf8View, DataType::LargeUtf8, true),
922 ];
923
924 for case in cases {
925 assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
926 }
927 }
928
929 #[test]
930 fn test_maybe_data_types() {
931 let cases = vec![
933 (
935 vec![DataType::UInt8, DataType::UInt16],
936 vec![DataType::UInt8, DataType::UInt16],
937 Some(vec![DataType::UInt8, DataType::UInt16]),
938 ),
939 (
941 vec![DataType::UInt16, DataType::UInt16],
942 vec![DataType::UInt8, DataType::UInt16],
943 Some(vec![DataType::UInt16, DataType::UInt16]),
944 ),
945 (vec![], vec![], Some(vec![])),
947 (
949 vec![DataType::Boolean, DataType::UInt16],
950 vec![DataType::UInt8, DataType::UInt16],
951 None,
952 ),
953 (
955 vec![DataType::Boolean, DataType::UInt32],
956 vec![DataType::Boolean, DataType::UInt16],
957 Some(vec![DataType::Boolean, DataType::UInt32]),
958 ),
959 (
961 vec![
962 DataType::Timestamp(TimeUnit::Nanosecond, None),
963 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
964 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
965 ],
966 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
967 Some(vec![
968 DataType::Timestamp(TimeUnit::Nanosecond, None),
969 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
970 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
971 ]),
972 ),
973 ];
974
975 for case in cases {
976 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
977 }
978 }
979
980 #[test]
981 fn test_get_valid_types_numeric() -> Result<()> {
982 let get_valid_types_flatten =
983 |function_name: &str,
984 signature: &TypeSignature,
985 current_types: &[DataType]| {
986 get_valid_types(function_name, signature, current_types)
987 .unwrap()
988 .into_iter()
989 .flatten()
990 .collect::<Vec<_>>()
991 };
992
993 let got = get_valid_types_flatten(
995 "test",
996 &TypeSignature::Numeric(1),
997 &[DataType::Int32],
998 );
999 assert_eq!(got, [DataType::Int32]);
1000
1001 let got = get_valid_types_flatten(
1003 "test",
1004 &TypeSignature::Numeric(2),
1005 &[DataType::Int32, DataType::Int64],
1006 );
1007 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1008
1009 let got = get_valid_types_flatten(
1011 "test",
1012 &TypeSignature::Numeric(3),
1013 &[DataType::Int32, DataType::Int64, DataType::Float64],
1014 );
1015 assert_eq!(
1016 got,
1017 [DataType::Float64, DataType::Float64, DataType::Float64]
1018 );
1019
1020 let got = get_valid_types(
1022 "test",
1023 &TypeSignature::Numeric(2),
1024 &[DataType::Int32, DataType::Utf8],
1025 )
1026 .unwrap_err();
1027 assert_contains!(
1028 got.to_string(),
1029 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1030 );
1031
1032 let got = get_valid_types_flatten(
1034 "test",
1035 &TypeSignature::Numeric(1),
1036 &[DataType::Null],
1037 );
1038 assert_eq!(got, [DataType::Float64]);
1039
1040 let got = get_valid_types(
1042 "test",
1043 &TypeSignature::Numeric(1),
1044 &[DataType::Timestamp(TimeUnit::Second, None)],
1045 )
1046 .unwrap_err();
1047 assert_contains!(
1048 got.to_string(),
1049 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1050 );
1051
1052 Ok(())
1053 }
1054
1055 #[test]
1056 fn test_get_valid_types_one_of() -> Result<()> {
1057 let signature =
1058 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1059
1060 let invalid_types = get_valid_types(
1061 "test",
1062 &signature,
1063 &[DataType::Int32, DataType::Int32, DataType::Int32],
1064 )?;
1065 assert_eq!(invalid_types.len(), 0);
1066
1067 let args = vec![DataType::Int32, DataType::Int32];
1068 let valid_types = get_valid_types("test", &signature, &args)?;
1069 assert_eq!(valid_types.len(), 1);
1070 assert_eq!(valid_types[0], args);
1071
1072 let args = vec![DataType::Int32];
1073 let valid_types = get_valid_types("test", &signature, &args)?;
1074 assert_eq!(valid_types.len(), 1);
1075 assert_eq!(valid_types[0], args);
1076
1077 Ok(())
1078 }
1079
1080 #[test]
1081 fn test_get_valid_types_length_check() -> Result<()> {
1082 let signature = TypeSignature::Numeric(1);
1083
1084 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1085 assert_contains!(
1086 err.to_string(),
1087 "Function 'test' expects 1 arguments but received 0"
1088 );
1089
1090 let err = get_valid_types(
1091 "test",
1092 &signature,
1093 &[DataType::Int32, DataType::Int32, DataType::Int32],
1094 )
1095 .unwrap_err();
1096 assert_contains!(
1097 err.to_string(),
1098 "Function 'test' expects 1 arguments but received 3"
1099 );
1100
1101 Ok(())
1102 }
1103
1104 #[test]
1105 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1106 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1107 let current_types = vec![
1108 DataType::FixedSizeList(Arc::clone(&inner), 2), ];
1110
1111 let signature = Signature::exact(
1112 vec![DataType::FixedSizeList(
1113 Arc::clone(&inner),
1114 FIXED_SIZE_LIST_WILDCARD,
1115 )],
1116 Volatility::Stable,
1117 );
1118
1119 let coerced_data_types = data_types("test", ¤t_types, &signature)?;
1120 assert_eq!(coerced_data_types, current_types);
1121
1122 let signature = Signature::exact(
1124 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1125 Volatility::Stable,
1126 );
1127 let coerced_data_types = data_types("test", ¤t_types, &signature);
1128 assert!(coerced_data_types.is_err());
1129
1130 let signature = Signature::exact(
1132 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1133 Volatility::Stable,
1134 );
1135 let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap();
1136 assert_eq!(coerced_data_types, current_types);
1137
1138 Ok(())
1139 }
1140
1141 #[test]
1142 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1143 let type_into = DataType::FixedSizeList(
1144 Arc::new(Field::new_list_field(
1145 DataType::FixedSizeList(
1146 Arc::new(Field::new_list_field(DataType::Int32, false)),
1147 FIXED_SIZE_LIST_WILDCARD,
1148 ),
1149 false,
1150 )),
1151 FIXED_SIZE_LIST_WILDCARD,
1152 );
1153
1154 let type_from = DataType::FixedSizeList(
1155 Arc::new(Field::new_list_field(
1156 DataType::FixedSizeList(
1157 Arc::new(Field::new_list_field(DataType::Int8, false)),
1158 4,
1159 ),
1160 false,
1161 )),
1162 3,
1163 );
1164
1165 assert_eq!(
1166 coerced_from(&type_into, &type_from),
1167 Some(DataType::FixedSizeList(
1168 Arc::new(Field::new_list_field(
1169 DataType::FixedSizeList(
1170 Arc::new(Field::new_list_field(DataType::Int32, false)),
1171 4,
1172 ),
1173 false,
1174 )),
1175 3,
1176 ))
1177 );
1178
1179 Ok(())
1180 }
1181
1182 #[test]
1183 fn test_coerced_from_dictionary() {
1184 let type_into =
1185 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1186 let type_from = DataType::Int64;
1187 assert_eq!(coerced_from(&type_into, &type_from), None);
1188
1189 let type_from =
1190 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1191 let type_into = DataType::Int64;
1192 assert_eq!(
1193 coerced_from(&type_into, &type_from),
1194 Some(type_into.clone())
1195 );
1196 }
1197
1198 #[test]
1199 fn test_get_valid_types_array_and_array() -> Result<()> {
1200 let function = "array_and_array";
1201 let signature = Signature::arrays(
1202 2,
1203 Some(ListCoercion::FixedSizedListToList),
1204 Volatility::Immutable,
1205 );
1206
1207 let data_types = vec![
1208 DataType::new_list(DataType::Int32, true),
1209 DataType::new_large_list(DataType::Float64, true),
1210 ];
1211 assert_eq!(
1212 get_valid_types(function, &signature.type_signature, &data_types)?,
1213 vec![vec![
1214 DataType::new_large_list(DataType::Float64, true),
1215 DataType::new_large_list(DataType::Float64, true),
1216 ]]
1217 );
1218
1219 let data_types = vec![
1220 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1221 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1222 ];
1223 assert_eq!(
1224 get_valid_types(function, &signature.type_signature, &data_types)?,
1225 vec![vec![
1226 DataType::new_list(DataType::Int64, true),
1227 DataType::new_list(DataType::Int64, true),
1228 ]]
1229 );
1230
1231 let data_types = vec![
1232 DataType::new_fixed_size_list(DataType::Null, 3, true),
1233 DataType::new_large_list(DataType::Utf8, true),
1234 ];
1235 assert_eq!(
1236 get_valid_types(function, &signature.type_signature, &data_types)?,
1237 vec![vec![
1238 DataType::new_large_list(DataType::Utf8, true),
1239 DataType::new_large_list(DataType::Utf8, true),
1240 ]]
1241 );
1242
1243 Ok(())
1244 }
1245
1246 #[test]
1247 fn test_get_valid_types_array_and_element() -> Result<()> {
1248 let function = "array_and_element";
1249 let signature = Signature::array_and_element(Volatility::Immutable);
1250
1251 let data_types =
1252 vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1253 assert_eq!(
1254 get_valid_types(function, &signature.type_signature, &data_types)?,
1255 vec![vec![
1256 DataType::new_list(DataType::Float64, true),
1257 DataType::Float64,
1258 ]]
1259 );
1260
1261 let data_types = vec![
1262 DataType::new_large_list(DataType::Int32, true),
1263 DataType::Null,
1264 ];
1265 assert_eq!(
1266 get_valid_types(function, &signature.type_signature, &data_types)?,
1267 vec![vec![
1268 DataType::new_large_list(DataType::Int32, true),
1269 DataType::Int32,
1270 ]]
1271 );
1272
1273 let data_types = vec![
1274 DataType::new_fixed_size_list(DataType::Null, 3, true),
1275 DataType::Utf8,
1276 ];
1277 assert_eq!(
1278 get_valid_types(function, &signature.type_signature, &data_types)?,
1279 vec![vec![
1280 DataType::new_list(DataType::Utf8, true),
1281 DataType::Utf8,
1282 ]]
1283 );
1284
1285 Ok(())
1286 }
1287
1288 #[test]
1289 fn test_get_valid_types_element_and_array() -> Result<()> {
1290 let function = "element_and_array";
1291 let signature = Signature::element_and_array(Volatility::Immutable);
1292
1293 let data_types = vec![
1294 DataType::new_large_list(DataType::Null, false),
1295 DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1296 ];
1297 assert_eq!(
1298 get_valid_types(function, &signature.type_signature, &data_types)?,
1299 vec![vec![
1300 DataType::new_large_list(DataType::Int64, true),
1301 DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1302 ]]
1303 );
1304
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1310 let function = "fixed_size_arrays";
1311 let signature = Signature::arrays(2, None, Volatility::Immutable);
1312
1313 let data_types = vec![
1314 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1315 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1316 ];
1317 assert_eq!(
1318 get_valid_types(function, &signature.type_signature, &data_types)?,
1319 vec![vec![
1320 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1321 DataType::new_fixed_size_list(DataType::Int64, 5, true),
1322 ]]
1323 );
1324
1325 let data_types = vec![
1326 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1327 DataType::new_list(DataType::Int32, true),
1328 ];
1329 assert_eq!(
1330 get_valid_types(function, &signature.type_signature, &data_types)?,
1331 vec![vec![
1332 DataType::new_list(DataType::Int64, true),
1333 DataType::new_list(DataType::Int64, true),
1334 ]]
1335 );
1336
1337 let data_types = vec![
1338 DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1339 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1340 ];
1341 assert_eq!(
1342 get_valid_types(function, &signature.type_signature, &data_types)?,
1343 vec![vec![]]
1344 );
1345
1346 Ok(())
1347 }
1348}