datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32    exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33    DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44    data_types_with_scalar_udf, fields_with_aggregate_udf,
45};
46use datafusion_expr::type_coercion::other::{
47    get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53    AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
54    ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
55};
56
57/// Performs type coercion by determining the schema
58/// and performing the expression rewrites.
59#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63    pub fn new() -> Self {
64        Self {}
65    }
66}
67
68/// Coerce output schema based upon optimizer config.
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70    if !config.optimizer.expand_views_at_output {
71        return Ok(plan);
72    }
73
74    let outer_refs = plan.expressions();
75    if outer_refs.is_empty() {
76        return Ok(plan);
77    }
78
79    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80        coerce_plan_expr_for_schema(plan, &dfschema?)
81    } else {
82        Ok(plan)
83    }
84}
85
86impl AnalyzerRule for TypeCoercion {
87    fn name(&self) -> &str {
88        "type_coercion"
89    }
90
91    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92        let empty_schema = DFSchema::empty();
93
94        // recurse
95        let transformed_plan = plan
96            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97            .data;
98
99        // finish
100        coerce_output(transformed_plan, config)
101    }
102}
103
104/// use the external schema to handle the correlated subqueries case
105///
106/// Assumes that children have already been optimized
107fn analyze_internal(
108    external_schema: &DFSchema,
109    plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111    // get schema representing all available input fields. This is used for data type
112    // resolution only, so order does not matter here
113    let mut schema = merge_schema(&plan.inputs());
114
115    if let LogicalPlan::TableScan(ts) = &plan {
116        let source_schema = DFSchema::try_from_qualified_schema(
117            ts.table_name.clone(),
118            &ts.source.schema(),
119        )?;
120        schema.merge(&source_schema);
121    }
122
123    // merge the outer schema for correlated subqueries
124    // like case:
125    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
126    schema.merge(external_schema);
127
128    // Coerce filter predicates to boolean (handles `WHERE NULL`)
129    let plan = if let LogicalPlan::Filter(mut filter) = plan {
130        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131        LogicalPlan::Filter(filter)
132    } else {
133        plan
134    };
135
136    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138    let name_preserver = NamePreserver::new(&plan);
139    // apply coercion rewrite all expressions in the plan individually
140    plan.map_expressions(|expr| {
141        let original_name = name_preserver.save(&expr);
142        expr.rewrite(&mut expr_rewrite)
143            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144    })?
145    // some plans need extra coercion after their expressions are coerced
146    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147    // recompute the schema after the expressions have been rewritten as the types may have changed
148    .map_data(|plan| plan.recompute_schema())
149}
150
151/// Rewrite expressions to apply type coercion.
152pub struct TypeCoercionRewriter<'a> {
153    pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157    /// Create a new [`TypeCoercionRewriter`] with a provided schema
158    /// representing both the inputs and output of the [`LogicalPlan`] node.
159    pub fn new(schema: &'a DFSchema) -> Self {
160        Self { schema }
161    }
162
163    /// Coerce the [`LogicalPlan`].
164    ///
165    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
166    /// for type-coercion approach.
167    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168        match plan {
169            LogicalPlan::Join(join) => self.coerce_join(join),
170            LogicalPlan::Union(union) => Self::coerce_union(union),
171            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172            _ => Ok(plan),
173        }
174    }
175
176    /// Coerce join equality expressions and join filter
177    ///
178    /// Joins must be treated specially as their equality expressions are stored
179    /// as a parallel list of left and right expressions, rather than a single
180    /// equality expression
181    ///
182    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
183    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
184    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185        join.on = join
186            .on
187            .into_iter()
188            .map(|(lhs, rhs)| {
189                // coerce the arguments as though they were a single binary equality
190                // expression
191                let left_schema = join.left.schema();
192                let right_schema = join.right.schema();
193                let (lhs, rhs) = self.coerce_binary_op(
194                    lhs,
195                    left_schema,
196                    Operator::Eq,
197                    rhs,
198                    right_schema,
199                )?;
200                Ok((lhs, rhs))
201            })
202            .collect::<Result<Vec<_>>>()?;
203
204        // Join filter must be boolean
205        join.filter = join
206            .filter
207            .map(|expr| self.coerce_join_filter(expr))
208            .transpose()?;
209
210        Ok(LogicalPlan::Join(join))
211    }
212
213    /// Coerce the union’s inputs to a common schema compatible with all inputs.
214    /// This occurs after wildcard expansion and the coercion of the input expressions.
215    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216        let union_schema = Arc::new(coerce_union_schema_with_schema(
217            &union_plan.inputs,
218            &union_plan.schema,
219        )?);
220        let new_inputs = union_plan
221            .inputs
222            .into_iter()
223            .map(|p| {
224                let plan =
225                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226                match plan {
227                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
228                        Ok(Arc::new(project_with_column_index(
229                            expr,
230                            input,
231                            Arc::clone(&union_schema),
232                        )?))
233                    }
234                    other_plan => Ok(Arc::new(other_plan)),
235                }
236            })
237            .collect::<Result<Vec<_>>>()?;
238        Ok(LogicalPlan::Union(Union {
239            inputs: new_inputs,
240            schema: union_schema,
241        }))
242    }
243
244    /// Coerce the fetch and skip expression to Int64 type.
245    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246        fn coerce_limit_expr(
247            expr: Expr,
248            schema: &DFSchema,
249            expr_name: &str,
250        ) -> Result<Expr> {
251            let dt = expr.get_type(schema)?;
252            if dt.is_integer() || dt.is_null() {
253                expr.cast_to(&DataType::Int64, schema)
254            } else {
255                plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
256            }
257        }
258
259        let empty_schema = DFSchema::empty();
260        let new_fetch = limit
261            .fetch
262            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263            .transpose()?;
264        let new_skip = limit
265            .skip
266            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267            .transpose()?;
268        Ok(LogicalPlan::Limit(Limit {
269            input: limit.input,
270            fetch: new_fetch.map(Box::new),
271            skip: new_skip.map(Box::new),
272        }))
273    }
274
275    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276        let expr_type = expr.get_type(self.schema)?;
277        match expr_type {
278            DataType::Boolean => Ok(expr),
279            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281        }
282    }
283
284    fn coerce_binary_op(
285        &self,
286        left: Expr,
287        left_schema: &DFSchema,
288        op: Operator,
289        right: Expr,
290        right_schema: &DFSchema,
291    ) -> Result<(Expr, Expr)> {
292        let (left_type, right_type) = BinaryTypeCoercer::new(
293            &left.get_type(left_schema)?,
294            &op,
295            &right.get_type(right_schema)?,
296        )
297        .get_input_types()?;
298
299        Ok((
300            left.cast_to(&left_type, left_schema)?,
301            right.cast_to(&right_type, right_schema)?,
302        ))
303    }
304}
305
306impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
307    type Node = Expr;
308
309    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
310        match expr {
311            Expr::Unnest(_) => not_impl_err!(
312                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
313            ),
314            Expr::ScalarSubquery(Subquery {
315                subquery,
316                outer_ref_columns,
317                spans,
318            }) => {
319                let new_plan =
320                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
321                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
322                    subquery: Arc::new(new_plan),
323                    outer_ref_columns,
324                    spans,
325                })))
326            }
327            Expr::Exists(Exists { subquery, negated }) => {
328                let new_plan = analyze_internal(
329                    self.schema,
330                    Arc::unwrap_or_clone(subquery.subquery),
331                )?
332                .data;
333                Ok(Transformed::yes(Expr::Exists(Exists {
334                    subquery: Subquery {
335                        subquery: Arc::new(new_plan),
336                        outer_ref_columns: subquery.outer_ref_columns,
337                        spans: subquery.spans,
338                    },
339                    negated,
340                })))
341            }
342            Expr::InSubquery(InSubquery {
343                expr,
344                subquery,
345                negated,
346            }) => {
347                let new_plan = analyze_internal(
348                    self.schema,
349                    Arc::unwrap_or_clone(subquery.subquery),
350                )?
351                .data;
352                let expr_type = expr.get_type(self.schema)?;
353                let subquery_type = new_plan.schema().field(0).data_type();
354                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
355                        "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
356                    ),
357                )?;
358                let new_subquery = Subquery {
359                    subquery: Arc::new(new_plan),
360                    outer_ref_columns: subquery.outer_ref_columns,
361                    spans: subquery.spans,
362                };
363                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
364                    Box::new(expr.cast_to(&common_type, self.schema)?),
365                    cast_subquery(new_subquery, &common_type)?,
366                    negated,
367                ))))
368            }
369            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
370                *expr,
371                self.schema,
372            )?))),
373            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
374                get_casted_expr_for_bool_op(*expr, self.schema)?,
375            ))),
376            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
377                get_casted_expr_for_bool_op(*expr, self.schema)?,
378            ))),
379            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
380                get_casted_expr_for_bool_op(*expr, self.schema)?,
381            ))),
382            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
383                get_casted_expr_for_bool_op(*expr, self.schema)?,
384            ))),
385            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
386                get_casted_expr_for_bool_op(*expr, self.schema)?,
387            ))),
388            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
389                get_casted_expr_for_bool_op(*expr, self.schema)?,
390            ))),
391            Expr::Like(Like {
392                negated,
393                expr,
394                pattern,
395                escape_char,
396                case_insensitive,
397            }) => {
398                let left_type = expr.get_type(self.schema)?;
399                let right_type = pattern.get_type(self.schema)?;
400                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
401                    let op_name = if case_insensitive {
402                        "ILIKE"
403                    } else {
404                        "LIKE"
405                    };
406                    plan_datafusion_err!(
407                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
408                    )
409                })?;
410                let expr = match left_type {
411                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
412                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
413                };
414                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
415                Ok(Transformed::yes(Expr::Like(Like::new(
416                    negated,
417                    expr,
418                    pattern,
419                    escape_char,
420                    case_insensitive,
421                ))))
422            }
423            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
424                let (left, right) =
425                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
426                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
427                    Box::new(left),
428                    op,
429                    Box::new(right),
430                ))))
431            }
432            Expr::Between(Between {
433                expr,
434                negated,
435                low,
436                high,
437            }) => {
438                let expr_type = expr.get_type(self.schema)?;
439                let low_type = low.get_type(self.schema)?;
440                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
441                    .ok_or_else(|| {
442                        DataFusionError::Internal(format!(
443                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
444                        ))
445                    })?;
446                let high_type = high.get_type(self.schema)?;
447                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
448                    .ok_or_else(|| {
449                        DataFusionError::Internal(format!(
450                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
451                        ))
452                    })?;
453                let coercion_type =
454                    comparison_coercion(&low_coerced_type, &high_coerced_type)
455                        .ok_or_else(|| {
456                            DataFusionError::Internal(format!(
457                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
458                            ))
459                        })?;
460                Ok(Transformed::yes(Expr::Between(Between::new(
461                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
462                    negated,
463                    Box::new(low.cast_to(&coercion_type, self.schema)?),
464                    Box::new(high.cast_to(&coercion_type, self.schema)?),
465                ))))
466            }
467            Expr::InList(InList {
468                expr,
469                list,
470                negated,
471            }) => {
472                let expr_data_type = expr.get_type(self.schema)?;
473                let list_data_types = list
474                    .iter()
475                    .map(|list_expr| list_expr.get_type(self.schema))
476                    .collect::<Result<Vec<_>>>()?;
477                let result_type =
478                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
479                match result_type {
480                    None => plan_err!(
481                        "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
482                    ),
483                    Some(coerced_type) => {
484                        // find the coerced type
485                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
486                        let cast_list_expr = list
487                            .into_iter()
488                            .map(|list_expr| {
489                                list_expr.cast_to(&coerced_type, self.schema)
490                            })
491                            .collect::<Result<Vec<_>>>()?;
492                        Ok(Transformed::yes(Expr::InList(InList ::new(
493                             Box::new(cast_expr),
494                             cast_list_expr,
495                            negated,
496                        ))))
497                    }
498                }
499            }
500            Expr::Case(case) => {
501                let case = coerce_case_expression(case, self.schema)?;
502                Ok(Transformed::yes(Expr::Case(case)))
503            }
504            Expr::ScalarFunction(ScalarFunction { func, args }) => {
505                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
506                    args,
507                    self.schema,
508                    &func,
509                )?;
510                Ok(Transformed::yes(Expr::ScalarFunction(
511                    ScalarFunction::new_udf(func, new_expr),
512                )))
513            }
514            Expr::AggregateFunction(expr::AggregateFunction {
515                func,
516                params:
517                    AggregateFunctionParams {
518                        args,
519                        distinct,
520                        filter,
521                        order_by,
522                        null_treatment,
523                    },
524            }) => {
525                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
526                    args,
527                    self.schema,
528                    &func,
529                )?;
530                Ok(Transformed::yes(Expr::AggregateFunction(
531                    expr::AggregateFunction::new_udf(
532                        func,
533                        new_expr,
534                        distinct,
535                        filter,
536                        order_by,
537                        null_treatment,
538                    ),
539                )))
540            }
541            Expr::WindowFunction(window_fun) => {
542                let WindowFunction {
543                    fun,
544                    params:
545                        expr::WindowFunctionParams {
546                            args,
547                            partition_by,
548                            order_by,
549                            window_frame,
550                            filter,
551                            null_treatment,
552                            distinct,
553                        },
554                } = *window_fun;
555                let window_frame =
556                    coerce_window_frame(window_frame, self.schema, &order_by)?;
557
558                let args = match &fun {
559                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
560                        coerce_arguments_for_signature_with_aggregate_udf(
561                            args,
562                            self.schema,
563                            udf,
564                        )?
565                    }
566                    _ => args,
567                };
568
569                let new_expr = Expr::from(WindowFunction {
570                    fun,
571                    params: expr::WindowFunctionParams {
572                        args,
573                        partition_by,
574                        order_by,
575                        window_frame,
576                        filter,
577                        null_treatment,
578                        distinct,
579                    },
580                });
581                Ok(Transformed::yes(new_expr))
582            }
583            // TODO: remove the next line after `Expr::Wildcard` is removed
584            #[expect(deprecated)]
585            Expr::Alias(_)
586            | Expr::Column(_)
587            | Expr::ScalarVariable(_, _)
588            | Expr::Literal(_, _)
589            | Expr::SimilarTo(_)
590            | Expr::IsNotNull(_)
591            | Expr::IsNull(_)
592            | Expr::Negative(_)
593            | Expr::Cast(_)
594            | Expr::TryCast(_)
595            | Expr::Wildcard { .. }
596            | Expr::GroupingSet(_)
597            | Expr::Placeholder(_)
598            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
599        }
600    }
601}
602
603/// Transform a schema to use non-view types for Utf8View and BinaryView
604fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
605    let metadata = dfschema.as_arrow().metadata.clone();
606    let mut transformed = false;
607
608    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
609        dfschema
610            .iter()
611            .map(|(qualifier, field)| match field.data_type() {
612                DataType::Utf8View => {
613                    transformed = true;
614                    (
615                        qualifier.cloned() as Option<TableReference>,
616                        Arc::new(Field::new(
617                            field.name(),
618                            DataType::LargeUtf8,
619                            field.is_nullable(),
620                        )),
621                    )
622                }
623                DataType::BinaryView => {
624                    transformed = true;
625                    (
626                        qualifier.cloned() as Option<TableReference>,
627                        Arc::new(Field::new(
628                            field.name(),
629                            DataType::LargeBinary,
630                            field.is_nullable(),
631                        )),
632                    )
633                }
634                _ => (
635                    qualifier.cloned() as Option<TableReference>,
636                    Arc::clone(field),
637                ),
638            })
639            .unzip();
640
641    if !transformed {
642        return None;
643    }
644
645    let schema = Schema::new_with_metadata(transformed_fields, metadata);
646    Some(DFSchema::from_field_specific_qualified_schema(
647        qualifiers,
648        &Arc::new(schema),
649    ))
650}
651
652/// Casts the given `value` to `target_type`. Note that this function
653/// only considers `Null` or `Utf8` values.
654fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
655    match value {
656        // Coerce Utf8 values:
657        ScalarValue::Utf8(Some(val)) => {
658            ScalarValue::try_from_string(val.clone(), target_type)
659        }
660        s => {
661            if s.is_null() {
662                // Coerce `Null` values:
663                ScalarValue::try_from(target_type)
664            } else {
665                // Values except `Utf8`/`Null` variants already have the right type
666                // (casted before) since we convert `sqlparser` outputs to `Utf8`
667                // for all possible cases. Therefore, we return a clone here.
668                Ok(s.clone())
669            }
670        }
671    }
672}
673
674/// This function coerces `value` to `target_type` in a range-aware fashion.
675/// If the coercion is successful, we return an `Ok` value with the result.
676/// If the coercion fails because `target_type` is not wide enough (i.e. we
677/// can not coerce to `target_type`, but we can to a wider type in the same
678/// family), we return a `Null` value of this type to signal this situation.
679/// Downstream code uses this signal to treat these values as *unbounded*.
680fn coerce_scalar_range_aware(
681    target_type: &DataType,
682    value: &ScalarValue,
683) -> Result<ScalarValue> {
684    coerce_scalar(target_type, value).or_else(|err| {
685        // If type coercion fails, check if the largest type in family works:
686        if let Some(largest_type) = get_widest_type_in_family(target_type) {
687            coerce_scalar(largest_type, value).map_or_else(
688                |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
689                |_| ScalarValue::try_from(target_type),
690            )
691        } else {
692            Err(err)
693        }
694    })
695}
696
697/// This function returns the widest type in the family of `given_type`.
698/// If the given type is already the widest type, it returns `None`.
699/// For example, if `given_type` is `Int8`, it returns `Int64`.
700fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
701    match given_type {
702        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
703        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
704        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
705        _ => None,
706    }
707}
708
709/// Coerces the given (window frame) `bound` to `target_type`.
710fn coerce_frame_bound(
711    target_type: &DataType,
712    bound: WindowFrameBound,
713) -> Result<WindowFrameBound> {
714    match bound {
715        WindowFrameBound::Preceding(v) => {
716            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
717        }
718        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
719        WindowFrameBound::Following(v) => {
720            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
721        }
722    }
723}
724
725fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
726    if col_type.is_numeric()
727        || is_utf8_or_utf8view_or_large_utf8(col_type)
728        || matches!(col_type, DataType::List(_))
729        || matches!(col_type, DataType::LargeList(_))
730        || matches!(col_type, DataType::FixedSizeList(_, _))
731        || matches!(col_type, DataType::Null)
732        || matches!(col_type, DataType::Boolean)
733    {
734        Ok(col_type.clone())
735    } else if is_datetime(col_type) {
736        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
737    } else if let DataType::Dictionary(_, value_type) = col_type {
738        extract_window_frame_target_type(value_type)
739    } else {
740        internal_err!("Cannot run range queries on datatype: {col_type:?}")
741    }
742}
743
744// Coerces the given `window_frame` to use appropriate natural types.
745// For example, ROWS and GROUPS frames use `UInt64` during calculations.
746fn coerce_window_frame(
747    window_frame: WindowFrame,
748    schema: &DFSchema,
749    expressions: &[Sort],
750) -> Result<WindowFrame> {
751    let mut window_frame = window_frame;
752    let target_type = match window_frame.units {
753        WindowFrameUnits::Range => {
754            let current_types = expressions
755                .first()
756                .map(|s| s.expr.get_type(schema))
757                .transpose()?;
758            if let Some(col_type) = current_types {
759                extract_window_frame_target_type(&col_type)?
760            } else {
761                return internal_err!("ORDER BY column cannot be empty");
762            }
763        }
764        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
765    };
766    window_frame.start_bound =
767        coerce_frame_bound(&target_type, window_frame.start_bound)?;
768    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
769    Ok(window_frame)
770}
771
772// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
773// The above op will be rewrite to the binary op when creating the physical op.
774fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
775    let left_type = expr.get_type(schema)?;
776    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
777        .get_input_types()?;
778    expr.cast_to(&DataType::Boolean, schema)
779}
780
781/// Returns `expressions` coerced to types compatible with
782/// `signature`, if possible.
783///
784/// See the module level documentation for more detail on coercion.
785fn coerce_arguments_for_signature_with_scalar_udf(
786    expressions: Vec<Expr>,
787    schema: &DFSchema,
788    func: &ScalarUDF,
789) -> Result<Vec<Expr>> {
790    if expressions.is_empty() {
791        return Ok(expressions);
792    }
793
794    let current_types = expressions
795        .iter()
796        .map(|e| e.get_type(schema))
797        .collect::<Result<Vec<_>>>()?;
798
799    let new_types = data_types_with_scalar_udf(&current_types, func)?;
800
801    expressions
802        .into_iter()
803        .enumerate()
804        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
805        .collect()
806}
807
808/// Returns `expressions` coerced to types compatible with
809/// `signature`, if possible.
810///
811/// See the module level documentation for more detail on coercion.
812fn coerce_arguments_for_signature_with_aggregate_udf(
813    expressions: Vec<Expr>,
814    schema: &DFSchema,
815    func: &AggregateUDF,
816) -> Result<Vec<Expr>> {
817    if expressions.is_empty() {
818        return Ok(expressions);
819    }
820
821    let current_fields = expressions
822        .iter()
823        .map(|e| e.to_field(schema).map(|(_, f)| f))
824        .collect::<Result<Vec<_>>>()?;
825
826    let new_types = fields_with_aggregate_udf(&current_fields, func)?
827        .into_iter()
828        .map(|f| f.data_type().clone())
829        .collect::<Vec<_>>();
830
831    expressions
832        .into_iter()
833        .enumerate()
834        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
835        .collect()
836}
837
838fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
839    // Given expressions like:
840    //
841    // CASE a1
842    //   WHEN a2 THEN b1
843    //   WHEN a3 THEN b2
844    //   ELSE b3
845    // END
846    //
847    // or:
848    //
849    // CASE
850    //   WHEN x1 THEN b1
851    //   WHEN x2 THEN b2
852    //   ELSE b3
853    // END
854    //
855    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
856    // (case-when expression coercion)
857    //
858    // All xN (x1, x2) must be converted to a boolean data type in the second example
859    // (when-boolean expression coercion)
860    //
861    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
862    // (then-else expression coercion)
863    //
864    // If any fail to find and cast to a common/specific data type, will return error
865    //
866    // Note that case-when and when-boolean expression coercions are mutually exclusive
867    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
868
869    // prepare types
870    let case_type = case
871        .expr
872        .as_ref()
873        .map(|expr| expr.get_type(schema))
874        .transpose()?;
875    let then_types = case
876        .when_then_expr
877        .iter()
878        .map(|(_when, then)| then.get_type(schema))
879        .collect::<Result<Vec<_>>>()?;
880    let else_type = case
881        .else_expr
882        .as_ref()
883        .map(|expr| expr.get_type(schema))
884        .transpose()?;
885
886    // find common coercible types
887    let case_when_coerce_type = case_type
888        .as_ref()
889        .map(|case_type| {
890            let when_types = case
891                .when_then_expr
892                .iter()
893                .map(|(when, _then)| when.get_type(schema))
894                .collect::<Result<Vec<_>>>()?;
895            let coerced_type =
896                get_coerce_type_for_case_expression(&when_types, Some(case_type));
897            coerced_type.ok_or_else(|| {
898                plan_datafusion_err!(
899                    "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
900                     to common types in CASE WHEN expression"
901                )
902            })
903        })
904        .transpose()?;
905    let then_else_coerce_type =
906        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
907            || {
908                plan_datafusion_err!(
909                    "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
910                     to common types in CASE WHEN expression"
911                )
912            },
913        )?;
914
915    // do cast if found common coercible types
916    let case_expr = case
917        .expr
918        .zip(case_when_coerce_type.as_ref())
919        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
920        .transpose()?
921        .map(Box::new);
922    let when_then = case
923        .when_then_expr
924        .into_iter()
925        .map(|(when, then)| {
926            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
927            let when = when.cast_to(when_type, schema).map_err(|e| {
928                DataFusionError::Context(
929                    format!(
930                        "WHEN expressions in CASE couldn't be \
931                         converted to common type ({when_type})"
932                    ),
933                    Box::new(e),
934                )
935            })?;
936            let then = then.cast_to(&then_else_coerce_type, schema)?;
937            Ok((Box::new(when), Box::new(then)))
938        })
939        .collect::<Result<Vec<_>>>()?;
940    let else_expr = case
941        .else_expr
942        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
943        .transpose()?
944        .map(Box::new);
945
946    Ok(Case::new(case_expr, when_then, else_expr))
947}
948
949/// Get a common schema that is compatible with all inputs of UNION.
950///
951/// This method presumes that the wildcard expansion is unneeded, or has already
952/// been applied.
953///
954/// ## Schema and Field Handling in Union Coercion
955///
956/// **Processing order**: The function starts with the base schema (first input) and then
957/// processes remaining inputs sequentially, with later inputs taking precedence in merging.
958///
959/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys.
960///
961/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys.
962///
963/// **Type coercion precedence**: The coerced type is determined by iteratively applying
964/// `comparison_coercion()` between the accumulated type and each new input's type. The
965/// result depends on type coercion rules, not input order.
966///
967/// **Nullability merging**: Nullability is accumulated using logical OR (`||`).
968/// Once any input field is nullable, the result field becomes nullable permanently.
969/// Later inputs can make a field nullable but cannot make it non-nullable.
970///
971/// **Field precedence**: Field names come from the first (base) schema, but the field properties
972/// (nullability and field-level metadata) have later schemas taking precedence.
973///
974/// **Example**:
975/// ```sql
976/// SELECT a, b FROM table1  -- a: Int32, metadata {"source": "t1"}, nullable=false
977/// UNION
978/// SELECT a, b FROM table2  -- a: Int64, metadata {"source": "t2"}, nullable=true
979/// UNION
980/// SELECT a, b FROM table3  -- a: Int32, metadata {"encoding": "utf8"}, nullable=false
981/// -- Result:
982/// -- a: Int64 (from type coercion), nullable=true (from table2),
983/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence)
984/// ```
985///
986/// **Precedence Summary**:
987/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order
988/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR)
989/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics)
990pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
991    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
992}
993fn coerce_union_schema_with_schema(
994    inputs: &[Arc<LogicalPlan>],
995    base_schema: &DFSchemaRef,
996) -> Result<DFSchema> {
997    let mut union_datatypes = base_schema
998        .fields()
999        .iter()
1000        .map(|f| f.data_type().clone())
1001        .collect::<Vec<_>>();
1002    let mut union_nullabilities = base_schema
1003        .fields()
1004        .iter()
1005        .map(|f| f.is_nullable())
1006        .collect::<Vec<_>>();
1007    let mut union_field_meta = base_schema
1008        .fields()
1009        .iter()
1010        .map(|f| f.metadata().clone())
1011        .collect::<Vec<_>>();
1012
1013    let mut metadata = base_schema.metadata().clone();
1014
1015    for (i, plan) in inputs.iter().enumerate() {
1016        let plan_schema = plan.schema();
1017        metadata.extend(plan_schema.metadata().clone());
1018
1019        if plan_schema.fields().len() != base_schema.fields().len() {
1020            return plan_err!(
1021                "Union schemas have different number of fields: \
1022                query 1 has {} fields whereas query {} has {} fields",
1023                base_schema.fields().len(),
1024                i + 1,
1025                plan_schema.fields().len()
1026            );
1027        }
1028
1029        // coerce data type and nullability for each field
1030        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1031            union_datatypes.iter_mut(),
1032            union_nullabilities.iter_mut(),
1033            union_field_meta.iter_mut(),
1034            plan_schema.fields().iter()
1035        ) {
1036            let coerced_type =
1037                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1038                    || {
1039                        plan_datafusion_err!(
1040                            "Incompatible inputs for Union: Previous inputs were \
1041                            of type {}, but got incompatible type {} on column '{}'",
1042                            union_datatype,
1043                            plan_field.data_type(),
1044                            plan_field.name()
1045                        )
1046                    },
1047                )?;
1048
1049            *union_datatype = coerced_type;
1050            *union_nullable = *union_nullable || plan_field.is_nullable();
1051            union_field_map.extend(plan_field.metadata().clone());
1052        }
1053    }
1054    let union_qualified_fields = izip!(
1055        base_schema.fields(),
1056        union_datatypes.into_iter(),
1057        union_nullabilities,
1058        union_field_meta.into_iter()
1059    )
1060    .map(|(field, datatype, nullable, metadata)| {
1061        let mut field = Field::new(field.name().clone(), datatype, nullable);
1062        field.set_metadata(metadata);
1063        (None, field.into())
1064    })
1065    .collect::<Vec<_>>();
1066
1067    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1068}
1069
1070/// See `<https://github.com/apache/datafusion/pull/2108>`
1071fn project_with_column_index(
1072    expr: Vec<Expr>,
1073    input: Arc<LogicalPlan>,
1074    schema: DFSchemaRef,
1075) -> Result<LogicalPlan> {
1076    let alias_expr = expr
1077        .into_iter()
1078        .enumerate()
1079        .map(|(i, e)| match e {
1080            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1081                Ok(e.unalias().alias(schema.field(i).name()))
1082            }
1083            Expr::Column(Column {
1084                relation: _,
1085                ref name,
1086                spans: _,
1087            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1088            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1089            #[expect(deprecated)]
1090            Expr::Wildcard { .. } => {
1091                plan_err!("Wildcard should be expanded before type coercion")
1092            }
1093            _ => Ok(e.alias(schema.field(i).name())),
1094        })
1095        .collect::<Result<Vec<_>>>()?;
1096
1097    Projection::try_new_with_schema(alias_expr, input, schema)
1098        .map(LogicalPlan::Projection)
1099}
1100
1101#[cfg(test)]
1102mod test {
1103    use std::any::Any;
1104    use std::sync::Arc;
1105
1106    use arrow::datatypes::DataType::Utf8;
1107    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1108    use insta::assert_snapshot;
1109
1110    use crate::analyzer::type_coercion::{
1111        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1112    };
1113    use crate::analyzer::Analyzer;
1114    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1115    use datafusion_common::config::ConfigOptions;
1116    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1117    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1118    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1119    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1120    use datafusion_expr::test::function_stub::avg_udaf;
1121    use datafusion_expr::{
1122        cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1123        BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1124        Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1125        SimpleAggregateUDF, Subquery, Union, Volatility,
1126    };
1127    use datafusion_functions_aggregate::average::AvgAccumulator;
1128    use datafusion_sql::TableReference;
1129
1130    fn empty() -> Arc<LogicalPlan> {
1131        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1132            produce_one_row: false,
1133            schema: Arc::new(DFSchema::empty()),
1134        }))
1135    }
1136
1137    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1138        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1139            produce_one_row: false,
1140            schema: Arc::new(
1141                DFSchema::from_unqualified_fields(
1142                    vec![Field::new("a", data_type, true)].into(),
1143                    std::collections::HashMap::new(),
1144                )
1145                .unwrap(),
1146            ),
1147        }))
1148    }
1149
1150    macro_rules! assert_analyzed_plan_eq {
1151        (
1152            $plan: expr,
1153            @ $expected: literal $(,)?
1154        ) => {{
1155            let options = ConfigOptions::default();
1156            let rule = Arc::new(TypeCoercion::new());
1157            assert_analyzed_plan_with_config_eq_snapshot!(
1158                options,
1159                rule,
1160                $plan,
1161                @ $expected,
1162            )
1163            }};
1164    }
1165
1166    macro_rules! coerce_on_output_if_viewtype {
1167        (
1168            $is_viewtype: expr,
1169            $plan: expr,
1170            @ $expected: literal $(,)?
1171        ) => {{
1172            let mut options = ConfigOptions::default();
1173            // coerce on output
1174            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1175            let rule = Arc::new(TypeCoercion::new());
1176
1177            assert_analyzed_plan_with_config_eq_snapshot!(
1178                options,
1179                rule,
1180                $plan,
1181                @ $expected,
1182            )
1183        }};
1184    }
1185
1186    fn assert_type_coercion_error(
1187        plan: LogicalPlan,
1188        expected_substr: &str,
1189    ) -> Result<()> {
1190        let options = ConfigOptions::default();
1191        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1192
1193        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1194            Ok(succeeded_plan) => {
1195                panic!(
1196                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1197                );
1198            }
1199            Err(e) => {
1200                let msg = e.to_string();
1201                assert!(
1202                    msg.contains(expected_substr),
1203                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1204                );
1205            }
1206        }
1207
1208        Ok(())
1209    }
1210
1211    #[test]
1212    fn simple_case() -> Result<()> {
1213        let expr = col("a").lt(lit(2_u32));
1214        let empty = empty_with_type(DataType::Float64);
1215        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1216
1217        assert_analyzed_plan_eq!(
1218            plan,
1219            @r"
1220        Projection: a < CAST(UInt32(2) AS Float64)
1221          EmptyRelation: rows=0
1222        "
1223        )
1224    }
1225
1226    #[test]
1227    fn test_coerce_union() -> Result<()> {
1228        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1229            produce_one_row: false,
1230            schema: Arc::new(
1231                DFSchema::try_from_qualified_schema(
1232                    TableReference::full("datafusion", "test", "foo"),
1233                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1234                )
1235                .unwrap(),
1236            ),
1237        }));
1238        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1239            produce_one_row: false,
1240            schema: Arc::new(
1241                DFSchema::try_from_qualified_schema(
1242                    TableReference::full("datafusion", "test", "foo"),
1243                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1244                )
1245                .unwrap(),
1246            ),
1247        }));
1248        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1249            left_plan, right_plan,
1250        ])?);
1251        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1252            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1253        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1254            vec![col("a")],
1255            Arc::new(analyzed_union),
1256        )?);
1257
1258        assert_analyzed_plan_eq!(
1259            top_level_plan,
1260            @r"
1261        Projection: a
1262          Union
1263            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1264              EmptyRelation: rows=0
1265            EmptyRelation: rows=0
1266        "
1267        )
1268    }
1269
1270    #[test]
1271    fn coerce_utf8view_output() -> Result<()> {
1272        // Plan A
1273        // scenario: outermost utf8view projection
1274        let expr = col("a");
1275        let empty = empty_with_type(DataType::Utf8View);
1276        let plan = LogicalPlan::Projection(Projection::try_new(
1277            vec![expr.clone()],
1278            Arc::clone(&empty),
1279        )?);
1280
1281        // Plan A: no coerce
1282        coerce_on_output_if_viewtype!(
1283            false,
1284            plan.clone(),
1285            @r"
1286        Projection: a
1287          EmptyRelation: rows=0
1288        "
1289        )?;
1290
1291        // Plan A: coerce requested: Utf8View => LargeUtf8
1292        coerce_on_output_if_viewtype!(
1293            true,
1294            plan.clone(),
1295            @r"
1296        Projection: CAST(a AS LargeUtf8)
1297          EmptyRelation: rows=0
1298        "
1299        )?;
1300
1301        // Plan B
1302        // scenario: outermost bool projection
1303        let bool_expr = col("a").lt(lit("foo"));
1304        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1305            vec![bool_expr],
1306            Arc::clone(&empty),
1307        )?);
1308        // Plan B: no coerce
1309        coerce_on_output_if_viewtype!(
1310            false,
1311            bool_plan.clone(),
1312            @r#"
1313        Projection: a < CAST(Utf8("foo") AS Utf8View)
1314          EmptyRelation: rows=0
1315        "#
1316        )?;
1317
1318        coerce_on_output_if_viewtype!(
1319            false,
1320            plan.clone(),
1321            @r"
1322        Projection: a
1323          EmptyRelation: rows=0
1324        "
1325        )?;
1326
1327        // Plan B: coerce requested: no coercion applied
1328        coerce_on_output_if_viewtype!(
1329            true,
1330            plan.clone(),
1331            @r"
1332        Projection: CAST(a AS LargeUtf8)
1333          EmptyRelation: rows=0
1334        "
1335        )?;
1336
1337        // Plan C
1338        // scenario: with a non-projection root logical plan node
1339        let sort_expr = expr.sort(true, true);
1340        let sort_plan = LogicalPlan::Sort(Sort {
1341            expr: vec![sort_expr],
1342            input: Arc::new(plan),
1343            fetch: None,
1344        });
1345
1346        // Plan C: no coerce
1347        coerce_on_output_if_viewtype!(
1348            false,
1349            sort_plan.clone(),
1350            @r"
1351        Sort: a ASC NULLS FIRST
1352          Projection: a
1353            EmptyRelation: rows=0
1354        "
1355        )?;
1356
1357        // Plan C: coerce requested: Utf8View => LargeUtf8
1358        coerce_on_output_if_viewtype!(
1359            true,
1360            sort_plan.clone(),
1361            @r"
1362        Projection: CAST(a AS LargeUtf8)
1363          Sort: a ASC NULLS FIRST
1364            Projection: a
1365              EmptyRelation: rows=0
1366        "
1367        )?;
1368
1369        // Plan D
1370        // scenario: two layers of projections with view types
1371        let plan = LogicalPlan::Projection(Projection::try_new(
1372            vec![col("a")],
1373            Arc::new(sort_plan),
1374        )?);
1375        // Plan D: no coerce
1376        coerce_on_output_if_viewtype!(
1377            false,
1378            plan.clone(),
1379            @r"
1380        Projection: a
1381          Sort: a ASC NULLS FIRST
1382            Projection: a
1383              EmptyRelation: rows=0
1384        "
1385        )?;
1386        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1387        coerce_on_output_if_viewtype!(
1388            true,
1389            plan.clone(),
1390            @r"
1391        Projection: CAST(a AS LargeUtf8)
1392          Sort: a ASC NULLS FIRST
1393            Projection: a
1394              EmptyRelation: rows=0
1395        "
1396        )?;
1397
1398        Ok(())
1399    }
1400
1401    #[test]
1402    fn coerce_binaryview_output() -> Result<()> {
1403        // Plan A
1404        // scenario: outermost binaryview projection
1405        let expr = col("a");
1406        let empty = empty_with_type(DataType::BinaryView);
1407        let plan = LogicalPlan::Projection(Projection::try_new(
1408            vec![expr.clone()],
1409            Arc::clone(&empty),
1410        )?);
1411
1412        // Plan A: no coerce
1413        coerce_on_output_if_viewtype!(
1414            false,
1415            plan.clone(),
1416            @r"
1417        Projection: a
1418          EmptyRelation: rows=0
1419        "
1420        )?;
1421
1422        // Plan A: coerce requested: BinaryView => LargeBinary
1423        coerce_on_output_if_viewtype!(
1424            true,
1425            plan.clone(),
1426            @r"
1427        Projection: CAST(a AS LargeBinary)
1428          EmptyRelation: rows=0
1429        "
1430        )?;
1431
1432        // Plan B
1433        // scenario: outermost bool projection
1434        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1435        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1436            vec![bool_expr],
1437            Arc::clone(&empty),
1438        )?);
1439
1440        // Plan B: no coerce
1441        coerce_on_output_if_viewtype!(
1442            false,
1443            bool_plan.clone(),
1444            @r#"
1445        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1446          EmptyRelation: rows=0
1447        "#
1448        )?;
1449
1450        // Plan B: coerce requested: no coercion applied
1451        coerce_on_output_if_viewtype!(
1452            true,
1453            bool_plan.clone(),
1454            @r#"
1455        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1456          EmptyRelation: rows=0
1457        "#
1458        )?;
1459
1460        // Plan C
1461        // scenario: with a non-projection root logical plan node
1462        let sort_expr = expr.sort(true, true);
1463        let sort_plan = LogicalPlan::Sort(Sort {
1464            expr: vec![sort_expr],
1465            input: Arc::new(plan),
1466            fetch: None,
1467        });
1468
1469        // Plan C: no coerce
1470        coerce_on_output_if_viewtype!(
1471            false,
1472            sort_plan.clone(),
1473            @r"
1474        Sort: a ASC NULLS FIRST
1475          Projection: a
1476            EmptyRelation: rows=0
1477        "
1478        )?;
1479        // Plan C: coerce requested: BinaryView => LargeBinary
1480        coerce_on_output_if_viewtype!(
1481            true,
1482            sort_plan.clone(),
1483            @r"
1484        Projection: CAST(a AS LargeBinary)
1485          Sort: a ASC NULLS FIRST
1486            Projection: a
1487              EmptyRelation: rows=0
1488        "
1489        )?;
1490
1491        // Plan D
1492        // scenario: two layers of projections with view types
1493        let plan = LogicalPlan::Projection(Projection::try_new(
1494            vec![col("a")],
1495            Arc::new(sort_plan),
1496        )?);
1497
1498        // Plan D: no coerce
1499        coerce_on_output_if_viewtype!(
1500            false,
1501            plan.clone(),
1502            @r"
1503        Projection: a
1504          Sort: a ASC NULLS FIRST
1505            Projection: a
1506              EmptyRelation: rows=0
1507        "
1508        )?;
1509
1510        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1511        coerce_on_output_if_viewtype!(
1512            true,
1513            plan.clone(),
1514            @r"
1515        Projection: CAST(a AS LargeBinary)
1516          Sort: a ASC NULLS FIRST
1517            Projection: a
1518              EmptyRelation: rows=0
1519        "
1520        )?;
1521
1522        Ok(())
1523    }
1524
1525    #[test]
1526    fn nested_case() -> Result<()> {
1527        let expr = col("a").lt(lit(2_u32));
1528        let empty = empty_with_type(DataType::Float64);
1529
1530        let plan = LogicalPlan::Projection(Projection::try_new(
1531            vec![expr.clone().or(expr)],
1532            empty,
1533        )?);
1534
1535        assert_analyzed_plan_eq!(
1536            plan,
1537            @r"
1538        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1539          EmptyRelation: rows=0
1540        "
1541        )
1542    }
1543
1544    #[derive(Debug, PartialEq, Eq, Hash)]
1545    struct TestScalarUDF {
1546        signature: Signature,
1547    }
1548
1549    impl ScalarUDFImpl for TestScalarUDF {
1550        fn as_any(&self) -> &dyn Any {
1551            self
1552        }
1553
1554        fn name(&self) -> &str {
1555            "TestScalarUDF"
1556        }
1557
1558        fn signature(&self) -> &Signature {
1559            &self.signature
1560        }
1561
1562        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1563            Ok(Utf8)
1564        }
1565
1566        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1567            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1568        }
1569    }
1570
1571    #[test]
1572    fn scalar_udf() -> Result<()> {
1573        let empty = empty();
1574
1575        let udf = ScalarUDF::from(TestScalarUDF {
1576            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1577        })
1578        .call(vec![lit(123_i32)]);
1579        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1580
1581        assert_analyzed_plan_eq!(
1582            plan,
1583            @r"
1584        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1585          EmptyRelation: rows=0
1586        "
1587        )
1588    }
1589
1590    #[test]
1591    fn scalar_udf_invalid_input() -> Result<()> {
1592        let empty = empty();
1593        let udf = ScalarUDF::from(TestScalarUDF {
1594            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1595        })
1596        .call(vec![lit("Apple")]);
1597        Projection::try_new(vec![udf], empty)
1598            .expect_err("Expected an error due to incorrect function input");
1599
1600        Ok(())
1601    }
1602
1603    #[test]
1604    fn scalar_function() -> Result<()> {
1605        // test that automatic argument type coercion for scalar functions work
1606        let empty = empty();
1607        let lit_expr = lit(10i64);
1608        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1609            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1610        });
1611        let scalar_function_expr =
1612            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1613        let plan = LogicalPlan::Projection(Projection::try_new(
1614            vec![scalar_function_expr],
1615            empty,
1616        )?);
1617
1618        assert_analyzed_plan_eq!(
1619            plan,
1620            @r"
1621        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1622          EmptyRelation: rows=0
1623        "
1624        )
1625    }
1626
1627    #[test]
1628    fn agg_udaf() -> Result<()> {
1629        let empty = empty();
1630        let my_avg = create_udaf(
1631            "MY_AVG",
1632            vec![DataType::Float64],
1633            Arc::new(DataType::Float64),
1634            Volatility::Immutable,
1635            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1636            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1637        );
1638        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1639            Arc::new(my_avg),
1640            vec![lit(10i64)],
1641            false,
1642            None,
1643            vec![],
1644            None,
1645        ));
1646        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1647
1648        assert_analyzed_plan_eq!(
1649            plan,
1650            @r"
1651        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1652          EmptyRelation: rows=0
1653        "
1654        )
1655    }
1656
1657    #[test]
1658    fn agg_udaf_invalid_input() -> Result<()> {
1659        let empty = empty();
1660        let return_type = DataType::Float64;
1661        let accumulator: AccumulatorFactoryFunction =
1662            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1663        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1664            "MY_AVG",
1665            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1666            return_type,
1667            accumulator,
1668            vec![
1669                Field::new("count", DataType::UInt64, true).into(),
1670                Field::new("avg", DataType::Float64, true).into(),
1671            ],
1672        ));
1673        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1674            Arc::new(my_avg),
1675            vec![lit("10")],
1676            false,
1677            None,
1678            vec![],
1679            None,
1680        ));
1681
1682        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1683        assert!(
1684            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1685        );
1686        Ok(())
1687    }
1688
1689    #[test]
1690    fn agg_function_case() -> Result<()> {
1691        let empty = empty();
1692        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1693            avg_udaf(),
1694            vec![lit(12f64)],
1695            false,
1696            None,
1697            vec![],
1698            None,
1699        ));
1700        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1701
1702        assert_analyzed_plan_eq!(
1703            plan,
1704            @r"
1705        Projection: avg(Float64(12))
1706          EmptyRelation: rows=0
1707        "
1708        )?;
1709
1710        let empty = empty_with_type(DataType::Int32);
1711        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1712            avg_udaf(),
1713            vec![cast(col("a"), DataType::Float64)],
1714            false,
1715            None,
1716            vec![],
1717            None,
1718        ));
1719        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1720
1721        assert_analyzed_plan_eq!(
1722            plan,
1723            @r"
1724        Projection: avg(CAST(a AS Float64))
1725          EmptyRelation: rows=0
1726        "
1727        )
1728    }
1729
1730    #[test]
1731    fn agg_function_invalid_input_avg() -> Result<()> {
1732        let empty = empty();
1733        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1734            avg_udaf(),
1735            vec![lit("1")],
1736            false,
1737            None,
1738            vec![],
1739            None,
1740        ));
1741        let err = Projection::try_new(vec![agg_expr], empty)
1742            .err()
1743            .unwrap()
1744            .strip_backtrace();
1745        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1746        Ok(())
1747    }
1748
1749    #[test]
1750    fn binary_op_date32_op_interval() -> Result<()> {
1751        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1752        let expr = cast(lit("1998-03-18"), DataType::Date32)
1753            + lit(ScalarValue::new_interval_dt(123, 456));
1754        let empty = empty();
1755        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1756
1757        assert_analyzed_plan_eq!(
1758            plan,
1759            @r#"
1760        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1761          EmptyRelation: rows=0
1762        "#
1763        )
1764    }
1765
1766    #[test]
1767    fn inlist_case() -> Result<()> {
1768        // a in (1,4,8), a is int64
1769        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1770        let empty = empty_with_type(DataType::Int64);
1771        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1772        assert_analyzed_plan_eq!(
1773            plan,
1774            @r"
1775        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1776          EmptyRelation: rows=0
1777        ")?;
1778
1779        // a in (1,4,8), a is decimal
1780        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1781        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1782            produce_one_row: false,
1783            schema: Arc::new(DFSchema::from_unqualified_fields(
1784                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1785                std::collections::HashMap::new(),
1786            )?),
1787        }));
1788        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1789        assert_analyzed_plan_eq!(
1790            plan,
1791            @r"
1792        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1793          EmptyRelation: rows=0
1794        ")
1795    }
1796
1797    #[test]
1798    fn between_case() -> Result<()> {
1799        let expr = col("a").between(
1800            lit("2002-05-08"),
1801            // (cast('2002-05-08' as date) + interval '1 months')
1802            cast(lit("2002-05-08"), DataType::Date32)
1803                + lit(ScalarValue::new_interval_ym(0, 1)),
1804        );
1805        let empty = empty_with_type(Utf8);
1806        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1807
1808        assert_analyzed_plan_eq!(
1809            plan,
1810            @r#"
1811        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1812          EmptyRelation: rows=0
1813        "#
1814        )
1815    }
1816
1817    #[test]
1818    fn between_infer_cheap_type() -> Result<()> {
1819        let expr = col("a").between(
1820            // (cast('2002-05-08' as date) + interval '1 months')
1821            cast(lit("2002-05-08"), DataType::Date32)
1822                + lit(ScalarValue::new_interval_ym(0, 1)),
1823            lit("2002-12-08"),
1824        );
1825        let empty = empty_with_type(Utf8);
1826        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1827
1828        // TODO: we should cast col(a).
1829        assert_analyzed_plan_eq!(
1830            plan,
1831            @r#"
1832        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1833          EmptyRelation: rows=0
1834        "#
1835        )
1836    }
1837
1838    #[test]
1839    fn between_null() -> Result<()> {
1840        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1841        let empty = empty();
1842        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1843
1844        assert_analyzed_plan_eq!(
1845            plan,
1846            @r"
1847        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1848          EmptyRelation: rows=0
1849        "
1850        )
1851    }
1852
1853    #[test]
1854    fn is_bool_for_type_coercion() -> Result<()> {
1855        // is true
1856        let expr = col("a").is_true();
1857        let empty = empty_with_type(DataType::Boolean);
1858        let plan =
1859            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1860
1861        assert_analyzed_plan_eq!(
1862            plan,
1863            @r"
1864        Projection: a IS TRUE
1865          EmptyRelation: rows=0
1866        "
1867        )?;
1868
1869        let empty = empty_with_type(DataType::Int64);
1870        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1871        assert_type_coercion_error(
1872            plan,
1873            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1874        )?;
1875
1876        // is not true
1877        let expr = col("a").is_not_true();
1878        let empty = empty_with_type(DataType::Boolean);
1879        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1880
1881        assert_analyzed_plan_eq!(
1882            plan,
1883            @r"
1884        Projection: a IS NOT TRUE
1885          EmptyRelation: rows=0
1886        "
1887        )?;
1888
1889        // is false
1890        let expr = col("a").is_false();
1891        let empty = empty_with_type(DataType::Boolean);
1892        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1893
1894        assert_analyzed_plan_eq!(
1895            plan,
1896            @r"
1897        Projection: a IS FALSE
1898          EmptyRelation: rows=0
1899        "
1900        )?;
1901
1902        // is not false
1903        let expr = col("a").is_not_false();
1904        let empty = empty_with_type(DataType::Boolean);
1905        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1906
1907        assert_analyzed_plan_eq!(
1908            plan,
1909            @r"
1910        Projection: a IS NOT FALSE
1911          EmptyRelation: rows=0
1912        "
1913        )
1914    }
1915
1916    #[test]
1917    fn like_for_type_coercion() -> Result<()> {
1918        // like : utf8 like "abc"
1919        let expr = Box::new(col("a"));
1920        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1921        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1922        let empty = empty_with_type(Utf8);
1923        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1924
1925        assert_analyzed_plan_eq!(
1926            plan,
1927            @r#"
1928        Projection: a LIKE Utf8("abc")
1929          EmptyRelation: rows=0
1930        "#
1931        )?;
1932
1933        let expr = Box::new(col("a"));
1934        let pattern = Box::new(lit(ScalarValue::Null));
1935        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1936        let empty = empty_with_type(Utf8);
1937        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1938
1939        assert_analyzed_plan_eq!(
1940            plan,
1941            @r"
1942        Projection: a LIKE CAST(NULL AS Utf8)
1943          EmptyRelation: rows=0
1944        "
1945        )?;
1946
1947        let expr = Box::new(col("a"));
1948        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1949        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1950        let empty = empty_with_type(DataType::Int64);
1951        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1952        assert_type_coercion_error(
1953            plan,
1954            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1955        )?;
1956
1957        // ilike
1958        let expr = Box::new(col("a"));
1959        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1960        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1961        let empty = empty_with_type(Utf8);
1962        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1963
1964        assert_analyzed_plan_eq!(
1965            plan,
1966            @r#"
1967        Projection: a ILIKE Utf8("abc")
1968          EmptyRelation: rows=0
1969        "#
1970        )?;
1971
1972        let expr = Box::new(col("a"));
1973        let pattern = Box::new(lit(ScalarValue::Null));
1974        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1975        let empty = empty_with_type(Utf8);
1976        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1977
1978        assert_analyzed_plan_eq!(
1979            plan,
1980            @r"
1981        Projection: a ILIKE CAST(NULL AS Utf8)
1982          EmptyRelation: rows=0
1983        "
1984        )?;
1985
1986        let expr = Box::new(col("a"));
1987        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1988        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1989        let empty = empty_with_type(DataType::Int64);
1990        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1991        assert_type_coercion_error(
1992            plan,
1993            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
1994        )?;
1995
1996        Ok(())
1997    }
1998
1999    #[test]
2000    fn unknown_for_type_coercion() -> Result<()> {
2001        // unknown
2002        let expr = col("a").is_unknown();
2003        let empty = empty_with_type(DataType::Boolean);
2004        let plan =
2005            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2006
2007        assert_analyzed_plan_eq!(
2008            plan,
2009            @r"
2010        Projection: a IS UNKNOWN
2011          EmptyRelation: rows=0
2012        "
2013        )?;
2014
2015        let empty = empty_with_type(Utf8);
2016        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2017        assert_type_coercion_error(
2018            plan,
2019            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2020        )?;
2021
2022        // is not unknown
2023        let expr = col("a").is_not_unknown();
2024        let empty = empty_with_type(DataType::Boolean);
2025        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2026
2027        assert_analyzed_plan_eq!(
2028            plan,
2029            @r"
2030        Projection: a IS NOT UNKNOWN
2031          EmptyRelation: rows=0
2032        "
2033        )
2034    }
2035
2036    #[test]
2037    fn concat_for_type_coercion() -> Result<()> {
2038        let empty = empty_with_type(Utf8);
2039        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2040
2041        // concat-type signature
2042        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2043            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2044        })
2045        .call(args.to_vec());
2046        let plan =
2047            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2048        assert_analyzed_plan_eq!(
2049            plan,
2050            @r#"
2051        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2052          EmptyRelation: rows=0
2053        "#
2054        )
2055    }
2056
2057    #[test]
2058    fn test_type_coercion_rewrite() -> Result<()> {
2059        // gt
2060        let schema = Arc::new(DFSchema::from_unqualified_fields(
2061            vec![Field::new("a", DataType::Int64, true)].into(),
2062            std::collections::HashMap::new(),
2063        )?);
2064        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2065        let expr = is_true(lit(12i32).gt(lit(13i64)));
2066        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2067        let result = expr.rewrite(&mut rewriter).data()?;
2068        assert_eq!(expected, result);
2069
2070        // eq
2071        let schema = Arc::new(DFSchema::from_unqualified_fields(
2072            vec![Field::new("a", DataType::Int64, true)].into(),
2073            std::collections::HashMap::new(),
2074        )?);
2075        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2076        let expr = is_true(lit(12i32).eq(lit(13i64)));
2077        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2078        let result = expr.rewrite(&mut rewriter).data()?;
2079        assert_eq!(expected, result);
2080
2081        // lt
2082        let schema = Arc::new(DFSchema::from_unqualified_fields(
2083            vec![Field::new("a", DataType::Int64, true)].into(),
2084            std::collections::HashMap::new(),
2085        )?);
2086        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2087        let expr = is_true(lit(12i32).lt(lit(13i64)));
2088        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2089        let result = expr.rewrite(&mut rewriter).data()?;
2090        assert_eq!(expected, result);
2091
2092        Ok(())
2093    }
2094
2095    #[test]
2096    fn binary_op_date32_eq_ts() -> Result<()> {
2097        let expr = cast(
2098            lit("1998-03-18"),
2099            DataType::Timestamp(TimeUnit::Nanosecond, None),
2100        )
2101        .eq(cast(lit("1998-03-18"), DataType::Date32));
2102        let empty = empty();
2103        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2104
2105        assert_analyzed_plan_eq!(
2106            plan,
2107            @r#"
2108        Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None))
2109          EmptyRelation: rows=0
2110        "#
2111        )
2112    }
2113
2114    fn cast_if_not_same_type(
2115        expr: Box<Expr>,
2116        data_type: &DataType,
2117        schema: &DFSchemaRef,
2118    ) -> Box<Expr> {
2119        if &expr.get_type(schema).unwrap() != data_type {
2120            Box::new(cast(*expr, data_type.clone()))
2121        } else {
2122            expr
2123        }
2124    }
2125
2126    fn cast_helper(
2127        case: Case,
2128        case_when_type: &DataType,
2129        then_else_type: &DataType,
2130        schema: &DFSchemaRef,
2131    ) -> Case {
2132        let expr = case
2133            .expr
2134            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2135        let when_then_expr = case
2136            .when_then_expr
2137            .into_iter()
2138            .map(|(when, then)| {
2139                (
2140                    cast_if_not_same_type(when, case_when_type, schema),
2141                    cast_if_not_same_type(then, then_else_type, schema),
2142                )
2143            })
2144            .collect::<Vec<_>>();
2145        let else_expr = case
2146            .else_expr
2147            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2148
2149        Case {
2150            expr,
2151            when_then_expr,
2152            else_expr,
2153        }
2154    }
2155
2156    #[test]
2157    fn test_case_expression_coercion() -> Result<()> {
2158        let schema = Arc::new(DFSchema::from_unqualified_fields(
2159            vec![
2160                Field::new("boolean", DataType::Boolean, true),
2161                Field::new("integer", DataType::Int32, true),
2162                Field::new("float", DataType::Float32, true),
2163                Field::new(
2164                    "timestamp",
2165                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2166                    true,
2167                ),
2168                Field::new("date", DataType::Date32, true),
2169                Field::new(
2170                    "interval",
2171                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2172                    true,
2173                ),
2174                Field::new("binary", DataType::Binary, true),
2175                Field::new("string", Utf8, true),
2176                Field::new("decimal", DataType::Decimal128(10, 10), true),
2177            ]
2178            .into(),
2179            std::collections::HashMap::new(),
2180        )?);
2181
2182        let case = Case {
2183            expr: None,
2184            when_then_expr: vec![
2185                (Box::new(col("boolean")), Box::new(col("integer"))),
2186                (Box::new(col("integer")), Box::new(col("float"))),
2187                (Box::new(col("string")), Box::new(col("string"))),
2188            ],
2189            else_expr: None,
2190        };
2191        let case_when_common_type = DataType::Boolean;
2192        let then_else_common_type = Utf8;
2193        let expected = cast_helper(
2194            case.clone(),
2195            &case_when_common_type,
2196            &then_else_common_type,
2197            &schema,
2198        );
2199        let actual = coerce_case_expression(case, &schema)?;
2200        assert_eq!(expected, actual);
2201
2202        let case = Case {
2203            expr: Some(Box::new(col("string"))),
2204            when_then_expr: vec![
2205                (Box::new(col("float")), Box::new(col("integer"))),
2206                (Box::new(col("integer")), Box::new(col("float"))),
2207                (Box::new(col("string")), Box::new(col("string"))),
2208            ],
2209            else_expr: Some(Box::new(col("string"))),
2210        };
2211        let case_when_common_type = Utf8;
2212        let then_else_common_type = Utf8;
2213        let expected = cast_helper(
2214            case.clone(),
2215            &case_when_common_type,
2216            &then_else_common_type,
2217            &schema,
2218        );
2219        let actual = coerce_case_expression(case, &schema)?;
2220        assert_eq!(expected, actual);
2221
2222        let case = Case {
2223            expr: Some(Box::new(col("interval"))),
2224            when_then_expr: vec![
2225                (Box::new(col("float")), Box::new(col("integer"))),
2226                (Box::new(col("binary")), Box::new(col("float"))),
2227                (Box::new(col("string")), Box::new(col("string"))),
2228            ],
2229            else_expr: Some(Box::new(col("string"))),
2230        };
2231        let err = coerce_case_expression(case, &schema).unwrap_err();
2232        assert_snapshot!(
2233            err.strip_backtrace(),
2234            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression"
2235        );
2236
2237        let case = Case {
2238            expr: Some(Box::new(col("string"))),
2239            when_then_expr: vec![
2240                (Box::new(col("float")), Box::new(col("date"))),
2241                (Box::new(col("string")), Box::new(col("float"))),
2242                (Box::new(col("string")), Box::new(col("binary"))),
2243            ],
2244            else_expr: Some(Box::new(col("timestamp"))),
2245        };
2246        let err = coerce_case_expression(case, &schema).unwrap_err();
2247        assert_snapshot!(
2248            err.strip_backtrace(),
2249            @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression"
2250        );
2251
2252        Ok(())
2253    }
2254
2255    macro_rules! test_case_expression {
2256        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2257            let case = Case {
2258                expr: $expr.map(|e| Box::new(col(e))),
2259                when_then_expr: $when_then,
2260                else_expr: None,
2261            };
2262
2263            let expected =
2264                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2265
2266            let actual = coerce_case_expression(case, &$schema)?;
2267            assert_eq!(expected, actual);
2268        };
2269    }
2270
2271    #[test]
2272    fn tes_case_when_list() -> Result<()> {
2273        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2274        let schema = Arc::new(DFSchema::from_unqualified_fields(
2275            vec![
2276                Field::new(
2277                    "large_list",
2278                    DataType::LargeList(Arc::clone(&inner_field)),
2279                    true,
2280                ),
2281                Field::new(
2282                    "fixed_list",
2283                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2284                    true,
2285                ),
2286                Field::new("list", DataType::List(inner_field), true),
2287            ]
2288            .into(),
2289            std::collections::HashMap::new(),
2290        )?);
2291
2292        test_case_expression!(
2293            Some("list"),
2294            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2295            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2296            Utf8,
2297            schema
2298        );
2299
2300        test_case_expression!(
2301            Some("large_list"),
2302            vec![(Box::new(col("list")), Box::new(lit("1")))],
2303            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2304            Utf8,
2305            schema
2306        );
2307
2308        test_case_expression!(
2309            Some("list"),
2310            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2311            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2312            Utf8,
2313            schema
2314        );
2315
2316        test_case_expression!(
2317            Some("fixed_list"),
2318            vec![(Box::new(col("list")), Box::new(lit("1")))],
2319            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2320            Utf8,
2321            schema
2322        );
2323
2324        test_case_expression!(
2325            Some("fixed_list"),
2326            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2327            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2328            Utf8,
2329            schema
2330        );
2331
2332        test_case_expression!(
2333            Some("large_list"),
2334            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2335            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2336            Utf8,
2337            schema
2338        );
2339        Ok(())
2340    }
2341
2342    #[test]
2343    fn test_then_else_list() -> Result<()> {
2344        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2345        let schema = Arc::new(DFSchema::from_unqualified_fields(
2346            vec![
2347                Field::new("boolean", DataType::Boolean, true),
2348                Field::new(
2349                    "large_list",
2350                    DataType::LargeList(Arc::clone(&inner_field)),
2351                    true,
2352                ),
2353                Field::new(
2354                    "fixed_list",
2355                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2356                    true,
2357                ),
2358                Field::new("list", DataType::List(inner_field), true),
2359            ]
2360            .into(),
2361            std::collections::HashMap::new(),
2362        )?);
2363
2364        // large list and list
2365        test_case_expression!(
2366            None::<String>,
2367            vec![
2368                (Box::new(col("boolean")), Box::new(col("large_list"))),
2369                (Box::new(col("boolean")), Box::new(col("list")))
2370            ],
2371            DataType::Boolean,
2372            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2373            schema
2374        );
2375
2376        test_case_expression!(
2377            None::<String>,
2378            vec![
2379                (Box::new(col("boolean")), Box::new(col("list"))),
2380                (Box::new(col("boolean")), Box::new(col("large_list")))
2381            ],
2382            DataType::Boolean,
2383            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2384            schema
2385        );
2386
2387        // fixed list and list
2388        test_case_expression!(
2389            None::<String>,
2390            vec![
2391                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2392                (Box::new(col("boolean")), Box::new(col("list")))
2393            ],
2394            DataType::Boolean,
2395            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2396            schema
2397        );
2398
2399        test_case_expression!(
2400            None::<String>,
2401            vec![
2402                (Box::new(col("boolean")), Box::new(col("list"))),
2403                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2404            ],
2405            DataType::Boolean,
2406            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2407            schema
2408        );
2409
2410        // fixed list and large list
2411        test_case_expression!(
2412            None::<String>,
2413            vec![
2414                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2415                (Box::new(col("boolean")), Box::new(col("large_list")))
2416            ],
2417            DataType::Boolean,
2418            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2419            schema
2420        );
2421
2422        test_case_expression!(
2423            None::<String>,
2424            vec![
2425                (Box::new(col("boolean")), Box::new(col("large_list"))),
2426                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2427            ],
2428            DataType::Boolean,
2429            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2430            schema
2431        );
2432        Ok(())
2433    }
2434
2435    #[test]
2436    fn test_map_with_diff_name() -> Result<()> {
2437        let mut builder = SchemaBuilder::new();
2438        builder.push(Field::new("key", Utf8, false));
2439        builder.push(Field::new("value", DataType::Float64, true));
2440        let struct_fields = builder.finish().fields;
2441
2442        let fields =
2443            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2444        let map_type_entries = DataType::Map(Arc::new(fields), false);
2445
2446        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2447        let may_type_custom = DataType::Map(Arc::new(fields), false);
2448
2449        let expr = col("a").eq(cast(col("a"), may_type_custom));
2450        let empty = empty_with_type(map_type_entries);
2451        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2452
2453        assert_analyzed_plan_eq!(
2454            plan,
2455            @r#"
2456        Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))
2457          EmptyRelation: rows=0
2458        "#
2459        )
2460    }
2461
2462    #[test]
2463    fn interval_plus_timestamp() -> Result<()> {
2464        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2465        let expr = Expr::BinaryExpr(BinaryExpr::new(
2466            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2467            Operator::Plus,
2468            Box::new(cast(
2469                lit("2000-01-01T00:00:00"),
2470                DataType::Timestamp(TimeUnit::Nanosecond, None),
2471            )),
2472        ));
2473        let empty = empty();
2474        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2475
2476        assert_analyzed_plan_eq!(
2477            plan,
2478            @r#"
2479        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None))
2480          EmptyRelation: rows=0
2481        "#
2482        )
2483    }
2484
2485    #[test]
2486    fn timestamp_subtract_timestamp() -> Result<()> {
2487        let expr = Expr::BinaryExpr(BinaryExpr::new(
2488            Box::new(cast(
2489                lit("1998-03-18"),
2490                DataType::Timestamp(TimeUnit::Nanosecond, None),
2491            )),
2492            Operator::Minus,
2493            Box::new(cast(
2494                lit("1998-03-18"),
2495                DataType::Timestamp(TimeUnit::Nanosecond, None),
2496            )),
2497        ));
2498        let empty = empty();
2499        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2500
2501        assert_analyzed_plan_eq!(
2502            plan,
2503            @r#"
2504        Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None))
2505          EmptyRelation: rows=0
2506        "#
2507        )
2508    }
2509
2510    #[test]
2511    fn in_subquery_cast_subquery() -> Result<()> {
2512        let empty_int32 = empty_with_type(DataType::Int32);
2513        let empty_int64 = empty_with_type(DataType::Int64);
2514
2515        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2516            Box::new(col("a")),
2517            Subquery {
2518                subquery: empty_int32,
2519                outer_ref_columns: vec![],
2520                spans: Spans::new(),
2521            },
2522            false,
2523        ));
2524        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2525        // add cast for subquery
2526
2527        assert_analyzed_plan_eq!(
2528            plan,
2529            @r"
2530        Filter: a IN (<subquery>)
2531          Subquery:
2532            Projection: CAST(a AS Int64)
2533              EmptyRelation: rows=0
2534          EmptyRelation: rows=0
2535        "
2536        )
2537    }
2538
2539    #[test]
2540    fn in_subquery_cast_expr() -> Result<()> {
2541        let empty_int32 = empty_with_type(DataType::Int32);
2542        let empty_int64 = empty_with_type(DataType::Int64);
2543
2544        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2545            Box::new(col("a")),
2546            Subquery {
2547                subquery: empty_int64,
2548                outer_ref_columns: vec![],
2549                spans: Spans::new(),
2550            },
2551            false,
2552        ));
2553        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2554
2555        // add cast for subquery
2556        assert_analyzed_plan_eq!(
2557            plan,
2558            @r"
2559        Filter: CAST(a AS Int64) IN (<subquery>)
2560          Subquery:
2561            EmptyRelation: rows=0
2562          EmptyRelation: rows=0
2563        "
2564        )
2565    }
2566
2567    #[test]
2568    fn in_subquery_cast_all() -> Result<()> {
2569        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2570        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2571
2572        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2573            Box::new(col("a")),
2574            Subquery {
2575                subquery: empty_inside,
2576                outer_ref_columns: vec![],
2577                spans: Spans::new(),
2578            },
2579            false,
2580        ));
2581        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2582
2583        // add cast for subquery
2584        assert_analyzed_plan_eq!(
2585            plan,
2586            @r"
2587        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2588          Subquery:
2589            Projection: CAST(a AS Decimal128(13, 8))
2590              EmptyRelation: rows=0
2591          EmptyRelation: rows=0
2592        "
2593        )
2594    }
2595}