datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::{binary_numeric_coercion, comparison_coercion};
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::{
21    compute::can_cast_types,
22    datatypes::{DataType, Field, TimeUnit},
23};
24use datafusion_common::types::LogicalType;
25use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
26use datafusion_common::{
27    exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28    utils::list_ndims, Result,
29};
30use datafusion_expr_common::signature::ArrayFunctionArgument;
31use datafusion_expr_common::{
32    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
33    type_coercion::binary::comparison_coercion_numeric,
34    type_coercion::binary::string_coercion,
35};
36use std::sync::Arc;
37
38/// Performs type coercion for scalar function arguments.
39///
40/// Returns the data types to which each argument must be coerced to
41/// match `signature`.
42///
43/// For more details on coercion in general, please see the
44/// [`type_coercion`](crate::type_coercion) module.
45pub fn data_types_with_scalar_udf(
46    current_types: &[DataType],
47    func: &ScalarUDF,
48) -> Result<Vec<DataType>> {
49    let signature = func.signature();
50    let type_signature = &signature.type_signature;
51
52    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
53        if type_signature.supports_zero_argument() {
54            return Ok(vec![]);
55        } else if type_signature.used_to_support_zero_arguments() {
56            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
57            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
58        } else {
59            return plan_err!("'{}' does not support zero arguments", func.name());
60        }
61    }
62
63    let valid_types =
64        get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
65
66    if valid_types
67        .iter()
68        .any(|data_type| data_type == current_types)
69    {
70        return Ok(current_types.to_vec());
71    }
72
73    try_coerce_types(func.name(), valid_types, current_types, type_signature)
74}
75
76/// Performs type coercion for aggregate function arguments.
77///
78/// Returns the data types to which each argument must be coerced to
79/// match `signature`.
80///
81/// For more details on coercion in general, please see the
82/// [`type_coercion`](crate::type_coercion) module.
83pub fn data_types_with_aggregate_udf(
84    current_types: &[DataType],
85    func: &AggregateUDF,
86) -> Result<Vec<DataType>> {
87    let signature = func.signature();
88    let type_signature = &signature.type_signature;
89
90    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
91        if type_signature.supports_zero_argument() {
92            return Ok(vec![]);
93        } else if type_signature.used_to_support_zero_arguments() {
94            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
95            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
96        } else {
97            return plan_err!("'{}' does not support zero arguments", func.name());
98        }
99    }
100
101    let valid_types =
102        get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;
103    if valid_types
104        .iter()
105        .any(|data_type| data_type == current_types)
106    {
107        return Ok(current_types.to_vec());
108    }
109
110    try_coerce_types(func.name(), valid_types, current_types, type_signature)
111}
112
113/// Performs type coercion for window function arguments.
114///
115/// Returns the data types to which each argument must be coerced to
116/// match `signature`.
117///
118/// For more details on coercion in general, please see the
119/// [`type_coercion`](crate::type_coercion) module.
120pub fn data_types_with_window_udf(
121    current_types: &[DataType],
122    func: &WindowUDF,
123) -> Result<Vec<DataType>> {
124    let signature = func.signature();
125    let type_signature = &signature.type_signature;
126
127    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
128        if type_signature.supports_zero_argument() {
129            return Ok(vec![]);
130        } else if type_signature.used_to_support_zero_arguments() {
131            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
132            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
133        } else {
134            return plan_err!("'{}' does not support zero arguments", func.name());
135        }
136    }
137
138    let valid_types =
139        get_valid_types_with_window_udf(type_signature, current_types, func)?;
140    if valid_types
141        .iter()
142        .any(|data_type| data_type == current_types)
143    {
144        return Ok(current_types.to_vec());
145    }
146
147    try_coerce_types(func.name(), valid_types, current_types, type_signature)
148}
149
150/// Performs type coercion for function arguments.
151///
152/// Returns the data types to which each argument must be coerced to
153/// match `signature`.
154///
155/// For more details on coercion in general, please see the
156/// [`type_coercion`](crate::type_coercion) module.
157pub fn data_types(
158    function_name: impl AsRef<str>,
159    current_types: &[DataType],
160    signature: &Signature,
161) -> Result<Vec<DataType>> {
162    let type_signature = &signature.type_signature;
163
164    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
165        if type_signature.supports_zero_argument() {
166            return Ok(vec![]);
167        } else if type_signature.used_to_support_zero_arguments() {
168            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
169            return plan_err!(
170                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
171                function_name.as_ref()
172            );
173        } else {
174            return plan_err!(
175                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
176                function_name.as_ref()
177            );
178        }
179    }
180
181    let valid_types =
182        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
183    if valid_types
184        .iter()
185        .any(|data_type| data_type == current_types)
186    {
187        return Ok(current_types.to_vec());
188    }
189
190    try_coerce_types(
191        function_name.as_ref(),
192        valid_types,
193        current_types,
194        type_signature,
195    )
196}
197
198fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
199    if let TypeSignature::OneOf(signatures) = type_signature {
200        return signatures.iter().all(is_well_supported_signature);
201    }
202
203    matches!(
204        type_signature,
205        TypeSignature::UserDefined
206            | TypeSignature::Numeric(_)
207            | TypeSignature::String(_)
208            | TypeSignature::Coercible(_)
209            | TypeSignature::Any(_)
210            | TypeSignature::Nullary
211            | TypeSignature::Comparable(_)
212    )
213}
214
215fn try_coerce_types(
216    function_name: &str,
217    valid_types: Vec<Vec<DataType>>,
218    current_types: &[DataType],
219    type_signature: &TypeSignature,
220) -> Result<Vec<DataType>> {
221    let mut valid_types = valid_types;
222
223    // Well-supported signature that returns exact valid types.
224    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
225        // There may be many valid types if valid signature is OneOf
226        // Otherwise, there should be only one valid type
227        if !type_signature.is_one_of() {
228            assert_eq!(valid_types.len(), 1);
229        }
230
231        let valid_types = valid_types.swap_remove(0);
232        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
233            return Ok(t);
234        }
235    } else {
236        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
237        // Try and coerce the argument types to match the signature, returning the
238        // coerced types from the first matching signature.
239        for valid_types in valid_types {
240            if let Some(types) = maybe_data_types(&valid_types, current_types) {
241                return Ok(types);
242            }
243        }
244    }
245
246    // none possible -> Error
247    plan_err!(
248        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
249    )
250}
251
252fn get_valid_types_with_scalar_udf(
253    signature: &TypeSignature,
254    current_types: &[DataType],
255    func: &ScalarUDF,
256) -> Result<Vec<Vec<DataType>>> {
257    match signature {
258        TypeSignature::UserDefined => match func.coerce_types(current_types) {
259            Ok(coerced_types) => Ok(vec![coerced_types]),
260            Err(e) => exec_err!(
261                "Function '{}' user-defined coercion failed with {:?}",
262                func.name(),
263                e.strip_backtrace()
264            ),
265        },
266        TypeSignature::OneOf(signatures) => {
267            let mut res = vec![];
268            let mut errors = vec![];
269            for sig in signatures {
270                match get_valid_types_with_scalar_udf(sig, current_types, func) {
271                    Ok(valid_types) => {
272                        res.extend(valid_types);
273                    }
274                    Err(e) => {
275                        errors.push(e.to_string());
276                    }
277                }
278            }
279
280            // Every signature failed, return the joined error
281            if res.is_empty() {
282                internal_err!(
283                    "Function '{}' failed to match any signature, errors: {}",
284                    func.name(),
285                    errors.join(",")
286                )
287            } else {
288                Ok(res)
289            }
290        }
291        _ => get_valid_types(func.name(), signature, current_types),
292    }
293}
294
295fn get_valid_types_with_aggregate_udf(
296    signature: &TypeSignature,
297    current_types: &[DataType],
298    func: &AggregateUDF,
299) -> Result<Vec<Vec<DataType>>> {
300    let valid_types = match signature {
301        TypeSignature::UserDefined => match func.coerce_types(current_types) {
302            Ok(coerced_types) => vec![coerced_types],
303            Err(e) => {
304                return exec_err!(
305                    "Function '{}' user-defined coercion failed with {:?}",
306                    func.name(),
307                    e.strip_backtrace()
308                )
309            }
310        },
311        TypeSignature::OneOf(signatures) => signatures
312            .iter()
313            .filter_map(|t| {
314                get_valid_types_with_aggregate_udf(t, current_types, func).ok()
315            })
316            .flatten()
317            .collect::<Vec<_>>(),
318        _ => get_valid_types(func.name(), signature, current_types)?,
319    };
320
321    Ok(valid_types)
322}
323
324fn get_valid_types_with_window_udf(
325    signature: &TypeSignature,
326    current_types: &[DataType],
327    func: &WindowUDF,
328) -> Result<Vec<Vec<DataType>>> {
329    let valid_types = match signature {
330        TypeSignature::UserDefined => match func.coerce_types(current_types) {
331            Ok(coerced_types) => vec![coerced_types],
332            Err(e) => {
333                return exec_err!(
334                    "Function '{}' user-defined coercion failed with {:?}",
335                    func.name(),
336                    e.strip_backtrace()
337                )
338            }
339        },
340        TypeSignature::OneOf(signatures) => signatures
341            .iter()
342            .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
343            .flatten()
344            .collect::<Vec<_>>(),
345        _ => get_valid_types(func.name(), signature, current_types)?,
346    };
347
348    Ok(valid_types)
349}
350
351/// Returns a Vec of all possible valid argument types for the given signature.
352fn get_valid_types(
353    function_name: &str,
354    signature: &TypeSignature,
355    current_types: &[DataType],
356) -> Result<Vec<Vec<DataType>>> {
357    fn array_valid_types(
358        function_name: &str,
359        current_types: &[DataType],
360        arguments: &[ArrayFunctionArgument],
361        array_coercion: Option<&ListCoercion>,
362    ) -> Result<Vec<Vec<DataType>>> {
363        if current_types.len() != arguments.len() {
364            return Ok(vec![vec![]]);
365        }
366
367        let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
368            if *arg == ArrayFunctionArgument::Array {
369                Some(idx)
370            } else {
371                None
372            }
373        });
374        let Some(array_idx) = array_idx else {
375            return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument"));
376        };
377        let Some(array_type) = array(&current_types[array_idx]) else {
378            return Ok(vec![vec![]]);
379        };
380
381        // We need to find the coerced base type, mainly for cases like:
382        // `array_append(List(null), i64)` -> `List(i64)`
383        let mut new_base_type = datafusion_common::utils::base_type(&array_type);
384        for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
385            match argument_type {
386                ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => {
387                    new_base_type =
388                        coerce_array_types(function_name, current_type, &new_base_type)?;
389                }
390                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {}
391            }
392        }
393        let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
394            &array_type,
395            &new_base_type,
396            array_coercion,
397        );
398
399        let new_elem_type = match new_array_type {
400            DataType::List(ref field)
401            | DataType::LargeList(ref field)
402            | DataType::FixedSizeList(ref field, _) => field.data_type(),
403            _ => return Ok(vec![vec![]]),
404        };
405
406        let mut valid_types = Vec::with_capacity(arguments.len());
407        for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
408            let valid_type = match argument_type {
409                ArrayFunctionArgument::Element => new_elem_type.clone(),
410                ArrayFunctionArgument::Index => DataType::Int64,
411                ArrayFunctionArgument::String => DataType::Utf8,
412                ArrayFunctionArgument::Array => {
413                    let Some(current_type) = array(current_type) else {
414                        return Ok(vec![vec![]]);
415                    };
416                    let new_type =
417                        datafusion_common::utils::coerced_type_with_base_type_only(
418                            &current_type,
419                            &new_base_type,
420                            array_coercion,
421                        );
422                    // All array arguments must be coercible to the same type
423                    if new_type != new_array_type {
424                        return Ok(vec![vec![]]);
425                    }
426                    new_type
427                }
428            };
429            valid_types.push(valid_type);
430        }
431
432        Ok(vec![valid_types])
433    }
434
435    fn array(array_type: &DataType) -> Option<DataType> {
436        match array_type {
437            DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
438            DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
439            DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field(
440                DataType::Int64,
441                true,
442            )))),
443            _ => None,
444        }
445    }
446
447    fn coerce_array_types(
448        function_name: &str,
449        current_type: &DataType,
450        base_type: &DataType,
451    ) -> Result<DataType> {
452        let current_base_type = datafusion_common::utils::base_type(current_type);
453        let new_base_type = comparison_coercion(base_type, &current_base_type);
454        new_base_type.ok_or_else(|| {
455            internal_datafusion_err!(
456                "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}"
457            )
458        })
459    }
460
461    fn recursive_array(array_type: &DataType) -> Option<DataType> {
462        match array_type {
463            DataType::List(_)
464            | DataType::LargeList(_)
465            | DataType::FixedSizeList(_, _) => {
466                let array_type = coerced_fixed_size_list_to_list(array_type);
467                Some(array_type)
468            }
469            _ => None,
470        }
471    }
472
473    fn function_length_check(
474        function_name: &str,
475        length: usize,
476        expected_length: usize,
477    ) -> Result<()> {
478        if length != expected_length {
479            return plan_err!(
480                "Function '{function_name}' expects {expected_length} arguments but received {length}"
481            );
482        }
483        Ok(())
484    }
485
486    let valid_types = match signature {
487        TypeSignature::Variadic(valid_types) => valid_types
488            .iter()
489            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
490            .collect(),
491        TypeSignature::String(number) => {
492            function_length_check(function_name, current_types.len(), *number)?;
493
494            let mut new_types = Vec::with_capacity(current_types.len());
495            for data_type in current_types.iter() {
496                let logical_data_type: NativeType = data_type.into();
497                if logical_data_type == NativeType::String {
498                    new_types.push(data_type.to_owned());
499                } else if logical_data_type == NativeType::Null {
500                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
501                    new_types.push(DataType::Utf8);
502                } else {
503                    return plan_err!(
504                        "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
505                    );
506                }
507            }
508
509            // Find the common string type for the given types
510            fn find_common_type(
511                function_name: &str,
512                lhs_type: &DataType,
513                rhs_type: &DataType,
514            ) -> Result<DataType> {
515                match (lhs_type, rhs_type) {
516                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
517                        find_common_type(function_name, lhs, rhs)
518                    }
519                    (DataType::Dictionary(_, v), other)
520                    | (other, DataType::Dictionary(_, v)) => {
521                        find_common_type(function_name, v, other)
522                    }
523                    _ => {
524                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
525                            Ok(coerced_type)
526                        } else {
527                            plan_err!(
528                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
529                            )
530                        }
531                    }
532                }
533            }
534
535            // Length checked above, safe to unwrap
536            let mut coerced_type = new_types.first().unwrap().to_owned();
537            for t in new_types.iter().skip(1) {
538                coerced_type = find_common_type(function_name, &coerced_type, t)?;
539            }
540
541            fn base_type_or_default_type(data_type: &DataType) -> DataType {
542                if let DataType::Dictionary(_, v) = data_type {
543                    base_type_or_default_type(v)
544                } else {
545                    data_type.to_owned()
546                }
547            }
548
549            vec![vec![base_type_or_default_type(&coerced_type); *number]]
550        }
551        TypeSignature::Numeric(number) => {
552            function_length_check(function_name, current_types.len(), *number)?;
553
554            // Find common numeric type among given types except string
555            let mut valid_type = current_types.first().unwrap().to_owned();
556            for t in current_types.iter().skip(1) {
557                let logical_data_type: NativeType = t.into();
558                if logical_data_type == NativeType::Null {
559                    continue;
560                }
561
562                if !logical_data_type.is_numeric() {
563                    return plan_err!(
564                        "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
565                    );
566                }
567
568                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
569                    valid_type = coerced_type;
570                } else {
571                    return plan_err!(
572                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
573                    );
574                }
575            }
576
577            let logical_data_type: NativeType = valid_type.clone().into();
578            // Fallback to default type if we don't know which type to coerced to
579            // f64 is chosen since most of the math functions utilize Signature::numeric,
580            // and their default type is double precision
581            if logical_data_type == NativeType::Null {
582                valid_type = DataType::Float64;
583            } else if !logical_data_type.is_numeric() {
584                return plan_err!(
585                    "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
586                );
587            }
588
589            vec![vec![valid_type; *number]]
590        }
591        TypeSignature::Comparable(num) => {
592            function_length_check(function_name, current_types.len(), *num)?;
593            let mut target_type = current_types[0].to_owned();
594            for data_type in current_types.iter().skip(1) {
595                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
596                    target_type = dt;
597                } else {
598                    return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
599                }
600            }
601            // Convert null to String type.
602            if target_type.is_null() {
603                vec![vec![DataType::Utf8View; *num]]
604            } else {
605                vec![vec![target_type; *num]]
606            }
607        }
608        TypeSignature::Coercible(param_types) => {
609            function_length_check(function_name, current_types.len(), param_types.len())?;
610
611            let mut new_types = Vec::with_capacity(current_types.len());
612            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
613                let current_native_type: NativeType = current_type.into();
614
615                if param.desired_type().matches_native_type(&current_native_type) {
616                    let casted_type = param.desired_type().default_casted_type(
617                        &current_native_type,
618                        current_type,
619                    )?;
620
621                    new_types.push(casted_type);
622                } else if param
623                .allowed_source_types()
624                .iter()
625                .any(|t| t.matches_native_type(&current_native_type)) {
626                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
627                    let default_casted_type = param.default_casted_type().unwrap();
628                    let casted_type = default_casted_type.default_cast_for(current_type)?;
629                    new_types.push(casted_type);
630                } else {
631                    return internal_err!(
632                        "Expect {} but received {}, DataType: {}",
633                        param.desired_type(),
634                        current_native_type,
635                        current_type
636                    );
637                }
638            }
639
640            vec![new_types]
641        }
642        TypeSignature::Uniform(number, valid_types) => {
643            if *number == 0 {
644                return plan_err!("The function '{function_name}' expected at least one argument");
645            }
646
647            valid_types
648                .iter()
649                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
650                .collect()
651        }
652        TypeSignature::UserDefined => {
653            return internal_err!(
654                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
655            )
656        }
657        TypeSignature::VariadicAny => {
658            if current_types.is_empty() {
659                return plan_err!(
660                    "Function '{function_name}' expected at least one argument but received 0"
661                );
662            }
663            vec![current_types.to_vec()]
664        }
665        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
666        TypeSignature::ArraySignature(ref function_signature) => match function_signature {
667            ArrayFunctionSignature::Array { arguments, array_coercion, } => {
668                array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
669            }
670            ArrayFunctionSignature::RecursiveArray => {
671                if current_types.len() != 1 {
672                    return Ok(vec![vec![]]);
673                }
674                recursive_array(&current_types[0])
675                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
676            }
677            ArrayFunctionSignature::MapArray => {
678                if current_types.len() != 1 {
679                    return Ok(vec![vec![]]);
680                }
681
682                match &current_types[0] {
683                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
684                    _ => vec![vec![]],
685                }
686            }
687        },
688        TypeSignature::Nullary => {
689            if !current_types.is_empty() {
690                return plan_err!(
691                    "The function '{function_name}' expected zero argument but received {}",
692                    current_types.len()
693                );
694            }
695            vec![vec![]]
696        }
697        TypeSignature::Any(number) => {
698            if current_types.is_empty() {
699                return plan_err!(
700                    "The function '{function_name}' expected at least one argument but received 0"
701                );
702            }
703
704            if current_types.len() != *number {
705                return plan_err!(
706                    "The function '{function_name}' expected {number} arguments but received {}",
707                    current_types.len()
708                );
709            }
710            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
711        }
712        TypeSignature::OneOf(types) => types
713            .iter()
714            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
715            .flatten()
716            .collect::<Vec<_>>(),
717    };
718
719    Ok(valid_types)
720}
721
722/// Try to coerce the current argument types to match the given `valid_types`.
723///
724/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
725/// but was called with `(int32, int64)`, this function could match the
726/// valid_types by coercing the first argument to `int64`, and would return
727/// `Some([int64, int64])`.
728fn maybe_data_types(
729    valid_types: &[DataType],
730    current_types: &[DataType],
731) -> Option<Vec<DataType>> {
732    if valid_types.len() != current_types.len() {
733        return None;
734    }
735
736    let mut new_type = Vec::with_capacity(valid_types.len());
737    for (i, valid_type) in valid_types.iter().enumerate() {
738        let current_type = &current_types[i];
739
740        if current_type == valid_type {
741            new_type.push(current_type.clone())
742        } else {
743            // attempt to coerce.
744            // TODO: Replace with `can_cast_types` after failing cases are resolved
745            // (they need new signature that returns exactly valid types instead of list of possible valid types).
746            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
747                new_type.push(coerced_type)
748            } else {
749                // not possible
750                return None;
751            }
752        }
753    }
754    Some(new_type)
755}
756
757/// Check if the current argument types can be coerced to match the given `valid_types`
758/// unlike `maybe_data_types`, this function does not coerce the types.
759/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
760fn maybe_data_types_without_coercion(
761    valid_types: &[DataType],
762    current_types: &[DataType],
763) -> Option<Vec<DataType>> {
764    if valid_types.len() != current_types.len() {
765        return None;
766    }
767
768    let mut new_type = Vec::with_capacity(valid_types.len());
769    for (i, valid_type) in valid_types.iter().enumerate() {
770        let current_type = &current_types[i];
771
772        if current_type == valid_type {
773            new_type.push(current_type.clone())
774        } else if can_cast_types(current_type, valid_type) {
775            // validate the valid type is castable from the current type
776            new_type.push(valid_type.clone())
777        } else {
778            return None;
779        }
780    }
781    Some(new_type)
782}
783
784/// Return true if a value of type `type_from` can be coerced
785/// (losslessly converted) into a value of `type_to`
786///
787/// See the module level documentation for more detail on coercion.
788pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
789    if type_into == type_from {
790        return true;
791    }
792    if let Some(coerced) = coerced_from(type_into, type_from) {
793        return coerced == *type_into;
794    }
795    false
796}
797
798/// Find the coerced type for the given `type_into` and `type_from`.
799/// Returns `None` if coercion is not possible.
800///
801/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
802///
803/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion.
804fn coerced_from<'a>(
805    type_into: &'a DataType,
806    type_from: &'a DataType,
807) -> Option<DataType> {
808    use self::DataType::*;
809
810    // match Dictionary first
811    match (type_into, type_from) {
812        // coerced dictionary first
813        (_, Dictionary(_, value_type))
814            if coerced_from(type_into, value_type).is_some() =>
815        {
816            Some(type_into.clone())
817        }
818        (Dictionary(_, value_type), _)
819            if coerced_from(value_type, type_from).is_some() =>
820        {
821            Some(type_into.clone())
822        }
823        // coerced into type_into
824        (Int8, Null | Int8) => Some(type_into.clone()),
825        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
826        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
827        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
828            Some(type_into.clone())
829        }
830        (UInt8, Null | UInt8) => Some(type_into.clone()),
831        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
832        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
833        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
834        (
835            Float32,
836            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
837            | Float32,
838        ) => Some(type_into.clone()),
839        (
840            Float64,
841            Null
842            | Int8
843            | Int16
844            | Int32
845            | Int64
846            | UInt8
847            | UInt16
848            | UInt32
849            | UInt64
850            | Float32
851            | Float64
852            | Decimal128(_, _),
853        ) => Some(type_into.clone()),
854        (
855            Timestamp(TimeUnit::Nanosecond, None),
856            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
857        ) => Some(type_into.clone()),
858        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
859        // We can go into a Utf8View from a Utf8 or LargeUtf8
860        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
861        // Any type can be coerced into strings
862        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
863        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
864
865        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
866
867        // Only accept list and largelist with the same number of dimensions unless the type is Null.
868        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
869        (List(_) | LargeList(_), _)
870            if datafusion_common::utils::base_type(type_from).eq(&Null)
871                || list_ndims(type_from) == list_ndims(type_into) =>
872        {
873            Some(type_into.clone())
874        }
875        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
876        (
877            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
878            FixedSizeList(f_from, size_from),
879        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
880            Some(data_type) if &data_type != f_into.data_type() => {
881                let new_field =
882                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
883                Some(FixedSizeList(new_field, *size_from))
884            }
885            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
886            _ => None,
887        },
888        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
889            match type_from {
890                Timestamp(_, Some(from_tz)) => {
891                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
892                }
893                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
894                    // In the absence of any other information assume the time zone is "+00" (UTC).
895                    Some(Timestamp(*unit, Some("+00".into())))
896                }
897                _ => None,
898            }
899        }
900        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
901            Some(type_into.clone())
902        }
903        _ => None,
904    }
905}
906
907#[cfg(test)]
908mod tests {
909
910    use crate::Volatility;
911
912    use super::*;
913    use arrow::datatypes::Field;
914    use datafusion_common::assert_contains;
915
916    #[test]
917    fn test_string_conversion() {
918        let cases = vec![
919            (DataType::Utf8View, DataType::Utf8, true),
920            (DataType::Utf8View, DataType::LargeUtf8, true),
921        ];
922
923        for case in cases {
924            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
925        }
926    }
927
928    #[test]
929    fn test_maybe_data_types() {
930        // this vec contains: arg1, arg2, expected result
931        let cases = vec![
932            // 2 entries, same values
933            (
934                vec![DataType::UInt8, DataType::UInt16],
935                vec![DataType::UInt8, DataType::UInt16],
936                Some(vec![DataType::UInt8, DataType::UInt16]),
937            ),
938            // 2 entries, can coerce values
939            (
940                vec![DataType::UInt16, DataType::UInt16],
941                vec![DataType::UInt8, DataType::UInt16],
942                Some(vec![DataType::UInt16, DataType::UInt16]),
943            ),
944            // 0 entries, all good
945            (vec![], vec![], Some(vec![])),
946            // 2 entries, can't coerce
947            (
948                vec![DataType::Boolean, DataType::UInt16],
949                vec![DataType::UInt8, DataType::UInt16],
950                None,
951            ),
952            // u32 -> u16 is possible
953            (
954                vec![DataType::Boolean, DataType::UInt32],
955                vec![DataType::Boolean, DataType::UInt16],
956                Some(vec![DataType::Boolean, DataType::UInt32]),
957            ),
958            // UTF8 -> Timestamp
959            (
960                vec![
961                    DataType::Timestamp(TimeUnit::Nanosecond, None),
962                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
963                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
964                ],
965                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
966                Some(vec![
967                    DataType::Timestamp(TimeUnit::Nanosecond, None),
968                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
969                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
970                ]),
971            ),
972        ];
973
974        for case in cases {
975            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
976        }
977    }
978
979    #[test]
980    fn test_get_valid_types_numeric() -> Result<()> {
981        let get_valid_types_flatten =
982            |function_name: &str,
983             signature: &TypeSignature,
984             current_types: &[DataType]| {
985                get_valid_types(function_name, signature, current_types)
986                    .unwrap()
987                    .into_iter()
988                    .flatten()
989                    .collect::<Vec<_>>()
990            };
991
992        // Trivial case.
993        let got = get_valid_types_flatten(
994            "test",
995            &TypeSignature::Numeric(1),
996            &[DataType::Int32],
997        );
998        assert_eq!(got, [DataType::Int32]);
999
1000        // Args are coerced into a common numeric type.
1001        let got = get_valid_types_flatten(
1002            "test",
1003            &TypeSignature::Numeric(2),
1004            &[DataType::Int32, DataType::Int64],
1005        );
1006        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1007
1008        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1009        let got = get_valid_types_flatten(
1010            "test",
1011            &TypeSignature::Numeric(3),
1012            &[DataType::Int32, DataType::Int64, DataType::Float64],
1013        );
1014        assert_eq!(
1015            got,
1016            [DataType::Float64, DataType::Float64, DataType::Float64]
1017        );
1018
1019        // Cannot coerce args to a common numeric type.
1020        let got = get_valid_types(
1021            "test",
1022            &TypeSignature::Numeric(2),
1023            &[DataType::Int32, DataType::Utf8],
1024        )
1025        .unwrap_err();
1026        assert_contains!(
1027            got.to_string(),
1028            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1029        );
1030
1031        // Fallbacks to float64 if the arg is of type null.
1032        let got = get_valid_types_flatten(
1033            "test",
1034            &TypeSignature::Numeric(1),
1035            &[DataType::Null],
1036        );
1037        assert_eq!(got, [DataType::Float64]);
1038
1039        // Rejects non-numeric arg.
1040        let got = get_valid_types(
1041            "test",
1042            &TypeSignature::Numeric(1),
1043            &[DataType::Timestamp(TimeUnit::Second, None)],
1044        )
1045        .unwrap_err();
1046        assert_contains!(
1047            got.to_string(),
1048            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1049        );
1050
1051        Ok(())
1052    }
1053
1054    #[test]
1055    fn test_get_valid_types_one_of() -> Result<()> {
1056        let signature =
1057            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1058
1059        let invalid_types = get_valid_types(
1060            "test",
1061            &signature,
1062            &[DataType::Int32, DataType::Int32, DataType::Int32],
1063        )?;
1064        assert_eq!(invalid_types.len(), 0);
1065
1066        let args = vec![DataType::Int32, DataType::Int32];
1067        let valid_types = get_valid_types("test", &signature, &args)?;
1068        assert_eq!(valid_types.len(), 1);
1069        assert_eq!(valid_types[0], args);
1070
1071        let args = vec![DataType::Int32];
1072        let valid_types = get_valid_types("test", &signature, &args)?;
1073        assert_eq!(valid_types.len(), 1);
1074        assert_eq!(valid_types[0], args);
1075
1076        Ok(())
1077    }
1078
1079    #[test]
1080    fn test_get_valid_types_length_check() -> Result<()> {
1081        let signature = TypeSignature::Numeric(1);
1082
1083        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1084        assert_contains!(
1085            err.to_string(),
1086            "Function 'test' expects 1 arguments but received 0"
1087        );
1088
1089        let err = get_valid_types(
1090            "test",
1091            &signature,
1092            &[DataType::Int32, DataType::Int32, DataType::Int32],
1093        )
1094        .unwrap_err();
1095        assert_contains!(
1096            err.to_string(),
1097            "Function 'test' expects 1 arguments but received 3"
1098        );
1099
1100        Ok(())
1101    }
1102
1103    #[test]
1104    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1105        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1106        let current_types = vec![
1107            DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size
1108        ];
1109
1110        let signature = Signature::exact(
1111            vec![DataType::FixedSizeList(
1112                Arc::clone(&inner),
1113                FIXED_SIZE_LIST_WILDCARD,
1114            )],
1115            Volatility::Stable,
1116        );
1117
1118        let coerced_data_types = data_types("test", &current_types, &signature)?;
1119        assert_eq!(coerced_data_types, current_types);
1120
1121        // make sure it can't coerce to a different size
1122        let signature = Signature::exact(
1123            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1124            Volatility::Stable,
1125        );
1126        let coerced_data_types = data_types("test", &current_types, &signature);
1127        assert!(coerced_data_types.is_err());
1128
1129        // make sure it works with the same type.
1130        let signature = Signature::exact(
1131            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1132            Volatility::Stable,
1133        );
1134        let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
1135        assert_eq!(coerced_data_types, current_types);
1136
1137        Ok(())
1138    }
1139
1140    #[test]
1141    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1142        let type_into = DataType::FixedSizeList(
1143            Arc::new(Field::new_list_field(
1144                DataType::FixedSizeList(
1145                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1146                    FIXED_SIZE_LIST_WILDCARD,
1147                ),
1148                false,
1149            )),
1150            FIXED_SIZE_LIST_WILDCARD,
1151        );
1152
1153        let type_from = DataType::FixedSizeList(
1154            Arc::new(Field::new_list_field(
1155                DataType::FixedSizeList(
1156                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1157                    4,
1158                ),
1159                false,
1160            )),
1161            3,
1162        );
1163
1164        assert_eq!(
1165            coerced_from(&type_into, &type_from),
1166            Some(DataType::FixedSizeList(
1167                Arc::new(Field::new_list_field(
1168                    DataType::FixedSizeList(
1169                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1170                        4,
1171                    ),
1172                    false,
1173                )),
1174                3,
1175            ))
1176        );
1177
1178        Ok(())
1179    }
1180
1181    #[test]
1182    fn test_coerced_from_dictionary() {
1183        let type_into =
1184            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1185        let type_from = DataType::Int64;
1186        assert_eq!(coerced_from(&type_into, &type_from), None);
1187
1188        let type_from =
1189            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1190        let type_into = DataType::Int64;
1191        assert_eq!(
1192            coerced_from(&type_into, &type_from),
1193            Some(type_into.clone())
1194        );
1195    }
1196}