Skip to main content

datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::binary_numeric_coercion;
19use crate::{
20    AggregateUDF, HigherOrderTypeSignature, HigherOrderUDF, ScalarUDF, Signature,
21    TypeSignature, ValueOrLambda, WindowUDF,
22};
23use arrow::datatypes::{Field, FieldRef};
24use arrow::{
25    compute::can_cast_types,
26    datatypes::{DataType, TimeUnit},
27};
28use datafusion_common::internal_datafusion_err;
29use datafusion_common::types::LogicalType;
30use datafusion_common::utils::{
31    ListCoercion, base_type, coerced_fixed_size_list_to_list,
32};
33use datafusion_common::{
34    Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims,
35};
36use datafusion_expr_common::signature::ArrayFunctionArgument;
37use datafusion_expr_common::type_coercion::binary::type_union_resolution;
38use datafusion_expr_common::{
39    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
40    type_coercion::binary::comparison_coercion,
41    type_coercion::binary::string_coercion,
42};
43use itertools::Itertools as _;
44use std::sync::Arc;
45
46/// Extension trait to unify common functionality between [`ScalarUDF`], [`AggregateUDF`]
47/// and [`WindowUDF`] for use by signature coercion functions.
48pub trait UDFCoercionExt {
49    /// Should delegate to [`ScalarUDF::name`], [`AggregateUDF::name`] or [`WindowUDF::name`].
50    fn name(&self) -> &str;
51    /// Should delegate to [`ScalarUDF::signature`], [`AggregateUDF::signature`]
52    /// or [`WindowUDF::signature`].
53    fn signature(&self) -> &Signature;
54    /// Should delegate to [`ScalarUDF::coerce_types`], [`AggregateUDF::coerce_types`]
55    /// or [`WindowUDF::coerce_types`].
56    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>>;
57}
58
59impl UDFCoercionExt for ScalarUDF {
60    fn name(&self) -> &str {
61        self.name()
62    }
63
64    fn signature(&self) -> &Signature {
65        self.signature()
66    }
67
68    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
69        self.coerce_types(arg_types)
70    }
71}
72
73impl UDFCoercionExt for AggregateUDF {
74    fn name(&self) -> &str {
75        self.name()
76    }
77
78    fn signature(&self) -> &Signature {
79        self.signature()
80    }
81
82    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
83        self.coerce_types(arg_types)
84    }
85}
86
87impl UDFCoercionExt for WindowUDF {
88    fn name(&self) -> &str {
89        self.name()
90    }
91
92    fn signature(&self) -> &Signature {
93        self.signature()
94    }
95
96    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
97        self.coerce_types(arg_types)
98    }
99}
100
101/// Performs type coercion for UDF arguments.
102///
103/// Returns the data types to which each argument must be coerced to
104/// match `signature`.
105///
106/// For more details on coercion in general, please see the
107/// [`type_coercion`](crate::type_coercion) module.
108pub fn fields_with_udf<F: UDFCoercionExt>(
109    current_fields: &[FieldRef],
110    func: &F,
111) -> Result<Vec<FieldRef>> {
112    let signature = func.signature();
113    let type_signature = &signature.type_signature;
114
115    if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
116        if type_signature.supports_zero_argument() {
117            return Ok(vec![]);
118        } else if type_signature.used_to_support_zero_arguments() {
119            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
120            return plan_err!(
121                "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
122                func.name()
123            );
124        } else {
125            return plan_err!("'{}' does not support zero arguments", func.name());
126        }
127    }
128    let current_types = current_fields
129        .iter()
130        .map(|f| f.data_type())
131        .cloned()
132        .collect::<Vec<_>>();
133
134    let valid_types = get_valid_types_with_udf(type_signature, &current_types, func)?;
135    if valid_types
136        .iter()
137        .any(|data_type| data_type == &current_types)
138    {
139        return Ok(current_fields.to_vec());
140    }
141
142    let updated_types =
143        try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
144
145    Ok(current_fields
146        .iter()
147        .zip(updated_types)
148        .map(|(current_field, new_type)| {
149            current_field.as_ref().clone().with_data_type(new_type)
150        })
151        .map(Arc::new)
152        .collect())
153}
154
155/// Performs type coercion for higher order function arguments.
156///
157/// For value arguments, returns the field to which each
158/// argument must be coerced to match `signature`.
159/// For lambda arguments, returns a clone of the associated data
160///
161/// Note this does not invokes [crate::HigherOrderUDFImpl::coerce_values_for_lambdas].
162/// If that's required, use [value_fields_with_higher_order_udf_and_lambdas]
163/// instead
164///
165/// For more details on coercion in general, please see the
166/// [`type_coercion`](crate::type_coercion) module.
167pub fn value_fields_with_higher_order_udf<L: Clone>(
168    current_fields: &[ValueOrLambda<FieldRef, L>],
169    func: &HigherOrderUDF,
170) -> Result<Vec<ValueOrLambda<FieldRef, L>>> {
171    match func.signature().type_signature {
172        HigherOrderTypeSignature::UserDefined => {
173            let arg_types = current_fields
174                .iter()
175                .filter_map(|p| match p {
176                    ValueOrLambda::Value(field) => Some(field.data_type().clone()),
177                    ValueOrLambda::Lambda(_) => None,
178                })
179                .collect::<Vec<_>>();
180
181            let coerced_types = func.coerce_value_types(&arg_types)?;
182
183            if coerced_types.len() != arg_types.len() {
184                return plan_err!(
185                    "{} coerce_value_types should have returned {} items but returned {}",
186                    func.name(),
187                    arg_types.len(),
188                    coerced_types.len()
189                );
190            }
191
192            // coerced_types has been partitioned from current_fields
193            // and refers only to values and not to lambdas, so instead
194            // of zipping them, we iterate over current_fields and only
195            // consume from coerced_types when a given argument is a value
196            // to reconstruct the arguments list with the correct order
197            // this supports any value and lambda positioning including
198            // multiple lambdas interleaved with values
199            let mut coerced_types = coerced_types.into_iter();
200
201            current_fields
202                .iter()
203                .map(|current_field| match current_field {
204                    ValueOrLambda::Value(field) => {
205                        let data_type = coerced_types.next().ok_or_else(|| {
206                            internal_datafusion_err!(
207                                "coerced_types len should have been checked above"
208                            )
209                        })?;
210
211                        Ok(ValueOrLambda::Value(Arc::new(
212                            field.as_ref().clone().with_data_type(data_type),
213                        )))
214                    }
215                    ValueOrLambda::Lambda(lambda) => {
216                        Ok(ValueOrLambda::Lambda(lambda.clone()))
217                    }
218                })
219                .collect()
220        }
221        HigherOrderTypeSignature::VariadicAny => Ok(current_fields.to_vec()),
222        HigherOrderTypeSignature::Any(number) => {
223            if current_fields.len() != number {
224                return plan_err!(
225                    "The function '{}' expected {number} arguments but received {}",
226                    func.name(),
227                    current_fields.len()
228                );
229            }
230
231            Ok(current_fields.to_vec())
232        }
233        HigherOrderTypeSignature::Exact(ref expected) => {
234            if current_fields.len() != expected.len() {
235                let name = func.name();
236                let expected_len = expected.len();
237                let actual_len = current_fields.len();
238                return plan_err!(
239                    "The function '{name}' expected {expected_len} argument(s) but received {actual_len}"
240                );
241            }
242
243            for (i, (actual, expected)) in
244                current_fields.iter().zip(expected.iter()).enumerate()
245            {
246                match (actual, expected) {
247                    (ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {}
248                    (ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {}
249                    (ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => {
250                        let name = func.name();
251                        return plan_err!(
252                            "The function '{name}' expected a lambda at position {i} but received a value"
253                        );
254                    }
255                    (ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => {
256                        let name = func.name();
257                        return plan_err!(
258                            "The function '{name}' expected a value at position {i} but received a lambda"
259                        );
260                    }
261                }
262            }
263
264            let arg_types = current_fields
265                .iter()
266                .filter_map(|p| match p {
267                    ValueOrLambda::Value(field) => Some(field.data_type().clone()),
268                    ValueOrLambda::Lambda(_) => None,
269                })
270                .collect::<Vec<_>>();
271
272            let coerced_types = func.coerce_value_types(&arg_types)?;
273
274            if coerced_types.len() != arg_types.len() {
275                return plan_err!(
276                    "{} coerce_value_types should have returned {} items but returned {}",
277                    func.name(),
278                    arg_types.len(),
279                    coerced_types.len()
280                );
281            }
282
283            let mut coerced_types = coerced_types.into_iter();
284
285            current_fields
286                .iter()
287                .map(|current_field| match current_field {
288                    ValueOrLambda::Value(field) => {
289                        let data_type = coerced_types.next().ok_or_else(|| {
290                            internal_datafusion_err!(
291                                "coerced_types len should have been checked above"
292                            )
293                        })?;
294
295                        Ok(ValueOrLambda::Value(Arc::new(
296                            field.as_ref().clone().with_data_type(data_type),
297                        )))
298                    }
299                    ValueOrLambda::Lambda(lambda) => {
300                        Ok(ValueOrLambda::Lambda(lambda.clone()))
301                    }
302                })
303                .collect()
304        }
305    }
306}
307
308/// Performs type coercion for higher order function arguments,
309/// including those defined by [crate::HigherOrderUDFImpl::coerce_values_for_lambdas],
310/// if it returns `Some(...)` instead of the default `None`. Note that
311/// compared to [value_fields_with_higher_order_udf], this function requires
312/// the [ValueOrLambda::Lambda] variant to contain the output field of the lambda.
313///
314/// For value arguments, returns the field to which each
315/// argument must be coerced to match `signature`.
316/// For lambda arguments, returns a clone of the output field
317///
318/// For more details on coercion in general, please see the
319/// [`type_coercion`](crate::type_coercion) module.
320pub fn value_fields_with_higher_order_udf_and_lambdas(
321    current_fields: &[ValueOrLambda<FieldRef, FieldRef>],
322    func: &HigherOrderUDF,
323) -> Result<Vec<ValueOrLambda<FieldRef, FieldRef>>> {
324    let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?;
325
326    let new_types = new_fields
327        .iter()
328        .map(|f| match f {
329            ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
330            ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
331        })
332        .collect::<Vec<_>>();
333
334    if let Some(new_value_types) = func.coerce_values_for_lambdas(&new_types)? {
335        let mut new_value_types = new_value_types.into_iter();
336
337        let value_types_count = new_types
338            .iter()
339            .filter(|e| matches!(e, ValueOrLambda::Value(_)))
340            .count();
341
342        if new_value_types.len() != value_types_count {
343            return plan_err!(
344                "{} coerce_values_for_lambdas returned {} values but {value_types_count} expected",
345                func.name(),
346                new_value_types.len()
347            );
348        }
349
350        for new_field in &mut new_fields {
351            match new_field {
352                ValueOrLambda::Value(value) => {
353                    let coerce_to = new_value_types.next().ok_or_else(|| {
354                        internal_datafusion_err!(
355                            "new_value_types len should have been checked above"
356                        )
357                    })?;
358
359                    if value.data_type() != &coerce_to {
360                        Arc::make_mut(value).set_data_type(coerce_to);
361                    }
362                }
363                ValueOrLambda::Lambda(_) => {}
364            }
365        }
366    };
367
368    Ok(new_fields)
369}
370
371/// Performs type coercion for scalar function arguments.
372///
373/// Returns the data types to which each argument must be coerced to
374/// match `signature`.
375///
376/// For more details on coercion in general, please see the
377/// [`type_coercion`](crate::type_coercion) module.
378#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
379pub fn data_types_with_scalar_udf(
380    current_types: &[DataType],
381    func: &ScalarUDF,
382) -> Result<Vec<DataType>> {
383    let current_fields = current_types
384        .iter()
385        .map(|dt| Arc::new(Field::new("f", dt.clone(), true)))
386        .collect::<Vec<_>>();
387    Ok(fields_with_udf(&current_fields, func)?
388        .iter()
389        .map(|f| f.data_type().clone())
390        .collect())
391}
392
393/// Performs type coercion for aggregate function arguments.
394///
395/// Returns the fields to which each argument must be coerced to
396/// match `signature`.
397///
398/// For more details on coercion in general, please see the
399/// [`type_coercion`](crate::type_coercion) module.
400#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
401pub fn fields_with_aggregate_udf(
402    current_fields: &[FieldRef],
403    func: &AggregateUDF,
404) -> Result<Vec<FieldRef>> {
405    fields_with_udf(current_fields, func)
406}
407
408/// Performs type coercion for window function arguments.
409///
410/// Returns the data types to which each argument must be coerced to
411/// match `signature`.
412///
413/// For more details on coercion in general, please see the
414/// [`type_coercion`](crate::type_coercion) module.
415#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
416pub fn fields_with_window_udf(
417    current_fields: &[FieldRef],
418    func: &WindowUDF,
419) -> Result<Vec<FieldRef>> {
420    fields_with_udf(current_fields, func)
421}
422
423/// Performs type coercion for function arguments.
424///
425/// Returns the data types to which each argument must be coerced to
426/// match `signature`.
427///
428/// For more details on coercion in general, please see the
429/// [`type_coercion`](crate::type_coercion) module.
430#[deprecated(since = "52.0.0", note = "use fields_with_udf")]
431pub fn data_types(
432    function_name: impl AsRef<str>,
433    current_types: &[DataType],
434    signature: &Signature,
435) -> Result<Vec<DataType>> {
436    let type_signature = &signature.type_signature;
437
438    if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
439        if type_signature.supports_zero_argument() {
440            return Ok(vec![]);
441        } else if type_signature.used_to_support_zero_arguments() {
442            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
443            return plan_err!(
444                "function '{}' has signature {type_signature} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
445                function_name.as_ref()
446            );
447        } else {
448            return plan_err!(
449                "Function '{}' has signature {type_signature} which does not support zero arguments",
450                function_name.as_ref()
451            );
452        }
453    }
454
455    let valid_types =
456        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
457    if valid_types
458        .iter()
459        .any(|data_type| data_type == current_types)
460    {
461        return Ok(current_types.to_vec());
462    }
463
464    try_coerce_types(
465        function_name.as_ref(),
466        valid_types,
467        current_types,
468        type_signature,
469    )
470}
471
472fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
473    match type_signature {
474        TypeSignature::OneOf(type_signatures) => {
475            type_signatures.iter().all(is_well_supported_signature)
476        }
477        TypeSignature::UserDefined
478        | TypeSignature::Numeric(_)
479        | TypeSignature::String(_)
480        | TypeSignature::Coercible(_)
481        | TypeSignature::Any(_)
482        | TypeSignature::Nullary
483        | TypeSignature::Comparable(_) => true,
484        TypeSignature::Variadic(_)
485        | TypeSignature::VariadicAny
486        | TypeSignature::Uniform(_, _)
487        | TypeSignature::Exact(_)
488        | TypeSignature::ArraySignature(_) => false,
489    }
490}
491
492fn try_coerce_types(
493    function_name: &str,
494    valid_types: Vec<Vec<DataType>>,
495    current_types: &[DataType],
496    type_signature: &TypeSignature,
497) -> Result<Vec<DataType>> {
498    let mut valid_types = valid_types;
499
500    // Well-supported signature that returns exact valid types.
501    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
502        // There may be many valid types if valid signature is OneOf
503        // Otherwise, there should be only one valid type
504        if !type_signature.is_one_of() {
505            assert_eq!(valid_types.len(), 1);
506        }
507
508        let valid_types = valid_types.swap_remove(0);
509        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
510            return Ok(t);
511        }
512    } else {
513        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
514        // Try and coerce the argument types to match the signature, returning the
515        // coerced types from the first matching signature.
516        for valid_types in valid_types {
517            if let Some(types) = maybe_data_types(&valid_types, current_types) {
518                return Ok(types);
519            }
520        }
521    }
522
523    // none possible -> Error
524    plan_err!(
525        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature} failed",
526        current_types.iter().join(", ")
527    )
528}
529
530fn get_valid_types_with_udf<F: UDFCoercionExt>(
531    signature: &TypeSignature,
532    current_types: &[DataType],
533    func: &F,
534) -> Result<Vec<Vec<DataType>>> {
535    let valid_types = match signature {
536        TypeSignature::UserDefined => match func.coerce_types(current_types) {
537            Ok(coerced_types) => vec![coerced_types],
538            Err(e) => {
539                return exec_err!(
540                    "Function '{}' user-defined coercion failed with: {}",
541                    func.name(),
542                    e.strip_backtrace()
543                );
544            }
545        },
546        TypeSignature::OneOf(signatures) => {
547            let mut res = vec![];
548            let mut errors = vec![];
549            for sig in signatures {
550                match get_valid_types_with_udf(sig, current_types, func) {
551                    Ok(valid_types) => {
552                        res.extend(valid_types);
553                    }
554                    Err(e) => {
555                        errors.push(e.to_string());
556                    }
557                }
558            }
559
560            // Every signature failed, return the joined error
561            if res.is_empty() {
562                return internal_err!(
563                    "Function '{}' failed to match any signature, errors: {}",
564                    func.name(),
565                    errors.join(",")
566                );
567            } else {
568                res
569            }
570        }
571        _ => get_valid_types(func.name(), signature, current_types)?,
572    };
573
574    Ok(valid_types)
575}
576
577/// Returns a Vec of all possible valid argument types for the given signature.
578fn get_valid_types(
579    function_name: &str,
580    signature: &TypeSignature,
581    current_types: &[DataType],
582) -> Result<Vec<Vec<DataType>>> {
583    fn array_valid_types(
584        function_name: &str,
585        current_types: &[DataType],
586        arguments: &[ArrayFunctionArgument],
587        array_coercion: Option<&ListCoercion>,
588    ) -> Result<Vec<Vec<DataType>>> {
589        if current_types.len() != arguments.len() {
590            return Ok(vec![vec![]]);
591        }
592
593        let mut large_list = false;
594        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
595        let mut list_sizes = Vec::with_capacity(arguments.len());
596        let mut element_types = Vec::with_capacity(arguments.len());
597        let mut nested_item_nullability = Vec::with_capacity(arguments.len());
598        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
599            match argument {
600                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
601                    nested_item_nullability.push(None);
602                }
603                ArrayFunctionArgument::Element => {
604                    element_types.push(current_type.clone());
605                    nested_item_nullability.push(None);
606                }
607                ArrayFunctionArgument::Array => match current_type {
608                    DataType::Null => {
609                        element_types.push(DataType::Null);
610                        nested_item_nullability.push(None);
611                    }
612                    DataType::List(field) | DataType::ListView(field) => {
613                        element_types.push(field.data_type().clone());
614                        nested_item_nullability.push(Some(field.is_nullable()));
615                        fixed_size = false;
616                    }
617                    DataType::LargeList(field) | DataType::LargeListView(field) => {
618                        element_types.push(field.data_type().clone());
619                        nested_item_nullability.push(Some(field.is_nullable()));
620                        large_list = true;
621                        fixed_size = false;
622                    }
623                    DataType::FixedSizeList(field, size) => {
624                        element_types.push(field.data_type().clone());
625                        nested_item_nullability.push(Some(field.is_nullable()));
626                        list_sizes.push(*size)
627                    }
628                    arg_type => {
629                        plan_err!("{function_name} does not support type {arg_type}")?
630                    }
631                },
632            }
633        }
634
635        debug_assert_eq!(nested_item_nullability.len(), arguments.len());
636
637        let Some(element_type) = type_union_resolution(&element_types) else {
638            return Ok(vec![vec![]]);
639        };
640
641        if !fixed_size {
642            list_sizes.clear()
643        };
644
645        let mut list_sizes = list_sizes.into_iter();
646        let valid_types = arguments
647            .iter()
648            .zip(current_types.iter())
649            .zip(nested_item_nullability)
650            .map(|((argument_type, current_type), is_nested_item_nullable)| {
651                match argument_type {
652                    ArrayFunctionArgument::Index => DataType::Int64,
653                    ArrayFunctionArgument::String => DataType::Utf8,
654                    ArrayFunctionArgument::Element => element_type.clone(),
655                    // TODO: support maintaining ListView types here
656                    // https://github.com/apache/datafusion/issues/21777
657                    ArrayFunctionArgument::Array => {
658                        if current_type.is_null() {
659                            DataType::Null
660                        } else if large_list {
661                            DataType::new_large_list(
662                                element_type.clone(),
663                                is_nested_item_nullable.unwrap_or(true),
664                            )
665                        } else if let Some(size) = list_sizes.next() {
666                            DataType::new_fixed_size_list(
667                                element_type.clone(),
668                                size,
669                                is_nested_item_nullable.unwrap_or(true),
670                            )
671                        } else {
672                            DataType::new_list(
673                                element_type.clone(),
674                                is_nested_item_nullable.unwrap_or(true),
675                            )
676                        }
677                    }
678                }
679            });
680
681        Ok(vec![valid_types.collect()])
682    }
683
684    fn recursive_array(array_type: &DataType) -> Option<DataType> {
685        match array_type {
686            DataType::List(_)
687            | DataType::LargeList(_)
688            | DataType::ListView(_)
689            | DataType::LargeListView(_)
690            | DataType::FixedSizeList(_, _) => {
691                let array_type = coerced_fixed_size_list_to_list(array_type);
692                Some(array_type)
693            }
694            _ => None,
695        }
696    }
697
698    fn function_length_check(
699        function_name: &str,
700        length: usize,
701        expected_length: usize,
702    ) -> Result<()> {
703        if length != expected_length {
704            return plan_err!(
705                "Function '{function_name}' expects {expected_length} arguments but received {length}"
706            );
707        }
708        Ok(())
709    }
710
711    let valid_types = match signature {
712        TypeSignature::Variadic(valid_types) => valid_types
713            .iter()
714            .map(|valid_type| vec![valid_type.clone(); current_types.len()])
715            .collect(),
716        TypeSignature::String(number) => {
717            function_length_check(function_name, current_types.len(), *number)?;
718
719            let mut new_types = Vec::with_capacity(current_types.len());
720            for data_type in current_types.iter() {
721                let logical_data_type: NativeType = data_type.into();
722                if logical_data_type == NativeType::String {
723                    new_types.push(data_type.to_owned());
724                } else if logical_data_type == NativeType::Null {
725                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
726                    new_types.push(DataType::Utf8);
727                } else {
728                    return plan_err!(
729                        "Function '{function_name}' expects String but received {logical_data_type}"
730                    );
731                }
732            }
733
734            // Find the common string type for the given types
735            fn find_common_type(
736                function_name: &str,
737                lhs_type: &DataType,
738                rhs_type: &DataType,
739            ) -> Result<DataType> {
740                match (lhs_type, rhs_type) {
741                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
742                        find_common_type(function_name, lhs, rhs)
743                    }
744                    (DataType::Dictionary(_, v), other)
745                    | (other, DataType::Dictionary(_, v)) => {
746                        find_common_type(function_name, v, other)
747                    }
748                    _ => {
749                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
750                            Ok(coerced_type)
751                        } else {
752                            plan_err!(
753                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
754                            )
755                        }
756                    }
757                }
758            }
759
760            // Length checked above, safe to unwrap
761            let mut coerced_type = new_types.first().unwrap().to_owned();
762            for t in new_types.iter().skip(1) {
763                coerced_type = find_common_type(function_name, &coerced_type, t)?;
764            }
765
766            fn base_type_or_default_type(data_type: &DataType) -> DataType {
767                if let DataType::Dictionary(_, v) = data_type {
768                    base_type_or_default_type(v)
769                } else {
770                    data_type.to_owned()
771                }
772            }
773
774            vec![vec![base_type_or_default_type(&coerced_type); *number]]
775        }
776        TypeSignature::Numeric(number) => {
777            function_length_check(function_name, current_types.len(), *number)?;
778
779            // Find common numeric type among given types except string
780            let mut valid_type = current_types.first().unwrap().to_owned();
781            for t in current_types.iter().skip(1) {
782                let logical_data_type: NativeType = t.into();
783                if logical_data_type == NativeType::Null {
784                    continue;
785                }
786
787                if !logical_data_type.is_numeric() {
788                    return plan_err!(
789                        "Function '{function_name}' expects Numeric but received {logical_data_type}"
790                    );
791                }
792
793                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
794                    valid_type = coerced_type;
795                } else {
796                    return plan_err!(
797                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
798                    );
799                }
800            }
801
802            let logical_data_type: NativeType = valid_type.clone().into();
803            // Fallback to default type if we don't know which type to coerced to
804            // f64 is chosen since most of the math functions utilize Signature::numeric,
805            // and their default type is double precision
806            if logical_data_type == NativeType::Null {
807                valid_type = DataType::Float64;
808            } else if !logical_data_type.is_numeric() {
809                return plan_err!(
810                    "Function '{function_name}' expects Numeric but received {logical_data_type}"
811                );
812            }
813
814            vec![vec![valid_type; *number]]
815        }
816        TypeSignature::Comparable(num) => {
817            function_length_check(function_name, current_types.len(), *num)?;
818            let mut target_type = current_types[0].to_owned();
819            for data_type in current_types.iter().skip(1) {
820                if let Some(dt) = comparison_coercion(&target_type, data_type) {
821                    target_type = dt;
822                } else {
823                    return plan_err!(
824                        "For function '{function_name}' {target_type} and {data_type} is not comparable"
825                    );
826                }
827            }
828            // Convert null to String type.
829            if target_type.is_null() {
830                vec![vec![DataType::Utf8View; *num]]
831            } else {
832                vec![vec![target_type; *num]]
833            }
834        }
835        TypeSignature::Coercible(param_types) => {
836            function_length_check(function_name, current_types.len(), param_types.len())?;
837
838            let mut new_types = Vec::with_capacity(current_types.len());
839            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
840                let current_native_type: NativeType = current_type.into();
841
842                if param
843                    .desired_type()
844                    .matches_native_type(&current_native_type)
845                {
846                    let casted_type = param
847                        .desired_type()
848                        .default_casted_type(&current_native_type, current_type)?;
849
850                    new_types.push(casted_type);
851                } else if param
852                    .allowed_source_types()
853                    .iter()
854                    .any(|t| t.matches_native_type(&current_native_type))
855                {
856                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
857                    let default_casted_type = param.default_casted_type().unwrap();
858                    let casted_type =
859                        default_casted_type.default_cast_for(current_type)?;
860                    new_types.push(casted_type);
861                } else {
862                    let hint = if matches!(current_native_type, NativeType::Binary) {
863                        "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String."
864                    } else {
865                        ""
866                    };
867                    return plan_err!(
868                        "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}",
869                        param.desired_type(),
870                        current_native_type,
871                        current_type
872                    );
873                }
874            }
875
876            vec![new_types]
877        }
878        TypeSignature::Uniform(number, valid_types) => {
879            if *number == 0 {
880                return plan_err!(
881                    "The function '{function_name}' expected at least one argument"
882                );
883            }
884
885            valid_types
886                .iter()
887                .map(|valid_type| vec![valid_type.clone(); *number])
888                .collect()
889        }
890        TypeSignature::UserDefined => {
891            return internal_err!(
892                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
893            );
894        }
895        TypeSignature::VariadicAny => {
896            if current_types.is_empty() {
897                return plan_err!(
898                    "Function '{function_name}' expected at least one argument but received 0"
899                );
900            }
901            vec![current_types.to_vec()]
902        }
903        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
904        TypeSignature::ArraySignature(function_signature) => match function_signature {
905            ArrayFunctionSignature::Array {
906                arguments,
907                array_coercion,
908            } => array_valid_types(
909                function_name,
910                current_types,
911                arguments,
912                array_coercion.as_ref(),
913            )?,
914            ArrayFunctionSignature::RecursiveArray => {
915                if current_types.len() != 1 {
916                    return Ok(vec![vec![]]);
917                }
918                recursive_array(&current_types[0])
919                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
920            }
921            ArrayFunctionSignature::MapArray => {
922                if current_types.len() != 1 {
923                    return Ok(vec![vec![]]);
924                }
925
926                match &current_types[0] {
927                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
928                    _ => vec![vec![]],
929                }
930            }
931        },
932        TypeSignature::Nullary => {
933            if !current_types.is_empty() {
934                return plan_err!(
935                    "The function '{function_name}' expected zero argument but received {}",
936                    current_types.len()
937                );
938            }
939            vec![vec![]]
940        }
941        TypeSignature::Any(number) => {
942            if current_types.is_empty() {
943                return plan_err!(
944                    "The function '{function_name}' expected at least one argument but received 0"
945                );
946            }
947
948            if current_types.len() != *number {
949                return plan_err!(
950                    "The function '{function_name}' expected {number} arguments but received {}",
951                    current_types.len()
952                );
953            }
954            vec![current_types.to_vec()]
955        }
956        TypeSignature::OneOf(types) => types
957            .iter()
958            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
959            .flatten()
960            .collect::<Vec<_>>(),
961    };
962
963    Ok(valid_types)
964}
965
966/// Try to coerce the current argument types to match the given `valid_types`.
967///
968/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
969/// but was called with `(int32, int64)`, this function could match the
970/// valid_types by coercing the first argument to `int64`, and would return
971/// `Some([int64, int64])`.
972fn maybe_data_types(
973    valid_types: &[DataType],
974    current_types: &[DataType],
975) -> Option<Vec<DataType>> {
976    if valid_types.len() != current_types.len() {
977        return None;
978    }
979
980    let mut new_type = Vec::with_capacity(valid_types.len());
981    for (i, valid_type) in valid_types.iter().enumerate() {
982        let current_type = &current_types[i];
983
984        if current_type == valid_type {
985            new_type.push(current_type.clone())
986        } else {
987            // attempt to coerce.
988            // TODO: Replace with `can_cast_types` after failing cases are resolved
989            // (they need new signature that returns exactly valid types instead of list of possible valid types).
990            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
991                new_type.push(coerced_type)
992            } else {
993                // not possible
994                return None;
995            }
996        }
997    }
998    Some(new_type)
999}
1000
1001/// Check if the current argument types can be coerced to match the given `valid_types`
1002/// unlike `maybe_data_types`, this function does not coerce the types.
1003/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
1004fn maybe_data_types_without_coercion(
1005    valid_types: &[DataType],
1006    current_types: &[DataType],
1007) -> Option<Vec<DataType>> {
1008    if valid_types.len() != current_types.len() {
1009        return None;
1010    }
1011
1012    let mut new_type = Vec::with_capacity(valid_types.len());
1013    for (i, valid_type) in valid_types.iter().enumerate() {
1014        let current_type = &current_types[i];
1015
1016        if current_type == valid_type {
1017            new_type.push(current_type.clone())
1018        } else if can_cast_types(current_type, valid_type) {
1019            // validate the valid type is castable from the current type
1020            new_type.push(valid_type.clone())
1021        } else {
1022            return None;
1023        }
1024    }
1025    Some(new_type)
1026}
1027
1028/// Return true if a value of type `type_from` can be coerced
1029/// (losslessly converted) into a value of `type_to`
1030///
1031/// See the module level documentation for more detail on coercion.
1032#[deprecated(since = "53.0.0", note = "Unused internal function")]
1033pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
1034    if type_into == type_from {
1035        return true;
1036    }
1037    if let Some(coerced) = coerced_from(type_into, type_from) {
1038        return coerced == *type_into;
1039    }
1040    false
1041}
1042
1043/// Find the coerced type for the given `type_into` and `type_from`.
1044/// Returns `None` if coercion is not possible.
1045///
1046/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
1047///
1048/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion.
1049fn coerced_from<'a>(
1050    type_into: &'a DataType,
1051    type_from: &'a DataType,
1052) -> Option<DataType> {
1053    use self::DataType::*;
1054
1055    // match Dictionary first
1056    match (type_into, type_from) {
1057        // coerced dictionary first
1058        (_, Dictionary(_, value_type))
1059            if coerced_from(type_into, value_type).is_some() =>
1060        {
1061            Some(type_into.clone())
1062        }
1063        (Dictionary(_, value_type), _)
1064            if coerced_from(value_type, type_from).is_some() =>
1065        {
1066            Some(type_into.clone())
1067        }
1068        // coerced into type_into
1069        (Int8, Null | Int8) => Some(type_into.clone()),
1070        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
1071        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
1072        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
1073            Some(type_into.clone())
1074        }
1075        (UInt8, Null | UInt8) => Some(type_into.clone()),
1076        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
1077        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
1078        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
1079        (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => {
1080            Some(type_into.clone())
1081        }
1082        (
1083            Float32,
1084            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
1085            | Float16 | Float32,
1086        ) => Some(type_into.clone()),
1087        (
1088            Float64,
1089            Null
1090            | Int8
1091            | Int16
1092            | Int32
1093            | Int64
1094            | UInt8
1095            | UInt16
1096            | UInt32
1097            | UInt64
1098            | Float16
1099            | Float32
1100            | Float64
1101            | Decimal32(_, _)
1102            | Decimal64(_, _)
1103            | Decimal128(_, _)
1104            | Decimal256(_, _),
1105        ) => Some(type_into.clone()),
1106        (
1107            Timestamp(TimeUnit::Nanosecond, None),
1108            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
1109        ) => Some(type_into.clone()),
1110        (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()),
1111        // We can go into a Utf8View from a Utf8 or LargeUtf8
1112        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
1113        // Any type can be coerced into strings
1114        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
1115        // We can go into a BinaryView from a Binary or LargeBinary
1116        (BinaryView, Binary | LargeBinary | Null) => Some(type_into.clone()),
1117        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
1118
1119        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
1120
1121        // Only accept list and largelist with the same number of dimensions unless the type is Null.
1122        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
1123        (List(_) | LargeList(_) | ListView(_) | LargeListView(_), _)
1124            if base_type(type_from).is_null()
1125                || list_ndims(type_from) == list_ndims(type_into) =>
1126        {
1127            Some(type_into.clone())
1128        }
1129        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
1130        (
1131            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
1132            FixedSizeList(f_from, size_from),
1133        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
1134            Some(data_type) if &data_type != f_into.data_type() => {
1135                let new_field =
1136                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
1137                Some(FixedSizeList(new_field, *size_from))
1138            }
1139            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
1140            _ => None,
1141        },
1142        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
1143            match type_from {
1144                Timestamp(_, Some(from_tz)) => {
1145                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
1146                }
1147                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
1148                    // In the absence of any other information assume the time zone is "+00" (UTC).
1149                    Some(Timestamp(*unit, Some("+00".into())))
1150                }
1151                _ => None,
1152            }
1153        }
1154        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
1155            Some(type_into.clone())
1156        }
1157        // Null can be coerced to any target type, provided the cast is valid.
1158        // This mirrors null_coercion() in binary comparison coercion
1159        // (expr-common/src/type_coercion/binary.rs) and is the symmetric
1160        // counterpart of the (Null, _) arm above. Without this, untyped
1161        // placeholders ($1, $foo) inside function calls fail signature matching
1162        // because their Null type doesn't match any Exact(...) variant.
1163        (_, Null) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
1164        _ => None,
1165    }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170    use crate::{
1171        HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature,
1172        HigherOrderUDFImpl, Volatility,
1173    };
1174
1175    use super::*;
1176    use arrow::datatypes::IntervalUnit;
1177    use datafusion_common::{
1178        assert_contains,
1179        types::{logical_binary, logical_int64},
1180    };
1181    use datafusion_expr_common::{
1182        columnar_value::ColumnarValue,
1183        signature::{Coercion, TypeSignatureClass},
1184    };
1185
1186    #[test]
1187    fn test_string_conversion() {
1188        let cases = vec![
1189            (DataType::Utf8View, DataType::Utf8),
1190            (DataType::Utf8View, DataType::LargeUtf8),
1191            (DataType::Utf8View, DataType::Null),
1192        ];
1193
1194        for case in cases {
1195            assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
1196        }
1197    }
1198
1199    #[test]
1200    fn test_binary_conversion() {
1201        let cases = vec![
1202            (DataType::BinaryView, DataType::Binary),
1203            (DataType::BinaryView, DataType::LargeBinary),
1204            (DataType::BinaryView, DataType::Null),
1205        ];
1206
1207        for case in cases {
1208            assert_eq!(coerced_from(&case.0, &case.1), Some(case.0));
1209        }
1210    }
1211
1212    #[test]
1213    fn test_coerced_from_null() {
1214        // Null should coerce to Interval (the motivating case)
1215        assert_eq!(
1216            coerced_from(
1217                &DataType::Interval(IntervalUnit::MonthDayNano),
1218                &DataType::Null
1219            ),
1220            Some(DataType::Interval(IntervalUnit::MonthDayNano))
1221        );
1222
1223        // Null should coerce to Date32
1224        assert_eq!(
1225            coerced_from(&DataType::Date32, &DataType::Null),
1226            Some(DataType::Date32)
1227        );
1228
1229        // Null should coerce to Timestamp with timezone
1230        assert_eq!(
1231            coerced_from(
1232                &DataType::Timestamp(TimeUnit::Microsecond, Some("+00".into())),
1233                &DataType::Null
1234            ),
1235            Some(DataType::Timestamp(
1236                TimeUnit::Microsecond,
1237                Some("+00".into())
1238            ))
1239        );
1240    }
1241
1242    #[test]
1243    fn test_maybe_data_types() {
1244        // this vec contains: arg1, arg2, expected result
1245        let cases = vec![
1246            // 2 entries, same values
1247            (
1248                vec![DataType::UInt8, DataType::UInt16],
1249                vec![DataType::UInt8, DataType::UInt16],
1250                Some(vec![DataType::UInt8, DataType::UInt16]),
1251            ),
1252            // 2 entries, can coerce values
1253            (
1254                vec![DataType::UInt16, DataType::UInt16],
1255                vec![DataType::UInt8, DataType::UInt16],
1256                Some(vec![DataType::UInt16, DataType::UInt16]),
1257            ),
1258            // 0 entries, all good
1259            (vec![], vec![], Some(vec![])),
1260            // 2 entries, can't coerce
1261            (
1262                vec![DataType::Boolean, DataType::UInt16],
1263                vec![DataType::UInt8, DataType::UInt16],
1264                None,
1265            ),
1266            // u32 -> u16 is possible
1267            (
1268                vec![DataType::Boolean, DataType::UInt32],
1269                vec![DataType::Boolean, DataType::UInt16],
1270                Some(vec![DataType::Boolean, DataType::UInt32]),
1271            ),
1272            // UTF8 -> Timestamp
1273            (
1274                vec![
1275                    DataType::Timestamp(TimeUnit::Nanosecond, None),
1276                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
1277                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1278                ],
1279                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
1280                Some(vec![
1281                    DataType::Timestamp(TimeUnit::Nanosecond, None),
1282                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
1283                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
1284                ]),
1285            ),
1286        ];
1287
1288        for case in cases {
1289            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
1290        }
1291    }
1292
1293    #[test]
1294    fn test_get_valid_types_numeric() -> Result<()> {
1295        let get_valid_types_flatten =
1296            |function_name: &str,
1297             signature: &TypeSignature,
1298             current_types: &[DataType]| {
1299                get_valid_types(function_name, signature, current_types)
1300                    .unwrap()
1301                    .into_iter()
1302                    .flatten()
1303                    .collect::<Vec<_>>()
1304            };
1305
1306        // Trivial case.
1307        let got = get_valid_types_flatten(
1308            "test",
1309            &TypeSignature::Numeric(1),
1310            &[DataType::Int32],
1311        );
1312        assert_eq!(got, [DataType::Int32]);
1313
1314        // Args are coerced into a common numeric type.
1315        let got = get_valid_types_flatten(
1316            "test",
1317            &TypeSignature::Numeric(2),
1318            &[DataType::Int32, DataType::Int64],
1319        );
1320        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1321
1322        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1323        let got = get_valid_types_flatten(
1324            "test",
1325            &TypeSignature::Numeric(3),
1326            &[DataType::Int32, DataType::Int64, DataType::Float64],
1327        );
1328        assert_eq!(
1329            got,
1330            [DataType::Float64, DataType::Float64, DataType::Float64]
1331        );
1332
1333        // Cannot coerce args to a common numeric type.
1334        let got = get_valid_types(
1335            "test",
1336            &TypeSignature::Numeric(2),
1337            &[DataType::Int32, DataType::Utf8],
1338        )
1339        .unwrap_err();
1340        assert_contains!(
1341            got.to_string(),
1342            "Function 'test' expects Numeric but received String"
1343        );
1344
1345        // Fallbacks to float64 if the arg is of type null.
1346        let got = get_valid_types_flatten(
1347            "test",
1348            &TypeSignature::Numeric(1),
1349            &[DataType::Null],
1350        );
1351        assert_eq!(got, [DataType::Float64]);
1352
1353        // Rejects non-numeric arg.
1354        let got = get_valid_types(
1355            "test",
1356            &TypeSignature::Numeric(1),
1357            &[DataType::Timestamp(TimeUnit::Second, None)],
1358        )
1359        .unwrap_err();
1360        assert_contains!(
1361            got.to_string(),
1362            "Function 'test' expects Numeric but received Timestamp(s)"
1363        );
1364
1365        Ok(())
1366    }
1367
1368    #[test]
1369    fn test_get_valid_types_one_of() -> Result<()> {
1370        let signature =
1371            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1372
1373        let invalid_types = get_valid_types(
1374            "test",
1375            &signature,
1376            &[DataType::Int32, DataType::Int32, DataType::Int32],
1377        )?;
1378        assert_eq!(invalid_types.len(), 0);
1379
1380        let args = vec![DataType::Int32, DataType::Int32];
1381        let valid_types = get_valid_types("test", &signature, &args)?;
1382        assert_eq!(valid_types.len(), 1);
1383        assert_eq!(valid_types[0], args);
1384
1385        let args = vec![DataType::Int32];
1386        let valid_types = get_valid_types("test", &signature, &args)?;
1387        assert_eq!(valid_types.len(), 1);
1388        assert_eq!(valid_types[0], args);
1389
1390        Ok(())
1391    }
1392
1393    #[test]
1394    fn test_get_valid_types_length_check() -> Result<()> {
1395        let signature = TypeSignature::Numeric(1);
1396
1397        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1398        assert_contains!(
1399            err.to_string(),
1400            "Function 'test' expects 1 arguments but received 0"
1401        );
1402
1403        let err = get_valid_types(
1404            "test",
1405            &signature,
1406            &[DataType::Int32, DataType::Int32, DataType::Int32],
1407        )
1408        .unwrap_err();
1409        assert_contains!(
1410            err.to_string(),
1411            "Function 'test' expects 1 arguments but received 3"
1412        );
1413
1414        Ok(())
1415    }
1416
1417    struct MockUdf(Signature);
1418
1419    impl UDFCoercionExt for MockUdf {
1420        fn name(&self) -> &str {
1421            "test"
1422        }
1423        fn signature(&self) -> &Signature {
1424            &self.0
1425        }
1426        fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1427            unimplemented!()
1428        }
1429    }
1430
1431    #[test]
1432    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1433        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1434        // able to coerce for any size
1435        let current_fields = vec![Arc::new(Field::new(
1436            "t",
1437            DataType::FixedSizeList(Arc::clone(&inner), 2),
1438            true,
1439        ))];
1440
1441        let signature = Signature::exact(
1442            vec![DataType::FixedSizeList(
1443                Arc::clone(&inner),
1444                FIXED_SIZE_LIST_WILDCARD,
1445            )],
1446            Volatility::Stable,
1447        );
1448
1449        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature))?;
1450        assert_eq!(coerced_fields, current_fields);
1451
1452        // make sure it can't coerce to a different size
1453        let signature = Signature::exact(
1454            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1455            Volatility::Stable,
1456        );
1457        let coerced_fields = fields_with_udf(&current_fields, &MockUdf(signature));
1458        assert!(coerced_fields.is_err());
1459
1460        // make sure it works with the same type.
1461        let signature = Signature::exact(
1462            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1463            Volatility::Stable,
1464        );
1465        let coerced_fields =
1466            fields_with_udf(&current_fields, &MockUdf(signature)).unwrap();
1467        assert_eq!(coerced_fields, current_fields);
1468
1469        Ok(())
1470    }
1471
1472    #[test]
1473    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1474        let type_into = DataType::FixedSizeList(
1475            Arc::new(Field::new_list_field(
1476                DataType::FixedSizeList(
1477                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1478                    FIXED_SIZE_LIST_WILDCARD,
1479                ),
1480                false,
1481            )),
1482            FIXED_SIZE_LIST_WILDCARD,
1483        );
1484
1485        let type_from = DataType::FixedSizeList(
1486            Arc::new(Field::new_list_field(
1487                DataType::FixedSizeList(
1488                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1489                    4,
1490                ),
1491                false,
1492            )),
1493            3,
1494        );
1495
1496        assert_eq!(
1497            coerced_from(&type_into, &type_from),
1498            Some(DataType::FixedSizeList(
1499                Arc::new(Field::new_list_field(
1500                    DataType::FixedSizeList(
1501                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1502                        4,
1503                    ),
1504                    false,
1505                )),
1506                3,
1507            ))
1508        );
1509
1510        Ok(())
1511    }
1512
1513    #[test]
1514    fn test_coerced_from_dictionary() {
1515        let type_into =
1516            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1517        let type_from = DataType::Int64;
1518        assert_eq!(coerced_from(&type_into, &type_from), None);
1519
1520        let type_from =
1521            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1522        let type_into = DataType::Int64;
1523        assert_eq!(
1524            coerced_from(&type_into, &type_from),
1525            Some(type_into.clone())
1526        );
1527    }
1528
1529    #[test]
1530    fn test_get_valid_types_array_and_array() -> Result<()> {
1531        let function = "array_and_array";
1532        let signature = Signature::arrays(
1533            2,
1534            Some(ListCoercion::FixedSizedListToList),
1535            Volatility::Immutable,
1536        );
1537
1538        let data_types = vec![
1539            DataType::new_list(DataType::Int32, true),
1540            DataType::new_large_list(DataType::Float64, true),
1541        ];
1542        assert_eq!(
1543            get_valid_types(function, &signature.type_signature, &data_types)?,
1544            vec![vec![
1545                DataType::new_large_list(DataType::Float64, true),
1546                DataType::new_large_list(DataType::Float64, true),
1547            ]]
1548        );
1549
1550        let data_types = vec![
1551            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1552            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1553        ];
1554        assert_eq!(
1555            get_valid_types(function, &signature.type_signature, &data_types)?,
1556            vec![vec![
1557                DataType::new_list(DataType::Int64, true),
1558                DataType::new_list(DataType::Int64, true),
1559            ]]
1560        );
1561
1562        let data_types = vec![
1563            DataType::new_fixed_size_list(DataType::Null, 3, true),
1564            DataType::new_large_list(DataType::Utf8, true),
1565        ];
1566        assert_eq!(
1567            get_valid_types(function, &signature.type_signature, &data_types)?,
1568            vec![vec![
1569                DataType::new_large_list(DataType::Utf8, true),
1570                DataType::new_large_list(DataType::Utf8, true),
1571            ]]
1572        );
1573
1574        let data_types = vec![
1575            DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1576            DataType::new_list(DataType::Int32, true),
1577        ];
1578        assert_eq!(
1579            get_valid_types(function, &signature.type_signature, &data_types)?,
1580            vec![vec![
1581                DataType::new_list(DataType::Int32, true),
1582                DataType::new_list(DataType::Int32, true),
1583            ]]
1584        );
1585
1586        let data_types = vec![
1587            DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1588            DataType::new_list(DataType::Int32, true),
1589        ];
1590        assert_eq!(
1591            get_valid_types(function, &signature.type_signature, &data_types)?,
1592            vec![vec![
1593                DataType::new_large_list(DataType::Int32, true),
1594                DataType::new_large_list(DataType::Int32, true),
1595            ]]
1596        );
1597
1598        let data_types = vec![
1599            DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1600            DataType::ListView(Field::new_list_field(DataType::Int32, true).into()),
1601        ];
1602        assert_eq!(
1603            get_valid_types(function, &signature.type_signature, &data_types)?,
1604            vec![vec![
1605                DataType::new_list(DataType::Int32, true),
1606                DataType::new_list(DataType::Int32, true),
1607            ]]
1608        );
1609
1610        let data_types = vec![
1611            DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1612            DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()),
1613        ];
1614        assert_eq!(
1615            get_valid_types(function, &signature.type_signature, &data_types)?,
1616            vec![vec![
1617                DataType::new_large_list(DataType::Int32, true),
1618                DataType::new_large_list(DataType::Int32, true),
1619            ]]
1620        );
1621
1622        Ok(())
1623    }
1624
1625    #[test]
1626    fn test_get_valid_types_array_and_element() -> Result<()> {
1627        let function = "array_and_element";
1628        let signature = Signature::array_and_element(Volatility::Immutable);
1629
1630        let data_types =
1631            vec![DataType::new_list(DataType::Int32, true), DataType::Float64];
1632        assert_eq!(
1633            get_valid_types(function, &signature.type_signature, &data_types)?,
1634            vec![vec![
1635                DataType::new_list(DataType::Float64, true),
1636                DataType::Float64,
1637            ]]
1638        );
1639
1640        let data_types = vec![
1641            DataType::new_large_list(DataType::Int32, true),
1642            DataType::Null,
1643        ];
1644        assert_eq!(
1645            get_valid_types(function, &signature.type_signature, &data_types)?,
1646            vec![vec![
1647                DataType::new_large_list(DataType::Int32, true),
1648                DataType::Int32,
1649            ]]
1650        );
1651
1652        let data_types = vec![
1653            DataType::new_fixed_size_list(DataType::Null, 3, true),
1654            DataType::Utf8,
1655        ];
1656        assert_eq!(
1657            get_valid_types(function, &signature.type_signature, &data_types)?,
1658            vec![vec![
1659                DataType::new_list(DataType::Utf8, true),
1660                DataType::Utf8,
1661            ]]
1662        );
1663
1664        Ok(())
1665    }
1666
1667    #[test]
1668    fn test_get_valid_types_element_and_array() -> Result<()> {
1669        let function = "element_and_array";
1670        let signature = Signature::element_and_array(Volatility::Immutable);
1671
1672        let data_types = vec![
1673            DataType::new_large_list(DataType::Null, false),
1674            DataType::new_list(DataType::new_list(DataType::Int64, true), true),
1675        ];
1676        assert_eq!(
1677            get_valid_types(function, &signature.type_signature, &data_types)?,
1678            vec![vec![
1679                DataType::new_large_list(DataType::Int64, true),
1680                DataType::new_list(DataType::new_large_list(DataType::Int64, true), true),
1681            ]]
1682        );
1683
1684        Ok(())
1685    }
1686
1687    #[test]
1688    fn test_coercible_nulls() -> Result<()> {
1689        fn null_input(coercion: Coercion) -> Result<Vec<DataType>> {
1690            fields_with_udf(
1691                &[Field::new("field", DataType::Null, true).into()],
1692                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1693            )
1694            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1695        }
1696
1697        // Casts Null to Int64 if we use TypeSignatureClass::Native
1698        let output = null_input(Coercion::new_exact(TypeSignatureClass::Native(
1699            logical_int64(),
1700        )))?;
1701        assert_eq!(vec![DataType::Int64], output);
1702
1703        let output = null_input(Coercion::new_implicit(
1704            TypeSignatureClass::Native(logical_int64()),
1705            vec![],
1706            NativeType::Int64,
1707        ))?;
1708        assert_eq!(vec![DataType::Int64], output);
1709
1710        // Null gets passed through if we use TypeSignatureClass apart from Native
1711        let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1712        assert_eq!(vec![DataType::Null], output);
1713
1714        let output = null_input(Coercion::new_implicit(
1715            TypeSignatureClass::Integer,
1716            vec![],
1717            NativeType::Int64,
1718        ))?;
1719        assert_eq!(vec![DataType::Null], output);
1720
1721        Ok(())
1722    }
1723
1724    #[test]
1725    fn test_coercible_dictionary() -> Result<()> {
1726        let dictionary =
1727            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64));
1728        fn dictionary_input(coercion: Coercion) -> Result<Vec<DataType>> {
1729            fields_with_udf(
1730                &[Field::new(
1731                    "field",
1732                    DataType::Dictionary(
1733                        Box::new(DataType::Int8),
1734                        Box::new(DataType::Int64),
1735                    ),
1736                    true,
1737                )
1738                .into()],
1739                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1740            )
1741            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1742        }
1743
1744        // Casts Dictionary to Int64 if we use TypeSignatureClass::Native
1745        let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native(
1746            logical_int64(),
1747        )))?;
1748        assert_eq!(vec![DataType::Int64], output);
1749
1750        let output = dictionary_input(Coercion::new_implicit(
1751            TypeSignatureClass::Native(logical_int64()),
1752            vec![],
1753            NativeType::Int64,
1754        ))?;
1755        assert_eq!(vec![DataType::Int64], output);
1756
1757        // Dictionary gets passed through if we use TypeSignatureClass apart from Native
1758        let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1759        assert_eq!(vec![dictionary.clone()], output);
1760
1761        let output = dictionary_input(Coercion::new_implicit(
1762            TypeSignatureClass::Integer,
1763            vec![],
1764            NativeType::Int64,
1765        ))?;
1766        assert_eq!(vec![dictionary.clone()], output);
1767
1768        Ok(())
1769    }
1770
1771    #[test]
1772    fn test_coercible_run_end_encoded() -> Result<()> {
1773        let run_end_encoded = DataType::RunEndEncoded(
1774            Field::new("run_ends", DataType::Int16, false).into(),
1775            Field::new("values", DataType::Int64, true).into(),
1776        );
1777        fn run_end_encoded_input(coercion: Coercion) -> Result<Vec<DataType>> {
1778            fields_with_udf(
1779                &[Field::new(
1780                    "field",
1781                    DataType::RunEndEncoded(
1782                        Field::new("run_ends", DataType::Int16, false).into(),
1783                        Field::new("values", DataType::Int64, true).into(),
1784                    ),
1785                    true,
1786                )
1787                .into()],
1788                &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)),
1789            )
1790            .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect())
1791        }
1792
1793        // Casts REE to Int64 if we use TypeSignatureClass::Native
1794        let output = run_end_encoded_input(Coercion::new_exact(
1795            TypeSignatureClass::Native(logical_int64()),
1796        ))?;
1797        assert_eq!(vec![DataType::Int64], output);
1798
1799        let output = run_end_encoded_input(Coercion::new_implicit(
1800            TypeSignatureClass::Native(logical_int64()),
1801            vec![],
1802            NativeType::Int64,
1803        ))?;
1804        assert_eq!(vec![DataType::Int64], output);
1805
1806        // REE gets passed through if we use TypeSignatureClass apart from Native
1807        let output =
1808            run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?;
1809        assert_eq!(vec![run_end_encoded.clone()], output);
1810
1811        let output = run_end_encoded_input(Coercion::new_implicit(
1812            TypeSignatureClass::Integer,
1813            vec![],
1814            NativeType::Int64,
1815        ))?;
1816        assert_eq!(vec![run_end_encoded.clone()], output);
1817
1818        Ok(())
1819    }
1820
1821    #[test]
1822    fn test_get_valid_types_coercible_binary() -> Result<()> {
1823        let signature = Signature::coercible(
1824            vec![Coercion::new_exact(TypeSignatureClass::Native(
1825                logical_binary(),
1826            ))],
1827            Volatility::Immutable,
1828        );
1829
1830        // Binary types should stay their original selves
1831        for t in [
1832            DataType::Binary,
1833            DataType::BinaryView,
1834            DataType::LargeBinary,
1835        ] {
1836            assert_eq!(
1837                get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?,
1838                vec![vec![t]]
1839            );
1840        }
1841
1842        Ok(())
1843    }
1844
1845    #[test]
1846    fn test_get_valid_types_fixed_size_arrays() -> Result<()> {
1847        let function = "fixed_size_arrays";
1848        let signature = Signature::arrays(2, None, Volatility::Immutable);
1849
1850        let data_types = vec![
1851            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1852            DataType::new_fixed_size_list(DataType::Int32, 5, true),
1853        ];
1854        assert_eq!(
1855            get_valid_types(function, &signature.type_signature, &data_types)?,
1856            vec![vec![
1857                DataType::new_fixed_size_list(DataType::Int64, 3, true),
1858                DataType::new_fixed_size_list(DataType::Int64, 5, true),
1859            ]]
1860        );
1861
1862        let data_types = vec![
1863            DataType::new_fixed_size_list(DataType::Int64, 3, true),
1864            DataType::new_list(DataType::Int32, true),
1865        ];
1866        assert_eq!(
1867            get_valid_types(function, &signature.type_signature, &data_types)?,
1868            vec![vec![
1869                DataType::new_list(DataType::Int64, true),
1870                DataType::new_list(DataType::Int64, true),
1871            ]]
1872        );
1873
1874        let data_types = vec![
1875            DataType::new_fixed_size_list(DataType::Utf8, 3, true),
1876            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1877        ];
1878        assert_eq!(
1879            get_valid_types(function, &signature.type_signature, &data_types)?,
1880            vec![vec![]]
1881        );
1882
1883        let data_types = vec![
1884            DataType::new_fixed_size_list(DataType::Int64, 3, false),
1885            DataType::new_list(DataType::Int32, false),
1886        ];
1887        assert_eq!(
1888            get_valid_types(function, &signature.type_signature, &data_types)?,
1889            vec![vec![
1890                DataType::new_list(DataType::Int64, false),
1891                DataType::new_list(DataType::Int64, false),
1892            ]]
1893        );
1894
1895        Ok(())
1896    }
1897
1898    #[derive(Debug, PartialEq, Eq, Hash)]
1899    struct MockHigherOrderUDF {
1900        signature: HigherOrderSignature,
1901        coerced_value_types: Vec<DataType>,
1902    }
1903
1904    impl HigherOrderUDFImpl for MockHigherOrderUDF {
1905        fn name(&self) -> &str {
1906            "mock_higher_order_function"
1907        }
1908
1909        fn signature(&self) -> &HigherOrderSignature {
1910            &self.signature
1911        }
1912
1913        fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1914            if arg_types.len() != 1 {
1915                return plan_err!(
1916                    "mock_higher_order_function expects 1 value arguments, got {}",
1917                    arg_types.len()
1918                );
1919            }
1920            Ok(self.coerced_value_types.clone())
1921        }
1922
1923        fn coerce_values_for_lambdas(
1924            &self,
1925            fields: &[ValueOrLambda<DataType, DataType>],
1926        ) -> Result<Option<Vec<DataType>>> {
1927            // thoerical impl of array_reduce without finish
1928            let [
1929                ValueOrLambda::Value(list),
1930                ValueOrLambda::Value(_initial),
1931                ValueOrLambda::Lambda(merge),
1932            ] = fields
1933            else {
1934                unreachable!()
1935            };
1936
1937            Ok(Some(vec![list.clone(), merge.clone()]))
1938        }
1939
1940        fn lambda_parameters(
1941            &self,
1942            _step: usize,
1943            _fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1944        ) -> Result<crate::LambdaParametersProgress> {
1945            unimplemented!("mock_higher_order_function")
1946        }
1947
1948        fn return_field_from_args(
1949            &self,
1950            _args: HigherOrderReturnFieldArgs,
1951        ) -> Result<FieldRef> {
1952            unimplemented!("mock_higher_order_function")
1953        }
1954
1955        fn invoke_with_args(
1956            &self,
1957            _args: HigherOrderFunctionArgs,
1958        ) -> Result<ColumnarValue> {
1959            unimplemented!("mock_higher_order_function")
1960        }
1961    }
1962
1963    #[test]
1964    fn test_higher_order_function_user_defined_type_coercion() {
1965        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
1966            signature: HigherOrderSignature::user_defined(Volatility::Immutable),
1967            coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
1968        });
1969
1970        let new_fields = value_fields_with_higher_order_udf(
1971            &[
1972                ValueOrLambda::Value(Arc::new(Field::new_list(
1973                    "",
1974                    Field::new_list_field(DataType::Int32, false),
1975                    false,
1976                ))),
1977                ValueOrLambda::Lambda(()),
1978            ],
1979            &fun,
1980        )
1981        .unwrap();
1982
1983        // from List(Int32) to LargeList(Int32)
1984        assert_eq!(
1985            new_fields,
1986            vec![
1987                ValueOrLambda::Value(Arc::new(Field::new_large_list(
1988                    "",
1989                    Field::new_list_field(DataType::Int32, false),
1990                    false
1991                ))),
1992                ValueOrLambda::Lambda(()),
1993            ]
1994        )
1995    }
1996
1997    #[test]
1998    fn test_higher_order_function_coerce_values_for_lambdas() {
1999        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2000            signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
2001            coerced_value_types: vec![],
2002        });
2003
2004        let new_fields = value_fields_with_higher_order_udf_and_lambdas(
2005            &[
2006                ValueOrLambda::Value(Arc::new(Field::new_list(
2007                    "",
2008                    Field::new_list_field(DataType::Float32, true),
2009                    true,
2010                ))),
2011                ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
2012                ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))),
2013            ],
2014            &fun,
2015        )
2016        .unwrap();
2017
2018        // second parameter from Int32 to Float32
2019        assert_eq!(
2020            new_fields,
2021            vec![
2022                ValueOrLambda::Value(Arc::new(Field::new_list(
2023                    "",
2024                    Field::new_list_field(DataType::Float32, true),
2025                    true,
2026                ))),
2027                ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
2028                ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))),
2029            ]
2030        )
2031    }
2032
2033    #[test]
2034    fn test_higher_order_function_user_defined_type_coercion_bad_args() {
2035        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2036            signature: HigherOrderSignature::user_defined(Volatility::Immutable),
2037            coerced_value_types: vec![DataType::Int32],
2038        });
2039
2040        let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err();
2041
2042        assert_contains!(
2043            err.to_string(),
2044            "mock_higher_order_function expects 1 value arguments, got 0"
2045        );
2046    }
2047
2048    #[test]
2049    fn test_higher_order_function_faulty_user_defined_type_coercion() {
2050        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2051            signature: HigherOrderSignature::user_defined(Volatility::Immutable),
2052            coerced_value_types: vec![DataType::Int32, DataType::Int32],
2053        });
2054
2055        let err = value_fields_with_higher_order_udf::<()>(
2056            &[ValueOrLambda::Value(Arc::new(Field::new(
2057                "",
2058                DataType::Int32,
2059                false,
2060            )))],
2061            &fun,
2062        )
2063        .unwrap_err();
2064
2065        assert_contains!(
2066            err.to_string(),
2067            "mock_higher_order_function coerce_value_types should have returned 1 items but returned 2"
2068        );
2069    }
2070
2071    #[test]
2072    fn test_higher_order_function_any_signature() {
2073        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2074            signature: HigherOrderSignature::any(1, Volatility::Immutable),
2075            coerced_value_types: vec![],
2076        });
2077
2078        let new_fields =
2079            value_fields_with_higher_order_udf(&[ValueOrLambda::Lambda(())], &fun)
2080                .unwrap();
2081
2082        // no coercion, just number of args checked
2083        assert_eq!(new_fields, vec![ValueOrLambda::Lambda(())])
2084    }
2085
2086    #[test]
2087    fn test_higher_order_function_any_signature_bad_args() {
2088        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2089            signature: HigherOrderSignature::any(1, Volatility::Immutable),
2090            coerced_value_types: vec![],
2091        });
2092
2093        let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err();
2094
2095        assert_contains!(
2096            err.to_string(),
2097            "The function 'mock_higher_order_function' expected 1 arguments but received 0"
2098        );
2099    }
2100
2101    #[test]
2102    fn test_higher_order_function_exact_signature() {
2103        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2104            signature: HigherOrderSignature::exact(
2105                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2106                Volatility::Immutable,
2107            ),
2108            coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
2109        });
2110
2111        let new_fields = value_fields_with_higher_order_udf(
2112            &[
2113                ValueOrLambda::Value(Arc::new(Field::new_list(
2114                    "",
2115                    Field::new_list_field(DataType::Int32, false),
2116                    false,
2117                ))),
2118                ValueOrLambda::Lambda(()),
2119            ],
2120            &fun,
2121        )
2122        .unwrap();
2123
2124        // type coercion applied: List(Int32) -> LargeList(Int32)
2125        assert_eq!(
2126            new_fields,
2127            vec![
2128                ValueOrLambda::Value(Arc::new(Field::new_large_list(
2129                    "",
2130                    Field::new_list_field(DataType::Int32, false),
2131                    false
2132                ))),
2133                ValueOrLambda::Lambda(()),
2134            ]
2135        )
2136    }
2137
2138    #[test]
2139    fn test_higher_order_function_exact_signature_wrong_value_count() {
2140        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2141            signature: HigherOrderSignature::exact(
2142                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2143                Volatility::Immutable,
2144            ),
2145            coerced_value_types: vec![],
2146        });
2147
2148        let err = value_fields_with_higher_order_udf::<()>(
2149            &[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())],
2150            &fun,
2151        )
2152        .unwrap_err();
2153
2154        assert_contains!(
2155            err.to_string(),
2156            "expected a value at position 0 but received a lambda"
2157        );
2158    }
2159
2160    #[test]
2161    fn test_higher_order_function_exact_signature_wrong_lambda_count() {
2162        let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
2163            signature: HigherOrderSignature::exact(
2164                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2165                Volatility::Immutable,
2166            ),
2167            coerced_value_types: vec![],
2168        });
2169
2170        let err = value_fields_with_higher_order_udf::<()>(
2171            &[
2172                ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2173                ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2174            ],
2175            &fun,
2176        )
2177        .unwrap_err();
2178
2179        assert_contains!(
2180            err.to_string(),
2181            "expected a lambda at position 1 but received a value"
2182        );
2183    }
2184}