datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32    exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33    DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44    data_types_with_aggregate_udf, data_types_with_scalar_udf,
45};
46use datafusion_expr::type_coercion::other::{
47    get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53    AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan,
54    Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound,
55    WindowFrameUnits,
56};
57
58/// Performs type coercion by determining the schema
59/// and performing the expression rewrites.
60#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64    pub fn new() -> Self {
65        Self {}
66    }
67}
68
69/// Coerce output schema based upon optimizer config.
70fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71    if !config.optimizer.expand_views_at_output {
72        return Ok(plan);
73    }
74
75    let outer_refs = plan.expressions();
76    if outer_refs.is_empty() {
77        return Ok(plan);
78    }
79
80    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81        coerce_plan_expr_for_schema(plan, &dfschema?)
82    } else {
83        Ok(plan)
84    }
85}
86
87impl AnalyzerRule for TypeCoercion {
88    fn name(&self) -> &str {
89        "type_coercion"
90    }
91
92    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93        let empty_schema = DFSchema::empty();
94
95        // recurse
96        let transformed_plan = plan
97            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98            .data;
99
100        // finish
101        coerce_output(transformed_plan, config)
102    }
103}
104
105/// use the external schema to handle the correlated subqueries case
106///
107/// Assumes that children have already been optimized
108fn analyze_internal(
109    external_schema: &DFSchema,
110    plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112    // get schema representing all available input fields. This is used for data type
113    // resolution only, so order does not matter here
114    let mut schema = merge_schema(&plan.inputs());
115
116    if let LogicalPlan::TableScan(ts) = &plan {
117        let source_schema = DFSchema::try_from_qualified_schema(
118            ts.table_name.clone(),
119            &ts.source.schema(),
120        )?;
121        schema.merge(&source_schema);
122    }
123
124    // merge the outer schema for correlated subqueries
125    // like case:
126    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
127    schema.merge(external_schema);
128
129    // Coerce filter predicates to boolean (handles `WHERE NULL`)
130    let plan = if let LogicalPlan::Filter(mut filter) = plan {
131        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132        LogicalPlan::Filter(filter)
133    } else {
134        plan
135    };
136
137    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139    let name_preserver = NamePreserver::new(&plan);
140    // apply coercion rewrite all expressions in the plan individually
141    plan.map_expressions(|expr| {
142        let original_name = name_preserver.save(&expr);
143        expr.rewrite(&mut expr_rewrite)
144            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145    })?
146    // some plans need extra coercion after their expressions are coerced
147    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148    // recompute the schema after the expressions have been rewritten as the types may have changed
149    .map_data(|plan| plan.recompute_schema())
150}
151
152/// Rewrite expressions to apply type coercion.
153pub struct TypeCoercionRewriter<'a> {
154    pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158    /// Create a new [`TypeCoercionRewriter`] with a provided schema
159    /// representing both the inputs and output of the [`LogicalPlan`] node.
160    pub fn new(schema: &'a DFSchema) -> Self {
161        Self { schema }
162    }
163
164    /// Coerce the [`LogicalPlan`].
165    ///
166    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
167    /// for type-coercion approach.
168    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169        match plan {
170            LogicalPlan::Join(join) => self.coerce_join(join),
171            LogicalPlan::Union(union) => Self::coerce_union(union),
172            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173            _ => Ok(plan),
174        }
175    }
176
177    /// Coerce join equality expressions and join filter
178    ///
179    /// Joins must be treated specially as their equality expressions are stored
180    /// as a parallel list of left and right expressions, rather than a single
181    /// equality expression
182    ///
183    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
184    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
185    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186        join.on = join
187            .on
188            .into_iter()
189            .map(|(lhs, rhs)| {
190                // coerce the arguments as though they were a single binary equality
191                // expression
192                let left_schema = join.left.schema();
193                let right_schema = join.right.schema();
194                let (lhs, rhs) = self.coerce_binary_op(
195                    lhs,
196                    left_schema,
197                    Operator::Eq,
198                    rhs,
199                    right_schema,
200                )?;
201                Ok((lhs, rhs))
202            })
203            .collect::<Result<Vec<_>>>()?;
204
205        // Join filter must be boolean
206        join.filter = join
207            .filter
208            .map(|expr| self.coerce_join_filter(expr))
209            .transpose()?;
210
211        Ok(LogicalPlan::Join(join))
212    }
213
214    /// Coerce the union’s inputs to a common schema compatible with all inputs.
215    /// This occurs after wildcard expansion and the coercion of the input expressions.
216    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217        let union_schema = Arc::new(coerce_union_schema_with_schema(
218            &union_plan.inputs,
219            &union_plan.schema,
220        )?);
221        let new_inputs = union_plan
222            .inputs
223            .into_iter()
224            .map(|p| {
225                let plan =
226                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227                match plan {
228                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
229                        Ok(Arc::new(project_with_column_index(
230                            expr,
231                            input,
232                            Arc::clone(&union_schema),
233                        )?))
234                    }
235                    other_plan => Ok(Arc::new(other_plan)),
236                }
237            })
238            .collect::<Result<Vec<_>>>()?;
239        Ok(LogicalPlan::Union(Union {
240            inputs: new_inputs,
241            schema: union_schema,
242        }))
243    }
244
245    /// Coerce the fetch and skip expression to Int64 type.
246    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247        fn coerce_limit_expr(
248            expr: Expr,
249            schema: &DFSchema,
250            expr_name: &str,
251        ) -> Result<Expr> {
252            let dt = expr.get_type(schema)?;
253            if dt.is_integer() || dt.is_null() {
254                expr.cast_to(&DataType::Int64, schema)
255            } else {
256                plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
257            }
258        }
259
260        let empty_schema = DFSchema::empty();
261        let new_fetch = limit
262            .fetch
263            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264            .transpose()?;
265        let new_skip = limit
266            .skip
267            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268            .transpose()?;
269        Ok(LogicalPlan::Limit(Limit {
270            input: limit.input,
271            fetch: new_fetch.map(Box::new),
272            skip: new_skip.map(Box::new),
273        }))
274    }
275
276    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277        let expr_type = expr.get_type(self.schema)?;
278        match expr_type {
279            DataType::Boolean => Ok(expr),
280            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282        }
283    }
284
285    fn coerce_binary_op(
286        &self,
287        left: Expr,
288        left_schema: &DFSchema,
289        op: Operator,
290        right: Expr,
291        right_schema: &DFSchema,
292    ) -> Result<(Expr, Expr)> {
293        let (left_type, right_type) = BinaryTypeCoercer::new(
294            &left.get_type(left_schema)?,
295            &op,
296            &right.get_type(right_schema)?,
297        )
298        .get_input_types()?;
299
300        Ok((
301            left.cast_to(&left_type, left_schema)?,
302            right.cast_to(&right_type, right_schema)?,
303        ))
304    }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308    type Node = Expr;
309
310    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311        match expr {
312            Expr::Unnest(_) => not_impl_err!(
313                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314            ),
315            Expr::ScalarSubquery(Subquery {
316                subquery,
317                outer_ref_columns,
318                spans,
319            }) => {
320                let new_plan =
321                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323                    subquery: Arc::new(new_plan),
324                    outer_ref_columns,
325                    spans,
326                })))
327            }
328            Expr::Exists(Exists { subquery, negated }) => {
329                let new_plan = analyze_internal(
330                    self.schema,
331                    Arc::unwrap_or_clone(subquery.subquery),
332                )?
333                .data;
334                Ok(Transformed::yes(Expr::Exists(Exists {
335                    subquery: Subquery {
336                        subquery: Arc::new(new_plan),
337                        outer_ref_columns: subquery.outer_ref_columns,
338                        spans: subquery.spans,
339                    },
340                    negated,
341                })))
342            }
343            Expr::InSubquery(InSubquery {
344                expr,
345                subquery,
346                negated,
347            }) => {
348                let new_plan = analyze_internal(
349                    self.schema,
350                    Arc::unwrap_or_clone(subquery.subquery),
351                )?
352                .data;
353                let expr_type = expr.get_type(self.schema)?;
354                let subquery_type = new_plan.schema().field(0).data_type();
355                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
356                        "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
357                    ),
358                )?;
359                let new_subquery = Subquery {
360                    subquery: Arc::new(new_plan),
361                    outer_ref_columns: subquery.outer_ref_columns,
362                    spans: subquery.spans,
363                };
364                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
365                    Box::new(expr.cast_to(&common_type, self.schema)?),
366                    cast_subquery(new_subquery, &common_type)?,
367                    negated,
368                ))))
369            }
370            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
371                *expr,
372                self.schema,
373            )?))),
374            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
375                get_casted_expr_for_bool_op(*expr, self.schema)?,
376            ))),
377            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
378                get_casted_expr_for_bool_op(*expr, self.schema)?,
379            ))),
380            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
381                get_casted_expr_for_bool_op(*expr, self.schema)?,
382            ))),
383            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
384                get_casted_expr_for_bool_op(*expr, self.schema)?,
385            ))),
386            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
387                get_casted_expr_for_bool_op(*expr, self.schema)?,
388            ))),
389            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
390                get_casted_expr_for_bool_op(*expr, self.schema)?,
391            ))),
392            Expr::Like(Like {
393                negated,
394                expr,
395                pattern,
396                escape_char,
397                case_insensitive,
398            }) => {
399                let left_type = expr.get_type(self.schema)?;
400                let right_type = pattern.get_type(self.schema)?;
401                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
402                    let op_name = if case_insensitive {
403                        "ILIKE"
404                    } else {
405                        "LIKE"
406                    };
407                    plan_datafusion_err!(
408                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
409                    )
410                })?;
411                let expr = match left_type {
412                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
413                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
414                };
415                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
416                Ok(Transformed::yes(Expr::Like(Like::new(
417                    negated,
418                    expr,
419                    pattern,
420                    escape_char,
421                    case_insensitive,
422                ))))
423            }
424            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
425                let (left, right) =
426                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
427                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
428                    Box::new(left),
429                    op,
430                    Box::new(right),
431                ))))
432            }
433            Expr::Between(Between {
434                expr,
435                negated,
436                low,
437                high,
438            }) => {
439                let expr_type = expr.get_type(self.schema)?;
440                let low_type = low.get_type(self.schema)?;
441                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
442                    .ok_or_else(|| {
443                        DataFusionError::Internal(format!(
444                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
445                        ))
446                    })?;
447                let high_type = high.get_type(self.schema)?;
448                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
449                    .ok_or_else(|| {
450                        DataFusionError::Internal(format!(
451                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
452                        ))
453                    })?;
454                let coercion_type =
455                    comparison_coercion(&low_coerced_type, &high_coerced_type)
456                        .ok_or_else(|| {
457                            DataFusionError::Internal(format!(
458                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
459                            ))
460                        })?;
461                Ok(Transformed::yes(Expr::Between(Between::new(
462                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
463                    negated,
464                    Box::new(low.cast_to(&coercion_type, self.schema)?),
465                    Box::new(high.cast_to(&coercion_type, self.schema)?),
466                ))))
467            }
468            Expr::InList(InList {
469                expr,
470                list,
471                negated,
472            }) => {
473                let expr_data_type = expr.get_type(self.schema)?;
474                let list_data_types = list
475                    .iter()
476                    .map(|list_expr| list_expr.get_type(self.schema))
477                    .collect::<Result<Vec<_>>>()?;
478                let result_type =
479                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
480                match result_type {
481                    None => plan_err!(
482                        "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
483                    ),
484                    Some(coerced_type) => {
485                        // find the coerced type
486                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
487                        let cast_list_expr = list
488                            .into_iter()
489                            .map(|list_expr| {
490                                list_expr.cast_to(&coerced_type, self.schema)
491                            })
492                            .collect::<Result<Vec<_>>>()?;
493                        Ok(Transformed::yes(Expr::InList(InList ::new(
494                             Box::new(cast_expr),
495                             cast_list_expr,
496                            negated,
497                        ))))
498                    }
499                }
500            }
501            Expr::Case(case) => {
502                let case = coerce_case_expression(case, self.schema)?;
503                Ok(Transformed::yes(Expr::Case(case)))
504            }
505            Expr::ScalarFunction(ScalarFunction { func, args }) => {
506                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
507                    args,
508                    self.schema,
509                    &func,
510                )?;
511                Ok(Transformed::yes(Expr::ScalarFunction(
512                    ScalarFunction::new_udf(func, new_expr),
513                )))
514            }
515            Expr::AggregateFunction(expr::AggregateFunction {
516                func,
517                params:
518                    AggregateFunctionParams {
519                        args,
520                        distinct,
521                        filter,
522                        order_by,
523                        null_treatment,
524                    },
525            }) => {
526                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
527                    args,
528                    self.schema,
529                    &func,
530                )?;
531                Ok(Transformed::yes(Expr::AggregateFunction(
532                    expr::AggregateFunction::new_udf(
533                        func,
534                        new_expr,
535                        distinct,
536                        filter,
537                        order_by,
538                        null_treatment,
539                    ),
540                )))
541            }
542            Expr::WindowFunction(WindowFunction {
543                fun,
544                params:
545                    expr::WindowFunctionParams {
546                        args,
547                        partition_by,
548                        order_by,
549                        window_frame,
550                        null_treatment,
551                    },
552            }) => {
553                let window_frame =
554                    coerce_window_frame(window_frame, self.schema, &order_by)?;
555
556                let args = match &fun {
557                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
558                        coerce_arguments_for_signature_with_aggregate_udf(
559                            args,
560                            self.schema,
561                            udf,
562                        )?
563                    }
564                    _ => args,
565                };
566
567                Ok(Transformed::yes(
568                    Expr::WindowFunction(WindowFunction::new(fun, args))
569                        .partition_by(partition_by)
570                        .order_by(order_by)
571                        .window_frame(window_frame)
572                        .null_treatment(null_treatment)
573                        .build()?,
574                ))
575            }
576            // TODO: remove the next line after `Expr::Wildcard` is removed
577            #[expect(deprecated)]
578            Expr::Alias(_)
579            | Expr::Column(_)
580            | Expr::ScalarVariable(_, _)
581            | Expr::Literal(_)
582            | Expr::SimilarTo(_)
583            | Expr::IsNotNull(_)
584            | Expr::IsNull(_)
585            | Expr::Negative(_)
586            | Expr::Cast(_)
587            | Expr::TryCast(_)
588            | Expr::Wildcard { .. }
589            | Expr::GroupingSet(_)
590            | Expr::Placeholder(_)
591            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
592        }
593    }
594}
595
596/// Transform a schema to use non-view types for Utf8View and BinaryView
597fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
598    let metadata = dfschema.as_arrow().metadata.clone();
599    let mut transformed = false;
600
601    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
602        dfschema
603            .iter()
604            .map(|(qualifier, field)| match field.data_type() {
605                DataType::Utf8View => {
606                    transformed = true;
607                    (
608                        qualifier.cloned() as Option<TableReference>,
609                        Arc::new(Field::new(
610                            field.name(),
611                            DataType::LargeUtf8,
612                            field.is_nullable(),
613                        )),
614                    )
615                }
616                DataType::BinaryView => {
617                    transformed = true;
618                    (
619                        qualifier.cloned() as Option<TableReference>,
620                        Arc::new(Field::new(
621                            field.name(),
622                            DataType::LargeBinary,
623                            field.is_nullable(),
624                        )),
625                    )
626                }
627                _ => (
628                    qualifier.cloned() as Option<TableReference>,
629                    Arc::clone(field),
630                ),
631            })
632            .unzip();
633
634    if !transformed {
635        return None;
636    }
637
638    let schema = Schema::new_with_metadata(transformed_fields, metadata);
639    Some(DFSchema::from_field_specific_qualified_schema(
640        qualifiers,
641        &Arc::new(schema),
642    ))
643}
644
645/// Casts the given `value` to `target_type`. Note that this function
646/// only considers `Null` or `Utf8` values.
647fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
648    match value {
649        // Coerce Utf8 values:
650        ScalarValue::Utf8(Some(val)) => {
651            ScalarValue::try_from_string(val.clone(), target_type)
652        }
653        s => {
654            if s.is_null() {
655                // Coerce `Null` values:
656                ScalarValue::try_from(target_type)
657            } else {
658                // Values except `Utf8`/`Null` variants already have the right type
659                // (casted before) since we convert `sqlparser` outputs to `Utf8`
660                // for all possible cases. Therefore, we return a clone here.
661                Ok(s.clone())
662            }
663        }
664    }
665}
666
667/// This function coerces `value` to `target_type` in a range-aware fashion.
668/// If the coercion is successful, we return an `Ok` value with the result.
669/// If the coercion fails because `target_type` is not wide enough (i.e. we
670/// can not coerce to `target_type`, but we can to a wider type in the same
671/// family), we return a `Null` value of this type to signal this situation.
672/// Downstream code uses this signal to treat these values as *unbounded*.
673fn coerce_scalar_range_aware(
674    target_type: &DataType,
675    value: &ScalarValue,
676) -> Result<ScalarValue> {
677    coerce_scalar(target_type, value).or_else(|err| {
678        // If type coercion fails, check if the largest type in family works:
679        if let Some(largest_type) = get_widest_type_in_family(target_type) {
680            coerce_scalar(largest_type, value).map_or_else(
681                |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
682                |_| ScalarValue::try_from(target_type),
683            )
684        } else {
685            Err(err)
686        }
687    })
688}
689
690/// This function returns the widest type in the family of `given_type`.
691/// If the given type is already the widest type, it returns `None`.
692/// For example, if `given_type` is `Int8`, it returns `Int64`.
693fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
694    match given_type {
695        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
696        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
697        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
698        _ => None,
699    }
700}
701
702/// Coerces the given (window frame) `bound` to `target_type`.
703fn coerce_frame_bound(
704    target_type: &DataType,
705    bound: WindowFrameBound,
706) -> Result<WindowFrameBound> {
707    match bound {
708        WindowFrameBound::Preceding(v) => {
709            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
710        }
711        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
712        WindowFrameBound::Following(v) => {
713            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
714        }
715    }
716}
717
718fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
719    if col_type.is_numeric()
720        || is_utf8_or_utf8view_or_large_utf8(col_type)
721        || matches!(col_type, DataType::Null)
722        || matches!(col_type, DataType::Boolean)
723    {
724        Ok(col_type.clone())
725    } else if is_datetime(col_type) {
726        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
727    } else if let DataType::Dictionary(_, value_type) = col_type {
728        extract_window_frame_target_type(value_type)
729    } else {
730        return internal_err!("Cannot run range queries on datatype: {col_type:?}");
731    }
732}
733
734// Coerces the given `window_frame` to use appropriate natural types.
735// For example, ROWS and GROUPS frames use `UInt64` during calculations.
736fn coerce_window_frame(
737    window_frame: WindowFrame,
738    schema: &DFSchema,
739    expressions: &[Sort],
740) -> Result<WindowFrame> {
741    let mut window_frame = window_frame;
742    let target_type = match window_frame.units {
743        WindowFrameUnits::Range => {
744            let current_types = expressions
745                .first()
746                .map(|s| s.expr.get_type(schema))
747                .transpose()?;
748            if let Some(col_type) = current_types {
749                extract_window_frame_target_type(&col_type)?
750            } else {
751                return internal_err!("ORDER BY column cannot be empty");
752            }
753        }
754        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
755    };
756    window_frame.start_bound =
757        coerce_frame_bound(&target_type, window_frame.start_bound)?;
758    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
759    Ok(window_frame)
760}
761
762// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
763// The above op will be rewrite to the binary op when creating the physical op.
764fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
765    let left_type = expr.get_type(schema)?;
766    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
767        .get_input_types()?;
768    expr.cast_to(&DataType::Boolean, schema)
769}
770
771/// Returns `expressions` coerced to types compatible with
772/// `signature`, if possible.
773///
774/// See the module level documentation for more detail on coercion.
775fn coerce_arguments_for_signature_with_scalar_udf(
776    expressions: Vec<Expr>,
777    schema: &DFSchema,
778    func: &ScalarUDF,
779) -> Result<Vec<Expr>> {
780    if expressions.is_empty() {
781        return Ok(expressions);
782    }
783
784    let current_types = expressions
785        .iter()
786        .map(|e| e.get_type(schema))
787        .collect::<Result<Vec<_>>>()?;
788
789    let new_types = data_types_with_scalar_udf(&current_types, func)?;
790
791    expressions
792        .into_iter()
793        .enumerate()
794        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
795        .collect()
796}
797
798/// Returns `expressions` coerced to types compatible with
799/// `signature`, if possible.
800///
801/// See the module level documentation for more detail on coercion.
802fn coerce_arguments_for_signature_with_aggregate_udf(
803    expressions: Vec<Expr>,
804    schema: &DFSchema,
805    func: &AggregateUDF,
806) -> Result<Vec<Expr>> {
807    if expressions.is_empty() {
808        return Ok(expressions);
809    }
810
811    let current_types = expressions
812        .iter()
813        .map(|e| e.get_type(schema))
814        .collect::<Result<Vec<_>>>()?;
815
816    let new_types = data_types_with_aggregate_udf(&current_types, func)?;
817
818    expressions
819        .into_iter()
820        .enumerate()
821        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
822        .collect()
823}
824
825fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
826    // Given expressions like:
827    //
828    // CASE a1
829    //   WHEN a2 THEN b1
830    //   WHEN a3 THEN b2
831    //   ELSE b3
832    // END
833    //
834    // or:
835    //
836    // CASE
837    //   WHEN x1 THEN b1
838    //   WHEN x2 THEN b2
839    //   ELSE b3
840    // END
841    //
842    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
843    // (case-when expression coercion)
844    //
845    // All xN (x1, x2) must be converted to a boolean data type in the second example
846    // (when-boolean expression coercion)
847    //
848    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
849    // (then-else expression coercion)
850    //
851    // If any fail to find and cast to a common/specific data type, will return error
852    //
853    // Note that case-when and when-boolean expression coercions are mutually exclusive
854    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
855
856    // prepare types
857    let case_type = case
858        .expr
859        .as_ref()
860        .map(|expr| expr.get_type(schema))
861        .transpose()?;
862    let then_types = case
863        .when_then_expr
864        .iter()
865        .map(|(_when, then)| then.get_type(schema))
866        .collect::<Result<Vec<_>>>()?;
867    let else_type = case
868        .else_expr
869        .as_ref()
870        .map(|expr| expr.get_type(schema))
871        .transpose()?;
872
873    // find common coercible types
874    let case_when_coerce_type = case_type
875        .as_ref()
876        .map(|case_type| {
877            let when_types = case
878                .when_then_expr
879                .iter()
880                .map(|(when, _then)| when.get_type(schema))
881                .collect::<Result<Vec<_>>>()?;
882            let coerced_type =
883                get_coerce_type_for_case_expression(&when_types, Some(case_type));
884            coerced_type.ok_or_else(|| {
885                plan_datafusion_err!(
886                    "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
887                     to common types in CASE WHEN expression"
888                )
889            })
890        })
891        .transpose()?;
892    let then_else_coerce_type =
893        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
894            || {
895                plan_datafusion_err!(
896                    "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
897                     to common types in CASE WHEN expression"
898                )
899            },
900        )?;
901
902    // do cast if found common coercible types
903    let case_expr = case
904        .expr
905        .zip(case_when_coerce_type.as_ref())
906        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
907        .transpose()?
908        .map(Box::new);
909    let when_then = case
910        .when_then_expr
911        .into_iter()
912        .map(|(when, then)| {
913            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
914            let when = when.cast_to(when_type, schema).map_err(|e| {
915                DataFusionError::Context(
916                    format!(
917                        "WHEN expressions in CASE couldn't be \
918                         converted to common type ({when_type})"
919                    ),
920                    Box::new(e),
921                )
922            })?;
923            let then = then.cast_to(&then_else_coerce_type, schema)?;
924            Ok((Box::new(when), Box::new(then)))
925        })
926        .collect::<Result<Vec<_>>>()?;
927    let else_expr = case
928        .else_expr
929        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
930        .transpose()?
931        .map(Box::new);
932
933    Ok(Case::new(case_expr, when_then, else_expr))
934}
935
936/// Get a common schema that is compatible with all inputs of UNION.
937///
938/// This method presumes that the wildcard expansion is unneeded, or has already
939/// been applied.
940pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
941    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
942}
943fn coerce_union_schema_with_schema(
944    inputs: &[Arc<LogicalPlan>],
945    base_schema: &DFSchemaRef,
946) -> Result<DFSchema> {
947    let mut union_datatypes = base_schema
948        .fields()
949        .iter()
950        .map(|f| f.data_type().clone())
951        .collect::<Vec<_>>();
952    let mut union_nullabilities = base_schema
953        .fields()
954        .iter()
955        .map(|f| f.is_nullable())
956        .collect::<Vec<_>>();
957    let mut union_field_meta = base_schema
958        .fields()
959        .iter()
960        .map(|f| f.metadata().clone())
961        .collect::<Vec<_>>();
962
963    let mut metadata = base_schema.metadata().clone();
964
965    for (i, plan) in inputs.iter().enumerate() {
966        let plan_schema = plan.schema();
967        metadata.extend(plan_schema.metadata().clone());
968
969        if plan_schema.fields().len() != base_schema.fields().len() {
970            return plan_err!(
971                "Union schemas have different number of fields: \
972                query 1 has {} fields whereas query {} has {} fields",
973                base_schema.fields().len(),
974                i + 1,
975                plan_schema.fields().len()
976            );
977        }
978
979        // coerce data type and nullability for each field
980        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
981            union_datatypes.iter_mut(),
982            union_nullabilities.iter_mut(),
983            union_field_meta.iter_mut(),
984            plan_schema.fields().iter()
985        ) {
986            let coerced_type =
987                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
988                    || {
989                        plan_datafusion_err!(
990                            "Incompatible inputs for Union: Previous inputs were \
991                            of type {}, but got incompatible type {} on column '{}'",
992                            union_datatype,
993                            plan_field.data_type(),
994                            plan_field.name()
995                        )
996                    },
997                )?;
998
999            *union_datatype = coerced_type;
1000            *union_nullable = *union_nullable || plan_field.is_nullable();
1001            union_field_map.extend(plan_field.metadata().clone());
1002        }
1003    }
1004    let union_qualified_fields = izip!(
1005        base_schema.fields(),
1006        union_datatypes.into_iter(),
1007        union_nullabilities,
1008        union_field_meta.into_iter()
1009    )
1010    .map(|(field, datatype, nullable, metadata)| {
1011        let mut field = Field::new(field.name().clone(), datatype, nullable);
1012        field.set_metadata(metadata);
1013        (None, field.into())
1014    })
1015    .collect::<Vec<_>>();
1016
1017    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1018}
1019
1020/// See `<https://github.com/apache/datafusion/pull/2108>`
1021fn project_with_column_index(
1022    expr: Vec<Expr>,
1023    input: Arc<LogicalPlan>,
1024    schema: DFSchemaRef,
1025) -> Result<LogicalPlan> {
1026    let alias_expr = expr
1027        .into_iter()
1028        .enumerate()
1029        .map(|(i, e)| match e {
1030            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1031                Ok(e.unalias().alias(schema.field(i).name()))
1032            }
1033            Expr::Column(Column {
1034                relation: _,
1035                ref name,
1036                spans: _,
1037            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1038            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1039            #[expect(deprecated)]
1040            Expr::Wildcard { .. } => {
1041                plan_err!("Wildcard should be expanded before type coercion")
1042            }
1043            _ => Ok(e.alias(schema.field(i).name())),
1044        })
1045        .collect::<Result<Vec<_>>>()?;
1046
1047    Projection::try_new_with_schema(alias_expr, input, schema)
1048        .map(LogicalPlan::Projection)
1049}
1050
1051#[cfg(test)]
1052mod test {
1053    use std::any::Any;
1054    use std::sync::Arc;
1055
1056    use arrow::datatypes::DataType::Utf8;
1057    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1058
1059    use crate::analyzer::type_coercion::{
1060        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1061    };
1062    use crate::analyzer::Analyzer;
1063    use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq};
1064    use datafusion_common::config::ConfigOptions;
1065    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1066    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1067    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1068    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1069    use datafusion_expr::test::function_stub::avg_udaf;
1070    use datafusion_expr::{
1071        cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1072        BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1073        Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1074        SimpleAggregateUDF, Subquery, Union, Volatility,
1075    };
1076    use datafusion_functions_aggregate::average::AvgAccumulator;
1077    use datafusion_sql::TableReference;
1078
1079    fn empty() -> Arc<LogicalPlan> {
1080        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1081            produce_one_row: false,
1082            schema: Arc::new(DFSchema::empty()),
1083        }))
1084    }
1085
1086    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1087        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1088            produce_one_row: false,
1089            schema: Arc::new(
1090                DFSchema::from_unqualified_fields(
1091                    vec![Field::new("a", data_type, true)].into(),
1092                    std::collections::HashMap::new(),
1093                )
1094                .unwrap(),
1095            ),
1096        }))
1097    }
1098
1099    #[test]
1100    fn simple_case() -> Result<()> {
1101        let expr = col("a").lt(lit(2_u32));
1102        let empty = empty_with_type(DataType::Float64);
1103        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1104        let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n  EmptyRelation";
1105        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1106    }
1107
1108    #[test]
1109    fn test_coerce_union() -> Result<()> {
1110        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1111            produce_one_row: false,
1112            schema: Arc::new(
1113                DFSchema::try_from_qualified_schema(
1114                    TableReference::full("datafusion", "test", "foo"),
1115                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1116                )
1117                .unwrap(),
1118            ),
1119        }));
1120        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1121            produce_one_row: false,
1122            schema: Arc::new(
1123                DFSchema::try_from_qualified_schema(
1124                    TableReference::full("datafusion", "test", "foo"),
1125                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1126                )
1127                .unwrap(),
1128            ),
1129        }));
1130        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1131            left_plan, right_plan,
1132        ])?);
1133        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1134            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1135        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1136            vec![col("a")],
1137            Arc::new(analyzed_union),
1138        )?);
1139
1140        let expected = "Projection: a\n  Union\n    Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n      EmptyRelation\n    EmptyRelation";
1141        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected)
1142    }
1143
1144    fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> {
1145        let mut options = ConfigOptions::default();
1146        options.optimizer.expand_views_at_output = true;
1147
1148        assert_analyzed_plan_with_config_eq(
1149            options,
1150            Arc::new(TypeCoercion::new()),
1151            plan.clone(),
1152            expected,
1153        )
1154    }
1155
1156    fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> {
1157        assert_analyzed_plan_with_config_eq(
1158            ConfigOptions::default(),
1159            Arc::new(TypeCoercion::new()),
1160            plan.clone(),
1161            expected,
1162        )
1163    }
1164
1165    #[test]
1166    fn coerce_utf8view_output() -> Result<()> {
1167        // Plan A
1168        // scenario: outermost utf8view projection
1169        let expr = col("a");
1170        let empty = empty_with_type(DataType::Utf8View);
1171        let plan = LogicalPlan::Projection(Projection::try_new(
1172            vec![expr.clone()],
1173            Arc::clone(&empty),
1174        )?);
1175        // Plan A: no coerce
1176        let if_not_coerced = "Projection: a\n  EmptyRelation";
1177        do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1178        // Plan A: coerce requested: Utf8View => LargeUtf8
1179        let if_coerced = "Projection: CAST(a AS LargeUtf8)\n  EmptyRelation";
1180        coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1181
1182        // Plan B
1183        // scenario: outermost bool projection
1184        let bool_expr = col("a").lt(lit("foo"));
1185        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1186            vec![bool_expr],
1187            Arc::clone(&empty),
1188        )?);
1189        // Plan B: no coerce
1190        let if_not_coerced =
1191            "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n  EmptyRelation";
1192        do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?;
1193        // Plan B: coerce requested: no coercion applied
1194        let if_coerced = if_not_coerced;
1195        coerce_on_output_if_viewtype(bool_plan, if_coerced)?;
1196
1197        // Plan C
1198        // scenario: with a non-projection root logical plan node
1199        let sort_expr = expr.sort(true, true);
1200        let sort_plan = LogicalPlan::Sort(Sort {
1201            expr: vec![sort_expr],
1202            input: Arc::new(plan),
1203            fetch: None,
1204        });
1205        // Plan C: no coerce
1206        let if_not_coerced =
1207            "Sort: a ASC NULLS FIRST\n  Projection: a\n    EmptyRelation";
1208        do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?;
1209        // Plan C: coerce requested: Utf8View => LargeUtf8
1210        let if_coerced = "Projection: CAST(a AS LargeUtf8)\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1211        coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?;
1212
1213        // Plan D
1214        // scenario: two layers of projections with view types
1215        let plan = LogicalPlan::Projection(Projection::try_new(
1216            vec![col("a")],
1217            Arc::new(sort_plan),
1218        )?);
1219        // Plan D: no coerce
1220        let if_not_coerced = "Projection: a\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1221        do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1222        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1223        let if_coerced = "Projection: CAST(a AS LargeUtf8)\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1224        coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1225
1226        Ok(())
1227    }
1228
1229    #[test]
1230    fn coerce_binaryview_output() -> Result<()> {
1231        // Plan A
1232        // scenario: outermost binaryview projection
1233        let expr = col("a");
1234        let empty = empty_with_type(DataType::BinaryView);
1235        let plan = LogicalPlan::Projection(Projection::try_new(
1236            vec![expr.clone()],
1237            Arc::clone(&empty),
1238        )?);
1239        // Plan A: no coerce
1240        let if_not_coerced = "Projection: a\n  EmptyRelation";
1241        do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1242        // Plan A: coerce requested: BinaryView => LargeBinary
1243        let if_coerced = "Projection: CAST(a AS LargeBinary)\n  EmptyRelation";
1244        coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1245
1246        // Plan B
1247        // scenario: outermost bool projection
1248        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1249        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1250            vec![bool_expr],
1251            Arc::clone(&empty),
1252        )?);
1253        // Plan B: no coerce
1254        let if_not_coerced =
1255            "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n  EmptyRelation";
1256        do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?;
1257        // Plan B: coerce requested: no coercion applied
1258        let if_coerced = if_not_coerced;
1259        coerce_on_output_if_viewtype(bool_plan, if_coerced)?;
1260
1261        // Plan C
1262        // scenario: with a non-projection root logical plan node
1263        let sort_expr = expr.sort(true, true);
1264        let sort_plan = LogicalPlan::Sort(Sort {
1265            expr: vec![sort_expr],
1266            input: Arc::new(plan),
1267            fetch: None,
1268        });
1269        // Plan C: no coerce
1270        let if_not_coerced =
1271            "Sort: a ASC NULLS FIRST\n  Projection: a\n    EmptyRelation";
1272        do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?;
1273        // Plan C: coerce requested: BinaryView => LargeBinary
1274        let if_coerced = "Projection: CAST(a AS LargeBinary)\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1275        coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?;
1276
1277        // Plan D
1278        // scenario: two layers of projections with view types
1279        let plan = LogicalPlan::Projection(Projection::try_new(
1280            vec![col("a")],
1281            Arc::new(sort_plan),
1282        )?);
1283        // Plan D: no coerce
1284        let if_not_coerced = "Projection: a\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1285        do_not_coerce_on_output(plan.clone(), if_not_coerced)?;
1286        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1287        let if_coerced = "Projection: CAST(a AS LargeBinary)\n  Sort: a ASC NULLS FIRST\n    Projection: a\n      EmptyRelation";
1288        coerce_on_output_if_viewtype(plan.clone(), if_coerced)?;
1289
1290        Ok(())
1291    }
1292
1293    #[test]
1294    fn nested_case() -> Result<()> {
1295        let expr = col("a").lt(lit(2_u32));
1296        let empty = empty_with_type(DataType::Float64);
1297
1298        let plan = LogicalPlan::Projection(Projection::try_new(
1299            vec![expr.clone().or(expr)],
1300            empty,
1301        )?);
1302        let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\
1303            \n  EmptyRelation";
1304        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1305    }
1306
1307    #[derive(Debug, Clone)]
1308    struct TestScalarUDF {
1309        signature: Signature,
1310    }
1311
1312    impl ScalarUDFImpl for TestScalarUDF {
1313        fn as_any(&self) -> &dyn Any {
1314            self
1315        }
1316
1317        fn name(&self) -> &str {
1318            "TestScalarUDF"
1319        }
1320
1321        fn signature(&self) -> &Signature {
1322            &self.signature
1323        }
1324
1325        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1326            Ok(Utf8)
1327        }
1328
1329        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1330            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1331        }
1332    }
1333
1334    #[test]
1335    fn scalar_udf() -> Result<()> {
1336        let empty = empty();
1337
1338        let udf = ScalarUDF::from(TestScalarUDF {
1339            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1340        })
1341        .call(vec![lit(123_i32)]);
1342        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1343        let expected =
1344            "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n  EmptyRelation";
1345        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1346    }
1347
1348    #[test]
1349    fn scalar_udf_invalid_input() -> Result<()> {
1350        let empty = empty();
1351        let udf = ScalarUDF::from(TestScalarUDF {
1352            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1353        })
1354        .call(vec![lit("Apple")]);
1355        Projection::try_new(vec![udf], empty)
1356            .expect_err("Expected an error due to incorrect function input");
1357
1358        Ok(())
1359    }
1360
1361    #[test]
1362    fn scalar_function() -> Result<()> {
1363        // test that automatic argument type coercion for scalar functions work
1364        let empty = empty();
1365        let lit_expr = lit(10i64);
1366        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1367            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1368        });
1369        let scalar_function_expr =
1370            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1371        let plan = LogicalPlan::Projection(Projection::try_new(
1372            vec![scalar_function_expr],
1373            empty,
1374        )?);
1375        let expected =
1376            "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n  EmptyRelation";
1377        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1378    }
1379
1380    #[test]
1381    fn agg_udaf() -> Result<()> {
1382        let empty = empty();
1383        let my_avg = create_udaf(
1384            "MY_AVG",
1385            vec![DataType::Float64],
1386            Arc::new(DataType::Float64),
1387            Volatility::Immutable,
1388            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1389            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1390        );
1391        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1392            Arc::new(my_avg),
1393            vec![lit(10i64)],
1394            false,
1395            None,
1396            None,
1397            None,
1398        ));
1399        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1400        let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n  EmptyRelation";
1401        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1402    }
1403
1404    #[test]
1405    fn agg_udaf_invalid_input() -> Result<()> {
1406        let empty = empty();
1407        let return_type = DataType::Float64;
1408        let accumulator: AccumulatorFactoryFunction =
1409            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1410        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1411            "MY_AVG",
1412            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1413            return_type,
1414            accumulator,
1415            vec![
1416                Field::new("count", DataType::UInt64, true),
1417                Field::new("avg", DataType::Float64, true),
1418            ],
1419        ));
1420        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1421            Arc::new(my_avg),
1422            vec![lit("10")],
1423            false,
1424            None,
1425            None,
1426            None,
1427        ));
1428
1429        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1430        assert!(
1431            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1432        );
1433        Ok(())
1434    }
1435
1436    #[test]
1437    fn agg_function_case() -> Result<()> {
1438        let empty = empty();
1439        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1440            avg_udaf(),
1441            vec![lit(12f64)],
1442            false,
1443            None,
1444            None,
1445            None,
1446        ));
1447        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1448        let expected = "Projection: avg(Float64(12))\n  EmptyRelation";
1449        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1450
1451        let empty = empty_with_type(DataType::Int32);
1452        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1453            avg_udaf(),
1454            vec![cast(col("a"), DataType::Float64)],
1455            false,
1456            None,
1457            None,
1458            None,
1459        ));
1460        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1461        let expected = "Projection: avg(CAST(a AS Float64))\n  EmptyRelation";
1462        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1463        Ok(())
1464    }
1465
1466    #[test]
1467    fn agg_function_invalid_input_avg() -> Result<()> {
1468        let empty = empty();
1469        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1470            avg_udaf(),
1471            vec![lit("1")],
1472            false,
1473            None,
1474            None,
1475            None,
1476        ));
1477        let err = Projection::try_new(vec![agg_expr], empty)
1478            .err()
1479            .unwrap()
1480            .strip_backtrace();
1481        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1482        Ok(())
1483    }
1484
1485    #[test]
1486    fn binary_op_date32_op_interval() -> Result<()> {
1487        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1488        let expr = cast(lit("1998-03-18"), DataType::Date32)
1489            + lit(ScalarValue::new_interval_dt(123, 456));
1490        let empty = empty();
1491        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1492        let expected =
1493            "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n  EmptyRelation";
1494        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1495        Ok(())
1496    }
1497
1498    #[test]
1499    fn inlist_case() -> Result<()> {
1500        // a in (1,4,8), a is int64
1501        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1502        let empty = empty_with_type(DataType::Int64);
1503        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1504        let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n  EmptyRelation";
1505        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1506
1507        // a in (1,4,8), a is decimal
1508        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1509        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1510            produce_one_row: false,
1511            schema: Arc::new(DFSchema::from_unqualified_fields(
1512                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1513                std::collections::HashMap::new(),
1514            )?),
1515        }));
1516        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1517        let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n  EmptyRelation";
1518        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1519    }
1520
1521    #[test]
1522    fn between_case() -> Result<()> {
1523        let expr = col("a").between(
1524            lit("2002-05-08"),
1525            // (cast('2002-05-08' as date) + interval '1 months')
1526            cast(lit("2002-05-08"), DataType::Date32)
1527                + lit(ScalarValue::new_interval_ym(0, 1)),
1528        );
1529        let empty = empty_with_type(Utf8);
1530        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1531        let expected =
1532            "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\
1533            \n  EmptyRelation";
1534        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1535    }
1536
1537    #[test]
1538    fn between_infer_cheap_type() -> Result<()> {
1539        let expr = col("a").between(
1540            // (cast('2002-05-08' as date) + interval '1 months')
1541            cast(lit("2002-05-08"), DataType::Date32)
1542                + lit(ScalarValue::new_interval_ym(0, 1)),
1543            lit("2002-12-08"),
1544        );
1545        let empty = empty_with_type(Utf8);
1546        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1547        // TODO: we should cast col(a).
1548        let expected =
1549            "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\
1550            \n  EmptyRelation";
1551        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1552    }
1553
1554    #[test]
1555    fn between_null() -> Result<()> {
1556        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1557        let empty = empty();
1558        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1559        let expected =
1560            "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\
1561            \n  EmptyRelation";
1562        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
1563    }
1564
1565    #[test]
1566    fn is_bool_for_type_coercion() -> Result<()> {
1567        // is true
1568        let expr = col("a").is_true();
1569        let empty = empty_with_type(DataType::Boolean);
1570        let plan =
1571            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1572        let expected = "Projection: a IS TRUE\n  EmptyRelation";
1573        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1574
1575        let empty = empty_with_type(DataType::Int64);
1576        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1577        let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, "");
1578        let err = ret.unwrap_err().to_string();
1579        assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
1580
1581        // is not true
1582        let expr = col("a").is_not_true();
1583        let empty = empty_with_type(DataType::Boolean);
1584        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1585        let expected = "Projection: a IS NOT TRUE\n  EmptyRelation";
1586        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1587
1588        // is false
1589        let expr = col("a").is_false();
1590        let empty = empty_with_type(DataType::Boolean);
1591        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1592        let expected = "Projection: a IS FALSE\n  EmptyRelation";
1593        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1594
1595        // is not false
1596        let expr = col("a").is_not_false();
1597        let empty = empty_with_type(DataType::Boolean);
1598        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1599        let expected = "Projection: a IS NOT FALSE\n  EmptyRelation";
1600        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1601
1602        Ok(())
1603    }
1604
1605    #[test]
1606    fn like_for_type_coercion() -> Result<()> {
1607        // like : utf8 like "abc"
1608        let expr = Box::new(col("a"));
1609        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1610        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1611        let empty = empty_with_type(Utf8);
1612        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1613        let expected = "Projection: a LIKE Utf8(\"abc\")\n  EmptyRelation";
1614        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1615
1616        let expr = Box::new(col("a"));
1617        let pattern = Box::new(lit(ScalarValue::Null));
1618        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1619        let empty = empty_with_type(Utf8);
1620        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1621        let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n  EmptyRelation";
1622        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1623
1624        let expr = Box::new(col("a"));
1625        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1626        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1627        let empty = empty_with_type(DataType::Int64);
1628        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1629        let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1630        assert!(err.is_err());
1631        assert!(err.unwrap_err().to_string().contains(
1632            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression"
1633        ));
1634
1635        // ilike
1636        let expr = Box::new(col("a"));
1637        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1638        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1639        let empty = empty_with_type(Utf8);
1640        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1641        let expected = "Projection: a ILIKE Utf8(\"abc\")\n  EmptyRelation";
1642        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1643
1644        let expr = Box::new(col("a"));
1645        let pattern = Box::new(lit(ScalarValue::Null));
1646        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1647        let empty = empty_with_type(Utf8);
1648        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1649        let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n  EmptyRelation";
1650        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1651
1652        let expr = Box::new(col("a"));
1653        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1654        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1655        let empty = empty_with_type(DataType::Int64);
1656        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1657        let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1658        assert!(err.is_err());
1659        assert!(err.unwrap_err().to_string().contains(
1660            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression"
1661        ));
1662        Ok(())
1663    }
1664
1665    #[test]
1666    fn unknown_for_type_coercion() -> Result<()> {
1667        // unknown
1668        let expr = col("a").is_unknown();
1669        let empty = empty_with_type(DataType::Boolean);
1670        let plan =
1671            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1672        let expected = "Projection: a IS UNKNOWN\n  EmptyRelation";
1673        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1674
1675        let empty = empty_with_type(Utf8);
1676        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1677        let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected);
1678        let err = ret.unwrap_err().to_string();
1679        assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}");
1680
1681        // is not unknown
1682        let expr = col("a").is_not_unknown();
1683        let empty = empty_with_type(DataType::Boolean);
1684        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1685        let expected = "Projection: a IS NOT UNKNOWN\n  EmptyRelation";
1686        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1687
1688        Ok(())
1689    }
1690
1691    #[test]
1692    fn concat_for_type_coercion() -> Result<()> {
1693        let empty = empty_with_type(Utf8);
1694        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
1695
1696        // concat-type signature
1697        {
1698            let expr = ScalarUDF::new_from_impl(TestScalarUDF {
1699                signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
1700            })
1701            .call(args.to_vec());
1702            let plan = LogicalPlan::Projection(Projection::try_new(
1703                vec![expr],
1704                Arc::clone(&empty),
1705            )?);
1706            let expected =
1707                "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n  EmptyRelation";
1708            assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1709        }
1710
1711        Ok(())
1712    }
1713
1714    #[test]
1715    fn test_type_coercion_rewrite() -> Result<()> {
1716        // gt
1717        let schema = Arc::new(DFSchema::from_unqualified_fields(
1718            vec![Field::new("a", DataType::Int64, true)].into(),
1719            std::collections::HashMap::new(),
1720        )?);
1721        let mut rewriter = TypeCoercionRewriter { schema: &schema };
1722        let expr = is_true(lit(12i32).gt(lit(13i64)));
1723        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
1724        let result = expr.rewrite(&mut rewriter).data()?;
1725        assert_eq!(expected, result);
1726
1727        // eq
1728        let schema = Arc::new(DFSchema::from_unqualified_fields(
1729            vec![Field::new("a", DataType::Int64, true)].into(),
1730            std::collections::HashMap::new(),
1731        )?);
1732        let mut rewriter = TypeCoercionRewriter { schema: &schema };
1733        let expr = is_true(lit(12i32).eq(lit(13i64)));
1734        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
1735        let result = expr.rewrite(&mut rewriter).data()?;
1736        assert_eq!(expected, result);
1737
1738        // lt
1739        let schema = Arc::new(DFSchema::from_unqualified_fields(
1740            vec![Field::new("a", DataType::Int64, true)].into(),
1741            std::collections::HashMap::new(),
1742        )?);
1743        let mut rewriter = TypeCoercionRewriter { schema: &schema };
1744        let expr = is_true(lit(12i32).lt(lit(13i64)));
1745        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
1746        let result = expr.rewrite(&mut rewriter).data()?;
1747        assert_eq!(expected, result);
1748
1749        Ok(())
1750    }
1751
1752    #[test]
1753    fn binary_op_date32_eq_ts() -> Result<()> {
1754        let expr = cast(
1755            lit("1998-03-18"),
1756            DataType::Timestamp(TimeUnit::Nanosecond, None),
1757        )
1758        .eq(cast(lit("1998-03-18"), DataType::Date32));
1759        let empty = empty();
1760        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1761        let expected =
1762            "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n  EmptyRelation";
1763        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
1764        Ok(())
1765    }
1766
1767    fn cast_if_not_same_type(
1768        expr: Box<Expr>,
1769        data_type: &DataType,
1770        schema: &DFSchemaRef,
1771    ) -> Box<Expr> {
1772        if &expr.get_type(schema).unwrap() != data_type {
1773            Box::new(cast(*expr, data_type.clone()))
1774        } else {
1775            expr
1776        }
1777    }
1778
1779    fn cast_helper(
1780        case: Case,
1781        case_when_type: &DataType,
1782        then_else_type: &DataType,
1783        schema: &DFSchemaRef,
1784    ) -> Case {
1785        let expr = case
1786            .expr
1787            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
1788        let when_then_expr = case
1789            .when_then_expr
1790            .into_iter()
1791            .map(|(when, then)| {
1792                (
1793                    cast_if_not_same_type(when, case_when_type, schema),
1794                    cast_if_not_same_type(then, then_else_type, schema),
1795                )
1796            })
1797            .collect::<Vec<_>>();
1798        let else_expr = case
1799            .else_expr
1800            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
1801
1802        Case {
1803            expr,
1804            when_then_expr,
1805            else_expr,
1806        }
1807    }
1808
1809    #[test]
1810    fn test_case_expression_coercion() -> Result<()> {
1811        let schema = Arc::new(DFSchema::from_unqualified_fields(
1812            vec![
1813                Field::new("boolean", DataType::Boolean, true),
1814                Field::new("integer", DataType::Int32, true),
1815                Field::new("float", DataType::Float32, true),
1816                Field::new(
1817                    "timestamp",
1818                    DataType::Timestamp(TimeUnit::Nanosecond, None),
1819                    true,
1820                ),
1821                Field::new("date", DataType::Date32, true),
1822                Field::new(
1823                    "interval",
1824                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
1825                    true,
1826                ),
1827                Field::new("binary", DataType::Binary, true),
1828                Field::new("string", Utf8, true),
1829                Field::new("decimal", DataType::Decimal128(10, 10), true),
1830            ]
1831            .into(),
1832            std::collections::HashMap::new(),
1833        )?);
1834
1835        let case = Case {
1836            expr: None,
1837            when_then_expr: vec![
1838                (Box::new(col("boolean")), Box::new(col("integer"))),
1839                (Box::new(col("integer")), Box::new(col("float"))),
1840                (Box::new(col("string")), Box::new(col("string"))),
1841            ],
1842            else_expr: None,
1843        };
1844        let case_when_common_type = DataType::Boolean;
1845        let then_else_common_type = Utf8;
1846        let expected = cast_helper(
1847            case.clone(),
1848            &case_when_common_type,
1849            &then_else_common_type,
1850            &schema,
1851        );
1852        let actual = coerce_case_expression(case, &schema)?;
1853        assert_eq!(expected, actual);
1854
1855        let case = Case {
1856            expr: Some(Box::new(col("string"))),
1857            when_then_expr: vec![
1858                (Box::new(col("float")), Box::new(col("integer"))),
1859                (Box::new(col("integer")), Box::new(col("float"))),
1860                (Box::new(col("string")), Box::new(col("string"))),
1861            ],
1862            else_expr: Some(Box::new(col("string"))),
1863        };
1864        let case_when_common_type = Utf8;
1865        let then_else_common_type = Utf8;
1866        let expected = cast_helper(
1867            case.clone(),
1868            &case_when_common_type,
1869            &then_else_common_type,
1870            &schema,
1871        );
1872        let actual = coerce_case_expression(case, &schema)?;
1873        assert_eq!(expected, actual);
1874
1875        let case = Case {
1876            expr: Some(Box::new(col("interval"))),
1877            when_then_expr: vec![
1878                (Box::new(col("float")), Box::new(col("integer"))),
1879                (Box::new(col("binary")), Box::new(col("float"))),
1880                (Box::new(col("string")), Box::new(col("string"))),
1881            ],
1882            else_expr: Some(Box::new(col("string"))),
1883        };
1884        let err = coerce_case_expression(case, &schema).unwrap_err();
1885        assert_eq!(
1886            err.strip_backtrace(),
1887            "Error during planning: \
1888            Failed to coerce case (Interval(MonthDayNano)) and \
1889            when ([Float32, Binary, Utf8]) to common types in \
1890            CASE WHEN expression"
1891        );
1892
1893        let case = Case {
1894            expr: Some(Box::new(col("string"))),
1895            when_then_expr: vec![
1896                (Box::new(col("float")), Box::new(col("date"))),
1897                (Box::new(col("string")), Box::new(col("float"))),
1898                (Box::new(col("string")), Box::new(col("binary"))),
1899            ],
1900            else_expr: Some(Box::new(col("timestamp"))),
1901        };
1902        let err = coerce_case_expression(case, &schema).unwrap_err();
1903        assert_eq!(
1904            err.strip_backtrace(),
1905            "Error during planning: \
1906            Failed to coerce then ([Date32, Float32, Binary]) and \
1907            else (Some(Timestamp(Nanosecond, None))) to common types \
1908            in CASE WHEN expression"
1909        );
1910
1911        Ok(())
1912    }
1913
1914    macro_rules! test_case_expression {
1915        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
1916            let case = Case {
1917                expr: $expr.map(|e| Box::new(col(e))),
1918                when_then_expr: $when_then,
1919                else_expr: None,
1920            };
1921
1922            let expected =
1923                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
1924
1925            let actual = coerce_case_expression(case, &$schema)?;
1926            assert_eq!(expected, actual);
1927        };
1928    }
1929
1930    #[test]
1931    fn tes_case_when_list() -> Result<()> {
1932        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
1933        let schema = Arc::new(DFSchema::from_unqualified_fields(
1934            vec![
1935                Field::new(
1936                    "large_list",
1937                    DataType::LargeList(Arc::clone(&inner_field)),
1938                    true,
1939                ),
1940                Field::new(
1941                    "fixed_list",
1942                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
1943                    true,
1944                ),
1945                Field::new("list", DataType::List(inner_field), true),
1946            ]
1947            .into(),
1948            std::collections::HashMap::new(),
1949        )?);
1950
1951        test_case_expression!(
1952            Some("list"),
1953            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
1954            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1955            Utf8,
1956            schema
1957        );
1958
1959        test_case_expression!(
1960            Some("large_list"),
1961            vec![(Box::new(col("list")), Box::new(lit("1")))],
1962            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1963            Utf8,
1964            schema
1965        );
1966
1967        test_case_expression!(
1968            Some("list"),
1969            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
1970            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
1971            Utf8,
1972            schema
1973        );
1974
1975        test_case_expression!(
1976            Some("fixed_list"),
1977            vec![(Box::new(col("list")), Box::new(lit("1")))],
1978            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
1979            Utf8,
1980            schema
1981        );
1982
1983        test_case_expression!(
1984            Some("fixed_list"),
1985            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
1986            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1987            Utf8,
1988            schema
1989        );
1990
1991        test_case_expression!(
1992            Some("large_list"),
1993            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
1994            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
1995            Utf8,
1996            schema
1997        );
1998        Ok(())
1999    }
2000
2001    #[test]
2002    fn test_then_else_list() -> Result<()> {
2003        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2004        let schema = Arc::new(DFSchema::from_unqualified_fields(
2005            vec![
2006                Field::new("boolean", DataType::Boolean, true),
2007                Field::new(
2008                    "large_list",
2009                    DataType::LargeList(Arc::clone(&inner_field)),
2010                    true,
2011                ),
2012                Field::new(
2013                    "fixed_list",
2014                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2015                    true,
2016                ),
2017                Field::new("list", DataType::List(inner_field), true),
2018            ]
2019            .into(),
2020            std::collections::HashMap::new(),
2021        )?);
2022
2023        // large list and list
2024        test_case_expression!(
2025            None::<String>,
2026            vec![
2027                (Box::new(col("boolean")), Box::new(col("large_list"))),
2028                (Box::new(col("boolean")), Box::new(col("list")))
2029            ],
2030            DataType::Boolean,
2031            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2032            schema
2033        );
2034
2035        test_case_expression!(
2036            None::<String>,
2037            vec![
2038                (Box::new(col("boolean")), Box::new(col("list"))),
2039                (Box::new(col("boolean")), Box::new(col("large_list")))
2040            ],
2041            DataType::Boolean,
2042            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2043            schema
2044        );
2045
2046        // fixed list and list
2047        test_case_expression!(
2048            None::<String>,
2049            vec![
2050                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2051                (Box::new(col("boolean")), Box::new(col("list")))
2052            ],
2053            DataType::Boolean,
2054            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2055            schema
2056        );
2057
2058        test_case_expression!(
2059            None::<String>,
2060            vec![
2061                (Box::new(col("boolean")), Box::new(col("list"))),
2062                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2063            ],
2064            DataType::Boolean,
2065            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2066            schema
2067        );
2068
2069        // fixed list and large list
2070        test_case_expression!(
2071            None::<String>,
2072            vec![
2073                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2074                (Box::new(col("boolean")), Box::new(col("large_list")))
2075            ],
2076            DataType::Boolean,
2077            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2078            schema
2079        );
2080
2081        test_case_expression!(
2082            None::<String>,
2083            vec![
2084                (Box::new(col("boolean")), Box::new(col("large_list"))),
2085                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2086            ],
2087            DataType::Boolean,
2088            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2089            schema
2090        );
2091        Ok(())
2092    }
2093
2094    #[test]
2095    fn test_map_with_diff_name() -> Result<()> {
2096        let mut builder = SchemaBuilder::new();
2097        builder.push(Field::new("key", Utf8, false));
2098        builder.push(Field::new("value", DataType::Float64, true));
2099        let struct_fields = builder.finish().fields;
2100
2101        let fields =
2102            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2103        let map_type_entries = DataType::Map(Arc::new(fields), false);
2104
2105        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2106        let may_type_cutsom = DataType::Map(Arc::new(fields), false);
2107
2108        let expr = col("a").eq(cast(col("a"), may_type_cutsom));
2109        let empty = empty_with_type(map_type_entries);
2110        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2111        let expected = "Projection: a = CAST(CAST(a AS Map(Field { name: \"key_value\", data_type: Struct([Field { name: \"key\", data_type: Utf8, \
2112        nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), \
2113        nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: \"entries\", data_type: Struct([Field { name: \"key\", data_type: Utf8, nullable: false, \
2114        dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))\n  \
2115        EmptyRelation";
2116        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)
2117    }
2118
2119    #[test]
2120    fn interval_plus_timestamp() -> Result<()> {
2121        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2122        let expr = Expr::BinaryExpr(BinaryExpr::new(
2123            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2124            Operator::Plus,
2125            Box::new(cast(
2126                lit("2000-01-01T00:00:00"),
2127                DataType::Timestamp(TimeUnit::Nanosecond, None),
2128            )),
2129        ));
2130        let empty = empty();
2131        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2132        let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n  EmptyRelation";
2133        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2134        Ok(())
2135    }
2136
2137    #[test]
2138    fn timestamp_subtract_timestamp() -> Result<()> {
2139        let expr = Expr::BinaryExpr(BinaryExpr::new(
2140            Box::new(cast(
2141                lit("1998-03-18"),
2142                DataType::Timestamp(TimeUnit::Nanosecond, None),
2143            )),
2144            Operator::Minus,
2145            Box::new(cast(
2146                lit("1998-03-18"),
2147                DataType::Timestamp(TimeUnit::Nanosecond, None),
2148            )),
2149        ));
2150        let empty = empty();
2151        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2152        let expected =
2153            "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n  EmptyRelation";
2154        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2155        Ok(())
2156    }
2157
2158    #[test]
2159    fn in_subquery_cast_subquery() -> Result<()> {
2160        let empty_int32 = empty_with_type(DataType::Int32);
2161        let empty_int64 = empty_with_type(DataType::Int64);
2162
2163        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2164            Box::new(col("a")),
2165            Subquery {
2166                subquery: empty_int32,
2167                outer_ref_columns: vec![],
2168                spans: Spans::new(),
2169            },
2170            false,
2171        ));
2172        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2173        // add cast for subquery
2174        let expected = "\
2175        Filter: a IN (<subquery>)\
2176        \n  Subquery:\
2177        \n    Projection: CAST(a AS Int64)\
2178        \n      EmptyRelation\
2179        \n  EmptyRelation";
2180        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2181        Ok(())
2182    }
2183
2184    #[test]
2185    fn in_subquery_cast_expr() -> Result<()> {
2186        let empty_int32 = empty_with_type(DataType::Int32);
2187        let empty_int64 = empty_with_type(DataType::Int64);
2188
2189        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2190            Box::new(col("a")),
2191            Subquery {
2192                subquery: empty_int64,
2193                outer_ref_columns: vec![],
2194                spans: Spans::new(),
2195            },
2196            false,
2197        ));
2198        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2199        // add cast for subquery
2200        let expected = "\
2201        Filter: CAST(a AS Int64) IN (<subquery>)\
2202        \n  Subquery:\
2203        \n    EmptyRelation\
2204        \n  EmptyRelation";
2205        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2206        Ok(())
2207    }
2208
2209    #[test]
2210    fn in_subquery_cast_all() -> Result<()> {
2211        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2212        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2213
2214        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2215            Box::new(col("a")),
2216            Subquery {
2217                subquery: empty_inside,
2218                outer_ref_columns: vec![],
2219                spans: Spans::new(),
2220            },
2221            false,
2222        ));
2223        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2224        // add cast for subquery
2225        let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)\
2226        \n  Subquery:\
2227        \n    Projection: CAST(a AS Decimal128(13, 8))\
2228        \n      EmptyRelation\
2229        \n  EmptyRelation";
2230        assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
2231        Ok(())
2232    }
2233}