datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::{izip, Itertools as _};
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32    exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
33    plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
34    TableReference,
35};
36use datafusion_expr::expr::{
37    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
38    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
39};
40use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
41use datafusion_expr::expr_schema::cast_subquery;
42use datafusion_expr::logical_plan::Subquery;
43use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
44use datafusion_expr::type_coercion::functions::{
45    data_types_with_scalar_udf, fields_with_aggregate_udf,
46};
47use datafusion_expr::type_coercion::other::{
48    get_coerce_type_for_case_expression, get_coerce_type_for_list,
49};
50use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
51use datafusion_expr::utils::merge_schema;
52use datafusion_expr::{
53    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
54    AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
55    ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
56};
57
58/// Performs type coercion by determining the schema
59/// and performing the expression rewrites.
60#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64    pub fn new() -> Self {
65        Self {}
66    }
67}
68
69/// Coerce output schema based upon optimizer config.
70fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71    if !config.optimizer.expand_views_at_output {
72        return Ok(plan);
73    }
74
75    let outer_refs = plan.expressions();
76    if outer_refs.is_empty() {
77        return Ok(plan);
78    }
79
80    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81        coerce_plan_expr_for_schema(plan, &dfschema?)
82    } else {
83        Ok(plan)
84    }
85}
86
87impl AnalyzerRule for TypeCoercion {
88    fn name(&self) -> &str {
89        "type_coercion"
90    }
91
92    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93        let empty_schema = DFSchema::empty();
94
95        // recurse
96        let transformed_plan = plan
97            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98            .data;
99
100        // finish
101        coerce_output(transformed_plan, config)
102    }
103}
104
105/// use the external schema to handle the correlated subqueries case
106///
107/// Assumes that children have already been optimized
108fn analyze_internal(
109    external_schema: &DFSchema,
110    plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112    // get schema representing all available input fields. This is used for data type
113    // resolution only, so order does not matter here
114    let mut schema = merge_schema(&plan.inputs());
115
116    if let LogicalPlan::TableScan(ts) = &plan {
117        let source_schema = DFSchema::try_from_qualified_schema(
118            ts.table_name.clone(),
119            &ts.source.schema(),
120        )?;
121        schema.merge(&source_schema);
122    }
123
124    // merge the outer schema for correlated subqueries
125    // like case:
126    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
127    schema.merge(external_schema);
128
129    // Coerce filter predicates to boolean (handles `WHERE NULL`)
130    let plan = if let LogicalPlan::Filter(mut filter) = plan {
131        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132        LogicalPlan::Filter(filter)
133    } else {
134        plan
135    };
136
137    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139    let name_preserver = NamePreserver::new(&plan);
140    // apply coercion rewrite all expressions in the plan individually
141    plan.map_expressions(|expr| {
142        let original_name = name_preserver.save(&expr);
143        expr.rewrite(&mut expr_rewrite)
144            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145    })?
146    // some plans need extra coercion after their expressions are coerced
147    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148    // recompute the schema after the expressions have been rewritten as the types may have changed
149    .map_data(|plan| plan.recompute_schema())
150}
151
152/// Rewrite expressions to apply type coercion.
153pub struct TypeCoercionRewriter<'a> {
154    pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158    /// Create a new [`TypeCoercionRewriter`] with a provided schema
159    /// representing both the inputs and output of the [`LogicalPlan`] node.
160    pub fn new(schema: &'a DFSchema) -> Self {
161        Self { schema }
162    }
163
164    /// Coerce the [`LogicalPlan`].
165    ///
166    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
167    /// for type-coercion approach.
168    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169        match plan {
170            LogicalPlan::Join(join) => self.coerce_join(join),
171            LogicalPlan::Union(union) => Self::coerce_union(union),
172            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173            _ => Ok(plan),
174        }
175    }
176
177    /// Coerce join equality expressions and join filter
178    ///
179    /// Joins must be treated specially as their equality expressions are stored
180    /// as a parallel list of left and right expressions, rather than a single
181    /// equality expression
182    ///
183    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
184    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
185    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186        join.on = join
187            .on
188            .into_iter()
189            .map(|(lhs, rhs)| {
190                // coerce the arguments as though they were a single binary equality
191                // expression
192                let left_schema = join.left.schema();
193                let right_schema = join.right.schema();
194                let (lhs, rhs) = self.coerce_binary_op(
195                    lhs,
196                    left_schema,
197                    Operator::Eq,
198                    rhs,
199                    right_schema,
200                )?;
201                Ok((lhs, rhs))
202            })
203            .collect::<Result<Vec<_>>>()?;
204
205        // Join filter must be boolean
206        join.filter = join
207            .filter
208            .map(|expr| self.coerce_join_filter(expr))
209            .transpose()?;
210
211        Ok(LogicalPlan::Join(join))
212    }
213
214    /// Coerce the union’s inputs to a common schema compatible with all inputs.
215    /// This occurs after wildcard expansion and the coercion of the input expressions.
216    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217        let union_schema = Arc::new(coerce_union_schema_with_schema(
218            &union_plan.inputs,
219            &union_plan.schema,
220        )?);
221        let new_inputs = union_plan
222            .inputs
223            .into_iter()
224            .map(|p| {
225                let plan =
226                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227                match plan {
228                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
229                        Ok(Arc::new(project_with_column_index(
230                            expr,
231                            input,
232                            Arc::clone(&union_schema),
233                        )?))
234                    }
235                    other_plan => Ok(Arc::new(other_plan)),
236                }
237            })
238            .collect::<Result<Vec<_>>>()?;
239        Ok(LogicalPlan::Union(Union {
240            inputs: new_inputs,
241            schema: union_schema,
242        }))
243    }
244
245    /// Coerce the fetch and skip expression to Int64 type.
246    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247        fn coerce_limit_expr(
248            expr: Expr,
249            schema: &DFSchema,
250            expr_name: &str,
251        ) -> Result<Expr> {
252            let dt = expr.get_type(schema)?;
253            if dt.is_integer() || dt.is_null() {
254                expr.cast_to(&DataType::Int64, schema)
255            } else {
256                plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
257            }
258        }
259
260        let empty_schema = DFSchema::empty();
261        let new_fetch = limit
262            .fetch
263            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264            .transpose()?;
265        let new_skip = limit
266            .skip
267            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268            .transpose()?;
269        Ok(LogicalPlan::Limit(Limit {
270            input: limit.input,
271            fetch: new_fetch.map(Box::new),
272            skip: new_skip.map(Box::new),
273        }))
274    }
275
276    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277        let expr_type = expr.get_type(self.schema)?;
278        match expr_type {
279            DataType::Boolean => Ok(expr),
280            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282        }
283    }
284
285    fn coerce_binary_op(
286        &self,
287        left: Expr,
288        left_schema: &DFSchema,
289        op: Operator,
290        right: Expr,
291        right_schema: &DFSchema,
292    ) -> Result<(Expr, Expr)> {
293        let (left_type, right_type) = BinaryTypeCoercer::new(
294            &left.get_type(left_schema)?,
295            &op,
296            &right.get_type(right_schema)?,
297        )
298        .get_input_types()?;
299
300        Ok((
301            left.cast_to(&left_type, left_schema)?,
302            right.cast_to(&right_type, right_schema)?,
303        ))
304    }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308    type Node = Expr;
309
310    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311        match expr {
312            Expr::Unnest(_) => not_impl_err!(
313                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314            ),
315            Expr::ScalarSubquery(Subquery {
316                subquery,
317                outer_ref_columns,
318                spans,
319            }) => {
320                let new_plan =
321                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323                    subquery: Arc::new(new_plan),
324                    outer_ref_columns,
325                    spans,
326                })))
327            }
328            Expr::Exists(Exists { subquery, negated }) => {
329                let new_plan = analyze_internal(
330                    self.schema,
331                    Arc::unwrap_or_clone(subquery.subquery),
332                )?
333                .data;
334                Ok(Transformed::yes(Expr::Exists(Exists {
335                    subquery: Subquery {
336                        subquery: Arc::new(new_plan),
337                        outer_ref_columns: subquery.outer_ref_columns,
338                        spans: subquery.spans,
339                    },
340                    negated,
341                })))
342            }
343            Expr::InSubquery(InSubquery {
344                expr,
345                subquery,
346                negated,
347            }) => {
348                let new_plan = analyze_internal(
349                    self.schema,
350                    Arc::unwrap_or_clone(subquery.subquery),
351                )?
352                .data;
353                let expr_type = expr.get_type(self.schema)?;
354                let subquery_type = new_plan.schema().field(0).data_type();
355                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
356                    plan_datafusion_err!(
357                    "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
358                ),
359                )?;
360                let new_subquery = Subquery {
361                    subquery: Arc::new(new_plan),
362                    outer_ref_columns: subquery.outer_ref_columns,
363                    spans: subquery.spans,
364                };
365                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
366                    Box::new(expr.cast_to(&common_type, self.schema)?),
367                    cast_subquery(new_subquery, &common_type)?,
368                    negated,
369                ))))
370            }
371            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
372                *expr,
373                self.schema,
374            )?))),
375            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
376                get_casted_expr_for_bool_op(*expr, self.schema)?,
377            ))),
378            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
379                get_casted_expr_for_bool_op(*expr, self.schema)?,
380            ))),
381            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
382                get_casted_expr_for_bool_op(*expr, self.schema)?,
383            ))),
384            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
385                get_casted_expr_for_bool_op(*expr, self.schema)?,
386            ))),
387            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
388                get_casted_expr_for_bool_op(*expr, self.schema)?,
389            ))),
390            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
391                get_casted_expr_for_bool_op(*expr, self.schema)?,
392            ))),
393            Expr::Like(Like {
394                negated,
395                expr,
396                pattern,
397                escape_char,
398                case_insensitive,
399            }) => {
400                let left_type = expr.get_type(self.schema)?;
401                let right_type = pattern.get_type(self.schema)?;
402                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
403                    let op_name = if case_insensitive {
404                        "ILIKE"
405                    } else {
406                        "LIKE"
407                    };
408                    plan_datafusion_err!(
409                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
410                    )
411                })?;
412                let expr = match left_type {
413                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
414                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
415                };
416                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
417                Ok(Transformed::yes(Expr::Like(Like::new(
418                    negated,
419                    expr,
420                    pattern,
421                    escape_char,
422                    case_insensitive,
423                ))))
424            }
425            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
426                let (left, right) =
427                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
428                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
429                    Box::new(left),
430                    op,
431                    Box::new(right),
432                ))))
433            }
434            Expr::Between(Between {
435                expr,
436                negated,
437                low,
438                high,
439            }) => {
440                let expr_type = expr.get_type(self.schema)?;
441                let low_type = low.get_type(self.schema)?;
442                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
443                    .ok_or_else(|| {
444                        internal_datafusion_err!(
445                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
446                        )
447                    })?;
448                let high_type = high.get_type(self.schema)?;
449                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
450                    .ok_or_else(|| {
451                        internal_datafusion_err!(
452                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
453                        )
454                    })?;
455                let coercion_type =
456                    comparison_coercion(&low_coerced_type, &high_coerced_type)
457                        .ok_or_else(|| {
458                            internal_datafusion_err!(
459                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
460                            )
461                        })?;
462                Ok(Transformed::yes(Expr::Between(Between::new(
463                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
464                    negated,
465                    Box::new(low.cast_to(&coercion_type, self.schema)?),
466                    Box::new(high.cast_to(&coercion_type, self.schema)?),
467                ))))
468            }
469            Expr::InList(InList {
470                expr,
471                list,
472                negated,
473            }) => {
474                let expr_data_type = expr.get_type(self.schema)?;
475                let list_data_types = list
476                    .iter()
477                    .map(|list_expr| list_expr.get_type(self.schema))
478                    .collect::<Result<Vec<_>>>()?;
479                let result_type =
480                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
481                match result_type {
482                    None => plan_err!(
483                        "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ")
484                    ),
485                    Some(coerced_type) => {
486                        // find the coerced type
487                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
488                        let cast_list_expr = list
489                            .into_iter()
490                            .map(|list_expr| {
491                                list_expr.cast_to(&coerced_type, self.schema)
492                            })
493                            .collect::<Result<Vec<_>>>()?;
494                        Ok(Transformed::yes(Expr::InList(InList ::new(
495                             Box::new(cast_expr),
496                             cast_list_expr,
497                            negated,
498                        ))))
499                    }
500                }
501            }
502            Expr::Case(case) => {
503                let case = coerce_case_expression(case, self.schema)?;
504                Ok(Transformed::yes(Expr::Case(case)))
505            }
506            Expr::ScalarFunction(ScalarFunction { func, args }) => {
507                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
508                    args,
509                    self.schema,
510                    &func,
511                )?;
512                Ok(Transformed::yes(Expr::ScalarFunction(
513                    ScalarFunction::new_udf(func, new_expr),
514                )))
515            }
516            Expr::AggregateFunction(expr::AggregateFunction {
517                func,
518                params:
519                    AggregateFunctionParams {
520                        args,
521                        distinct,
522                        filter,
523                        order_by,
524                        null_treatment,
525                    },
526            }) => {
527                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
528                    args,
529                    self.schema,
530                    &func,
531                )?;
532                Ok(Transformed::yes(Expr::AggregateFunction(
533                    expr::AggregateFunction::new_udf(
534                        func,
535                        new_expr,
536                        distinct,
537                        filter,
538                        order_by,
539                        null_treatment,
540                    ),
541                )))
542            }
543            Expr::WindowFunction(window_fun) => {
544                let WindowFunction {
545                    fun,
546                    params:
547                        expr::WindowFunctionParams {
548                            args,
549                            partition_by,
550                            order_by,
551                            window_frame,
552                            filter,
553                            null_treatment,
554                            distinct,
555                        },
556                } = *window_fun;
557                let window_frame =
558                    coerce_window_frame(window_frame, self.schema, &order_by)?;
559
560                let args = match &fun {
561                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
562                        coerce_arguments_for_signature_with_aggregate_udf(
563                            args,
564                            self.schema,
565                            udf,
566                        )?
567                    }
568                    _ => args,
569                };
570
571                let new_expr = Expr::from(WindowFunction {
572                    fun,
573                    params: expr::WindowFunctionParams {
574                        args,
575                        partition_by,
576                        order_by,
577                        window_frame,
578                        filter,
579                        null_treatment,
580                        distinct,
581                    },
582                });
583                Ok(Transformed::yes(new_expr))
584            }
585            // TODO: remove the next line after `Expr::Wildcard` is removed
586            #[expect(deprecated)]
587            Expr::Alias(_)
588            | Expr::Column(_)
589            | Expr::ScalarVariable(_, _)
590            | Expr::Literal(_, _)
591            | Expr::SimilarTo(_)
592            | Expr::IsNotNull(_)
593            | Expr::IsNull(_)
594            | Expr::Negative(_)
595            | Expr::Cast(_)
596            | Expr::TryCast(_)
597            | Expr::Wildcard { .. }
598            | Expr::GroupingSet(_)
599            | Expr::Placeholder(_)
600            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
601        }
602    }
603}
604
605/// Transform a schema to use non-view types for Utf8View and BinaryView
606fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
607    let metadata = dfschema.as_arrow().metadata.clone();
608    let mut transformed = false;
609
610    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
611        dfschema
612            .iter()
613            .map(|(qualifier, field)| match field.data_type() {
614                DataType::Utf8View => {
615                    transformed = true;
616                    (
617                        qualifier.cloned() as Option<TableReference>,
618                        Arc::new(Field::new(
619                            field.name(),
620                            DataType::LargeUtf8,
621                            field.is_nullable(),
622                        )),
623                    )
624                }
625                DataType::BinaryView => {
626                    transformed = true;
627                    (
628                        qualifier.cloned() as Option<TableReference>,
629                        Arc::new(Field::new(
630                            field.name(),
631                            DataType::LargeBinary,
632                            field.is_nullable(),
633                        )),
634                    )
635                }
636                _ => (
637                    qualifier.cloned() as Option<TableReference>,
638                    Arc::clone(field),
639                ),
640            })
641            .unzip();
642
643    if !transformed {
644        return None;
645    }
646
647    let schema = Schema::new_with_metadata(transformed_fields, metadata);
648    Some(DFSchema::from_field_specific_qualified_schema(
649        qualifiers,
650        &Arc::new(schema),
651    ))
652}
653
654/// Casts the given `value` to `target_type`. Note that this function
655/// only considers `Null` or `Utf8` values.
656fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
657    match value {
658        // Coerce Utf8 values:
659        ScalarValue::Utf8(Some(val)) => {
660            ScalarValue::try_from_string(val.clone(), target_type)
661        }
662        s => {
663            if s.is_null() {
664                // Coerce `Null` values:
665                ScalarValue::try_from(target_type)
666            } else {
667                // Values except `Utf8`/`Null` variants already have the right type
668                // (casted before) since we convert `sqlparser` outputs to `Utf8`
669                // for all possible cases. Therefore, we return a clone here.
670                Ok(s.clone())
671            }
672        }
673    }
674}
675
676/// This function coerces `value` to `target_type` in a range-aware fashion.
677/// If the coercion is successful, we return an `Ok` value with the result.
678/// If the coercion fails because `target_type` is not wide enough (i.e. we
679/// can not coerce to `target_type`, but we can to a wider type in the same
680/// family), we return a `Null` value of this type to signal this situation.
681/// Downstream code uses this signal to treat these values as *unbounded*.
682fn coerce_scalar_range_aware(
683    target_type: &DataType,
684    value: &ScalarValue,
685) -> Result<ScalarValue> {
686    coerce_scalar(target_type, value).or_else(|err| {
687        // If type coercion fails, check if the largest type in family works:
688        if let Some(largest_type) = get_widest_type_in_family(target_type) {
689            coerce_scalar(largest_type, value).map_or_else(
690                |_| exec_err!("Cannot cast {value:?} to {target_type}"),
691                |_| ScalarValue::try_from(target_type),
692            )
693        } else {
694            Err(err)
695        }
696    })
697}
698
699/// This function returns the widest type in the family of `given_type`.
700/// If the given type is already the widest type, it returns `None`.
701/// For example, if `given_type` is `Int8`, it returns `Int64`.
702fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
703    match given_type {
704        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
705        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
706        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
707        _ => None,
708    }
709}
710
711/// Coerces the given (window frame) `bound` to `target_type`.
712fn coerce_frame_bound(
713    target_type: &DataType,
714    bound: WindowFrameBound,
715) -> Result<WindowFrameBound> {
716    match bound {
717        WindowFrameBound::Preceding(v) => {
718            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
719        }
720        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
721        WindowFrameBound::Following(v) => {
722            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
723        }
724    }
725}
726
727fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
728    if col_type.is_numeric()
729        || is_utf8_or_utf8view_or_large_utf8(col_type)
730        || matches!(col_type, DataType::List(_))
731        || matches!(col_type, DataType::LargeList(_))
732        || matches!(col_type, DataType::FixedSizeList(_, _))
733        || matches!(col_type, DataType::Null)
734        || matches!(col_type, DataType::Boolean)
735    {
736        Ok(col_type.clone())
737    } else if is_datetime(col_type) {
738        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
739    } else if let DataType::Dictionary(_, value_type) = col_type {
740        extract_window_frame_target_type(value_type)
741    } else {
742        internal_err!("Cannot run range queries on datatype: {col_type}")
743    }
744}
745
746// Coerces the given `window_frame` to use appropriate natural types.
747// For example, ROWS and GROUPS frames use `UInt64` during calculations.
748fn coerce_window_frame(
749    window_frame: WindowFrame,
750    schema: &DFSchema,
751    expressions: &[Sort],
752) -> Result<WindowFrame> {
753    let mut window_frame = window_frame;
754    let target_type = match window_frame.units {
755        WindowFrameUnits::Range => {
756            let current_types = expressions
757                .first()
758                .map(|s| s.expr.get_type(schema))
759                .transpose()?;
760            if let Some(col_type) = current_types {
761                extract_window_frame_target_type(&col_type)?
762            } else {
763                return internal_err!("ORDER BY column cannot be empty");
764            }
765        }
766        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
767    };
768    window_frame.start_bound =
769        coerce_frame_bound(&target_type, window_frame.start_bound)?;
770    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
771    Ok(window_frame)
772}
773
774// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
775// The above op will be rewrite to the binary op when creating the physical op.
776fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
777    let left_type = expr.get_type(schema)?;
778    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
779        .get_input_types()?;
780    expr.cast_to(&DataType::Boolean, schema)
781}
782
783/// Returns `expressions` coerced to types compatible with
784/// `signature`, if possible.
785///
786/// See the module level documentation for more detail on coercion.
787fn coerce_arguments_for_signature_with_scalar_udf(
788    expressions: Vec<Expr>,
789    schema: &DFSchema,
790    func: &ScalarUDF,
791) -> Result<Vec<Expr>> {
792    if expressions.is_empty() {
793        return Ok(expressions);
794    }
795
796    let current_types = expressions
797        .iter()
798        .map(|e| e.get_type(schema))
799        .collect::<Result<Vec<_>>>()?;
800
801    let new_types = data_types_with_scalar_udf(&current_types, func)?;
802
803    expressions
804        .into_iter()
805        .enumerate()
806        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
807        .collect()
808}
809
810/// Returns `expressions` coerced to types compatible with
811/// `signature`, if possible.
812///
813/// See the module level documentation for more detail on coercion.
814fn coerce_arguments_for_signature_with_aggregate_udf(
815    expressions: Vec<Expr>,
816    schema: &DFSchema,
817    func: &AggregateUDF,
818) -> Result<Vec<Expr>> {
819    if expressions.is_empty() {
820        return Ok(expressions);
821    }
822
823    let current_fields = expressions
824        .iter()
825        .map(|e| e.to_field(schema).map(|(_, f)| f))
826        .collect::<Result<Vec<_>>>()?;
827
828    let new_types = fields_with_aggregate_udf(&current_fields, func)?
829        .into_iter()
830        .map(|f| f.data_type().clone())
831        .collect::<Vec<_>>();
832
833    expressions
834        .into_iter()
835        .enumerate()
836        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
837        .collect()
838}
839
840fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
841    // Given expressions like:
842    //
843    // CASE a1
844    //   WHEN a2 THEN b1
845    //   WHEN a3 THEN b2
846    //   ELSE b3
847    // END
848    //
849    // or:
850    //
851    // CASE
852    //   WHEN x1 THEN b1
853    //   WHEN x2 THEN b2
854    //   ELSE b3
855    // END
856    //
857    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
858    // (case-when expression coercion)
859    //
860    // All xN (x1, x2) must be converted to a boolean data type in the second example
861    // (when-boolean expression coercion)
862    //
863    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
864    // (then-else expression coercion)
865    //
866    // If any fail to find and cast to a common/specific data type, will return error
867    //
868    // Note that case-when and when-boolean expression coercions are mutually exclusive
869    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
870
871    // prepare types
872    let case_type = case
873        .expr
874        .as_ref()
875        .map(|expr| expr.get_type(schema))
876        .transpose()?;
877    let then_types = case
878        .when_then_expr
879        .iter()
880        .map(|(_when, then)| then.get_type(schema))
881        .collect::<Result<Vec<_>>>()?;
882    let else_type = case
883        .else_expr
884        .as_ref()
885        .map(|expr| expr.get_type(schema))
886        .transpose()?;
887
888    // find common coercible types
889    let case_when_coerce_type = case_type
890        .as_ref()
891        .map(|case_type| {
892            let when_types = case
893                .when_then_expr
894                .iter()
895                .map(|(when, _then)| when.get_type(schema))
896                .collect::<Result<Vec<_>>>()?;
897            let coerced_type =
898                get_coerce_type_for_case_expression(&when_types, Some(case_type));
899            coerced_type.ok_or_else(|| {
900                plan_datafusion_err!(
901                    "Failed to coerce case ({case_type}) and when ({}) \
902                     to common types in CASE WHEN expression",
903                    when_types.iter().join(", ")
904                )
905            })
906        })
907        .transpose()?;
908    let then_else_coerce_type =
909        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
910            || {
911                if let Some(else_type) = else_type {
912                    plan_datafusion_err!(
913                        "Failed to coerce then ({}) and else ({else_type}) \
914                         to common types in CASE WHEN expression",
915                        then_types.iter().join(", ")
916                    )
917                } else {
918                    plan_datafusion_err!(
919                        "Failed to coerce then ({}) and else (None) \
920                         to common types in CASE WHEN expression",
921                        then_types.iter().join(", ")
922                    )
923                }
924            },
925        )?;
926
927    // do cast if found common coercible types
928    let case_expr = case
929        .expr
930        .zip(case_when_coerce_type.as_ref())
931        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
932        .transpose()?
933        .map(Box::new);
934    let when_then = case
935        .when_then_expr
936        .into_iter()
937        .map(|(when, then)| {
938            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
939            let when = when.cast_to(when_type, schema).map_err(|e| {
940                DataFusionError::Context(
941                    format!(
942                        "WHEN expressions in CASE couldn't be \
943                         converted to common type ({when_type})"
944                    ),
945                    Box::new(e),
946                )
947            })?;
948            let then = then.cast_to(&then_else_coerce_type, schema)?;
949            Ok((Box::new(when), Box::new(then)))
950        })
951        .collect::<Result<Vec<_>>>()?;
952    let else_expr = case
953        .else_expr
954        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
955        .transpose()?
956        .map(Box::new);
957
958    Ok(Case::new(case_expr, when_then, else_expr))
959}
960
961/// Get a common schema that is compatible with all inputs of UNION.
962///
963/// This method presumes that the wildcard expansion is unneeded, or has already
964/// been applied.
965///
966/// ## Schema and Field Handling in Union Coercion
967///
968/// **Processing order**: The function starts with the base schema (first input) and then
969/// processes remaining inputs sequentially, with later inputs taking precedence in merging.
970///
971/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys.
972///
973/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys.
974///
975/// **Type coercion precedence**: The coerced type is determined by iteratively applying
976/// `comparison_coercion()` between the accumulated type and each new input's type. The
977/// result depends on type coercion rules, not input order.
978///
979/// **Nullability merging**: Nullability is accumulated using logical OR (`||`).
980/// Once any input field is nullable, the result field becomes nullable permanently.
981/// Later inputs can make a field nullable but cannot make it non-nullable.
982///
983/// **Field precedence**: Field names come from the first (base) schema, but the field properties
984/// (nullability and field-level metadata) have later schemas taking precedence.
985///
986/// **Example**:
987/// ```sql
988/// SELECT a, b FROM table1  -- a: Int32, metadata {"source": "t1"}, nullable=false
989/// UNION
990/// SELECT a, b FROM table2  -- a: Int64, metadata {"source": "t2"}, nullable=true
991/// UNION
992/// SELECT a, b FROM table3  -- a: Int32, metadata {"encoding": "utf8"}, nullable=false
993/// -- Result:
994/// -- a: Int64 (from type coercion), nullable=true (from table2),
995/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence)
996/// ```
997///
998/// **Precedence Summary**:
999/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order
1000/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR)
1001/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics)
1002pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1003    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1004}
1005fn coerce_union_schema_with_schema(
1006    inputs: &[Arc<LogicalPlan>],
1007    base_schema: &DFSchemaRef,
1008) -> Result<DFSchema> {
1009    let mut union_datatypes = base_schema
1010        .fields()
1011        .iter()
1012        .map(|f| f.data_type().clone())
1013        .collect::<Vec<_>>();
1014    let mut union_nullabilities = base_schema
1015        .fields()
1016        .iter()
1017        .map(|f| f.is_nullable())
1018        .collect::<Vec<_>>();
1019    let mut union_field_meta = base_schema
1020        .fields()
1021        .iter()
1022        .map(|f| f.metadata().clone())
1023        .collect::<Vec<_>>();
1024
1025    let mut metadata = base_schema.metadata().clone();
1026
1027    for (i, plan) in inputs.iter().enumerate() {
1028        let plan_schema = plan.schema();
1029        metadata.extend(plan_schema.metadata().clone());
1030
1031        if plan_schema.fields().len() != base_schema.fields().len() {
1032            return plan_err!(
1033                "Union schemas have different number of fields: \
1034                query 1 has {} fields whereas query {} has {} fields",
1035                base_schema.fields().len(),
1036                i + 1,
1037                plan_schema.fields().len()
1038            );
1039        }
1040
1041        // coerce data type and nullability for each field
1042        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1043            union_datatypes.iter_mut(),
1044            union_nullabilities.iter_mut(),
1045            union_field_meta.iter_mut(),
1046            plan_schema.fields().iter()
1047        ) {
1048            let coerced_type =
1049                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1050                    || {
1051                        plan_datafusion_err!(
1052                            "Incompatible inputs for Union: Previous inputs were \
1053                            of type {}, but got incompatible type {} on column '{}'",
1054                            union_datatype,
1055                            plan_field.data_type(),
1056                            plan_field.name()
1057                        )
1058                    },
1059                )?;
1060
1061            *union_datatype = coerced_type;
1062            *union_nullable = *union_nullable || plan_field.is_nullable();
1063            union_field_map.extend(plan_field.metadata().clone());
1064        }
1065    }
1066    let union_qualified_fields = izip!(
1067        base_schema.fields(),
1068        union_datatypes.into_iter(),
1069        union_nullabilities,
1070        union_field_meta.into_iter()
1071    )
1072    .map(|(field, datatype, nullable, metadata)| {
1073        let mut field = Field::new(field.name().clone(), datatype, nullable);
1074        field.set_metadata(metadata);
1075        (None, field.into())
1076    })
1077    .collect::<Vec<_>>();
1078
1079    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1080}
1081
1082/// See `<https://github.com/apache/datafusion/pull/2108>`
1083fn project_with_column_index(
1084    expr: Vec<Expr>,
1085    input: Arc<LogicalPlan>,
1086    schema: DFSchemaRef,
1087) -> Result<LogicalPlan> {
1088    let alias_expr = expr
1089        .into_iter()
1090        .enumerate()
1091        .map(|(i, e)| match e {
1092            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1093                Ok(e.unalias().alias(schema.field(i).name()))
1094            }
1095            Expr::Column(Column {
1096                relation: _,
1097                ref name,
1098                spans: _,
1099            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1100            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1101            #[expect(deprecated)]
1102            Expr::Wildcard { .. } => {
1103                plan_err!("Wildcard should be expanded before type coercion")
1104            }
1105            _ => Ok(e.alias(schema.field(i).name())),
1106        })
1107        .collect::<Result<Vec<_>>>()?;
1108
1109    Projection::try_new_with_schema(alias_expr, input, schema)
1110        .map(LogicalPlan::Projection)
1111}
1112
1113#[cfg(test)]
1114mod test {
1115    use std::any::Any;
1116    use std::sync::Arc;
1117
1118    use arrow::datatypes::DataType::Utf8;
1119    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1120    use insta::assert_snapshot;
1121
1122    use crate::analyzer::type_coercion::{
1123        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1124    };
1125    use crate::analyzer::Analyzer;
1126    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1127    use datafusion_common::config::ConfigOptions;
1128    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1129    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1130    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1131    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1132    use datafusion_expr::test::function_stub::avg_udaf;
1133    use datafusion_expr::{
1134        cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1135        BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1136        Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1137        SimpleAggregateUDF, Subquery, Union, Volatility,
1138    };
1139    use datafusion_functions_aggregate::average::AvgAccumulator;
1140    use datafusion_sql::TableReference;
1141
1142    fn empty() -> Arc<LogicalPlan> {
1143        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1144            produce_one_row: false,
1145            schema: Arc::new(DFSchema::empty()),
1146        }))
1147    }
1148
1149    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1150        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1151            produce_one_row: false,
1152            schema: Arc::new(
1153                DFSchema::from_unqualified_fields(
1154                    vec![Field::new("a", data_type, true)].into(),
1155                    std::collections::HashMap::new(),
1156                )
1157                .unwrap(),
1158            ),
1159        }))
1160    }
1161
1162    macro_rules! assert_analyzed_plan_eq {
1163        (
1164            $plan: expr,
1165            @ $expected: literal $(,)?
1166        ) => {{
1167            let options = ConfigOptions::default();
1168            let rule = Arc::new(TypeCoercion::new());
1169            assert_analyzed_plan_with_config_eq_snapshot!(
1170                options,
1171                rule,
1172                $plan,
1173                @ $expected,
1174            )
1175            }};
1176    }
1177
1178    macro_rules! coerce_on_output_if_viewtype {
1179        (
1180            $is_viewtype: expr,
1181            $plan: expr,
1182            @ $expected: literal $(,)?
1183        ) => {{
1184            let mut options = ConfigOptions::default();
1185            // coerce on output
1186            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1187            let rule = Arc::new(TypeCoercion::new());
1188
1189            assert_analyzed_plan_with_config_eq_snapshot!(
1190                options,
1191                rule,
1192                $plan,
1193                @ $expected,
1194            )
1195        }};
1196    }
1197
1198    fn assert_type_coercion_error(
1199        plan: LogicalPlan,
1200        expected_substr: &str,
1201    ) -> Result<()> {
1202        let options = ConfigOptions::default();
1203        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1204
1205        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1206            Ok(succeeded_plan) => {
1207                panic!(
1208                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1209                );
1210            }
1211            Err(e) => {
1212                let msg = e.to_string();
1213                assert!(
1214                    msg.contains(expected_substr),
1215                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1216                );
1217            }
1218        }
1219
1220        Ok(())
1221    }
1222
1223    #[test]
1224    fn simple_case() -> Result<()> {
1225        let expr = col("a").lt(lit(2_u32));
1226        let empty = empty_with_type(DataType::Float64);
1227        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1228
1229        assert_analyzed_plan_eq!(
1230            plan,
1231            @r"
1232        Projection: a < CAST(UInt32(2) AS Float64)
1233          EmptyRelation: rows=0
1234        "
1235        )
1236    }
1237
1238    #[test]
1239    fn test_coerce_union() -> Result<()> {
1240        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1241            produce_one_row: false,
1242            schema: Arc::new(
1243                DFSchema::try_from_qualified_schema(
1244                    TableReference::full("datafusion", "test", "foo"),
1245                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1246                )
1247                .unwrap(),
1248            ),
1249        }));
1250        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1251            produce_one_row: false,
1252            schema: Arc::new(
1253                DFSchema::try_from_qualified_schema(
1254                    TableReference::full("datafusion", "test", "foo"),
1255                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1256                )
1257                .unwrap(),
1258            ),
1259        }));
1260        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1261            left_plan, right_plan,
1262        ])?);
1263        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1264            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1265        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1266            vec![col("a")],
1267            Arc::new(analyzed_union),
1268        )?);
1269
1270        assert_analyzed_plan_eq!(
1271            top_level_plan,
1272            @r"
1273        Projection: a
1274          Union
1275            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1276              EmptyRelation: rows=0
1277            EmptyRelation: rows=0
1278        "
1279        )
1280    }
1281
1282    #[test]
1283    fn coerce_utf8view_output() -> Result<()> {
1284        // Plan A
1285        // scenario: outermost utf8view projection
1286        let expr = col("a");
1287        let empty = empty_with_type(DataType::Utf8View);
1288        let plan = LogicalPlan::Projection(Projection::try_new(
1289            vec![expr.clone()],
1290            Arc::clone(&empty),
1291        )?);
1292
1293        // Plan A: no coerce
1294        coerce_on_output_if_viewtype!(
1295            false,
1296            plan.clone(),
1297            @r"
1298        Projection: a
1299          EmptyRelation: rows=0
1300        "
1301        )?;
1302
1303        // Plan A: coerce requested: Utf8View => LargeUtf8
1304        coerce_on_output_if_viewtype!(
1305            true,
1306            plan.clone(),
1307            @r"
1308        Projection: CAST(a AS LargeUtf8)
1309          EmptyRelation: rows=0
1310        "
1311        )?;
1312
1313        // Plan B
1314        // scenario: outermost bool projection
1315        let bool_expr = col("a").lt(lit("foo"));
1316        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1317            vec![bool_expr],
1318            Arc::clone(&empty),
1319        )?);
1320        // Plan B: no coerce
1321        coerce_on_output_if_viewtype!(
1322            false,
1323            bool_plan.clone(),
1324            @r#"
1325        Projection: a < CAST(Utf8("foo") AS Utf8View)
1326          EmptyRelation: rows=0
1327        "#
1328        )?;
1329
1330        coerce_on_output_if_viewtype!(
1331            false,
1332            plan.clone(),
1333            @r"
1334        Projection: a
1335          EmptyRelation: rows=0
1336        "
1337        )?;
1338
1339        // Plan B: coerce requested: no coercion applied
1340        coerce_on_output_if_viewtype!(
1341            true,
1342            plan.clone(),
1343            @r"
1344        Projection: CAST(a AS LargeUtf8)
1345          EmptyRelation: rows=0
1346        "
1347        )?;
1348
1349        // Plan C
1350        // scenario: with a non-projection root logical plan node
1351        let sort_expr = expr.sort(true, true);
1352        let sort_plan = LogicalPlan::Sort(Sort {
1353            expr: vec![sort_expr],
1354            input: Arc::new(plan),
1355            fetch: None,
1356        });
1357
1358        // Plan C: no coerce
1359        coerce_on_output_if_viewtype!(
1360            false,
1361            sort_plan.clone(),
1362            @r"
1363        Sort: a ASC NULLS FIRST
1364          Projection: a
1365            EmptyRelation: rows=0
1366        "
1367        )?;
1368
1369        // Plan C: coerce requested: Utf8View => LargeUtf8
1370        coerce_on_output_if_viewtype!(
1371            true,
1372            sort_plan.clone(),
1373            @r"
1374        Projection: CAST(a AS LargeUtf8)
1375          Sort: a ASC NULLS FIRST
1376            Projection: a
1377              EmptyRelation: rows=0
1378        "
1379        )?;
1380
1381        // Plan D
1382        // scenario: two layers of projections with view types
1383        let plan = LogicalPlan::Projection(Projection::try_new(
1384            vec![col("a")],
1385            Arc::new(sort_plan),
1386        )?);
1387        // Plan D: no coerce
1388        coerce_on_output_if_viewtype!(
1389            false,
1390            plan.clone(),
1391            @r"
1392        Projection: a
1393          Sort: a ASC NULLS FIRST
1394            Projection: a
1395              EmptyRelation: rows=0
1396        "
1397        )?;
1398        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1399        coerce_on_output_if_viewtype!(
1400            true,
1401            plan.clone(),
1402            @r"
1403        Projection: CAST(a AS LargeUtf8)
1404          Sort: a ASC NULLS FIRST
1405            Projection: a
1406              EmptyRelation: rows=0
1407        "
1408        )?;
1409
1410        Ok(())
1411    }
1412
1413    #[test]
1414    fn coerce_binaryview_output() -> Result<()> {
1415        // Plan A
1416        // scenario: outermost binaryview projection
1417        let expr = col("a");
1418        let empty = empty_with_type(DataType::BinaryView);
1419        let plan = LogicalPlan::Projection(Projection::try_new(
1420            vec![expr.clone()],
1421            Arc::clone(&empty),
1422        )?);
1423
1424        // Plan A: no coerce
1425        coerce_on_output_if_viewtype!(
1426            false,
1427            plan.clone(),
1428            @r"
1429        Projection: a
1430          EmptyRelation: rows=0
1431        "
1432        )?;
1433
1434        // Plan A: coerce requested: BinaryView => LargeBinary
1435        coerce_on_output_if_viewtype!(
1436            true,
1437            plan.clone(),
1438            @r"
1439        Projection: CAST(a AS LargeBinary)
1440          EmptyRelation: rows=0
1441        "
1442        )?;
1443
1444        // Plan B
1445        // scenario: outermost bool projection
1446        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1447        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1448            vec![bool_expr],
1449            Arc::clone(&empty),
1450        )?);
1451
1452        // Plan B: no coerce
1453        coerce_on_output_if_viewtype!(
1454            false,
1455            bool_plan.clone(),
1456            @r#"
1457        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1458          EmptyRelation: rows=0
1459        "#
1460        )?;
1461
1462        // Plan B: coerce requested: no coercion applied
1463        coerce_on_output_if_viewtype!(
1464            true,
1465            bool_plan.clone(),
1466            @r#"
1467        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1468          EmptyRelation: rows=0
1469        "#
1470        )?;
1471
1472        // Plan C
1473        // scenario: with a non-projection root logical plan node
1474        let sort_expr = expr.sort(true, true);
1475        let sort_plan = LogicalPlan::Sort(Sort {
1476            expr: vec![sort_expr],
1477            input: Arc::new(plan),
1478            fetch: None,
1479        });
1480
1481        // Plan C: no coerce
1482        coerce_on_output_if_viewtype!(
1483            false,
1484            sort_plan.clone(),
1485            @r"
1486        Sort: a ASC NULLS FIRST
1487          Projection: a
1488            EmptyRelation: rows=0
1489        "
1490        )?;
1491        // Plan C: coerce requested: BinaryView => LargeBinary
1492        coerce_on_output_if_viewtype!(
1493            true,
1494            sort_plan.clone(),
1495            @r"
1496        Projection: CAST(a AS LargeBinary)
1497          Sort: a ASC NULLS FIRST
1498            Projection: a
1499              EmptyRelation: rows=0
1500        "
1501        )?;
1502
1503        // Plan D
1504        // scenario: two layers of projections with view types
1505        let plan = LogicalPlan::Projection(Projection::try_new(
1506            vec![col("a")],
1507            Arc::new(sort_plan),
1508        )?);
1509
1510        // Plan D: no coerce
1511        coerce_on_output_if_viewtype!(
1512            false,
1513            plan.clone(),
1514            @r"
1515        Projection: a
1516          Sort: a ASC NULLS FIRST
1517            Projection: a
1518              EmptyRelation: rows=0
1519        "
1520        )?;
1521
1522        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1523        coerce_on_output_if_viewtype!(
1524            true,
1525            plan.clone(),
1526            @r"
1527        Projection: CAST(a AS LargeBinary)
1528          Sort: a ASC NULLS FIRST
1529            Projection: a
1530              EmptyRelation: rows=0
1531        "
1532        )?;
1533
1534        Ok(())
1535    }
1536
1537    #[test]
1538    fn nested_case() -> Result<()> {
1539        let expr = col("a").lt(lit(2_u32));
1540        let empty = empty_with_type(DataType::Float64);
1541
1542        let plan = LogicalPlan::Projection(Projection::try_new(
1543            vec![expr.clone().or(expr)],
1544            empty,
1545        )?);
1546
1547        assert_analyzed_plan_eq!(
1548            plan,
1549            @r"
1550        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1551          EmptyRelation: rows=0
1552        "
1553        )
1554    }
1555
1556    #[derive(Debug, PartialEq, Eq, Hash)]
1557    struct TestScalarUDF {
1558        signature: Signature,
1559    }
1560
1561    impl ScalarUDFImpl for TestScalarUDF {
1562        fn as_any(&self) -> &dyn Any {
1563            self
1564        }
1565
1566        fn name(&self) -> &str {
1567            "TestScalarUDF"
1568        }
1569
1570        fn signature(&self) -> &Signature {
1571            &self.signature
1572        }
1573
1574        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1575            Ok(Utf8)
1576        }
1577
1578        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1579            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1580        }
1581    }
1582
1583    #[test]
1584    fn scalar_udf() -> Result<()> {
1585        let empty = empty();
1586
1587        let udf = ScalarUDF::from(TestScalarUDF {
1588            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1589        })
1590        .call(vec![lit(123_i32)]);
1591        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1592
1593        assert_analyzed_plan_eq!(
1594            plan,
1595            @r"
1596        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1597          EmptyRelation: rows=0
1598        "
1599        )
1600    }
1601
1602    #[test]
1603    fn scalar_udf_invalid_input() -> Result<()> {
1604        let empty = empty();
1605        let udf = ScalarUDF::from(TestScalarUDF {
1606            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1607        })
1608        .call(vec![lit("Apple")]);
1609        Projection::try_new(vec![udf], empty)
1610            .expect_err("Expected an error due to incorrect function input");
1611
1612        Ok(())
1613    }
1614
1615    #[test]
1616    fn scalar_function() -> Result<()> {
1617        // test that automatic argument type coercion for scalar functions work
1618        let empty = empty();
1619        let lit_expr = lit(10i64);
1620        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1621            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1622        });
1623        let scalar_function_expr =
1624            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1625        let plan = LogicalPlan::Projection(Projection::try_new(
1626            vec![scalar_function_expr],
1627            empty,
1628        )?);
1629
1630        assert_analyzed_plan_eq!(
1631            plan,
1632            @r"
1633        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1634          EmptyRelation: rows=0
1635        "
1636        )
1637    }
1638
1639    #[test]
1640    fn agg_udaf() -> Result<()> {
1641        let empty = empty();
1642        let my_avg = create_udaf(
1643            "MY_AVG",
1644            vec![DataType::Float64],
1645            Arc::new(DataType::Float64),
1646            Volatility::Immutable,
1647            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1648            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1649        );
1650        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1651            Arc::new(my_avg),
1652            vec![lit(10i64)],
1653            false,
1654            None,
1655            vec![],
1656            None,
1657        ));
1658        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1659
1660        assert_analyzed_plan_eq!(
1661            plan,
1662            @r"
1663        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1664          EmptyRelation: rows=0
1665        "
1666        )
1667    }
1668
1669    #[test]
1670    fn agg_udaf_invalid_input() -> Result<()> {
1671        let empty = empty();
1672        let return_type = DataType::Float64;
1673        let accumulator: AccumulatorFactoryFunction =
1674            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1675        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1676            "MY_AVG",
1677            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1678            return_type,
1679            accumulator,
1680            vec![
1681                Field::new("count", DataType::UInt64, true).into(),
1682                Field::new("avg", DataType::Float64, true).into(),
1683            ],
1684        ));
1685        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1686            Arc::new(my_avg),
1687            vec![lit("10")],
1688            false,
1689            None,
1690            vec![],
1691            None,
1692        ));
1693
1694        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1695        assert!(
1696            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1697        );
1698        Ok(())
1699    }
1700
1701    #[test]
1702    fn agg_function_case() -> Result<()> {
1703        let empty = empty();
1704        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1705            avg_udaf(),
1706            vec![lit(12f64)],
1707            false,
1708            None,
1709            vec![],
1710            None,
1711        ));
1712        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1713
1714        assert_analyzed_plan_eq!(
1715            plan,
1716            @r"
1717        Projection: avg(Float64(12))
1718          EmptyRelation: rows=0
1719        "
1720        )?;
1721
1722        let empty = empty_with_type(DataType::Int32);
1723        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1724            avg_udaf(),
1725            vec![cast(col("a"), DataType::Float64)],
1726            false,
1727            None,
1728            vec![],
1729            None,
1730        ));
1731        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1732
1733        assert_analyzed_plan_eq!(
1734            plan,
1735            @r"
1736        Projection: avg(CAST(a AS Float64))
1737          EmptyRelation: rows=0
1738        "
1739        )
1740    }
1741
1742    #[test]
1743    fn agg_function_invalid_input_avg() -> Result<()> {
1744        let empty = empty();
1745        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1746            avg_udaf(),
1747            vec![lit("1")],
1748            false,
1749            None,
1750            vec![],
1751            None,
1752        ));
1753        let err = Projection::try_new(vec![agg_expr], empty)
1754            .err()
1755            .unwrap()
1756            .strip_backtrace();
1757        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1758        Ok(())
1759    }
1760
1761    #[test]
1762    fn binary_op_date32_op_interval() -> Result<()> {
1763        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1764        let expr = cast(lit("1998-03-18"), DataType::Date32)
1765            + lit(ScalarValue::new_interval_dt(123, 456));
1766        let empty = empty();
1767        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1768
1769        assert_analyzed_plan_eq!(
1770            plan,
1771            @r#"
1772        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1773          EmptyRelation: rows=0
1774        "#
1775        )
1776    }
1777
1778    #[test]
1779    fn inlist_case() -> Result<()> {
1780        // a in (1,4,8), a is int64
1781        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1782        let empty = empty_with_type(DataType::Int64);
1783        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1784        assert_analyzed_plan_eq!(
1785            plan,
1786            @r"
1787        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1788          EmptyRelation: rows=0
1789        ")?;
1790
1791        // a in (1,4,8), a is decimal
1792        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1793        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1794            produce_one_row: false,
1795            schema: Arc::new(DFSchema::from_unqualified_fields(
1796                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1797                std::collections::HashMap::new(),
1798            )?),
1799        }));
1800        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1801        assert_analyzed_plan_eq!(
1802            plan,
1803            @r"
1804        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1805          EmptyRelation: rows=0
1806        ")
1807    }
1808
1809    #[test]
1810    fn between_case() -> Result<()> {
1811        let expr = col("a").between(
1812            lit("2002-05-08"),
1813            // (cast('2002-05-08' as date) + interval '1 months')
1814            cast(lit("2002-05-08"), DataType::Date32)
1815                + lit(ScalarValue::new_interval_ym(0, 1)),
1816        );
1817        let empty = empty_with_type(Utf8);
1818        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1819
1820        assert_analyzed_plan_eq!(
1821            plan,
1822            @r#"
1823        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1824          EmptyRelation: rows=0
1825        "#
1826        )
1827    }
1828
1829    #[test]
1830    fn between_infer_cheap_type() -> Result<()> {
1831        let expr = col("a").between(
1832            // (cast('2002-05-08' as date) + interval '1 months')
1833            cast(lit("2002-05-08"), DataType::Date32)
1834                + lit(ScalarValue::new_interval_ym(0, 1)),
1835            lit("2002-12-08"),
1836        );
1837        let empty = empty_with_type(Utf8);
1838        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1839
1840        // TODO: we should cast col(a).
1841        assert_analyzed_plan_eq!(
1842            plan,
1843            @r#"
1844        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1845          EmptyRelation: rows=0
1846        "#
1847        )
1848    }
1849
1850    #[test]
1851    fn between_null() -> Result<()> {
1852        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1853        let empty = empty();
1854        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1855
1856        assert_analyzed_plan_eq!(
1857            plan,
1858            @r"
1859        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1860          EmptyRelation: rows=0
1861        "
1862        )
1863    }
1864
1865    #[test]
1866    fn is_bool_for_type_coercion() -> Result<()> {
1867        // is true
1868        let expr = col("a").is_true();
1869        let empty = empty_with_type(DataType::Boolean);
1870        let plan =
1871            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1872
1873        assert_analyzed_plan_eq!(
1874            plan,
1875            @r"
1876        Projection: a IS TRUE
1877          EmptyRelation: rows=0
1878        "
1879        )?;
1880
1881        let empty = empty_with_type(DataType::Int64);
1882        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1883        assert_type_coercion_error(
1884            plan,
1885            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1886        )?;
1887
1888        // is not true
1889        let expr = col("a").is_not_true();
1890        let empty = empty_with_type(DataType::Boolean);
1891        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1892
1893        assert_analyzed_plan_eq!(
1894            plan,
1895            @r"
1896        Projection: a IS NOT TRUE
1897          EmptyRelation: rows=0
1898        "
1899        )?;
1900
1901        // is false
1902        let expr = col("a").is_false();
1903        let empty = empty_with_type(DataType::Boolean);
1904        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1905
1906        assert_analyzed_plan_eq!(
1907            plan,
1908            @r"
1909        Projection: a IS FALSE
1910          EmptyRelation: rows=0
1911        "
1912        )?;
1913
1914        // is not false
1915        let expr = col("a").is_not_false();
1916        let empty = empty_with_type(DataType::Boolean);
1917        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1918
1919        assert_analyzed_plan_eq!(
1920            plan,
1921            @r"
1922        Projection: a IS NOT FALSE
1923          EmptyRelation: rows=0
1924        "
1925        )
1926    }
1927
1928    #[test]
1929    fn like_for_type_coercion() -> Result<()> {
1930        // like : utf8 like "abc"
1931        let expr = Box::new(col("a"));
1932        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1933        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1934        let empty = empty_with_type(Utf8);
1935        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1936
1937        assert_analyzed_plan_eq!(
1938            plan,
1939            @r#"
1940        Projection: a LIKE Utf8("abc")
1941          EmptyRelation: rows=0
1942        "#
1943        )?;
1944
1945        let expr = Box::new(col("a"));
1946        let pattern = Box::new(lit(ScalarValue::Null));
1947        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1948        let empty = empty_with_type(Utf8);
1949        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1950
1951        assert_analyzed_plan_eq!(
1952            plan,
1953            @r"
1954        Projection: a LIKE CAST(NULL AS Utf8)
1955          EmptyRelation: rows=0
1956        "
1957        )?;
1958
1959        let expr = Box::new(col("a"));
1960        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1961        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1962        let empty = empty_with_type(DataType::Int64);
1963        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1964        assert_type_coercion_error(
1965            plan,
1966            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1967        )?;
1968
1969        // ilike
1970        let expr = Box::new(col("a"));
1971        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1972        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1973        let empty = empty_with_type(Utf8);
1974        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1975
1976        assert_analyzed_plan_eq!(
1977            plan,
1978            @r#"
1979        Projection: a ILIKE Utf8("abc")
1980          EmptyRelation: rows=0
1981        "#
1982        )?;
1983
1984        let expr = Box::new(col("a"));
1985        let pattern = Box::new(lit(ScalarValue::Null));
1986        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1987        let empty = empty_with_type(Utf8);
1988        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1989
1990        assert_analyzed_plan_eq!(
1991            plan,
1992            @r"
1993        Projection: a ILIKE CAST(NULL AS Utf8)
1994          EmptyRelation: rows=0
1995        "
1996        )?;
1997
1998        let expr = Box::new(col("a"));
1999        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2000        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2001        let empty = empty_with_type(DataType::Int64);
2002        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2003        assert_type_coercion_error(
2004            plan,
2005            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2006        )?;
2007
2008        Ok(())
2009    }
2010
2011    #[test]
2012    fn unknown_for_type_coercion() -> Result<()> {
2013        // unknown
2014        let expr = col("a").is_unknown();
2015        let empty = empty_with_type(DataType::Boolean);
2016        let plan =
2017            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2018
2019        assert_analyzed_plan_eq!(
2020            plan,
2021            @r"
2022        Projection: a IS UNKNOWN
2023          EmptyRelation: rows=0
2024        "
2025        )?;
2026
2027        let empty = empty_with_type(Utf8);
2028        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2029        assert_type_coercion_error(
2030            plan,
2031            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2032        )?;
2033
2034        // is not unknown
2035        let expr = col("a").is_not_unknown();
2036        let empty = empty_with_type(DataType::Boolean);
2037        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2038
2039        assert_analyzed_plan_eq!(
2040            plan,
2041            @r"
2042        Projection: a IS NOT UNKNOWN
2043          EmptyRelation: rows=0
2044        "
2045        )
2046    }
2047
2048    #[test]
2049    fn concat_for_type_coercion() -> Result<()> {
2050        let empty = empty_with_type(Utf8);
2051        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2052
2053        // concat-type signature
2054        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2055            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2056        })
2057        .call(args.to_vec());
2058        let plan =
2059            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2060        assert_analyzed_plan_eq!(
2061            plan,
2062            @r#"
2063        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2064          EmptyRelation: rows=0
2065        "#
2066        )
2067    }
2068
2069    #[test]
2070    fn test_type_coercion_rewrite() -> Result<()> {
2071        // gt
2072        let schema = Arc::new(DFSchema::from_unqualified_fields(
2073            vec![Field::new("a", DataType::Int64, true)].into(),
2074            std::collections::HashMap::new(),
2075        )?);
2076        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2077        let expr = is_true(lit(12i32).gt(lit(13i64)));
2078        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2079        let result = expr.rewrite(&mut rewriter).data()?;
2080        assert_eq!(expected, result);
2081
2082        // eq
2083        let schema = Arc::new(DFSchema::from_unqualified_fields(
2084            vec![Field::new("a", DataType::Int64, true)].into(),
2085            std::collections::HashMap::new(),
2086        )?);
2087        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2088        let expr = is_true(lit(12i32).eq(lit(13i64)));
2089        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2090        let result = expr.rewrite(&mut rewriter).data()?;
2091        assert_eq!(expected, result);
2092
2093        // lt
2094        let schema = Arc::new(DFSchema::from_unqualified_fields(
2095            vec![Field::new("a", DataType::Int64, true)].into(),
2096            std::collections::HashMap::new(),
2097        )?);
2098        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2099        let expr = is_true(lit(12i32).lt(lit(13i64)));
2100        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2101        let result = expr.rewrite(&mut rewriter).data()?;
2102        assert_eq!(expected, result);
2103
2104        Ok(())
2105    }
2106
2107    #[test]
2108    fn binary_op_date32_eq_ts() -> Result<()> {
2109        let expr = cast(
2110            lit("1998-03-18"),
2111            DataType::Timestamp(TimeUnit::Nanosecond, None),
2112        )
2113        .eq(cast(lit("1998-03-18"), DataType::Date32));
2114        let empty = empty();
2115        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2116
2117        assert_analyzed_plan_eq!(
2118            plan,
2119            @r#"
2120        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2121          EmptyRelation: rows=0
2122        "#
2123        )
2124    }
2125
2126    fn cast_if_not_same_type(
2127        expr: Box<Expr>,
2128        data_type: &DataType,
2129        schema: &DFSchemaRef,
2130    ) -> Box<Expr> {
2131        if &expr.get_type(schema).unwrap() != data_type {
2132            Box::new(cast(*expr, data_type.clone()))
2133        } else {
2134            expr
2135        }
2136    }
2137
2138    fn cast_helper(
2139        case: Case,
2140        case_when_type: &DataType,
2141        then_else_type: &DataType,
2142        schema: &DFSchemaRef,
2143    ) -> Case {
2144        let expr = case
2145            .expr
2146            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2147        let when_then_expr = case
2148            .when_then_expr
2149            .into_iter()
2150            .map(|(when, then)| {
2151                (
2152                    cast_if_not_same_type(when, case_when_type, schema),
2153                    cast_if_not_same_type(then, then_else_type, schema),
2154                )
2155            })
2156            .collect::<Vec<_>>();
2157        let else_expr = case
2158            .else_expr
2159            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2160
2161        Case {
2162            expr,
2163            when_then_expr,
2164            else_expr,
2165        }
2166    }
2167
2168    #[test]
2169    fn test_case_expression_coercion() -> Result<()> {
2170        let schema = Arc::new(DFSchema::from_unqualified_fields(
2171            vec![
2172                Field::new("boolean", DataType::Boolean, true),
2173                Field::new("integer", DataType::Int32, true),
2174                Field::new("float", DataType::Float32, true),
2175                Field::new(
2176                    "timestamp",
2177                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2178                    true,
2179                ),
2180                Field::new("date", DataType::Date32, true),
2181                Field::new(
2182                    "interval",
2183                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2184                    true,
2185                ),
2186                Field::new("binary", DataType::Binary, true),
2187                Field::new("string", Utf8, true),
2188                Field::new("decimal", DataType::Decimal128(10, 10), true),
2189            ]
2190            .into(),
2191            std::collections::HashMap::new(),
2192        )?);
2193
2194        let case = Case {
2195            expr: None,
2196            when_then_expr: vec![
2197                (Box::new(col("boolean")), Box::new(col("integer"))),
2198                (Box::new(col("integer")), Box::new(col("float"))),
2199                (Box::new(col("string")), Box::new(col("string"))),
2200            ],
2201            else_expr: None,
2202        };
2203        let case_when_common_type = DataType::Boolean;
2204        let then_else_common_type = Utf8;
2205        let expected = cast_helper(
2206            case.clone(),
2207            &case_when_common_type,
2208            &then_else_common_type,
2209            &schema,
2210        );
2211        let actual = coerce_case_expression(case, &schema)?;
2212        assert_eq!(expected, actual);
2213
2214        let case = Case {
2215            expr: Some(Box::new(col("string"))),
2216            when_then_expr: vec![
2217                (Box::new(col("float")), Box::new(col("integer"))),
2218                (Box::new(col("integer")), Box::new(col("float"))),
2219                (Box::new(col("string")), Box::new(col("string"))),
2220            ],
2221            else_expr: Some(Box::new(col("string"))),
2222        };
2223        let case_when_common_type = Utf8;
2224        let then_else_common_type = Utf8;
2225        let expected = cast_helper(
2226            case.clone(),
2227            &case_when_common_type,
2228            &then_else_common_type,
2229            &schema,
2230        );
2231        let actual = coerce_case_expression(case, &schema)?;
2232        assert_eq!(expected, actual);
2233
2234        let case = Case {
2235            expr: Some(Box::new(col("interval"))),
2236            when_then_expr: vec![
2237                (Box::new(col("float")), Box::new(col("integer"))),
2238                (Box::new(col("binary")), Box::new(col("float"))),
2239                (Box::new(col("string")), Box::new(col("string"))),
2240            ],
2241            else_expr: Some(Box::new(col("string"))),
2242        };
2243        let err = coerce_case_expression(case, &schema).unwrap_err();
2244        assert_snapshot!(
2245            err.strip_backtrace(),
2246            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2247        );
2248
2249        let case = Case {
2250            expr: Some(Box::new(col("string"))),
2251            when_then_expr: vec![
2252                (Box::new(col("float")), Box::new(col("date"))),
2253                (Box::new(col("string")), Box::new(col("float"))),
2254                (Box::new(col("string")), Box::new(col("binary"))),
2255            ],
2256            else_expr: Some(Box::new(col("timestamp"))),
2257        };
2258        let err = coerce_case_expression(case, &schema).unwrap_err();
2259        assert_snapshot!(
2260            err.strip_backtrace(),
2261            @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2262        );
2263
2264        Ok(())
2265    }
2266
2267    macro_rules! test_case_expression {
2268        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2269            let case = Case {
2270                expr: $expr.map(|e| Box::new(col(e))),
2271                when_then_expr: $when_then,
2272                else_expr: None,
2273            };
2274
2275            let expected =
2276                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2277
2278            let actual = coerce_case_expression(case, &$schema)?;
2279            assert_eq!(expected, actual);
2280        };
2281    }
2282
2283    #[test]
2284    fn tes_case_when_list() -> Result<()> {
2285        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2286        let schema = Arc::new(DFSchema::from_unqualified_fields(
2287            vec![
2288                Field::new(
2289                    "large_list",
2290                    DataType::LargeList(Arc::clone(&inner_field)),
2291                    true,
2292                ),
2293                Field::new(
2294                    "fixed_list",
2295                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2296                    true,
2297                ),
2298                Field::new("list", DataType::List(inner_field), true),
2299            ]
2300            .into(),
2301            std::collections::HashMap::new(),
2302        )?);
2303
2304        test_case_expression!(
2305            Some("list"),
2306            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2307            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2308            Utf8,
2309            schema
2310        );
2311
2312        test_case_expression!(
2313            Some("large_list"),
2314            vec![(Box::new(col("list")), Box::new(lit("1")))],
2315            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2316            Utf8,
2317            schema
2318        );
2319
2320        test_case_expression!(
2321            Some("list"),
2322            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2323            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2324            Utf8,
2325            schema
2326        );
2327
2328        test_case_expression!(
2329            Some("fixed_list"),
2330            vec![(Box::new(col("list")), Box::new(lit("1")))],
2331            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2332            Utf8,
2333            schema
2334        );
2335
2336        test_case_expression!(
2337            Some("fixed_list"),
2338            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2339            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2340            Utf8,
2341            schema
2342        );
2343
2344        test_case_expression!(
2345            Some("large_list"),
2346            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2347            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2348            Utf8,
2349            schema
2350        );
2351        Ok(())
2352    }
2353
2354    #[test]
2355    fn test_then_else_list() -> Result<()> {
2356        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2357        let schema = Arc::new(DFSchema::from_unqualified_fields(
2358            vec![
2359                Field::new("boolean", DataType::Boolean, true),
2360                Field::new(
2361                    "large_list",
2362                    DataType::LargeList(Arc::clone(&inner_field)),
2363                    true,
2364                ),
2365                Field::new(
2366                    "fixed_list",
2367                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2368                    true,
2369                ),
2370                Field::new("list", DataType::List(inner_field), true),
2371            ]
2372            .into(),
2373            std::collections::HashMap::new(),
2374        )?);
2375
2376        // large list and list
2377        test_case_expression!(
2378            None::<String>,
2379            vec![
2380                (Box::new(col("boolean")), Box::new(col("large_list"))),
2381                (Box::new(col("boolean")), Box::new(col("list")))
2382            ],
2383            DataType::Boolean,
2384            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2385            schema
2386        );
2387
2388        test_case_expression!(
2389            None::<String>,
2390            vec![
2391                (Box::new(col("boolean")), Box::new(col("list"))),
2392                (Box::new(col("boolean")), Box::new(col("large_list")))
2393            ],
2394            DataType::Boolean,
2395            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2396            schema
2397        );
2398
2399        // fixed list and list
2400        test_case_expression!(
2401            None::<String>,
2402            vec![
2403                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2404                (Box::new(col("boolean")), Box::new(col("list")))
2405            ],
2406            DataType::Boolean,
2407            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2408            schema
2409        );
2410
2411        test_case_expression!(
2412            None::<String>,
2413            vec![
2414                (Box::new(col("boolean")), Box::new(col("list"))),
2415                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2416            ],
2417            DataType::Boolean,
2418            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2419            schema
2420        );
2421
2422        // fixed list and large list
2423        test_case_expression!(
2424            None::<String>,
2425            vec![
2426                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2427                (Box::new(col("boolean")), Box::new(col("large_list")))
2428            ],
2429            DataType::Boolean,
2430            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2431            schema
2432        );
2433
2434        test_case_expression!(
2435            None::<String>,
2436            vec![
2437                (Box::new(col("boolean")), Box::new(col("large_list"))),
2438                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2439            ],
2440            DataType::Boolean,
2441            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2442            schema
2443        );
2444        Ok(())
2445    }
2446
2447    #[test]
2448    fn test_map_with_diff_name() -> Result<()> {
2449        let mut builder = SchemaBuilder::new();
2450        builder.push(Field::new("key", Utf8, false));
2451        builder.push(Field::new("value", DataType::Float64, true));
2452        let struct_fields = builder.finish().fields;
2453
2454        let fields =
2455            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2456        let map_type_entries = DataType::Map(Arc::new(fields), false);
2457
2458        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2459        let may_type_custom = DataType::Map(Arc::new(fields), false);
2460
2461        let expr = col("a").eq(cast(col("a"), may_type_custom));
2462        let empty = empty_with_type(map_type_entries);
2463        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2464
2465        assert_analyzed_plan_eq!(
2466            plan,
2467            @r#"
2468        Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted))
2469          EmptyRelation: rows=0
2470        "#
2471        )
2472    }
2473
2474    #[test]
2475    fn interval_plus_timestamp() -> Result<()> {
2476        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2477        let expr = Expr::BinaryExpr(BinaryExpr::new(
2478            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2479            Operator::Plus,
2480            Box::new(cast(
2481                lit("2000-01-01T00:00:00"),
2482                DataType::Timestamp(TimeUnit::Nanosecond, None),
2483            )),
2484        ));
2485        let empty = empty();
2486        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2487
2488        assert_analyzed_plan_eq!(
2489            plan,
2490            @r#"
2491        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2492          EmptyRelation: rows=0
2493        "#
2494        )
2495    }
2496
2497    #[test]
2498    fn timestamp_subtract_timestamp() -> Result<()> {
2499        let expr = Expr::BinaryExpr(BinaryExpr::new(
2500            Box::new(cast(
2501                lit("1998-03-18"),
2502                DataType::Timestamp(TimeUnit::Nanosecond, None),
2503            )),
2504            Operator::Minus,
2505            Box::new(cast(
2506                lit("1998-03-18"),
2507                DataType::Timestamp(TimeUnit::Nanosecond, None),
2508            )),
2509        ));
2510        let empty = empty();
2511        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2512
2513        assert_analyzed_plan_eq!(
2514            plan,
2515            @r#"
2516        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2517          EmptyRelation: rows=0
2518        "#
2519        )
2520    }
2521
2522    #[test]
2523    fn in_subquery_cast_subquery() -> Result<()> {
2524        let empty_int32 = empty_with_type(DataType::Int32);
2525        let empty_int64 = empty_with_type(DataType::Int64);
2526
2527        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2528            Box::new(col("a")),
2529            Subquery {
2530                subquery: empty_int32,
2531                outer_ref_columns: vec![],
2532                spans: Spans::new(),
2533            },
2534            false,
2535        ));
2536        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2537        // add cast for subquery
2538
2539        assert_analyzed_plan_eq!(
2540            plan,
2541            @r"
2542        Filter: a IN (<subquery>)
2543          Subquery:
2544            Projection: CAST(a AS Int64)
2545              EmptyRelation: rows=0
2546          EmptyRelation: rows=0
2547        "
2548        )
2549    }
2550
2551    #[test]
2552    fn in_subquery_cast_expr() -> Result<()> {
2553        let empty_int32 = empty_with_type(DataType::Int32);
2554        let empty_int64 = empty_with_type(DataType::Int64);
2555
2556        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2557            Box::new(col("a")),
2558            Subquery {
2559                subquery: empty_int64,
2560                outer_ref_columns: vec![],
2561                spans: Spans::new(),
2562            },
2563            false,
2564        ));
2565        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2566
2567        // add cast for subquery
2568        assert_analyzed_plan_eq!(
2569            plan,
2570            @r"
2571        Filter: CAST(a AS Int64) IN (<subquery>)
2572          Subquery:
2573            EmptyRelation: rows=0
2574          EmptyRelation: rows=0
2575        "
2576        )
2577    }
2578
2579    #[test]
2580    fn in_subquery_cast_all() -> Result<()> {
2581        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2582        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2583
2584        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2585            Box::new(col("a")),
2586            Subquery {
2587                subquery: empty_inside,
2588                outer_ref_columns: vec![],
2589                spans: Spans::new(),
2590            },
2591            false,
2592        ));
2593        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2594
2595        // add cast for subquery
2596        assert_analyzed_plan_eq!(
2597            plan,
2598            @r"
2599        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2600          Subquery:
2601            Projection: CAST(a AS Decimal128(13, 8))
2602              EmptyRelation: rows=0
2603          EmptyRelation: rows=0
2604        "
2605        )
2606    }
2607}