datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::FieldRef;
21use arrow::{
22    compute::can_cast_types,
23    datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27    base_type, coerced_fixed_size_list_to_list, ListCoercion,
28};
29use datafusion_common::{
30    exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36    type_coercion::binary::comparison_coercion_numeric,
37    type_coercion::binary::string_coercion,
38};
39use std::sync::Arc;
40
41/// Performs type coercion for scalar function arguments.
42///
43/// Returns the data types to which each argument must be coerced to
44/// match `signature`.
45///
46/// For more details on coercion in general, please see the
47/// [`type_coercion`](crate::type_coercion) module.
48pub fn data_types_with_scalar_udf(
49    current_types: &[DataType],
50    func: &ScalarUDF,
51) -> Result<Vec<DataType>> {
52    let signature = func.signature();
53    let type_signature = &signature.type_signature;
54
55    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
56        if type_signature.supports_zero_argument() {
57            return Ok(vec![]);
58        } else if type_signature.used_to_support_zero_arguments() {
59            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
60            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
61        } else {
62            return plan_err!("'{}' does not support zero arguments", func.name());
63        }
64    }
65
66    let valid_types =
67        get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
68
69    if valid_types
70        .iter()
71        .any(|data_type| data_type == current_types)
72    {
73        return Ok(current_types.to_vec());
74    }
75
76    try_coerce_types(func.name(), valid_types, current_types, type_signature)
77}
78
79/// Performs type coercion for aggregate function arguments.
80///
81/// Returns the fields to which each argument must be coerced to
82/// match `signature`.
83///
84/// For more details on coercion in general, please see the
85/// [`type_coercion`](crate::type_coercion) module.
86pub fn fields_with_aggregate_udf(
87    current_fields: &[FieldRef],
88    func: &AggregateUDF,
89) -> Result<Vec<FieldRef>> {
90    let signature = func.signature();
91    let type_signature = &signature.type_signature;
92
93    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
94        if type_signature.supports_zero_argument() {
95            return Ok(vec![]);
96        } else if type_signature.used_to_support_zero_arguments() {
97            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
98            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
99        } else {
100            return plan_err!("'{}' does not support zero arguments", func.name());
101        }
102    }
103    let current_types = current_fields
104        .iter()
105        .map(|f| f.data_type())
106        .cloned()
107        .collect::<Vec<_>>();
108
109    let valid_types =
110        get_valid_types_with_aggregate_udf(type_signature, &current_types, func)?;
111    if valid_types
112        .iter()
113        .any(|data_type| data_type == &current_types)
114    {
115        return Ok(current_fields.to_vec());
116    }
117
118    let updated_types =
119        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
120
121    Ok(current_fields
122        .iter()
123        .zip(updated_types)
124        .map(|(current_field, new_type)| {
125            current_field.as_ref().clone().with_data_type(new_type)
126        })
127        .map(Arc::new)
128        .collect())
129}
130
131/// Performs type coercion for window function arguments.
132///
133/// Returns the data types to which each argument must be coerced to
134/// match `signature`.
135///
136/// For more details on coercion in general, please see the
137/// [`type_coercion`](crate::type_coercion) module.
138pub fn fields_with_window_udf(
139    current_fields: &[FieldRef],
140    func: &WindowUDF,
141) -> Result<Vec<FieldRef>> {
142    let signature = func.signature();
143    let type_signature = &signature.type_signature;
144
145    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
146        if type_signature.supports_zero_argument() {
147            return Ok(vec![]);
148        } else if type_signature.used_to_support_zero_arguments() {
149            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
150            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
151        } else {
152            return plan_err!("'{}' does not support zero arguments", func.name());
153        }
154    }
155
156    let current_types = current_fields
157        .iter()
158        .map(|f| f.data_type())
159        .cloned()
160        .collect::<Vec<_>>();
161    let valid_types =
162        get_valid_types_with_window_udf(type_signature, &current_types, func)?;
163    if valid_types
164        .iter()
165        .any(|data_type| data_type == &current_types)
166    {
167        return Ok(current_fields.to_vec());
168    }
169
170    let updated_types =
171        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
172
173    Ok(current_fields
174        .iter()
175        .zip(updated_types)
176        .map(|(current_field, new_type)| {
177            current_field.as_ref().clone().with_data_type(new_type)
178        })
179        .map(Arc::new)
180        .collect())
181}
182
183/// Performs type coercion for function arguments.
184///
185/// Returns the data types to which each argument must be coerced to
186/// match `signature`.
187///
188/// For more details on coercion in general, please see the
189/// [`type_coercion`](crate::type_coercion) module.
190pub fn data_types(
191    function_name: impl AsRef<str>,
192    current_types: &[DataType],
193    signature: &Signature,
194) -> Result<Vec<DataType>> {
195    let type_signature = &signature.type_signature;
196
197    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
198        if type_signature.supports_zero_argument() {
199            return Ok(vec![]);
200        } else if type_signature.used_to_support_zero_arguments() {
201            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
202            return plan_err!(
203                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
204                function_name.as_ref()
205            );
206        } else {
207            return plan_err!(
208                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
209                function_name.as_ref()
210            );
211        }
212    }
213
214    let valid_types =
215        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
216    if valid_types
217        .iter()
218        .any(|data_type| data_type == current_types)
219    {
220        return Ok(current_types.to_vec());
221    }
222
223    try_coerce_types(
224        function_name.as_ref(),
225        valid_types,
226        current_types,
227        type_signature,
228    )
229}
230
231fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
232    if let TypeSignature::OneOf(signatures) = type_signature {
233        return signatures.iter().all(is_well_supported_signature);
234    }
235
236    matches!(
237        type_signature,
238        TypeSignature::UserDefined
239            | TypeSignature::Numeric(_)
240            | TypeSignature::String(_)
241            | TypeSignature::Coercible(_)
242            | TypeSignature::Any(_)
243            | TypeSignature::Nullary
244            | TypeSignature::Comparable(_)
245    )
246}
247
248fn try_coerce_types(
249    function_name: &str,
250    valid_types: Vec<Vec<DataType>>,
251    current_types: &[DataType],
252    type_signature: &TypeSignature,
253) -> Result<Vec<DataType>> {
254    let mut valid_types = valid_types;
255
256    // Well-supported signature that returns exact valid types.
257    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
258        // There may be many valid types if valid signature is OneOf
259        // Otherwise, there should be only one valid type
260        if !type_signature.is_one_of() {
261            assert_eq!(valid_types.len(), 1);
262        }
263
264        let valid_types = valid_types.swap_remove(0);
265        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
266            return Ok(t);
267        }
268    } else {
269        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
270        // Try and coerce the argument types to match the signature, returning the
271        // coerced types from the first matching signature.
272        for valid_types in valid_types {
273            if let Some(types) = maybe_data_types(&valid_types, current_types) {
274                return Ok(types);
275            }
276        }
277    }
278
279    // none possible -> Error
280    plan_err!(
281        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
282    )
283}
284
285fn get_valid_types_with_scalar_udf(
286    signature: &TypeSignature,
287    current_types: &[DataType],
288    func: &ScalarUDF,
289) -> Result<Vec<Vec<DataType>>> {
290    match signature {
291        TypeSignature::UserDefined => match func.coerce_types(current_types) {
292            Ok(coerced_types) => Ok(vec![coerced_types]),
293            Err(e) => exec_err!(
294                "Function '{}' user-defined coercion failed with {:?}",
295                func.name(),
296                e.strip_backtrace()
297            ),
298        },
299        TypeSignature::OneOf(signatures) => {
300            let mut res = vec![];
301            let mut errors = vec![];
302            for sig in signatures {
303                match get_valid_types_with_scalar_udf(sig, current_types, func) {
304                    Ok(valid_types) => {
305                        res.extend(valid_types);
306                    }
307                    Err(e) => {
308                        errors.push(e.to_string());
309                    }
310                }
311            }
312
313            // Every signature failed, return the joined error
314            if res.is_empty() {
315                internal_err!(
316                    "Function '{}' failed to match any signature, errors: {}",
317                    func.name(),
318                    errors.join(",")
319                )
320            } else {
321                Ok(res)
322            }
323        }
324        _ => get_valid_types(func.name(), signature, current_types),
325    }
326}
327
328fn get_valid_types_with_aggregate_udf(
329    signature: &TypeSignature,
330    current_types: &[DataType],
331    func: &AggregateUDF,
332) -> Result<Vec<Vec<DataType>>> {
333    let valid_types = match signature {
334        TypeSignature::UserDefined => match func.coerce_types(current_types) {
335            Ok(coerced_types) => vec![coerced_types],
336            Err(e) => {
337                return exec_err!(
338                    "Function '{}' user-defined coercion failed with {:?}",
339                    func.name(),
340                    e.strip_backtrace()
341                )
342            }
343        },
344        TypeSignature::OneOf(signatures) => signatures
345            .iter()
346            .filter_map(|t| {
347                get_valid_types_with_aggregate_udf(t, current_types, func).ok()
348            })
349            .flatten()
350            .collect::<Vec<_>>(),
351        _ => get_valid_types(func.name(), signature, current_types)?,
352    };
353
354    Ok(valid_types)
355}
356
357fn get_valid_types_with_window_udf(
358    signature: &TypeSignature,
359    current_types: &[DataType],
360    func: &WindowUDF,
361) -> Result<Vec<Vec<DataType>>> {
362    let valid_types = match signature {
363        TypeSignature::UserDefined => match func.coerce_types(current_types) {
364            Ok(coerced_types) => vec![coerced_types],
365            Err(e) => {
366                return exec_err!(
367                    "Function '{}' user-defined coercion failed with {:?}",
368                    func.name(),
369                    e.strip_backtrace()
370                )
371            }
372        },
373        TypeSignature::OneOf(signatures) => signatures
374            .iter()
375            .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
376            .flatten()
377            .collect::<Vec<_>>(),
378        _ => get_valid_types(func.name(), signature, current_types)?,
379    };
380
381    Ok(valid_types)
382}
383
384/// Returns a Vec of all possible valid argument types for the given signature.
385fn get_valid_types(
386    function_name: &str,
387    signature: &TypeSignature,
388    current_types: &[DataType],
389) -> Result<Vec<Vec<DataType>>> {
390    fn array_valid_types(
391        function_name: &str,
392        current_types: &[DataType],
393        arguments: &[ArrayFunctionArgument],
394        array_coercion: Option<&ListCoercion>,
395    ) -> Result<Vec<Vec<DataType>>> {
396        if current_types.len() != arguments.len() {
397            return Ok(vec![vec![]]);
398        }
399
400        let mut large_list = false;
401        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
402        let mut list_sizes = Vec::with_capacity(arguments.len());
403        let mut element_types = Vec::with_capacity(arguments.len());
404        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
405            match argument {
406                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
407                ArrayFunctionArgument::Element => {
408                    element_types.push(current_type.clone())
409                }
410                ArrayFunctionArgument::Array => match current_type {
411                    DataType::Null => element_types.push(DataType::Null),
412                    DataType::List(field) => {
413                        element_types.push(field.data_type().clone());
414                        fixed_size = false;
415                    }
416                    DataType::LargeList(field) => {
417                        element_types.push(field.data_type().clone());
418                        large_list = true;
419                        fixed_size = false;
420                    }
421                    DataType::FixedSizeList(field, size) => {
422                        element_types.push(field.data_type().clone());
423                        list_sizes.push(*size)
424                    }
425                    arg_type => {
426                        plan_err!("{function_name} does not support type {arg_type}")?
427                    }
428                },
429            }
430        }
431
432        let Some(element_type) = type_union_resolution(&element_types) else {
433            return Ok(vec![vec![]]);
434        };
435
436        if !fixed_size {
437            list_sizes.clear()
438        }
439
440        let mut list_sizes = list_sizes.into_iter();
441        let valid_types = arguments.iter().zip(current_types.iter()).map(
442            |(argument_type, current_type)| match argument_type {
443                ArrayFunctionArgument::Index => DataType::Int64,
444                ArrayFunctionArgument::String => DataType::Utf8,
445                ArrayFunctionArgument::Element => element_type.clone(),
446                ArrayFunctionArgument::Array => {
447                    if current_type.is_null() {
448                        DataType::Null
449                    } else if large_list {
450                        DataType::new_large_list(element_type.clone(), true)
451                    } else if let Some(size) = list_sizes.next() {
452                        DataType::new_fixed_size_list(element_type.clone(), size, true)
453                    } else {
454                        DataType::new_list(element_type.clone(), true)
455                    }
456                }
457            },
458        );
459
460        Ok(vec![valid_types.collect()])
461    }
462
463    fn recursive_array(array_type: &DataType) -> Option<DataType> {
464        match array_type {
465            DataType::List(_)
466            | DataType::LargeList(_)
467            | DataType::FixedSizeList(_, _) => {
468                let array_type = coerced_fixed_size_list_to_list(array_type);
469                Some(array_type)
470            }
471            _ => None,
472        }
473    }
474
475    fn function_length_check(
476        function_name: &str,
477        length: usize,
478        expected_length: usize,
479    ) -> Result<()> {
480        if length != expected_length {
481            return plan_err!(
482                "Function '{function_name}' expects {expected_length} arguments but received {length}"
483            );
484        }
485        Ok(())
486    }
487
488    let valid_types = match signature {
489        TypeSignature::Variadic(valid_types) => valid_types
490            .iter()
491            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
492            .collect(),
493        TypeSignature::String(number) => {
494            function_length_check(function_name, current_types.len(), *number)?;
495
496            let mut new_types = Vec::with_capacity(current_types.len());
497            for data_type in current_types.iter() {
498                let logical_data_type: NativeType = data_type.into();
499                if logical_data_type == NativeType::String {
500                    new_types.push(data_type.to_owned());
501                } else if logical_data_type == NativeType::Null {
502                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
503                    new_types.push(DataType::Utf8);
504                } else {
505                    return plan_err!(
506                        "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
507                    );
508                }
509            }
510
511            // Find the common string type for the given types
512            fn find_common_type(
513                function_name: &str,
514                lhs_type: &DataType,
515                rhs_type: &DataType,
516            ) -> Result<DataType> {
517                match (lhs_type, rhs_type) {
518                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
519                        find_common_type(function_name, lhs, rhs)
520                    }
521                    (DataType::Dictionary(_, v), other)
522                    | (other, DataType::Dictionary(_, v)) => {
523                        find_common_type(function_name, v, other)
524                    }
525                    _ => {
526                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
527                            Ok(coerced_type)
528                        } else {
529                            plan_err!(
530                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
531                            )
532                        }
533                    }
534                }
535            }
536
537            // Length checked above, safe to unwrap
538            let mut coerced_type = new_types.first().unwrap().to_owned();
539            for t in new_types.iter().skip(1) {
540                coerced_type = find_common_type(function_name, &coerced_type, t)?;
541            }
542
543            fn base_type_or_default_type(data_type: &DataType) -> DataType {
544                if let DataType::Dictionary(_, v) = data_type {
545                    base_type_or_default_type(v)
546                } else {
547                    data_type.to_owned()
548                }
549            }
550
551            vec![vec![base_type_or_default_type(&coerced_type); *number]]
552        }
553        TypeSignature::Numeric(number) => {
554            function_length_check(function_name, current_types.len(), *number)?;
555
556            // Find common numeric type among given types except string
557            let mut valid_type = current_types.first().unwrap().to_owned();
558            for t in current_types.iter().skip(1) {
559                let logical_data_type: NativeType = t.into();
560                if logical_data_type == NativeType::Null {
561                    continue;
562                }
563
564                if !logical_data_type.is_numeric() {
565                    return plan_err!(
566                        "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
567                    );
568                }
569
570                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
571                    valid_type = coerced_type;
572                } else {
573                    return plan_err!(
574                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
575                    );
576                }
577            }
578
579            let logical_data_type: NativeType = valid_type.clone().into();
580            // Fallback to default type if we don't know which type to coerced to
581            // f64 is chosen since most of the math functions utilize Signature::numeric,
582            // and their default type is double precision
583            if logical_data_type == NativeType::Null {
584                valid_type = DataType::Float64;
585            } else if !logical_data_type.is_numeric() {
586                return plan_err!(
587                    "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
588                );
589            }
590
591            vec![vec![valid_type; *number]]
592        }
593        TypeSignature::Comparable(num) => {
594            function_length_check(function_name, current_types.len(), *num)?;
595            let mut target_type = current_types[0].to_owned();
596            for data_type in current_types.iter().skip(1) {
597                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
598                    target_type = dt;
599                } else {
600                    return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
601                }
602            }
603            // Convert null to String type.
604            if target_type.is_null() {
605                vec![vec![DataType::Utf8View; *num]]
606            } else {
607                vec![vec![target_type; *num]]
608            }
609        }
610        TypeSignature::Coercible(param_types) => {
611            function_length_check(function_name, current_types.len(), param_types.len())?;
612
613            let mut new_types = Vec::with_capacity(current_types.len());
614            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
615                let current_native_type: NativeType = current_type.into();
616
617                if param.desired_type().matches_native_type(&current_native_type) {
618                    let casted_type = param.desired_type().default_casted_type(
619                        &current_native_type,
620                        current_type,
621                    )?;
622
623                    new_types.push(casted_type);
624                } else if param
625                .allowed_source_types()
626                .iter()
627                .any(|t| t.matches_native_type(&current_native_type)) {
628                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
629                    let default_casted_type = param.default_casted_type().unwrap();
630                    let casted_type = default_casted_type.default_cast_for(current_type)?;
631                    new_types.push(casted_type);
632                } else {
633                    return internal_err!(
634                        "Expect {} but received {}, DataType: {}",
635                        param.desired_type(),
636                        current_native_type,
637                        current_type
638                    );
639                }
640            }
641
642            vec![new_types]
643        }
644        TypeSignature::Uniform(number, valid_types) => {
645            if *number == 0 {
646                return plan_err!("The function '{function_name}' expected at least one argument");
647            }
648
649            valid_types
650                .iter()
651                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
652                .collect()
653        }
654        TypeSignature::UserDefined => {
655            return internal_err!(
656                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
657            )
658        }
659        TypeSignature::VariadicAny => {
660            if current_types.is_empty() {
661                return plan_err!(
662                    "Function '{function_name}' expected at least one argument but received 0"
663                );
664            }
665            vec![current_types.to_vec()]
666        }
667        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
668        TypeSignature::ArraySignature(ref function_signature) => match function_signature {
669            ArrayFunctionSignature::Array { arguments, array_coercion, } => {
670                array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
671            }
672            ArrayFunctionSignature::RecursiveArray => {
673                if current_types.len() != 1 {
674                    return Ok(vec![vec![]]);
675                }
676                recursive_array(&current_types[0])
677                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
678            }
679            ArrayFunctionSignature::MapArray => {
680                if current_types.len() != 1 {
681                    return Ok(vec![vec![]]);
682                }
683
684                match &current_types[0] {
685                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
686                    _ => vec![vec![]],
687                }
688            }
689        },
690        TypeSignature::Nullary => {
691            if !current_types.is_empty() {
692                return plan_err!(
693                    "The function '{function_name}' expected zero argument but received {}",
694                    current_types.len()
695                );
696            }
697            vec![vec![]]
698        }
699        TypeSignature::Any(number) => {
700            if current_types.is_empty() {
701                return plan_err!(
702                    "The function '{function_name}' expected at least one argument but received 0"
703                );
704            }
705
706            if current_types.len() != *number {
707                return plan_err!(
708                    "The function '{function_name}' expected {number} arguments but received {}",
709                    current_types.len()
710                );
711            }
712            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
713        }
714        TypeSignature::OneOf(types) => types
715            .iter()
716            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
717            .flatten()
718            .collect::<Vec<_>>(),
719    };
720
721    Ok(valid_types)
722}
723
724/// Try to coerce the current argument types to match the given `valid_types`.
725///
726/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
727/// but was called with `(int32, int64)`, this function could match the
728/// valid_types by coercing the first argument to `int64`, and would return
729/// `Some([int64, int64])`.
730fn maybe_data_types(
731    valid_types: &[DataType],
732    current_types: &[DataType],
733) -> Option<Vec<DataType>> {
734    if valid_types.len() != current_types.len() {
735        return None;
736    }
737
738    let mut new_type = Vec::with_capacity(valid_types.len());
739    for (i, valid_type) in valid_types.iter().enumerate() {
740        let current_type = &current_types[i];
741
742        if current_type == valid_type {
743            new_type.push(current_type.clone())
744        } else {
745            // attempt to coerce.
746            // TODO: Replace with `can_cast_types` after failing cases are resolved
747            // (they need new signature that returns exactly valid types instead of list of possible valid types).
748            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
749                new_type.push(coerced_type)
750            } else {
751                // not possible
752                return None;
753            }
754        }
755    }
756    Some(new_type)
757}
758
759/// Check if the current argument types can be coerced to match the given `valid_types`
760/// unlike `maybe_data_types`, this function does not coerce the types.
761/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
762fn maybe_data_types_without_coercion(
763    valid_types: &[DataType],
764    current_types: &[DataType],
765) -> Option<Vec<DataType>> {
766    if valid_types.len() != current_types.len() {
767        return None;
768    }
769
770    let mut new_type = Vec::with_capacity(valid_types.len());
771    for (i, valid_type) in valid_types.iter().enumerate() {
772        let current_type = &current_types[i];
773
774        if current_type == valid_type {
775            new_type.push(current_type.clone())
776        } else if can_cast_types(current_type, valid_type) {
777            // validate the valid type is castable from the current type
778            new_type.push(valid_type.clone())
779        } else {
780            return None;
781        }
782    }
783    Some(new_type)
784}
785
786/// Return true if a value of type `type_from` can be coerced
787/// (losslessly converted) into a value of `type_to`
788///
789/// See the module level documentation for more detail on coercion.
790pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
791    if type_into == type_from {
792        return true;
793    }
794    if let Some(coerced) = coerced_from(type_into, type_from) {
795        return coerced == *type_into;
796    }
797    false
798}
799
800/// Find the coerced type for the given `type_into` and `type_from`.
801/// Returns `None` if coercion is not possible.
802///
803/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
804///
805/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
806fn coerced_from<'a>(
807    type_into: &'a DataType,
808    type_from: &'a DataType,
809) -> Option<DataType> {
810    use self::DataType::*;
811
812    // match Dictionary first
813    match (type_into, type_from) {
814        // coerced dictionary first
815        (_, Dictionary(_, value_type))
816            if coerced_from(type_into, value_type).is_some() =>
817        {
818            Some(type_into.clone())
819        }
820        (Dictionary(_, value_type), _)
821            if coerced_from(value_type, type_from).is_some() =>
822        {
823            Some(type_into.clone())
824        }
825        // coerced into type_into
826        (Int8, Null | Int8) => Some(type_into.clone()),
827        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
828        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
829        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
830            Some(type_into.clone())
831        }
832        (UInt8, Null | UInt8) => Some(type_into.clone()),
833        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
834        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
835        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
836        (
837            Float32,
838            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
839            | Float32,
840        ) => Some(type_into.clone()),
841        (
842            Float64,
843            Null
844            | Int8
845            | Int16
846            | Int32
847            | Int64
848            | UInt8
849            | UInt16
850            | UInt32
851            | UInt64
852            | Float32
853            | Float64
854            | Decimal128(_, _),
855        ) => Some(type_into.clone()),
856        (
857            Timestamp(TimeUnit::Nanosecond, None),
858            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
859        ) => Some(type_into.clone()),
860        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
861        // We can go into a Utf8View from a Utf8 or LargeUtf8
862        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
863        // Any type can be coerced into strings
864        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
865        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
866
867        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
868
869        // Only accept list and largelist with the same number of dimensions unless the type is Null.
870        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
871        (List(_) | LargeList(_), _)
872            if base_type(type_from).is_null()
873                || list_ndims(type_from) == list_ndims(type_into) =>
874        {
875            Some(type_into.clone())
876        }
877        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
878        (
879            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
880            FixedSizeList(f_from, size_from),
881        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
882            Some(data_type) if &data_type != f_into.data_type() => {
883                let new_field =
884                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
885                Some(FixedSizeList(new_field, *size_from))
886            }
887            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
888            _ => None,
889        },
890        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
891            match type_from {
892                Timestamp(_, Some(from_tz)) => {
893                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
894                }
895                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
896                    // In the absence of any other information assume the time zone is "+00" (UTC).
897                    Some(Timestamp(*unit, Some("+00".into())))
898                }
899                _ => None,
900            }
901        }
902        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
903            Some(type_into.clone())
904        }
905        _ => None,
906    }
907}
908
909#[cfg(test)]
910mod tests {
911    use crate::Volatility;
912
913    use super::*;
914    use arrow::datatypes::Field;
915    use datafusion_common::assert_contains;
916
917    #[test]
918    fn test_string_conversion() {
919        let cases = vec![
920            (DataType::Utf8View, DataType::Utf8, true),
921            (DataType::Utf8View, DataType::LargeUtf8, true),
922        ];
923
924        for case in cases {
925            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
926        }
927    }
928
929    #[test]
930    fn test_maybe_data_types() {
931        // this vec contains: arg1, arg2, expected result
932        let cases = vec![
933            // 2 entries, same values
934            (
935                vec![DataType::UInt8, DataType::UInt16],
936                vec![DataType::UInt8, DataType::UInt16],
937                Some(vec![DataType::UInt8, DataType::UInt16]),
938            ),
939            // 2 entries, can coerce values
940            (
941                vec![DataType::UInt16, DataType::UInt16],
942                vec![DataType::UInt8, DataType::UInt16],
943                Some(vec![DataType::UInt16, DataType::UInt16]),
944            ),
945            // 0 entries, all good
946            (vec![], vec![], Some(vec![])),
947            // 2 entries, can't coerce
948            (
949                vec![DataType::Boolean, DataType::UInt16],
950                vec![DataType::UInt8, DataType::UInt16],
951                None,
952            ),
953            // u32 -> u16 is possible
954            (
955                vec![DataType::Boolean, DataType::UInt32],
956                vec![DataType::Boolean, DataType::UInt16],
957                Some(vec![DataType::Boolean, DataType::UInt32]),
958            ),
959            // UTF8 -> Timestamp
960            (
961                vec![
962                    DataType::Timestamp(TimeUnit::Nanosecond, None),
963                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
964                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
965                ],
966                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
967                Some(vec![
968                    DataType::Timestamp(TimeUnit::Nanosecond, None),
969                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
970                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
971                ]),
972            ),
973        ];
974
975        for case in cases {
976            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
977        }
978    }
979
980    #[test]
981    fn test_get_valid_types_numeric() -> Result<()> {
982        let get_valid_types_flatten =
983            |function_name: &str,
984             signature: &TypeSignature,
985             current_types: &[DataType]| {
986                get_valid_types(function_name, signature, current_types)
987                    .unwrap()
988                    .into_iter()
989                    .flatten()
990                    .collect::<Vec<_>>()
991            };
992
993        // Trivial case.
994        let got = get_valid_types_flatten(
995            "test",
996            &TypeSignature::Numeric(1),
997            &[DataType::Int32],
998        );
999        assert_eq!(got, [DataType::Int32]);
1000
1001        // Args are coerced into a common numeric type.
1002        let got = get_valid_types_flatten(
1003            "test",
1004            &TypeSignature::Numeric(2),
1005            &[DataType::Int32, DataType::Int64],
1006        );
1007        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1008
1009        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1010        let got = get_valid_types_flatten(
1011            "test",
1012            &TypeSignature::Numeric(3),
1013            &[DataType::Int32, DataType::Int64, DataType::Float64],
1014        );
1015        assert_eq!(
1016            got,
1017            [DataType::Float64, DataType::Float64, DataType::Float64]
1018        );
1019
1020        // Cannot coerce args to a common numeric type.
1021        let got = get_valid_types(
1022            "test",
1023            &TypeSignature::Numeric(2),
1024            &[DataType::Int32, DataType::Utf8],
1025        )
1026        .unwrap_err();
1027        assert_contains!(
1028            got.to_string(),
1029            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1030        );
1031
1032        // Fallbacks to float64 if the arg is of type null.
1033        let got = get_valid_types_flatten(
1034            "test",
1035            &TypeSignature::Numeric(1),
1036            &[DataType::Null],
1037        );
1038        assert_eq!(got, [DataType::Float64]);
1039
1040        // Rejects non-numeric arg.
1041        let got = get_valid_types(
1042            "test",
1043            &TypeSignature::Numeric(1),
1044            &[DataType::Timestamp(TimeUnit::Second, None)],
1045        )
1046        .unwrap_err();
1047        assert_contains!(
1048            got.to_string(),
1049            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1050        );
1051
1052        Ok(())
1053    }
1054
1055    #[test]
1056    fn test_get_valid_types_one_of() -> Result<()> {
1057        let signature =
1058            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1059
1060        let invalid_types = get_valid_types(
1061            "test",
1062            &signature,
1063            &[DataType::Int32, DataType::Int32, DataType::Int32],
1064        )?;
1065        assert_eq!(invalid_types.len(), 0);
1066
1067        let args = vec![DataType::Int32, DataType::Int32];
1068        let valid_types = get_valid_types("test", &signature, &args)?;
1069        assert_eq!(valid_types.len(), 1);
1070        assert_eq!(valid_types[0], args);
1071
1072        let args = vec![DataType::Int32];
1073        let valid_types = get_valid_types("test", &signature, &args)?;
1074        assert_eq!(valid_types.len(), 1);
1075        assert_eq!(valid_types[0], args);
1076
1077        Ok(())
1078    }
1079
1080    #[test]
1081    fn test_get_valid_types_length_check() -> Result<()> {
1082        let signature = TypeSignature::Numeric(1);
1083
1084        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1085        assert_contains!(
1086            err.to_string(),
1087            "Function 'test' expects 1 arguments but received 0"
1088        );
1089
1090        let err = get_valid_types(
1091            "test",
1092            &signature,
1093            &[DataType::Int32, DataType::Int32, DataType::Int32],
1094        )
1095        .unwrap_err();
1096        assert_contains!(
1097            err.to_string(),
1098            "Function 'test' expects 1 arguments but received 3"
1099        );
1100
1101        Ok(())
1102    }
1103
1104    #[test]
1105    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1106        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1107        let current_types = vec![
1108            DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size
1109        ];
1110
1111        let signature = Signature::exact(
1112            vec![DataType::FixedSizeList(
1113                Arc::clone(&inner),
1114                FIXED_SIZE_LIST_WILDCARD,
1115            )],
1116            Volatility::Stable,
1117        );
1118
1119        let coerced_data_types = data_types("test", &current_types, &signature)?;
1120        assert_eq!(coerced_data_types, current_types);
1121
1122        // make sure it can't coerce to a different size
1123        let signature = Signature::exact(
1124            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1125            Volatility::Stable,
1126        );
1127        let coerced_data_types = data_types("test", &current_types, &signature);
1128        assert!(coerced_data_types.is_err());
1129
1130        // make sure it works with the same type.
1131        let signature = Signature::exact(
1132            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1133            Volatility::Stable,
1134        );
1135        let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
1136        assert_eq!(coerced_data_types, current_types);
1137
1138        Ok(())
1139    }
1140
1141    #[test]
1142    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1143        let type_into = DataType::FixedSizeList(
1144            Arc::new(Field::new_list_field(
1145                DataType::FixedSizeList(
1146                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1147                    FIXED_SIZE_LIST_WILDCARD,
1148                ),
1149                false,
1150            )),
1151            FIXED_SIZE_LIST_WILDCARD,
1152        );
1153
1154        let type_from = DataType::FixedSizeList(
1155            Arc::new(Field::new_list_field(
1156                DataType::FixedSizeList(
1157                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1158                    4,
1159                ),
1160                false,
1161            )),
1162            3,
1163        );
1164
1165        assert_eq!(
1166            coerced_from(&type_into, &type_from),
1167            Some(DataType::FixedSizeList(
1168                Arc::new(Field::new_list_field(
1169                    DataType::FixedSizeList(
1170                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1171                        4,
1172                    ),
1173                    false,
1174                )),
1175                3,
1176            ))
1177        );
1178
1179        Ok(())
1180    }
1181
1182    #[test]
1183    fn test_coerced_from_dictionary() {
1184        let type_into =
1185            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1186        let type_from = DataType::Int64;
1187        assert_eq!(coerced_from(&type_into, &type_from), None);
1188
1189        let type_from =
1190            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1191        let type_into = DataType::Int64;
1192        assert_eq!(
1193            coerced_from(&type_into, &type_from),
1194            Some(type_into.clone())
1195        );
1196    }
1197
1198    #[test]
1199    fn test_get_valid_types_array_and_array() -> Result<()> {
1200        let function = "array_and_array";
1201        let signature = Signature::arrays(
1202            2,
1203            Some(ListCoercion::FixedSizedListToList),
1204            Volatility::Immutable,
1205        );
1206
1207        let data_types = vec![
1208            DataType::new_list(DataType::Int32, true),
1209            DataType::new_large_list(DataType::Float64, true),
1210        ];
1211        assert_eq!(
1212            get_valid_types(function, &signature.type_signature, &data_types)?,
1213            vec![vec![
1214                DataType::new_large_list(DataType::Float64, true),
1215                DataType::new_large_list(DataType::Float64, true),
1216            ]]
1217        );
1218
1219        let data_types = vec![
1220            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1221            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1222        ];
1223        assert_eq!(
1224            get_valid_types(function, &signature.type_signature, &data_types)?,
1225            vec![vec![
1226                DataType::new_list(DataType::Int64, true),
1227                DataType::new_list(DataType::Int64, true),
1228            ]]
1229        );
1230
1231        let data_types = vec![
1232            DataType::new_fixed_size_list(DataType::Null, 3, true),
1233            DataType::new_large_list(DataType::Utf8, true),
1234        ];
1235        assert_eq!(
1236            get_valid_types(function, &signature.type_signature, &data_types)?,
1237            vec![vec![
1238                DataType::new_large_list(DataType::Utf8, true),
1239                DataType::new_large_list(DataType::Utf8, true),
1240            ]]
1241        );
1242
1243        Ok(())
1244    }
1245
1246    #[test]
1247    fn test_get_valid_types_array_and_element() -> Result<()> {
1248        let function = "array_and_element";
1249        let signature = Signature::array_and_element(Volatility::Immutable);
1250
1251        let data_types =
1252            vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1253        assert_eq!(
1254            get_valid_types(function, &signature.type_signature, &data_types)?,
1255            vec![vec![
1256                DataType::new_list(DataType::Float64, true),
1257                DataType::Float64,
1258            ]]
1259        );
1260
1261        let data_types = vec![
1262            DataType::new_large_list(DataType::Int32, true),
1263            DataType::Null,
1264        ];
1265        assert_eq!(
1266            get_valid_types(function, &signature.type_signature, &data_types)?,
1267            vec![vec![
1268                DataType::new_large_list(DataType::Int32, true),
1269                DataType::Int32,
1270            ]]
1271        );
1272
1273        let data_types = vec![
1274            DataType::new_fixed_size_list(DataType::Null, 3, true),
1275            DataType::Utf8,
1276        ];
1277        assert_eq!(
1278            get_valid_types(function, &signature.type_signature, &data_types)?,
1279            vec![vec![
1280                DataType::new_list(DataType::Utf8, true),
1281                DataType::Utf8,
1282            ]]
1283        );
1284
1285        Ok(())
1286    }
1287
1288    #[test]
1289    fn test_get_valid_types_element_and_array() -> Result<()> {
1290        let function = "element_and_array";
1291        let signature = Signature::element_and_array(Volatility::Immutable);
1292
1293        let data_types = vec![
1294            DataType::new_large_list(DataType::Null, false),
1295            DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1296        ];
1297        assert_eq!(
1298            get_valid_types(function, &signature.type_signature, &data_types)?,
1299            vec![vec![
1300                DataType::new_large_list(DataType::Int64, true),
1301                DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1302            ]]
1303        );
1304
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1310        let function = "fixed_size_arrays";
1311        let signature = Signature::arrays(2, None, Volatility::Immutable);
1312
1313        let data_types = vec![
1314            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1315            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1316        ];
1317        assert_eq!(
1318            get_valid_types(function, &signature.type_signature, &data_types)?,
1319            vec![vec![
1320                DataType::new_fixed_size_list(DataType::Int64, 3, true),
1321                DataType::new_fixed_size_list(DataType::Int64, 5, true),
1322            ]]
1323        );
1324
1325        let data_types = vec![
1326            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1327            DataType::new_list(DataType::Int32, true),
1328        ];
1329        assert_eq!(
1330            get_valid_types(function, &signature.type_signature, &data_types)?,
1331            vec![vec![
1332                DataType::new_list(DataType::Int64, true),
1333                DataType::new_list(DataType::Int64, true),
1334            ]]
1335        );
1336
1337        let data_types = vec![
1338            DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1339            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1340        ];
1341        assert_eq!(
1342            get_valid_types(function, &signature.type_signature, &data_types)?,
1343            vec![vec![]]
1344        );
1345
1346        Ok(())
1347    }
1348}