Skip to main content

datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use arrow::compute::can_cast_types;
21use datafusion_expr::binary::BinaryTypeCoercer;
22use itertools::{Itertools as _, izip};
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26use crate::utils::NamePreserver;
27
28use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
29use arrow::temporal_conversions::SECONDS_IN_DAY;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
32use datafusion_common::{
33    Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34    exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
35    plan_err,
36};
37use datafusion_expr::expr::{
38    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39    InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf};
46use datafusion_expr::type_coercion::is_datetime;
47use datafusion_expr::type_coercion::other::{
48    get_coerce_type_for_case_expression, get_coerce_type_for_list,
49};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52    Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union,
53    WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, is_not_false, is_not_true,
54    is_not_unknown, is_true, is_unknown, lit, not,
55};
56
57/// Performs type coercion by determining the schema
58/// and performing the expression rewrites.
59#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63    pub fn new() -> Self {
64        Self {}
65    }
66}
67
68/// Coerce output schema based upon optimizer config.
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70    if !config.optimizer.expand_views_at_output {
71        return Ok(plan);
72    }
73
74    let outer_refs = plan.expressions();
75    if outer_refs.is_empty() {
76        return Ok(plan);
77    }
78
79    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80        coerce_plan_expr_for_schema(plan, &dfschema?)
81    } else {
82        Ok(plan)
83    }
84}
85
86impl AnalyzerRule for TypeCoercion {
87    fn name(&self) -> &str {
88        "type_coercion"
89    }
90
91    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92        let empty_schema = DFSchema::empty();
93
94        // recurse
95        let transformed_plan = plan
96            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97            .data;
98
99        // finish
100        coerce_output(transformed_plan, config)
101    }
102}
103
104/// use the external schema to handle the correlated subqueries case
105///
106/// Assumes that children have already been optimized
107fn analyze_internal(
108    external_schema: &DFSchema,
109    plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111    // get schema representing all available input fields. This is used for data type
112    // resolution only, so order does not matter here
113    let mut schema = merge_schema(&plan.inputs());
114
115    if let LogicalPlan::TableScan(ts) = &plan {
116        let source_schema = DFSchema::try_from_qualified_schema(
117            ts.table_name.clone(),
118            &ts.source.schema(),
119        )?;
120        schema.merge(&source_schema);
121    }
122
123    // merge the outer schema for correlated subqueries
124    // like case:
125    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
126    schema.merge(external_schema);
127
128    // Coerce filter predicates to boolean (handles `WHERE NULL`)
129    let plan = if let LogicalPlan::Filter(mut filter) = plan {
130        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131        LogicalPlan::Filter(filter)
132    } else {
133        plan
134    };
135
136    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138    let name_preserver = NamePreserver::new(&plan);
139    // apply coercion rewrite all expressions in the plan individually
140    plan.map_expressions(|expr| {
141        let original_name = name_preserver.save(&expr);
142        expr.rewrite(&mut expr_rewrite)
143            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144    })?
145    // some plans need extra coercion after their expressions are coerced
146    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147    // recompute the schema after the expressions have been rewritten as the types may have changed
148    .map_data(|plan| plan.recompute_schema())
149}
150
151/// Rewrite expressions to apply type coercion.
152pub struct TypeCoercionRewriter<'a> {
153    pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157    /// Create a new [`TypeCoercionRewriter`] with a provided schema
158    /// representing both the inputs and output of the [`LogicalPlan`] node.
159    pub fn new(schema: &'a DFSchema) -> Self {
160        Self { schema }
161    }
162
163    /// Coerce the [`LogicalPlan`].
164    ///
165    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
166    /// for type-coercion approach.
167    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168        match plan {
169            LogicalPlan::Join(join) => self.coerce_join(join),
170            LogicalPlan::Union(union) => Self::coerce_union(union),
171            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172            _ => Ok(plan),
173        }
174    }
175
176    /// Coerce join equality expressions and join filter
177    ///
178    /// Joins must be treated specially as their equality expressions are stored
179    /// as a parallel list of left and right expressions, rather than a single
180    /// equality expression
181    ///
182    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
183    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
184    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185        join.on = join
186            .on
187            .into_iter()
188            .map(|(lhs, rhs)| {
189                // coerce the arguments as though they were a single binary equality
190                // expression
191                let left_schema = join.left.schema();
192                let right_schema = join.right.schema();
193                let (lhs, rhs) = self.coerce_binary_op(
194                    lhs,
195                    left_schema,
196                    Operator::Eq,
197                    rhs,
198                    right_schema,
199                )?;
200                Ok((lhs, rhs))
201            })
202            .collect::<Result<Vec<_>>>()?;
203
204        // Join filter must be boolean
205        join.filter = join
206            .filter
207            .map(|expr| self.coerce_join_filter(expr))
208            .transpose()?;
209
210        Ok(LogicalPlan::Join(join))
211    }
212
213    /// Coerce the union’s inputs to a common schema compatible with all inputs.
214    /// This occurs after wildcard expansion and the coercion of the input expressions.
215    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216        let union_schema = Arc::new(coerce_union_schema_with_schema(
217            &union_plan.inputs,
218            &union_plan.schema,
219        )?);
220        let new_inputs = union_plan
221            .inputs
222            .into_iter()
223            .map(|p| {
224                let plan =
225                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226                match plan {
227                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
228                        Ok(Arc::new(project_with_column_index(
229                            expr,
230                            input,
231                            Arc::clone(&union_schema),
232                        )?))
233                    }
234                    other_plan => Ok(Arc::new(other_plan)),
235                }
236            })
237            .collect::<Result<Vec<_>>>()?;
238        Ok(LogicalPlan::Union(Union {
239            inputs: new_inputs,
240            schema: union_schema,
241        }))
242    }
243
244    /// Coerce the fetch and skip expression to Int64 type.
245    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246        fn coerce_limit_expr(
247            expr: Expr,
248            schema: &DFSchema,
249            expr_name: &str,
250        ) -> Result<Expr> {
251            let dt = expr.get_type(schema)?;
252            if dt.is_integer() || dt.is_null() {
253                expr.cast_to(&DataType::Int64, schema)
254            } else {
255                plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
256            }
257        }
258
259        let empty_schema = DFSchema::empty();
260        let new_fetch = limit
261            .fetch
262            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263            .transpose()?;
264        let new_skip = limit
265            .skip
266            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267            .transpose()?;
268        Ok(LogicalPlan::Limit(Limit {
269            input: limit.input,
270            fetch: new_fetch.map(Box::new),
271            skip: new_skip.map(Box::new),
272        }))
273    }
274
275    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276        let expr_type = expr.get_type(self.schema)?;
277        match expr_type {
278            DataType::Boolean => Ok(expr),
279            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281        }
282    }
283
284    fn coerce_binary_op(
285        &self,
286        left: Expr,
287        left_schema: &DFSchema,
288        op: Operator,
289        right: Expr,
290        right_schema: &DFSchema,
291    ) -> Result<(Expr, Expr)> {
292        let left_data_type = left.get_type(left_schema)?;
293        let right_data_type = right.get_type(right_schema)?;
294        let (left_type, right_type) =
295            BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type)
296                .get_input_types()?;
297        let left_cast_ok = can_cast_types(&left_data_type, &left_type);
298        let right_cast_ok = can_cast_types(&right_data_type, &right_type);
299
300        // handle special cases for
301        // * Date +/- int => Date
302        // * Date + time => Timestamp
303        let left_expr = if !left_cast_ok {
304            Self::coerce_date_time_math_op(
305                left,
306                &op,
307                &left_data_type,
308                &left_type,
309                &right_type,
310            )?
311        } else {
312            left.cast_to(&left_type, left_schema)?
313        };
314
315        let right_expr = if !right_cast_ok {
316            Self::coerce_date_time_math_op(
317                right,
318                &op,
319                &right_data_type,
320                &right_type,
321                &left_type,
322            )?
323        } else {
324            right.cast_to(&right_type, right_schema)?
325        };
326
327        Ok((left_expr, right_expr))
328    }
329
330    fn coerce_date_time_math_op(
331        expr: Expr,
332        op: &Operator,
333        left_current_type: &DataType,
334        left_target_type: &DataType,
335        right_target_type: &DataType,
336    ) -> Result<Expr, DataFusionError> {
337        use DataType::*;
338
339        fn cast(expr: Expr, target_type: DataType) -> Expr {
340            Expr::Cast(Cast::new(Box::new(expr), target_type))
341        }
342
343        fn time_to_nanos(
344            expr: Expr,
345            expr_type: &DataType,
346        ) -> Result<Expr, DataFusionError> {
347            let expr = match expr_type {
348                Time32(TimeUnit::Second) => {
349                    cast(cast(expr, Int32), Int64)
350                        * lit(ScalarValue::Int64(Some(1_000_000_000)))
351                }
352                Time32(TimeUnit::Millisecond) => {
353                    cast(cast(expr, Int32), Int64)
354                        * lit(ScalarValue::Int64(Some(1_000_000)))
355                }
356                Time64(TimeUnit::Microsecond) => {
357                    cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000)))
358                }
359                Time64(TimeUnit::Nanosecond) => cast(expr, Int64),
360                t => return internal_err!("Unexpected time data type {t}"),
361            };
362
363            Ok(expr)
364        }
365
366        let e = match (
367            &op,
368            &left_current_type,
369            &left_target_type,
370            &right_target_type,
371        ) {
372            // int +/- date => date
373            (
374                Operator::Plus | Operator::Minus,
375                Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
376                Interval(IntervalUnit::MonthDayNano),
377                Date32 | Date64,
378            ) => {
379                // cast to i64 first
380                let expr = match *left_current_type {
381                    Int64 => expr,
382                    _ => cast(expr, Int64),
383                };
384                // next, multiply by 86400 to get seconds
385                let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY));
386                // cast to duration
387                let expr = cast(expr, Duration(TimeUnit::Second));
388                // finally cast to interval
389                cast(expr, Interval(IntervalUnit::MonthDayNano))
390            }
391            // These might seem to be a bit convoluted, however for arrow to do date + time arithmetic
392            // date must be cast to Timestamp(Nanosecond) and time cast to Duration(Nanosecond)
393            // (they must be the same timeunit).
394            //
395            // For Time32/64 we first need to cast to an Int64, convert that to nanoseconds based
396            // on the time unit, then cast that to duration.
397            //
398            // Time + date -> timestamp or
399            (
400                Operator::Plus | Operator::Minus,
401                Time32(_) | Time64(_),
402                Duration(TimeUnit::Nanosecond),
403                Timestamp(TimeUnit::Nanosecond, None),
404            ) => {
405                // cast to int64, convert to nanoseconds
406                let expr = time_to_nanos(expr, left_current_type)?;
407                // cast to duration
408                cast(expr, Duration(TimeUnit::Nanosecond))
409            }
410            // Similar to above, for arrow to do time - time we need to convert to an interval.
411            // To do that we first need to cast to an Int64, convert that to nanoseconds based
412            // on the time unit, then cast that to duration, then finally cast to an interval.
413            //
414            // Time - time -> timestamp
415            (
416                Operator::Plus | Operator::Minus,
417                Time32(_) | Time64(_),
418                Interval(IntervalUnit::MonthDayNano),
419                Interval(IntervalUnit::MonthDayNano),
420            ) => {
421                // cast to int64, convert to nanoseconds
422                let expr = time_to_nanos(expr, left_current_type)?;
423                // cast to duration
424                let expr = cast(expr, Duration(TimeUnit::Nanosecond));
425                // finally cast to interval
426                cast(expr, Interval(IntervalUnit::MonthDayNano))
427            }
428            _ => {
429                return plan_err!(
430                    "Cannot automatically convert {left_current_type} to {left_target_type}"
431                );
432            }
433        };
434
435        Ok(e)
436    }
437}
438
439impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
440    type Node = Expr;
441
442    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
443        match expr {
444            Expr::Unnest(_) => not_impl_err!(
445                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
446            ),
447            Expr::ScalarSubquery(Subquery {
448                subquery,
449                outer_ref_columns,
450                spans,
451            }) => {
452                let new_plan =
453                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
454                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
455                    subquery: Arc::new(new_plan),
456                    outer_ref_columns,
457                    spans,
458                })))
459            }
460            Expr::Exists(Exists { subquery, negated }) => {
461                let new_plan = analyze_internal(
462                    self.schema,
463                    Arc::unwrap_or_clone(subquery.subquery),
464                )?
465                .data;
466                Ok(Transformed::yes(Expr::Exists(Exists {
467                    subquery: Subquery {
468                        subquery: Arc::new(new_plan),
469                        outer_ref_columns: subquery.outer_ref_columns,
470                        spans: subquery.spans,
471                    },
472                    negated,
473                })))
474            }
475            Expr::InSubquery(InSubquery {
476                expr,
477                subquery,
478                negated,
479            }) => {
480                let new_plan = analyze_internal(
481                    self.schema,
482                    Arc::unwrap_or_clone(subquery.subquery),
483                )?
484                .data;
485                let expr_type = expr.get_type(self.schema)?;
486                let subquery_type = new_plan.schema().field(0).data_type();
487                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
488                    plan_datafusion_err!(
489                    "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
490                ),
491                )?;
492                let new_subquery = Subquery {
493                    subquery: Arc::new(new_plan),
494                    outer_ref_columns: subquery.outer_ref_columns,
495                    spans: subquery.spans,
496                };
497                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
498                    Box::new(expr.cast_to(&common_type, self.schema)?),
499                    cast_subquery(new_subquery, &common_type)?,
500                    negated,
501                ))))
502            }
503            Expr::SetComparison(SetComparison {
504                expr,
505                subquery,
506                op,
507                quantifier,
508            }) => {
509                let new_plan = analyze_internal(
510                    self.schema,
511                    Arc::unwrap_or_clone(subquery.subquery),
512                )?
513                .data;
514                let expr_type = expr.get_type(self.schema)?;
515                let subquery_type = new_plan.schema().field(0).data_type();
516                if (expr_type.is_numeric() && subquery_type.is_string())
517                    || (subquery_type.is_numeric() && expr_type.is_string())
518                {
519                    return plan_err!(
520                        "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
521                    );
522                }
523                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
524                    plan_datafusion_err!(
525                        "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
526                    ),
527                )?;
528                let new_subquery = Subquery {
529                    subquery: Arc::new(new_plan),
530                    outer_ref_columns: subquery.outer_ref_columns,
531                    spans: subquery.spans,
532                };
533                Ok(Transformed::yes(Expr::SetComparison(SetComparison::new(
534                    Box::new(expr.cast_to(&common_type, self.schema)?),
535                    cast_subquery(new_subquery, &common_type)?,
536                    op,
537                    quantifier,
538                ))))
539            }
540            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
541                *expr,
542                self.schema,
543            )?))),
544            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
545                get_casted_expr_for_bool_op(*expr, self.schema)?,
546            ))),
547            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
548                get_casted_expr_for_bool_op(*expr, self.schema)?,
549            ))),
550            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
551                get_casted_expr_for_bool_op(*expr, self.schema)?,
552            ))),
553            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
554                get_casted_expr_for_bool_op(*expr, self.schema)?,
555            ))),
556            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
557                get_casted_expr_for_bool_op(*expr, self.schema)?,
558            ))),
559            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
560                get_casted_expr_for_bool_op(*expr, self.schema)?,
561            ))),
562            Expr::Like(Like {
563                negated,
564                expr,
565                pattern,
566                escape_char,
567                case_insensitive,
568            }) => {
569                let left_type = expr.get_type(self.schema)?;
570                let right_type = pattern.get_type(self.schema)?;
571                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
572                    let op_name = if case_insensitive {
573                        "ILIKE"
574                    } else {
575                        "LIKE"
576                    };
577                    plan_datafusion_err!(
578                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
579                    )
580                })?;
581                let expr = match left_type {
582                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
583                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
584                };
585                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
586                Ok(Transformed::yes(Expr::Like(Like::new(
587                    negated,
588                    expr,
589                    pattern,
590                    escape_char,
591                    case_insensitive,
592                ))))
593            }
594            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
595                let (left, right) =
596                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
597                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
598                    Box::new(left),
599                    op,
600                    Box::new(right),
601                ))))
602            }
603            Expr::Between(Between {
604                expr,
605                negated,
606                low,
607                high,
608            }) => {
609                let expr_type = expr.get_type(self.schema)?;
610                let low_type = low.get_type(self.schema)?;
611                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
612                    .ok_or_else(|| {
613                        internal_datafusion_err!(
614                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
615                        )
616                    })?;
617                let high_type = high.get_type(self.schema)?;
618                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
619                    .ok_or_else(|| {
620                        internal_datafusion_err!(
621                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
622                        )
623                    })?;
624                let coercion_type =
625                    comparison_coercion(&low_coerced_type, &high_coerced_type)
626                        .ok_or_else(|| {
627                            internal_datafusion_err!(
628                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
629                            )
630                        })?;
631                Ok(Transformed::yes(Expr::Between(Between::new(
632                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
633                    negated,
634                    Box::new(low.cast_to(&coercion_type, self.schema)?),
635                    Box::new(high.cast_to(&coercion_type, self.schema)?),
636                ))))
637            }
638            Expr::InList(InList {
639                expr,
640                list,
641                negated,
642            }) => {
643                let expr_data_type = expr.get_type(self.schema)?;
644                let list_data_types = list
645                    .iter()
646                    .map(|list_expr| list_expr.get_type(self.schema))
647                    .collect::<Result<Vec<_>>>()?;
648                let result_type =
649                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
650                match result_type {
651                    None => plan_err!(
652                        "Can not find compatible types to compare {expr_data_type} with [{}]",
653                        list_data_types.iter().join(", ")
654                    ),
655                    Some(coerced_type) => {
656                        // find the coerced type
657                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
658                        let cast_list_expr = list
659                            .into_iter()
660                            .map(|list_expr| {
661                                list_expr.cast_to(&coerced_type, self.schema)
662                            })
663                            .collect::<Result<Vec<_>>>()?;
664                        Ok(Transformed::yes(Expr::InList(InList::new(
665                            Box::new(cast_expr),
666                            cast_list_expr,
667                            negated,
668                        ))))
669                    }
670                }
671            }
672            Expr::Case(case) => {
673                let case = coerce_case_expression(case, self.schema)?;
674                Ok(Transformed::yes(Expr::Case(case)))
675            }
676            Expr::ScalarFunction(ScalarFunction { func, args }) => {
677                let new_expr =
678                    coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
679                Ok(Transformed::yes(Expr::ScalarFunction(
680                    ScalarFunction::new_udf(func, new_expr),
681                )))
682            }
683            Expr::AggregateFunction(expr::AggregateFunction {
684                func,
685                params:
686                    AggregateFunctionParams {
687                        args,
688                        distinct,
689                        filter,
690                        order_by,
691                        null_treatment,
692                    },
693            }) => {
694                let new_expr =
695                    coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
696                Ok(Transformed::yes(Expr::AggregateFunction(
697                    expr::AggregateFunction::new_udf(
698                        func,
699                        new_expr,
700                        distinct,
701                        filter,
702                        order_by,
703                        null_treatment,
704                    ),
705                )))
706            }
707            Expr::WindowFunction(window_fun) => {
708                let WindowFunction {
709                    fun,
710                    params:
711                        expr::WindowFunctionParams {
712                            args,
713                            partition_by,
714                            order_by,
715                            window_frame,
716                            filter,
717                            null_treatment,
718                            distinct,
719                        },
720                } = *window_fun;
721                let window_frame =
722                    coerce_window_frame(window_frame, self.schema, &order_by)?;
723
724                let args = match &fun {
725                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
726                        coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
727                    }
728                    expr::WindowFunctionDefinition::WindowUDF(udf) => {
729                        coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
730                    }
731                };
732
733                let new_expr = Expr::from(WindowFunction {
734                    fun,
735                    params: expr::WindowFunctionParams {
736                        args,
737                        partition_by,
738                        order_by,
739                        window_frame,
740                        filter,
741                        null_treatment,
742                        distinct,
743                    },
744                });
745                Ok(Transformed::yes(new_expr))
746            }
747            // TODO: remove the next line after `Expr::Wildcard` is removed
748            #[expect(deprecated)]
749            Expr::Alias(_)
750            | Expr::Column(_)
751            | Expr::ScalarVariable(_, _)
752            | Expr::Literal(_, _)
753            | Expr::SimilarTo(_)
754            | Expr::IsNotNull(_)
755            | Expr::IsNull(_)
756            | Expr::Negative(_)
757            | Expr::Cast(_)
758            | Expr::TryCast(_)
759            | Expr::Wildcard { .. }
760            | Expr::GroupingSet(_)
761            | Expr::Placeholder(_)
762            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
763        }
764    }
765}
766
767/// Transform a schema to use non-view types for Utf8View and BinaryView
768fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
769    let metadata = dfschema.as_arrow().metadata.clone();
770    let mut transformed = false;
771
772    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
773        dfschema
774            .iter()
775            .map(|(qualifier, field)| match field.data_type() {
776                DataType::Utf8View => {
777                    transformed = true;
778                    (
779                        qualifier.cloned() as Option<TableReference>,
780                        Arc::new(Field::new(
781                            field.name(),
782                            DataType::LargeUtf8,
783                            field.is_nullable(),
784                        )),
785                    )
786                }
787                DataType::BinaryView => {
788                    transformed = true;
789                    (
790                        qualifier.cloned() as Option<TableReference>,
791                        Arc::new(Field::new(
792                            field.name(),
793                            DataType::LargeBinary,
794                            field.is_nullable(),
795                        )),
796                    )
797                }
798                _ => (
799                    qualifier.cloned() as Option<TableReference>,
800                    Arc::clone(field),
801                ),
802            })
803            .unzip();
804
805    if !transformed {
806        return None;
807    }
808
809    let schema = Schema::new_with_metadata(transformed_fields, metadata);
810    Some(DFSchema::from_field_specific_qualified_schema(
811        qualifiers,
812        &Arc::new(schema),
813    ))
814}
815
816/// Casts the given `value` to `target_type`. Note that this function
817/// only considers `Null` or `Utf8` values.
818fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
819    match value {
820        // Coerce Utf8 values:
821        ScalarValue::Utf8(Some(val)) => {
822            ScalarValue::try_from_string(val.clone(), target_type)
823        }
824        s => {
825            if s.is_null() {
826                // Coerce `Null` values:
827                ScalarValue::try_from(target_type)
828            } else {
829                // Values except `Utf8`/`Null` variants already have the right type
830                // (casted before) since we convert `sqlparser` outputs to `Utf8`
831                // for all possible cases. Therefore, we return a clone here.
832                Ok(s.clone())
833            }
834        }
835    }
836}
837
838/// This function coerces `value` to `target_type` in a range-aware fashion.
839/// If the coercion is successful, we return an `Ok` value with the result.
840/// If the coercion fails because `target_type` is not wide enough (i.e. we
841/// can not coerce to `target_type`, but we can to a wider type in the same
842/// family), we return a `Null` value of this type to signal this situation.
843/// Downstream code uses this signal to treat these values as *unbounded*.
844fn coerce_scalar_range_aware(
845    target_type: &DataType,
846    value: &ScalarValue,
847) -> Result<ScalarValue> {
848    coerce_scalar(target_type, value).or_else(|err| {
849        // If type coercion fails, check if the largest type in family works:
850        if let Some(largest_type) = get_widest_type_in_family(target_type) {
851            coerce_scalar(largest_type, value).map_or_else(
852                |_| exec_err!("Cannot cast {value:?} to {target_type}"),
853                |_| ScalarValue::try_from(target_type),
854            )
855        } else {
856            Err(err)
857        }
858    })
859}
860
861/// This function returns the widest type in the family of `given_type`.
862/// If the given type is already the widest type, it returns `None`.
863/// For example, if `given_type` is `Int8`, it returns `Int64`.
864fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
865    match given_type {
866        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
867        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
868        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
869        _ => None,
870    }
871}
872
873/// Coerces the given (window frame) `bound` to `target_type`.
874fn coerce_frame_bound(
875    target_type: &DataType,
876    bound: WindowFrameBound,
877) -> Result<WindowFrameBound> {
878    match bound {
879        WindowFrameBound::Preceding(v) => {
880            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
881        }
882        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
883        WindowFrameBound::Following(v) => {
884            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
885        }
886    }
887}
888
889fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
890    if col_type.is_numeric()
891        || col_type.is_string()
892        || col_type.is_null()
893        || matches!(
894            col_type,
895            DataType::List(_)
896                | DataType::LargeList(_)
897                | DataType::FixedSizeList(_, _)
898                | DataType::Boolean
899        )
900    {
901        Ok(col_type.clone())
902    } else if is_datetime(col_type) {
903        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
904    } else if let DataType::Dictionary(_, value_type) = col_type {
905        extract_window_frame_target_type(value_type)
906    } else {
907        internal_err!("Cannot run range queries on datatype: {col_type}")
908    }
909}
910
911// Coerces the given `window_frame` to use appropriate natural types.
912// For example, ROWS and GROUPS frames use `UInt64` during calculations.
913fn coerce_window_frame(
914    window_frame: WindowFrame,
915    schema: &DFSchema,
916    expressions: &[Sort],
917) -> Result<WindowFrame> {
918    let mut window_frame = window_frame;
919    let target_type = match window_frame.units {
920        WindowFrameUnits::Range => {
921            let current_types = expressions
922                .first()
923                .map(|s| s.expr.get_type(schema))
924                .transpose()?;
925            if let Some(col_type) = current_types {
926                extract_window_frame_target_type(&col_type)?
927            } else {
928                return internal_err!("ORDER BY column cannot be empty");
929            }
930        }
931        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
932    };
933    window_frame.start_bound =
934        coerce_frame_bound(&target_type, window_frame.start_bound)?;
935    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
936    Ok(window_frame)
937}
938
939// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
940// The above op will be rewrite to the binary op when creating the physical op.
941fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
942    let left_type = expr.get_type(schema)?;
943    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
944        .get_input_types()?;
945    expr.cast_to(&DataType::Boolean, schema)
946}
947
948/// Returns `expressions` coerced to types compatible with
949/// `signature`, if possible.
950///
951/// See the module level documentation for more detail on coercion.
952fn coerce_arguments_for_signature<F: UDFCoercionExt>(
953    expressions: Vec<Expr>,
954    schema: &DFSchema,
955    func: &F,
956) -> Result<Vec<Expr>> {
957    let current_fields = expressions
958        .iter()
959        .map(|e| e.to_field(schema).map(|(_, f)| f))
960        .collect::<Result<Vec<_>>>()?;
961
962    let coerced_types = fields_with_udf(&current_fields, func)?
963        .into_iter()
964        .map(|f| f.data_type().clone())
965        .collect::<Vec<_>>();
966
967    expressions
968        .into_iter()
969        .enumerate()
970        .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
971        .collect()
972}
973
974fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
975    // Given expressions like:
976    //
977    // CASE a1
978    //   WHEN a2 THEN b1
979    //   WHEN a3 THEN b2
980    //   ELSE b3
981    // END
982    //
983    // or:
984    //
985    // CASE
986    //   WHEN x1 THEN b1
987    //   WHEN x2 THEN b2
988    //   ELSE b3
989    // END
990    //
991    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
992    // (case-when expression coercion)
993    //
994    // All xN (x1, x2) must be converted to a boolean data type in the second example
995    // (when-boolean expression coercion)
996    //
997    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
998    // (then-else expression coercion)
999    //
1000    // If any fail to find and cast to a common/specific data type, will return error
1001    //
1002    // Note that case-when and when-boolean expression coercions are mutually exclusive
1003    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
1004
1005    // prepare types
1006    let case_type = case
1007        .expr
1008        .as_ref()
1009        .map(|expr| expr.get_type(schema))
1010        .transpose()?;
1011    let then_types = case
1012        .when_then_expr
1013        .iter()
1014        .map(|(_when, then)| then.get_type(schema))
1015        .collect::<Result<Vec<_>>>()?;
1016    let else_type = case
1017        .else_expr
1018        .as_ref()
1019        .map(|expr| expr.get_type(schema))
1020        .transpose()?;
1021
1022    // find common coercible types
1023    let case_when_coerce_type = case_type
1024        .as_ref()
1025        .map(|case_type| {
1026            let when_types = case
1027                .when_then_expr
1028                .iter()
1029                .map(|(when, _then)| when.get_type(schema))
1030                .collect::<Result<Vec<_>>>()?;
1031            let coerced_type =
1032                get_coerce_type_for_case_expression(&when_types, Some(case_type));
1033            coerced_type.ok_or_else(|| {
1034                plan_datafusion_err!(
1035                    "Failed to coerce case ({case_type}) and when ({}) \
1036                     to common types in CASE WHEN expression",
1037                    when_types.iter().join(", ")
1038                )
1039            })
1040        })
1041        .transpose()?;
1042    let then_else_coerce_type =
1043        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
1044            || {
1045                if let Some(else_type) = else_type {
1046                    plan_datafusion_err!(
1047                        "Failed to coerce then ({}) and else ({else_type}) \
1048                         to common types in CASE WHEN expression",
1049                        then_types.iter().join(", ")
1050                    )
1051                } else {
1052                    plan_datafusion_err!(
1053                        "Failed to coerce then ({}) and else (None) \
1054                         to common types in CASE WHEN expression",
1055                        then_types.iter().join(", ")
1056                    )
1057                }
1058            },
1059        )?;
1060
1061    // do cast if found common coercible types
1062    let case_expr = case
1063        .expr
1064        .zip(case_when_coerce_type.as_ref())
1065        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
1066        .transpose()?
1067        .map(Box::new);
1068    let when_then = case
1069        .when_then_expr
1070        .into_iter()
1071        .map(|(when, then)| {
1072            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
1073            let when = when.cast_to(when_type, schema).map_err(|e| {
1074                DataFusionError::Context(
1075                    format!(
1076                        "WHEN expressions in CASE couldn't be \
1077                         converted to common type ({when_type})"
1078                    ),
1079                    Box::new(e),
1080                )
1081            })?;
1082            let then = then.cast_to(&then_else_coerce_type, schema)?;
1083            Ok((Box::new(when), Box::new(then)))
1084        })
1085        .collect::<Result<Vec<_>>>()?;
1086    let else_expr = case
1087        .else_expr
1088        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
1089        .transpose()?
1090        .map(Box::new);
1091
1092    Ok(Case::new(case_expr, when_then, else_expr))
1093}
1094
1095/// Get a common schema that is compatible with all inputs of UNION.
1096///
1097/// This method presumes that the wildcard expansion is unneeded, or has already
1098/// been applied.
1099///
1100/// ## Schema and Field Handling in Union Coercion
1101///
1102/// **Processing order**: The function starts with the base schema (first input) and then
1103/// processes remaining inputs sequentially, with later inputs taking precedence in merging.
1104///
1105/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys.
1106///
1107/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys.
1108///
1109/// **Type coercion precedence**: The coerced type is determined by iteratively applying
1110/// `comparison_coercion()` between the accumulated type and each new input's type. The
1111/// result depends on type coercion rules, not input order.
1112///
1113/// **Nullability merging**: Nullability is accumulated using logical OR (`||`).
1114/// Once any input field is nullable, the result field becomes nullable permanently.
1115/// Later inputs can make a field nullable but cannot make it non-nullable.
1116///
1117/// **Field precedence**: Field names come from the first (base) schema, but the field properties
1118/// (nullability and field-level metadata) have later schemas taking precedence.
1119///
1120/// **Example**:
1121/// ```sql
1122/// SELECT a, b FROM table1  -- a: Int32, metadata {"source": "t1"}, nullable=false
1123/// UNION
1124/// SELECT a, b FROM table2  -- a: Int64, metadata {"source": "t2"}, nullable=true
1125/// UNION
1126/// SELECT a, b FROM table3  -- a: Int32, metadata {"encoding": "utf8"}, nullable=false
1127/// -- Result:
1128/// -- a: Int64 (from type coercion), nullable=true (from table2),
1129/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence)
1130/// ```
1131///
1132/// **Precedence Summary**:
1133/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order
1134/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR)
1135/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics)
1136pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1137    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1138}
1139fn coerce_union_schema_with_schema(
1140    inputs: &[Arc<LogicalPlan>],
1141    base_schema: &DFSchemaRef,
1142) -> Result<DFSchema> {
1143    let mut union_datatypes = base_schema
1144        .fields()
1145        .iter()
1146        .map(|f| f.data_type().clone())
1147        .collect::<Vec<_>>();
1148    let mut union_nullabilities = base_schema
1149        .fields()
1150        .iter()
1151        .map(|f| f.is_nullable())
1152        .collect::<Vec<_>>();
1153    let mut union_field_meta = base_schema
1154        .fields()
1155        .iter()
1156        .map(|f| f.metadata().clone())
1157        .collect::<Vec<_>>();
1158
1159    let mut metadata = base_schema.metadata().clone();
1160
1161    for (i, plan) in inputs.iter().enumerate() {
1162        let plan_schema = plan.schema();
1163        metadata.extend(plan_schema.metadata().clone());
1164
1165        if plan_schema.fields().len() != base_schema.fields().len() {
1166            return plan_err!(
1167                "Union schemas have different number of fields: \
1168                query 1 has {} fields whereas query {} has {} fields",
1169                base_schema.fields().len(),
1170                i + 1,
1171                plan_schema.fields().len()
1172            );
1173        }
1174
1175        // coerce data type and nullability for each field
1176        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1177            union_datatypes.iter_mut(),
1178            union_nullabilities.iter_mut(),
1179            union_field_meta.iter_mut(),
1180            plan_schema.fields().iter()
1181        ) {
1182            let coerced_type =
1183                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1184                    || {
1185                        plan_datafusion_err!(
1186                            "Incompatible inputs for Union: Previous inputs were \
1187                            of type {}, but got incompatible type {} on column '{}'",
1188                            union_datatype,
1189                            plan_field.data_type(),
1190                            plan_field.name()
1191                        )
1192                    },
1193                )?;
1194
1195            *union_datatype = coerced_type;
1196            *union_nullable = *union_nullable || plan_field.is_nullable();
1197            union_field_map.extend(plan_field.metadata().clone());
1198        }
1199    }
1200    let union_qualified_fields = izip!(
1201        base_schema.fields(),
1202        union_datatypes.into_iter(),
1203        union_nullabilities,
1204        union_field_meta.into_iter()
1205    )
1206    .map(|(field, datatype, nullable, metadata)| {
1207        let mut field = Field::new(field.name().clone(), datatype, nullable);
1208        field.set_metadata(metadata);
1209        (None, field.into())
1210    })
1211    .collect::<Vec<_>>();
1212
1213    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1214}
1215
1216/// See `<https://github.com/apache/datafusion/pull/2108>`
1217fn project_with_column_index(
1218    expr: Vec<Expr>,
1219    input: Arc<LogicalPlan>,
1220    schema: DFSchemaRef,
1221) -> Result<LogicalPlan> {
1222    let alias_expr = expr
1223        .into_iter()
1224        .enumerate()
1225        .map(|(i, e)| match e {
1226            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1227                Ok(e.unalias().alias(schema.field(i).name()))
1228            }
1229            Expr::Column(Column {
1230                relation: _,
1231                ref name,
1232                spans: _,
1233            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1234            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1235            #[expect(deprecated)]
1236            Expr::Wildcard { .. } => {
1237                plan_err!("Wildcard should be expanded before type coercion")
1238            }
1239            _ => Ok(e.alias(schema.field(i).name())),
1240        })
1241        .collect::<Result<Vec<_>>>()?;
1242
1243    Projection::try_new_with_schema(alias_expr, input, schema)
1244        .map(LogicalPlan::Projection)
1245}
1246
1247#[cfg(test)]
1248mod test {
1249    use std::any::Any;
1250    use std::sync::Arc;
1251
1252    use arrow::datatypes::DataType::Utf8;
1253    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1254    use insta::assert_snapshot;
1255
1256    use crate::analyzer::Analyzer;
1257    use crate::analyzer::type_coercion::{
1258        TypeCoercion, TypeCoercionRewriter, coerce_case_expression,
1259    };
1260    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1261    use datafusion_common::config::ConfigOptions;
1262    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1263    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1264    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1265    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1266    use datafusion_expr::test::function_stub::avg_udaf;
1267    use datafusion_expr::{
1268        AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
1269        ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF,
1270        ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast,
1271        col, create_udaf, is_true, lit,
1272    };
1273    use datafusion_functions_aggregate::average::AvgAccumulator;
1274    use datafusion_sql::TableReference;
1275
1276    fn empty() -> Arc<LogicalPlan> {
1277        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1278            produce_one_row: false,
1279            schema: Arc::new(DFSchema::empty()),
1280        }))
1281    }
1282
1283    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1284        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1285            produce_one_row: false,
1286            schema: Arc::new(
1287                DFSchema::from_unqualified_fields(
1288                    vec![Field::new("a", data_type, true)].into(),
1289                    std::collections::HashMap::new(),
1290                )
1291                .unwrap(),
1292            ),
1293        }))
1294    }
1295
1296    macro_rules! assert_analyzed_plan_eq {
1297        (
1298            $plan: expr,
1299            @ $expected: literal $(,)?
1300        ) => {{
1301            let options = ConfigOptions::default();
1302            let rule = Arc::new(TypeCoercion::new());
1303            assert_analyzed_plan_with_config_eq_snapshot!(
1304                options,
1305                rule,
1306                $plan,
1307                @ $expected,
1308            )
1309            }};
1310    }
1311
1312    macro_rules! coerce_on_output_if_viewtype {
1313        (
1314            $is_viewtype: expr,
1315            $plan: expr,
1316            @ $expected: literal $(,)?
1317        ) => {{
1318            let mut options = ConfigOptions::default();
1319            // coerce on output
1320            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1321            let rule = Arc::new(TypeCoercion::new());
1322
1323            assert_analyzed_plan_with_config_eq_snapshot!(
1324                options,
1325                rule,
1326                $plan,
1327                @ $expected,
1328            )
1329        }};
1330    }
1331
1332    fn assert_type_coercion_error(
1333        plan: LogicalPlan,
1334        expected_substr: &str,
1335    ) -> Result<()> {
1336        let options = ConfigOptions::default();
1337        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1338
1339        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1340            Ok(succeeded_plan) => {
1341                panic!(
1342                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1343                );
1344            }
1345            Err(e) => {
1346                let msg = e.to_string();
1347                assert!(
1348                    msg.contains(expected_substr),
1349                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1350                );
1351            }
1352        }
1353
1354        Ok(())
1355    }
1356
1357    #[test]
1358    fn simple_case() -> Result<()> {
1359        let expr = col("a").lt(lit(2_u32));
1360        let empty = empty_with_type(DataType::Float64);
1361        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1362
1363        assert_analyzed_plan_eq!(
1364            plan,
1365            @r"
1366        Projection: a < CAST(UInt32(2) AS Float64)
1367          EmptyRelation: rows=0
1368        "
1369        )
1370    }
1371
1372    #[test]
1373    fn test_coerce_union() -> Result<()> {
1374        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1375            produce_one_row: false,
1376            schema: Arc::new(
1377                DFSchema::try_from_qualified_schema(
1378                    TableReference::full("datafusion", "test", "foo"),
1379                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1380                )
1381                .unwrap(),
1382            ),
1383        }));
1384        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1385            produce_one_row: false,
1386            schema: Arc::new(
1387                DFSchema::try_from_qualified_schema(
1388                    TableReference::full("datafusion", "test", "foo"),
1389                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1390                )
1391                .unwrap(),
1392            ),
1393        }));
1394        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1395            left_plan, right_plan,
1396        ])?);
1397        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1398            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1399        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1400            vec![col("a")],
1401            Arc::new(analyzed_union),
1402        )?);
1403
1404        assert_analyzed_plan_eq!(
1405            top_level_plan,
1406            @r"
1407        Projection: a
1408          Union
1409            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1410              EmptyRelation: rows=0
1411            EmptyRelation: rows=0
1412        "
1413        )
1414    }
1415
1416    #[test]
1417    fn coerce_utf8view_output() -> Result<()> {
1418        // Plan A
1419        // scenario: outermost utf8view projection
1420        let expr = col("a");
1421        let empty = empty_with_type(DataType::Utf8View);
1422        let plan = LogicalPlan::Projection(Projection::try_new(
1423            vec![expr.clone()],
1424            Arc::clone(&empty),
1425        )?);
1426
1427        // Plan A: no coerce
1428        coerce_on_output_if_viewtype!(
1429            false,
1430            plan.clone(),
1431            @r"
1432        Projection: a
1433          EmptyRelation: rows=0
1434        "
1435        )?;
1436
1437        // Plan A: coerce requested: Utf8View => LargeUtf8
1438        coerce_on_output_if_viewtype!(
1439            true,
1440            plan.clone(),
1441            @r"
1442        Projection: CAST(a AS LargeUtf8) AS a
1443          EmptyRelation: rows=0
1444        "
1445        )?;
1446
1447        // Plan B
1448        // scenario: outermost bool projection
1449        let bool_expr = col("a").lt(lit("foo"));
1450        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1451            vec![bool_expr],
1452            Arc::clone(&empty),
1453        )?);
1454        // Plan B: no coerce
1455        coerce_on_output_if_viewtype!(
1456            false,
1457            bool_plan.clone(),
1458            @r#"
1459        Projection: a < CAST(Utf8("foo") AS Utf8View)
1460          EmptyRelation: rows=0
1461        "#
1462        )?;
1463
1464        coerce_on_output_if_viewtype!(
1465            false,
1466            plan.clone(),
1467            @r"
1468        Projection: a
1469          EmptyRelation: rows=0
1470        "
1471        )?;
1472
1473        // Plan B: coerce requested: no coercion applied
1474        coerce_on_output_if_viewtype!(
1475            true,
1476            plan.clone(),
1477            @r"
1478        Projection: CAST(a AS LargeUtf8) AS a
1479          EmptyRelation: rows=0
1480        "
1481        )?;
1482
1483        // Plan C
1484        // scenario: with a non-projection root logical plan node
1485        let sort_expr = expr.sort(true, true);
1486        let sort_plan = LogicalPlan::Sort(Sort {
1487            expr: vec![sort_expr],
1488            input: Arc::new(plan),
1489            fetch: None,
1490        });
1491
1492        // Plan C: no coerce
1493        coerce_on_output_if_viewtype!(
1494            false,
1495            sort_plan.clone(),
1496            @r"
1497        Sort: a ASC NULLS FIRST
1498          Projection: a
1499            EmptyRelation: rows=0
1500        "
1501        )?;
1502
1503        // Plan C: coerce requested: Utf8View => LargeUtf8
1504        coerce_on_output_if_viewtype!(
1505            true,
1506            sort_plan.clone(),
1507            @r"
1508        Projection: CAST(a AS LargeUtf8) AS a
1509          Sort: a ASC NULLS FIRST
1510            Projection: a
1511              EmptyRelation: rows=0
1512        "
1513        )?;
1514
1515        // Plan D
1516        // scenario: two layers of projections with view types
1517        let plan = LogicalPlan::Projection(Projection::try_new(
1518            vec![col("a")],
1519            Arc::new(sort_plan),
1520        )?);
1521        // Plan D: no coerce
1522        coerce_on_output_if_viewtype!(
1523            false,
1524            plan.clone(),
1525            @r"
1526        Projection: a
1527          Sort: a ASC NULLS FIRST
1528            Projection: a
1529              EmptyRelation: rows=0
1530        "
1531        )?;
1532        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1533        coerce_on_output_if_viewtype!(
1534            true,
1535            plan.clone(),
1536            @r"
1537        Projection: CAST(a AS LargeUtf8) AS a
1538          Sort: a ASC NULLS FIRST
1539            Projection: a
1540              EmptyRelation: rows=0
1541        "
1542        )?;
1543
1544        Ok(())
1545    }
1546
1547    #[test]
1548    fn coerce_binaryview_output() -> Result<()> {
1549        // Plan A
1550        // scenario: outermost binaryview projection
1551        let expr = col("a");
1552        let empty = empty_with_type(DataType::BinaryView);
1553        let plan = LogicalPlan::Projection(Projection::try_new(
1554            vec![expr.clone()],
1555            Arc::clone(&empty),
1556        )?);
1557
1558        // Plan A: no coerce
1559        coerce_on_output_if_viewtype!(
1560            false,
1561            plan.clone(),
1562            @r"
1563        Projection: a
1564          EmptyRelation: rows=0
1565        "
1566        )?;
1567
1568        // Plan A: coerce requested: BinaryView => LargeBinary
1569        coerce_on_output_if_viewtype!(
1570            true,
1571            plan.clone(),
1572            @r"
1573        Projection: CAST(a AS LargeBinary) AS a
1574          EmptyRelation: rows=0
1575        "
1576        )?;
1577
1578        // Plan B
1579        // scenario: outermost bool projection
1580        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1581        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1582            vec![bool_expr],
1583            Arc::clone(&empty),
1584        )?);
1585
1586        // Plan B: no coerce
1587        coerce_on_output_if_viewtype!(
1588            false,
1589            bool_plan.clone(),
1590            @r#"
1591        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1592          EmptyRelation: rows=0
1593        "#
1594        )?;
1595
1596        // Plan B: coerce requested: no coercion applied
1597        coerce_on_output_if_viewtype!(
1598            true,
1599            bool_plan.clone(),
1600            @r#"
1601        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1602          EmptyRelation: rows=0
1603        "#
1604        )?;
1605
1606        // Plan C
1607        // scenario: with a non-projection root logical plan node
1608        let sort_expr = expr.sort(true, true);
1609        let sort_plan = LogicalPlan::Sort(Sort {
1610            expr: vec![sort_expr],
1611            input: Arc::new(plan),
1612            fetch: None,
1613        });
1614
1615        // Plan C: no coerce
1616        coerce_on_output_if_viewtype!(
1617            false,
1618            sort_plan.clone(),
1619            @r"
1620        Sort: a ASC NULLS FIRST
1621          Projection: a
1622            EmptyRelation: rows=0
1623        "
1624        )?;
1625        // Plan C: coerce requested: BinaryView => LargeBinary
1626        coerce_on_output_if_viewtype!(
1627            true,
1628            sort_plan.clone(),
1629            @r"
1630        Projection: CAST(a AS LargeBinary) AS a
1631          Sort: a ASC NULLS FIRST
1632            Projection: a
1633              EmptyRelation: rows=0
1634        "
1635        )?;
1636
1637        // Plan D
1638        // scenario: two layers of projections with view types
1639        let plan = LogicalPlan::Projection(Projection::try_new(
1640            vec![col("a")],
1641            Arc::new(sort_plan),
1642        )?);
1643
1644        // Plan D: no coerce
1645        coerce_on_output_if_viewtype!(
1646            false,
1647            plan.clone(),
1648            @r"
1649        Projection: a
1650          Sort: a ASC NULLS FIRST
1651            Projection: a
1652              EmptyRelation: rows=0
1653        "
1654        )?;
1655
1656        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1657        coerce_on_output_if_viewtype!(
1658            true,
1659            plan.clone(),
1660            @r"
1661        Projection: CAST(a AS LargeBinary) AS a
1662          Sort: a ASC NULLS FIRST
1663            Projection: a
1664              EmptyRelation: rows=0
1665        "
1666        )?;
1667
1668        Ok(())
1669    }
1670
1671    #[test]
1672    fn nested_case() -> Result<()> {
1673        let expr = col("a").lt(lit(2_u32));
1674        let empty = empty_with_type(DataType::Float64);
1675
1676        let plan = LogicalPlan::Projection(Projection::try_new(
1677            vec![expr.clone().or(expr)],
1678            empty,
1679        )?);
1680
1681        assert_analyzed_plan_eq!(
1682            plan,
1683            @r"
1684        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1685          EmptyRelation: rows=0
1686        "
1687        )
1688    }
1689
1690    #[derive(Debug, PartialEq, Eq, Hash)]
1691    struct TestScalarUDF {
1692        signature: Signature,
1693    }
1694
1695    impl ScalarUDFImpl for TestScalarUDF {
1696        fn as_any(&self) -> &dyn Any {
1697            self
1698        }
1699
1700        fn name(&self) -> &str {
1701            "TestScalarUDF"
1702        }
1703
1704        fn signature(&self) -> &Signature {
1705            &self.signature
1706        }
1707
1708        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1709            Ok(Utf8)
1710        }
1711
1712        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1713            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1714        }
1715    }
1716
1717    #[test]
1718    fn scalar_udf() -> Result<()> {
1719        let empty = empty();
1720
1721        let udf = ScalarUDF::from(TestScalarUDF {
1722            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1723        })
1724        .call(vec![lit(123_i32)]);
1725        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1726
1727        assert_analyzed_plan_eq!(
1728            plan,
1729            @r"
1730        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1731          EmptyRelation: rows=0
1732        "
1733        )
1734    }
1735
1736    #[test]
1737    fn scalar_udf_invalid_input() -> Result<()> {
1738        let empty = empty();
1739        let udf = ScalarUDF::from(TestScalarUDF {
1740            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1741        })
1742        .call(vec![lit("Apple")]);
1743        Projection::try_new(vec![udf], empty)
1744            .expect_err("Expected an error due to incorrect function input");
1745
1746        Ok(())
1747    }
1748
1749    #[test]
1750    fn scalar_function() -> Result<()> {
1751        // test that automatic argument type coercion for scalar functions work
1752        let empty = empty();
1753        let lit_expr = lit(10i64);
1754        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1755            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1756        });
1757        let scalar_function_expr =
1758            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1759        let plan = LogicalPlan::Projection(Projection::try_new(
1760            vec![scalar_function_expr],
1761            empty,
1762        )?);
1763
1764        assert_analyzed_plan_eq!(
1765            plan,
1766            @r"
1767        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1768          EmptyRelation: rows=0
1769        "
1770        )
1771    }
1772
1773    #[test]
1774    fn agg_udaf() -> Result<()> {
1775        let empty = empty();
1776        let my_avg = create_udaf(
1777            "MY_AVG",
1778            vec![DataType::Float64],
1779            Arc::new(DataType::Float64),
1780            Volatility::Immutable,
1781            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1782            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1783        );
1784        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1785            Arc::new(my_avg),
1786            vec![lit(10i64)],
1787            false,
1788            None,
1789            vec![],
1790            None,
1791        ));
1792        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1793
1794        assert_analyzed_plan_eq!(
1795            plan,
1796            @r"
1797        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1798          EmptyRelation: rows=0
1799        "
1800        )
1801    }
1802
1803    #[test]
1804    fn agg_udaf_invalid_input() -> Result<()> {
1805        let empty = empty();
1806        let return_type = DataType::Float64;
1807        let accumulator: AccumulatorFactoryFunction =
1808            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1809        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1810            "MY_AVG",
1811            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1812            return_type,
1813            accumulator,
1814            vec![
1815                Field::new("count", DataType::UInt64, true).into(),
1816                Field::new("avg", DataType::Float64, true).into(),
1817            ],
1818        ));
1819        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1820            Arc::new(my_avg),
1821            vec![lit("10")],
1822            false,
1823            None,
1824            vec![],
1825            None,
1826        ));
1827
1828        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1829        assert!(
1830            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1831        );
1832        Ok(())
1833    }
1834
1835    #[test]
1836    fn agg_function_case() -> Result<()> {
1837        let empty = empty();
1838        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1839            avg_udaf(),
1840            vec![lit(12f64)],
1841            false,
1842            None,
1843            vec![],
1844            None,
1845        ));
1846        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1847
1848        assert_analyzed_plan_eq!(
1849            plan,
1850            @r"
1851        Projection: avg(Float64(12))
1852          EmptyRelation: rows=0
1853        "
1854        )?;
1855
1856        let empty = empty_with_type(DataType::Int32);
1857        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1858            avg_udaf(),
1859            vec![cast(col("a"), DataType::Float64)],
1860            false,
1861            None,
1862            vec![],
1863            None,
1864        ));
1865        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1866
1867        assert_analyzed_plan_eq!(
1868            plan,
1869            @r"
1870        Projection: avg(CAST(a AS Float64))
1871          EmptyRelation: rows=0
1872        "
1873        )
1874    }
1875
1876    #[test]
1877    fn agg_function_invalid_input_avg() -> Result<()> {
1878        let empty = empty();
1879        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1880            avg_udaf(),
1881            vec![lit("1")],
1882            false,
1883            None,
1884            vec![],
1885            None,
1886        ));
1887        let err = Projection::try_new(vec![agg_expr], empty)
1888            .err()
1889            .unwrap()
1890            .strip_backtrace();
1891        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed"));
1892        Ok(())
1893    }
1894
1895    #[test]
1896    fn binary_op_date32_op_interval() -> Result<()> {
1897        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1898        let expr = cast(lit("1998-03-18"), DataType::Date32)
1899            + lit(ScalarValue::new_interval_dt(123, 456));
1900        let empty = empty();
1901        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1902
1903        assert_analyzed_plan_eq!(
1904            plan,
1905            @r#"
1906        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1907          EmptyRelation: rows=0
1908        "#
1909        )
1910    }
1911
1912    #[test]
1913    fn inlist_case() -> Result<()> {
1914        // a in (1,4,8), a is int64
1915        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1916        let empty = empty_with_type(DataType::Int64);
1917        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1918        assert_analyzed_plan_eq!(
1919            plan,
1920            @r"
1921        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1922          EmptyRelation: rows=0
1923        ")?;
1924
1925        // a in (1,4,8), a is decimal
1926        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1927        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1928            produce_one_row: false,
1929            schema: Arc::new(DFSchema::from_unqualified_fields(
1930                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1931                std::collections::HashMap::new(),
1932            )?),
1933        }));
1934        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1935        assert_analyzed_plan_eq!(
1936            plan,
1937            @r"
1938        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1939          EmptyRelation: rows=0
1940        ")
1941    }
1942
1943    #[test]
1944    fn between_case() -> Result<()> {
1945        let expr = col("a").between(
1946            lit("2002-05-08"),
1947            // (cast('2002-05-08' as date) + interval '1 months')
1948            cast(lit("2002-05-08"), DataType::Date32)
1949                + lit(ScalarValue::new_interval_ym(0, 1)),
1950        );
1951        let empty = empty_with_type(Utf8);
1952        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1953
1954        assert_analyzed_plan_eq!(
1955            plan,
1956            @r#"
1957        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1958          EmptyRelation: rows=0
1959        "#
1960        )
1961    }
1962
1963    #[test]
1964    fn between_infer_cheap_type() -> Result<()> {
1965        let expr = col("a").between(
1966            // (cast('2002-05-08' as date) + interval '1 months')
1967            cast(lit("2002-05-08"), DataType::Date32)
1968                + lit(ScalarValue::new_interval_ym(0, 1)),
1969            lit("2002-12-08"),
1970        );
1971        let empty = empty_with_type(Utf8);
1972        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1973
1974        // TODO: we should cast col(a).
1975        assert_analyzed_plan_eq!(
1976            plan,
1977            @r#"
1978        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1979          EmptyRelation: rows=0
1980        "#
1981        )
1982    }
1983
1984    #[test]
1985    fn between_null() -> Result<()> {
1986        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1987        let empty = empty();
1988        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1989
1990        assert_analyzed_plan_eq!(
1991            plan,
1992            @r"
1993        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1994          EmptyRelation: rows=0
1995        "
1996        )
1997    }
1998
1999    #[test]
2000    fn is_bool_for_type_coercion() -> Result<()> {
2001        // is true
2002        let expr = col("a").is_true();
2003        let empty = empty_with_type(DataType::Boolean);
2004        let plan =
2005            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2006
2007        assert_analyzed_plan_eq!(
2008            plan,
2009            @r"
2010        Projection: a IS TRUE
2011          EmptyRelation: rows=0
2012        "
2013        )?;
2014
2015        let empty = empty_with_type(DataType::Int64);
2016        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2017        assert_type_coercion_error(
2018            plan,
2019            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean",
2020        )?;
2021
2022        // is not true
2023        let expr = col("a").is_not_true();
2024        let empty = empty_with_type(DataType::Boolean);
2025        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2026
2027        assert_analyzed_plan_eq!(
2028            plan,
2029            @r"
2030        Projection: a IS NOT TRUE
2031          EmptyRelation: rows=0
2032        "
2033        )?;
2034
2035        // is false
2036        let expr = col("a").is_false();
2037        let empty = empty_with_type(DataType::Boolean);
2038        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2039
2040        assert_analyzed_plan_eq!(
2041            plan,
2042            @r"
2043        Projection: a IS FALSE
2044          EmptyRelation: rows=0
2045        "
2046        )?;
2047
2048        // is not false
2049        let expr = col("a").is_not_false();
2050        let empty = empty_with_type(DataType::Boolean);
2051        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2052
2053        assert_analyzed_plan_eq!(
2054            plan,
2055            @r"
2056        Projection: a IS NOT FALSE
2057          EmptyRelation: rows=0
2058        "
2059        )
2060    }
2061
2062    #[test]
2063    fn like_for_type_coercion() -> Result<()> {
2064        // like : utf8 like "abc"
2065        let expr = Box::new(col("a"));
2066        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2067        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2068        let empty = empty_with_type(Utf8);
2069        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2070
2071        assert_analyzed_plan_eq!(
2072            plan,
2073            @r#"
2074        Projection: a LIKE Utf8("abc")
2075          EmptyRelation: rows=0
2076        "#
2077        )?;
2078
2079        let expr = Box::new(col("a"));
2080        let pattern = Box::new(lit(ScalarValue::Null));
2081        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2082        let empty = empty_with_type(Utf8);
2083        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2084
2085        assert_analyzed_plan_eq!(
2086            plan,
2087            @r"
2088        Projection: a LIKE CAST(NULL AS Utf8)
2089          EmptyRelation: rows=0
2090        "
2091        )?;
2092
2093        let expr = Box::new(col("a"));
2094        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2095        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2096        let empty = empty_with_type(DataType::Int64);
2097        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2098        assert_type_coercion_error(
2099            plan,
2100            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
2101        )?;
2102
2103        // ilike
2104        let expr = Box::new(col("a"));
2105        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2106        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2107        let empty = empty_with_type(Utf8);
2108        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2109
2110        assert_analyzed_plan_eq!(
2111            plan,
2112            @r#"
2113        Projection: a ILIKE Utf8("abc")
2114          EmptyRelation: rows=0
2115        "#
2116        )?;
2117
2118        let expr = Box::new(col("a"));
2119        let pattern = Box::new(lit(ScalarValue::Null));
2120        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2121        let empty = empty_with_type(Utf8);
2122        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2123
2124        assert_analyzed_plan_eq!(
2125            plan,
2126            @r"
2127        Projection: a ILIKE CAST(NULL AS Utf8)
2128          EmptyRelation: rows=0
2129        "
2130        )?;
2131
2132        let expr = Box::new(col("a"));
2133        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2134        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2135        let empty = empty_with_type(DataType::Int64);
2136        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2137        assert_type_coercion_error(
2138            plan,
2139            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2140        )?;
2141
2142        Ok(())
2143    }
2144
2145    #[test]
2146    fn unknown_for_type_coercion() -> Result<()> {
2147        // unknown
2148        let expr = col("a").is_unknown();
2149        let empty = empty_with_type(DataType::Boolean);
2150        let plan =
2151            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2152
2153        assert_analyzed_plan_eq!(
2154            plan,
2155            @r"
2156        Projection: a IS UNKNOWN
2157          EmptyRelation: rows=0
2158        "
2159        )?;
2160
2161        let empty = empty_with_type(Utf8);
2162        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2163        assert_type_coercion_error(
2164            plan,
2165            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean",
2166        )?;
2167
2168        // is not unknown
2169        let expr = col("a").is_not_unknown();
2170        let empty = empty_with_type(DataType::Boolean);
2171        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2172
2173        assert_analyzed_plan_eq!(
2174            plan,
2175            @r"
2176        Projection: a IS NOT UNKNOWN
2177          EmptyRelation: rows=0
2178        "
2179        )
2180    }
2181
2182    #[test]
2183    fn concat_for_type_coercion() -> Result<()> {
2184        let empty = empty_with_type(Utf8);
2185        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2186
2187        // concat-type signature
2188        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2189            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2190        })
2191        .call(args.to_vec());
2192        let plan =
2193            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2194        assert_analyzed_plan_eq!(
2195            plan,
2196            @r#"
2197        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2198          EmptyRelation: rows=0
2199        "#
2200        )
2201    }
2202
2203    #[test]
2204    fn test_type_coercion_rewrite() -> Result<()> {
2205        // gt
2206        let schema = Arc::new(DFSchema::from_unqualified_fields(
2207            vec![Field::new("a", DataType::Int64, true)].into(),
2208            std::collections::HashMap::new(),
2209        )?);
2210        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2211        let expr = is_true(lit(12i32).gt(lit(13i64)));
2212        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2213        let result = expr.rewrite(&mut rewriter).data()?;
2214        assert_eq!(expected, result);
2215
2216        // eq
2217        let schema = Arc::new(DFSchema::from_unqualified_fields(
2218            vec![Field::new("a", DataType::Int64, true)].into(),
2219            std::collections::HashMap::new(),
2220        )?);
2221        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2222        let expr = is_true(lit(12i32).eq(lit(13i64)));
2223        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2224        let result = expr.rewrite(&mut rewriter).data()?;
2225        assert_eq!(expected, result);
2226
2227        // lt
2228        let schema = Arc::new(DFSchema::from_unqualified_fields(
2229            vec![Field::new("a", DataType::Int64, true)].into(),
2230            std::collections::HashMap::new(),
2231        )?);
2232        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2233        let expr = is_true(lit(12i32).lt(lit(13i64)));
2234        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2235        let result = expr.rewrite(&mut rewriter).data()?;
2236        assert_eq!(expected, result);
2237
2238        Ok(())
2239    }
2240
2241    #[test]
2242    fn binary_op_date32_eq_ts() -> Result<()> {
2243        let expr = cast(
2244            lit("1998-03-18"),
2245            DataType::Timestamp(TimeUnit::Nanosecond, None),
2246        )
2247        .eq(cast(lit("1998-03-18"), DataType::Date32));
2248        let empty = empty();
2249        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2250
2251        assert_analyzed_plan_eq!(
2252            plan,
2253            @r#"
2254        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2255          EmptyRelation: rows=0
2256        "#
2257        )
2258    }
2259
2260    fn cast_if_not_same_type(
2261        expr: Box<Expr>,
2262        data_type: &DataType,
2263        schema: &DFSchemaRef,
2264    ) -> Box<Expr> {
2265        if &expr.get_type(schema).unwrap() != data_type {
2266            Box::new(cast(*expr, data_type.clone()))
2267        } else {
2268            expr
2269        }
2270    }
2271
2272    fn cast_helper(
2273        case: Case,
2274        case_when_type: &DataType,
2275        then_else_type: &DataType,
2276        schema: &DFSchemaRef,
2277    ) -> Case {
2278        let expr = case
2279            .expr
2280            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2281        let when_then_expr = case
2282            .when_then_expr
2283            .into_iter()
2284            .map(|(when, then)| {
2285                (
2286                    cast_if_not_same_type(when, case_when_type, schema),
2287                    cast_if_not_same_type(then, then_else_type, schema),
2288                )
2289            })
2290            .collect::<Vec<_>>();
2291        let else_expr = case
2292            .else_expr
2293            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2294
2295        Case {
2296            expr,
2297            when_then_expr,
2298            else_expr,
2299        }
2300    }
2301
2302    #[test]
2303    fn test_case_expression_coercion() -> Result<()> {
2304        let schema = Arc::new(DFSchema::from_unqualified_fields(
2305            vec![
2306                Field::new("boolean", DataType::Boolean, true),
2307                Field::new("integer", DataType::Int32, true),
2308                Field::new("float", DataType::Float32, true),
2309                Field::new(
2310                    "timestamp",
2311                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2312                    true,
2313                ),
2314                Field::new("date", DataType::Date32, true),
2315                Field::new(
2316                    "interval",
2317                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2318                    true,
2319                ),
2320                Field::new("binary", DataType::Binary, true),
2321                Field::new("string", Utf8, true),
2322                Field::new("decimal", DataType::Decimal128(10, 10), true),
2323            ]
2324            .into(),
2325            std::collections::HashMap::new(),
2326        )?);
2327
2328        let case = Case {
2329            expr: None,
2330            when_then_expr: vec![
2331                (Box::new(col("boolean")), Box::new(col("integer"))),
2332                (Box::new(col("integer")), Box::new(col("float"))),
2333                (Box::new(col("string")), Box::new(col("string"))),
2334            ],
2335            else_expr: None,
2336        };
2337        let case_when_common_type = DataType::Boolean;
2338        let then_else_common_type = Utf8;
2339        let expected = cast_helper(
2340            case.clone(),
2341            &case_when_common_type,
2342            &then_else_common_type,
2343            &schema,
2344        );
2345        let actual = coerce_case_expression(case, &schema)?;
2346        assert_eq!(expected, actual);
2347
2348        let case = Case {
2349            expr: Some(Box::new(col("string"))),
2350            when_then_expr: vec![
2351                (Box::new(col("float")), Box::new(col("integer"))),
2352                (Box::new(col("integer")), Box::new(col("float"))),
2353                (Box::new(col("string")), Box::new(col("string"))),
2354            ],
2355            else_expr: Some(Box::new(col("string"))),
2356        };
2357        let case_when_common_type = Utf8;
2358        let then_else_common_type = Utf8;
2359        let expected = cast_helper(
2360            case.clone(),
2361            &case_when_common_type,
2362            &then_else_common_type,
2363            &schema,
2364        );
2365        let actual = coerce_case_expression(case, &schema)?;
2366        assert_eq!(expected, actual);
2367
2368        let case = Case {
2369            expr: Some(Box::new(col("interval"))),
2370            when_then_expr: vec![
2371                (Box::new(col("float")), Box::new(col("integer"))),
2372                (Box::new(col("binary")), Box::new(col("float"))),
2373                (Box::new(col("string")), Box::new(col("string"))),
2374            ],
2375            else_expr: Some(Box::new(col("string"))),
2376        };
2377        let err = coerce_case_expression(case, &schema).unwrap_err();
2378        assert_snapshot!(
2379            err.strip_backtrace(),
2380            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2381        );
2382
2383        let case = Case {
2384            expr: Some(Box::new(col("string"))),
2385            when_then_expr: vec![
2386                (Box::new(col("float")), Box::new(col("date"))),
2387                (Box::new(col("string")), Box::new(col("float"))),
2388                (Box::new(col("string")), Box::new(col("binary"))),
2389            ],
2390            else_expr: Some(Box::new(col("timestamp"))),
2391        };
2392        let err = coerce_case_expression(case, &schema).unwrap_err();
2393        assert_snapshot!(
2394            err.strip_backtrace(),
2395            @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2396        );
2397
2398        Ok(())
2399    }
2400
2401    macro_rules! test_case_expression {
2402        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2403            let case = Case {
2404                expr: $expr.map(|e| Box::new(col(e))),
2405                when_then_expr: $when_then,
2406                else_expr: None,
2407            };
2408
2409            let expected =
2410                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2411
2412            let actual = coerce_case_expression(case, &$schema)?;
2413            assert_eq!(expected, actual);
2414        };
2415    }
2416
2417    #[test]
2418    fn tes_case_when_list() -> Result<()> {
2419        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2420        let schema = Arc::new(DFSchema::from_unqualified_fields(
2421            vec![
2422                Field::new(
2423                    "large_list",
2424                    DataType::LargeList(Arc::clone(&inner_field)),
2425                    true,
2426                ),
2427                Field::new(
2428                    "fixed_list",
2429                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2430                    true,
2431                ),
2432                Field::new("list", DataType::List(inner_field), true),
2433            ]
2434            .into(),
2435            std::collections::HashMap::new(),
2436        )?);
2437
2438        test_case_expression!(
2439            Some("list"),
2440            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2441            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2442            Utf8,
2443            schema
2444        );
2445
2446        test_case_expression!(
2447            Some("large_list"),
2448            vec![(Box::new(col("list")), Box::new(lit("1")))],
2449            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2450            Utf8,
2451            schema
2452        );
2453
2454        test_case_expression!(
2455            Some("list"),
2456            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2457            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2458            Utf8,
2459            schema
2460        );
2461
2462        test_case_expression!(
2463            Some("fixed_list"),
2464            vec![(Box::new(col("list")), Box::new(lit("1")))],
2465            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2466            Utf8,
2467            schema
2468        );
2469
2470        test_case_expression!(
2471            Some("fixed_list"),
2472            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2473            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2474            Utf8,
2475            schema
2476        );
2477
2478        test_case_expression!(
2479            Some("large_list"),
2480            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2481            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2482            Utf8,
2483            schema
2484        );
2485        Ok(())
2486    }
2487
2488    #[test]
2489    fn test_then_else_list() -> Result<()> {
2490        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2491        let schema = Arc::new(DFSchema::from_unqualified_fields(
2492            vec![
2493                Field::new("boolean", DataType::Boolean, true),
2494                Field::new(
2495                    "large_list",
2496                    DataType::LargeList(Arc::clone(&inner_field)),
2497                    true,
2498                ),
2499                Field::new(
2500                    "fixed_list",
2501                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2502                    true,
2503                ),
2504                Field::new("list", DataType::List(inner_field), true),
2505            ]
2506            .into(),
2507            std::collections::HashMap::new(),
2508        )?);
2509
2510        // large list and list
2511        test_case_expression!(
2512            None::<String>,
2513            vec![
2514                (Box::new(col("boolean")), Box::new(col("large_list"))),
2515                (Box::new(col("boolean")), Box::new(col("list")))
2516            ],
2517            DataType::Boolean,
2518            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2519            schema
2520        );
2521
2522        test_case_expression!(
2523            None::<String>,
2524            vec![
2525                (Box::new(col("boolean")), Box::new(col("list"))),
2526                (Box::new(col("boolean")), Box::new(col("large_list")))
2527            ],
2528            DataType::Boolean,
2529            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2530            schema
2531        );
2532
2533        // fixed list and list
2534        test_case_expression!(
2535            None::<String>,
2536            vec![
2537                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2538                (Box::new(col("boolean")), Box::new(col("list")))
2539            ],
2540            DataType::Boolean,
2541            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2542            schema
2543        );
2544
2545        test_case_expression!(
2546            None::<String>,
2547            vec![
2548                (Box::new(col("boolean")), Box::new(col("list"))),
2549                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2550            ],
2551            DataType::Boolean,
2552            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2553            schema
2554        );
2555
2556        // fixed list and large list
2557        test_case_expression!(
2558            None::<String>,
2559            vec![
2560                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2561                (Box::new(col("boolean")), Box::new(col("large_list")))
2562            ],
2563            DataType::Boolean,
2564            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2565            schema
2566        );
2567
2568        test_case_expression!(
2569            None::<String>,
2570            vec![
2571                (Box::new(col("boolean")), Box::new(col("large_list"))),
2572                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2573            ],
2574            DataType::Boolean,
2575            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2576            schema
2577        );
2578        Ok(())
2579    }
2580
2581    #[test]
2582    fn test_map_with_diff_name() -> Result<()> {
2583        let mut builder = SchemaBuilder::new();
2584        builder.push(Field::new("key", Utf8, false));
2585        builder.push(Field::new("value", DataType::Float64, true));
2586        let struct_fields = builder.finish().fields;
2587
2588        let fields =
2589            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2590        let map_type_entries = DataType::Map(Arc::new(fields), false);
2591
2592        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2593        let may_type_custom = DataType::Map(Arc::new(fields), false);
2594
2595        let expr = col("a").eq(cast(col("a"), may_type_custom));
2596        let empty = empty_with_type(map_type_entries);
2597        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2598
2599        assert_analyzed_plan_eq!(
2600            plan,
2601            @r#"
2602        Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted))
2603          EmptyRelation: rows=0
2604        "#
2605        )
2606    }
2607
2608    #[test]
2609    fn interval_plus_timestamp() -> Result<()> {
2610        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2611        let expr = Expr::BinaryExpr(BinaryExpr::new(
2612            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2613            Operator::Plus,
2614            Box::new(cast(
2615                lit("2000-01-01T00:00:00"),
2616                DataType::Timestamp(TimeUnit::Nanosecond, None),
2617            )),
2618        ));
2619        let empty = empty();
2620        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2621
2622        assert_analyzed_plan_eq!(
2623            plan,
2624            @r#"
2625        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2626          EmptyRelation: rows=0
2627        "#
2628        )
2629    }
2630
2631    #[test]
2632    fn timestamp_subtract_timestamp() -> Result<()> {
2633        let expr = Expr::BinaryExpr(BinaryExpr::new(
2634            Box::new(cast(
2635                lit("1998-03-18"),
2636                DataType::Timestamp(TimeUnit::Nanosecond, None),
2637            )),
2638            Operator::Minus,
2639            Box::new(cast(
2640                lit("1998-03-18"),
2641                DataType::Timestamp(TimeUnit::Nanosecond, None),
2642            )),
2643        ));
2644        let empty = empty();
2645        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2646
2647        assert_analyzed_plan_eq!(
2648            plan,
2649            @r#"
2650        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2651          EmptyRelation: rows=0
2652        "#
2653        )
2654    }
2655
2656    #[test]
2657    fn in_subquery_cast_subquery() -> Result<()> {
2658        let empty_int32 = empty_with_type(DataType::Int32);
2659        let empty_int64 = empty_with_type(DataType::Int64);
2660
2661        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2662            Box::new(col("a")),
2663            Subquery {
2664                subquery: empty_int32,
2665                outer_ref_columns: vec![],
2666                spans: Spans::new(),
2667            },
2668            false,
2669        ));
2670        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2671        // add cast for subquery
2672
2673        assert_analyzed_plan_eq!(
2674            plan,
2675            @r"
2676        Filter: a IN (<subquery>)
2677          Subquery:
2678            Projection: CAST(a AS Int64)
2679              EmptyRelation: rows=0
2680          EmptyRelation: rows=0
2681        "
2682        )
2683    }
2684
2685    #[test]
2686    fn in_subquery_cast_expr() -> Result<()> {
2687        let empty_int32 = empty_with_type(DataType::Int32);
2688        let empty_int64 = empty_with_type(DataType::Int64);
2689
2690        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2691            Box::new(col("a")),
2692            Subquery {
2693                subquery: empty_int64,
2694                outer_ref_columns: vec![],
2695                spans: Spans::new(),
2696            },
2697            false,
2698        ));
2699        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2700
2701        // add cast for subquery
2702        assert_analyzed_plan_eq!(
2703            plan,
2704            @r"
2705        Filter: CAST(a AS Int64) IN (<subquery>)
2706          Subquery:
2707            EmptyRelation: rows=0
2708          EmptyRelation: rows=0
2709        "
2710        )
2711    }
2712
2713    #[test]
2714    fn in_subquery_cast_all() -> Result<()> {
2715        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2716        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2717
2718        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2719            Box::new(col("a")),
2720            Subquery {
2721                subquery: empty_inside,
2722                outer_ref_columns: vec![],
2723                spans: Spans::new(),
2724            },
2725            false,
2726        ));
2727        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2728
2729        // add cast for subquery
2730        assert_analyzed_plan_eq!(
2731            plan,
2732            @r"
2733        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2734          Subquery:
2735            Projection: CAST(a AS Decimal128(13, 8))
2736              EmptyRelation: rows=0
2737          EmptyRelation: rows=0
2738        "
2739        )
2740    }
2741}