Skip to main content

datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::binary_numeric_coercion;
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::datatypes::{Field, FieldRef};
21use arrow::{
22    compute::can_cast_types,
23    datatypes::{DataType, TimeUnit},
24};
25use datafusion_common::types::LogicalType;
26use datafusion_common::utils::{
27    ListCoercion, base_type, coerced_fixed_size_list_to_list,
28};
29use datafusion_common::{
30    Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
31};
32use datafusion_expr_common::signature::ArrayFunctionArgument;
33use datafusion_expr_common::type_coercion::binary::type_union_resolution;
34use datafusion_expr_common::{
35    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
36    type_coercion::binary::comparison_coercion_numeric,
37    type_coercion::binary::string_coercion,
38};
39use itertools::Itertools as _;
40use std::sync::Arc;
41
42/// Extension trait to unify common functionality between [`ScalarUDF`], [`AggregateUDF`]
43/// and [`WindowUDF`] for use by signature coercion functions.
44pub trait UDFCoercionExt {
45    /// Should delegate to [`ScalarUDF::name`], [`AggregateUDF::name`] or [`WindowUDF::name`].
46    fn name(&self) -> &str;
47    /// Should delegate to [`ScalarUDF::signature`], [`AggregateUDF::signature`]
48    /// or [`WindowUDF::signature`].
49    fn signature(&self) -> &Signature;
50    /// Should delegate to [`ScalarUDF::coerce_types`], [`AggregateUDF::coerce_types`]
51    /// or [`WindowUDF::coerce_types`].
52    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
53}
54
55impl UDFCoercionExt for ScalarUDF {
56    fn name(&self) -> &str {
57        self.name()
58    }
59
60    fn signature(&self) -> &Signature {
61        self.signature()
62    }
63
64    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
65        self.coerce_types(arg_types)
66    }
67}
68
69impl UDFCoercionExt for AggregateUDF {
70    fn name(&self) -> &str {
71        self.name()
72    }
73
74    fn signature(&self) -> &Signature {
75        self.signature()
76    }
77
78    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
79        self.coerce_types(arg_types)
80    }
81}
82
83impl UDFCoercionExt for WindowUDF {
84    fn name(&self) -> &str {
85        self.name()
86    }
87
88    fn signature(&self) -> &Signature {
89        self.signature()
90    }
91
92    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
93        self.coerce_types(arg_types)
94    }
95}
96
97/// Performs type coercion for UDF arguments.
98///
99/// Returns the data types to which each argument must be coerced to
100/// match `signature`.
101///
102/// For more details on coercion in general, please see the
103/// [`type_coercion`](crate::type_coercion) module.
104pub fn fields_with_udf<F: UDFCoercionExt>(
105    current_fields: &[FieldRef],
106    func: &F,
107) -> Result<Vec<FieldRef>> {
108    let signature = func.signature();
109    let type_signature = &signature.type_signature;
110
111    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
112        if type_signature.supports_zero_argument() {
113            return Ok(vec![]);
114        } else if type_signature.used_to_support_zero_arguments() {
115            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
116            return plan_err!(
117                "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
118                func.name()
119            );
120        } else {
121            return plan_err!("'{}' does not support zero arguments", func.name());
122        }
123    }
124    let current_types = current_fields
125        .iter()
126        .map(|f| f.data_type())
127        .cloned()
128        .collect::<Vec<_>>();
129
130    let valid_types = get_valid_types_with_udf(type_signature, &current_types, func)?;
131    if valid_types
132        .iter()
133        .any(|data_type| data_type == &current_types)
134    {
135        return Ok(current_fields.to_vec());
136    }
137
138    let updated_types =
139        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
140
141    Ok(current_fields
142        .iter()
143        .zip(updated_types)
144        .map(|(current_field, new_type)| {
145            current_field.as_ref().clone().with_data_type(new_type)
146        })
147        .map(Arc::new)
148        .collect())
149}
150
151/// Performs type coercion for scalar function arguments.
152///
153/// Returns the data types to which each argument must be coerced to
154/// match `signature`.
155///
156/// For more details on coercion in general, please see the
157/// [`type_coercion`](crate::type_coercion) module.
158#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
159pub fn data_types_with_scalar_udf(
160    current_types: &[DataType],
161    func: &ScalarUDF,
162) -> Result<Vec<DataType>> {
163    let current_fields = current_types
164        .iter()
165        .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
166        .collect::<Vec<_>>();
167    Ok(fields_with_udf(&current_fields, func)?
168        .iter()
169        .map(|f| f.data_type().clone())
170        .collect())
171}
172
173/// Performs type coercion for aggregate function arguments.
174///
175/// Returns the fields to which each argument must be coerced to
176/// match `signature`.
177///
178/// For more details on coercion in general, please see the
179/// [`type_coercion`](crate::type_coercion) module.
180#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
181pub fn fields_with_aggregate_udf(
182    current_fields: &[FieldRef],
183    func: &AggregateUDF,
184) -> Result<Vec<FieldRef>> {
185    fields_with_udf(current_fields, func)
186}
187
188/// Performs type coercion for window function arguments.
189///
190/// Returns the data types to which each argument must be coerced to
191/// match `signature`.
192///
193/// For more details on coercion in general, please see the
194/// [`type_coercion`](crate::type_coercion) module.
195#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
196pub fn fields_with_window_udf(
197    current_fields: &[FieldRef],
198    func: &WindowUDF,
199) -> Result<Vec<FieldRef>> {
200    fields_with_udf(current_fields, func)
201}
202
203/// Performs type coercion for function arguments.
204///
205/// Returns the data types to which each argument must be coerced to
206/// match `signature`.
207///
208/// For more details on coercion in general, please see the
209/// [`type_coercion`](crate::type_coercion) module.
210#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
211pub fn data_types(
212    function_name: impl AsRef<str>,
213    current_types: &[DataType],
214    signature: &Signature,
215) -> Result<Vec<DataType>> {
216    let type_signature = &signature.type_signature;
217
218    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
219        if type_signature.supports_zero_argument() {
220            return Ok(vec![]);
221        } else if type_signature.used_to_support_zero_arguments() {
222            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
223            return plan_err!(
224                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
225                function_name.as_ref()
226            );
227        } else {
228            return plan_err!(
229                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
230                function_name.as_ref()
231            );
232        }
233    }
234
235    let valid_types =
236        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
237    if valid_types
238        .iter()
239        .any(|data_type| data_type == current_types)
240    {
241        return Ok(current_types.to_vec());
242    }
243
244    try_coerce_types(
245        function_name.as_ref(),
246        valid_types,
247        current_types,
248        type_signature,
249    )
250}
251
252fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
253    match type_signature {
254        TypeSignature::OneOf(type_signatures) => {
255            type_signatures.iter().all(is_well_supported_signature)
256        }
257        TypeSignature::UserDefined
258        | TypeSignature::Numeric(_)
259        | TypeSignature::String(_)
260        | TypeSignature::Coercible(_)
261        | TypeSignature::Any(_)
262        | TypeSignature::Nullary
263        | TypeSignature::Comparable(_) => true,
264        TypeSignature::Variadic(_)
265        | TypeSignature::VariadicAny
266        | TypeSignature::Uniform(_, _)
267        | TypeSignature::Exact(_)
268        | TypeSignature::ArraySignature(_) => false,
269    }
270}
271
272fn try_coerce_types(
273    function_name: &str,
274    valid_types: Vec<Vec<DataType>>,
275    current_types: &[DataType],
276    type_signature: &TypeSignature,
277) -> Result<Vec<DataType>> {
278    let mut valid_types = valid_types;
279
280    // Well-supported signature that returns exact valid types.
281    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
282        // There may be many valid types if valid signature is OneOf
283        // Otherwise, there should be only one valid type
284        if !type_signature.is_one_of() {
285            assert_eq!(valid_types.len(), 1);
286        }
287
288        let valid_types = valid_types.swap_remove(0);
289        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
290            return Ok(t);
291        }
292    } else {
293        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
294        // Try and coerce the argument types to match the signature, returning the
295        // coerced types from the first matching signature.
296        for valid_types in valid_types {
297            if let Some(types) = maybe_data_types(&valid_types, current_types) {
298                return Ok(types);
299            }
300        }
301    }
302
303    // none possible -> Error
304    plan_err!(
305        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed",
306        current_types.iter().join(", ")
307    )
308}
309
310fn get_valid_types_with_udf<F: UDFCoercionExt>(
311    signature: &TypeSignature,
312    current_types: &[DataType],
313    func: &F,
314) -> Result<Vec<Vec<DataType>>> {
315    let valid_types = match signature {
316        TypeSignature::UserDefined => match func.coerce_types(current_types) {
317            Ok(coerced_types) => vec![coerced_types],
318            Err(e) => {
319                return exec_err!(
320                    "Function '{}' user-defined coercion failed with {:?}",
321                    func.name(),
322                    e.strip_backtrace()
323                );
324            }
325        },
326        TypeSignature::OneOf(signatures) => {
327            let mut res = vec![];
328            let mut errors = vec![];
329            for sig in signatures {
330                match get_valid_types_with_udf(sig, current_types, func) {
331                    Ok(valid_types) => {
332                        res.extend(valid_types);
333                    }
334                    Err(e) => {
335                        errors.push(e.to_string());
336                    }
337                }
338            }
339
340            // Every signature failed, return the joined error
341            if res.is_empty() {
342                return internal_err!(
343                    "Function '{}' failed to match any signature, errors: {}",
344                    func.name(),
345                    errors.join(",")
346                );
347            } else {
348                res
349            }
350        }
351        _ => get_valid_types(func.name(), signature, current_types)?,
352    };
353
354    Ok(valid_types)
355}
356
357/// Returns a Vec of all possible valid argument types for the given signature.
358fn get_valid_types(
359    function_name: &str,
360    signature: &TypeSignature,
361    current_types: &[DataType],
362) -> Result<Vec<Vec<DataType>>> {
363    fn array_valid_types(
364        function_name: &str,
365        current_types: &[DataType],
366        arguments: &[ArrayFunctionArgument],
367        array_coercion: Option<&ListCoercion>,
368    ) -> Result<Vec<Vec<DataType>>> {
369        if current_types.len() != arguments.len() {
370            return Ok(vec![vec![]]);
371        }
372
373        let mut large_list = false;
374        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
375        let mut list_sizes = Vec::with_capacity(arguments.len());
376        let mut element_types = Vec::with_capacity(arguments.len());
377        let mut nested_item_nullability = Vec::with_capacity(arguments.len());
378        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
379            match argument {
380                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
381                    nested_item_nullability.push(None);
382                }
383                ArrayFunctionArgument::Element => {
384                    element_types.push(current_type.clone());
385                    nested_item_nullability.push(None);
386                }
387                ArrayFunctionArgument::Array => match current_type {
388                    DataType::Null => {
389                        element_types.push(DataType::Null);
390                        nested_item_nullability.push(None);
391                    }
392                    DataType::List(field) => {
393                        element_types.push(field.data_type().clone());
394                        nested_item_nullability.push(Some(field.is_nullable()));
395                        fixed_size = false;
396                    }
397                    DataType::LargeList(field) => {
398                        element_types.push(field.data_type().clone());
399                        nested_item_nullability.push(Some(field.is_nullable()));
400                        large_list = true;
401                        fixed_size = false;
402                    }
403                    DataType::FixedSizeList(field, size) => {
404                        element_types.push(field.data_type().clone());
405                        nested_item_nullability.push(Some(field.is_nullable()));
406                        list_sizes.push(*size)
407                    }
408                    arg_type => {
409                        plan_err!("{function_name} does not support type {arg_type}")?
410                    }
411                },
412            }
413        }
414
415        debug_assert_eq!(nested_item_nullability.len(), arguments.len());
416
417        let Some(element_type) = type_union_resolution(&element_types) else {
418            return Ok(vec![vec![]]);
419        };
420
421        if !fixed_size {
422            list_sizes.clear()
423        };
424
425        let mut list_sizes = list_sizes.into_iter();
426        let valid_types = arguments
427            .iter()
428            .zip(current_types.iter())
429            .zip(nested_item_nullability)
430            .map(|((argument_type, current_type), is_nested_item_nullable)| {
431                match argument_type {
432                    ArrayFunctionArgument::Index => DataType::Int64,
433                    ArrayFunctionArgument::String => DataType::Utf8,
434                    ArrayFunctionArgument::Element => element_type.clone(),
435                    ArrayFunctionArgument::Array => {
436                        if current_type.is_null() {
437                            DataType::Null
438                        } else if large_list {
439                            DataType::new_large_list(
440                                element_type.clone(),
441                                is_nested_item_nullable.unwrap_or(true),
442                            )
443                        } else if let Some(size) = list_sizes.next() {
444                            DataType::new_fixed_size_list(
445                                element_type.clone(),
446                                size,
447                                is_nested_item_nullable.unwrap_or(true),
448                            )
449                        } else {
450                            DataType::new_list(
451                                element_type.clone(),
452                                is_nested_item_nullable.unwrap_or(true),
453                            )
454                        }
455                    }
456                }
457            });
458
459        Ok(vec![valid_types.collect()])
460    }
461
462    fn recursive_array(array_type: &DataType) -> Option<DataType> {
463        match array_type {
464            DataType::List(_)
465            | DataType::LargeList(_)
466            | DataType::FixedSizeList(_, _) => {
467                let array_type = coerced_fixed_size_list_to_list(array_type);
468                Some(array_type)
469            }
470            _ => None,
471        }
472    }
473
474    fn function_length_check(
475        function_name: &str,
476        length: usize,
477        expected_length: usize,
478    ) -> Result<()> {
479        if length != expected_length {
480            return plan_err!(
481                "Function '{function_name}' expects {expected_length} arguments but received {length}"
482            );
483        }
484        Ok(())
485    }
486
487    let valid_types = match signature {
488        TypeSignature::Variadic(valid_types) => valid_types
489            .iter()
490            .map(|valid_type| vec![valid_type.clone(); current_types.len()])
491            .collect(),
492        TypeSignature::String(number) => {
493            function_length_check(function_name, current_types.len(), *number)?;
494
495            let mut new_types = Vec::with_capacity(current_types.len());
496            for data_type in current_types.iter() {
497                let logical_data_type: NativeType = data_type.into();
498                if logical_data_type == NativeType::String {
499                    new_types.push(data_type.to_owned());
500                } else if logical_data_type == NativeType::Null {
501                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
502                    new_types.push(DataType::Utf8);
503                } else {
504                    return plan_err!(
505                        "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}"
506                    );
507                }
508            }
509
510            // Find the common string type for the given types
511            fn find_common_type(
512                function_name: &str,
513                lhs_type: &DataType,
514                rhs_type: &DataType,
515            ) -> Result<DataType> {
516                match (lhs_type, rhs_type) {
517                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
518                        find_common_type(function_name, lhs, rhs)
519                    }
520                    (DataType::Dictionary(_, v), other)
521                    | (other, DataType::Dictionary(_, v)) => {
522                        find_common_type(function_name, v, other)
523                    }
524                    _ => {
525                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
526                            Ok(coerced_type)
527                        } else {
528                            plan_err!(
529                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
530                            )
531                        }
532                    }
533                }
534            }
535
536            // Length checked above, safe to unwrap
537            let mut coerced_type = new_types.first().unwrap().to_owned();
538            for t in new_types.iter().skip(1) {
539                coerced_type = find_common_type(function_name, &coerced_type, t)?;
540            }
541
542            fn base_type_or_default_type(data_type: &DataType) -> DataType {
543                if let DataType::Dictionary(_, v) = data_type {
544                    base_type_or_default_type(v)
545                } else {
546                    data_type.to_owned()
547                }
548            }
549
550            vec![vec![base_type_or_default_type(&coerced_type); *number]]
551        }
552        TypeSignature::Numeric(number) => {
553            function_length_check(function_name, current_types.len(), *number)?;
554
555            // Find common numeric type among given types except string
556            let mut valid_type = current_types.first().unwrap().to_owned();
557            for t in current_types.iter().skip(1) {
558                let logical_data_type: NativeType = t.into();
559                if logical_data_type == NativeType::Null {
560                    continue;
561                }
562
563                if !logical_data_type.is_numeric() {
564                    return plan_err!(
565                        "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
566                    );
567                }
568
569                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
570                    valid_type = coerced_type;
571                } else {
572                    return plan_err!(
573                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
574                    );
575                }
576            }
577
578            let logical_data_type: NativeType = valid_type.clone().into();
579            // Fallback to default type if we don't know which type to coerced to
580            // f64 is chosen since most of the math functions utilize Signature::numeric,
581            // and their default type is double precision
582            if logical_data_type == NativeType::Null {
583                valid_type = DataType::Float64;
584            } else if !logical_data_type.is_numeric() {
585                return plan_err!(
586                    "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}"
587                );
588            }
589
590            vec![vec![valid_type; *number]]
591        }
592        TypeSignature::Comparable(num) => {
593            function_length_check(function_name, current_types.len(), *num)?;
594            let mut target_type = current_types[0].to_owned();
595            for data_type in current_types.iter().skip(1) {
596                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
597                    target_type = dt;
598                } else {
599                    return plan_err!(
600                        "For function '{function_name}' {target_type} and {data_type} is not comparable"
601                    );
602                }
603            }
604            // Convert null to String type.
605            if target_type.is_null() {
606                vec![vec![DataType::Utf8View; *num]]
607            } else {
608                vec![vec![target_type; *num]]
609            }
610        }
611        TypeSignature::Coercible(param_types) => {
612            function_length_check(function_name, current_types.len(), param_types.len())?;
613
614            let mut new_types = Vec::with_capacity(current_types.len());
615            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
616                let current_native_type: NativeType = current_type.into();
617
618                if param
619                    .desired_type()
620                    .matches_native_type(&current_native_type)
621                {
622                    let casted_type = param
623                        .desired_type()
624                        .default_casted_type(&current_native_type, current_type)?;
625
626                    new_types.push(casted_type);
627                } else if param
628                    .allowed_source_types()
629                    .iter()
630                    .any(|t| t.matches_native_type(&current_native_type))
631                {
632                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
633                    let default_casted_type = param.default_casted_type().unwrap();
634                    let casted_type =
635                        default_casted_type.default_cast_for(current_type)?;
636                    new_types.push(casted_type);
637                } else {
638                    let hint = if matches!(current_native_type, NativeType::Binary) {
639                        "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String."
640                    } else {
641                        ""
642                    };
643                    return plan_err!(
644                        "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}",
645                        param.desired_type(),
646                        current_native_type,
647                        current_type
648                    );
649                }
650            }
651
652            vec![new_types]
653        }
654        TypeSignature::Uniform(number, valid_types) => {
655            if *number == 0 {
656                return plan_err!(
657                    "The function '{function_name}' expected at least one argument"
658                );
659            }
660
661            valid_types
662                .iter()
663                .map(|valid_type| vec![valid_type.clone(); *number])
664                .collect()
665        }
666        TypeSignature::UserDefined => {
667            return internal_err!(
668                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
669            );
670        }
671        TypeSignature::VariadicAny => {
672            if current_types.is_empty() {
673                return plan_err!(
674                    "Function '{function_name}' expected at least one argument but received 0"
675                );
676            }
677            vec![current_types.to_vec()]
678        }
679        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
680        TypeSignature::ArraySignature(function_signature) => match function_signature {
681            ArrayFunctionSignature::Array {
682                arguments,
683                array_coercion,
684            } => array_valid_types(
685                function_name,
686                current_types,
687                arguments,
688                array_coercion.as_ref(),
689            )?,
690            ArrayFunctionSignature::RecursiveArray => {
691                if current_types.len() != 1 {
692                    return Ok(vec![vec![]]);
693                }
694                recursive_array(&current_types[0])
695                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
696            }
697            ArrayFunctionSignature::MapArray => {
698                if current_types.len() != 1 {
699                    return Ok(vec![vec![]]);
700                }
701
702                match &current_types[0] {
703                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
704                    _ => vec![vec![]],
705                }
706            }
707        },
708        TypeSignature::Nullary => {
709            if !current_types.is_empty() {
710                return plan_err!(
711                    "The function '{function_name}' expected zero argument but received {}",
712                    current_types.len()
713                );
714            }
715            vec![vec![]]
716        }
717        TypeSignature::Any(number) => {
718            if current_types.is_empty() {
719                return plan_err!(
720                    "The function '{function_name}' expected at least one argument but received 0"
721                );
722            }
723
724            if current_types.len() != *number {
725                return plan_err!(
726                    "The function '{function_name}' expected {number} arguments but received {}",
727                    current_types.len()
728                );
729            }
730            vec![current_types.to_vec()]
731        }
732        TypeSignature::OneOf(types) => types
733            .iter()
734            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
735            .flatten()
736            .collect::<Vec<_>>(),
737    };
738
739    Ok(valid_types)
740}
741
742/// Try to coerce the current argument types to match the given `valid_types`.
743///
744/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
745/// but was called with `(int32, int64)`, this function could match the
746/// valid_types by coercing the first argument to `int64`, and would return
747/// `Some([int64, int64])`.
748fn maybe_data_types(
749    valid_types: &[DataType],
750    current_types: &[DataType],
751) -> Option<Vec<DataType>> {
752    if valid_types.len() != current_types.len() {
753        return None;
754    }
755
756    let mut new_type = Vec::with_capacity(valid_types.len());
757    for (i, valid_type) in valid_types.iter().enumerate() {
758        let current_type = &current_types[i];
759
760        if current_type == valid_type {
761            new_type.push(current_type.clone())
762        } else {
763            // attempt to coerce.
764            // TODO: Replace with `can_cast_types` after failing cases are resolved
765            // (they need new signature that returns exactly valid types instead of list of possible valid types).
766            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
767                new_type.push(coerced_type)
768            } else {
769                // not possible
770                return None;
771            }
772        }
773    }
774    Some(new_type)
775}
776
777/// Check if the current argument types can be coerced to match the given `valid_types`
778/// unlike `maybe_data_types`, this function does not coerce the types.
779/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
780fn maybe_data_types_without_coercion(
781    valid_types: &[DataType],
782    current_types: &[DataType],
783) -> Option<Vec<DataType>> {
784    if valid_types.len() != current_types.len() {
785        return None;
786    }
787
788    let mut new_type = Vec::with_capacity(valid_types.len());
789    for (i, valid_type) in valid_types.iter().enumerate() {
790        let current_type = &current_types[i];
791
792        if current_type == valid_type {
793            new_type.push(current_type.clone())
794        } else if can_cast_types(current_type, valid_type) {
795            // validate the valid type is castable from the current type
796            new_type.push(valid_type.clone())
797        } else {
798            return None;
799        }
800    }
801    Some(new_type)
802}
803
804/// Return true if a value of type `type_from` can be coerced
805/// (losslessly converted) into a value of `type_to`
806///
807/// See the module level documentation for more detail on coercion.
808#[deprecated(since = "53.0.0", note = "Unused internal function")]
809pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
810    if type_into == type_from {
811        return true;
812    }
813    if let Some(coerced) = coerced_from(type_into, type_from) {
814        return coerced == *type_into;
815    }
816    false
817}
818
819/// Find the coerced type for the given `type_into` and `type_from`.
820/// Returns `None` if coercion is not possible.
821///
822/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
823///
824/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
825fn coerced_from<'a>(
826    type_into: &'a DataType,
827    type_from: &'a DataType,
828) -> Option<DataType> {
829    use self::DataType::*;
830
831    // match Dictionary first
832    match (type_into, type_from) {
833        // coerced dictionary first
834        (_, Dictionary(_, value_type))
835            if coerced_from(type_into, value_type).is_some() =>
836        {
837            Some(type_into.clone())
838        }
839        (Dictionary(_, value_type), _)
840            if coerced_from(value_type, type_from).is_some() =>
841        {
842            Some(type_into.clone())
843        }
844        // coerced into type_into
845        (Int8, Null | Int8) => Some(type_into.clone()),
846        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
847        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
848        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
849            Some(type_into.clone())
850        }
851        (UInt8, Null | UInt8) => Some(type_into.clone()),
852        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
853        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
854        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
855        (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => {
856            Some(type_into.clone())
857        }
858        (
859            Float32,
860            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
861            | Float16 | Float32,
862        ) => Some(type_into.clone()),
863        (
864            Float64,
865            Null
866            | Int8
867            | Int16
868            | Int32
869            | Int64
870            | UInt8
871            | UInt16
872            | UInt32
873            | UInt64
874            | Float16
875            | Float32
876            | Float64
877            | Decimal32(_, _)
878            | Decimal64(_, _)
879            | Decimal128(_, _)
880            | Decimal256(_, _),
881        ) => Some(type_into.clone()),
882        (
883            Timestamp(TimeUnit::Nanosecond, None),
884            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
885        ) => Some(type_into.clone()),
886        (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()),
887        // We can go into a Utf8View from a Utf8 or LargeUtf8
888        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
889        // Any type can be coerced into strings
890        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
891        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
892
893        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
894
895        // Only accept list and largelist with the same number of dimensions unless the type is Null.
896        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
897        (List(_) | LargeList(_), _)
898            if base_type(type_from).is_null()
899                || list_ndims(type_from) == list_ndims(type_into) =>
900        {
901            Some(type_into.clone())
902        }
903        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
904        (
905            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
906            FixedSizeList(f_from, size_from),
907        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
908            Some(data_type) if &data_type != f_into.data_type() => {
909                let new_field =
910                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
911                Some(FixedSizeList(new_field, *size_from))
912            }
913            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
914            _ => None,
915        },
916        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
917            match type_from {
918                Timestamp(_, Some(from_tz)) => {
919                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
920                }
921                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
922                    // In the absence of any other information assume the time zone is "+00" (UTC).
923                    Some(Timestamp(*unit, Some("+00".into())))
924                }
925                _ => None,
926            }
927        }
928        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
929            Some(type_into.clone())
930        }
931        _ => None,
932    }
933}
934
935#[cfg(test)]
936mod tests {
937    use crate::Volatility;
938
939    use super::*;
940    use arrow::datatypes::Field;
941    use datafusion_common::{
942        assert_contains,
943        types::{logical_binary, logical_int64},
944    };
945    use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
946
947    #[test]
948    fn test_string_conversion() {
949        let cases = vec![
950            (DataType::Utf8View, DataType::Utf8),
951            (DataType::Utf8View, DataType::LargeUtf8),
952        ];
953
954        for case in cases {
955            assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
956        }
957    }
958
959    #[test]
960    fn test_maybe_data_types() {
961        // this vec contains: arg1, arg2, expected result
962        let cases = vec![
963            // 2 entries, same values
964            (
965                vec![DataType::UInt8, DataType::UInt16],
966                vec![DataType::UInt8, DataType::UInt16],
967                Some(vec![DataType::UInt8, DataType::UInt16]),
968            ),
969            // 2 entries, can coerce values
970            (
971                vec![DataType::UInt16, DataType::UInt16],
972                vec![DataType::UInt8, DataType::UInt16],
973                Some(vec![DataType::UInt16, DataType::UInt16]),
974            ),
975            // 0 entries, all good
976            (vec![], vec![], Some(vec![])),
977            // 2 entries, can't coerce
978            (
979                vec![DataType::Boolean, DataType::UInt16],
980                vec![DataType::UInt8, DataType::UInt16],
981                None,
982            ),
983            // u32 -> u16 is possible
984            (
985                vec![DataType::Boolean, DataType::UInt32],
986                vec![DataType::Boolean, DataType::UInt16],
987                Some(vec![DataType::Boolean, DataType::UInt32]),
988            ),
989            // UTF8 -> Timestamp
990            (
991                vec![
992                    DataType::Timestamp(TimeUnit::Nanosecond, None),
993                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
994                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
995                ],
996                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
997                Some(vec![
998                    DataType::Timestamp(TimeUnit::Nanosecond, None),
999                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1000                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1001                ]),
1002            ),
1003        ];
1004
1005        for case in cases {
1006            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1007        }
1008    }
1009
1010    #[test]
1011    fn test_get_valid_types_numeric() -> Result<()> {
1012        let get_valid_types_flatten =
1013            |function_name: &str,
1014             signature: &TypeSignature,
1015             current_types: &[DataType]| {
1016                get_valid_types(function_name, signature, current_types)
1017                    .unwrap()
1018                    .into_iter()
1019                    .flatten()
1020                    .collect::<Vec<_>>()
1021            };
1022
1023        // Trivial case.
1024        let got = get_valid_types_flatten(
1025            "test",
1026            &TypeSignature::Numeric(1),
1027            &[DataType::Int32],
1028        );
1029        assert_eq!(got, [DataType::Int32]);
1030
1031        // Args are coerced into a common numeric type.
1032        let got = get_valid_types_flatten(
1033            "test",
1034            &TypeSignature::Numeric(2),
1035            &[DataType::Int32, DataType::Int64],
1036        );
1037        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1038
1039        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1040        let got = get_valid_types_flatten(
1041            "test",
1042            &TypeSignature::Numeric(3),
1043            &[DataType::Int32, DataType::Int64, DataType::Float64],
1044        );
1045        assert_eq!(
1046            got,
1047            [DataType::Float64, DataType::Float64, DataType::Float64]
1048        );
1049
1050        // Cannot coerce args to a common numeric type.
1051        let got = get_valid_types(
1052            "test",
1053            &TypeSignature::Numeric(2),
1054            &[DataType::Int32, DataType::Utf8],
1055        )
1056        .unwrap_err();
1057        assert_contains!(
1058            got.to_string(),
1059            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1060        );
1061
1062        // Fallbacks to float64 if the arg is of type null.
1063        let got = get_valid_types_flatten(
1064            "test",
1065            &TypeSignature::Numeric(1),
1066            &[DataType::Null],
1067        );
1068        assert_eq!(got, [DataType::Float64]);
1069
1070        // Rejects non-numeric arg.
1071        let got = get_valid_types(
1072            "test",
1073            &TypeSignature::Numeric(1),
1074            &[DataType::Timestamp(TimeUnit::Second, None)],
1075        )
1076        .unwrap_err();
1077        assert_contains!(
1078            got.to_string(),
1079            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(s)"
1080        );
1081
1082        Ok(())
1083    }
1084
1085    #[test]
1086    fn test_get_valid_types_one_of() -> Result<()> {
1087        let signature =
1088            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1089
1090        let invalid_types = get_valid_types(
1091            "test",
1092            &signature,
1093            &[DataType::Int32, DataType::Int32, DataType::Int32],
1094        )?;
1095        assert_eq!(invalid_types.len(), 0);
1096
1097        let args = vec![DataType::Int32, DataType::Int32];
1098        let valid_types = get_valid_types("test", &signature, &args)?;
1099        assert_eq!(valid_types.len(), 1);
1100        assert_eq!(valid_types[0], args);
1101
1102        let args = vec![DataType::Int32];
1103        let valid_types = get_valid_types("test", &signature, &args)?;
1104        assert_eq!(valid_types.len(), 1);
1105        assert_eq!(valid_types[0], args);
1106
1107        Ok(())
1108    }
1109
1110    #[test]
1111    fn test_get_valid_types_length_check() -> Result<()> {
1112        let signature = TypeSignature::Numeric(1);
1113
1114        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1115        assert_contains!(
1116            err.to_string(),
1117            "Function 'test' expects 1 arguments but received 0"
1118        );
1119
1120        let err = get_valid_types(
1121            "test",
1122            &signature,
1123            &[DataType::Int32, DataType::Int32, DataType::Int32],
1124        )
1125        .unwrap_err();
1126        assert_contains!(
1127            err.to_string(),
1128            "Function 'test' expects 1 arguments but received 3"
1129        );
1130
1131        Ok(())
1132    }
1133
1134    struct MockUdf(Signature);
1135
1136    impl UDFCoercionExt for MockUdf {
1137        fn name(&self) -> &str {
1138            "test"
1139        }
1140        fn signature(&self) -> &Signature {
1141            &self.0
1142        }
1143        fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1144            unimplemented!()
1145        }
1146    }
1147
1148    #[test]
1149    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1150        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1151        // able to coerce for any size
1152        let current_fields = vec![Arc::new(Field::new(
1153            "t",
1154            DataType::FixedSizeList(Arc::clone(&inner), 2),
1155            true,
1156        ))];
1157
1158        let signature = Signature::exact(
1159            vec![DataType::FixedSizeList(
1160                Arc::clone(&inner),
1161                FIXED_SIZE_LIST_WILDCARD,
1162            )],
1163            Volatility::Stable,
1164        );
1165
1166        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature))?;
1167        assert_eq!(coerced_fields, current_fields);
1168
1169        // make sure it can't coerce to a different size
1170        let signature = Signature::exact(
1171            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1172            Volatility::Stable,
1173        );
1174        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature));
1175        assert!(coerced_fields.is_err());
1176
1177        // make sure it works with the same type.
1178        let signature = Signature::exact(
1179            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1180            Volatility::Stable,
1181        );
1182        let coerced_fields =
1183            fields_with_udf(&current_fields, &MockUdf(signature)).unwrap();
1184        assert_eq!(coerced_fields, current_fields);
1185
1186        Ok(())
1187    }
1188
1189    #[test]
1190    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1191        let type_into = DataType::FixedSizeList(
1192            Arc::new(Field::new_list_field(
1193                DataType::FixedSizeList(
1194                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1195                    FIXED_SIZE_LIST_WILDCARD,
1196                ),
1197                false,
1198            )),
1199            FIXED_SIZE_LIST_WILDCARD,
1200        );
1201
1202        let type_from = DataType::FixedSizeList(
1203            Arc::new(Field::new_list_field(
1204                DataType::FixedSizeList(
1205                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1206                    4,
1207                ),
1208                false,
1209            )),
1210            3,
1211        );
1212
1213        assert_eq!(
1214            coerced_from(&type_into, &type_from),
1215            Some(DataType::FixedSizeList(
1216                Arc::new(Field::new_list_field(
1217                    DataType::FixedSizeList(
1218                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1219                        4,
1220                    ),
1221                    false,
1222                )),
1223                3,
1224            ))
1225        );
1226
1227        Ok(())
1228    }
1229
1230    #[test]
1231    fn test_coerced_from_dictionary() {
1232        let type_into =
1233            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1234        let type_from = DataType::Int64;
1235        assert_eq!(coerced_from(&type_into, &type_from), None);
1236
1237        let type_from =
1238            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1239        let type_into = DataType::Int64;
1240        assert_eq!(
1241            coerced_from(&type_into, &type_from),
1242            Some(type_into.clone())
1243        );
1244    }
1245
1246    #[test]
1247    fn test_get_valid_types_array_and_array() -> Result<()> {
1248        let function = "array_and_array";
1249        let signature = Signature::arrays(
1250            2,
1251            Some(ListCoercion::FixedSizedListToList),
1252            Volatility::Immutable,
1253        );
1254
1255        let data_types = vec![
1256            DataType::new_list(DataType::Int32, true),
1257            DataType::new_large_list(DataType::Float64, true),
1258        ];
1259        assert_eq!(
1260            get_valid_types(function, &signature.type_signature, &data_types)?,
1261            vec![vec![
1262                DataType::new_large_list(DataType::Float64, true),
1263                DataType::new_large_list(DataType::Float64, true),
1264            ]]
1265        );
1266
1267        let data_types = vec![
1268            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1269            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1270        ];
1271        assert_eq!(
1272            get_valid_types(function, &signature.type_signature, &data_types)?,
1273            vec![vec![
1274                DataType::new_list(DataType::Int64, true),
1275                DataType::new_list(DataType::Int64, true),
1276            ]]
1277        );
1278
1279        let data_types = vec![
1280            DataType::new_fixed_size_list(DataType::Null, 3, true),
1281            DataType::new_large_list(DataType::Utf8, true),
1282        ];
1283        assert_eq!(
1284            get_valid_types(function, &signature.type_signature, &data_types)?,
1285            vec![vec![
1286                DataType::new_large_list(DataType::Utf8, true),
1287                DataType::new_large_list(DataType::Utf8, true),
1288            ]]
1289        );
1290
1291        Ok(())
1292    }
1293
1294    #[test]
1295    fn test_get_valid_types_array_and_element() -> Result<()> {
1296        let function = "array_and_element";
1297        let signature = Signature::array_and_element(Volatility::Immutable);
1298
1299        let data_types =
1300            vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1301        assert_eq!(
1302            get_valid_types(function, &signature.type_signature, &data_types)?,
1303            vec![vec![
1304                DataType::new_list(DataType::Float64, true),
1305                DataType::Float64,
1306            ]]
1307        );
1308
1309        let data_types = vec![
1310            DataType::new_large_list(DataType::Int32, true),
1311            DataType::Null,
1312        ];
1313        assert_eq!(
1314            get_valid_types(function, &signature.type_signature, &data_types)?,
1315            vec![vec![
1316                DataType::new_large_list(DataType::Int32, true),
1317                DataType::Int32,
1318            ]]
1319        );
1320
1321        let data_types = vec![
1322            DataType::new_fixed_size_list(DataType::Null, 3, true),
1323            DataType::Utf8,
1324        ];
1325        assert_eq!(
1326            get_valid_types(function, &signature.type_signature, &data_types)?,
1327            vec![vec![
1328                DataType::new_list(DataType::Utf8, true),
1329                DataType::Utf8,
1330            ]]
1331        );
1332
1333        Ok(())
1334    }
1335
1336    #[test]
1337    fn test_get_valid_types_element_and_array() -> Result<()> {
1338        let function = "element_and_array";
1339        let signature = Signature::element_and_array(Volatility::Immutable);
1340
1341        let data_types = vec![
1342            DataType::new_large_list(DataType::Null, false),
1343            DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1344        ];
1345        assert_eq!(
1346            get_valid_types(function, &signature.type_signature, &data_types)?,
1347            vec![vec![
1348                DataType::new_large_list(DataType::Int64, true),
1349                DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1350            ]]
1351        );
1352
1353        Ok(())
1354    }
1355
1356    #[test]
1357    fn test_coercible_nulls() -> Result<()> {
1358        fn null_input(coercion: Coercion) -> Result<Vec<DataType>> {
1359            fields_with_udf(
1360                &[Field::new("field", DataType::Null, true).into()],
1361                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1362            )
1363            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1364        }
1365
1366        // Casts Null to Int64 if we use TypeSignatureClass::Native
1367        let output = null_input(Coercion::new_exact(TypeSignatureClass::Native(
1368            logical_int64(),
1369        )))?;
1370        assert_eq!(vec![DataType::Int64], output);
1371
1372        let output = null_input(Coercion::new_implicit(
1373            TypeSignatureClass::Native(logical_int64()),
1374            vec![],
1375            NativeType::Int64,
1376        ))?;
1377        assert_eq!(vec![DataType::Int64], output);
1378
1379        // Null gets passed through if we use TypeSignatureClass apart from Native
1380        let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1381        assert_eq!(vec![DataType::Null], output);
1382
1383        let output = null_input(Coercion::new_implicit(
1384            TypeSignatureClass::Integer,
1385            vec![],
1386            NativeType::Int64,
1387        ))?;
1388        assert_eq!(vec![DataType::Null], output);
1389
1390        Ok(())
1391    }
1392
1393    #[test]
1394    fn test_coercible_dictionary() -> Result<()> {
1395        let dictionary =
1396            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64));
1397        fn dictionary_input(coercion: Coercion) -> Result<Vec<DataType>> {
1398            fields_with_udf(
1399                &[Field::new(
1400                    "field",
1401                    DataType::Dictionary(
1402                        Box::new(DataType::Int8),
1403                        Box::new(DataType::Int64),
1404                    ),
1405                    true,
1406                )
1407                .into()],
1408                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1409            )
1410            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1411        }
1412
1413        // Casts Dictionary to Int64 if we use TypeSignatureClass::Native
1414        let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native(
1415            logical_int64(),
1416        )))?;
1417        assert_eq!(vec![DataType::Int64], output);
1418
1419        let output = dictionary_input(Coercion::new_implicit(
1420            TypeSignatureClass::Native(logical_int64()),
1421            vec![],
1422            NativeType::Int64,
1423        ))?;
1424        assert_eq!(vec![DataType::Int64], output);
1425
1426        // Dictionary gets passed through if we use TypeSignatureClass apart from Native
1427        let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1428        assert_eq!(vec![dictionary.clone()], output);
1429
1430        let output = dictionary_input(Coercion::new_implicit(
1431            TypeSignatureClass::Integer,
1432            vec![],
1433            NativeType::Int64,
1434        ))?;
1435        assert_eq!(vec![dictionary.clone()], output);
1436
1437        Ok(())
1438    }
1439
1440    #[test]
1441    fn test_coercible_run_end_encoded() -> Result<()> {
1442        let run_end_encoded = DataType::RunEndEncoded(
1443            Field::new("run_ends", DataType::Int16, false).into(),
1444            Field::new("values", DataType::Int64, true).into(),
1445        );
1446        fn run_end_encoded_input(coercion: Coercion) -> Result<Vec<DataType>> {
1447            fields_with_udf(
1448                &[Field::new(
1449                    "field",
1450                    DataType::RunEndEncoded(
1451                        Field::new("run_ends", DataType::Int16, false).into(),
1452                        Field::new("values", DataType::Int64, true).into(),
1453                    ),
1454                    true,
1455                )
1456                .into()],
1457                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1458            )
1459            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1460        }
1461
1462        // Casts REE to Int64 if we use TypeSignatureClass::Native
1463        let output = run_end_encoded_input(Coercion::new_exact(
1464            TypeSignatureClass::Native(logical_int64()),
1465        ))?;
1466        assert_eq!(vec![DataType::Int64], output);
1467
1468        let output = run_end_encoded_input(Coercion::new_implicit(
1469            TypeSignatureClass::Native(logical_int64()),
1470            vec![],
1471            NativeType::Int64,
1472        ))?;
1473        assert_eq!(vec![DataType::Int64], output);
1474
1475        // REE gets passed through if we use TypeSignatureClass apart from Native
1476        let output =
1477            run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1478        assert_eq!(vec![run_end_encoded.clone()], output);
1479
1480        let output = run_end_encoded_input(Coercion::new_implicit(
1481            TypeSignatureClass::Integer,
1482            vec![],
1483            NativeType::Int64,
1484        ))?;
1485        assert_eq!(vec![run_end_encoded.clone()], output);
1486
1487        Ok(())
1488    }
1489
1490    #[test]
1491    fn test_get_valid_types_coercible_binary() -> Result<()> {
1492        let signature = Signature::coercible(
1493            vec![Coercion::new_exact(TypeSignatureClass::Native(
1494                logical_binary(),
1495            ))],
1496            Volatility::Immutable,
1497        );
1498
1499        // Binary types should stay their original selves
1500        for t in [
1501            DataType::Binary,
1502            DataType::BinaryView,
1503            DataType::LargeBinary,
1504        ] {
1505            assert_eq!(
1506                get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1507                vec![vec![t]]
1508            );
1509        }
1510
1511        Ok(())
1512    }
1513
1514    #[test]
1515    fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1516        let function = "fixed_size_arrays";
1517        let signature = Signature::arrays(2, None, Volatility::Immutable);
1518
1519        let data_types = vec![
1520            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1521            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1522        ];
1523        assert_eq!(
1524            get_valid_types(function, &signature.type_signature, &data_types)?,
1525            vec![vec![
1526                DataType::new_fixed_size_list(DataType::Int64, 3, true),
1527                DataType::new_fixed_size_list(DataType::Int64, 5, true),
1528            ]]
1529        );
1530
1531        let data_types = vec![
1532            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1533            DataType::new_list(DataType::Int32, true),
1534        ];
1535        assert_eq!(
1536            get_valid_types(function, &signature.type_signature, &data_types)?,
1537            vec![vec![
1538                DataType::new_list(DataType::Int64, true),
1539                DataType::new_list(DataType::Int64, true),
1540            ]]
1541        );
1542
1543        let data_types = vec![
1544            DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1545            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1546        ];
1547        assert_eq!(
1548            get_valid_types(function, &signature.type_signature, &data_types)?,
1549            vec![vec![]]
1550        );
1551
1552        let data_types = vec![
1553            DataType::new_fixed_size_list(DataType::Int64, 3, false),
1554            DataType::new_list(DataType::Int32, false),
1555        ];
1556        assert_eq!(
1557            get_valid_types(function, &signature.type_signature, &data_types)?,
1558            vec![vec![
1559                DataType::new_list(DataType::Int64, false),
1560                DataType::new_list(DataType::Int64, false),
1561            ]]
1562        );
1563
1564        Ok(())
1565    }
1566}