datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use arrow::compute::can_cast_types;
21use datafusion_expr::binary::BinaryTypeCoercer;
22use itertools::{Itertools as _, izip};
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26use crate::utils::NamePreserver;
27
28use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
29use arrow::temporal_conversions::SECONDS_IN_DAY;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
32use datafusion_common::{
33    Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34    exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
35    plan_err,
36};
37use datafusion_expr::expr::{
38    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::fields_with_udf;
46use datafusion_expr::type_coercion::other::{
47    get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52    AggregateUDF, Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator,
53    Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
54    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not,
55};
56
57/// Performs type coercion by determining the schema
58/// and performing the expression rewrites.
59#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63    pub fn new() -> Self {
64        Self {}
65    }
66}
67
68/// Coerce output schema based upon optimizer config.
69fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70    if !config.optimizer.expand_views_at_output {
71        return Ok(plan);
72    }
73
74    let outer_refs = plan.expressions();
75    if outer_refs.is_empty() {
76        return Ok(plan);
77    }
78
79    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80        coerce_plan_expr_for_schema(plan, &dfschema?)
81    } else {
82        Ok(plan)
83    }
84}
85
86impl AnalyzerRule for TypeCoercion {
87    fn name(&self) -> &str {
88        "type_coercion"
89    }
90
91    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92        let empty_schema = DFSchema::empty();
93
94        // recurse
95        let transformed_plan = plan
96            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97            .data;
98
99        // finish
100        coerce_output(transformed_plan, config)
101    }
102}
103
104/// use the external schema to handle the correlated subqueries case
105///
106/// Assumes that children have already been optimized
107fn analyze_internal(
108    external_schema: &DFSchema,
109    plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111    // get schema representing all available input fields. This is used for data type
112    // resolution only, so order does not matter here
113    let mut schema = merge_schema(&plan.inputs());
114
115    if let LogicalPlan::TableScan(ts) = &plan {
116        let source_schema = DFSchema::try_from_qualified_schema(
117            ts.table_name.clone(),
118            &ts.source.schema(),
119        )?;
120        schema.merge(&source_schema);
121    }
122
123    // merge the outer schema for correlated subqueries
124    // like case:
125    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
126    schema.merge(external_schema);
127
128    // Coerce filter predicates to boolean (handles `WHERE NULL`)
129    let plan = if let LogicalPlan::Filter(mut filter) = plan {
130        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131        LogicalPlan::Filter(filter)
132    } else {
133        plan
134    };
135
136    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138    let name_preserver = NamePreserver::new(&plan);
139    // apply coercion rewrite all expressions in the plan individually
140    plan.map_expressions(|expr| {
141        let original_name = name_preserver.save(&expr);
142        expr.rewrite(&mut expr_rewrite)
143            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144    })?
145    // some plans need extra coercion after their expressions are coerced
146    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147    // recompute the schema after the expressions have been rewritten as the types may have changed
148    .map_data(|plan| plan.recompute_schema())
149}
150
151/// Rewrite expressions to apply type coercion.
152pub struct TypeCoercionRewriter<'a> {
153    pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157    /// Create a new [`TypeCoercionRewriter`] with a provided schema
158    /// representing both the inputs and output of the [`LogicalPlan`] node.
159    pub fn new(schema: &'a DFSchema) -> Self {
160        Self { schema }
161    }
162
163    /// Coerce the [`LogicalPlan`].
164    ///
165    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
166    /// for type-coercion approach.
167    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168        match plan {
169            LogicalPlan::Join(join) => self.coerce_join(join),
170            LogicalPlan::Union(union) => Self::coerce_union(union),
171            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172            _ => Ok(plan),
173        }
174    }
175
176    /// Coerce join equality expressions and join filter
177    ///
178    /// Joins must be treated specially as their equality expressions are stored
179    /// as a parallel list of left and right expressions, rather than a single
180    /// equality expression
181    ///
182    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
183    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
184    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185        join.on = join
186            .on
187            .into_iter()
188            .map(|(lhs, rhs)| {
189                // coerce the arguments as though they were a single binary equality
190                // expression
191                let left_schema = join.left.schema();
192                let right_schema = join.right.schema();
193                let (lhs, rhs) = self.coerce_binary_op(
194                    lhs,
195                    left_schema,
196                    Operator::Eq,
197                    rhs,
198                    right_schema,
199                )?;
200                Ok((lhs, rhs))
201            })
202            .collect::<Result<Vec<_>>>()?;
203
204        // Join filter must be boolean
205        join.filter = join
206            .filter
207            .map(|expr| self.coerce_join_filter(expr))
208            .transpose()?;
209
210        Ok(LogicalPlan::Join(join))
211    }
212
213    /// Coerce the union’s inputs to a common schema compatible with all inputs.
214    /// This occurs after wildcard expansion and the coercion of the input expressions.
215    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216        let union_schema = Arc::new(coerce_union_schema_with_schema(
217            &union_plan.inputs,
218            &union_plan.schema,
219        )?);
220        let new_inputs = union_plan
221            .inputs
222            .into_iter()
223            .map(|p| {
224                let plan =
225                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226                match plan {
227                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
228                        Ok(Arc::new(project_with_column_index(
229                            expr,
230                            input,
231                            Arc::clone(&union_schema),
232                        )?))
233                    }
234                    other_plan => Ok(Arc::new(other_plan)),
235                }
236            })
237            .collect::<Result<Vec<_>>>()?;
238        Ok(LogicalPlan::Union(Union {
239            inputs: new_inputs,
240            schema: union_schema,
241        }))
242    }
243
244    /// Coerce the fetch and skip expression to Int64 type.
245    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246        fn coerce_limit_expr(
247            expr: Expr,
248            schema: &DFSchema,
249            expr_name: &str,
250        ) -> Result<Expr> {
251            let dt = expr.get_type(schema)?;
252            if dt.is_integer() || dt.is_null() {
253                expr.cast_to(&DataType::Int64, schema)
254            } else {
255                plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
256            }
257        }
258
259        let empty_schema = DFSchema::empty();
260        let new_fetch = limit
261            .fetch
262            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263            .transpose()?;
264        let new_skip = limit
265            .skip
266            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267            .transpose()?;
268        Ok(LogicalPlan::Limit(Limit {
269            input: limit.input,
270            fetch: new_fetch.map(Box::new),
271            skip: new_skip.map(Box::new),
272        }))
273    }
274
275    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276        let expr_type = expr.get_type(self.schema)?;
277        match expr_type {
278            DataType::Boolean => Ok(expr),
279            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281        }
282    }
283
284    fn coerce_binary_op(
285        &self,
286        left: Expr,
287        left_schema: &DFSchema,
288        op: Operator,
289        right: Expr,
290        right_schema: &DFSchema,
291    ) -> Result<(Expr, Expr)> {
292        let left_data_type = left.get_type(left_schema)?;
293        let right_data_type = right.get_type(right_schema)?;
294        let (left_type, right_type) =
295            BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type)
296                .get_input_types()?;
297        let left_cast_ok = can_cast_types(&left_data_type, &left_type);
298        let right_cast_ok = can_cast_types(&right_data_type, &right_type);
299
300        // handle special cases for
301        // * Date +/- int => Date
302        // * Date + time => Timestamp
303        let left_expr = if !left_cast_ok {
304            Self::coerce_date_time_math_op(
305                left,
306                &op,
307                &left_data_type,
308                &left_type,
309                &right_type,
310            )?
311        } else {
312            left.cast_to(&left_type, left_schema)?
313        };
314
315        let right_expr = if !right_cast_ok {
316            Self::coerce_date_time_math_op(
317                right,
318                &op,
319                &right_data_type,
320                &right_type,
321                &left_type,
322            )?
323        } else {
324            right.cast_to(&right_type, right_schema)?
325        };
326
327        Ok((left_expr, right_expr))
328    }
329
330    fn coerce_date_time_math_op(
331        expr: Expr,
332        op: &Operator,
333        left_current_type: &DataType,
334        left_target_type: &DataType,
335        right_target_type: &DataType,
336    ) -> Result<Expr, DataFusionError> {
337        use DataType::*;
338
339        fn cast(expr: Expr, target_type: DataType) -> Expr {
340            Expr::Cast(Cast::new(Box::new(expr), target_type))
341        }
342
343        fn time_to_nanos(
344            expr: Expr,
345            expr_type: &DataType,
346        ) -> Result<Expr, DataFusionError> {
347            let expr = match expr_type {
348                Time32(TimeUnit::Second) => {
349                    cast(cast(expr, Int32), Int64)
350                        * lit(ScalarValue::Int64(Some(1_000_000_000)))
351                }
352                Time32(TimeUnit::Millisecond) => {
353                    cast(cast(expr, Int32), Int64)
354                        * lit(ScalarValue::Int64(Some(1_000_000)))
355                }
356                Time64(TimeUnit::Microsecond) => {
357                    cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000)))
358                }
359                Time64(TimeUnit::Nanosecond) => cast(expr, Int64),
360                t => return internal_err!("Unexpected time data type {t}"),
361            };
362
363            Ok(expr)
364        }
365
366        let e = match (
367            &op,
368            &left_current_type,
369            &left_target_type,
370            &right_target_type,
371        ) {
372            // int +/- date => date
373            (
374                Operator::Plus | Operator::Minus,
375                Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
376                Interval(IntervalUnit::MonthDayNano),
377                Date32 | Date64,
378            ) => {
379                // cast to i64 first
380                let expr = match *left_current_type {
381                    Int64 => expr,
382                    _ => cast(expr, Int64),
383                };
384                // next, multiply by 86400 to get seconds
385                let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY));
386                // cast to duration
387                let expr = cast(expr, Duration(TimeUnit::Second));
388                // finally cast to interval
389                cast(expr, Interval(IntervalUnit::MonthDayNano))
390            }
391            // These might seem to be a bit convoluted, however for arrow to do date + time arithmetic
392            // date must be cast to Timestamp(Nanosecond) and time cast to Duration(Nanosecond)
393            // (they must be the same timeunit).
394            //
395            // For Time32/64 we first need to cast to an Int64, convert that to nanoseconds based
396            // on the time unit, then cast that to duration.
397            //
398            // Time + date -> timestamp or
399            (
400                Operator::Plus | Operator::Minus,
401                Time32(_) | Time64(_),
402                Duration(TimeUnit::Nanosecond),
403                Timestamp(TimeUnit::Nanosecond, None),
404            ) => {
405                // cast to int64, convert to nanoseconds
406                let expr = time_to_nanos(expr, left_current_type)?;
407                // cast to duration
408                cast(expr, Duration(TimeUnit::Nanosecond))
409            }
410            // Similar to above, for arrow to do time - time we need to convert to an interval.
411            // To do that we first need to cast to an Int64, convert that to nanoseconds based
412            // on the time unit, then cast that to duration, then finally cast to an interval.
413            //
414            // Time - time -> timestamp
415            (
416                Operator::Plus | Operator::Minus,
417                Time32(_) | Time64(_),
418                Interval(IntervalUnit::MonthDayNano),
419                Interval(IntervalUnit::MonthDayNano),
420            ) => {
421                // cast to int64, convert to nanoseconds
422                let expr = time_to_nanos(expr, left_current_type)?;
423                // cast to duration
424                let expr = cast(expr, Duration(TimeUnit::Nanosecond));
425                // finally cast to interval
426                cast(expr, Interval(IntervalUnit::MonthDayNano))
427            }
428            _ => {
429                return plan_err!(
430                    "Cannot automatically convert {left_current_type} to {left_target_type}"
431                );
432            }
433        };
434
435        Ok(e)
436    }
437}
438
439impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
440    type Node = Expr;
441
442    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
443        match expr {
444            Expr::Unnest(_) => not_impl_err!(
445                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
446            ),
447            Expr::ScalarSubquery(Subquery {
448                subquery,
449                outer_ref_columns,
450                spans,
451            }) => {
452                let new_plan =
453                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
454                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
455                    subquery: Arc::new(new_plan),
456                    outer_ref_columns,
457                    spans,
458                })))
459            }
460            Expr::Exists(Exists { subquery, negated }) => {
461                let new_plan = analyze_internal(
462                    self.schema,
463                    Arc::unwrap_or_clone(subquery.subquery),
464                )?
465                .data;
466                Ok(Transformed::yes(Expr::Exists(Exists {
467                    subquery: Subquery {
468                        subquery: Arc::new(new_plan),
469                        outer_ref_columns: subquery.outer_ref_columns,
470                        spans: subquery.spans,
471                    },
472                    negated,
473                })))
474            }
475            Expr::InSubquery(InSubquery {
476                expr,
477                subquery,
478                negated,
479            }) => {
480                let new_plan = analyze_internal(
481                    self.schema,
482                    Arc::unwrap_or_clone(subquery.subquery),
483                )?
484                .data;
485                let expr_type = expr.get_type(self.schema)?;
486                let subquery_type = new_plan.schema().field(0).data_type();
487                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
488                    plan_datafusion_err!(
489                    "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
490                ),
491                )?;
492                let new_subquery = Subquery {
493                    subquery: Arc::new(new_plan),
494                    outer_ref_columns: subquery.outer_ref_columns,
495                    spans: subquery.spans,
496                };
497                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
498                    Box::new(expr.cast_to(&common_type, self.schema)?),
499                    cast_subquery(new_subquery, &common_type)?,
500                    negated,
501                ))))
502            }
503            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
504                *expr,
505                self.schema,
506            )?))),
507            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
508                get_casted_expr_for_bool_op(*expr, self.schema)?,
509            ))),
510            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
511                get_casted_expr_for_bool_op(*expr, self.schema)?,
512            ))),
513            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
514                get_casted_expr_for_bool_op(*expr, self.schema)?,
515            ))),
516            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
517                get_casted_expr_for_bool_op(*expr, self.schema)?,
518            ))),
519            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
520                get_casted_expr_for_bool_op(*expr, self.schema)?,
521            ))),
522            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
523                get_casted_expr_for_bool_op(*expr, self.schema)?,
524            ))),
525            Expr::Like(Like {
526                negated,
527                expr,
528                pattern,
529                escape_char,
530                case_insensitive,
531            }) => {
532                let left_type = expr.get_type(self.schema)?;
533                let right_type = pattern.get_type(self.schema)?;
534                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
535                    let op_name = if case_insensitive {
536                        "ILIKE"
537                    } else {
538                        "LIKE"
539                    };
540                    plan_datafusion_err!(
541                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
542                    )
543                })?;
544                let expr = match left_type {
545                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
546                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
547                };
548                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
549                Ok(Transformed::yes(Expr::Like(Like::new(
550                    negated,
551                    expr,
552                    pattern,
553                    escape_char,
554                    case_insensitive,
555                ))))
556            }
557            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
558                let (left, right) =
559                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
560                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
561                    Box::new(left),
562                    op,
563                    Box::new(right),
564                ))))
565            }
566            Expr::Between(Between {
567                expr,
568                negated,
569                low,
570                high,
571            }) => {
572                let expr_type = expr.get_type(self.schema)?;
573                let low_type = low.get_type(self.schema)?;
574                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
575                    .ok_or_else(|| {
576                        internal_datafusion_err!(
577                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
578                        )
579                    })?;
580                let high_type = high.get_type(self.schema)?;
581                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
582                    .ok_or_else(|| {
583                        internal_datafusion_err!(
584                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
585                        )
586                    })?;
587                let coercion_type =
588                    comparison_coercion(&low_coerced_type, &high_coerced_type)
589                        .ok_or_else(|| {
590                            internal_datafusion_err!(
591                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
592                            )
593                        })?;
594                Ok(Transformed::yes(Expr::Between(Between::new(
595                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
596                    negated,
597                    Box::new(low.cast_to(&coercion_type, self.schema)?),
598                    Box::new(high.cast_to(&coercion_type, self.schema)?),
599                ))))
600            }
601            Expr::InList(InList {
602                expr,
603                list,
604                negated,
605            }) => {
606                let expr_data_type = expr.get_type(self.schema)?;
607                let list_data_types = list
608                    .iter()
609                    .map(|list_expr| list_expr.get_type(self.schema))
610                    .collect::<Result<Vec<_>>>()?;
611                let result_type =
612                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
613                match result_type {
614                    None => plan_err!(
615                        "Can not find compatible types to compare {expr_data_type} with [{}]",
616                        list_data_types.iter().join(", ")
617                    ),
618                    Some(coerced_type) => {
619                        // find the coerced type
620                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
621                        let cast_list_expr = list
622                            .into_iter()
623                            .map(|list_expr| {
624                                list_expr.cast_to(&coerced_type, self.schema)
625                            })
626                            .collect::<Result<Vec<_>>>()?;
627                        Ok(Transformed::yes(Expr::InList(InList::new(
628                            Box::new(cast_expr),
629                            cast_list_expr,
630                            negated,
631                        ))))
632                    }
633                }
634            }
635            Expr::Case(case) => {
636                let case = coerce_case_expression(case, self.schema)?;
637                Ok(Transformed::yes(Expr::Case(case)))
638            }
639            Expr::ScalarFunction(ScalarFunction { func, args }) => {
640                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
641                    args,
642                    self.schema,
643                    &func,
644                )?;
645                Ok(Transformed::yes(Expr::ScalarFunction(
646                    ScalarFunction::new_udf(func, new_expr),
647                )))
648            }
649            Expr::AggregateFunction(expr::AggregateFunction {
650                func,
651                params:
652                    AggregateFunctionParams {
653                        args,
654                        distinct,
655                        filter,
656                        order_by,
657                        null_treatment,
658                    },
659            }) => {
660                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
661                    args,
662                    self.schema,
663                    &func,
664                )?;
665                Ok(Transformed::yes(Expr::AggregateFunction(
666                    expr::AggregateFunction::new_udf(
667                        func,
668                        new_expr,
669                        distinct,
670                        filter,
671                        order_by,
672                        null_treatment,
673                    ),
674                )))
675            }
676            Expr::WindowFunction(window_fun) => {
677                let WindowFunction {
678                    fun,
679                    params:
680                        expr::WindowFunctionParams {
681                            args,
682                            partition_by,
683                            order_by,
684                            window_frame,
685                            filter,
686                            null_treatment,
687                            distinct,
688                        },
689                } = *window_fun;
690                let window_frame =
691                    coerce_window_frame(window_frame, self.schema, &order_by)?;
692
693                let args = match &fun {
694                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
695                        coerce_arguments_for_signature_with_aggregate_udf(
696                            args,
697                            self.schema,
698                            udf,
699                        )?
700                    }
701                    _ => args,
702                };
703
704                let new_expr = Expr::from(WindowFunction {
705                    fun,
706                    params: expr::WindowFunctionParams {
707                        args,
708                        partition_by,
709                        order_by,
710                        window_frame,
711                        filter,
712                        null_treatment,
713                        distinct,
714                    },
715                });
716                Ok(Transformed::yes(new_expr))
717            }
718            // TODO: remove the next line after `Expr::Wildcard` is removed
719            #[expect(deprecated)]
720            Expr::Alias(_)
721            | Expr::Column(_)
722            | Expr::ScalarVariable(_, _)
723            | Expr::Literal(_, _)
724            | Expr::SimilarTo(_)
725            | Expr::IsNotNull(_)
726            | Expr::IsNull(_)
727            | Expr::Negative(_)
728            | Expr::Cast(_)
729            | Expr::TryCast(_)
730            | Expr::Wildcard { .. }
731            | Expr::GroupingSet(_)
732            | Expr::Placeholder(_)
733            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
734        }
735    }
736}
737
738/// Transform a schema to use non-view types for Utf8View and BinaryView
739fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
740    let metadata = dfschema.as_arrow().metadata.clone();
741    let mut transformed = false;
742
743    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
744        dfschema
745            .iter()
746            .map(|(qualifier, field)| match field.data_type() {
747                DataType::Utf8View => {
748                    transformed = true;
749                    (
750                        qualifier.cloned() as Option<TableReference>,
751                        Arc::new(Field::new(
752                            field.name(),
753                            DataType::LargeUtf8,
754                            field.is_nullable(),
755                        )),
756                    )
757                }
758                DataType::BinaryView => {
759                    transformed = true;
760                    (
761                        qualifier.cloned() as Option<TableReference>,
762                        Arc::new(Field::new(
763                            field.name(),
764                            DataType::LargeBinary,
765                            field.is_nullable(),
766                        )),
767                    )
768                }
769                _ => (
770                    qualifier.cloned() as Option<TableReference>,
771                    Arc::clone(field),
772                ),
773            })
774            .unzip();
775
776    if !transformed {
777        return None;
778    }
779
780    let schema = Schema::new_with_metadata(transformed_fields, metadata);
781    Some(DFSchema::from_field_specific_qualified_schema(
782        qualifiers,
783        &Arc::new(schema),
784    ))
785}
786
787/// Casts the given `value` to `target_type`. Note that this function
788/// only considers `Null` or `Utf8` values.
789fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
790    match value {
791        // Coerce Utf8 values:
792        ScalarValue::Utf8(Some(val)) => {
793            ScalarValue::try_from_string(val.clone(), target_type)
794        }
795        s => {
796            if s.is_null() {
797                // Coerce `Null` values:
798                ScalarValue::try_from(target_type)
799            } else {
800                // Values except `Utf8`/`Null` variants already have the right type
801                // (casted before) since we convert `sqlparser` outputs to `Utf8`
802                // for all possible cases. Therefore, we return a clone here.
803                Ok(s.clone())
804            }
805        }
806    }
807}
808
809/// This function coerces `value` to `target_type` in a range-aware fashion.
810/// If the coercion is successful, we return an `Ok` value with the result.
811/// If the coercion fails because `target_type` is not wide enough (i.e. we
812/// can not coerce to `target_type`, but we can to a wider type in the same
813/// family), we return a `Null` value of this type to signal this situation.
814/// Downstream code uses this signal to treat these values as *unbounded*.
815fn coerce_scalar_range_aware(
816    target_type: &DataType,
817    value: &ScalarValue,
818) -> Result<ScalarValue> {
819    coerce_scalar(target_type, value).or_else(|err| {
820        // If type coercion fails, check if the largest type in family works:
821        if let Some(largest_type) = get_widest_type_in_family(target_type) {
822            coerce_scalar(largest_type, value).map_or_else(
823                |_| exec_err!("Cannot cast {value:?} to {target_type}"),
824                |_| ScalarValue::try_from(target_type),
825            )
826        } else {
827            Err(err)
828        }
829    })
830}
831
832/// This function returns the widest type in the family of `given_type`.
833/// If the given type is already the widest type, it returns `None`.
834/// For example, if `given_type` is `Int8`, it returns `Int64`.
835fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
836    match given_type {
837        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
838        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
839        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
840        _ => None,
841    }
842}
843
844/// Coerces the given (window frame) `bound` to `target_type`.
845fn coerce_frame_bound(
846    target_type: &DataType,
847    bound: WindowFrameBound,
848) -> Result<WindowFrameBound> {
849    match bound {
850        WindowFrameBound::Preceding(v) => {
851            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
852        }
853        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
854        WindowFrameBound::Following(v) => {
855            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
856        }
857    }
858}
859
860fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
861    if col_type.is_numeric()
862        || is_utf8_or_utf8view_or_large_utf8(col_type)
863        || matches!(col_type, DataType::List(_))
864        || matches!(col_type, DataType::LargeList(_))
865        || matches!(col_type, DataType::FixedSizeList(_, _))
866        || matches!(col_type, DataType::Null)
867        || matches!(col_type, DataType::Boolean)
868    {
869        Ok(col_type.clone())
870    } else if is_datetime(col_type) {
871        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
872    } else if let DataType::Dictionary(_, value_type) = col_type {
873        extract_window_frame_target_type(value_type)
874    } else {
875        internal_err!("Cannot run range queries on datatype: {col_type}")
876    }
877}
878
879// Coerces the given `window_frame` to use appropriate natural types.
880// For example, ROWS and GROUPS frames use `UInt64` during calculations.
881fn coerce_window_frame(
882    window_frame: WindowFrame,
883    schema: &DFSchema,
884    expressions: &[Sort],
885) -> Result<WindowFrame> {
886    let mut window_frame = window_frame;
887    let target_type = match window_frame.units {
888        WindowFrameUnits::Range => {
889            let current_types = expressions
890                .first()
891                .map(|s| s.expr.get_type(schema))
892                .transpose()?;
893            if let Some(col_type) = current_types {
894                extract_window_frame_target_type(&col_type)?
895            } else {
896                return internal_err!("ORDER BY column cannot be empty");
897            }
898        }
899        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
900    };
901    window_frame.start_bound =
902        coerce_frame_bound(&target_type, window_frame.start_bound)?;
903    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
904    Ok(window_frame)
905}
906
907// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
908// The above op will be rewrite to the binary op when creating the physical op.
909fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
910    let left_type = expr.get_type(schema)?;
911    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
912        .get_input_types()?;
913    expr.cast_to(&DataType::Boolean, schema)
914}
915
916/// Returns `expressions` coerced to types compatible with
917/// `signature`, if possible.
918///
919/// See the module level documentation for more detail on coercion.
920fn coerce_arguments_for_signature_with_scalar_udf(
921    expressions: Vec<Expr>,
922    schema: &DFSchema,
923    func: &ScalarUDF,
924) -> Result<Vec<Expr>> {
925    if expressions.is_empty() {
926        return Ok(expressions);
927    }
928
929    let current_fields = expressions
930        .iter()
931        .map(|e| e.to_field(schema).map(|(_, f)| f))
932        .collect::<Result<Vec<_>>>()?;
933
934    let coerced_types = fields_with_udf(&current_fields, func)?
935        .into_iter()
936        .map(|f| f.data_type().clone())
937        .collect::<Vec<_>>();
938
939    expressions
940        .into_iter()
941        .enumerate()
942        .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
943        .collect()
944}
945
946/// Returns `expressions` coerced to types compatible with
947/// `signature`, if possible.
948///
949/// See the module level documentation for more detail on coercion.
950fn coerce_arguments_for_signature_with_aggregate_udf(
951    expressions: Vec<Expr>,
952    schema: &DFSchema,
953    func: &AggregateUDF,
954) -> Result<Vec<Expr>> {
955    if expressions.is_empty() {
956        return Ok(expressions);
957    }
958
959    let current_fields = expressions
960        .iter()
961        .map(|e| e.to_field(schema).map(|(_, f)| f))
962        .collect::<Result<Vec<_>>>()?;
963
964    let coerced_types = fields_with_udf(&current_fields, func)?
965        .into_iter()
966        .map(|f| f.data_type().clone())
967        .collect::<Vec<_>>();
968
969    expressions
970        .into_iter()
971        .enumerate()
972        .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
973        .collect()
974}
975
976fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
977    // Given expressions like:
978    //
979    // CASE a1
980    //   WHEN a2 THEN b1
981    //   WHEN a3 THEN b2
982    //   ELSE b3
983    // END
984    //
985    // or:
986    //
987    // CASE
988    //   WHEN x1 THEN b1
989    //   WHEN x2 THEN b2
990    //   ELSE b3
991    // END
992    //
993    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
994    // (case-when expression coercion)
995    //
996    // All xN (x1, x2) must be converted to a boolean data type in the second example
997    // (when-boolean expression coercion)
998    //
999    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
1000    // (then-else expression coercion)
1001    //
1002    // If any fail to find and cast to a common/specific data type, will return error
1003    //
1004    // Note that case-when and when-boolean expression coercions are mutually exclusive
1005    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
1006
1007    // prepare types
1008    let case_type = case
1009        .expr
1010        .as_ref()
1011        .map(|expr| expr.get_type(schema))
1012        .transpose()?;
1013    let then_types = case
1014        .when_then_expr
1015        .iter()
1016        .map(|(_when, then)| then.get_type(schema))
1017        .collect::<Result<Vec<_>>>()?;
1018    let else_type = case
1019        .else_expr
1020        .as_ref()
1021        .map(|expr| expr.get_type(schema))
1022        .transpose()?;
1023
1024    // find common coercible types
1025    let case_when_coerce_type = case_type
1026        .as_ref()
1027        .map(|case_type| {
1028            let when_types = case
1029                .when_then_expr
1030                .iter()
1031                .map(|(when, _then)| when.get_type(schema))
1032                .collect::<Result<Vec<_>>>()?;
1033            let coerced_type =
1034                get_coerce_type_for_case_expression(&when_types, Some(case_type));
1035            coerced_type.ok_or_else(|| {
1036                plan_datafusion_err!(
1037                    "Failed to coerce case ({case_type}) and when ({}) \
1038                     to common types in CASE WHEN expression",
1039                    when_types.iter().join(", ")
1040                )
1041            })
1042        })
1043        .transpose()?;
1044    let then_else_coerce_type =
1045        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
1046            || {
1047                if let Some(else_type) = else_type {
1048                    plan_datafusion_err!(
1049                        "Failed to coerce then ({}) and else ({else_type}) \
1050                         to common types in CASE WHEN expression",
1051                        then_types.iter().join(", ")
1052                    )
1053                } else {
1054                    plan_datafusion_err!(
1055                        "Failed to coerce then ({}) and else (None) \
1056                         to common types in CASE WHEN expression",
1057                        then_types.iter().join(", ")
1058                    )
1059                }
1060            },
1061        )?;
1062
1063    // do cast if found common coercible types
1064    let case_expr = case
1065        .expr
1066        .zip(case_when_coerce_type.as_ref())
1067        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
1068        .transpose()?
1069        .map(Box::new);
1070    let when_then = case
1071        .when_then_expr
1072        .into_iter()
1073        .map(|(when, then)| {
1074            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
1075            let when = when.cast_to(when_type, schema).map_err(|e| {
1076                DataFusionError::Context(
1077                    format!(
1078                        "WHEN expressions in CASE couldn't be \
1079                         converted to common type ({when_type})"
1080                    ),
1081                    Box::new(e),
1082                )
1083            })?;
1084            let then = then.cast_to(&then_else_coerce_type, schema)?;
1085            Ok((Box::new(when), Box::new(then)))
1086        })
1087        .collect::<Result<Vec<_>>>()?;
1088    let else_expr = case
1089        .else_expr
1090        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
1091        .transpose()?
1092        .map(Box::new);
1093
1094    Ok(Case::new(case_expr, when_then, else_expr))
1095}
1096
1097/// Get a common schema that is compatible with all inputs of UNION.
1098///
1099/// This method presumes that the wildcard expansion is unneeded, or has already
1100/// been applied.
1101///
1102/// ## Schema and Field Handling in Union Coercion
1103///
1104/// **Processing order**: The function starts with the base schema (first input) and then
1105/// processes remaining inputs sequentially, with later inputs taking precedence in merging.
1106///
1107/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys.
1108///
1109/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys.
1110///
1111/// **Type coercion precedence**: The coerced type is determined by iteratively applying
1112/// `comparison_coercion()` between the accumulated type and each new input's type. The
1113/// result depends on type coercion rules, not input order.
1114///
1115/// **Nullability merging**: Nullability is accumulated using logical OR (`||`).
1116/// Once any input field is nullable, the result field becomes nullable permanently.
1117/// Later inputs can make a field nullable but cannot make it non-nullable.
1118///
1119/// **Field precedence**: Field names come from the first (base) schema, but the field properties
1120/// (nullability and field-level metadata) have later schemas taking precedence.
1121///
1122/// **Example**:
1123/// ```sql
1124/// SELECT a, b FROM table1  -- a: Int32, metadata {"source": "t1"}, nullable=false
1125/// UNION
1126/// SELECT a, b FROM table2  -- a: Int64, metadata {"source": "t2"}, nullable=true
1127/// UNION
1128/// SELECT a, b FROM table3  -- a: Int32, metadata {"encoding": "utf8"}, nullable=false
1129/// -- Result:
1130/// -- a: Int64 (from type coercion), nullable=true (from table2),
1131/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence)
1132/// ```
1133///
1134/// **Precedence Summary**:
1135/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order
1136/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR)
1137/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics)
1138pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1139    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1140}
1141fn coerce_union_schema_with_schema(
1142    inputs: &[Arc<LogicalPlan>],
1143    base_schema: &DFSchemaRef,
1144) -> Result<DFSchema> {
1145    let mut union_datatypes = base_schema
1146        .fields()
1147        .iter()
1148        .map(|f| f.data_type().clone())
1149        .collect::<Vec<_>>();
1150    let mut union_nullabilities = base_schema
1151        .fields()
1152        .iter()
1153        .map(|f| f.is_nullable())
1154        .collect::<Vec<_>>();
1155    let mut union_field_meta = base_schema
1156        .fields()
1157        .iter()
1158        .map(|f| f.metadata().clone())
1159        .collect::<Vec<_>>();
1160
1161    let mut metadata = base_schema.metadata().clone();
1162
1163    for (i, plan) in inputs.iter().enumerate() {
1164        let plan_schema = plan.schema();
1165        metadata.extend(plan_schema.metadata().clone());
1166
1167        if plan_schema.fields().len() != base_schema.fields().len() {
1168            return plan_err!(
1169                "Union schemas have different number of fields: \
1170                query 1 has {} fields whereas query {} has {} fields",
1171                base_schema.fields().len(),
1172                i + 1,
1173                plan_schema.fields().len()
1174            );
1175        }
1176
1177        // coerce data type and nullability for each field
1178        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1179            union_datatypes.iter_mut(),
1180            union_nullabilities.iter_mut(),
1181            union_field_meta.iter_mut(),
1182            plan_schema.fields().iter()
1183        ) {
1184            let coerced_type =
1185                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1186                    || {
1187                        plan_datafusion_err!(
1188                            "Incompatible inputs for Union: Previous inputs were \
1189                            of type {}, but got incompatible type {} on column '{}'",
1190                            union_datatype,
1191                            plan_field.data_type(),
1192                            plan_field.name()
1193                        )
1194                    },
1195                )?;
1196
1197            *union_datatype = coerced_type;
1198            *union_nullable = *union_nullable || plan_field.is_nullable();
1199            union_field_map.extend(plan_field.metadata().clone());
1200        }
1201    }
1202    let union_qualified_fields = izip!(
1203        base_schema.fields(),
1204        union_datatypes.into_iter(),
1205        union_nullabilities,
1206        union_field_meta.into_iter()
1207    )
1208    .map(|(field, datatype, nullable, metadata)| {
1209        let mut field = Field::new(field.name().clone(), datatype, nullable);
1210        field.set_metadata(metadata);
1211        (None, field.into())
1212    })
1213    .collect::<Vec<_>>();
1214
1215    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1216}
1217
1218/// See `<https://github.com/apache/datafusion/pull/2108>`
1219fn project_with_column_index(
1220    expr: Vec<Expr>,
1221    input: Arc<LogicalPlan>,
1222    schema: DFSchemaRef,
1223) -> Result<LogicalPlan> {
1224    let alias_expr = expr
1225        .into_iter()
1226        .enumerate()
1227        .map(|(i, e)| match e {
1228            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1229                Ok(e.unalias().alias(schema.field(i).name()))
1230            }
1231            Expr::Column(Column {
1232                relation: _,
1233                ref name,
1234                spans: _,
1235            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1236            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1237            #[expect(deprecated)]
1238            Expr::Wildcard { .. } => {
1239                plan_err!("Wildcard should be expanded before type coercion")
1240            }
1241            _ => Ok(e.alias(schema.field(i).name())),
1242        })
1243        .collect::<Result<Vec<_>>>()?;
1244
1245    Projection::try_new_with_schema(alias_expr, input, schema)
1246        .map(LogicalPlan::Projection)
1247}
1248
1249#[cfg(test)]
1250mod test {
1251    use std::any::Any;
1252    use std::sync::Arc;
1253
1254    use arrow::datatypes::DataType::Utf8;
1255    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1256    use insta::assert_snapshot;
1257
1258    use crate::analyzer::Analyzer;
1259    use crate::analyzer::type_coercion::{
1260        TypeCoercion, TypeCoercionRewriter, coerce_case_expression,
1261    };
1262    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1263    use datafusion_common::config::ConfigOptions;
1264    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1265    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1266    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1267    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1268    use datafusion_expr::test::function_stub::avg_udaf;
1269    use datafusion_expr::{
1270        AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
1271        ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF,
1272        ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast,
1273        col, create_udaf, is_true, lit,
1274    };
1275    use datafusion_functions_aggregate::average::AvgAccumulator;
1276    use datafusion_sql::TableReference;
1277
1278    fn empty() -> Arc<LogicalPlan> {
1279        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1280            produce_one_row: false,
1281            schema: Arc::new(DFSchema::empty()),
1282        }))
1283    }
1284
1285    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1286        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1287            produce_one_row: false,
1288            schema: Arc::new(
1289                DFSchema::from_unqualified_fields(
1290                    vec![Field::new("a", data_type, true)].into(),
1291                    std::collections::HashMap::new(),
1292                )
1293                .unwrap(),
1294            ),
1295        }))
1296    }
1297
1298    macro_rules! assert_analyzed_plan_eq {
1299        (
1300            $plan: expr,
1301            @ $expected: literal $(,)?
1302        ) => {{
1303            let options = ConfigOptions::default();
1304            let rule = Arc::new(TypeCoercion::new());
1305            assert_analyzed_plan_with_config_eq_snapshot!(
1306                options,
1307                rule,
1308                $plan,
1309                @ $expected,
1310            )
1311            }};
1312    }
1313
1314    macro_rules! coerce_on_output_if_viewtype {
1315        (
1316            $is_viewtype: expr,
1317            $plan: expr,
1318            @ $expected: literal $(,)?
1319        ) => {{
1320            let mut options = ConfigOptions::default();
1321            // coerce on output
1322            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1323            let rule = Arc::new(TypeCoercion::new());
1324
1325            assert_analyzed_plan_with_config_eq_snapshot!(
1326                options,
1327                rule,
1328                $plan,
1329                @ $expected,
1330            )
1331        }};
1332    }
1333
1334    fn assert_type_coercion_error(
1335        plan: LogicalPlan,
1336        expected_substr: &str,
1337    ) -> Result<()> {
1338        let options = ConfigOptions::default();
1339        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1340
1341        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1342            Ok(succeeded_plan) => {
1343                panic!(
1344                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1345                );
1346            }
1347            Err(e) => {
1348                let msg = e.to_string();
1349                assert!(
1350                    msg.contains(expected_substr),
1351                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1352                );
1353            }
1354        }
1355
1356        Ok(())
1357    }
1358
1359    #[test]
1360    fn simple_case() -> Result<()> {
1361        let expr = col("a").lt(lit(2_u32));
1362        let empty = empty_with_type(DataType::Float64);
1363        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1364
1365        assert_analyzed_plan_eq!(
1366            plan,
1367            @r"
1368        Projection: a < CAST(UInt32(2) AS Float64)
1369          EmptyRelation: rows=0
1370        "
1371        )
1372    }
1373
1374    #[test]
1375    fn test_coerce_union() -> Result<()> {
1376        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1377            produce_one_row: false,
1378            schema: Arc::new(
1379                DFSchema::try_from_qualified_schema(
1380                    TableReference::full("datafusion", "test", "foo"),
1381                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1382                )
1383                .unwrap(),
1384            ),
1385        }));
1386        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1387            produce_one_row: false,
1388            schema: Arc::new(
1389                DFSchema::try_from_qualified_schema(
1390                    TableReference::full("datafusion", "test", "foo"),
1391                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1392                )
1393                .unwrap(),
1394            ),
1395        }));
1396        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1397            left_plan, right_plan,
1398        ])?);
1399        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1400            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1401        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1402            vec![col("a")],
1403            Arc::new(analyzed_union),
1404        )?);
1405
1406        assert_analyzed_plan_eq!(
1407            top_level_plan,
1408            @r"
1409        Projection: a
1410          Union
1411            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1412              EmptyRelation: rows=0
1413            EmptyRelation: rows=0
1414        "
1415        )
1416    }
1417
1418    #[test]
1419    fn coerce_utf8view_output() -> Result<()> {
1420        // Plan A
1421        // scenario: outermost utf8view projection
1422        let expr = col("a");
1423        let empty = empty_with_type(DataType::Utf8View);
1424        let plan = LogicalPlan::Projection(Projection::try_new(
1425            vec![expr.clone()],
1426            Arc::clone(&empty),
1427        )?);
1428
1429        // Plan A: no coerce
1430        coerce_on_output_if_viewtype!(
1431            false,
1432            plan.clone(),
1433            @r"
1434        Projection: a
1435          EmptyRelation: rows=0
1436        "
1437        )?;
1438
1439        // Plan A: coerce requested: Utf8View => LargeUtf8
1440        coerce_on_output_if_viewtype!(
1441            true,
1442            plan.clone(),
1443            @r"
1444        Projection: CAST(a AS LargeUtf8) AS a
1445          EmptyRelation: rows=0
1446        "
1447        )?;
1448
1449        // Plan B
1450        // scenario: outermost bool projection
1451        let bool_expr = col("a").lt(lit("foo"));
1452        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1453            vec![bool_expr],
1454            Arc::clone(&empty),
1455        )?);
1456        // Plan B: no coerce
1457        coerce_on_output_if_viewtype!(
1458            false,
1459            bool_plan.clone(),
1460            @r#"
1461        Projection: a < CAST(Utf8("foo") AS Utf8View)
1462          EmptyRelation: rows=0
1463        "#
1464        )?;
1465
1466        coerce_on_output_if_viewtype!(
1467            false,
1468            plan.clone(),
1469            @r"
1470        Projection: a
1471          EmptyRelation: rows=0
1472        "
1473        )?;
1474
1475        // Plan B: coerce requested: no coercion applied
1476        coerce_on_output_if_viewtype!(
1477            true,
1478            plan.clone(),
1479            @r"
1480        Projection: CAST(a AS LargeUtf8) AS a
1481          EmptyRelation: rows=0
1482        "
1483        )?;
1484
1485        // Plan C
1486        // scenario: with a non-projection root logical plan node
1487        let sort_expr = expr.sort(true, true);
1488        let sort_plan = LogicalPlan::Sort(Sort {
1489            expr: vec![sort_expr],
1490            input: Arc::new(plan),
1491            fetch: None,
1492        });
1493
1494        // Plan C: no coerce
1495        coerce_on_output_if_viewtype!(
1496            false,
1497            sort_plan.clone(),
1498            @r"
1499        Sort: a ASC NULLS FIRST
1500          Projection: a
1501            EmptyRelation: rows=0
1502        "
1503        )?;
1504
1505        // Plan C: coerce requested: Utf8View => LargeUtf8
1506        coerce_on_output_if_viewtype!(
1507            true,
1508            sort_plan.clone(),
1509            @r"
1510        Projection: CAST(a AS LargeUtf8) AS a
1511          Sort: a ASC NULLS FIRST
1512            Projection: a
1513              EmptyRelation: rows=0
1514        "
1515        )?;
1516
1517        // Plan D
1518        // scenario: two layers of projections with view types
1519        let plan = LogicalPlan::Projection(Projection::try_new(
1520            vec![col("a")],
1521            Arc::new(sort_plan),
1522        )?);
1523        // Plan D: no coerce
1524        coerce_on_output_if_viewtype!(
1525            false,
1526            plan.clone(),
1527            @r"
1528        Projection: a
1529          Sort: a ASC NULLS FIRST
1530            Projection: a
1531              EmptyRelation: rows=0
1532        "
1533        )?;
1534        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1535        coerce_on_output_if_viewtype!(
1536            true,
1537            plan.clone(),
1538            @r"
1539        Projection: CAST(a AS LargeUtf8) AS a
1540          Sort: a ASC NULLS FIRST
1541            Projection: a
1542              EmptyRelation: rows=0
1543        "
1544        )?;
1545
1546        Ok(())
1547    }
1548
1549    #[test]
1550    fn coerce_binaryview_output() -> Result<()> {
1551        // Plan A
1552        // scenario: outermost binaryview projection
1553        let expr = col("a");
1554        let empty = empty_with_type(DataType::BinaryView);
1555        let plan = LogicalPlan::Projection(Projection::try_new(
1556            vec![expr.clone()],
1557            Arc::clone(&empty),
1558        )?);
1559
1560        // Plan A: no coerce
1561        coerce_on_output_if_viewtype!(
1562            false,
1563            plan.clone(),
1564            @r"
1565        Projection: a
1566          EmptyRelation: rows=0
1567        "
1568        )?;
1569
1570        // Plan A: coerce requested: BinaryView => LargeBinary
1571        coerce_on_output_if_viewtype!(
1572            true,
1573            plan.clone(),
1574            @r"
1575        Projection: CAST(a AS LargeBinary) AS a
1576          EmptyRelation: rows=0
1577        "
1578        )?;
1579
1580        // Plan B
1581        // scenario: outermost bool projection
1582        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1583        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1584            vec![bool_expr],
1585            Arc::clone(&empty),
1586        )?);
1587
1588        // Plan B: no coerce
1589        coerce_on_output_if_viewtype!(
1590            false,
1591            bool_plan.clone(),
1592            @r#"
1593        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1594          EmptyRelation: rows=0
1595        "#
1596        )?;
1597
1598        // Plan B: coerce requested: no coercion applied
1599        coerce_on_output_if_viewtype!(
1600            true,
1601            bool_plan.clone(),
1602            @r#"
1603        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1604          EmptyRelation: rows=0
1605        "#
1606        )?;
1607
1608        // Plan C
1609        // scenario: with a non-projection root logical plan node
1610        let sort_expr = expr.sort(true, true);
1611        let sort_plan = LogicalPlan::Sort(Sort {
1612            expr: vec![sort_expr],
1613            input: Arc::new(plan),
1614            fetch: None,
1615        });
1616
1617        // Plan C: no coerce
1618        coerce_on_output_if_viewtype!(
1619            false,
1620            sort_plan.clone(),
1621            @r"
1622        Sort: a ASC NULLS FIRST
1623          Projection: a
1624            EmptyRelation: rows=0
1625        "
1626        )?;
1627        // Plan C: coerce requested: BinaryView => LargeBinary
1628        coerce_on_output_if_viewtype!(
1629            true,
1630            sort_plan.clone(),
1631            @r"
1632        Projection: CAST(a AS LargeBinary) AS a
1633          Sort: a ASC NULLS FIRST
1634            Projection: a
1635              EmptyRelation: rows=0
1636        "
1637        )?;
1638
1639        // Plan D
1640        // scenario: two layers of projections with view types
1641        let plan = LogicalPlan::Projection(Projection::try_new(
1642            vec![col("a")],
1643            Arc::new(sort_plan),
1644        )?);
1645
1646        // Plan D: no coerce
1647        coerce_on_output_if_viewtype!(
1648            false,
1649            plan.clone(),
1650            @r"
1651        Projection: a
1652          Sort: a ASC NULLS FIRST
1653            Projection: a
1654              EmptyRelation: rows=0
1655        "
1656        )?;
1657
1658        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1659        coerce_on_output_if_viewtype!(
1660            true,
1661            plan.clone(),
1662            @r"
1663        Projection: CAST(a AS LargeBinary) AS a
1664          Sort: a ASC NULLS FIRST
1665            Projection: a
1666              EmptyRelation: rows=0
1667        "
1668        )?;
1669
1670        Ok(())
1671    }
1672
1673    #[test]
1674    fn nested_case() -> Result<()> {
1675        let expr = col("a").lt(lit(2_u32));
1676        let empty = empty_with_type(DataType::Float64);
1677
1678        let plan = LogicalPlan::Projection(Projection::try_new(
1679            vec![expr.clone().or(expr)],
1680            empty,
1681        )?);
1682
1683        assert_analyzed_plan_eq!(
1684            plan,
1685            @r"
1686        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1687          EmptyRelation: rows=0
1688        "
1689        )
1690    }
1691
1692    #[derive(Debug, PartialEq, Eq, Hash)]
1693    struct TestScalarUDF {
1694        signature: Signature,
1695    }
1696
1697    impl ScalarUDFImpl for TestScalarUDF {
1698        fn as_any(&self) -> &dyn Any {
1699            self
1700        }
1701
1702        fn name(&self) -> &str {
1703            "TestScalarUDF"
1704        }
1705
1706        fn signature(&self) -> &Signature {
1707            &self.signature
1708        }
1709
1710        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1711            Ok(Utf8)
1712        }
1713
1714        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1715            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1716        }
1717    }
1718
1719    #[test]
1720    fn scalar_udf() -> Result<()> {
1721        let empty = empty();
1722
1723        let udf = ScalarUDF::from(TestScalarUDF {
1724            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1725        })
1726        .call(vec![lit(123_i32)]);
1727        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1728
1729        assert_analyzed_plan_eq!(
1730            plan,
1731            @r"
1732        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1733          EmptyRelation: rows=0
1734        "
1735        )
1736    }
1737
1738    #[test]
1739    fn scalar_udf_invalid_input() -> Result<()> {
1740        let empty = empty();
1741        let udf = ScalarUDF::from(TestScalarUDF {
1742            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1743        })
1744        .call(vec![lit("Apple")]);
1745        Projection::try_new(vec![udf], empty)
1746            .expect_err("Expected an error due to incorrect function input");
1747
1748        Ok(())
1749    }
1750
1751    #[test]
1752    fn scalar_function() -> Result<()> {
1753        // test that automatic argument type coercion for scalar functions work
1754        let empty = empty();
1755        let lit_expr = lit(10i64);
1756        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1757            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1758        });
1759        let scalar_function_expr =
1760            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1761        let plan = LogicalPlan::Projection(Projection::try_new(
1762            vec![scalar_function_expr],
1763            empty,
1764        )?);
1765
1766        assert_analyzed_plan_eq!(
1767            plan,
1768            @r"
1769        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1770          EmptyRelation: rows=0
1771        "
1772        )
1773    }
1774
1775    #[test]
1776    fn agg_udaf() -> Result<()> {
1777        let empty = empty();
1778        let my_avg = create_udaf(
1779            "MY_AVG",
1780            vec![DataType::Float64],
1781            Arc::new(DataType::Float64),
1782            Volatility::Immutable,
1783            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1784            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1785        );
1786        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1787            Arc::new(my_avg),
1788            vec![lit(10i64)],
1789            false,
1790            None,
1791            vec![],
1792            None,
1793        ));
1794        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1795
1796        assert_analyzed_plan_eq!(
1797            plan,
1798            @r"
1799        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1800          EmptyRelation: rows=0
1801        "
1802        )
1803    }
1804
1805    #[test]
1806    fn agg_udaf_invalid_input() -> Result<()> {
1807        let empty = empty();
1808        let return_type = DataType::Float64;
1809        let accumulator: AccumulatorFactoryFunction =
1810            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1811        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1812            "MY_AVG",
1813            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1814            return_type,
1815            accumulator,
1816            vec![
1817                Field::new("count", DataType::UInt64, true).into(),
1818                Field::new("avg", DataType::Float64, true).into(),
1819            ],
1820        ));
1821        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1822            Arc::new(my_avg),
1823            vec![lit("10")],
1824            false,
1825            None,
1826            vec![],
1827            None,
1828        ));
1829
1830        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1831        assert!(
1832            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1833        );
1834        Ok(())
1835    }
1836
1837    #[test]
1838    fn agg_function_case() -> Result<()> {
1839        let empty = empty();
1840        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1841            avg_udaf(),
1842            vec![lit(12f64)],
1843            false,
1844            None,
1845            vec![],
1846            None,
1847        ));
1848        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1849
1850        assert_analyzed_plan_eq!(
1851            plan,
1852            @r"
1853        Projection: avg(Float64(12))
1854          EmptyRelation: rows=0
1855        "
1856        )?;
1857
1858        let empty = empty_with_type(DataType::Int32);
1859        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1860            avg_udaf(),
1861            vec![cast(col("a"), DataType::Float64)],
1862            false,
1863            None,
1864            vec![],
1865            None,
1866        ));
1867        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1868
1869        assert_analyzed_plan_eq!(
1870            plan,
1871            @r"
1872        Projection: avg(CAST(a AS Float64))
1873          EmptyRelation: rows=0
1874        "
1875        )
1876    }
1877
1878    #[test]
1879    fn agg_function_invalid_input_avg() -> Result<()> {
1880        let empty = empty();
1881        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1882            avg_udaf(),
1883            vec![lit("1")],
1884            false,
1885            None,
1886            vec![],
1887            None,
1888        ));
1889        let err = Projection::try_new(vec![agg_expr], empty)
1890            .err()
1891            .unwrap()
1892            .strip_backtrace();
1893        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1894        Ok(())
1895    }
1896
1897    #[test]
1898    fn binary_op_date32_op_interval() -> Result<()> {
1899        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1900        let expr = cast(lit("1998-03-18"), DataType::Date32)
1901            + lit(ScalarValue::new_interval_dt(123, 456));
1902        let empty = empty();
1903        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1904
1905        assert_analyzed_plan_eq!(
1906            plan,
1907            @r#"
1908        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1909          EmptyRelation: rows=0
1910        "#
1911        )
1912    }
1913
1914    #[test]
1915    fn inlist_case() -> Result<()> {
1916        // a in (1,4,8), a is int64
1917        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1918        let empty = empty_with_type(DataType::Int64);
1919        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1920        assert_analyzed_plan_eq!(
1921            plan,
1922            @r"
1923        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1924          EmptyRelation: rows=0
1925        ")?;
1926
1927        // a in (1,4,8), a is decimal
1928        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1929        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1930            produce_one_row: false,
1931            schema: Arc::new(DFSchema::from_unqualified_fields(
1932                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1933                std::collections::HashMap::new(),
1934            )?),
1935        }));
1936        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1937        assert_analyzed_plan_eq!(
1938            plan,
1939            @r"
1940        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1941          EmptyRelation: rows=0
1942        ")
1943    }
1944
1945    #[test]
1946    fn between_case() -> Result<()> {
1947        let expr = col("a").between(
1948            lit("2002-05-08"),
1949            // (cast('2002-05-08' as date) + interval '1 months')
1950            cast(lit("2002-05-08"), DataType::Date32)
1951                + lit(ScalarValue::new_interval_ym(0, 1)),
1952        );
1953        let empty = empty_with_type(Utf8);
1954        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1955
1956        assert_analyzed_plan_eq!(
1957            plan,
1958            @r#"
1959        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1960          EmptyRelation: rows=0
1961        "#
1962        )
1963    }
1964
1965    #[test]
1966    fn between_infer_cheap_type() -> Result<()> {
1967        let expr = col("a").between(
1968            // (cast('2002-05-08' as date) + interval '1 months')
1969            cast(lit("2002-05-08"), DataType::Date32)
1970                + lit(ScalarValue::new_interval_ym(0, 1)),
1971            lit("2002-12-08"),
1972        );
1973        let empty = empty_with_type(Utf8);
1974        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1975
1976        // TODO: we should cast col(a).
1977        assert_analyzed_plan_eq!(
1978            plan,
1979            @r#"
1980        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1981          EmptyRelation: rows=0
1982        "#
1983        )
1984    }
1985
1986    #[test]
1987    fn between_null() -> Result<()> {
1988        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1989        let empty = empty();
1990        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1991
1992        assert_analyzed_plan_eq!(
1993            plan,
1994            @r"
1995        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1996          EmptyRelation: rows=0
1997        "
1998        )
1999    }
2000
2001    #[test]
2002    fn is_bool_for_type_coercion() -> Result<()> {
2003        // is true
2004        let expr = col("a").is_true();
2005        let empty = empty_with_type(DataType::Boolean);
2006        let plan =
2007            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2008
2009        assert_analyzed_plan_eq!(
2010            plan,
2011            @r"
2012        Projection: a IS TRUE
2013          EmptyRelation: rows=0
2014        "
2015        )?;
2016
2017        let empty = empty_with_type(DataType::Int64);
2018        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2019        assert_type_coercion_error(
2020            plan,
2021            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean",
2022        )?;
2023
2024        // is not true
2025        let expr = col("a").is_not_true();
2026        let empty = empty_with_type(DataType::Boolean);
2027        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2028
2029        assert_analyzed_plan_eq!(
2030            plan,
2031            @r"
2032        Projection: a IS NOT TRUE
2033          EmptyRelation: rows=0
2034        "
2035        )?;
2036
2037        // is false
2038        let expr = col("a").is_false();
2039        let empty = empty_with_type(DataType::Boolean);
2040        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2041
2042        assert_analyzed_plan_eq!(
2043            plan,
2044            @r"
2045        Projection: a IS FALSE
2046          EmptyRelation: rows=0
2047        "
2048        )?;
2049
2050        // is not false
2051        let expr = col("a").is_not_false();
2052        let empty = empty_with_type(DataType::Boolean);
2053        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2054
2055        assert_analyzed_plan_eq!(
2056            plan,
2057            @r"
2058        Projection: a IS NOT FALSE
2059          EmptyRelation: rows=0
2060        "
2061        )
2062    }
2063
2064    #[test]
2065    fn like_for_type_coercion() -> Result<()> {
2066        // like : utf8 like "abc"
2067        let expr = Box::new(col("a"));
2068        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2069        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2070        let empty = empty_with_type(Utf8);
2071        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2072
2073        assert_analyzed_plan_eq!(
2074            plan,
2075            @r#"
2076        Projection: a LIKE Utf8("abc")
2077          EmptyRelation: rows=0
2078        "#
2079        )?;
2080
2081        let expr = Box::new(col("a"));
2082        let pattern = Box::new(lit(ScalarValue::Null));
2083        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2084        let empty = empty_with_type(Utf8);
2085        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2086
2087        assert_analyzed_plan_eq!(
2088            plan,
2089            @r"
2090        Projection: a LIKE CAST(NULL AS Utf8)
2091          EmptyRelation: rows=0
2092        "
2093        )?;
2094
2095        let expr = Box::new(col("a"));
2096        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2097        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2098        let empty = empty_with_type(DataType::Int64);
2099        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2100        assert_type_coercion_error(
2101            plan,
2102            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
2103        )?;
2104
2105        // ilike
2106        let expr = Box::new(col("a"));
2107        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2108        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2109        let empty = empty_with_type(Utf8);
2110        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2111
2112        assert_analyzed_plan_eq!(
2113            plan,
2114            @r#"
2115        Projection: a ILIKE Utf8("abc")
2116          EmptyRelation: rows=0
2117        "#
2118        )?;
2119
2120        let expr = Box::new(col("a"));
2121        let pattern = Box::new(lit(ScalarValue::Null));
2122        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2123        let empty = empty_with_type(Utf8);
2124        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2125
2126        assert_analyzed_plan_eq!(
2127            plan,
2128            @r"
2129        Projection: a ILIKE CAST(NULL AS Utf8)
2130          EmptyRelation: rows=0
2131        "
2132        )?;
2133
2134        let expr = Box::new(col("a"));
2135        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2136        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2137        let empty = empty_with_type(DataType::Int64);
2138        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2139        assert_type_coercion_error(
2140            plan,
2141            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2142        )?;
2143
2144        Ok(())
2145    }
2146
2147    #[test]
2148    fn unknown_for_type_coercion() -> Result<()> {
2149        // unknown
2150        let expr = col("a").is_unknown();
2151        let empty = empty_with_type(DataType::Boolean);
2152        let plan =
2153            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2154
2155        assert_analyzed_plan_eq!(
2156            plan,
2157            @r"
2158        Projection: a IS UNKNOWN
2159          EmptyRelation: rows=0
2160        "
2161        )?;
2162
2163        let empty = empty_with_type(Utf8);
2164        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2165        assert_type_coercion_error(
2166            plan,
2167            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean",
2168        )?;
2169
2170        // is not unknown
2171        let expr = col("a").is_not_unknown();
2172        let empty = empty_with_type(DataType::Boolean);
2173        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2174
2175        assert_analyzed_plan_eq!(
2176            plan,
2177            @r"
2178        Projection: a IS NOT UNKNOWN
2179          EmptyRelation: rows=0
2180        "
2181        )
2182    }
2183
2184    #[test]
2185    fn concat_for_type_coercion() -> Result<()> {
2186        let empty = empty_with_type(Utf8);
2187        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2188
2189        // concat-type signature
2190        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2191            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2192        })
2193        .call(args.to_vec());
2194        let plan =
2195            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2196        assert_analyzed_plan_eq!(
2197            plan,
2198            @r#"
2199        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2200          EmptyRelation: rows=0
2201        "#
2202        )
2203    }
2204
2205    #[test]
2206    fn test_type_coercion_rewrite() -> Result<()> {
2207        // gt
2208        let schema = Arc::new(DFSchema::from_unqualified_fields(
2209            vec![Field::new("a", DataType::Int64, true)].into(),
2210            std::collections::HashMap::new(),
2211        )?);
2212        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2213        let expr = is_true(lit(12i32).gt(lit(13i64)));
2214        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2215        let result = expr.rewrite(&mut rewriter).data()?;
2216        assert_eq!(expected, result);
2217
2218        // eq
2219        let schema = Arc::new(DFSchema::from_unqualified_fields(
2220            vec![Field::new("a", DataType::Int64, true)].into(),
2221            std::collections::HashMap::new(),
2222        )?);
2223        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2224        let expr = is_true(lit(12i32).eq(lit(13i64)));
2225        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2226        let result = expr.rewrite(&mut rewriter).data()?;
2227        assert_eq!(expected, result);
2228
2229        // lt
2230        let schema = Arc::new(DFSchema::from_unqualified_fields(
2231            vec![Field::new("a", DataType::Int64, true)].into(),
2232            std::collections::HashMap::new(),
2233        )?);
2234        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2235        let expr = is_true(lit(12i32).lt(lit(13i64)));
2236        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2237        let result = expr.rewrite(&mut rewriter).data()?;
2238        assert_eq!(expected, result);
2239
2240        Ok(())
2241    }
2242
2243    #[test]
2244    fn binary_op_date32_eq_ts() -> Result<()> {
2245        let expr = cast(
2246            lit("1998-03-18"),
2247            DataType::Timestamp(TimeUnit::Nanosecond, None),
2248        )
2249        .eq(cast(lit("1998-03-18"), DataType::Date32));
2250        let empty = empty();
2251        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2252
2253        assert_analyzed_plan_eq!(
2254            plan,
2255            @r#"
2256        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2257          EmptyRelation: rows=0
2258        "#
2259        )
2260    }
2261
2262    fn cast_if_not_same_type(
2263        expr: Box<Expr>,
2264        data_type: &DataType,
2265        schema: &DFSchemaRef,
2266    ) -> Box<Expr> {
2267        if &expr.get_type(schema).unwrap() != data_type {
2268            Box::new(cast(*expr, data_type.clone()))
2269        } else {
2270            expr
2271        }
2272    }
2273
2274    fn cast_helper(
2275        case: Case,
2276        case_when_type: &DataType,
2277        then_else_type: &DataType,
2278        schema: &DFSchemaRef,
2279    ) -> Case {
2280        let expr = case
2281            .expr
2282            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2283        let when_then_expr = case
2284            .when_then_expr
2285            .into_iter()
2286            .map(|(when, then)| {
2287                (
2288                    cast_if_not_same_type(when, case_when_type, schema),
2289                    cast_if_not_same_type(then, then_else_type, schema),
2290                )
2291            })
2292            .collect::<Vec<_>>();
2293        let else_expr = case
2294            .else_expr
2295            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2296
2297        Case {
2298            expr,
2299            when_then_expr,
2300            else_expr,
2301        }
2302    }
2303
2304    #[test]
2305    fn test_case_expression_coercion() -> Result<()> {
2306        let schema = Arc::new(DFSchema::from_unqualified_fields(
2307            vec![
2308                Field::new("boolean", DataType::Boolean, true),
2309                Field::new("integer", DataType::Int32, true),
2310                Field::new("float", DataType::Float32, true),
2311                Field::new(
2312                    "timestamp",
2313                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2314                    true,
2315                ),
2316                Field::new("date", DataType::Date32, true),
2317                Field::new(
2318                    "interval",
2319                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2320                    true,
2321                ),
2322                Field::new("binary", DataType::Binary, true),
2323                Field::new("string", Utf8, true),
2324                Field::new("decimal", DataType::Decimal128(10, 10), true),
2325            ]
2326            .into(),
2327            std::collections::HashMap::new(),
2328        )?);
2329
2330        let case = Case {
2331            expr: None,
2332            when_then_expr: vec![
2333                (Box::new(col("boolean")), Box::new(col("integer"))),
2334                (Box::new(col("integer")), Box::new(col("float"))),
2335                (Box::new(col("string")), Box::new(col("string"))),
2336            ],
2337            else_expr: None,
2338        };
2339        let case_when_common_type = DataType::Boolean;
2340        let then_else_common_type = Utf8;
2341        let expected = cast_helper(
2342            case.clone(),
2343            &case_when_common_type,
2344            &then_else_common_type,
2345            &schema,
2346        );
2347        let actual = coerce_case_expression(case, &schema)?;
2348        assert_eq!(expected, actual);
2349
2350        let case = Case {
2351            expr: Some(Box::new(col("string"))),
2352            when_then_expr: vec![
2353                (Box::new(col("float")), Box::new(col("integer"))),
2354                (Box::new(col("integer")), Box::new(col("float"))),
2355                (Box::new(col("string")), Box::new(col("string"))),
2356            ],
2357            else_expr: Some(Box::new(col("string"))),
2358        };
2359        let case_when_common_type = Utf8;
2360        let then_else_common_type = Utf8;
2361        let expected = cast_helper(
2362            case.clone(),
2363            &case_when_common_type,
2364            &then_else_common_type,
2365            &schema,
2366        );
2367        let actual = coerce_case_expression(case, &schema)?;
2368        assert_eq!(expected, actual);
2369
2370        let case = Case {
2371            expr: Some(Box::new(col("interval"))),
2372            when_then_expr: vec![
2373                (Box::new(col("float")), Box::new(col("integer"))),
2374                (Box::new(col("binary")), Box::new(col("float"))),
2375                (Box::new(col("string")), Box::new(col("string"))),
2376            ],
2377            else_expr: Some(Box::new(col("string"))),
2378        };
2379        let err = coerce_case_expression(case, &schema).unwrap_err();
2380        assert_snapshot!(
2381            err.strip_backtrace(),
2382            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2383        );
2384
2385        let case = Case {
2386            expr: Some(Box::new(col("string"))),
2387            when_then_expr: vec![
2388                (Box::new(col("float")), Box::new(col("date"))),
2389                (Box::new(col("string")), Box::new(col("float"))),
2390                (Box::new(col("string")), Box::new(col("binary"))),
2391            ],
2392            else_expr: Some(Box::new(col("timestamp"))),
2393        };
2394        let err = coerce_case_expression(case, &schema).unwrap_err();
2395        assert_snapshot!(
2396            err.strip_backtrace(),
2397            @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2398        );
2399
2400        Ok(())
2401    }
2402
2403    macro_rules! test_case_expression {
2404        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2405            let case = Case {
2406                expr: $expr.map(|e| Box::new(col(e))),
2407                when_then_expr: $when_then,
2408                else_expr: None,
2409            };
2410
2411            let expected =
2412                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2413
2414            let actual = coerce_case_expression(case, &$schema)?;
2415            assert_eq!(expected, actual);
2416        };
2417    }
2418
2419    #[test]
2420    fn tes_case_when_list() -> Result<()> {
2421        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2422        let schema = Arc::new(DFSchema::from_unqualified_fields(
2423            vec![
2424                Field::new(
2425                    "large_list",
2426                    DataType::LargeList(Arc::clone(&inner_field)),
2427                    true,
2428                ),
2429                Field::new(
2430                    "fixed_list",
2431                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2432                    true,
2433                ),
2434                Field::new("list", DataType::List(inner_field), true),
2435            ]
2436            .into(),
2437            std::collections::HashMap::new(),
2438        )?);
2439
2440        test_case_expression!(
2441            Some("list"),
2442            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2443            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2444            Utf8,
2445            schema
2446        );
2447
2448        test_case_expression!(
2449            Some("large_list"),
2450            vec![(Box::new(col("list")), Box::new(lit("1")))],
2451            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2452            Utf8,
2453            schema
2454        );
2455
2456        test_case_expression!(
2457            Some("list"),
2458            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2459            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2460            Utf8,
2461            schema
2462        );
2463
2464        test_case_expression!(
2465            Some("fixed_list"),
2466            vec![(Box::new(col("list")), Box::new(lit("1")))],
2467            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2468            Utf8,
2469            schema
2470        );
2471
2472        test_case_expression!(
2473            Some("fixed_list"),
2474            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2475            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2476            Utf8,
2477            schema
2478        );
2479
2480        test_case_expression!(
2481            Some("large_list"),
2482            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2483            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2484            Utf8,
2485            schema
2486        );
2487        Ok(())
2488    }
2489
2490    #[test]
2491    fn test_then_else_list() -> Result<()> {
2492        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2493        let schema = Arc::new(DFSchema::from_unqualified_fields(
2494            vec![
2495                Field::new("boolean", DataType::Boolean, true),
2496                Field::new(
2497                    "large_list",
2498                    DataType::LargeList(Arc::clone(&inner_field)),
2499                    true,
2500                ),
2501                Field::new(
2502                    "fixed_list",
2503                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2504                    true,
2505                ),
2506                Field::new("list", DataType::List(inner_field), true),
2507            ]
2508            .into(),
2509            std::collections::HashMap::new(),
2510        )?);
2511
2512        // large list and list
2513        test_case_expression!(
2514            None::<String>,
2515            vec![
2516                (Box::new(col("boolean")), Box::new(col("large_list"))),
2517                (Box::new(col("boolean")), Box::new(col("list")))
2518            ],
2519            DataType::Boolean,
2520            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2521            schema
2522        );
2523
2524        test_case_expression!(
2525            None::<String>,
2526            vec![
2527                (Box::new(col("boolean")), Box::new(col("list"))),
2528                (Box::new(col("boolean")), Box::new(col("large_list")))
2529            ],
2530            DataType::Boolean,
2531            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2532            schema
2533        );
2534
2535        // fixed list and list
2536        test_case_expression!(
2537            None::<String>,
2538            vec![
2539                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2540                (Box::new(col("boolean")), Box::new(col("list")))
2541            ],
2542            DataType::Boolean,
2543            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2544            schema
2545        );
2546
2547        test_case_expression!(
2548            None::<String>,
2549            vec![
2550                (Box::new(col("boolean")), Box::new(col("list"))),
2551                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2552            ],
2553            DataType::Boolean,
2554            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2555            schema
2556        );
2557
2558        // fixed list and large list
2559        test_case_expression!(
2560            None::<String>,
2561            vec![
2562                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2563                (Box::new(col("boolean")), Box::new(col("large_list")))
2564            ],
2565            DataType::Boolean,
2566            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2567            schema
2568        );
2569
2570        test_case_expression!(
2571            None::<String>,
2572            vec![
2573                (Box::new(col("boolean")), Box::new(col("large_list"))),
2574                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2575            ],
2576            DataType::Boolean,
2577            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2578            schema
2579        );
2580        Ok(())
2581    }
2582
2583    #[test]
2584    fn test_map_with_diff_name() -> Result<()> {
2585        let mut builder = SchemaBuilder::new();
2586        builder.push(Field::new("key", Utf8, false));
2587        builder.push(Field::new("value", DataType::Float64, true));
2588        let struct_fields = builder.finish().fields;
2589
2590        let fields =
2591            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2592        let map_type_entries = DataType::Map(Arc::new(fields), false);
2593
2594        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2595        let may_type_custom = DataType::Map(Arc::new(fields), false);
2596
2597        let expr = col("a").eq(cast(col("a"), may_type_custom));
2598        let empty = empty_with_type(map_type_entries);
2599        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2600
2601        assert_analyzed_plan_eq!(
2602            plan,
2603            @r#"
2604        Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted))
2605          EmptyRelation: rows=0
2606        "#
2607        )
2608    }
2609
2610    #[test]
2611    fn interval_plus_timestamp() -> Result<()> {
2612        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2613        let expr = Expr::BinaryExpr(BinaryExpr::new(
2614            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2615            Operator::Plus,
2616            Box::new(cast(
2617                lit("2000-01-01T00:00:00"),
2618                DataType::Timestamp(TimeUnit::Nanosecond, None),
2619            )),
2620        ));
2621        let empty = empty();
2622        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2623
2624        assert_analyzed_plan_eq!(
2625            plan,
2626            @r#"
2627        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2628          EmptyRelation: rows=0
2629        "#
2630        )
2631    }
2632
2633    #[test]
2634    fn timestamp_subtract_timestamp() -> Result<()> {
2635        let expr = Expr::BinaryExpr(BinaryExpr::new(
2636            Box::new(cast(
2637                lit("1998-03-18"),
2638                DataType::Timestamp(TimeUnit::Nanosecond, None),
2639            )),
2640            Operator::Minus,
2641            Box::new(cast(
2642                lit("1998-03-18"),
2643                DataType::Timestamp(TimeUnit::Nanosecond, None),
2644            )),
2645        ));
2646        let empty = empty();
2647        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2648
2649        assert_analyzed_plan_eq!(
2650            plan,
2651            @r#"
2652        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2653          EmptyRelation: rows=0
2654        "#
2655        )
2656    }
2657
2658    #[test]
2659    fn in_subquery_cast_subquery() -> Result<()> {
2660        let empty_int32 = empty_with_type(DataType::Int32);
2661        let empty_int64 = empty_with_type(DataType::Int64);
2662
2663        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2664            Box::new(col("a")),
2665            Subquery {
2666                subquery: empty_int32,
2667                outer_ref_columns: vec![],
2668                spans: Spans::new(),
2669            },
2670            false,
2671        ));
2672        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2673        // add cast for subquery
2674
2675        assert_analyzed_plan_eq!(
2676            plan,
2677            @r"
2678        Filter: a IN (<subquery>)
2679          Subquery:
2680            Projection: CAST(a AS Int64)
2681              EmptyRelation: rows=0
2682          EmptyRelation: rows=0
2683        "
2684        )
2685    }
2686
2687    #[test]
2688    fn in_subquery_cast_expr() -> Result<()> {
2689        let empty_int32 = empty_with_type(DataType::Int32);
2690        let empty_int64 = empty_with_type(DataType::Int64);
2691
2692        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2693            Box::new(col("a")),
2694            Subquery {
2695                subquery: empty_int64,
2696                outer_ref_columns: vec![],
2697                spans: Spans::new(),
2698            },
2699            false,
2700        ));
2701        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2702
2703        // add cast for subquery
2704        assert_analyzed_plan_eq!(
2705            plan,
2706            @r"
2707        Filter: CAST(a AS Int64) IN (<subquery>)
2708          Subquery:
2709            EmptyRelation: rows=0
2710          EmptyRelation: rows=0
2711        "
2712        )
2713    }
2714
2715    #[test]
2716    fn in_subquery_cast_all() -> Result<()> {
2717        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2718        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2719
2720        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2721            Box::new(col("a")),
2722            Subquery {
2723                subquery: empty_inside,
2724                outer_ref_columns: vec![],
2725                spans: Spans::new(),
2726            },
2727            false,
2728        ));
2729        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2730
2731        // add cast for subquery
2732        assert_analyzed_plan_eq!(
2733            plan,
2734            @r"
2735        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2736          Subquery:
2737            Projection: CAST(a AS Decimal128(13, 8))
2738              EmptyRelation: rows=0
2739          EmptyRelation: rows=0
2740        "
2741        )
2742    }
2743}