1use super::binary::binary_numeric_coercion;
19use crate::{
20 AggregateUDF, HigherOrderTypeSignature, HigherOrderUDF, ScalarUDF, Signature,
21 TypeSignature, ValueOrLambda, WindowUDF,
22};
23use arrow::datatypes::{Field, FieldRef};
24use arrow::{
25 compute::can_cast_types,
26 datatypes::{DataType, TimeUnit},
27};
28use datafusion_common::internal_datafusion_err;
29use datafusion_common::types::LogicalType;
30use datafusion_common::utils::{
31 ListCoercion, base_type, coerced_fixed_size_list_to_list,
32};
33use datafusion_common::{
34 Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
35};
36use datafusion_expr_common::signature::ArrayFunctionArgument;
37use datafusion_expr_common::type_coercion::binary::type_union_resolution;
38use datafusion_expr_common::{
39 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
40 type_coercion::binary::comparison_coercion,
41 type_coercion::binary::string_coercion,
42};
43use itertools::Itertools as _;
44use std::sync::Arc;
45
46pub trait UDFCoercionExt {
49 fn name(&self) -> &str;
51 fn signature(&self) -> &Signature;
54 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
57}
58
59impl UDFCoercionExt for ScalarUDF {
60 fn name(&self) -> &str {
61 self.name()
62 }
63
64 fn signature(&self) -> &Signature {
65 self.signature()
66 }
67
68 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
69 self.coerce_types(arg_types)
70 }
71}
72
73impl UDFCoercionExt for AggregateUDF {
74 fn name(&self) -> &str {
75 self.name()
76 }
77
78 fn signature(&self) -> &Signature {
79 self.signature()
80 }
81
82 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
83 self.coerce_types(arg_types)
84 }
85}
86
87impl UDFCoercionExt for WindowUDF {
88 fn name(&self) -> &str {
89 self.name()
90 }
91
92 fn signature(&self) -> &Signature {
93 self.signature()
94 }
95
96 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
97 self.coerce_types(arg_types)
98 }
99}
100
101pub fn fields_with_udf<F: UDFCoercionExt>(
109 current_fields: &[FieldRef],
110 func: &F,
111) -> Result<Vec<FieldRef>> {
112 let signature = func.signature();
113 let type_signature = &signature.type_signature;
114
115 if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
116 if type_signature.supports_zero_argument() {
117 return Ok(vec![]);
118 } else if type_signature.used_to_support_zero_arguments() {
119 return plan_err!(
121 "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
122 func.name()
123 );
124 } else {
125 return plan_err!("'{}' does not support zero arguments", func.name());
126 }
127 }
128 let current_types = current_fields
129 .iter()
130 .map(|f| f.data_type())
131 .cloned()
132 .collect::<Vec<_>>();
133
134 let valid_types = get_valid_types_with_udf(type_signature, ¤t_types, func)?;
135 if valid_types
136 .iter()
137 .any(|data_type| data_type == ¤t_types)
138 {
139 return Ok(current_fields.to_vec());
140 }
141
142 let updated_types =
143 try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?;
144
145 Ok(current_fields
146 .iter()
147 .zip(updated_types)
148 .map(|(current_field, new_type)| {
149 current_field.as_ref().clone().with_data_type(new_type)
150 })
151 .map(Arc::new)
152 .collect())
153}
154
155pub fn value_fields_with_higher_order_udf<L: Clone>(
168 current_fields: &[ValueOrLambda<FieldRef, L>],
169 func: &HigherOrderUDF,
170) -> Result<Vec<ValueOrLambda<FieldRef, L>>> {
171 match func.signature().type_signature {
172 HigherOrderTypeSignature::UserDefined => {
173 let arg_types = current_fields
174 .iter()
175 .filter_map(|p| match p {
176 ValueOrLambda::Value(field) => Some(field.data_type().clone()),
177 ValueOrLambda::Lambda(_) => None,
178 })
179 .collect::<Vec<_>>();
180
181 let coerced_types = func.coerce_value_types(&arg_types)?;
182
183 if coerced_types.len() != arg_types.len() {
184 return plan_err!(
185 "{} coerce_value_types should have returned {} items but returned {}",
186 func.name(),
187 arg_types.len(),
188 coerced_types.len()
189 );
190 }
191
192 let mut coerced_types = coerced_types.into_iter();
200
201 current_fields
202 .iter()
203 .map(|current_field| match current_field {
204 ValueOrLambda::Value(field) => {
205 let data_type = coerced_types.next().ok_or_else(|| {
206 internal_datafusion_err!(
207 "coerced_types len should have been checked above"
208 )
209 })?;
210
211 Ok(ValueOrLambda::Value(Arc::new(
212 field.as_ref().clone().with_data_type(data_type),
213 )))
214 }
215 ValueOrLambda::Lambda(lambda) => {
216 Ok(ValueOrLambda::Lambda(lambda.clone()))
217 }
218 })
219 .collect()
220 }
221 HigherOrderTypeSignature::VariadicAny => Ok(current_fields.to_vec()),
222 HigherOrderTypeSignature::Any(number) => {
223 if current_fields.len() != number {
224 return plan_err!(
225 "The function '{}' expected {number} arguments but received {}",
226 func.name(),
227 current_fields.len()
228 );
229 }
230
231 Ok(current_fields.to_vec())
232 }
233 HigherOrderTypeSignature::Exact(ref expected) => {
234 if current_fields.len() != expected.len() {
235 let name = func.name();
236 let expected_len = expected.len();
237 let actual_len = current_fields.len();
238 return plan_err!(
239 "The function '{name}' expected {expected_len} argument(s) but received {actual_len}"
240 );
241 }
242
243 for (i, (actual, expected)) in
244 current_fields.iter().zip(expected.iter()).enumerate()
245 {
246 match (actual, expected) {
247 (ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {}
248 (ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {}
249 (ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => {
250 let name = func.name();
251 return plan_err!(
252 "The function '{name}' expected a lambda at position {i} but received a value"
253 );
254 }
255 (ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => {
256 let name = func.name();
257 return plan_err!(
258 "The function '{name}' expected a value at position {i} but received a lambda"
259 );
260 }
261 }
262 }
263
264 let arg_types = current_fields
265 .iter()
266 .filter_map(|p| match p {
267 ValueOrLambda::Value(field) => Some(field.data_type().clone()),
268 ValueOrLambda::Lambda(_) => None,
269 })
270 .collect::<Vec<_>>();
271
272 let coerced_types = func.coerce_value_types(&arg_types)?;
273
274 if coerced_types.len() != arg_types.len() {
275 return plan_err!(
276 "{} coerce_value_types should have returned {} items but returned {}",
277 func.name(),
278 arg_types.len(),
279 coerced_types.len()
280 );
281 }
282
283 let mut coerced_types = coerced_types.into_iter();
284
285 current_fields
286 .iter()
287 .map(|current_field| match current_field {
288 ValueOrLambda::Value(field) => {
289 let data_type = coerced_types.next().ok_or_else(|| {
290 internal_datafusion_err!(
291 "coerced_types len should have been checked above"
292 )
293 })?;
294
295 Ok(ValueOrLambda::Value(Arc::new(
296 field.as_ref().clone().with_data_type(data_type),
297 )))
298 }
299 ValueOrLambda::Lambda(lambda) => {
300 Ok(ValueOrLambda::Lambda(lambda.clone()))
301 }
302 })
303 .collect()
304 }
305 }
306}
307
308pub fn value_fields_with_higher_order_udf_and_lambdas(
321 current_fields: &[ValueOrLambda<FieldRef, FieldRef>],
322 func: &HigherOrderUDF,
323) -> Result<Vec<ValueOrLambda<FieldRef, FieldRef>>> {
324 let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?;
325
326 let new_types = new_fields
327 .iter()
328 .map(|f| match f {
329 ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
330 ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
331 })
332 .collect::<Vec<_>>();
333
334 if let Some(new_value_types) = func.coerce_values_for_lambdas(&new_types)? {
335 let mut new_value_types = new_value_types.into_iter();
336
337 let value_types_count = new_types
338 .iter()
339 .filter(|e| matches!(e, ValueOrLambda::Value(_)))
340 .count();
341
342 if new_value_types.len() != value_types_count {
343 return plan_err!(
344 "{} coerce_values_for_lambdas returned {} values but {value_types_count} expected",
345 func.name(),
346 new_value_types.len()
347 );
348 }
349
350 for new_field in &mut new_fields {
351 match new_field {
352 ValueOrLambda::Value(value) => {
353 let coerce_to = new_value_types.next().ok_or_else(|| {
354 internal_datafusion_err!(
355 "new_value_types len should have been checked above"
356 )
357 })?;
358
359 if value.data_type() != &coerce_to {
360 Arc::make_mut(value).set_data_type(coerce_to);
361 }
362 }
363 ValueOrLambda::Lambda(_) => {}
364 }
365 }
366 };
367
368 Ok(new_fields)
369}
370
371#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
379pub fn data_types_with_scalar_udf(
380 current_types: &[DataType],
381 func: &ScalarUDF,
382) -> Result<Vec<DataType>> {
383 let current_fields = current_types
384 .iter()
385 .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
386 .collect::<Vec<_>>();
387 Ok(fields_with_udf(¤t_fields, func)?
388 .iter()
389 .map(|f| f.data_type().clone())
390 .collect())
391}
392
393#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
401pub fn fields_with_aggregate_udf(
402 current_fields: &[FieldRef],
403 func: &AggregateUDF,
404) -> Result<Vec<FieldRef>> {
405 fields_with_udf(current_fields, func)
406}
407
408#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
416pub fn fields_with_window_udf(
417 current_fields: &[FieldRef],
418 func: &WindowUDF,
419) -> Result<Vec<FieldRef>> {
420 fields_with_udf(current_fields, func)
421}
422
423#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
431pub fn data_types(
432 function_name: impl AsRef<str>,
433 current_types: &[DataType],
434 signature: &Signature,
435) -> Result<Vec<DataType>> {
436 let type_signature = &signature.type_signature;
437
438 if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
439 if type_signature.supports_zero_argument() {
440 return Ok(vec![]);
441 } else if type_signature.used_to_support_zero_arguments() {
442 return plan_err!(
444 "function '{}' has signature {type_signature} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
445 function_name.as_ref()
446 );
447 } else {
448 return plan_err!(
449 "Function '{}' has signature {type_signature} which does not support zero arguments",
450 function_name.as_ref()
451 );
452 }
453 }
454
455 let valid_types =
456 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
457 if valid_types
458 .iter()
459 .any(|data_type| data_type == current_types)
460 {
461 return Ok(current_types.to_vec());
462 }
463
464 try_coerce_types(
465 function_name.as_ref(),
466 valid_types,
467 current_types,
468 type_signature,
469 )
470}
471
472fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
473 match type_signature {
474 TypeSignature::OneOf(type_signatures) => {
475 type_signatures.iter().all(is_well_supported_signature)
476 }
477 TypeSignature::UserDefined
478 | TypeSignature::Numeric(_)
479 | TypeSignature::String(_)
480 | TypeSignature::Coercible(_)
481 | TypeSignature::Any(_)
482 | TypeSignature::Nullary
483 | TypeSignature::Comparable(_) => true,
484 TypeSignature::Variadic(_)
485 | TypeSignature::VariadicAny
486 | TypeSignature::Uniform(_, _)
487 | TypeSignature::Exact(_)
488 | TypeSignature::ArraySignature(_) => false,
489 }
490}
491
492fn try_coerce_types(
493 function_name: &str,
494 valid_types: Vec<Vec<DataType>>,
495 current_types: &[DataType],
496 type_signature: &TypeSignature,
497) -> Result<Vec<DataType>> {
498 let mut valid_types = valid_types;
499
500 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
502 if !type_signature.is_one_of() {
505 assert_eq!(valid_types.len(), 1);
506 }
507
508 let valid_types = valid_types.swap_remove(0);
509 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
510 return Ok(t);
511 }
512 } else {
513 for valid_types in valid_types {
517 if let Some(types) = maybe_data_types(&valid_types, current_types) {
518 return Ok(types);
519 }
520 }
521 }
522
523 plan_err!(
525 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature} failed",
526 current_types.iter().join(", ")
527 )
528}
529
530fn get_valid_types_with_udf<F: UDFCoercionExt>(
531 signature: &TypeSignature,
532 current_types: &[DataType],
533 func: &F,
534) -> Result<Vec<Vec<DataType>>> {
535 let valid_types = match signature {
536 TypeSignature::UserDefined => match func.coerce_types(current_types) {
537 Ok(coerced_types) => vec![coerced_types],
538 Err(e) => {
539 return exec_err!(
540 "Function '{}' user-defined coercion failed with: {}",
541 func.name(),
542 e.strip_backtrace()
543 );
544 }
545 },
546 TypeSignature::OneOf(signatures) => {
547 let mut res = vec![];
548 let mut errors = vec![];
549 for sig in signatures {
550 match get_valid_types_with_udf(sig, current_types, func) {
551 Ok(valid_types) => {
552 res.extend(valid_types);
553 }
554 Err(e) => {
555 errors.push(e.to_string());
556 }
557 }
558 }
559
560 if res.is_empty() {
562 return internal_err!(
563 "Function '{}' failed to match any signature, errors: {}",
564 func.name(),
565 errors.join(",")
566 );
567 } else {
568 res
569 }
570 }
571 _ => get_valid_types(func.name(), signature, current_types)?,
572 };
573
574 Ok(valid_types)
575}
576
577fn get_valid_types(
579 function_name: &str,
580 signature: &TypeSignature,
581 current_types: &[DataType],
582) -> Result<Vec<Vec<DataType>>> {
583 fn array_valid_types(
584 function_name: &str,
585 current_types: &[DataType],
586 arguments: &[ArrayFunctionArgument],
587 array_coercion: Option<&ListCoercion>,
588 ) -> Result<Vec<Vec<DataType>>> {
589 if current_types.len() != arguments.len() {
590 return Ok(vec![vec![]]);
591 }
592
593 let mut large_list = false;
594 let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
595 let mut list_sizes = Vec::with_capacity(arguments.len());
596 let mut element_types = Vec::with_capacity(arguments.len());
597 let mut nested_item_nullability = Vec::with_capacity(arguments.len());
598 for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
599 match argument {
600 ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
601 nested_item_nullability.push(None);
602 }
603 ArrayFunctionArgument::Element => {
604 element_types.push(current_type.clone());
605 nested_item_nullability.push(None);
606 }
607 ArrayFunctionArgument::Array => match current_type {
608 DataType::Null => {
609 element_types.push(DataType::Null);
610 nested_item_nullability.push(None);
611 }
612 DataType::List(field) | DataType::ListView(field) => {
613 element_types.push(field.data_type().clone());
614 nested_item_nullability.push(Some(field.is_nullable()));
615 fixed_size = false;
616 }
617 DataType::LargeList(field) | DataType::LargeListView(field) => {
618 element_types.push(field.data_type().clone());
619 nested_item_nullability.push(Some(field.is_nullable()));
620 large_list = true;
621 fixed_size = false;
622 }
623 DataType::FixedSizeList(field, size) => {
624 element_types.push(field.data_type().clone());
625 nested_item_nullability.push(Some(field.is_nullable()));
626 list_sizes.push(*size)
627 }
628 arg_type => {
629 plan_err!("{function_name} does not support type {arg_type}")?
630 }
631 },
632 }
633 }
634
635 debug_assert_eq!(nested_item_nullability.len(), arguments.len());
636
637 let Some(element_type) = type_union_resolution(&element_types) else {
638 return Ok(vec![vec![]]);
639 };
640
641 if !fixed_size {
642 list_sizes.clear()
643 };
644
645 let mut list_sizes = list_sizes.into_iter();
646 let valid_types = arguments
647 .iter()
648 .zip(current_types.iter())
649 .zip(nested_item_nullability)
650 .map(|((argument_type, current_type), is_nested_item_nullable)| {
651 match argument_type {
652 ArrayFunctionArgument::Index => DataType::Int64,
653 ArrayFunctionArgument::String => DataType::Utf8,
654 ArrayFunctionArgument::Element => element_type.clone(),
655 ArrayFunctionArgument::Array => {
658 if current_type.is_null() {
659 DataType::Null
660 } else if large_list {
661 DataType::new_large_list(
662 element_type.clone(),
663 is_nested_item_nullable.unwrap_or(true),
664 )
665 } else if let Some(size) = list_sizes.next() {
666 DataType::new_fixed_size_list(
667 element_type.clone(),
668 size,
669 is_nested_item_nullable.unwrap_or(true),
670 )
671 } else {
672 DataType::new_list(
673 element_type.clone(),
674 is_nested_item_nullable.unwrap_or(true),
675 )
676 }
677 }
678 }
679 });
680
681 Ok(vec![valid_types.collect()])
682 }
683
684 fn recursive_array(array_type: &DataType) -> Option<DataType> {
685 match array_type {
686 DataType::List(_)
687 | DataType::LargeList(_)
688 | DataType::ListView(_)
689 | DataType::LargeListView(_)
690 | DataType::FixedSizeList(_, _) => {
691 let array_type = coerced_fixed_size_list_to_list(array_type);
692 Some(array_type)
693 }
694 _ => None,
695 }
696 }
697
698 fn function_length_check(
699 function_name: &str,
700 length: usize,
701 expected_length: usize,
702 ) -> Result<()> {
703 if length != expected_length {
704 return plan_err!(
705 "Function '{function_name}' expects {expected_length} arguments but received {length}"
706 );
707 }
708 Ok(())
709 }
710
711 let valid_types = match signature {
712 TypeSignature::Variadic(valid_types) => valid_types
713 .iter()
714 .map(|valid_type| vec![valid_type.clone(); current_types.len()])
715 .collect(),
716 TypeSignature::String(number) => {
717 function_length_check(function_name, current_types.len(), *number)?;
718
719 let mut new_types = Vec::with_capacity(current_types.len());
720 for data_type in current_types.iter() {
721 let logical_data_type: NativeType = data_type.into();
722 if logical_data_type == NativeType::String {
723 new_types.push(data_type.to_owned());
724 } else if logical_data_type == NativeType::Null {
725 new_types.push(DataType::Utf8);
727 } else {
728 return plan_err!(
729 "Function '{function_name}' expects String but received {logical_data_type}"
730 );
731 }
732 }
733
734 fn find_common_type(
736 function_name: &str,
737 lhs_type: &DataType,
738 rhs_type: &DataType,
739 ) -> Result<DataType> {
740 match (lhs_type, rhs_type) {
741 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
742 find_common_type(function_name, lhs, rhs)
743 }
744 (DataType::Dictionary(_, v), other)
745 | (other, DataType::Dictionary(_, v)) => {
746 find_common_type(function_name, v, other)
747 }
748 _ => {
749 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
750 Ok(coerced_type)
751 } else {
752 plan_err!(
753 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
754 )
755 }
756 }
757 }
758 }
759
760 let mut coerced_type = new_types.first().unwrap().to_owned();
762 for t in new_types.iter().skip(1) {
763 coerced_type = find_common_type(function_name, &coerced_type, t)?;
764 }
765
766 fn base_type_or_default_type(data_type: &DataType) -> DataType {
767 if let DataType::Dictionary(_, v) = data_type {
768 base_type_or_default_type(v)
769 } else {
770 data_type.to_owned()
771 }
772 }
773
774 vec![vec![base_type_or_default_type(&coerced_type); *number]]
775 }
776 TypeSignature::Numeric(number) => {
777 function_length_check(function_name, current_types.len(), *number)?;
778
779 let mut valid_type = current_types.first().unwrap().to_owned();
781 for t in current_types.iter().skip(1) {
782 let logical_data_type: NativeType = t.into();
783 if logical_data_type == NativeType::Null {
784 continue;
785 }
786
787 if !logical_data_type.is_numeric() {
788 return plan_err!(
789 "Function '{function_name}' expects Numeric but received {logical_data_type}"
790 );
791 }
792
793 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
794 valid_type = coerced_type;
795 } else {
796 return plan_err!(
797 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
798 );
799 }
800 }
801
802 let logical_data_type: NativeType = valid_type.clone().into();
803 if logical_data_type == NativeType::Null {
807 valid_type = DataType::Float64;
808 } else if !logical_data_type.is_numeric() {
809 return plan_err!(
810 "Function '{function_name}' expects Numeric but received {logical_data_type}"
811 );
812 }
813
814 vec![vec![valid_type; *number]]
815 }
816 TypeSignature::Comparable(num) => {
817 function_length_check(function_name, current_types.len(), *num)?;
818 let mut target_type = current_types[0].to_owned();
819 for data_type in current_types.iter().skip(1) {
820 if let Some(dt) = comparison_coercion(&target_type, data_type) {
821 target_type = dt;
822 } else {
823 return plan_err!(
824 "For function '{function_name}' {target_type} and {data_type} is not comparable"
825 );
826 }
827 }
828 if target_type.is_null() {
830 vec![vec![DataType::Utf8View; *num]]
831 } else {
832 vec![vec![target_type; *num]]
833 }
834 }
835 TypeSignature::Coercible(param_types) => {
836 function_length_check(function_name, current_types.len(), param_types.len())?;
837
838 let mut new_types = Vec::with_capacity(current_types.len());
839 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
840 let current_native_type: NativeType = current_type.into();
841
842 if param
843 .desired_type()
844 .matches_native_type(¤t_native_type)
845 {
846 let casted_type = param
847 .desired_type()
848 .default_casted_type(¤t_native_type, current_type)?;
849
850 new_types.push(casted_type);
851 } else if param
852 .allowed_source_types()
853 .iter()
854 .any(|t| t.matches_native_type(¤t_native_type))
855 {
856 let default_casted_type = param.default_casted_type().unwrap();
858 let casted_type =
859 default_casted_type.default_cast_for(current_type)?;
860 new_types.push(casted_type);
861 } else {
862 let hint = if matches!(current_native_type, NativeType::Binary) {
863 "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String."
864 } else {
865 ""
866 };
867 return plan_err!(
868 "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}",
869 param.desired_type(),
870 current_native_type,
871 current_type
872 );
873 }
874 }
875
876 vec![new_types]
877 }
878 TypeSignature::Uniform(number, valid_types) => {
879 if *number == 0 {
880 return plan_err!(
881 "The function '{function_name}' expected at least one argument"
882 );
883 }
884
885 valid_types
886 .iter()
887 .map(|valid_type| vec![valid_type.clone(); *number])
888 .collect()
889 }
890 TypeSignature::UserDefined => {
891 return internal_err!(
892 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
893 );
894 }
895 TypeSignature::VariadicAny => {
896 if current_types.is_empty() {
897 return plan_err!(
898 "Function '{function_name}' expected at least one argument but received 0"
899 );
900 }
901 vec![current_types.to_vec()]
902 }
903 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
904 TypeSignature::ArraySignature(function_signature) => match function_signature {
905 ArrayFunctionSignature::Array {
906 arguments,
907 array_coercion,
908 } => array_valid_types(
909 function_name,
910 current_types,
911 arguments,
912 array_coercion.as_ref(),
913 )?,
914 ArrayFunctionSignature::RecursiveArray => {
915 if current_types.len() != 1 {
916 return Ok(vec![vec![]]);
917 }
918 recursive_array(¤t_types[0])
919 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
920 }
921 ArrayFunctionSignature::MapArray => {
922 if current_types.len() != 1 {
923 return Ok(vec![vec![]]);
924 }
925
926 match ¤t_types[0] {
927 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
928 _ => vec![vec![]],
929 }
930 }
931 },
932 TypeSignature::Nullary => {
933 if !current_types.is_empty() {
934 return plan_err!(
935 "The function '{function_name}' expected zero argument but received {}",
936 current_types.len()
937 );
938 }
939 vec![vec![]]
940 }
941 TypeSignature::Any(number) => {
942 if current_types.is_empty() {
943 return plan_err!(
944 "The function '{function_name}' expected at least one argument but received 0"
945 );
946 }
947
948 if current_types.len() != *number {
949 return plan_err!(
950 "The function '{function_name}' expected {number} arguments but received {}",
951 current_types.len()
952 );
953 }
954 vec![current_types.to_vec()]
955 }
956 TypeSignature::OneOf(types) => types
957 .iter()
958 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
959 .flatten()
960 .collect::<Vec<_>>(),
961 };
962
963 Ok(valid_types)
964}
965
966fn maybe_data_types(
973 valid_types: &[DataType],
974 current_types: &[DataType],
975) -> Option<Vec<DataType>> {
976 if valid_types.len() != current_types.len() {
977 return None;
978 }
979
980 let mut new_type = Vec::with_capacity(valid_types.len());
981 for (i, valid_type) in valid_types.iter().enumerate() {
982 let current_type = ¤t_types[i];
983
984 if current_type == valid_type {
985 new_type.push(current_type.clone())
986 } else {
987 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
991 new_type.push(coerced_type)
992 } else {
993 return None;
995 }
996 }
997 }
998 Some(new_type)
999}
1000
1001fn maybe_data_types_without_coercion(
1005 valid_types: &[DataType],
1006 current_types: &[DataType],
1007) -> Option<Vec<DataType>> {
1008 if valid_types.len() != current_types.len() {
1009 return None;
1010 }
1011
1012 let mut new_type = Vec::with_capacity(valid_types.len());
1013 for (i, valid_type) in valid_types.iter().enumerate() {
1014 let current_type = ¤t_types[i];
1015
1016 if current_type == valid_type {
1017 new_type.push(current_type.clone())
1018 } else if can_cast_types(current_type, valid_type) {
1019 new_type.push(valid_type.clone())
1021 } else {
1022 return None;
1023 }
1024 }
1025 Some(new_type)
1026}
1027
1028#[deprecated(since = "53.0.0", note = "Unused internal function")]
1033pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
1034 if type_into == type_from {
1035 return true;
1036 }
1037 if let Some(coerced) = coerced_from(type_into, type_from) {
1038 return coerced == *type_into;
1039 }
1040 false
1041}
1042
1043fn coerced_from<'a>(
1050 type_into: &'a DataType,
1051 type_from: &'a DataType,
1052) -> Option<DataType> {
1053 use self::DataType::*;
1054
1055 match (type_into, type_from) {
1057 (_, Dictionary(_, value_type))
1059 if coerced_from(type_into, value_type).is_some() =>
1060 {
1061 Some(type_into.clone())
1062 }
1063 (Dictionary(_, value_type), _)
1064 if coerced_from(value_type, type_from).is_some() =>
1065 {
1066 Some(type_into.clone())
1067 }
1068 (Int8, Null | Int8) => Some(type_into.clone()),
1070 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
1071 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
1072 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
1073 Some(type_into.clone())
1074 }
1075 (UInt8, Null | UInt8) => Some(type_into.clone()),
1076 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
1077 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
1078 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
1079 (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => {
1080 Some(type_into.clone())
1081 }
1082 (
1083 Float32,
1084 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
1085 | Float16 | Float32,
1086 ) => Some(type_into.clone()),
1087 (
1088 Float64,
1089 Null
1090 | Int8
1091 | Int16
1092 | Int32
1093 | Int64
1094 | UInt8
1095 | UInt16
1096 | UInt32
1097 | UInt64
1098 | Float16
1099 | Float32
1100 | Float64
1101 | Decimal32(_, _)
1102 | Decimal64(_, _)
1103 | Decimal128(_, _)
1104 | Decimal256(_, _),
1105 ) => Some(type_into.clone()),
1106 (
1107 Timestamp(TimeUnit::Nanosecond, None),
1108 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
1109 ) => Some(type_into.clone()),
1110 (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()),
1111 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
1113 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
1115 (BinaryView, Binary | LargeBinary | Null) => Some(type_into.clone()),
1117 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
1118
1119 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
1120
1121 (List(_) | LargeList(_) | ListView(_) | LargeListView(_), _)
1124 if base_type(type_from).is_null()
1125 || list_ndims(type_from) == list_ndims(type_into) =>
1126 {
1127 Some(type_into.clone())
1128 }
1129 (
1131 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
1132 FixedSizeList(f_from, size_from),
1133 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
1134 Some(data_type) if &data_type != f_into.data_type() => {
1135 let new_field =
1136 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
1137 Some(FixedSizeList(new_field, *size_from))
1138 }
1139 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
1140 _ => None,
1141 },
1142 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
1143 match type_from {
1144 Timestamp(_, Some(from_tz)) => {
1145 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
1146 }
1147 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
1148 Some(Timestamp(*unit, Some("+00".into())))
1150 }
1151 _ => None,
1152 }
1153 }
1154 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
1155 Some(type_into.clone())
1156 }
1157 (_, Null) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
1164 _ => None,
1165 }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170 use crate::{
1171 HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature,
1172 HigherOrderUDFImpl, Volatility,
1173 };
1174
1175 use super::*;
1176 use arrow::datatypes::IntervalUnit;
1177 use datafusion_common::{
1178 assert_contains,
1179 types::{logical_binary, logical_int64},
1180 };
1181 use datafusion_expr_common::{
1182 columnar_value::ColumnarValue,
1183 signature::{Coercion, TypeSignatureClass},
1184 };
1185
1186 #[test]
1187 fn test_string_conversion() {
1188 let cases = vec![
1189 (DataType::Utf8View, DataType::Utf8),
1190 (DataType::Utf8View, DataType::LargeUtf8),
1191 (DataType::Utf8View, DataType::Null),
1192 ];
1193
1194 for case in cases {
1195 assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
1196 }
1197 }
1198
1199 #[test]
1200 fn test_binary_conversion() {
1201 let cases = vec![
1202 (DataType::BinaryView, DataType::Binary),
1203 (DataType::BinaryView, DataType::LargeBinary),
1204 (DataType::BinaryView, DataType::Null),
1205 ];
1206
1207 for case in cases {
1208 assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
1209 }
1210 }
1211
1212 #[test]
1213 fn test_coerced_from_null() {
1214 assert_eq!(
1216 coerced_from(
1217 &DataType::Interval(IntervalUnit::MonthDayNano),
1218 &DataType::Null
1219 ),
1220 Some(DataType::Interval(IntervalUnit::MonthDayNano))
1221 );
1222
1223 assert_eq!(
1225 coerced_from(&DataType::Date32, &DataType::Null),
1226 Some(DataType::Date32)
1227 );
1228
1229 assert_eq!(
1231 coerced_from(
1232 &DataType::Timestamp(TimeUnit::Microsecond, Some("+00".into())),
1233 &DataType::Null
1234 ),
1235 Some(DataType::Timestamp(
1236 TimeUnit::Microsecond,
1237 Some("+00".into())
1238 ))
1239 );
1240 }
1241
1242 #[test]
1243 fn test_maybe_data_types() {
1244 let cases = vec![
1246 (
1248 vec![DataType::UInt8, DataType::UInt16],
1249 vec![DataType::UInt8, DataType::UInt16],
1250 Some(vec![DataType::UInt8, DataType::UInt16]),
1251 ),
1252 (
1254 vec![DataType::UInt16, DataType::UInt16],
1255 vec![DataType::UInt8, DataType::UInt16],
1256 Some(vec![DataType::UInt16, DataType::UInt16]),
1257 ),
1258 (vec![], vec![], Some(vec![])),
1260 (
1262 vec![DataType::Boolean, DataType::UInt16],
1263 vec![DataType::UInt8, DataType::UInt16],
1264 None,
1265 ),
1266 (
1268 vec![DataType::Boolean, DataType::UInt32],
1269 vec![DataType::Boolean, DataType::UInt16],
1270 Some(vec![DataType::Boolean, DataType::UInt32]),
1271 ),
1272 (
1274 vec![
1275 DataType::Timestamp(TimeUnit::Nanosecond, None),
1276 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
1277 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1278 ],
1279 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
1280 Some(vec![
1281 DataType::Timestamp(TimeUnit::Nanosecond, None),
1282 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1283 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1284 ]),
1285 ),
1286 ];
1287
1288 for case in cases {
1289 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1290 }
1291 }
1292
1293 #[test]
1294 fn test_get_valid_types_numeric() -> Result<()> {
1295 let get_valid_types_flatten =
1296 |function_name: &str,
1297 signature: &TypeSignature,
1298 current_types: &[DataType]| {
1299 get_valid_types(function_name, signature, current_types)
1300 .unwrap()
1301 .into_iter()
1302 .flatten()
1303 .collect::<Vec<_>>()
1304 };
1305
1306 let got = get_valid_types_flatten(
1308 "test",
1309 &TypeSignature::Numeric(1),
1310 &[DataType::Int32],
1311 );
1312 assert_eq!(got, [DataType::Int32]);
1313
1314 let got = get_valid_types_flatten(
1316 "test",
1317 &TypeSignature::Numeric(2),
1318 &[DataType::Int32, DataType::Int64],
1319 );
1320 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1321
1322 let got = get_valid_types_flatten(
1324 "test",
1325 &TypeSignature::Numeric(3),
1326 &[DataType::Int32, DataType::Int64, DataType::Float64],
1327 );
1328 assert_eq!(
1329 got,
1330 [DataType::Float64, DataType::Float64, DataType::Float64]
1331 );
1332
1333 let got = get_valid_types(
1335 "test",
1336 &TypeSignature::Numeric(2),
1337 &[DataType::Int32, DataType::Utf8],
1338 )
1339 .unwrap_err();
1340 assert_contains!(
1341 got.to_string(),
1342 "Function 'test' expects Numeric but received String"
1343 );
1344
1345 let got = get_valid_types_flatten(
1347 "test",
1348 &TypeSignature::Numeric(1),
1349 &[DataType::Null],
1350 );
1351 assert_eq!(got, [DataType::Float64]);
1352
1353 let got = get_valid_types(
1355 "test",
1356 &TypeSignature::Numeric(1),
1357 &[DataType::Timestamp(TimeUnit::Second, None)],
1358 )
1359 .unwrap_err();
1360 assert_contains!(
1361 got.to_string(),
1362 "Function 'test' expects Numeric but received Timestamp(s)"
1363 );
1364
1365 Ok(())
1366 }
1367
1368 #[test]
1369 fn test_get_valid_types_one_of() -> Result<()> {
1370 let signature =
1371 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1372
1373 let invalid_types = get_valid_types(
1374 "test",
1375 &signature,
1376 &[DataType::Int32, DataType::Int32, DataType::Int32],
1377 )?;
1378 assert_eq!(invalid_types.len(), 0);
1379
1380 let args = vec![DataType::Int32, DataType::Int32];
1381 let valid_types = get_valid_types("test", &signature, &args)?;
1382 assert_eq!(valid_types.len(), 1);
1383 assert_eq!(valid_types[0], args);
1384
1385 let args = vec![DataType::Int32];
1386 let valid_types = get_valid_types("test", &signature, &args)?;
1387 assert_eq!(valid_types.len(), 1);
1388 assert_eq!(valid_types[0], args);
1389
1390 Ok(())
1391 }
1392
1393 #[test]
1394 fn test_get_valid_types_length_check() -> Result<()> {
1395 let signature = TypeSignature::Numeric(1);
1396
1397 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1398 assert_contains!(
1399 err.to_string(),
1400 "Function 'test' expects 1 arguments but received 0"
1401 );
1402
1403 let err = get_valid_types(
1404 "test",
1405 &signature,
1406 &[DataType::Int32, DataType::Int32, DataType::Int32],
1407 )
1408 .unwrap_err();
1409 assert_contains!(
1410 err.to_string(),
1411 "Function 'test' expects 1 arguments but received 3"
1412 );
1413
1414 Ok(())
1415 }
1416
1417 struct MockUdf(Signature);
1418
1419 impl UDFCoercionExt for MockUdf {
1420 fn name(&self) -> &str {
1421 "test"
1422 }
1423 fn signature(&self) -> &Signature {
1424 &self.0
1425 }
1426 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1427 unimplemented!()
1428 }
1429 }
1430
1431 #[test]
1432 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1433 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1434 let current_fields = vec![Arc::new(Field::new(
1436 "t",
1437 DataType::FixedSizeList(Arc::clone(&inner), 2),
1438 true,
1439 ))];
1440
1441 let signature = Signature::exact(
1442 vec![DataType::FixedSizeList(
1443 Arc::clone(&inner),
1444 FIXED_SIZE_LIST_WILDCARD,
1445 )],
1446 Volatility::Stable,
1447 );
1448
1449 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature))?;
1450 assert_eq!(coerced_fields, current_fields);
1451
1452 let signature = Signature::exact(
1454 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1455 Volatility::Stable,
1456 );
1457 let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature));
1458 assert!(coerced_fields.is_err());
1459
1460 let signature = Signature::exact(
1462 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1463 Volatility::Stable,
1464 );
1465 let coerced_fields =
1466 fields_with_udf(¤t_fields, &MockUdf(signature)).unwrap();
1467 assert_eq!(coerced_fields, current_fields);
1468
1469 Ok(())
1470 }
1471
1472 #[test]
1473 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1474 let type_into = DataType::FixedSizeList(
1475 Arc::new(Field::new_list_field(
1476 DataType::FixedSizeList(
1477 Arc::new(Field::new_list_field(DataType::Int32, false)),
1478 FIXED_SIZE_LIST_WILDCARD,
1479 ),
1480 false,
1481 )),
1482 FIXED_SIZE_LIST_WILDCARD,
1483 );
1484
1485 let type_from = DataType::FixedSizeList(
1486 Arc::new(Field::new_list_field(
1487 DataType::FixedSizeList(
1488 Arc::new(Field::new_list_field(DataType::Int8, false)),
1489 4,
1490 ),
1491 false,
1492 )),
1493 3,
1494 );
1495
1496 assert_eq!(
1497 coerced_from(&type_into, &type_from),
1498 Some(DataType::FixedSizeList(
1499 Arc::new(Field::new_list_field(
1500 DataType::FixedSizeList(
1501 Arc::new(Field::new_list_field(DataType::Int32, false)),
1502 4,
1503 ),
1504 false,
1505 )),
1506 3,
1507 ))
1508 );
1509
1510 Ok(())
1511 }
1512
1513 #[test]
1514 fn test_coerced_from_dictionary() {
1515 let type_into =
1516 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1517 let type_from = DataType::Int64;
1518 assert_eq!(coerced_from(&type_into, &type_from), None);
1519
1520 let type_from =
1521 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1522 let type_into = DataType::Int64;
1523 assert_eq!(
1524 coerced_from(&type_into, &type_from),
1525 Some(type_into.clone())
1526 );
1527 }
1528
1529 #[test]
1530 fn test_get_valid_types_array_and_array() -> Result<()> {
1531 let function = "array_and_array";
1532 let signature = Signature::arrays(
1533 2,
1534 Some(ListCoercion::FixedSizedListToList),
1535 Volatility::Immutable,
1536 );
1537
1538 let data_types = vec![
1539 DataType::new_list(DataType::Int32, true),
1540 DataType::new_large_list(DataType::Float64, true),
1541 ];
1542 assert_eq!(
1543 get_valid_types(function, &signature.type_signature, &data_types)?,
1544 vec![vec![
1545 DataType::new_large_list(DataType::Float64, true),
1546 DataType::new_large_list(DataType::Float64, true),
1547 ]]
1548 );
1549
1550 let data_types = vec![
1551 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1552 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1553 ];
1554 assert_eq!(
1555 get_valid_types(function, &signature.type_signature, &data_types)?,
1556 vec![vec![
1557 DataType::new_list(DataType::Int64, true),
1558 DataType::new_list(DataType::Int64, true),
1559 ]]
1560 );
1561
1562 let data_types = vec![
1563 DataType::new_fixed_size_list(DataType::Null, 3, true),
1564 DataType::new_large_list(DataType::Utf8, true),
1565 ];
1566 assert_eq!(
1567 get_valid_types(function, &signature.type_signature, &data_types)?,
1568 vec![vec![
1569 DataType::new_large_list(DataType::Utf8, true),
1570 DataType::new_large_list(DataType::Utf8, true),
1571 ]]
1572 );
1573
1574 let data_types = vec![
1575 DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1576 DataType::new_list(DataType::Int32, true),
1577 ];
1578 assert_eq!(
1579 get_valid_types(function, &signature.type_signature, &data_types)?,
1580 vec![vec![
1581 DataType::new_list(DataType::Int32, true),
1582 DataType::new_list(DataType::Int32, true),
1583 ]]
1584 );
1585
1586 let data_types = vec![
1587 DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1588 DataType::new_list(DataType::Int32, true),
1589 ];
1590 assert_eq!(
1591 get_valid_types(function, &signature.type_signature, &data_types)?,
1592 vec![vec![
1593 DataType::new_large_list(DataType::Int32, true),
1594 DataType::new_large_list(DataType::Int32, true),
1595 ]]
1596 );
1597
1598 let data_types = vec![
1599 DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1600 DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1601 ];
1602 assert_eq!(
1603 get_valid_types(function, &signature.type_signature, &data_types)?,
1604 vec![vec![
1605 DataType::new_list(DataType::Int32, true),
1606 DataType::new_list(DataType::Int32, true),
1607 ]]
1608 );
1609
1610 let data_types = vec![
1611 DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1612 DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1613 ];
1614 assert_eq!(
1615 get_valid_types(function, &signature.type_signature, &data_types)?,
1616 vec![vec![
1617 DataType::new_large_list(DataType::Int32, true),
1618 DataType::new_large_list(DataType::Int32, true),
1619 ]]
1620 );
1621
1622 Ok(())
1623 }
1624
1625 #[test]
1626 fn test_get_valid_types_array_and_element() -> Result<()> {
1627 let function = "array_and_element";
1628 let signature = Signature::array_and_element(Volatility::Immutable);
1629
1630 let data_types =
1631 vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1632 assert_eq!(
1633 get_valid_types(function, &signature.type_signature, &data_types)?,
1634 vec![vec![
1635 DataType::new_list(DataType::Float64, true),
1636 DataType::Float64,
1637 ]]
1638 );
1639
1640 let data_types = vec![
1641 DataType::new_large_list(DataType::Int32, true),
1642 DataType::Null,
1643 ];
1644 assert_eq!(
1645 get_valid_types(function, &signature.type_signature, &data_types)?,
1646 vec![vec![
1647 DataType::new_large_list(DataType::Int32, true),
1648 DataType::Int32,
1649 ]]
1650 );
1651
1652 let data_types = vec![
1653 DataType::new_fixed_size_list(DataType::Null, 3, true),
1654 DataType::Utf8,
1655 ];
1656 assert_eq!(
1657 get_valid_types(function, &signature.type_signature, &data_types)?,
1658 vec![vec![
1659 DataType::new_list(DataType::Utf8, true),
1660 DataType::Utf8,
1661 ]]
1662 );
1663
1664 Ok(())
1665 }
1666
1667 #[test]
1668 fn test_get_valid_types_element_and_array() -> Result<()> {
1669 let function = "element_and_array";
1670 let signature = Signature::element_and_array(Volatility::Immutable);
1671
1672 let data_types = vec![
1673 DataType::new_large_list(DataType::Null, false),
1674 DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1675 ];
1676 assert_eq!(
1677 get_valid_types(function, &signature.type_signature, &data_types)?,
1678 vec![vec![
1679 DataType::new_large_list(DataType::Int64, true),
1680 DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1681 ]]
1682 );
1683
1684 Ok(())
1685 }
1686
1687 #[test]
1688 fn test_coercible_nulls() -> Result<()> {
1689 fn null_input(coercion: Coercion) -> Result<Vec<DataType>> {
1690 fields_with_udf(
1691 &[Field::new("field", DataType::Null, true).into()],
1692 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1693 )
1694 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1695 }
1696
1697 let output = null_input(Coercion::new_exact(TypeSignatureClass::Native(
1699 logical_int64(),
1700 )))?;
1701 assert_eq!(vec![DataType::Int64], output);
1702
1703 let output = null_input(Coercion::new_implicit(
1704 TypeSignatureClass::Native(logical_int64()),
1705 vec![],
1706 NativeType::Int64,
1707 ))?;
1708 assert_eq!(vec![DataType::Int64], output);
1709
1710 let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1712 assert_eq!(vec![DataType::Null], output);
1713
1714 let output = null_input(Coercion::new_implicit(
1715 TypeSignatureClass::Integer,
1716 vec![],
1717 NativeType::Int64,
1718 ))?;
1719 assert_eq!(vec![DataType::Null], output);
1720
1721 Ok(())
1722 }
1723
1724 #[test]
1725 fn test_coercible_dictionary() -> Result<()> {
1726 let dictionary =
1727 DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64));
1728 fn dictionary_input(coercion: Coercion) -> Result<Vec<DataType>> {
1729 fields_with_udf(
1730 &[Field::new(
1731 "field",
1732 DataType::Dictionary(
1733 Box::new(DataType::Int8),
1734 Box::new(DataType::Int64),
1735 ),
1736 true,
1737 )
1738 .into()],
1739 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1740 )
1741 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1742 }
1743
1744 let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native(
1746 logical_int64(),
1747 )))?;
1748 assert_eq!(vec![DataType::Int64], output);
1749
1750 let output = dictionary_input(Coercion::new_implicit(
1751 TypeSignatureClass::Native(logical_int64()),
1752 vec![],
1753 NativeType::Int64,
1754 ))?;
1755 assert_eq!(vec![DataType::Int64], output);
1756
1757 let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1759 assert_eq!(vec![dictionary.clone()], output);
1760
1761 let output = dictionary_input(Coercion::new_implicit(
1762 TypeSignatureClass::Integer,
1763 vec![],
1764 NativeType::Int64,
1765 ))?;
1766 assert_eq!(vec![dictionary.clone()], output);
1767
1768 Ok(())
1769 }
1770
1771 #[test]
1772 fn test_coercible_run_end_encoded() -> Result<()> {
1773 let run_end_encoded = DataType::RunEndEncoded(
1774 Field::new("run_ends", DataType::Int16, false).into(),
1775 Field::new("values", DataType::Int64, true).into(),
1776 );
1777 fn run_end_encoded_input(coercion: Coercion) -> Result<Vec<DataType>> {
1778 fields_with_udf(
1779 &[Field::new(
1780 "field",
1781 DataType::RunEndEncoded(
1782 Field::new("run_ends", DataType::Int16, false).into(),
1783 Field::new("values", DataType::Int64, true).into(),
1784 ),
1785 true,
1786 )
1787 .into()],
1788 &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1789 )
1790 .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1791 }
1792
1793 let output = run_end_encoded_input(Coercion::new_exact(
1795 TypeSignatureClass::Native(logical_int64()),
1796 ))?;
1797 assert_eq!(vec![DataType::Int64], output);
1798
1799 let output = run_end_encoded_input(Coercion::new_implicit(
1800 TypeSignatureClass::Native(logical_int64()),
1801 vec![],
1802 NativeType::Int64,
1803 ))?;
1804 assert_eq!(vec![DataType::Int64], output);
1805
1806 let output =
1808 run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1809 assert_eq!(vec![run_end_encoded.clone()], output);
1810
1811 let output = run_end_encoded_input(Coercion::new_implicit(
1812 TypeSignatureClass::Integer,
1813 vec![],
1814 NativeType::Int64,
1815 ))?;
1816 assert_eq!(vec![run_end_encoded.clone()], output);
1817
1818 Ok(())
1819 }
1820
1821 #[test]
1822 fn test_get_valid_types_coercible_binary() -> Result<()> {
1823 let signature = Signature::coercible(
1824 vec![Coercion::new_exact(TypeSignatureClass::Native(
1825 logical_binary(),
1826 ))],
1827 Volatility::Immutable,
1828 );
1829
1830 for t in [
1832 DataType::Binary,
1833 DataType::BinaryView,
1834 DataType::LargeBinary,
1835 ] {
1836 assert_eq!(
1837 get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1838 vec![vec![t]]
1839 );
1840 }
1841
1842 Ok(())
1843 }
1844
1845 #[test]
1846 fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1847 let function = "fixed_size_arrays";
1848 let signature = Signature::arrays(2, None, Volatility::Immutable);
1849
1850 let data_types = vec![
1851 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1852 DataType::new_fixed_size_list(DataType::Int32, 5, true),
1853 ];
1854 assert_eq!(
1855 get_valid_types(function, &signature.type_signature, &data_types)?,
1856 vec![vec![
1857 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1858 DataType::new_fixed_size_list(DataType::Int64, 5, true),
1859 ]]
1860 );
1861
1862 let data_types = vec![
1863 DataType::new_fixed_size_list(DataType::Int64, 3, true),
1864 DataType::new_list(DataType::Int32, true),
1865 ];
1866 assert_eq!(
1867 get_valid_types(function, &signature.type_signature, &data_types)?,
1868 vec![vec![
1869 DataType::new_list(DataType::Int64, true),
1870 DataType::new_list(DataType::Int64, true),
1871 ]]
1872 );
1873
1874 let data_types = vec![
1875 DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1876 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1877 ];
1878 assert_eq!(
1879 get_valid_types(function, &signature.type_signature, &data_types)?,
1880 vec![vec![]]
1881 );
1882
1883 let data_types = vec![
1884 DataType::new_fixed_size_list(DataType::Int64, 3, false),
1885 DataType::new_list(DataType::Int32, false),
1886 ];
1887 assert_eq!(
1888 get_valid_types(function, &signature.type_signature, &data_types)?,
1889 vec![vec![
1890 DataType::new_list(DataType::Int64, false),
1891 DataType::new_list(DataType::Int64, false),
1892 ]]
1893 );
1894
1895 Ok(())
1896 }
1897
1898 #[derive(Debug, PartialEq, Eq, Hash)]
1899 struct MockHigherOrderUDF {
1900 signature: HigherOrderSignature,
1901 coerced_value_types: Vec<DataType>,
1902 }
1903
1904 impl HigherOrderUDFImpl for MockHigherOrderUDF {
1905 fn name(&self) -> &str {
1906 "mock_higher_order_function"
1907 }
1908
1909 fn signature(&self) -> &HigherOrderSignature {
1910 &self.signature
1911 }
1912
1913 fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1914 if arg_types.len() != 1 {
1915 return plan_err!(
1916 "mock_higher_order_function expects 1 value arguments, got {}",
1917 arg_types.len()
1918 );
1919 }
1920 Ok(self.coerced_value_types.clone())
1921 }
1922
1923 fn coerce_values_for_lambdas(
1924 &self,
1925 fields: &[ValueOrLambda<DataType, DataType>],
1926 ) -> Result<Option<Vec<DataType>>> {
1927 let [
1929 ValueOrLambda::Value(list),
1930 ValueOrLambda::Value(_initial),
1931 ValueOrLambda::Lambda(merge),
1932 ] = fields
1933 else {
1934 unreachable!()
1935 };
1936
1937 Ok(Some(vec![list.clone(), merge.clone()]))
1938 }
1939
1940 fn lambda_parameters(
1941 &self,
1942 _step: usize,
1943 _fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1944 ) -> Result<crate::LambdaParametersProgress> {
1945 unimplemented!("mock_higher_order_function")
1946 }
1947
1948 fn return_field_from_args(
1949 &self,
1950 _args: HigherOrderReturnFieldArgs,
1951 ) -> Result<FieldRef> {
1952 unimplemented!("mock_higher_order_function")
1953 }
1954
1955 fn invoke_with_args(
1956 &self,
1957 _args: HigherOrderFunctionArgs,
1958 ) -> Result<ColumnarValue> {
1959 unimplemented!("mock_higher_order_function")
1960 }
1961 }
1962
1963 #[test]
1964 fn test_higher_order_function_user_defined_type_coercion() {
1965 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
1966 signature: HigherOrderSignature::user_defined(Volatility::Immutable),
1967 coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
1968 });
1969
1970 let new_fields = value_fields_with_higher_order_udf(
1971 &[
1972 ValueOrLambda::Value(Arc::new(Field::new_list(
1973 "",
1974 Field::new_list_field(DataType::Int32, false),
1975 false,
1976 ))),
1977 ValueOrLambda::Lambda(()),
1978 ],
1979 &fun,
1980 )
1981 .unwrap();
1982
1983 assert_eq!(
1985 new_fields,
1986 vec![
1987 ValueOrLambda::Value(Arc::new(Field::new_large_list(
1988 "",
1989 Field::new_list_field(DataType::Int32, false),
1990 false
1991 ))),
1992 ValueOrLambda::Lambda(()),
1993 ]
1994 )
1995 }
1996
1997 #[test]
1998 fn test_higher_order_function_coerce_values_for_lambdas() {
1999 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2000 signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
2001 coerced_value_types: vec![],
2002 });
2003
2004 let new_fields = value_fields_with_higher_order_udf_and_lambdas(
2005 &[
2006 ValueOrLambda::Value(Arc::new(Field::new_list(
2007 "",
2008 Field::new_list_field(DataType::Float32, true),
2009 true,
2010 ))),
2011 ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
2012 ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))),
2013 ],
2014 &fun,
2015 )
2016 .unwrap();
2017
2018 assert_eq!(
2020 new_fields,
2021 vec![
2022 ValueOrLambda::Value(Arc::new(Field::new_list(
2023 "",
2024 Field::new_list_field(DataType::Float32, true),
2025 true,
2026 ))),
2027 ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
2028 ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))),
2029 ]
2030 )
2031 }
2032
2033 #[test]
2034 fn test_higher_order_function_user_defined_type_coercion_bad_args() {
2035 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2036 signature: HigherOrderSignature::user_defined(Volatility::Immutable),
2037 coerced_value_types: vec![DataType::Int32],
2038 });
2039
2040 let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err();
2041
2042 assert_contains!(
2043 err.to_string(),
2044 "mock_higher_order_function expects 1 value arguments, got 0"
2045 );
2046 }
2047
2048 #[test]
2049 fn test_higher_order_function_faulty_user_defined_type_coercion() {
2050 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2051 signature: HigherOrderSignature::user_defined(Volatility::Immutable),
2052 coerced_value_types: vec![DataType::Int32, DataType::Int32],
2053 });
2054
2055 let err = value_fields_with_higher_order_udf::<()>(
2056 &[ValueOrLambda::Value(Arc::new(Field::new(
2057 "",
2058 DataType::Int32,
2059 false,
2060 )))],
2061 &fun,
2062 )
2063 .unwrap_err();
2064
2065 assert_contains!(
2066 err.to_string(),
2067 "mock_higher_order_function coerce_value_types should have returned 1 items but returned 2"
2068 );
2069 }
2070
2071 #[test]
2072 fn test_higher_order_function_any_signature() {
2073 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2074 signature: HigherOrderSignature::any(1, Volatility::Immutable),
2075 coerced_value_types: vec![],
2076 });
2077
2078 let new_fields =
2079 value_fields_with_higher_order_udf(&[ValueOrLambda::Lambda(())], &fun)
2080 .unwrap();
2081
2082 assert_eq!(new_fields, vec![ValueOrLambda::Lambda(())])
2084 }
2085
2086 #[test]
2087 fn test_higher_order_function_any_signature_bad_args() {
2088 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2089 signature: HigherOrderSignature::any(1, Volatility::Immutable),
2090 coerced_value_types: vec![],
2091 });
2092
2093 let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err();
2094
2095 assert_contains!(
2096 err.to_string(),
2097 "The function 'mock_higher_order_function' expected 1 arguments but received 0"
2098 );
2099 }
2100
2101 #[test]
2102 fn test_higher_order_function_exact_signature() {
2103 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2104 signature: HigherOrderSignature::exact(
2105 vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2106 Volatility::Immutable,
2107 ),
2108 coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
2109 });
2110
2111 let new_fields = value_fields_with_higher_order_udf(
2112 &[
2113 ValueOrLambda::Value(Arc::new(Field::new_list(
2114 "",
2115 Field::new_list_field(DataType::Int32, false),
2116 false,
2117 ))),
2118 ValueOrLambda::Lambda(()),
2119 ],
2120 &fun,
2121 )
2122 .unwrap();
2123
2124 assert_eq!(
2126 new_fields,
2127 vec![
2128 ValueOrLambda::Value(Arc::new(Field::new_large_list(
2129 "",
2130 Field::new_list_field(DataType::Int32, false),
2131 false
2132 ))),
2133 ValueOrLambda::Lambda(()),
2134 ]
2135 )
2136 }
2137
2138 #[test]
2139 fn test_higher_order_function_exact_signature_wrong_value_count() {
2140 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2141 signature: HigherOrderSignature::exact(
2142 vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2143 Volatility::Immutable,
2144 ),
2145 coerced_value_types: vec![],
2146 });
2147
2148 let err = value_fields_with_higher_order_udf::<()>(
2149 &[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())],
2150 &fun,
2151 )
2152 .unwrap_err();
2153
2154 assert_contains!(
2155 err.to_string(),
2156 "expected a value at position 0 but received a lambda"
2157 );
2158 }
2159
2160 #[test]
2161 fn test_higher_order_function_exact_signature_wrong_lambda_count() {
2162 let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2163 signature: HigherOrderSignature::exact(
2164 vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2165 Volatility::Immutable,
2166 ),
2167 coerced_value_types: vec![],
2168 });
2169
2170 let err = value_fields_with_higher_order_udf::<()>(
2171 &[
2172 ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2173 ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2174 ],
2175 &fun,
2176 )
2177 .unwrap_err();
2178
2179 assert_contains!(
2180 err.to_string(),
2181 "expected a lambda at position 1 but received a value"
2182 );
2183 }
2184}