datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use itertools::izip;
24
25use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
26
27use crate::analyzer::AnalyzerRule;
28use crate::utils::NamePreserver;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
31use datafusion_common::{
32    exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column,
33    DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34};
35use datafusion_expr::expr::{
36    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
37    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
38};
39use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
40use datafusion_expr::expr_schema::cast_subquery;
41use datafusion_expr::logical_plan::Subquery;
42use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
43use datafusion_expr::type_coercion::functions::{
44    data_types_with_scalar_udf, fields_with_aggregate_udf,
45};
46use datafusion_expr::type_coercion::other::{
47    get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
53    AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan,
54    Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound,
55    WindowFrameUnits,
56};
57
58/// Performs type coercion by determining the schema
59/// and performing the expression rewrites.
60#[derive(Default, Debug)]
61pub struct TypeCoercion {}
62
63impl TypeCoercion {
64    pub fn new() -> Self {
65        Self {}
66    }
67}
68
69/// Coerce output schema based upon optimizer config.
70fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
71    if !config.optimizer.expand_views_at_output {
72        return Ok(plan);
73    }
74
75    let outer_refs = plan.expressions();
76    if outer_refs.is_empty() {
77        return Ok(plan);
78    }
79
80    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
81        coerce_plan_expr_for_schema(plan, &dfschema?)
82    } else {
83        Ok(plan)
84    }
85}
86
87impl AnalyzerRule for TypeCoercion {
88    fn name(&self) -> &str {
89        "type_coercion"
90    }
91
92    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
93        let empty_schema = DFSchema::empty();
94
95        // recurse
96        let transformed_plan = plan
97            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
98            .data;
99
100        // finish
101        coerce_output(transformed_plan, config)
102    }
103}
104
105/// use the external schema to handle the correlated subqueries case
106///
107/// Assumes that children have already been optimized
108fn analyze_internal(
109    external_schema: &DFSchema,
110    plan: LogicalPlan,
111) -> Result<Transformed<LogicalPlan>> {
112    // get schema representing all available input fields. This is used for data type
113    // resolution only, so order does not matter here
114    let mut schema = merge_schema(&plan.inputs());
115
116    if let LogicalPlan::TableScan(ts) = &plan {
117        let source_schema = DFSchema::try_from_qualified_schema(
118            ts.table_name.clone(),
119            &ts.source.schema(),
120        )?;
121        schema.merge(&source_schema);
122    }
123
124    // merge the outer schema for correlated subqueries
125    // like case:
126    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
127    schema.merge(external_schema);
128
129    // Coerce filter predicates to boolean (handles `WHERE NULL`)
130    let plan = if let LogicalPlan::Filter(mut filter) = plan {
131        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
132        LogicalPlan::Filter(filter)
133    } else {
134        plan
135    };
136
137    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
138
139    let name_preserver = NamePreserver::new(&plan);
140    // apply coercion rewrite all expressions in the plan individually
141    plan.map_expressions(|expr| {
142        let original_name = name_preserver.save(&expr);
143        expr.rewrite(&mut expr_rewrite)
144            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
145    })?
146    // some plans need extra coercion after their expressions are coerced
147    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
148    // recompute the schema after the expressions have been rewritten as the types may have changed
149    .map_data(|plan| plan.recompute_schema())
150}
151
152/// Rewrite expressions to apply type coercion.
153pub struct TypeCoercionRewriter<'a> {
154    pub(crate) schema: &'a DFSchema,
155}
156
157impl<'a> TypeCoercionRewriter<'a> {
158    /// Create a new [`TypeCoercionRewriter`] with a provided schema
159    /// representing both the inputs and output of the [`LogicalPlan`] node.
160    pub fn new(schema: &'a DFSchema) -> Self {
161        Self { schema }
162    }
163
164    /// Coerce the [`LogicalPlan`].
165    ///
166    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
167    /// for type-coercion approach.
168    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
169        match plan {
170            LogicalPlan::Join(join) => self.coerce_join(join),
171            LogicalPlan::Union(union) => Self::coerce_union(union),
172            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
173            _ => Ok(plan),
174        }
175    }
176
177    /// Coerce join equality expressions and join filter
178    ///
179    /// Joins must be treated specially as their equality expressions are stored
180    /// as a parallel list of left and right expressions, rather than a single
181    /// equality expression
182    ///
183    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
184    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
185    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
186        join.on = join
187            .on
188            .into_iter()
189            .map(|(lhs, rhs)| {
190                // coerce the arguments as though they were a single binary equality
191                // expression
192                let left_schema = join.left.schema();
193                let right_schema = join.right.schema();
194                let (lhs, rhs) = self.coerce_binary_op(
195                    lhs,
196                    left_schema,
197                    Operator::Eq,
198                    rhs,
199                    right_schema,
200                )?;
201                Ok((lhs, rhs))
202            })
203            .collect::<Result<Vec<_>>>()?;
204
205        // Join filter must be boolean
206        join.filter = join
207            .filter
208            .map(|expr| self.coerce_join_filter(expr))
209            .transpose()?;
210
211        Ok(LogicalPlan::Join(join))
212    }
213
214    /// Coerce the union’s inputs to a common schema compatible with all inputs.
215    /// This occurs after wildcard expansion and the coercion of the input expressions.
216    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
217        let union_schema = Arc::new(coerce_union_schema_with_schema(
218            &union_plan.inputs,
219            &union_plan.schema,
220        )?);
221        let new_inputs = union_plan
222            .inputs
223            .into_iter()
224            .map(|p| {
225                let plan =
226                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
227                match plan {
228                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
229                        Ok(Arc::new(project_with_column_index(
230                            expr,
231                            input,
232                            Arc::clone(&union_schema),
233                        )?))
234                    }
235                    other_plan => Ok(Arc::new(other_plan)),
236                }
237            })
238            .collect::<Result<Vec<_>>>()?;
239        Ok(LogicalPlan::Union(Union {
240            inputs: new_inputs,
241            schema: union_schema,
242        }))
243    }
244
245    /// Coerce the fetch and skip expression to Int64 type.
246    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
247        fn coerce_limit_expr(
248            expr: Expr,
249            schema: &DFSchema,
250            expr_name: &str,
251        ) -> Result<Expr> {
252            let dt = expr.get_type(schema)?;
253            if dt.is_integer() || dt.is_null() {
254                expr.cast_to(&DataType::Int64, schema)
255            } else {
256                plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}")
257            }
258        }
259
260        let empty_schema = DFSchema::empty();
261        let new_fetch = limit
262            .fetch
263            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
264            .transpose()?;
265        let new_skip = limit
266            .skip
267            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
268            .transpose()?;
269        Ok(LogicalPlan::Limit(Limit {
270            input: limit.input,
271            fetch: new_fetch.map(Box::new),
272            skip: new_skip.map(Box::new),
273        }))
274    }
275
276    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
277        let expr_type = expr.get_type(self.schema)?;
278        match expr_type {
279            DataType::Boolean => Ok(expr),
280            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
281            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
282        }
283    }
284
285    fn coerce_binary_op(
286        &self,
287        left: Expr,
288        left_schema: &DFSchema,
289        op: Operator,
290        right: Expr,
291        right_schema: &DFSchema,
292    ) -> Result<(Expr, Expr)> {
293        let (left_type, right_type) = BinaryTypeCoercer::new(
294            &left.get_type(left_schema)?,
295            &op,
296            &right.get_type(right_schema)?,
297        )
298        .get_input_types()?;
299
300        Ok((
301            left.cast_to(&left_type, left_schema)?,
302            right.cast_to(&right_type, right_schema)?,
303        ))
304    }
305}
306
307impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
308    type Node = Expr;
309
310    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
311        match expr {
312            Expr::Unnest(_) => not_impl_err!(
313                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
314            ),
315            Expr::ScalarSubquery(Subquery {
316                subquery,
317                outer_ref_columns,
318                spans,
319            }) => {
320                let new_plan =
321                    analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
322                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
323                    subquery: Arc::new(new_plan),
324                    outer_ref_columns,
325                    spans,
326                })))
327            }
328            Expr::Exists(Exists { subquery, negated }) => {
329                let new_plan = analyze_internal(
330                    self.schema,
331                    Arc::unwrap_or_clone(subquery.subquery),
332                )?
333                .data;
334                Ok(Transformed::yes(Expr::Exists(Exists {
335                    subquery: Subquery {
336                        subquery: Arc::new(new_plan),
337                        outer_ref_columns: subquery.outer_ref_columns,
338                        spans: subquery.spans,
339                    },
340                    negated,
341                })))
342            }
343            Expr::InSubquery(InSubquery {
344                expr,
345                subquery,
346                negated,
347            }) => {
348                let new_plan = analyze_internal(
349                    self.schema,
350                    Arc::unwrap_or_clone(subquery.subquery),
351                )?
352                .data;
353                let expr_type = expr.get_type(self.schema)?;
354                let subquery_type = new_plan.schema().field(0).data_type();
355                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
356                        "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
357                    ),
358                )?;
359                let new_subquery = Subquery {
360                    subquery: Arc::new(new_plan),
361                    outer_ref_columns: subquery.outer_ref_columns,
362                    spans: subquery.spans,
363                };
364                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
365                    Box::new(expr.cast_to(&common_type, self.schema)?),
366                    cast_subquery(new_subquery, &common_type)?,
367                    negated,
368                ))))
369            }
370            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
371                *expr,
372                self.schema,
373            )?))),
374            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
375                get_casted_expr_for_bool_op(*expr, self.schema)?,
376            ))),
377            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
378                get_casted_expr_for_bool_op(*expr, self.schema)?,
379            ))),
380            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
381                get_casted_expr_for_bool_op(*expr, self.schema)?,
382            ))),
383            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
384                get_casted_expr_for_bool_op(*expr, self.schema)?,
385            ))),
386            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
387                get_casted_expr_for_bool_op(*expr, self.schema)?,
388            ))),
389            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
390                get_casted_expr_for_bool_op(*expr, self.schema)?,
391            ))),
392            Expr::Like(Like {
393                negated,
394                expr,
395                pattern,
396                escape_char,
397                case_insensitive,
398            }) => {
399                let left_type = expr.get_type(self.schema)?;
400                let right_type = pattern.get_type(self.schema)?;
401                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
402                    let op_name = if case_insensitive {
403                        "ILIKE"
404                    } else {
405                        "LIKE"
406                    };
407                    plan_datafusion_err!(
408                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
409                    )
410                })?;
411                let expr = match left_type {
412                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
413                    _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
414                };
415                let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
416                Ok(Transformed::yes(Expr::Like(Like::new(
417                    negated,
418                    expr,
419                    pattern,
420                    escape_char,
421                    case_insensitive,
422                ))))
423            }
424            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
425                let (left, right) =
426                    self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
427                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
428                    Box::new(left),
429                    op,
430                    Box::new(right),
431                ))))
432            }
433            Expr::Between(Between {
434                expr,
435                negated,
436                low,
437                high,
438            }) => {
439                let expr_type = expr.get_type(self.schema)?;
440                let low_type = low.get_type(self.schema)?;
441                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
442                    .ok_or_else(|| {
443                        DataFusionError::Internal(format!(
444                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
445                        ))
446                    })?;
447                let high_type = high.get_type(self.schema)?;
448                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
449                    .ok_or_else(|| {
450                        DataFusionError::Internal(format!(
451                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
452                        ))
453                    })?;
454                let coercion_type =
455                    comparison_coercion(&low_coerced_type, &high_coerced_type)
456                        .ok_or_else(|| {
457                            DataFusionError::Internal(format!(
458                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
459                            ))
460                        })?;
461                Ok(Transformed::yes(Expr::Between(Between::new(
462                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
463                    negated,
464                    Box::new(low.cast_to(&coercion_type, self.schema)?),
465                    Box::new(high.cast_to(&coercion_type, self.schema)?),
466                ))))
467            }
468            Expr::InList(InList {
469                expr,
470                list,
471                negated,
472            }) => {
473                let expr_data_type = expr.get_type(self.schema)?;
474                let list_data_types = list
475                    .iter()
476                    .map(|list_expr| list_expr.get_type(self.schema))
477                    .collect::<Result<Vec<_>>>()?;
478                let result_type =
479                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
480                match result_type {
481                    None => plan_err!(
482                        "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}"
483                    ),
484                    Some(coerced_type) => {
485                        // find the coerced type
486                        let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
487                        let cast_list_expr = list
488                            .into_iter()
489                            .map(|list_expr| {
490                                list_expr.cast_to(&coerced_type, self.schema)
491                            })
492                            .collect::<Result<Vec<_>>>()?;
493                        Ok(Transformed::yes(Expr::InList(InList ::new(
494                             Box::new(cast_expr),
495                             cast_list_expr,
496                            negated,
497                        ))))
498                    }
499                }
500            }
501            Expr::Case(case) => {
502                let case = coerce_case_expression(case, self.schema)?;
503                Ok(Transformed::yes(Expr::Case(case)))
504            }
505            Expr::ScalarFunction(ScalarFunction { func, args }) => {
506                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
507                    args,
508                    self.schema,
509                    &func,
510                )?;
511                Ok(Transformed::yes(Expr::ScalarFunction(
512                    ScalarFunction::new_udf(func, new_expr),
513                )))
514            }
515            Expr::AggregateFunction(expr::AggregateFunction {
516                func,
517                params:
518                    AggregateFunctionParams {
519                        args,
520                        distinct,
521                        filter,
522                        order_by,
523                        null_treatment,
524                    },
525            }) => {
526                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
527                    args,
528                    self.schema,
529                    &func,
530                )?;
531                Ok(Transformed::yes(Expr::AggregateFunction(
532                    expr::AggregateFunction::new_udf(
533                        func,
534                        new_expr,
535                        distinct,
536                        filter,
537                        order_by,
538                        null_treatment,
539                    ),
540                )))
541            }
542            Expr::WindowFunction(window_fun) => {
543                let WindowFunction {
544                    fun,
545                    params:
546                        expr::WindowFunctionParams {
547                            args,
548                            partition_by,
549                            order_by,
550                            window_frame,
551                            null_treatment,
552                        },
553                } = *window_fun;
554                let window_frame =
555                    coerce_window_frame(window_frame, self.schema, &order_by)?;
556
557                let args = match &fun {
558                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
559                        coerce_arguments_for_signature_with_aggregate_udf(
560                            args,
561                            self.schema,
562                            udf,
563                        )?
564                    }
565                    _ => args,
566                };
567
568                Ok(Transformed::yes(
569                    Expr::from(WindowFunction::new(fun, args))
570                        .partition_by(partition_by)
571                        .order_by(order_by)
572                        .window_frame(window_frame)
573                        .null_treatment(null_treatment)
574                        .build()?,
575                ))
576            }
577            // TODO: remove the next line after `Expr::Wildcard` is removed
578            #[expect(deprecated)]
579            Expr::Alias(_)
580            | Expr::Column(_)
581            | Expr::ScalarVariable(_, _)
582            | Expr::Literal(_, _)
583            | Expr::SimilarTo(_)
584            | Expr::IsNotNull(_)
585            | Expr::IsNull(_)
586            | Expr::Negative(_)
587            | Expr::Cast(_)
588            | Expr::TryCast(_)
589            | Expr::Wildcard { .. }
590            | Expr::GroupingSet(_)
591            | Expr::Placeholder(_)
592            | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
593        }
594    }
595}
596
597/// Transform a schema to use non-view types for Utf8View and BinaryView
598fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
599    let metadata = dfschema.as_arrow().metadata.clone();
600    let mut transformed = false;
601
602    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
603        dfschema
604            .iter()
605            .map(|(qualifier, field)| match field.data_type() {
606                DataType::Utf8View => {
607                    transformed = true;
608                    (
609                        qualifier.cloned() as Option<TableReference>,
610                        Arc::new(Field::new(
611                            field.name(),
612                            DataType::LargeUtf8,
613                            field.is_nullable(),
614                        )),
615                    )
616                }
617                DataType::BinaryView => {
618                    transformed = true;
619                    (
620                        qualifier.cloned() as Option<TableReference>,
621                        Arc::new(Field::new(
622                            field.name(),
623                            DataType::LargeBinary,
624                            field.is_nullable(),
625                        )),
626                    )
627                }
628                _ => (
629                    qualifier.cloned() as Option<TableReference>,
630                    Arc::clone(field),
631                ),
632            })
633            .unzip();
634
635    if !transformed {
636        return None;
637    }
638
639    let schema = Schema::new_with_metadata(transformed_fields, metadata);
640    Some(DFSchema::from_field_specific_qualified_schema(
641        qualifiers,
642        &Arc::new(schema),
643    ))
644}
645
646/// Casts the given `value` to `target_type`. Note that this function
647/// only considers `Null` or `Utf8` values.
648fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
649    match value {
650        // Coerce Utf8 values:
651        ScalarValue::Utf8(Some(val)) => {
652            ScalarValue::try_from_string(val.clone(), target_type)
653        }
654        s => {
655            if s.is_null() {
656                // Coerce `Null` values:
657                ScalarValue::try_from(target_type)
658            } else {
659                // Values except `Utf8`/`Null` variants already have the right type
660                // (casted before) since we convert `sqlparser` outputs to `Utf8`
661                // for all possible cases. Therefore, we return a clone here.
662                Ok(s.clone())
663            }
664        }
665    }
666}
667
668/// This function coerces `value` to `target_type` in a range-aware fashion.
669/// If the coercion is successful, we return an `Ok` value with the result.
670/// If the coercion fails because `target_type` is not wide enough (i.e. we
671/// can not coerce to `target_type`, but we can to a wider type in the same
672/// family), we return a `Null` value of this type to signal this situation.
673/// Downstream code uses this signal to treat these values as *unbounded*.
674fn coerce_scalar_range_aware(
675    target_type: &DataType,
676    value: &ScalarValue,
677) -> Result<ScalarValue> {
678    coerce_scalar(target_type, value).or_else(|err| {
679        // If type coercion fails, check if the largest type in family works:
680        if let Some(largest_type) = get_widest_type_in_family(target_type) {
681            coerce_scalar(largest_type, value).map_or_else(
682                |_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
683                |_| ScalarValue::try_from(target_type),
684            )
685        } else {
686            Err(err)
687        }
688    })
689}
690
691/// This function returns the widest type in the family of `given_type`.
692/// If the given type is already the widest type, it returns `None`.
693/// For example, if `given_type` is `Int8`, it returns `Int64`.
694fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
695    match given_type {
696        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
697        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
698        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
699        _ => None,
700    }
701}
702
703/// Coerces the given (window frame) `bound` to `target_type`.
704fn coerce_frame_bound(
705    target_type: &DataType,
706    bound: WindowFrameBound,
707) -> Result<WindowFrameBound> {
708    match bound {
709        WindowFrameBound::Preceding(v) => {
710            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
711        }
712        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
713        WindowFrameBound::Following(v) => {
714            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
715        }
716    }
717}
718
719fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
720    if col_type.is_numeric()
721        || is_utf8_or_utf8view_or_large_utf8(col_type)
722        || matches!(col_type, DataType::List(_))
723        || matches!(col_type, DataType::LargeList(_))
724        || matches!(col_type, DataType::FixedSizeList(_, _))
725        || matches!(col_type, DataType::Null)
726        || matches!(col_type, DataType::Boolean)
727    {
728        Ok(col_type.clone())
729    } else if is_datetime(col_type) {
730        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
731    } else if let DataType::Dictionary(_, value_type) = col_type {
732        extract_window_frame_target_type(value_type)
733    } else {
734        return internal_err!("Cannot run range queries on datatype: {col_type:?}");
735    }
736}
737
738// Coerces the given `window_frame` to use appropriate natural types.
739// For example, ROWS and GROUPS frames use `UInt64` during calculations.
740fn coerce_window_frame(
741    window_frame: WindowFrame,
742    schema: &DFSchema,
743    expressions: &[Sort],
744) -> Result<WindowFrame> {
745    let mut window_frame = window_frame;
746    let target_type = match window_frame.units {
747        WindowFrameUnits::Range => {
748            let current_types = expressions
749                .first()
750                .map(|s| s.expr.get_type(schema))
751                .transpose()?;
752            if let Some(col_type) = current_types {
753                extract_window_frame_target_type(&col_type)?
754            } else {
755                return internal_err!("ORDER BY column cannot be empty");
756            }
757        }
758        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
759    };
760    window_frame.start_bound =
761        coerce_frame_bound(&target_type, window_frame.start_bound)?;
762    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
763    Ok(window_frame)
764}
765
766// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
767// The above op will be rewrite to the binary op when creating the physical op.
768fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
769    let left_type = expr.get_type(schema)?;
770    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
771        .get_input_types()?;
772    expr.cast_to(&DataType::Boolean, schema)
773}
774
775/// Returns `expressions` coerced to types compatible with
776/// `signature`, if possible.
777///
778/// See the module level documentation for more detail on coercion.
779fn coerce_arguments_for_signature_with_scalar_udf(
780    expressions: Vec<Expr>,
781    schema: &DFSchema,
782    func: &ScalarUDF,
783) -> Result<Vec<Expr>> {
784    if expressions.is_empty() {
785        return Ok(expressions);
786    }
787
788    let current_types = expressions
789        .iter()
790        .map(|e| e.get_type(schema))
791        .collect::<Result<Vec<_>>>()?;
792
793    let new_types = data_types_with_scalar_udf(&current_types, func)?;
794
795    expressions
796        .into_iter()
797        .enumerate()
798        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
799        .collect()
800}
801
802/// Returns `expressions` coerced to types compatible with
803/// `signature`, if possible.
804///
805/// See the module level documentation for more detail on coercion.
806fn coerce_arguments_for_signature_with_aggregate_udf(
807    expressions: Vec<Expr>,
808    schema: &DFSchema,
809    func: &AggregateUDF,
810) -> Result<Vec<Expr>> {
811    if expressions.is_empty() {
812        return Ok(expressions);
813    }
814
815    let current_fields = expressions
816        .iter()
817        .map(|e| e.to_field(schema).map(|(_, f)| f))
818        .collect::<Result<Vec<_>>>()?;
819
820    let new_types = fields_with_aggregate_udf(&current_fields, func)?
821        .into_iter()
822        .map(|f| f.data_type().clone())
823        .collect::<Vec<_>>();
824
825    expressions
826        .into_iter()
827        .enumerate()
828        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
829        .collect()
830}
831
832fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
833    // Given expressions like:
834    //
835    // CASE a1
836    //   WHEN a2 THEN b1
837    //   WHEN a3 THEN b2
838    //   ELSE b3
839    // END
840    //
841    // or:
842    //
843    // CASE
844    //   WHEN x1 THEN b1
845    //   WHEN x2 THEN b2
846    //   ELSE b3
847    // END
848    //
849    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
850    // (case-when expression coercion)
851    //
852    // All xN (x1, x2) must be converted to a boolean data type in the second example
853    // (when-boolean expression coercion)
854    //
855    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
856    // (then-else expression coercion)
857    //
858    // If any fail to find and cast to a common/specific data type, will return error
859    //
860    // Note that case-when and when-boolean expression coercions are mutually exclusive
861    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
862
863    // prepare types
864    let case_type = case
865        .expr
866        .as_ref()
867        .map(|expr| expr.get_type(schema))
868        .transpose()?;
869    let then_types = case
870        .when_then_expr
871        .iter()
872        .map(|(_when, then)| then.get_type(schema))
873        .collect::<Result<Vec<_>>>()?;
874    let else_type = case
875        .else_expr
876        .as_ref()
877        .map(|expr| expr.get_type(schema))
878        .transpose()?;
879
880    // find common coercible types
881    let case_when_coerce_type = case_type
882        .as_ref()
883        .map(|case_type| {
884            let when_types = case
885                .when_then_expr
886                .iter()
887                .map(|(when, _then)| when.get_type(schema))
888                .collect::<Result<Vec<_>>>()?;
889            let coerced_type =
890                get_coerce_type_for_case_expression(&when_types, Some(case_type));
891            coerced_type.ok_or_else(|| {
892                plan_datafusion_err!(
893                    "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \
894                     to common types in CASE WHEN expression"
895                )
896            })
897        })
898        .transpose()?;
899    let then_else_coerce_type =
900        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
901            || {
902                plan_datafusion_err!(
903                    "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \
904                     to common types in CASE WHEN expression"
905                )
906            },
907        )?;
908
909    // do cast if found common coercible types
910    let case_expr = case
911        .expr
912        .zip(case_when_coerce_type.as_ref())
913        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
914        .transpose()?
915        .map(Box::new);
916    let when_then = case
917        .when_then_expr
918        .into_iter()
919        .map(|(when, then)| {
920            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
921            let when = when.cast_to(when_type, schema).map_err(|e| {
922                DataFusionError::Context(
923                    format!(
924                        "WHEN expressions in CASE couldn't be \
925                         converted to common type ({when_type})"
926                    ),
927                    Box::new(e),
928                )
929            })?;
930            let then = then.cast_to(&then_else_coerce_type, schema)?;
931            Ok((Box::new(when), Box::new(then)))
932        })
933        .collect::<Result<Vec<_>>>()?;
934    let else_expr = case
935        .else_expr
936        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
937        .transpose()?
938        .map(Box::new);
939
940    Ok(Case::new(case_expr, when_then, else_expr))
941}
942
943/// Get a common schema that is compatible with all inputs of UNION.
944///
945/// This method presumes that the wildcard expansion is unneeded, or has already
946/// been applied.
947pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
948    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
949}
950fn coerce_union_schema_with_schema(
951    inputs: &[Arc<LogicalPlan>],
952    base_schema: &DFSchemaRef,
953) -> Result<DFSchema> {
954    let mut union_datatypes = base_schema
955        .fields()
956        .iter()
957        .map(|f| f.data_type().clone())
958        .collect::<Vec<_>>();
959    let mut union_nullabilities = base_schema
960        .fields()
961        .iter()
962        .map(|f| f.is_nullable())
963        .collect::<Vec<_>>();
964    let mut union_field_meta = base_schema
965        .fields()
966        .iter()
967        .map(|f| f.metadata().clone())
968        .collect::<Vec<_>>();
969
970    let mut metadata = base_schema.metadata().clone();
971
972    for (i, plan) in inputs.iter().enumerate() {
973        let plan_schema = plan.schema();
974        metadata.extend(plan_schema.metadata().clone());
975
976        if plan_schema.fields().len() != base_schema.fields().len() {
977            return plan_err!(
978                "Union schemas have different number of fields: \
979                query 1 has {} fields whereas query {} has {} fields",
980                base_schema.fields().len(),
981                i + 1,
982                plan_schema.fields().len()
983            );
984        }
985
986        // coerce data type and nullability for each field
987        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
988            union_datatypes.iter_mut(),
989            union_nullabilities.iter_mut(),
990            union_field_meta.iter_mut(),
991            plan_schema.fields().iter()
992        ) {
993            let coerced_type =
994                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
995                    || {
996                        plan_datafusion_err!(
997                            "Incompatible inputs for Union: Previous inputs were \
998                            of type {}, but got incompatible type {} on column '{}'",
999                            union_datatype,
1000                            plan_field.data_type(),
1001                            plan_field.name()
1002                        )
1003                    },
1004                )?;
1005
1006            *union_datatype = coerced_type;
1007            *union_nullable = *union_nullable || plan_field.is_nullable();
1008            union_field_map.extend(plan_field.metadata().clone());
1009        }
1010    }
1011    let union_qualified_fields = izip!(
1012        base_schema.fields(),
1013        union_datatypes.into_iter(),
1014        union_nullabilities,
1015        union_field_meta.into_iter()
1016    )
1017    .map(|(field, datatype, nullable, metadata)| {
1018        let mut field = Field::new(field.name().clone(), datatype, nullable);
1019        field.set_metadata(metadata);
1020        (None, field.into())
1021    })
1022    .collect::<Vec<_>>();
1023
1024    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1025}
1026
1027/// See `<https://github.com/apache/datafusion/pull/2108>`
1028fn project_with_column_index(
1029    expr: Vec<Expr>,
1030    input: Arc<LogicalPlan>,
1031    schema: DFSchemaRef,
1032) -> Result<LogicalPlan> {
1033    let alias_expr = expr
1034        .into_iter()
1035        .enumerate()
1036        .map(|(i, e)| match e {
1037            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1038                Ok(e.unalias().alias(schema.field(i).name()))
1039            }
1040            Expr::Column(Column {
1041                relation: _,
1042                ref name,
1043                spans: _,
1044            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1045            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1046            #[expect(deprecated)]
1047            Expr::Wildcard { .. } => {
1048                plan_err!("Wildcard should be expanded before type coercion")
1049            }
1050            _ => Ok(e.alias(schema.field(i).name())),
1051        })
1052        .collect::<Result<Vec<_>>>()?;
1053
1054    Projection::try_new_with_schema(alias_expr, input, schema)
1055        .map(LogicalPlan::Projection)
1056}
1057
1058#[cfg(test)]
1059mod test {
1060    use std::any::Any;
1061    use std::sync::Arc;
1062
1063    use arrow::datatypes::DataType::Utf8;
1064    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1065    use insta::assert_snapshot;
1066
1067    use crate::analyzer::type_coercion::{
1068        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1069    };
1070    use crate::analyzer::Analyzer;
1071    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1072    use datafusion_common::config::ConfigOptions;
1073    use datafusion_common::tree_node::{TransformedResult, TreeNode};
1074    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1075    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1076    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1077    use datafusion_expr::test::function_stub::avg_udaf;
1078    use datafusion_expr::{
1079        cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1080        BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1081        Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1082        SimpleAggregateUDF, Subquery, Union, Volatility,
1083    };
1084    use datafusion_functions_aggregate::average::AvgAccumulator;
1085    use datafusion_sql::TableReference;
1086
1087    fn empty() -> Arc<LogicalPlan> {
1088        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1089            produce_one_row: false,
1090            schema: Arc::new(DFSchema::empty()),
1091        }))
1092    }
1093
1094    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1095        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1096            produce_one_row: false,
1097            schema: Arc::new(
1098                DFSchema::from_unqualified_fields(
1099                    vec![Field::new("a", data_type, true)].into(),
1100                    std::collections::HashMap::new(),
1101                )
1102                .unwrap(),
1103            ),
1104        }))
1105    }
1106
1107    macro_rules! assert_analyzed_plan_eq {
1108        (
1109            $plan: expr,
1110            @ $expected: literal $(,)?
1111        ) => {{
1112            let options = ConfigOptions::default();
1113            let rule = Arc::new(TypeCoercion::new());
1114            assert_analyzed_plan_with_config_eq_snapshot!(
1115                options,
1116                rule,
1117                $plan,
1118                @ $expected,
1119            )
1120            }};
1121    }
1122
1123    macro_rules! coerce_on_output_if_viewtype {
1124        (
1125            $is_viewtype: expr,
1126            $plan: expr,
1127            @ $expected: literal $(,)?
1128        ) => {{
1129            let mut options = ConfigOptions::default();
1130            // coerce on output
1131            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1132            let rule = Arc::new(TypeCoercion::new());
1133
1134            assert_analyzed_plan_with_config_eq_snapshot!(
1135                options,
1136                rule,
1137                $plan,
1138                @ $expected,
1139            )
1140        }};
1141    }
1142
1143    fn assert_type_coercion_error(
1144        plan: LogicalPlan,
1145        expected_substr: &str,
1146    ) -> Result<()> {
1147        let options = ConfigOptions::default();
1148        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1149
1150        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1151            Ok(succeeded_plan) => {
1152                panic!(
1153                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1154                );
1155            }
1156            Err(e) => {
1157                let msg = e.to_string();
1158                assert!(
1159                    msg.contains(expected_substr),
1160                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1161                );
1162            }
1163        }
1164
1165        Ok(())
1166    }
1167
1168    #[test]
1169    fn simple_case() -> Result<()> {
1170        let expr = col("a").lt(lit(2_u32));
1171        let empty = empty_with_type(DataType::Float64);
1172        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1173
1174        assert_analyzed_plan_eq!(
1175            plan,
1176            @r"
1177        Projection: a < CAST(UInt32(2) AS Float64)
1178          EmptyRelation
1179        "
1180        )
1181    }
1182
1183    #[test]
1184    fn test_coerce_union() -> Result<()> {
1185        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1186            produce_one_row: false,
1187            schema: Arc::new(
1188                DFSchema::try_from_qualified_schema(
1189                    TableReference::full("datafusion", "test", "foo"),
1190                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1191                )
1192                .unwrap(),
1193            ),
1194        }));
1195        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1196            produce_one_row: false,
1197            schema: Arc::new(
1198                DFSchema::try_from_qualified_schema(
1199                    TableReference::full("datafusion", "test", "foo"),
1200                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1201                )
1202                .unwrap(),
1203            ),
1204        }));
1205        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1206            left_plan, right_plan,
1207        ])?);
1208        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1209            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1210        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1211            vec![col("a")],
1212            Arc::new(analyzed_union),
1213        )?);
1214
1215        assert_analyzed_plan_eq!(
1216            top_level_plan,
1217            @r"
1218        Projection: a
1219          Union
1220            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1221              EmptyRelation
1222            EmptyRelation
1223        "
1224        )
1225    }
1226
1227    #[test]
1228    fn coerce_utf8view_output() -> Result<()> {
1229        // Plan A
1230        // scenario: outermost utf8view projection
1231        let expr = col("a");
1232        let empty = empty_with_type(DataType::Utf8View);
1233        let plan = LogicalPlan::Projection(Projection::try_new(
1234            vec![expr.clone()],
1235            Arc::clone(&empty),
1236        )?);
1237
1238        // Plan A: no coerce
1239        coerce_on_output_if_viewtype!(
1240            false,
1241            plan.clone(),
1242            @r"
1243        Projection: a
1244          EmptyRelation
1245        "
1246        )?;
1247
1248        // Plan A: coerce requested: Utf8View => LargeUtf8
1249        coerce_on_output_if_viewtype!(
1250            true,
1251            plan.clone(),
1252            @r"
1253        Projection: CAST(a AS LargeUtf8)
1254          EmptyRelation
1255        "
1256        )?;
1257
1258        // Plan B
1259        // scenario: outermost bool projection
1260        let bool_expr = col("a").lt(lit("foo"));
1261        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1262            vec![bool_expr],
1263            Arc::clone(&empty),
1264        )?);
1265        // Plan B: no coerce
1266        coerce_on_output_if_viewtype!(
1267            false,
1268            bool_plan.clone(),
1269            @r#"
1270        Projection: a < CAST(Utf8("foo") AS Utf8View)
1271          EmptyRelation
1272        "#
1273        )?;
1274
1275        coerce_on_output_if_viewtype!(
1276            false,
1277            plan.clone(),
1278            @r"
1279        Projection: a
1280          EmptyRelation
1281        "
1282        )?;
1283
1284        // Plan B: coerce requested: no coercion applied
1285        coerce_on_output_if_viewtype!(
1286            true,
1287            plan.clone(),
1288            @r"
1289        Projection: CAST(a AS LargeUtf8)
1290          EmptyRelation
1291        "
1292        )?;
1293
1294        // Plan C
1295        // scenario: with a non-projection root logical plan node
1296        let sort_expr = expr.sort(true, true);
1297        let sort_plan = LogicalPlan::Sort(Sort {
1298            expr: vec![sort_expr],
1299            input: Arc::new(plan),
1300            fetch: None,
1301        });
1302
1303        // Plan C: no coerce
1304        coerce_on_output_if_viewtype!(
1305            false,
1306            sort_plan.clone(),
1307            @r"
1308        Sort: a ASC NULLS FIRST
1309          Projection: a
1310            EmptyRelation
1311        "
1312        )?;
1313
1314        // Plan C: coerce requested: Utf8View => LargeUtf8
1315        coerce_on_output_if_viewtype!(
1316            true,
1317            sort_plan.clone(),
1318            @r"
1319        Projection: CAST(a AS LargeUtf8)
1320          Sort: a ASC NULLS FIRST
1321            Projection: a
1322              EmptyRelation
1323        "
1324        )?;
1325
1326        // Plan D
1327        // scenario: two layers of projections with view types
1328        let plan = LogicalPlan::Projection(Projection::try_new(
1329            vec![col("a")],
1330            Arc::new(sort_plan),
1331        )?);
1332        // Plan D: no coerce
1333        coerce_on_output_if_viewtype!(
1334            false,
1335            plan.clone(),
1336            @r"
1337        Projection: a
1338          Sort: a ASC NULLS FIRST
1339            Projection: a
1340              EmptyRelation
1341        "
1342        )?;
1343        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1344        coerce_on_output_if_viewtype!(
1345            true,
1346            plan.clone(),
1347            @r"
1348        Projection: CAST(a AS LargeUtf8)
1349          Sort: a ASC NULLS FIRST
1350            Projection: a
1351              EmptyRelation
1352        "
1353        )?;
1354
1355        Ok(())
1356    }
1357
1358    #[test]
1359    fn coerce_binaryview_output() -> Result<()> {
1360        // Plan A
1361        // scenario: outermost binaryview projection
1362        let expr = col("a");
1363        let empty = empty_with_type(DataType::BinaryView);
1364        let plan = LogicalPlan::Projection(Projection::try_new(
1365            vec![expr.clone()],
1366            Arc::clone(&empty),
1367        )?);
1368
1369        // Plan A: no coerce
1370        coerce_on_output_if_viewtype!(
1371            false,
1372            plan.clone(),
1373            @r"
1374        Projection: a
1375          EmptyRelation
1376        "
1377        )?;
1378
1379        // Plan A: coerce requested: BinaryView => LargeBinary
1380        coerce_on_output_if_viewtype!(
1381            true,
1382            plan.clone(),
1383            @r"
1384        Projection: CAST(a AS LargeBinary)
1385          EmptyRelation
1386        "
1387        )?;
1388
1389        // Plan B
1390        // scenario: outermost bool projection
1391        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1392        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1393            vec![bool_expr],
1394            Arc::clone(&empty),
1395        )?);
1396
1397        // Plan B: no coerce
1398        coerce_on_output_if_viewtype!(
1399            false,
1400            bool_plan.clone(),
1401            @r#"
1402        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1403          EmptyRelation
1404        "#
1405        )?;
1406
1407        // Plan B: coerce requested: no coercion applied
1408        coerce_on_output_if_viewtype!(
1409            true,
1410            bool_plan.clone(),
1411            @r#"
1412        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1413          EmptyRelation
1414        "#
1415        )?;
1416
1417        // Plan C
1418        // scenario: with a non-projection root logical plan node
1419        let sort_expr = expr.sort(true, true);
1420        let sort_plan = LogicalPlan::Sort(Sort {
1421            expr: vec![sort_expr],
1422            input: Arc::new(plan),
1423            fetch: None,
1424        });
1425
1426        // Plan C: no coerce
1427        coerce_on_output_if_viewtype!(
1428            false,
1429            sort_plan.clone(),
1430            @r"
1431        Sort: a ASC NULLS FIRST
1432          Projection: a
1433            EmptyRelation
1434        "
1435        )?;
1436        // Plan C: coerce requested: BinaryView => LargeBinary
1437        coerce_on_output_if_viewtype!(
1438            true,
1439            sort_plan.clone(),
1440            @r"
1441        Projection: CAST(a AS LargeBinary)
1442          Sort: a ASC NULLS FIRST
1443            Projection: a
1444              EmptyRelation
1445        "
1446        )?;
1447
1448        // Plan D
1449        // scenario: two layers of projections with view types
1450        let plan = LogicalPlan::Projection(Projection::try_new(
1451            vec![col("a")],
1452            Arc::new(sort_plan),
1453        )?);
1454
1455        // Plan D: no coerce
1456        coerce_on_output_if_viewtype!(
1457            false,
1458            plan.clone(),
1459            @r"
1460        Projection: a
1461          Sort: a ASC NULLS FIRST
1462            Projection: a
1463              EmptyRelation
1464        "
1465        )?;
1466
1467        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1468        coerce_on_output_if_viewtype!(
1469            true,
1470            plan.clone(),
1471            @r"
1472        Projection: CAST(a AS LargeBinary)
1473          Sort: a ASC NULLS FIRST
1474            Projection: a
1475              EmptyRelation
1476        "
1477        )?;
1478
1479        Ok(())
1480    }
1481
1482    #[test]
1483    fn nested_case() -> Result<()> {
1484        let expr = col("a").lt(lit(2_u32));
1485        let empty = empty_with_type(DataType::Float64);
1486
1487        let plan = LogicalPlan::Projection(Projection::try_new(
1488            vec![expr.clone().or(expr)],
1489            empty,
1490        )?);
1491
1492        assert_analyzed_plan_eq!(
1493            plan,
1494            @r"
1495        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1496          EmptyRelation
1497        "
1498        )
1499    }
1500
1501    #[derive(Debug, Clone)]
1502    struct TestScalarUDF {
1503        signature: Signature,
1504    }
1505
1506    impl ScalarUDFImpl for TestScalarUDF {
1507        fn as_any(&self) -> &dyn Any {
1508            self
1509        }
1510
1511        fn name(&self) -> &str {
1512            "TestScalarUDF"
1513        }
1514
1515        fn signature(&self) -> &Signature {
1516            &self.signature
1517        }
1518
1519        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1520            Ok(Utf8)
1521        }
1522
1523        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1524            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1525        }
1526    }
1527
1528    #[test]
1529    fn scalar_udf() -> Result<()> {
1530        let empty = empty();
1531
1532        let udf = ScalarUDF::from(TestScalarUDF {
1533            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1534        })
1535        .call(vec![lit(123_i32)]);
1536        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1537
1538        assert_analyzed_plan_eq!(
1539            plan,
1540            @r"
1541        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1542          EmptyRelation
1543        "
1544        )
1545    }
1546
1547    #[test]
1548    fn scalar_udf_invalid_input() -> Result<()> {
1549        let empty = empty();
1550        let udf = ScalarUDF::from(TestScalarUDF {
1551            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1552        })
1553        .call(vec![lit("Apple")]);
1554        Projection::try_new(vec![udf], empty)
1555            .expect_err("Expected an error due to incorrect function input");
1556
1557        Ok(())
1558    }
1559
1560    #[test]
1561    fn scalar_function() -> Result<()> {
1562        // test that automatic argument type coercion for scalar functions work
1563        let empty = empty();
1564        let lit_expr = lit(10i64);
1565        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1566            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1567        });
1568        let scalar_function_expr =
1569            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1570        let plan = LogicalPlan::Projection(Projection::try_new(
1571            vec![scalar_function_expr],
1572            empty,
1573        )?);
1574
1575        assert_analyzed_plan_eq!(
1576            plan,
1577            @r"
1578        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1579          EmptyRelation
1580        "
1581        )
1582    }
1583
1584    #[test]
1585    fn agg_udaf() -> Result<()> {
1586        let empty = empty();
1587        let my_avg = create_udaf(
1588            "MY_AVG",
1589            vec![DataType::Float64],
1590            Arc::new(DataType::Float64),
1591            Volatility::Immutable,
1592            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1593            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1594        );
1595        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1596            Arc::new(my_avg),
1597            vec![lit(10i64)],
1598            false,
1599            None,
1600            vec![],
1601            None,
1602        ));
1603        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1604
1605        assert_analyzed_plan_eq!(
1606            plan,
1607            @r"
1608        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1609          EmptyRelation
1610        "
1611        )
1612    }
1613
1614    #[test]
1615    fn agg_udaf_invalid_input() -> Result<()> {
1616        let empty = empty();
1617        let return_type = DataType::Float64;
1618        let accumulator: AccumulatorFactoryFunction =
1619            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1620        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1621            "MY_AVG",
1622            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1623            return_type,
1624            accumulator,
1625            vec![
1626                Field::new("count", DataType::UInt64, true).into(),
1627                Field::new("avg", DataType::Float64, true).into(),
1628            ],
1629        ));
1630        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1631            Arc::new(my_avg),
1632            vec![lit("10")],
1633            false,
1634            None,
1635            vec![],
1636            None,
1637        ));
1638
1639        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1640        assert!(
1641            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed")
1642        );
1643        Ok(())
1644    }
1645
1646    #[test]
1647    fn agg_function_case() -> Result<()> {
1648        let empty = empty();
1649        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1650            avg_udaf(),
1651            vec![lit(12f64)],
1652            false,
1653            None,
1654            vec![],
1655            None,
1656        ));
1657        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1658
1659        assert_analyzed_plan_eq!(
1660            plan,
1661            @r"
1662        Projection: avg(Float64(12))
1663          EmptyRelation
1664        "
1665        )?;
1666
1667        let empty = empty_with_type(DataType::Int32);
1668        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1669            avg_udaf(),
1670            vec![cast(col("a"), DataType::Float64)],
1671            false,
1672            None,
1673            vec![],
1674            None,
1675        ));
1676        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1677
1678        assert_analyzed_plan_eq!(
1679            plan,
1680            @r"
1681        Projection: avg(CAST(a AS Float64))
1682          EmptyRelation
1683        "
1684        )
1685    }
1686
1687    #[test]
1688    fn agg_function_invalid_input_avg() -> Result<()> {
1689        let empty = empty();
1690        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1691            avg_udaf(),
1692            vec![lit("1")],
1693            false,
1694            None,
1695            vec![],
1696            None,
1697        ));
1698        let err = Projection::try_new(vec![agg_expr], empty)
1699            .err()
1700            .unwrap()
1701            .strip_backtrace();
1702        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1703        Ok(())
1704    }
1705
1706    #[test]
1707    fn binary_op_date32_op_interval() -> Result<()> {
1708        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1709        let expr = cast(lit("1998-03-18"), DataType::Date32)
1710            + lit(ScalarValue::new_interval_dt(123, 456));
1711        let empty = empty();
1712        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1713
1714        assert_analyzed_plan_eq!(
1715            plan,
1716            @r#"
1717        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1718          EmptyRelation
1719        "#
1720        )
1721    }
1722
1723    #[test]
1724    fn inlist_case() -> Result<()> {
1725        // a in (1,4,8), a is int64
1726        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1727        let empty = empty_with_type(DataType::Int64);
1728        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1729        assert_analyzed_plan_eq!(
1730            plan,
1731            @r"
1732        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1733          EmptyRelation
1734        ")?;
1735
1736        // a in (1,4,8), a is decimal
1737        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1738        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1739            produce_one_row: false,
1740            schema: Arc::new(DFSchema::from_unqualified_fields(
1741                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1742                std::collections::HashMap::new(),
1743            )?),
1744        }));
1745        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1746        assert_analyzed_plan_eq!(
1747            plan,
1748            @r"
1749        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1750          EmptyRelation
1751        ")
1752    }
1753
1754    #[test]
1755    fn between_case() -> Result<()> {
1756        let expr = col("a").between(
1757            lit("2002-05-08"),
1758            // (cast('2002-05-08' as date) + interval '1 months')
1759            cast(lit("2002-05-08"), DataType::Date32)
1760                + lit(ScalarValue::new_interval_ym(0, 1)),
1761        );
1762        let empty = empty_with_type(Utf8);
1763        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1764
1765        assert_analyzed_plan_eq!(
1766            plan,
1767            @r#"
1768        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1769          EmptyRelation
1770        "#
1771        )
1772    }
1773
1774    #[test]
1775    fn between_infer_cheap_type() -> Result<()> {
1776        let expr = col("a").between(
1777            // (cast('2002-05-08' as date) + interval '1 months')
1778            cast(lit("2002-05-08"), DataType::Date32)
1779                + lit(ScalarValue::new_interval_ym(0, 1)),
1780            lit("2002-12-08"),
1781        );
1782        let empty = empty_with_type(Utf8);
1783        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1784
1785        // TODO: we should cast col(a).
1786        assert_analyzed_plan_eq!(
1787            plan,
1788            @r#"
1789        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1790          EmptyRelation
1791        "#
1792        )
1793    }
1794
1795    #[test]
1796    fn between_null() -> Result<()> {
1797        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1798        let empty = empty();
1799        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1800
1801        assert_analyzed_plan_eq!(
1802            plan,
1803            @r"
1804        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1805          EmptyRelation
1806        "
1807        )
1808    }
1809
1810    #[test]
1811    fn is_bool_for_type_coercion() -> Result<()> {
1812        // is true
1813        let expr = col("a").is_true();
1814        let empty = empty_with_type(DataType::Boolean);
1815        let plan =
1816            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1817
1818        assert_analyzed_plan_eq!(
1819            plan,
1820            @r"
1821        Projection: a IS TRUE
1822          EmptyRelation
1823        "
1824        )?;
1825
1826        let empty = empty_with_type(DataType::Int64);
1827        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1828        assert_type_coercion_error(
1829            plan,
1830            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1831        )?;
1832
1833        // is not true
1834        let expr = col("a").is_not_true();
1835        let empty = empty_with_type(DataType::Boolean);
1836        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1837
1838        assert_analyzed_plan_eq!(
1839            plan,
1840            @r"
1841        Projection: a IS NOT TRUE
1842          EmptyRelation
1843        "
1844        )?;
1845
1846        // is false
1847        let expr = col("a").is_false();
1848        let empty = empty_with_type(DataType::Boolean);
1849        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1850
1851        assert_analyzed_plan_eq!(
1852            plan,
1853            @r"
1854        Projection: a IS FALSE
1855          EmptyRelation
1856        "
1857        )?;
1858
1859        // is not false
1860        let expr = col("a").is_not_false();
1861        let empty = empty_with_type(DataType::Boolean);
1862        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1863
1864        assert_analyzed_plan_eq!(
1865            plan,
1866            @r"
1867        Projection: a IS NOT FALSE
1868          EmptyRelation
1869        "
1870        )
1871    }
1872
1873    #[test]
1874    fn like_for_type_coercion() -> Result<()> {
1875        // like : utf8 like "abc"
1876        let expr = Box::new(col("a"));
1877        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1878        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1879        let empty = empty_with_type(Utf8);
1880        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1881
1882        assert_analyzed_plan_eq!(
1883            plan,
1884            @r#"
1885        Projection: a LIKE Utf8("abc")
1886          EmptyRelation
1887        "#
1888        )?;
1889
1890        let expr = Box::new(col("a"));
1891        let pattern = Box::new(lit(ScalarValue::Null));
1892        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1893        let empty = empty_with_type(Utf8);
1894        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1895
1896        assert_analyzed_plan_eq!(
1897            plan,
1898            @r"
1899        Projection: a LIKE CAST(NULL AS Utf8)
1900          EmptyRelation
1901        "
1902        )?;
1903
1904        let expr = Box::new(col("a"));
1905        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1906        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1907        let empty = empty_with_type(DataType::Int64);
1908        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1909        assert_type_coercion_error(
1910            plan,
1911            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1912        )?;
1913
1914        // ilike
1915        let expr = Box::new(col("a"));
1916        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1917        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1918        let empty = empty_with_type(Utf8);
1919        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1920
1921        assert_analyzed_plan_eq!(
1922            plan,
1923            @r#"
1924        Projection: a ILIKE Utf8("abc")
1925          EmptyRelation
1926        "#
1927        )?;
1928
1929        let expr = Box::new(col("a"));
1930        let pattern = Box::new(lit(ScalarValue::Null));
1931        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1932        let empty = empty_with_type(Utf8);
1933        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1934
1935        assert_analyzed_plan_eq!(
1936            plan,
1937            @r"
1938        Projection: a ILIKE CAST(NULL AS Utf8)
1939          EmptyRelation
1940        "
1941        )?;
1942
1943        let expr = Box::new(col("a"));
1944        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1945        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1946        let empty = empty_with_type(DataType::Int64);
1947        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1948        assert_type_coercion_error(
1949            plan,
1950            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
1951        )?;
1952
1953        Ok(())
1954    }
1955
1956    #[test]
1957    fn unknown_for_type_coercion() -> Result<()> {
1958        // unknown
1959        let expr = col("a").is_unknown();
1960        let empty = empty_with_type(DataType::Boolean);
1961        let plan =
1962            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1963
1964        assert_analyzed_plan_eq!(
1965            plan,
1966            @r"
1967        Projection: a IS UNKNOWN
1968          EmptyRelation
1969        "
1970        )?;
1971
1972        let empty = empty_with_type(Utf8);
1973        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1974        assert_type_coercion_error(
1975            plan,
1976            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
1977        )?;
1978
1979        // is not unknown
1980        let expr = col("a").is_not_unknown();
1981        let empty = empty_with_type(DataType::Boolean);
1982        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1983
1984        assert_analyzed_plan_eq!(
1985            plan,
1986            @r"
1987        Projection: a IS NOT UNKNOWN
1988          EmptyRelation
1989        "
1990        )
1991    }
1992
1993    #[test]
1994    fn concat_for_type_coercion() -> Result<()> {
1995        let empty = empty_with_type(Utf8);
1996        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
1997
1998        // concat-type signature
1999        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2000            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2001        })
2002        .call(args.to_vec());
2003        let plan =
2004            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2005        assert_analyzed_plan_eq!(
2006            plan,
2007            @r#"
2008        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2009          EmptyRelation
2010        "#
2011        )
2012    }
2013
2014    #[test]
2015    fn test_type_coercion_rewrite() -> Result<()> {
2016        // gt
2017        let schema = Arc::new(DFSchema::from_unqualified_fields(
2018            vec![Field::new("a", DataType::Int64, true)].into(),
2019            std::collections::HashMap::new(),
2020        )?);
2021        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2022        let expr = is_true(lit(12i32).gt(lit(13i64)));
2023        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2024        let result = expr.rewrite(&mut rewriter).data()?;
2025        assert_eq!(expected, result);
2026
2027        // eq
2028        let schema = Arc::new(DFSchema::from_unqualified_fields(
2029            vec![Field::new("a", DataType::Int64, true)].into(),
2030            std::collections::HashMap::new(),
2031        )?);
2032        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2033        let expr = is_true(lit(12i32).eq(lit(13i64)));
2034        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2035        let result = expr.rewrite(&mut rewriter).data()?;
2036        assert_eq!(expected, result);
2037
2038        // lt
2039        let schema = Arc::new(DFSchema::from_unqualified_fields(
2040            vec![Field::new("a", DataType::Int64, true)].into(),
2041            std::collections::HashMap::new(),
2042        )?);
2043        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2044        let expr = is_true(lit(12i32).lt(lit(13i64)));
2045        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2046        let result = expr.rewrite(&mut rewriter).data()?;
2047        assert_eq!(expected, result);
2048
2049        Ok(())
2050    }
2051
2052    #[test]
2053    fn binary_op_date32_eq_ts() -> Result<()> {
2054        let expr = cast(
2055            lit("1998-03-18"),
2056            DataType::Timestamp(TimeUnit::Nanosecond, None),
2057        )
2058        .eq(cast(lit("1998-03-18"), DataType::Date32));
2059        let empty = empty();
2060        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2061
2062        assert_analyzed_plan_eq!(
2063            plan,
2064            @r#"
2065        Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None))
2066          EmptyRelation
2067        "#
2068        )
2069    }
2070
2071    fn cast_if_not_same_type(
2072        expr: Box<Expr>,
2073        data_type: &DataType,
2074        schema: &DFSchemaRef,
2075    ) -> Box<Expr> {
2076        if &expr.get_type(schema).unwrap() != data_type {
2077            Box::new(cast(*expr, data_type.clone()))
2078        } else {
2079            expr
2080        }
2081    }
2082
2083    fn cast_helper(
2084        case: Case,
2085        case_when_type: &DataType,
2086        then_else_type: &DataType,
2087        schema: &DFSchemaRef,
2088    ) -> Case {
2089        let expr = case
2090            .expr
2091            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2092        let when_then_expr = case
2093            .when_then_expr
2094            .into_iter()
2095            .map(|(when, then)| {
2096                (
2097                    cast_if_not_same_type(when, case_when_type, schema),
2098                    cast_if_not_same_type(then, then_else_type, schema),
2099                )
2100            })
2101            .collect::<Vec<_>>();
2102        let else_expr = case
2103            .else_expr
2104            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2105
2106        Case {
2107            expr,
2108            when_then_expr,
2109            else_expr,
2110        }
2111    }
2112
2113    #[test]
2114    fn test_case_expression_coercion() -> Result<()> {
2115        let schema = Arc::new(DFSchema::from_unqualified_fields(
2116            vec![
2117                Field::new("boolean", DataType::Boolean, true),
2118                Field::new("integer", DataType::Int32, true),
2119                Field::new("float", DataType::Float32, true),
2120                Field::new(
2121                    "timestamp",
2122                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2123                    true,
2124                ),
2125                Field::new("date", DataType::Date32, true),
2126                Field::new(
2127                    "interval",
2128                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2129                    true,
2130                ),
2131                Field::new("binary", DataType::Binary, true),
2132                Field::new("string", Utf8, true),
2133                Field::new("decimal", DataType::Decimal128(10, 10), true),
2134            ]
2135            .into(),
2136            std::collections::HashMap::new(),
2137        )?);
2138
2139        let case = Case {
2140            expr: None,
2141            when_then_expr: vec![
2142                (Box::new(col("boolean")), Box::new(col("integer"))),
2143                (Box::new(col("integer")), Box::new(col("float"))),
2144                (Box::new(col("string")), Box::new(col("string"))),
2145            ],
2146            else_expr: None,
2147        };
2148        let case_when_common_type = DataType::Boolean;
2149        let then_else_common_type = Utf8;
2150        let expected = cast_helper(
2151            case.clone(),
2152            &case_when_common_type,
2153            &then_else_common_type,
2154            &schema,
2155        );
2156        let actual = coerce_case_expression(case, &schema)?;
2157        assert_eq!(expected, actual);
2158
2159        let case = Case {
2160            expr: Some(Box::new(col("string"))),
2161            when_then_expr: vec![
2162                (Box::new(col("float")), Box::new(col("integer"))),
2163                (Box::new(col("integer")), Box::new(col("float"))),
2164                (Box::new(col("string")), Box::new(col("string"))),
2165            ],
2166            else_expr: Some(Box::new(col("string"))),
2167        };
2168        let case_when_common_type = Utf8;
2169        let then_else_common_type = Utf8;
2170        let expected = cast_helper(
2171            case.clone(),
2172            &case_when_common_type,
2173            &then_else_common_type,
2174            &schema,
2175        );
2176        let actual = coerce_case_expression(case, &schema)?;
2177        assert_eq!(expected, actual);
2178
2179        let case = Case {
2180            expr: Some(Box::new(col("interval"))),
2181            when_then_expr: vec![
2182                (Box::new(col("float")), Box::new(col("integer"))),
2183                (Box::new(col("binary")), Box::new(col("float"))),
2184                (Box::new(col("string")), Box::new(col("string"))),
2185            ],
2186            else_expr: Some(Box::new(col("string"))),
2187        };
2188        let err = coerce_case_expression(case, &schema).unwrap_err();
2189        assert_snapshot!(
2190            err.strip_backtrace(),
2191            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression"
2192        );
2193
2194        let case = Case {
2195            expr: Some(Box::new(col("string"))),
2196            when_then_expr: vec![
2197                (Box::new(col("float")), Box::new(col("date"))),
2198                (Box::new(col("string")), Box::new(col("float"))),
2199                (Box::new(col("string")), Box::new(col("binary"))),
2200            ],
2201            else_expr: Some(Box::new(col("timestamp"))),
2202        };
2203        let err = coerce_case_expression(case, &schema).unwrap_err();
2204        assert_snapshot!(
2205            err.strip_backtrace(),
2206            @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression"
2207        );
2208
2209        Ok(())
2210    }
2211
2212    macro_rules! test_case_expression {
2213        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2214            let case = Case {
2215                expr: $expr.map(|e| Box::new(col(e))),
2216                when_then_expr: $when_then,
2217                else_expr: None,
2218            };
2219
2220            let expected =
2221                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2222
2223            let actual = coerce_case_expression(case, &$schema)?;
2224            assert_eq!(expected, actual);
2225        };
2226    }
2227
2228    #[test]
2229    fn tes_case_when_list() -> Result<()> {
2230        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2231        let schema = Arc::new(DFSchema::from_unqualified_fields(
2232            vec![
2233                Field::new(
2234                    "large_list",
2235                    DataType::LargeList(Arc::clone(&inner_field)),
2236                    true,
2237                ),
2238                Field::new(
2239                    "fixed_list",
2240                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2241                    true,
2242                ),
2243                Field::new("list", DataType::List(inner_field), true),
2244            ]
2245            .into(),
2246            std::collections::HashMap::new(),
2247        )?);
2248
2249        test_case_expression!(
2250            Some("list"),
2251            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2252            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2253            Utf8,
2254            schema
2255        );
2256
2257        test_case_expression!(
2258            Some("large_list"),
2259            vec![(Box::new(col("list")), Box::new(lit("1")))],
2260            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2261            Utf8,
2262            schema
2263        );
2264
2265        test_case_expression!(
2266            Some("list"),
2267            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2268            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2269            Utf8,
2270            schema
2271        );
2272
2273        test_case_expression!(
2274            Some("fixed_list"),
2275            vec![(Box::new(col("list")), Box::new(lit("1")))],
2276            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2277            Utf8,
2278            schema
2279        );
2280
2281        test_case_expression!(
2282            Some("fixed_list"),
2283            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2284            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2285            Utf8,
2286            schema
2287        );
2288
2289        test_case_expression!(
2290            Some("large_list"),
2291            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2292            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2293            Utf8,
2294            schema
2295        );
2296        Ok(())
2297    }
2298
2299    #[test]
2300    fn test_then_else_list() -> Result<()> {
2301        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2302        let schema = Arc::new(DFSchema::from_unqualified_fields(
2303            vec![
2304                Field::new("boolean", DataType::Boolean, true),
2305                Field::new(
2306                    "large_list",
2307                    DataType::LargeList(Arc::clone(&inner_field)),
2308                    true,
2309                ),
2310                Field::new(
2311                    "fixed_list",
2312                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2313                    true,
2314                ),
2315                Field::new("list", DataType::List(inner_field), true),
2316            ]
2317            .into(),
2318            std::collections::HashMap::new(),
2319        )?);
2320
2321        // large list and list
2322        test_case_expression!(
2323            None::<String>,
2324            vec![
2325                (Box::new(col("boolean")), Box::new(col("large_list"))),
2326                (Box::new(col("boolean")), Box::new(col("list")))
2327            ],
2328            DataType::Boolean,
2329            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2330            schema
2331        );
2332
2333        test_case_expression!(
2334            None::<String>,
2335            vec![
2336                (Box::new(col("boolean")), Box::new(col("list"))),
2337                (Box::new(col("boolean")), Box::new(col("large_list")))
2338            ],
2339            DataType::Boolean,
2340            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2341            schema
2342        );
2343
2344        // fixed list and list
2345        test_case_expression!(
2346            None::<String>,
2347            vec![
2348                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2349                (Box::new(col("boolean")), Box::new(col("list")))
2350            ],
2351            DataType::Boolean,
2352            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2353            schema
2354        );
2355
2356        test_case_expression!(
2357            None::<String>,
2358            vec![
2359                (Box::new(col("boolean")), Box::new(col("list"))),
2360                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2361            ],
2362            DataType::Boolean,
2363            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2364            schema
2365        );
2366
2367        // fixed list and large list
2368        test_case_expression!(
2369            None::<String>,
2370            vec![
2371                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2372                (Box::new(col("boolean")), Box::new(col("large_list")))
2373            ],
2374            DataType::Boolean,
2375            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2376            schema
2377        );
2378
2379        test_case_expression!(
2380            None::<String>,
2381            vec![
2382                (Box::new(col("boolean")), Box::new(col("large_list"))),
2383                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2384            ],
2385            DataType::Boolean,
2386            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2387            schema
2388        );
2389        Ok(())
2390    }
2391
2392    #[test]
2393    fn test_map_with_diff_name() -> Result<()> {
2394        let mut builder = SchemaBuilder::new();
2395        builder.push(Field::new("key", Utf8, false));
2396        builder.push(Field::new("value", DataType::Float64, true));
2397        let struct_fields = builder.finish().fields;
2398
2399        let fields =
2400            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2401        let map_type_entries = DataType::Map(Arc::new(fields), false);
2402
2403        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2404        let may_type_cutsom = DataType::Map(Arc::new(fields), false);
2405
2406        let expr = col("a").eq(cast(col("a"), may_type_cutsom));
2407        let empty = empty_with_type(map_type_entries);
2408        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2409
2410        assert_analyzed_plan_eq!(
2411            plan,
2412            @r#"
2413        Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))
2414          EmptyRelation
2415        "#
2416        )
2417    }
2418
2419    #[test]
2420    fn interval_plus_timestamp() -> Result<()> {
2421        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2422        let expr = Expr::BinaryExpr(BinaryExpr::new(
2423            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2424            Operator::Plus,
2425            Box::new(cast(
2426                lit("2000-01-01T00:00:00"),
2427                DataType::Timestamp(TimeUnit::Nanosecond, None),
2428            )),
2429        ));
2430        let empty = empty();
2431        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2432
2433        assert_analyzed_plan_eq!(
2434            plan,
2435            @r#"
2436        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None))
2437          EmptyRelation
2438        "#
2439        )
2440    }
2441
2442    #[test]
2443    fn timestamp_subtract_timestamp() -> Result<()> {
2444        let expr = Expr::BinaryExpr(BinaryExpr::new(
2445            Box::new(cast(
2446                lit("1998-03-18"),
2447                DataType::Timestamp(TimeUnit::Nanosecond, None),
2448            )),
2449            Operator::Minus,
2450            Box::new(cast(
2451                lit("1998-03-18"),
2452                DataType::Timestamp(TimeUnit::Nanosecond, None),
2453            )),
2454        ));
2455        let empty = empty();
2456        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2457
2458        assert_analyzed_plan_eq!(
2459            plan,
2460            @r#"
2461        Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None))
2462          EmptyRelation
2463        "#
2464        )
2465    }
2466
2467    #[test]
2468    fn in_subquery_cast_subquery() -> Result<()> {
2469        let empty_int32 = empty_with_type(DataType::Int32);
2470        let empty_int64 = empty_with_type(DataType::Int64);
2471
2472        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2473            Box::new(col("a")),
2474            Subquery {
2475                subquery: empty_int32,
2476                outer_ref_columns: vec![],
2477                spans: Spans::new(),
2478            },
2479            false,
2480        ));
2481        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2482        // add cast for subquery
2483
2484        assert_analyzed_plan_eq!(
2485            plan,
2486            @r"
2487        Filter: a IN (<subquery>)
2488          Subquery:
2489            Projection: CAST(a AS Int64)
2490              EmptyRelation
2491          EmptyRelation
2492        "
2493        )
2494    }
2495
2496    #[test]
2497    fn in_subquery_cast_expr() -> Result<()> {
2498        let empty_int32 = empty_with_type(DataType::Int32);
2499        let empty_int64 = empty_with_type(DataType::Int64);
2500
2501        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2502            Box::new(col("a")),
2503            Subquery {
2504                subquery: empty_int64,
2505                outer_ref_columns: vec![],
2506                spans: Spans::new(),
2507            },
2508            false,
2509        ));
2510        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2511
2512        // add cast for subquery
2513        assert_analyzed_plan_eq!(
2514            plan,
2515            @r"
2516        Filter: CAST(a AS Int64) IN (<subquery>)
2517          Subquery:
2518            EmptyRelation
2519          EmptyRelation
2520        "
2521        )
2522    }
2523
2524    #[test]
2525    fn in_subquery_cast_all() -> Result<()> {
2526        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2527        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2528
2529        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2530            Box::new(col("a")),
2531            Subquery {
2532                subquery: empty_inside,
2533                outer_ref_columns: vec![],
2534                spans: Spans::new(),
2535            },
2536            false,
2537        ));
2538        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2539
2540        // add cast for subquery
2541        assert_analyzed_plan_eq!(
2542            plan,
2543            @r"
2544        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2545          Subquery:
2546            Projection: CAST(a AS Decimal128(13, 8))
2547              EmptyRelation
2548          EmptyRelation
2549        "
2550        )
2551    }
2552}