datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::FieldRef;
21use arrow::{
22    compute::can_cast_types,
23    datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27    base_type, coerced_fixed_size_list_to_list, ListCoercion,
28};
29use datafusion_common::{
30    exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36    type_coercion::binary::comparison_coercion_numeric,
37    type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42/// Performs type coercion for scalar function arguments.
43///
44/// Returns the data types to which each argument must be coerced to
45/// match `signature`.
46///
47/// For more details on coercion in general, please see the
48/// [`type_coercion`](crate::type_coercion) module.
49pub fn data_types_with_scalar_udf(
50    current_types: &[DataType],
51    func: &ScalarUDF,
52) -> Result<Vec<DataType>> {
53    let signature = func.signature();
54    let type_signature = &signature.type_signature;
55
56    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
57        if type_signature.supports_zero_argument() {
58            return Ok(vec![]);
59        } else if type_signature.used_to_support_zero_arguments() {
60            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
61            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
62        } else {
63            return plan_err!("'{}' does not support zero arguments", func.name());
64        }
65    }
66
67    let valid_types =
68        get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
69
70    if valid_types
71        .iter()
72        .any(|data_type| data_type == current_types)
73    {
74        return Ok(current_types.to_vec());
75    }
76
77    try_coerce_types(func.name(), valid_types, current_types, type_signature)
78}
79
80/// Performs type coercion for aggregate function arguments.
81///
82/// Returns the fields to which each argument must be coerced to
83/// match `signature`.
84///
85/// For more details on coercion in general, please see the
86/// [`type_coercion`](crate::type_coercion) module.
87pub fn fields_with_aggregate_udf(
88    current_fields: &[FieldRef],
89    func: &AggregateUDF,
90) -> Result<Vec<FieldRef>> {
91    let signature = func.signature();
92    let type_signature = &signature.type_signature;
93
94    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
95        if type_signature.supports_zero_argument() {
96            return Ok(vec![]);
97        } else if type_signature.used_to_support_zero_arguments() {
98            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
99            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
100        } else {
101            return plan_err!("'{}' does not support zero arguments", func.name());
102        }
103    }
104    let current_types = current_fields
105        .iter()
106        .map(|f| f.data_type())
107        .cloned()
108        .collect::<Vec<_>>();
109
110    let valid_types =
111        get_valid_types_with_aggregate_udf(type_signature, &current_types, func)?;
112    if valid_types
113        .iter()
114        .any(|data_type| data_type == &current_types)
115    {
116        return Ok(current_fields.to_vec());
117    }
118
119    let updated_types =
120        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
121
122    Ok(current_fields
123        .iter()
124        .zip(updated_types)
125        .map(|(current_field, new_type)| {
126            current_field.as_ref().clone().with_data_type(new_type)
127        })
128        .map(Arc::new)
129        .collect())
130}
131
132/// Performs type coercion for window function arguments.
133///
134/// Returns the data types to which each argument must be coerced to
135/// match `signature`.
136///
137/// For more details on coercion in general, please see the
138/// [`type_coercion`](crate::type_coercion) module.
139pub fn fields_with_window_udf(
140    current_fields: &[FieldRef],
141    func: &WindowUDF,
142) -> Result<Vec<FieldRef>> {
143    let signature = func.signature();
144    let type_signature = &signature.type_signature;
145
146    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
147        if type_signature.supports_zero_argument() {
148            return Ok(vec![]);
149        } else if type_signature.used_to_support_zero_arguments() {
150            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
151            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
152        } else {
153            return plan_err!("'{}' does not support zero arguments", func.name());
154        }
155    }
156
157    let current_types = current_fields
158        .iter()
159        .map(|f| f.data_type())
160        .cloned()
161        .collect::<Vec<_>>();
162    let valid_types =
163        get_valid_types_with_window_udf(type_signature, &current_types, func)?;
164    if valid_types
165        .iter()
166        .any(|data_type| data_type == &current_types)
167    {
168        return Ok(current_fields.to_vec());
169    }
170
171    let updated_types =
172        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
173
174    Ok(current_fields
175        .iter()
176        .zip(updated_types)
177        .map(|(current_field, new_type)| {
178            current_field.as_ref().clone().with_data_type(new_type)
179        })
180        .map(Arc::new)
181        .collect())
182}
183
184/// Performs type coercion for function arguments.
185///
186/// Returns the data types to which each argument must be coerced to
187/// match `signature`.
188///
189/// For more details on coercion in general, please see the
190/// [`type_coercion`](crate::type_coercion) module.
191pub fn data_types(
192    function_name: impl AsRef<str>,
193    current_types: &[DataType],
194    signature: &Signature,
195) -> Result<Vec<DataType>> {
196    let type_signature = &signature.type_signature;
197
198    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
199        if type_signature.supports_zero_argument() {
200            return Ok(vec![]);
201        } else if type_signature.used_to_support_zero_arguments() {
202            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
203            return plan_err!(
204                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
205                function_name.as_ref()
206            );
207        } else {
208            return plan_err!(
209                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
210                function_name.as_ref()
211            );
212        }
213    }
214
215    let valid_types =
216        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
217    if valid_types
218        .iter()
219        .any(|data_type| data_type == current_types)
220    {
221        return Ok(current_types.to_vec());
222    }
223
224    try_coerce_types(
225        function_name.as_ref(),
226        valid_types,
227        current_types,
228        type_signature,
229    )
230}
231
232fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
233    if let TypeSignature::OneOf(signatures) = type_signature {
234        return signatures.iter().all(is_well_supported_signature);
235    }
236
237    matches!(
238        type_signature,
239        TypeSignature::UserDefined
240            | TypeSignature::Numeric(_)
241            | TypeSignature::String(_)
242            | TypeSignature::Coercible(_)
243            | TypeSignature::Any(_)
244            | TypeSignature::Nullary
245            | TypeSignature::Comparable(_)
246    )
247}
248
249fn try_coerce_types(
250    function_name: &str,
251    valid_types: Vec<Vec<DataType>>,
252    current_types: &[DataType],
253    type_signature: &TypeSignature,
254) -> Result<Vec<DataType>> {
255    let mut valid_types = valid_types;
256
257    // Well-supported signature that returns exact valid types.
258    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
259        // There may be many valid types if valid signature is OneOf
260        // Otherwise, there should be only one valid type
261        if !type_signature.is_one_of() {
262            assert_eq!(valid_types.len(), 1);
263        }
264
265        let valid_types = valid_types.swap_remove(0);
266        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
267            return Ok(t);
268        }
269    } else {
270        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
271        // Try and coerce the argument types to match the signature, returning the
272        // coerced types from the first matching signature.
273        for valid_types in valid_types {
274            if let Some(types) = maybe_data_types(&valid_types, current_types) {
275                return Ok(types);
276            }
277        }
278    }
279
280    // none possible -> Error
281    plan_err!(
282        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
283        current_types.iter().join(", ")
284    )
285}
286
287fn get_valid_types_with_scalar_udf(
288    signature: &TypeSignature,
289    current_types: &[DataType],
290    func: &ScalarUDF,
291) -> Result<Vec<Vec<DataType>>> {
292    match signature {
293        TypeSignature::UserDefined => match func.coerce_types(current_types) {
294            Ok(coerced_types) => Ok(vec![coerced_types]),
295            Err(e) => exec_err!(
296                "Function '{}' user-defined coercion failed with {:?}",
297                func.name(),
298                e.strip_backtrace()
299            ),
300        },
301        TypeSignature::OneOf(signatures) => {
302            let mut res = vec![];
303            let mut errors = vec![];
304            for sig in signatures {
305                match get_valid_types_with_scalar_udf(sig, current_types, func) {
306                    Ok(valid_types) => {
307                        res.extend(valid_types);
308                    }
309                    Err(e) => {
310                        errors.push(e.to_string());
311                    }
312                }
313            }
314
315            // Every signature failed, return the joined error
316            if res.is_empty() {
317                internal_err!(
318                    "Function '{}' failed to match any signature, errors: {}",
319                    func.name(),
320                    errors.join(",")
321                )
322            } else {
323                Ok(res)
324            }
325        }
326        _ => get_valid_types(func.name(), signature, current_types),
327    }
328}
329
330fn get_valid_types_with_aggregate_udf(
331    signature: &TypeSignature,
332    current_types: &[DataType],
333    func: &AggregateUDF,
334) -> Result<Vec<Vec<DataType>>> {
335    let valid_types = match signature {
336        TypeSignature::UserDefined => match func.coerce_types(current_types) {
337            Ok(coerced_types) => vec![coerced_types],
338            Err(e) => {
339                return exec_err!(
340                    "Function '{}' user-defined coercion failed with {:?}",
341                    func.name(),
342                    e.strip_backtrace()
343                )
344            }
345        },
346        TypeSignature::OneOf(signatures) => signatures
347            .iter()
348            .filter_map(|t| {
349                get_valid_types_with_aggregate_udf(t, current_types, func).ok()
350            })
351            .flatten()
352            .collect::<Vec<_>>(),
353        _ => get_valid_types(func.name(), signature, current_types)?,
354    };
355
356    Ok(valid_types)
357}
358
359fn get_valid_types_with_window_udf(
360    signature: &TypeSignature,
361    current_types: &[DataType],
362    func: &WindowUDF,
363) -> Result<Vec<Vec<DataType>>> {
364    let valid_types = match signature {
365        TypeSignature::UserDefined => match func.coerce_types(current_types) {
366            Ok(coerced_types) => vec![coerced_types],
367            Err(e) => {
368                return exec_err!(
369                    "Function '{}' user-defined coercion failed with {:?}",
370                    func.name(),
371                    e.strip_backtrace()
372                )
373            }
374        },
375        TypeSignature::OneOf(signatures) => signatures
376            .iter()
377            .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
378            .flatten()
379            .collect::<Vec<_>>(),
380        _ => get_valid_types(func.name(), signature, current_types)?,
381    };
382
383    Ok(valid_types)
384}
385
386/// Returns a Vec of all possible valid argument types for the given signature.
387fn get_valid_types(
388    function_name: &str,
389    signature: &TypeSignature,
390    current_types: &[DataType],
391) -> Result<Vec<Vec<DataType>>> {
392    fn array_valid_types(
393        function_name: &str,
394        current_types: &[DataType],
395        arguments: &[ArrayFunctionArgument],
396        array_coercion: Option<&ListCoercion>,
397    ) -> Result<Vec<Vec<DataType>>> {
398        if current_types.len() != arguments.len() {
399            return Ok(vec![vec![]]);
400        }
401
402        let mut large_list = false;
403        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
404        let mut list_sizes = Vec::with_capacity(arguments.len());
405        let mut element_types = Vec::with_capacity(arguments.len());
406        let mut nested_item_nullability = Vec::with_capacity(arguments.len());
407        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
408            match argument {
409                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
410                    nested_item_nullability.push(None);
411                }
412                ArrayFunctionArgument::Element => {
413                    element_types.push(current_type.clone());
414                    nested_item_nullability.push(None);
415                }
416                ArrayFunctionArgument::Array => match current_type {
417                    DataType::Null => {
418                        element_types.push(DataType::Null);
419                        nested_item_nullability.push(None);
420                    }
421                    DataType::List(field) => {
422                        element_types.push(field.data_type().clone());
423                        nested_item_nullability.push(Some(field.is_nullable()));
424                        fixed_size = false;
425                    }
426                    DataType::LargeList(field) => {
427                        element_types.push(field.data_type().clone());
428                        nested_item_nullability.push(Some(field.is_nullable()));
429                        large_list = true;
430                        fixed_size = false;
431                    }
432                    DataType::FixedSizeList(field, size) => {
433                        element_types.push(field.data_type().clone());
434                        nested_item_nullability.push(Some(field.is_nullable()));
435                        list_sizes.push(*size)
436                    }
437                    arg_type => {
438                        plan_err!("{function_name} does not support type {arg_type}")?
439                    }
440                },
441            }
442        }
443
444        debug_assert_eq!(nested_item_nullability.len(), arguments.len());
445
446        let Some(element_type) = type_union_resolution(&element_types) else {
447            return Ok(vec![vec![]]);
448        };
449
450        if !fixed_size {
451            list_sizes.clear()
452        };
453
454        let mut list_sizes = list_sizes.into_iter();
455        let valid_types = arguments
456            .iter()
457            .zip(current_types.iter())
458            .zip(nested_item_nullability)
459            .map(|((argument_type, current_type), is_nested_item_nullable)| {
460                match argument_type {
461                    ArrayFunctionArgument::Index => DataType::Int64,
462                    ArrayFunctionArgument::String => DataType::Utf8,
463                    ArrayFunctionArgument::Element => element_type.clone(),
464                    ArrayFunctionArgument::Array => {
465                        if current_type.is_null() {
466                            DataType::Null
467                        } else if large_list {
468                            DataType::new_large_list(
469                                element_type.clone(),
470                                is_nested_item_nullable.unwrap_or(true),
471                            )
472                        } else if let Some(size) = list_sizes.next() {
473                            DataType::new_fixed_size_list(
474                                element_type.clone(),
475                                size,
476                                is_nested_item_nullable.unwrap_or(true),
477                            )
478                        } else {
479                            DataType::new_list(
480                                element_type.clone(),
481                                is_nested_item_nullable.unwrap_or(true),
482                            )
483                        }
484                    }
485                }
486            });
487
488        Ok(vec![valid_types.collect()])
489    }
490
491    fn recursive_array(array_type: &DataType) -> Option<DataType> {
492        match array_type {
493            DataType::List(_)
494            | DataType::LargeList(_)
495            | DataType::FixedSizeList(_, _) => {
496                let array_type = coerced_fixed_size_list_to_list(array_type);
497                Some(array_type)
498            }
499            _ => None,
500        }
501    }
502
503    fn function_length_check(
504        function_name: &str,
505        length: usize,
506        expected_length: usize,
507    ) -> Result<()> {
508        if length != expected_length {
509            return plan_err!(
510                "Function '{function_name}' expects {expected_length} arguments but received {length}"
511            );
512        }
513        Ok(())
514    }
515
516    let valid_types = match signature {
517        TypeSignature::Variadic(valid_types) => valid_types
518            .iter()
519            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
520            .collect(),
521        TypeSignature::String(number) => {
522            function_length_check(function_name, current_types.len(), *number)?;
523
524            let mut new_types = Vec::with_capacity(current_types.len());
525            for data_type in current_types.iter() {
526                let logical_data_type: NativeType = data_type.into();
527                if logical_data_type == NativeType::String {
528                    new_types.push(data_type.to_owned());
529                } else if logical_data_type == NativeType::Null {
530                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
531                    new_types.push(DataType::Utf8);
532                } else {
533                    return plan_err!(
534                        "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
535                    );
536                }
537            }
538
539            // Find the common string type for the given types
540            fn find_common_type(
541                function_name: &str,
542                lhs_type: &DataType,
543                rhs_type: &DataType,
544            ) -> Result<DataType> {
545                match (lhs_type, rhs_type) {
546                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
547                        find_common_type(function_name, lhs, rhs)
548                    }
549                    (DataType::Dictionary(_, v), other)
550                    | (other, DataType::Dictionary(_, v)) => {
551                        find_common_type(function_name, v, other)
552                    }
553                    _ => {
554                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
555                            Ok(coerced_type)
556                        } else {
557                            plan_err!(
558                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
559                            )
560                        }
561                    }
562                }
563            }
564
565            // Length checked above, safe to unwrap
566            let mut coerced_type = new_types.first().unwrap().to_owned();
567            for t in new_types.iter().skip(1) {
568                coerced_type = find_common_type(function_name, &coerced_type, t)?;
569            }
570
571            fn base_type_or_default_type(data_type: &DataType) -> DataType {
572                if let DataType::Dictionary(_, v) = data_type {
573                    base_type_or_default_type(v)
574                } else {
575                    data_type.to_owned()
576                }
577            }
578
579            vec![vec![base_type_or_default_type(&coerced_type); *number]]
580        }
581        TypeSignature::Numeric(number) => {
582            function_length_check(function_name, current_types.len(), *number)?;
583
584            // Find common numeric type among given types except string
585            let mut valid_type = current_types.first().unwrap().to_owned();
586            for t in current_types.iter().skip(1) {
587                let logical_data_type: NativeType = t.into();
588                if logical_data_type == NativeType::Null {
589                    continue;
590                }
591
592                if !logical_data_type.is_numeric() {
593                    return plan_err!(
594                        "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
595                    );
596                }
597
598                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
599                    valid_type = coerced_type;
600                } else {
601                    return plan_err!(
602                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
603                    );
604                }
605            }
606
607            let logical_data_type: NativeType = valid_type.clone().into();
608            // Fallback to default type if we don't know which type to coerced to
609            // f64 is chosen since most of the math functions utilize Signature::numeric,
610            // and their default type is double precision
611            if logical_data_type == NativeType::Null {
612                valid_type = DataType::Float64;
613            } else if !logical_data_type.is_numeric() {
614                return plan_err!(
615                    "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
616                );
617            }
618
619            vec![vec![valid_type; *number]]
620        }
621        TypeSignature::Comparable(num) => {
622            function_length_check(function_name, current_types.len(), *num)?;
623            let mut target_type = current_types[0].to_owned();
624            for data_type in current_types.iter().skip(1) {
625                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
626                    target_type = dt;
627                } else {
628                    return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
629                }
630            }
631            // Convert null to String type.
632            if target_type.is_null() {
633                vec![vec![DataType::Utf8View; *num]]
634            } else {
635                vec![vec![target_type; *num]]
636            }
637        }
638        TypeSignature::Coercible(param_types) => {
639            function_length_check(function_name, current_types.len(), param_types.len())?;
640
641            let mut new_types = Vec::with_capacity(current_types.len());
642            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
643                let current_native_type: NativeType = current_type.into();
644
645                if param.desired_type().matches_native_type(&current_native_type) {
646                    let casted_type = param.desired_type().default_casted_type(
647                        &current_native_type,
648                        current_type,
649                    )?;
650
651                    new_types.push(casted_type);
652                } else if param
653                .allowed_source_types()
654                .iter()
655                .any(|t| t.matches_native_type(&current_native_type)) {
656                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
657                    let default_casted_type = param.default_casted_type().unwrap();
658                    let casted_type = default_casted_type.default_cast_for(current_type)?;
659                    new_types.push(casted_type);
660                } else {
661                    return internal_err!(
662                        "Expect {} but received NativeType::{}, DataType: {}",
663                        param.desired_type(),
664                        current_native_type,
665                        current_type
666                    );
667                }
668            }
669
670            vec![new_types]
671        }
672        TypeSignature::Uniform(number, valid_types) => {
673            if *number == 0 {
674                return plan_err!("The function '{function_name}' expected at least one argument");
675            }
676
677            valid_types
678                .iter()
679                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
680                .collect()
681        }
682        TypeSignature::UserDefined => {
683            return internal_err!(
684                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
685            )
686        }
687        TypeSignature::VariadicAny => {
688            if current_types.is_empty() {
689                return plan_err!(
690                    "Function '{function_name}' expected at least one argument but received 0"
691                );
692            }
693            vec![current_types.to_vec()]
694        }
695        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
696        TypeSignature::ArraySignature(ref function_signature) => match function_signature {
697            ArrayFunctionSignature::Array { arguments, array_coercion, } => {
698                array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
699            }
700            ArrayFunctionSignature::RecursiveArray => {
701                if current_types.len() != 1 {
702                    return Ok(vec![vec![]]);
703                }
704                recursive_array(&current_types[0])
705                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
706            }
707            ArrayFunctionSignature::MapArray => {
708                if current_types.len() != 1 {
709                    return Ok(vec![vec![]]);
710                }
711
712                match &current_types[0] {
713                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
714                    _ => vec![vec![]],
715                }
716            }
717        },
718        TypeSignature::Nullary => {
719            if !current_types.is_empty() {
720                return plan_err!(
721                    "The function '{function_name}' expected zero argument but received {}",
722                    current_types.len()
723                );
724            }
725            vec![vec![]]
726        }
727        TypeSignature::Any(number) => {
728            if current_types.is_empty() {
729                return plan_err!(
730                    "The function '{function_name}' expected at least one argument but received 0"
731                );
732            }
733
734            if current_types.len() != *number {
735                return plan_err!(
736                    "The function '{function_name}' expected {number} arguments but received {}",
737                    current_types.len()
738                );
739            }
740            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
741        }
742        TypeSignature::OneOf(types) => types
743            .iter()
744            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
745            .flatten()
746            .collect::<Vec<_>>(),
747    };
748
749    Ok(valid_types)
750}
751
752/// Try to coerce the current argument types to match the given `valid_types`.
753///
754/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
755/// but was called with `(int32, int64)`, this function could match the
756/// valid_types by coercing the first argument to `int64`, and would return
757/// `Some([int64, int64])`.
758fn maybe_data_types(
759    valid_types: &[DataType],
760    current_types: &[DataType],
761) -> Option<Vec<DataType>> {
762    if valid_types.len() != current_types.len() {
763        return None;
764    }
765
766    let mut new_type = Vec::with_capacity(valid_types.len());
767    for (i, valid_type) in valid_types.iter().enumerate() {
768        let current_type = &current_types[i];
769
770        if current_type == valid_type {
771            new_type.push(current_type.clone())
772        } else {
773            // attempt to coerce.
774            // TODO: Replace with `can_cast_types` after failing cases are resolved
775            // (they need new signature that returns exactly valid types instead of list of possible valid types).
776            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
777                new_type.push(coerced_type)
778            } else {
779                // not possible
780                return None;
781            }
782        }
783    }
784    Some(new_type)
785}
786
787/// Check if the current argument types can be coerced to match the given `valid_types`
788/// unlike `maybe_data_types`, this function does not coerce the types.
789/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
790fn maybe_data_types_without_coercion(
791    valid_types: &[DataType],
792    current_types: &[DataType],
793) -> Option<Vec<DataType>> {
794    if valid_types.len() != current_types.len() {
795        return None;
796    }
797
798    let mut new_type = Vec::with_capacity(valid_types.len());
799    for (i, valid_type) in valid_types.iter().enumerate() {
800        let current_type = &current_types[i];
801
802        if current_type == valid_type {
803            new_type.push(current_type.clone())
804        } else if can_cast_types(current_type, valid_type) {
805            // validate the valid type is castable from the current type
806            new_type.push(valid_type.clone())
807        } else {
808            return None;
809        }
810    }
811    Some(new_type)
812}
813
814/// Return true if a value of type `type_from` can be coerced
815/// (losslessly converted) into a value of `type_to`
816///
817/// See the module level documentation for more detail on coercion.
818pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
819    if type_into == type_from {
820        return true;
821    }
822    if let Some(coerced) = coerced_from(type_into, type_from) {
823        return coerced == *type_into;
824    }
825    false
826}
827
828/// Find the coerced type for the given `type_into` and `type_from`.
829/// Returns `None` if coercion is not possible.
830///
831/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
832///
833/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
834fn coerced_from<'a>(
835    type_into: &'a DataType,
836    type_from: &'a DataType,
837) -> Option<DataType> {
838    use self::DataType::*;
839
840    // match Dictionary first
841    match (type_into, type_from) {
842        // coerced dictionary first
843        (_, Dictionary(_, value_type))
844            if coerced_from(type_into, value_type).is_some() =>
845        {
846            Some(type_into.clone())
847        }
848        (Dictionary(_, value_type), _)
849            if coerced_from(value_type, type_from).is_some() =>
850        {
851            Some(type_into.clone())
852        }
853        // coerced into type_into
854        (Int8, Null | Int8) => Some(type_into.clone()),
855        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
856        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
857        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
858            Some(type_into.clone())
859        }
860        (UInt8, Null | UInt8) => Some(type_into.clone()),
861        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
862        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
863        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
864        (
865            Float32,
866            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
867            | Float32,
868        ) => Some(type_into.clone()),
869        (
870            Float64,
871            Null
872            | Int8
873            | Int16
874            | Int32
875            | Int64
876            | UInt8
877            | UInt16
878            | UInt32
879            | UInt64
880            | Float32
881            | Float64
882            | Decimal32(_, _)
883            | Decimal64(_, _)
884            | Decimal128(_, _)
885            | Decimal256(_, _),
886        ) => Some(type_into.clone()),
887        (
888            Timestamp(TimeUnit::Nanosecond, None),
889            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
890        ) => Some(type_into.clone()),
891        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
892        // We can go into a Utf8View from a Utf8 or LargeUtf8
893        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
894        // Any type can be coerced into strings
895        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
896        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
897
898        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
899
900        // Only accept list and largelist with the same number of dimensions unless the type is Null.
901        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
902        (List(_) | LargeList(_), _)
903            if base_type(type_from).is_null()
904                || list_ndims(type_from) == list_ndims(type_into) =>
905        {
906            Some(type_into.clone())
907        }
908        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
909        (
910            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
911            FixedSizeList(f_from, size_from),
912        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
913            Some(data_type) if &data_type != f_into.data_type() => {
914                let new_field =
915                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
916                Some(FixedSizeList(new_field, *size_from))
917            }
918            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
919            _ => None,
920        },
921        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
922            match type_from {
923                Timestamp(_, Some(from_tz)) => {
924                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
925                }
926                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
927                    // In the absence of any other information assume the time zone is "+00" (UTC).
928                    Some(Timestamp(*unit, Some("+00".into())))
929                }
930                _ => None,
931            }
932        }
933        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
934            Some(type_into.clone())
935        }
936        _ => None,
937    }
938}
939
940#[cfg(test)]
941mod tests {
942    use crate::Volatility;
943
944    use super::*;
945    use arrow::datatypes::Field;
946    use datafusion_common::assert_contains;
947
948    #[test]
949    fn test_string_conversion() {
950        let cases = vec![
951            (DataType::Utf8View, DataType::Utf8, true),
952            (DataType::Utf8View, DataType::LargeUtf8, true),
953        ];
954
955        for case in cases {
956            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
957        }
958    }
959
960    #[test]
961    fn test_maybe_data_types() {
962        // this vec contains: arg1, arg2, expected result
963        let cases = vec![
964            // 2 entries, same values
965            (
966                vec![DataType::UInt8, DataType::UInt16],
967                vec![DataType::UInt8, DataType::UInt16],
968                Some(vec![DataType::UInt8, DataType::UInt16]),
969            ),
970            // 2 entries, can coerce values
971            (
972                vec![DataType::UInt16, DataType::UInt16],
973                vec![DataType::UInt8, DataType::UInt16],
974                Some(vec![DataType::UInt16, DataType::UInt16]),
975            ),
976            // 0 entries, all good
977            (vec![], vec![], Some(vec![])),
978            // 2 entries, can't coerce
979            (
980                vec![DataType::Boolean, DataType::UInt16],
981                vec![DataType::UInt8, DataType::UInt16],
982                None,
983            ),
984            // u32 -> u16 is possible
985            (
986                vec![DataType::Boolean, DataType::UInt32],
987                vec![DataType::Boolean, DataType::UInt16],
988                Some(vec![DataType::Boolean, DataType::UInt32]),
989            ),
990            // UTF8 -> Timestamp
991            (
992                vec![
993                    DataType::Timestamp(TimeUnit::Nanosecond, None),
994                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
995                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
996                ],
997                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
998                Some(vec![
999                    DataType::Timestamp(TimeUnit::Nanosecond, None),
1000                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1001                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1002                ]),
1003            ),
1004        ];
1005
1006        for case in cases {
1007            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1008        }
1009    }
1010
1011    #[test]
1012    fn test_get_valid_types_numeric() -> Result<()> {
1013        let get_valid_types_flatten =
1014            |function_name: &str,
1015             signature: &TypeSignature,
1016             current_types: &[DataType]| {
1017                get_valid_types(function_name, signature, current_types)
1018                    .unwrap()
1019                    .into_iter()
1020                    .flatten()
1021                    .collect::<Vec<_>>()
1022            };
1023
1024        // Trivial case.
1025        let got = get_valid_types_flatten(
1026            "test",
1027            &TypeSignature::Numeric(1),
1028            &[DataType::Int32],
1029        );
1030        assert_eq!(got, [DataType::Int32]);
1031
1032        // Args are coerced into a common numeric type.
1033        let got = get_valid_types_flatten(
1034            "test",
1035            &TypeSignature::Numeric(2),
1036            &[DataType::Int32, DataType::Int64],
1037        );
1038        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1039
1040        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1041        let got = get_valid_types_flatten(
1042            "test",
1043            &TypeSignature::Numeric(3),
1044            &[DataType::Int32, DataType::Int64, DataType::Float64],
1045        );
1046        assert_eq!(
1047            got,
1048            [DataType::Float64, DataType::Float64, DataType::Float64]
1049        );
1050
1051        // Cannot coerce args to a common numeric type.
1052        let got = get_valid_types(
1053            "test",
1054            &TypeSignature::Numeric(2),
1055            &[DataType::Int32, DataType::Utf8],
1056        )
1057        .unwrap_err();
1058        assert_contains!(
1059            got.to_string(),
1060            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1061        );
1062
1063        // Fallbacks to float64 if the arg is of type null.
1064        let got = get_valid_types_flatten(
1065            "test",
1066            &TypeSignature::Numeric(1),
1067            &[DataType::Null],
1068        );
1069        assert_eq!(got, [DataType::Float64]);
1070
1071        // Rejects non-numeric arg.
1072        let got = get_valid_types(
1073            "test",
1074            &TypeSignature::Numeric(1),
1075            &[DataType::Timestamp(TimeUnit::Second, None)],
1076        )
1077        .unwrap_err();
1078        assert_contains!(
1079            got.to_string(),
1080            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1081        );
1082
1083        Ok(())
1084    }
1085
1086    #[test]
1087    fn test_get_valid_types_one_of() -> Result<()> {
1088        let signature =
1089            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1090
1091        let invalid_types = get_valid_types(
1092            "test",
1093            &signature,
1094            &[DataType::Int32, DataType::Int32, DataType::Int32],
1095        )?;
1096        assert_eq!(invalid_types.len(), 0);
1097
1098        let args = vec![DataType::Int32, DataType::Int32];
1099        let valid_types = get_valid_types("test", &signature, &args)?;
1100        assert_eq!(valid_types.len(), 1);
1101        assert_eq!(valid_types[0], args);
1102
1103        let args = vec![DataType::Int32];
1104        let valid_types = get_valid_types("test", &signature, &args)?;
1105        assert_eq!(valid_types.len(), 1);
1106        assert_eq!(valid_types[0], args);
1107
1108        Ok(())
1109    }
1110
1111    #[test]
1112    fn test_get_valid_types_length_check() -> Result<()> {
1113        let signature = TypeSignature::Numeric(1);
1114
1115        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1116        assert_contains!(
1117            err.to_string(),
1118            "Function 'test' expects 1 arguments but received 0"
1119        );
1120
1121        let err = get_valid_types(
1122            "test",
1123            &signature,
1124            &[DataType::Int32, DataType::Int32, DataType::Int32],
1125        )
1126        .unwrap_err();
1127        assert_contains!(
1128            err.to_string(),
1129            "Function 'test' expects 1 arguments but received 3"
1130        );
1131
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1137        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1138        let current_types = vec![
1139            DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size
1140        ];
1141
1142        let signature = Signature::exact(
1143            vec![DataType::FixedSizeList(
1144                Arc::clone(&inner),
1145                FIXED_SIZE_LIST_WILDCARD,
1146            )],
1147            Volatility::Stable,
1148        );
1149
1150        let coerced_data_types = data_types("test", &current_types, &signature)?;
1151        assert_eq!(coerced_data_types, current_types);
1152
1153        // make sure it can't coerce to a different size
1154        let signature = Signature::exact(
1155            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1156            Volatility::Stable,
1157        );
1158        let coerced_data_types = data_types("test", &current_types, &signature);
1159        assert!(coerced_data_types.is_err());
1160
1161        // make sure it works with the same type.
1162        let signature = Signature::exact(
1163            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1164            Volatility::Stable,
1165        );
1166        let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
1167        assert_eq!(coerced_data_types, current_types);
1168
1169        Ok(())
1170    }
1171
1172    #[test]
1173    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1174        let type_into = DataType::FixedSizeList(
1175            Arc::new(Field::new_list_field(
1176                DataType::FixedSizeList(
1177                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1178                    FIXED_SIZE_LIST_WILDCARD,
1179                ),
1180                false,
1181            )),
1182            FIXED_SIZE_LIST_WILDCARD,
1183        );
1184
1185        let type_from = DataType::FixedSizeList(
1186            Arc::new(Field::new_list_field(
1187                DataType::FixedSizeList(
1188                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1189                    4,
1190                ),
1191                false,
1192            )),
1193            3,
1194        );
1195
1196        assert_eq!(
1197            coerced_from(&type_into, &type_from),
1198            Some(DataType::FixedSizeList(
1199                Arc::new(Field::new_list_field(
1200                    DataType::FixedSizeList(
1201                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1202                        4,
1203                    ),
1204                    false,
1205                )),
1206                3,
1207            ))
1208        );
1209
1210        Ok(())
1211    }
1212
1213    #[test]
1214    fn test_coerced_from_dictionary() {
1215        let type_into =
1216            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1217        let type_from = DataType::Int64;
1218        assert_eq!(coerced_from(&type_into, &type_from), None);
1219
1220        let type_from =
1221            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1222        let type_into = DataType::Int64;
1223        assert_eq!(
1224            coerced_from(&type_into, &type_from),
1225            Some(type_into.clone())
1226        );
1227    }
1228
1229    #[test]
1230    fn test_get_valid_types_array_and_array() -> Result<()> {
1231        let function = "array_and_array";
1232        let signature = Signature::arrays(
1233            2,
1234            Some(ListCoercion::FixedSizedListToList),
1235            Volatility::Immutable,
1236        );
1237
1238        let data_types = vec![
1239            DataType::new_list(DataType::Int32, true),
1240            DataType::new_large_list(DataType::Float64, true),
1241        ];
1242        assert_eq!(
1243            get_valid_types(function, &signature.type_signature, &data_types)?,
1244            vec![vec![
1245                DataType::new_large_list(DataType::Float64, true),
1246                DataType::new_large_list(DataType::Float64, true),
1247            ]]
1248        );
1249
1250        let data_types = vec![
1251            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1252            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1253        ];
1254        assert_eq!(
1255            get_valid_types(function, &signature.type_signature, &data_types)?,
1256            vec![vec![
1257                DataType::new_list(DataType::Int64, true),
1258                DataType::new_list(DataType::Int64, true),
1259            ]]
1260        );
1261
1262        let data_types = vec![
1263            DataType::new_fixed_size_list(DataType::Null, 3, true),
1264            DataType::new_large_list(DataType::Utf8, true),
1265        ];
1266        assert_eq!(
1267            get_valid_types(function, &signature.type_signature, &data_types)?,
1268            vec![vec![
1269                DataType::new_large_list(DataType::Utf8, true),
1270                DataType::new_large_list(DataType::Utf8, true),
1271            ]]
1272        );
1273
1274        Ok(())
1275    }
1276
1277    #[test]
1278    fn test_get_valid_types_array_and_element() -> Result<()> {
1279        let function = "array_and_element";
1280        let signature = Signature::array_and_element(Volatility::Immutable);
1281
1282        let data_types =
1283            vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1284        assert_eq!(
1285            get_valid_types(function, &signature.type_signature, &data_types)?,
1286            vec![vec![
1287                DataType::new_list(DataType::Float64, true),
1288                DataType::Float64,
1289            ]]
1290        );
1291
1292        let data_types = vec![
1293            DataType::new_large_list(DataType::Int32, true),
1294            DataType::Null,
1295        ];
1296        assert_eq!(
1297            get_valid_types(function, &signature.type_signature, &data_types)?,
1298            vec![vec![
1299                DataType::new_large_list(DataType::Int32, true),
1300                DataType::Int32,
1301            ]]
1302        );
1303
1304        let data_types = vec![
1305            DataType::new_fixed_size_list(DataType::Null, 3, true),
1306            DataType::Utf8,
1307        ];
1308        assert_eq!(
1309            get_valid_types(function, &signature.type_signature, &data_types)?,
1310            vec![vec![
1311                DataType::new_list(DataType::Utf8, true),
1312                DataType::Utf8,
1313            ]]
1314        );
1315
1316        Ok(())
1317    }
1318
1319    #[test]
1320    fn test_get_valid_types_element_and_array() -> Result<()> {
1321        let function = "element_and_array";
1322        let signature = Signature::element_and_array(Volatility::Immutable);
1323
1324        let data_types = vec![
1325            DataType::new_large_list(DataType::Null, false),
1326            DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1327        ];
1328        assert_eq!(
1329            get_valid_types(function, &signature.type_signature, &data_types)?,
1330            vec![vec![
1331                DataType::new_large_list(DataType::Int64, true),
1332                DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1333            ]]
1334        );
1335
1336        Ok(())
1337    }
1338
1339    #[test]
1340    fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1341        let function = "fixed_size_arrays";
1342        let signature = Signature::arrays(2, None, Volatility::Immutable);
1343
1344        let data_types = vec![
1345            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1346            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1347        ];
1348        assert_eq!(
1349            get_valid_types(function, &signature.type_signature, &data_types)?,
1350            vec![vec![
1351                DataType::new_fixed_size_list(DataType::Int64, 3, true),
1352                DataType::new_fixed_size_list(DataType::Int64, 5, true),
1353            ]]
1354        );
1355
1356        let data_types = vec![
1357            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1358            DataType::new_list(DataType::Int32, true),
1359        ];
1360        assert_eq!(
1361            get_valid_types(function, &signature.type_signature, &data_types)?,
1362            vec![vec![
1363                DataType::new_list(DataType::Int64, true),
1364                DataType::new_list(DataType::Int64, true),
1365            ]]
1366        );
1367
1368        let data_types = vec![
1369            DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1370            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1371        ];
1372        assert_eq!(
1373            get_valid_types(function, &signature.type_signature, &data_types)?,
1374            vec![vec![]]
1375        );
1376
1377        let data_types = vec![
1378            DataType::new_fixed_size_list(DataType::Int64, 3, false),
1379            DataType::new_list(DataType::Int32, false),
1380        ];
1381        assert_eq!(
1382            get_valid_types(function, &signature.type_signature, &data_types)?,
1383            vec![vec![
1384                DataType::new_list(DataType::Int64, false),
1385                DataType::new_list(DataType::Int64, false),
1386            ]]
1387        );
1388
1389        Ok(())
1390    }
1391}