datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::{Field, FieldRef};
21use arrow::{
22    compute::can_cast_types,
23    datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27    ListCoercion, base_type, coerced_fixed_size_list_to_list,
28};
29use datafusion_common::{
30    Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36    type_coercion::binary::comparison_coercion_numeric,
37    type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42/// Extension trait to unify common functionality between [`ScalarUDF`], [`AggregateUDF`]
43/// and [`WindowUDF`] for use by signature coercion functions.
44pub trait UDFCoercionExt {
45    /// Should delegate to [`ScalarUDF::name`], [`AggregateUDF::name`] or [`WindowUDF::name`].
46    fn name(&self) -> &str;
47    /// Should delegate to [`ScalarUDF::signature`], [`AggregateUDF::signature`]
48    /// or [`WindowUDF::signature`].
49    fn signature(&self) -> &Signature;
50    /// Should delegate to [`ScalarUDF::coerce_types`], [`AggregateUDF::coerce_types`]
51    /// or [`WindowUDF::coerce_types`].
52    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
53}
54
55impl UDFCoercionExt for ScalarUDF {
56    fn name(&self) -> &str {
57        self.name()
58    }
59
60    fn signature(&self) -> &Signature {
61        self.signature()
62    }
63
64    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
65        self.coerce_types(arg_types)
66    }
67}
68
69impl UDFCoercionExt for AggregateUDF {
70    fn name(&self) -> &str {
71        self.name()
72    }
73
74    fn signature(&self) -> &Signature {
75        self.signature()
76    }
77
78    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79        self.coerce_types(arg_types)
80    }
81}
82
83impl UDFCoercionExt for WindowUDF {
84    fn name(&self) -> &str {
85        self.name()
86    }
87
88    fn signature(&self) -> &Signature {
89        self.signature()
90    }
91
92    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
93        self.coerce_types(arg_types)
94    }
95}
96
97/// Performs type coercion for scalar function arguments.
98///
99/// Returns the data types to which each argument must be coerced to
100/// match `signature`.
101///
102/// For more details on coercion in general, please see the
103/// [`type_coercion`](crate::type_coercion) module.
104#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
105pub fn data_types_with_scalar_udf(
106    current_types: &[DataType],
107    func: &ScalarUDF,
108) -> Result<Vec<DataType>> {
109    let current_fields = current_types
110        .iter()
111        .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
112        .collect::<Vec<_>>();
113    Ok(fields_with_udf(&current_fields, func)?
114        .iter()
115        .map(|f| f.data_type().clone())
116        .collect())
117}
118
119/// Performs type coercion for aggregate function arguments.
120///
121/// Returns the fields to which each argument must be coerced to
122/// match `signature`.
123///
124/// For more details on coercion in general, please see the
125/// [`type_coercion`](crate::type_coercion) module.
126#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
127pub fn fields_with_aggregate_udf(
128    current_fields: &[FieldRef],
129    func: &AggregateUDF,
130) -> Result<Vec<FieldRef>> {
131    fields_with_udf(current_fields, func)
132}
133
134/// Performs type coercion for window function arguments.
135///
136/// Returns the data types to which each argument must be coerced to
137/// match `signature`.
138///
139/// For more details on coercion in general, please see the
140/// [`type_coercion`](crate::type_coercion) module.
141#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
142pub fn fields_with_window_udf(
143    current_fields: &[FieldRef],
144    func: &WindowUDF,
145) -> Result<Vec<FieldRef>> {
146    fields_with_udf(current_fields, func)
147}
148
149/// Performs type coercion for UDF arguments.
150///
151/// Returns the data types to which each argument must be coerced to
152/// match `signature`.
153///
154/// For more details on coercion in general, please see the
155/// [`type_coercion`](crate::type_coercion) module.
156pub fn fields_with_udf<F: UDFCoercionExt>(
157    current_fields: &[FieldRef],
158    func: &F,
159) -> Result<Vec<FieldRef>> {
160    let signature = func.signature();
161    let type_signature = &signature.type_signature;
162
163    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
164        if type_signature.supports_zero_argument() {
165            return Ok(vec![]);
166        } else if type_signature.used_to_support_zero_arguments() {
167            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
168            return plan_err!(
169                "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
170                func.name()
171            );
172        } else {
173            return plan_err!("'{}' does not support zero arguments", func.name());
174        }
175    }
176    let current_types = current_fields
177        .iter()
178        .map(|f| f.data_type())
179        .cloned()
180        .collect::<Vec<_>>();
181
182    let valid_types = get_valid_types_with_udf(type_signature, &current_types, func)?;
183    if valid_types
184        .iter()
185        .any(|data_type| data_type == &current_types)
186    {
187        return Ok(current_fields.to_vec());
188    }
189
190    let updated_types =
191        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
192
193    Ok(current_fields
194        .iter()
195        .zip(updated_types)
196        .map(|(current_field, new_type)| {
197            current_field.as_ref().clone().with_data_type(new_type)
198        })
199        .map(Arc::new)
200        .collect())
201}
202
203/// Performs type coercion for function arguments.
204///
205/// Returns the data types to which each argument must be coerced to
206/// match `signature`.
207///
208/// For more details on coercion in general, please see the
209/// [`type_coercion`](crate::type_coercion) module.
210#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
211pub fn data_types(
212    function_name: impl AsRef<str>,
213    current_types: &[DataType],
214    signature: &Signature,
215) -> Result<Vec<DataType>> {
216    let type_signature = &signature.type_signature;
217
218    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
219        if type_signature.supports_zero_argument() {
220            return Ok(vec![]);
221        } else if type_signature.used_to_support_zero_arguments() {
222            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
223            return plan_err!(
224                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
225                function_name.as_ref()
226            );
227        } else {
228            return plan_err!(
229                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
230                function_name.as_ref()
231            );
232        }
233    }
234
235    let valid_types =
236        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
237    if valid_types
238        .iter()
239        .any(|data_type| data_type == current_types)
240    {
241        return Ok(current_types.to_vec());
242    }
243
244    try_coerce_types(
245        function_name.as_ref(),
246        valid_types,
247        current_types,
248        type_signature,
249    )
250}
251
252fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
253    match type_signature {
254        TypeSignature::OneOf(type_signatures) => {
255            type_signatures.iter().all(is_well_supported_signature)
256        }
257        TypeSignature::UserDefined
258        | TypeSignature::Numeric(_)
259        | TypeSignature::String(_)
260        | TypeSignature::Coercible(_)
261        | TypeSignature::Any(_)
262        | TypeSignature::Nullary
263        | TypeSignature::Comparable(_) => true,
264        TypeSignature::Variadic(_)
265        | TypeSignature::VariadicAny
266        | TypeSignature::Uniform(_, _)
267        | TypeSignature::Exact(_)
268        | TypeSignature::ArraySignature(_) => false,
269    }
270}
271
272fn try_coerce_types(
273    function_name: &str,
274    valid_types: Vec<Vec<DataType>>,
275    current_types: &[DataType],
276    type_signature: &TypeSignature,
277) -> Result<Vec<DataType>> {
278    let mut valid_types = valid_types;
279
280    // Well-supported signature that returns exact valid types.
281    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
282        // There may be many valid types if valid signature is OneOf
283        // Otherwise, there should be only one valid type
284        if !type_signature.is_one_of() {
285            assert_eq!(valid_types.len(), 1);
286        }
287
288        let valid_types = valid_types.swap_remove(0);
289        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
290            return Ok(t);
291        }
292    } else {
293        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
294        // Try and coerce the argument types to match the signature, returning the
295        // coerced types from the first matching signature.
296        for valid_types in valid_types {
297            if let Some(types) = maybe_data_types(&valid_types, current_types) {
298                return Ok(types);
299            }
300        }
301    }
302
303    // none possible -> Error
304    plan_err!(
305        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
306        current_types.iter().join(", ")
307    )
308}
309
310fn get_valid_types_with_udf<F: UDFCoercionExt>(
311    signature: &TypeSignature,
312    current_types: &[DataType],
313    func: &F,
314) -> Result<Vec<Vec<DataType>>> {
315    let valid_types = match signature {
316        TypeSignature::UserDefined => match func.coerce_types(current_types) {
317            Ok(coerced_types) => vec![coerced_types],
318            Err(e) => {
319                return exec_err!(
320                    "Function '{}' user-defined coercion failed with {:?}",
321                    func.name(),
322                    e.strip_backtrace()
323                );
324            }
325        },
326        TypeSignature::OneOf(signatures) => {
327            let mut res = vec![];
328            let mut errors = vec![];
329            for sig in signatures {
330                match get_valid_types_with_udf(sig, current_types, func) {
331                    Ok(valid_types) => {
332                        res.extend(valid_types);
333                    }
334                    Err(e) => {
335                        errors.push(e.to_string());
336                    }
337                }
338            }
339
340            // Every signature failed, return the joined error
341            if res.is_empty() {
342                return internal_err!(
343                    "Function '{}' failed to match any signature, errors: {}",
344                    func.name(),
345                    errors.join(",")
346                );
347            } else {
348                res
349            }
350        }
351        _ => get_valid_types(func.name(), signature, current_types)?,
352    };
353
354    Ok(valid_types)
355}
356
357/// Returns a Vec of all possible valid argument types for the given signature.
358fn get_valid_types(
359    function_name: &str,
360    signature: &TypeSignature,
361    current_types: &[DataType],
362) -> Result<Vec<Vec<DataType>>> {
363    fn array_valid_types(
364        function_name: &str,
365        current_types: &[DataType],
366        arguments: &[ArrayFunctionArgument],
367        array_coercion: Option<&ListCoercion>,
368    ) -> Result<Vec<Vec<DataType>>> {
369        if current_types.len() != arguments.len() {
370            return Ok(vec![vec![]]);
371        }
372
373        let mut large_list = false;
374        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
375        let mut list_sizes = Vec::with_capacity(arguments.len());
376        let mut element_types = Vec::with_capacity(arguments.len());
377        let mut nested_item_nullability = Vec::with_capacity(arguments.len());
378        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
379            match argument {
380                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
381                    nested_item_nullability.push(None);
382                }
383                ArrayFunctionArgument::Element => {
384                    element_types.push(current_type.clone());
385                    nested_item_nullability.push(None);
386                }
387                ArrayFunctionArgument::Array => match current_type {
388                    DataType::Null => {
389                        element_types.push(DataType::Null);
390                        nested_item_nullability.push(None);
391                    }
392                    DataType::List(field) => {
393                        element_types.push(field.data_type().clone());
394                        nested_item_nullability.push(Some(field.is_nullable()));
395                        fixed_size = false;
396                    }
397                    DataType::LargeList(field) => {
398                        element_types.push(field.data_type().clone());
399                        nested_item_nullability.push(Some(field.is_nullable()));
400                        large_list = true;
401                        fixed_size = false;
402                    }
403                    DataType::FixedSizeList(field, size) => {
404                        element_types.push(field.data_type().clone());
405                        nested_item_nullability.push(Some(field.is_nullable()));
406                        list_sizes.push(*size)
407                    }
408                    arg_type => {
409                        plan_err!("{function_name} does not support type {arg_type}")?
410                    }
411                },
412            }
413        }
414
415        debug_assert_eq!(nested_item_nullability.len(), arguments.len());
416
417        let Some(element_type) = type_union_resolution(&element_types) else {
418            return Ok(vec![vec![]]);
419        };
420
421        if !fixed_size {
422            list_sizes.clear()
423        };
424
425        let mut list_sizes = list_sizes.into_iter();
426        let valid_types = arguments
427            .iter()
428            .zip(current_types.iter())
429            .zip(nested_item_nullability)
430            .map(|((argument_type, current_type), is_nested_item_nullable)| {
431                match argument_type {
432                    ArrayFunctionArgument::Index => DataType::Int64,
433                    ArrayFunctionArgument::String => DataType::Utf8,
434                    ArrayFunctionArgument::Element => element_type.clone(),
435                    ArrayFunctionArgument::Array => {
436                        if current_type.is_null() {
437                            DataType::Null
438                        } else if large_list {
439                            DataType::new_large_list(
440                                element_type.clone(),
441                                is_nested_item_nullable.unwrap_or(true),
442                            )
443                        } else if let Some(size) = list_sizes.next() {
444                            DataType::new_fixed_size_list(
445                                element_type.clone(),
446                                size,
447                                is_nested_item_nullable.unwrap_or(true),
448                            )
449                        } else {
450                            DataType::new_list(
451                                element_type.clone(),
452                                is_nested_item_nullable.unwrap_or(true),
453                            )
454                        }
455                    }
456                }
457            });
458
459        Ok(vec![valid_types.collect()])
460    }
461
462    fn recursive_array(array_type: &DataType) -> Option<DataType> {
463        match array_type {
464            DataType::List(_)
465            | DataType::LargeList(_)
466            | DataType::FixedSizeList(_, _) => {
467                let array_type = coerced_fixed_size_list_to_list(array_type);
468                Some(array_type)
469            }
470            _ => None,
471        }
472    }
473
474    fn function_length_check(
475        function_name: &str,
476        length: usize,
477        expected_length: usize,
478    ) -> Result<()> {
479        if length != expected_length {
480            return plan_err!(
481                "Function '{function_name}' expects {expected_length} arguments but received {length}"
482            );
483        }
484        Ok(())
485    }
486
487    let valid_types = match signature {
488        TypeSignature::Variadic(valid_types) => valid_types
489            .iter()
490            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
491            .collect(),
492        TypeSignature::String(number) => {
493            function_length_check(function_name, current_types.len(), *number)?;
494
495            let mut new_types = Vec::with_capacity(current_types.len());
496            for data_type in current_types.iter() {
497                let logical_data_type: NativeType = data_type.into();
498                if logical_data_type == NativeType::String {
499                    new_types.push(data_type.to_owned());
500                } else if logical_data_type == NativeType::Null {
501                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
502                    new_types.push(DataType::Utf8);
503                } else {
504                    return plan_err!(
505                        "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
506                    );
507                }
508            }
509
510            // Find the common string type for the given types
511            fn find_common_type(
512                function_name: &str,
513                lhs_type: &DataType,
514                rhs_type: &DataType,
515            ) -> Result<DataType> {
516                match (lhs_type, rhs_type) {
517                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
518                        find_common_type(function_name, lhs, rhs)
519                    }
520                    (DataType::Dictionary(_, v), other)
521                    | (other, DataType::Dictionary(_, v)) => {
522                        find_common_type(function_name, v, other)
523                    }
524                    _ => {
525                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
526                            Ok(coerced_type)
527                        } else {
528                            plan_err!(
529                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
530                            )
531                        }
532                    }
533                }
534            }
535
536            // Length checked above, safe to unwrap
537            let mut coerced_type = new_types.first().unwrap().to_owned();
538            for t in new_types.iter().skip(1) {
539                coerced_type = find_common_type(function_name, &coerced_type, t)?;
540            }
541
542            fn base_type_or_default_type(data_type: &DataType) -> DataType {
543                if let DataType::Dictionary(_, v) = data_type {
544                    base_type_or_default_type(v)
545                } else {
546                    data_type.to_owned()
547                }
548            }
549
550            vec![vec![base_type_or_default_type(&coerced_type); *number]]
551        }
552        TypeSignature::Numeric(number) => {
553            function_length_check(function_name, current_types.len(), *number)?;
554
555            // Find common numeric type among given types except string
556            let mut valid_type = current_types.first().unwrap().to_owned();
557            for t in current_types.iter().skip(1) {
558                let logical_data_type: NativeType = t.into();
559                if logical_data_type == NativeType::Null {
560                    continue;
561                }
562
563                if !logical_data_type.is_numeric() {
564                    return plan_err!(
565                        "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
566                    );
567                }
568
569                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
570                    valid_type = coerced_type;
571                } else {
572                    return plan_err!(
573                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
574                    );
575                }
576            }
577
578            let logical_data_type: NativeType = valid_type.clone().into();
579            // Fallback to default type if we don't know which type to coerced to
580            // f64 is chosen since most of the math functions utilize Signature::numeric,
581            // and their default type is double precision
582            if logical_data_type == NativeType::Null {
583                valid_type = DataType::Float64;
584            } else if !logical_data_type.is_numeric() {
585                return plan_err!(
586                    "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
587                );
588            }
589
590            vec![vec![valid_type; *number]]
591        }
592        TypeSignature::Comparable(num) => {
593            function_length_check(function_name, current_types.len(), *num)?;
594            let mut target_type = current_types[0].to_owned();
595            for data_type in current_types.iter().skip(1) {
596                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
597                    target_type = dt;
598                } else {
599                    return plan_err!(
600                        "For function '{function_name}' {target_type} and {data_type} is not comparable"
601                    );
602                }
603            }
604            // Convert null to String type.
605            if target_type.is_null() {
606                vec![vec![DataType::Utf8View; *num]]
607            } else {
608                vec![vec![target_type; *num]]
609            }
610        }
611        TypeSignature::Coercible(param_types) => {
612            function_length_check(function_name, current_types.len(), param_types.len())?;
613
614            let mut new_types = Vec::with_capacity(current_types.len());
615            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
616                let current_native_type: NativeType = current_type.into();
617
618                if param
619                    .desired_type()
620                    .matches_native_type(&current_native_type)
621                {
622                    let casted_type = param
623                        .desired_type()
624                        .default_casted_type(&current_native_type, current_type)?;
625
626                    new_types.push(casted_type);
627                } else if param
628                    .allowed_source_types()
629                    .iter()
630                    .any(|t| t.matches_native_type(&current_native_type))
631                {
632                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
633                    let default_casted_type = param.default_casted_type().unwrap();
634                    let casted_type =
635                        default_casted_type.default_cast_for(current_type)?;
636                    new_types.push(casted_type);
637                } else {
638                    return internal_err!(
639                        "Expect {} but received NativeType::{}, DataType: {}",
640                        param.desired_type(),
641                        current_native_type,
642                        current_type
643                    );
644                }
645            }
646
647            vec![new_types]
648        }
649        TypeSignature::Uniform(number, valid_types) => {
650            if *number == 0 {
651                return plan_err!(
652                    "The function '{function_name}' expected at least one argument"
653                );
654            }
655
656            valid_types
657                .iter()
658                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
659                .collect()
660        }
661        TypeSignature::UserDefined => {
662            return internal_err!(
663                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
664            );
665        }
666        TypeSignature::VariadicAny => {
667            if current_types.is_empty() {
668                return plan_err!(
669                    "Function '{function_name}' expected at least one argument but received 0"
670                );
671            }
672            vec![current_types.to_vec()]
673        }
674        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
675        TypeSignature::ArraySignature(function_signature) => match function_signature {
676            ArrayFunctionSignature::Array {
677                arguments,
678                array_coercion,
679            } => array_valid_types(
680                function_name,
681                current_types,
682                arguments,
683                array_coercion.as_ref(),
684            )?,
685            ArrayFunctionSignature::RecursiveArray => {
686                if current_types.len() != 1 {
687                    return Ok(vec![vec![]]);
688                }
689                recursive_array(&current_types[0])
690                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
691            }
692            ArrayFunctionSignature::MapArray => {
693                if current_types.len() != 1 {
694                    return Ok(vec![vec![]]);
695                }
696
697                match &current_types[0] {
698                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
699                    _ => vec![vec![]],
700                }
701            }
702        },
703        TypeSignature::Nullary => {
704            if !current_types.is_empty() {
705                return plan_err!(
706                    "The function '{function_name}' expected zero argument but received {}",
707                    current_types.len()
708                );
709            }
710            vec![vec![]]
711        }
712        TypeSignature::Any(number) => {
713            if current_types.is_empty() {
714                return plan_err!(
715                    "The function '{function_name}' expected at least one argument but received 0"
716                );
717            }
718
719            if current_types.len() != *number {
720                return plan_err!(
721                    "The function '{function_name}' expected {number} arguments but received {}",
722                    current_types.len()
723                );
724            }
725            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
726        }
727        TypeSignature::OneOf(types) => types
728            .iter()
729            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
730            .flatten()
731            .collect::<Vec<_>>(),
732    };
733
734    Ok(valid_types)
735}
736
737/// Try to coerce the current argument types to match the given `valid_types`.
738///
739/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
740/// but was called with `(int32, int64)`, this function could match the
741/// valid_types by coercing the first argument to `int64`, and would return
742/// `Some([int64, int64])`.
743fn maybe_data_types(
744    valid_types: &[DataType],
745    current_types: &[DataType],
746) -> Option<Vec<DataType>> {
747    if valid_types.len() != current_types.len() {
748        return None;
749    }
750
751    let mut new_type = Vec::with_capacity(valid_types.len());
752    for (i, valid_type) in valid_types.iter().enumerate() {
753        let current_type = &current_types[i];
754
755        if current_type == valid_type {
756            new_type.push(current_type.clone())
757        } else {
758            // attempt to coerce.
759            // TODO: Replace with `can_cast_types` after failing cases are resolved
760            // (they need new signature that returns exactly valid types instead of list of possible valid types).
761            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
762                new_type.push(coerced_type)
763            } else {
764                // not possible
765                return None;
766            }
767        }
768    }
769    Some(new_type)
770}
771
772/// Check if the current argument types can be coerced to match the given `valid_types`
773/// unlike `maybe_data_types`, this function does not coerce the types.
774/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
775fn maybe_data_types_without_coercion(
776    valid_types: &[DataType],
777    current_types: &[DataType],
778) -> Option<Vec<DataType>> {
779    if valid_types.len() != current_types.len() {
780        return None;
781    }
782
783    let mut new_type = Vec::with_capacity(valid_types.len());
784    for (i, valid_type) in valid_types.iter().enumerate() {
785        let current_type = &current_types[i];
786
787        if current_type == valid_type {
788            new_type.push(current_type.clone())
789        } else if can_cast_types(current_type, valid_type) {
790            // validate the valid type is castable from the current type
791            new_type.push(valid_type.clone())
792        } else {
793            return None;
794        }
795    }
796    Some(new_type)
797}
798
799/// Return true if a value of type `type_from` can be coerced
800/// (losslessly converted) into a value of `type_to`
801///
802/// See the module level documentation for more detail on coercion.
803pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
804    if type_into == type_from {
805        return true;
806    }
807    if let Some(coerced) = coerced_from(type_into, type_from) {
808        return coerced == *type_into;
809    }
810    false
811}
812
813/// Find the coerced type for the given `type_into` and `type_from`.
814/// Returns `None` if coercion is not possible.
815///
816/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
817///
818/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
819fn coerced_from<'a>(
820    type_into: &'a DataType,
821    type_from: &'a DataType,
822) -> Option<DataType> {
823    use self::DataType::*;
824
825    // match Dictionary first
826    match (type_into, type_from) {
827        // coerced dictionary first
828        (_, Dictionary(_, value_type))
829            if coerced_from(type_into, value_type).is_some() =>
830        {
831            Some(type_into.clone())
832        }
833        (Dictionary(_, value_type), _)
834            if coerced_from(value_type, type_from).is_some() =>
835        {
836            Some(type_into.clone())
837        }
838        // coerced into type_into
839        (Int8, Null | Int8) => Some(type_into.clone()),
840        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
841        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
842        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
843            Some(type_into.clone())
844        }
845        (UInt8, Null | UInt8) => Some(type_into.clone()),
846        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
847        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
848        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
849        (
850            Float32,
851            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
852            | Float32,
853        ) => Some(type_into.clone()),
854        (
855            Float64,
856            Null
857            | Int8
858            | Int16
859            | Int32
860            | Int64
861            | UInt8
862            | UInt16
863            | UInt32
864            | UInt64
865            | Float32
866            | Float64
867            | Decimal32(_, _)
868            | Decimal64(_, _)
869            | Decimal128(_, _)
870            | Decimal256(_, _),
871        ) => Some(type_into.clone()),
872        (
873            Timestamp(TimeUnit::Nanosecond, None),
874            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
875        ) => Some(type_into.clone()),
876        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
877        // We can go into a Utf8View from a Utf8 or LargeUtf8
878        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
879        // Any type can be coerced into strings
880        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
881        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
882
883        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
884
885        // Only accept list and largelist with the same number of dimensions unless the type is Null.
886        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
887        (List(_) | LargeList(_), _)
888            if base_type(type_from).is_null()
889                || list_ndims(type_from) == list_ndims(type_into) =>
890        {
891            Some(type_into.clone())
892        }
893        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
894        (
895            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
896            FixedSizeList(f_from, size_from),
897        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
898            Some(data_type) if &data_type != f_into.data_type() => {
899                let new_field =
900                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
901                Some(FixedSizeList(new_field, *size_from))
902            }
903            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
904            _ => None,
905        },
906        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
907            match type_from {
908                Timestamp(_, Some(from_tz)) => {
909                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
910                }
911                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
912                    // In the absence of any other information assume the time zone is "+00" (UTC).
913                    Some(Timestamp(*unit, Some("+00".into())))
914                }
915                _ => None,
916            }
917        }
918        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
919            Some(type_into.clone())
920        }
921        _ => None,
922    }
923}
924
925#[cfg(test)]
926mod tests {
927    use crate::Volatility;
928
929    use super::*;
930    use arrow::datatypes::Field;
931    use datafusion_common::{assert_contains, types::logical_binary};
932    use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
933
934    #[test]
935    fn test_string_conversion() {
936        let cases = vec![
937            (DataType::Utf8View, DataType::Utf8, true),
938            (DataType::Utf8View, DataType::LargeUtf8, true),
939        ];
940
941        for case in cases {
942            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
943        }
944    }
945
946    #[test]
947    fn test_maybe_data_types() {
948        // this vec contains: arg1, arg2, expected result
949        let cases = vec![
950            // 2 entries, same values
951            (
952                vec![DataType::UInt8, DataType::UInt16],
953                vec![DataType::UInt8, DataType::UInt16],
954                Some(vec![DataType::UInt8, DataType::UInt16]),
955            ),
956            // 2 entries, can coerce values
957            (
958                vec![DataType::UInt16, DataType::UInt16],
959                vec![DataType::UInt8, DataType::UInt16],
960                Some(vec![DataType::UInt16, DataType::UInt16]),
961            ),
962            // 0 entries, all good
963            (vec![], vec![], Some(vec![])),
964            // 2 entries, can't coerce
965            (
966                vec![DataType::Boolean, DataType::UInt16],
967                vec![DataType::UInt8, DataType::UInt16],
968                None,
969            ),
970            // u32 -> u16 is possible
971            (
972                vec![DataType::Boolean, DataType::UInt32],
973                vec![DataType::Boolean, DataType::UInt16],
974                Some(vec![DataType::Boolean, DataType::UInt32]),
975            ),
976            // UTF8 -> Timestamp
977            (
978                vec![
979                    DataType::Timestamp(TimeUnit::Nanosecond, None),
980                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
981                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
982                ],
983                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
984                Some(vec![
985                    DataType::Timestamp(TimeUnit::Nanosecond, None),
986                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
987                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
988                ]),
989            ),
990        ];
991
992        for case in cases {
993            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
994        }
995    }
996
997    #[test]
998    fn test_get_valid_types_numeric() -> Result<()> {
999        let get_valid_types_flatten =
1000            |function_name: &str,
1001             signature: &TypeSignature,
1002             current_types: &[DataType]| {
1003                get_valid_types(function_name, signature, current_types)
1004                    .unwrap()
1005                    .into_iter()
1006                    .flatten()
1007                    .collect::<Vec<_>>()
1008            };
1009
1010        // Trivial case.
1011        let got = get_valid_types_flatten(
1012            "test",
1013            &TypeSignature::Numeric(1),
1014            &[DataType::Int32],
1015        );
1016        assert_eq!(got, [DataType::Int32]);
1017
1018        // Args are coerced into a common numeric type.
1019        let got = get_valid_types_flatten(
1020            "test",
1021            &TypeSignature::Numeric(2),
1022            &[DataType::Int32, DataType::Int64],
1023        );
1024        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1025
1026        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1027        let got = get_valid_types_flatten(
1028            "test",
1029            &TypeSignature::Numeric(3),
1030            &[DataType::Int32, DataType::Int64, DataType::Float64],
1031        );
1032        assert_eq!(
1033            got,
1034            [DataType::Float64, DataType::Float64, DataType::Float64]
1035        );
1036
1037        // Cannot coerce args to a common numeric type.
1038        let got = get_valid_types(
1039            "test",
1040            &TypeSignature::Numeric(2),
1041            &[DataType::Int32, DataType::Utf8],
1042        )
1043        .unwrap_err();
1044        assert_contains!(
1045            got.to_string(),
1046            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1047        );
1048
1049        // Fallbacks to float64 if the arg is of type null.
1050        let got = get_valid_types_flatten(
1051            "test",
1052            &TypeSignature::Numeric(1),
1053            &[DataType::Null],
1054        );
1055        assert_eq!(got, [DataType::Float64]);
1056
1057        // Rejects non-numeric arg.
1058        let got = get_valid_types(
1059            "test",
1060            &TypeSignature::Numeric(1),
1061            &[DataType::Timestamp(TimeUnit::Second, None)],
1062        )
1063        .unwrap_err();
1064        assert_contains!(
1065            got.to_string(),
1066            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1067        );
1068
1069        Ok(())
1070    }
1071
1072    #[test]
1073    fn test_get_valid_types_one_of() -> Result<()> {
1074        let signature =
1075            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1076
1077        let invalid_types = get_valid_types(
1078            "test",
1079            &signature,
1080            &[DataType::Int32, DataType::Int32, DataType::Int32],
1081        )?;
1082        assert_eq!(invalid_types.len(), 0);
1083
1084        let args = vec![DataType::Int32, DataType::Int32];
1085        let valid_types = get_valid_types("test", &signature, &args)?;
1086        assert_eq!(valid_types.len(), 1);
1087        assert_eq!(valid_types[0], args);
1088
1089        let args = vec![DataType::Int32];
1090        let valid_types = get_valid_types("test", &signature, &args)?;
1091        assert_eq!(valid_types.len(), 1);
1092        assert_eq!(valid_types[0], args);
1093
1094        Ok(())
1095    }
1096
1097    #[test]
1098    fn test_get_valid_types_length_check() -> Result<()> {
1099        let signature = TypeSignature::Numeric(1);
1100
1101        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1102        assert_contains!(
1103            err.to_string(),
1104            "Function 'test' expects 1 arguments but received 0"
1105        );
1106
1107        let err = get_valid_types(
1108            "test",
1109            &signature,
1110            &[DataType::Int32, DataType::Int32, DataType::Int32],
1111        )
1112        .unwrap_err();
1113        assert_contains!(
1114            err.to_string(),
1115            "Function 'test' expects 1 arguments but received 3"
1116        );
1117
1118        Ok(())
1119    }
1120
1121    #[test]
1122    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1123        struct MockUdf(Signature);
1124
1125        impl UDFCoercionExt for MockUdf {
1126            fn name(&self) -> &str {
1127                "test"
1128            }
1129            fn signature(&self) -> &Signature {
1130                &self.0
1131            }
1132            fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1133                unimplemented!()
1134            }
1135        }
1136
1137        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1138        // able to coerce for any size
1139        let current_fields = vec![Arc::new(Field::new(
1140            "t",
1141            DataType::FixedSizeList(Arc::clone(&inner), 2),
1142            true,
1143        ))];
1144
1145        let signature = Signature::exact(
1146            vec![DataType::FixedSizeList(
1147                Arc::clone(&inner),
1148                FIXED_SIZE_LIST_WILDCARD,
1149            )],
1150            Volatility::Stable,
1151        );
1152
1153        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature))?;
1154        assert_eq!(coerced_fields, current_fields);
1155
1156        // make sure it can't coerce to a different size
1157        let signature = Signature::exact(
1158            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1159            Volatility::Stable,
1160        );
1161        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature));
1162        assert!(coerced_fields.is_err());
1163
1164        // make sure it works with the same type.
1165        let signature = Signature::exact(
1166            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1167            Volatility::Stable,
1168        );
1169        let coerced_fields =
1170            fields_with_udf(&current_fields, &MockUdf(signature)).unwrap();
1171        assert_eq!(coerced_fields, current_fields);
1172
1173        Ok(())
1174    }
1175
1176    #[test]
1177    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1178        let type_into = DataType::FixedSizeList(
1179            Arc::new(Field::new_list_field(
1180                DataType::FixedSizeList(
1181                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1182                    FIXED_SIZE_LIST_WILDCARD,
1183                ),
1184                false,
1185            )),
1186            FIXED_SIZE_LIST_WILDCARD,
1187        );
1188
1189        let type_from = DataType::FixedSizeList(
1190            Arc::new(Field::new_list_field(
1191                DataType::FixedSizeList(
1192                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1193                    4,
1194                ),
1195                false,
1196            )),
1197            3,
1198        );
1199
1200        assert_eq!(
1201            coerced_from(&type_into, &type_from),
1202            Some(DataType::FixedSizeList(
1203                Arc::new(Field::new_list_field(
1204                    DataType::FixedSizeList(
1205                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1206                        4,
1207                    ),
1208                    false,
1209                )),
1210                3,
1211            ))
1212        );
1213
1214        Ok(())
1215    }
1216
1217    #[test]
1218    fn test_coerced_from_dictionary() {
1219        let type_into =
1220            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1221        let type_from = DataType::Int64;
1222        assert_eq!(coerced_from(&type_into, &type_from), None);
1223
1224        let type_from =
1225            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1226        let type_into = DataType::Int64;
1227        assert_eq!(
1228            coerced_from(&type_into, &type_from),
1229            Some(type_into.clone())
1230        );
1231    }
1232
1233    #[test]
1234    fn test_get_valid_types_array_and_array() -> Result<()> {
1235        let function = "array_and_array";
1236        let signature = Signature::arrays(
1237            2,
1238            Some(ListCoercion::FixedSizedListToList),
1239            Volatility::Immutable,
1240        );
1241
1242        let data_types = vec![
1243            DataType::new_list(DataType::Int32, true),
1244            DataType::new_large_list(DataType::Float64, true),
1245        ];
1246        assert_eq!(
1247            get_valid_types(function, &signature.type_signature, &data_types)?,
1248            vec![vec![
1249                DataType::new_large_list(DataType::Float64, true),
1250                DataType::new_large_list(DataType::Float64, true),
1251            ]]
1252        );
1253
1254        let data_types = vec![
1255            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1256            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1257        ];
1258        assert_eq!(
1259            get_valid_types(function, &signature.type_signature, &data_types)?,
1260            vec![vec![
1261                DataType::new_list(DataType::Int64, true),
1262                DataType::new_list(DataType::Int64, true),
1263            ]]
1264        );
1265
1266        let data_types = vec![
1267            DataType::new_fixed_size_list(DataType::Null, 3, true),
1268            DataType::new_large_list(DataType::Utf8, true),
1269        ];
1270        assert_eq!(
1271            get_valid_types(function, &signature.type_signature, &data_types)?,
1272            vec![vec![
1273                DataType::new_large_list(DataType::Utf8, true),
1274                DataType::new_large_list(DataType::Utf8, true),
1275            ]]
1276        );
1277
1278        Ok(())
1279    }
1280
1281    #[test]
1282    fn test_get_valid_types_array_and_element() -> Result<()> {
1283        let function = "array_and_element";
1284        let signature = Signature::array_and_element(Volatility::Immutable);
1285
1286        let data_types =
1287            vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1288        assert_eq!(
1289            get_valid_types(function, &signature.type_signature, &data_types)?,
1290            vec![vec![
1291                DataType::new_list(DataType::Float64, true),
1292                DataType::Float64,
1293            ]]
1294        );
1295
1296        let data_types = vec![
1297            DataType::new_large_list(DataType::Int32, true),
1298            DataType::Null,
1299        ];
1300        assert_eq!(
1301            get_valid_types(function, &signature.type_signature, &data_types)?,
1302            vec![vec![
1303                DataType::new_large_list(DataType::Int32, true),
1304                DataType::Int32,
1305            ]]
1306        );
1307
1308        let data_types = vec![
1309            DataType::new_fixed_size_list(DataType::Null, 3, true),
1310            DataType::Utf8,
1311        ];
1312        assert_eq!(
1313            get_valid_types(function, &signature.type_signature, &data_types)?,
1314            vec![vec![
1315                DataType::new_list(DataType::Utf8, true),
1316                DataType::Utf8,
1317            ]]
1318        );
1319
1320        Ok(())
1321    }
1322
1323    #[test]
1324    fn test_get_valid_types_element_and_array() -> Result<()> {
1325        let function = "element_and_array";
1326        let signature = Signature::element_and_array(Volatility::Immutable);
1327
1328        let data_types = vec![
1329            DataType::new_large_list(DataType::Null, false),
1330            DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1331        ];
1332        assert_eq!(
1333            get_valid_types(function, &signature.type_signature, &data_types)?,
1334            vec![vec![
1335                DataType::new_large_list(DataType::Int64, true),
1336                DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1337            ]]
1338        );
1339
1340        Ok(())
1341    }
1342
1343    #[test]
1344    fn test_get_valid_types_coercible_binary() -> Result<()> {
1345        let signature = Signature::coercible(
1346            vec![Coercion::new_exact(TypeSignatureClass::Native(
1347                logical_binary(),
1348            ))],
1349            Volatility::Immutable,
1350        );
1351
1352        // Binary types should stay their original selves
1353        for t in [
1354            DataType::Binary,
1355            DataType::BinaryView,
1356            DataType::LargeBinary,
1357        ] {
1358            assert_eq!(
1359                get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1360                vec![vec![t]]
1361            );
1362        }
1363
1364        Ok(())
1365    }
1366
1367    #[test]
1368    fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1369        let function = "fixed_size_arrays";
1370        let signature = Signature::arrays(2, None, Volatility::Immutable);
1371
1372        let data_types = vec![
1373            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1374            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1375        ];
1376        assert_eq!(
1377            get_valid_types(function, &signature.type_signature, &data_types)?,
1378            vec![vec![
1379                DataType::new_fixed_size_list(DataType::Int64, 3, true),
1380                DataType::new_fixed_size_list(DataType::Int64, 5, true),
1381            ]]
1382        );
1383
1384        let data_types = vec![
1385            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1386            DataType::new_list(DataType::Int32, true),
1387        ];
1388        assert_eq!(
1389            get_valid_types(function, &signature.type_signature, &data_types)?,
1390            vec![vec![
1391                DataType::new_list(DataType::Int64, true),
1392                DataType::new_list(DataType::Int64, true),
1393            ]]
1394        );
1395
1396        let data_types = vec![
1397            DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1398            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1399        ];
1400        assert_eq!(
1401            get_valid_types(function, &signature.type_signature, &data_types)?,
1402            vec![vec![]]
1403        );
1404
1405        let data_types = vec![
1406            DataType::new_fixed_size_list(DataType::Int64, 3, false),
1407            DataType::new_list(DataType::Int32, false),
1408        ];
1409        assert_eq!(
1410            get_valid_types(function, &signature.type_signature, &data_types)?,
1411            vec![vec![
1412                DataType::new_list(DataType::Int64, false),
1413                DataType::new_list(DataType::Int64, false),
1414            ]]
1415        );
1416
1417        Ok(())
1418    }
1419}