1use arrow::compute::can_cast_types;
21use datafusion_expr::binary::BinaryTypeCoercer;
22use itertools::{Itertools as _, izip};
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26use crate::utils::NamePreserver;
27
28use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
29use arrow::temporal_conversions::SECONDS_IN_DAY;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
32use datafusion_common::{
33 Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
35 plan_err,
36};
37use datafusion_expr::expr::{
38 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39 InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::{UDFCoercionExt, fields_with_udf};
46use datafusion_expr::type_coercion::is_datetime;
47use datafusion_expr::type_coercion::other::{
48 get_coerce_type_for_case_expression, get_coerce_type_for_list,
49};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52 Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union,
53 WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, is_not_false, is_not_true,
54 is_not_unknown, is_true, is_unknown, lit, not,
55};
56
57#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63 pub fn new() -> Self {
64 Self {}
65 }
66}
67
68fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70 if !config.optimizer.expand_views_at_output {
71 return Ok(plan);
72 }
73
74 let outer_refs = plan.expressions();
75 if outer_refs.is_empty() {
76 return Ok(plan);
77 }
78
79 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80 coerce_plan_expr_for_schema(plan, &dfschema?)
81 } else {
82 Ok(plan)
83 }
84}
85
86impl AnalyzerRule for TypeCoercion {
87 fn name(&self) -> &str {
88 "type_coercion"
89 }
90
91 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92 let empty_schema = DFSchema::empty();
93
94 let transformed_plan = plan
96 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97 .data;
98
99 coerce_output(transformed_plan, config)
101 }
102}
103
104fn analyze_internal(
108 external_schema: &DFSchema,
109 plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111 let mut schema = merge_schema(&plan.inputs());
114
115 if let LogicalPlan::TableScan(ts) = &plan {
116 let source_schema = DFSchema::try_from_qualified_schema(
117 ts.table_name.clone(),
118 &ts.source.schema(),
119 )?;
120 schema.merge(&source_schema);
121 }
122
123 schema.merge(external_schema);
127
128 let plan = if let LogicalPlan::Filter(mut filter) = plan {
130 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131 LogicalPlan::Filter(filter)
132 } else {
133 plan
134 };
135
136 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138 let name_preserver = NamePreserver::new(&plan);
139 plan.map_expressions(|expr| {
141 let original_name = name_preserver.save(&expr);
142 expr.rewrite(&mut expr_rewrite)
143 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144 })?
145 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147 .map_data(|plan| plan.recompute_schema())
149}
150
151pub struct TypeCoercionRewriter<'a> {
153 pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157 pub fn new(schema: &'a DFSchema) -> Self {
160 Self { schema }
161 }
162
163 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168 match plan {
169 LogicalPlan::Join(join) => self.coerce_join(join),
170 LogicalPlan::Union(union) => Self::coerce_union(union),
171 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172 _ => Ok(plan),
173 }
174 }
175
176 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185 join.on = join
186 .on
187 .into_iter()
188 .map(|(lhs, rhs)| {
189 let left_schema = join.left.schema();
192 let right_schema = join.right.schema();
193 let (lhs, rhs) = self.coerce_binary_op(
194 lhs,
195 left_schema,
196 Operator::Eq,
197 rhs,
198 right_schema,
199 )?;
200 Ok((lhs, rhs))
201 })
202 .collect::<Result<Vec<_>>>()?;
203
204 join.filter = join
206 .filter
207 .map(|expr| self.coerce_join_filter(expr))
208 .transpose()?;
209
210 Ok(LogicalPlan::Join(join))
211 }
212
213 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216 let union_schema = Arc::new(coerce_union_schema_with_schema(
217 &union_plan.inputs,
218 &union_plan.schema,
219 )?);
220 let new_inputs = union_plan
221 .inputs
222 .into_iter()
223 .map(|p| {
224 let plan =
225 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226 match plan {
227 LogicalPlan::Projection(Projection { expr, input, .. }) => {
228 Ok(Arc::new(project_with_column_index(
229 expr,
230 input,
231 Arc::clone(&union_schema),
232 )?))
233 }
234 other_plan => Ok(Arc::new(other_plan)),
235 }
236 })
237 .collect::<Result<Vec<_>>>()?;
238 Ok(LogicalPlan::Union(Union {
239 inputs: new_inputs,
240 schema: union_schema,
241 }))
242 }
243
244 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246 fn coerce_limit_expr(
247 expr: Expr,
248 schema: &DFSchema,
249 expr_name: &str,
250 ) -> Result<Expr> {
251 let dt = expr.get_type(schema)?;
252 if dt.is_integer() || dt.is_null() {
253 expr.cast_to(&DataType::Int64, schema)
254 } else {
255 plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
256 }
257 }
258
259 let empty_schema = DFSchema::empty();
260 let new_fetch = limit
261 .fetch
262 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263 .transpose()?;
264 let new_skip = limit
265 .skip
266 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267 .transpose()?;
268 Ok(LogicalPlan::Limit(Limit {
269 input: limit.input,
270 fetch: new_fetch.map(Box::new),
271 skip: new_skip.map(Box::new),
272 }))
273 }
274
275 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276 let expr_type = expr.get_type(self.schema)?;
277 match expr_type {
278 DataType::Boolean => Ok(expr),
279 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281 }
282 }
283
284 fn coerce_binary_op(
285 &self,
286 left: Expr,
287 left_schema: &DFSchema,
288 op: Operator,
289 right: Expr,
290 right_schema: &DFSchema,
291 ) -> Result<(Expr, Expr)> {
292 let left_data_type = left.get_type(left_schema)?;
293 let right_data_type = right.get_type(right_schema)?;
294 let (left_type, right_type) =
295 BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type)
296 .get_input_types()?;
297 let left_cast_ok = can_cast_types(&left_data_type, &left_type);
298 let right_cast_ok = can_cast_types(&right_data_type, &right_type);
299
300 let left_expr = if !left_cast_ok {
304 Self::coerce_date_time_math_op(
305 left,
306 &op,
307 &left_data_type,
308 &left_type,
309 &right_type,
310 )?
311 } else {
312 left.cast_to(&left_type, left_schema)?
313 };
314
315 let right_expr = if !right_cast_ok {
316 Self::coerce_date_time_math_op(
317 right,
318 &op,
319 &right_data_type,
320 &right_type,
321 &left_type,
322 )?
323 } else {
324 right.cast_to(&right_type, right_schema)?
325 };
326
327 Ok((left_expr, right_expr))
328 }
329
330 fn coerce_date_time_math_op(
331 expr: Expr,
332 op: &Operator,
333 left_current_type: &DataType,
334 left_target_type: &DataType,
335 right_target_type: &DataType,
336 ) -> Result<Expr, DataFusionError> {
337 use DataType::*;
338
339 fn cast(expr: Expr, target_type: DataType) -> Expr {
340 Expr::Cast(Cast::new(Box::new(expr), target_type))
341 }
342
343 fn time_to_nanos(
344 expr: Expr,
345 expr_type: &DataType,
346 ) -> Result<Expr, DataFusionError> {
347 let expr = match expr_type {
348 Time32(TimeUnit::Second) => {
349 cast(cast(expr, Int32), Int64)
350 * lit(ScalarValue::Int64(Some(1_000_000_000)))
351 }
352 Time32(TimeUnit::Millisecond) => {
353 cast(cast(expr, Int32), Int64)
354 * lit(ScalarValue::Int64(Some(1_000_000)))
355 }
356 Time64(TimeUnit::Microsecond) => {
357 cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000)))
358 }
359 Time64(TimeUnit::Nanosecond) => cast(expr, Int64),
360 t => return internal_err!("Unexpected time data type {t}"),
361 };
362
363 Ok(expr)
364 }
365
366 let e = match (
367 &op,
368 &left_current_type,
369 &left_target_type,
370 &right_target_type,
371 ) {
372 (
374 Operator::Plus | Operator::Minus,
375 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
376 Interval(IntervalUnit::MonthDayNano),
377 Date32 | Date64,
378 ) => {
379 let expr = match *left_current_type {
381 Int64 => expr,
382 _ => cast(expr, Int64),
383 };
384 let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY));
386 let expr = cast(expr, Duration(TimeUnit::Second));
388 cast(expr, Interval(IntervalUnit::MonthDayNano))
390 }
391 (
400 Operator::Plus | Operator::Minus,
401 Time32(_) | Time64(_),
402 Duration(TimeUnit::Nanosecond),
403 Timestamp(TimeUnit::Nanosecond, None),
404 ) => {
405 let expr = time_to_nanos(expr, left_current_type)?;
407 cast(expr, Duration(TimeUnit::Nanosecond))
409 }
410 (
416 Operator::Plus | Operator::Minus,
417 Time32(_) | Time64(_),
418 Interval(IntervalUnit::MonthDayNano),
419 Interval(IntervalUnit::MonthDayNano),
420 ) => {
421 let expr = time_to_nanos(expr, left_current_type)?;
423 let expr = cast(expr, Duration(TimeUnit::Nanosecond));
425 cast(expr, Interval(IntervalUnit::MonthDayNano))
427 }
428 _ => {
429 return plan_err!(
430 "Cannot automatically convert {left_current_type} to {left_target_type}"
431 );
432 }
433 };
434
435 Ok(e)
436 }
437}
438
439impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
440 type Node = Expr;
441
442 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
443 match expr {
444 Expr::Unnest(_) => not_impl_err!(
445 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
446 ),
447 Expr::ScalarSubquery(Subquery {
448 subquery,
449 outer_ref_columns,
450 spans,
451 }) => {
452 let new_plan =
453 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
454 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
455 subquery: Arc::new(new_plan),
456 outer_ref_columns,
457 spans,
458 })))
459 }
460 Expr::Exists(Exists { subquery, negated }) => {
461 let new_plan = analyze_internal(
462 self.schema,
463 Arc::unwrap_or_clone(subquery.subquery),
464 )?
465 .data;
466 Ok(Transformed::yes(Expr::Exists(Exists {
467 subquery: Subquery {
468 subquery: Arc::new(new_plan),
469 outer_ref_columns: subquery.outer_ref_columns,
470 spans: subquery.spans,
471 },
472 negated,
473 })))
474 }
475 Expr::InSubquery(InSubquery {
476 expr,
477 subquery,
478 negated,
479 }) => {
480 let new_plan = analyze_internal(
481 self.schema,
482 Arc::unwrap_or_clone(subquery.subquery),
483 )?
484 .data;
485 let expr_type = expr.get_type(self.schema)?;
486 let subquery_type = new_plan.schema().field(0).data_type();
487 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
488 plan_datafusion_err!(
489 "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
490 ),
491 )?;
492 let new_subquery = Subquery {
493 subquery: Arc::new(new_plan),
494 outer_ref_columns: subquery.outer_ref_columns,
495 spans: subquery.spans,
496 };
497 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
498 Box::new(expr.cast_to(&common_type, self.schema)?),
499 cast_subquery(new_subquery, &common_type)?,
500 negated,
501 ))))
502 }
503 Expr::SetComparison(SetComparison {
504 expr,
505 subquery,
506 op,
507 quantifier,
508 }) => {
509 let new_plan = analyze_internal(
510 self.schema,
511 Arc::unwrap_or_clone(subquery.subquery),
512 )?
513 .data;
514 let expr_type = expr.get_type(self.schema)?;
515 let subquery_type = new_plan.schema().field(0).data_type();
516 if (expr_type.is_numeric() && subquery_type.is_string())
517 || (subquery_type.is_numeric() && expr_type.is_string())
518 {
519 return plan_err!(
520 "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
521 );
522 }
523 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
524 plan_datafusion_err!(
525 "expr type {expr_type} can't cast to {subquery_type} in SetComparison"
526 ),
527 )?;
528 let new_subquery = Subquery {
529 subquery: Arc::new(new_plan),
530 outer_ref_columns: subquery.outer_ref_columns,
531 spans: subquery.spans,
532 };
533 Ok(Transformed::yes(Expr::SetComparison(SetComparison::new(
534 Box::new(expr.cast_to(&common_type, self.schema)?),
535 cast_subquery(new_subquery, &common_type)?,
536 op,
537 quantifier,
538 ))))
539 }
540 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
541 *expr,
542 self.schema,
543 )?))),
544 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
545 get_casted_expr_for_bool_op(*expr, self.schema)?,
546 ))),
547 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
548 get_casted_expr_for_bool_op(*expr, self.schema)?,
549 ))),
550 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
551 get_casted_expr_for_bool_op(*expr, self.schema)?,
552 ))),
553 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
554 get_casted_expr_for_bool_op(*expr, self.schema)?,
555 ))),
556 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
557 get_casted_expr_for_bool_op(*expr, self.schema)?,
558 ))),
559 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
560 get_casted_expr_for_bool_op(*expr, self.schema)?,
561 ))),
562 Expr::Like(Like {
563 negated,
564 expr,
565 pattern,
566 escape_char,
567 case_insensitive,
568 }) => {
569 let left_type = expr.get_type(self.schema)?;
570 let right_type = pattern.get_type(self.schema)?;
571 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
572 let op_name = if case_insensitive {
573 "ILIKE"
574 } else {
575 "LIKE"
576 };
577 plan_datafusion_err!(
578 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
579 )
580 })?;
581 let expr = match left_type {
582 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
583 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
584 };
585 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
586 Ok(Transformed::yes(Expr::Like(Like::new(
587 negated,
588 expr,
589 pattern,
590 escape_char,
591 case_insensitive,
592 ))))
593 }
594 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
595 let (left, right) =
596 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
597 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
598 Box::new(left),
599 op,
600 Box::new(right),
601 ))))
602 }
603 Expr::Between(Between {
604 expr,
605 negated,
606 low,
607 high,
608 }) => {
609 let expr_type = expr.get_type(self.schema)?;
610 let low_type = low.get_type(self.schema)?;
611 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
612 .ok_or_else(|| {
613 internal_datafusion_err!(
614 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
615 )
616 })?;
617 let high_type = high.get_type(self.schema)?;
618 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
619 .ok_or_else(|| {
620 internal_datafusion_err!(
621 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
622 )
623 })?;
624 let coercion_type =
625 comparison_coercion(&low_coerced_type, &high_coerced_type)
626 .ok_or_else(|| {
627 internal_datafusion_err!(
628 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
629 )
630 })?;
631 Ok(Transformed::yes(Expr::Between(Between::new(
632 Box::new(expr.cast_to(&coercion_type, self.schema)?),
633 negated,
634 Box::new(low.cast_to(&coercion_type, self.schema)?),
635 Box::new(high.cast_to(&coercion_type, self.schema)?),
636 ))))
637 }
638 Expr::InList(InList {
639 expr,
640 list,
641 negated,
642 }) => {
643 let expr_data_type = expr.get_type(self.schema)?;
644 let list_data_types = list
645 .iter()
646 .map(|list_expr| list_expr.get_type(self.schema))
647 .collect::<Result<Vec<_>>>()?;
648 let result_type =
649 get_coerce_type_for_list(&expr_data_type, &list_data_types);
650 match result_type {
651 None => plan_err!(
652 "Can not find compatible types to compare {expr_data_type} with [{}]",
653 list_data_types.iter().join(", ")
654 ),
655 Some(coerced_type) => {
656 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
658 let cast_list_expr = list
659 .into_iter()
660 .map(|list_expr| {
661 list_expr.cast_to(&coerced_type, self.schema)
662 })
663 .collect::<Result<Vec<_>>>()?;
664 Ok(Transformed::yes(Expr::InList(InList::new(
665 Box::new(cast_expr),
666 cast_list_expr,
667 negated,
668 ))))
669 }
670 }
671 }
672 Expr::Case(case) => {
673 let case = coerce_case_expression(case, self.schema)?;
674 Ok(Transformed::yes(Expr::Case(case)))
675 }
676 Expr::ScalarFunction(ScalarFunction { func, args }) => {
677 let new_expr =
678 coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
679 Ok(Transformed::yes(Expr::ScalarFunction(
680 ScalarFunction::new_udf(func, new_expr),
681 )))
682 }
683 Expr::AggregateFunction(expr::AggregateFunction {
684 func,
685 params:
686 AggregateFunctionParams {
687 args,
688 distinct,
689 filter,
690 order_by,
691 null_treatment,
692 },
693 }) => {
694 let new_expr =
695 coerce_arguments_for_signature(args, self.schema, func.as_ref())?;
696 Ok(Transformed::yes(Expr::AggregateFunction(
697 expr::AggregateFunction::new_udf(
698 func,
699 new_expr,
700 distinct,
701 filter,
702 order_by,
703 null_treatment,
704 ),
705 )))
706 }
707 Expr::WindowFunction(window_fun) => {
708 let WindowFunction {
709 fun,
710 params:
711 expr::WindowFunctionParams {
712 args,
713 partition_by,
714 order_by,
715 window_frame,
716 filter,
717 null_treatment,
718 distinct,
719 },
720 } = *window_fun;
721 let window_frame =
722 coerce_window_frame(window_frame, self.schema, &order_by)?;
723
724 let args = match &fun {
725 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
726 coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
727 }
728 expr::WindowFunctionDefinition::WindowUDF(udf) => {
729 coerce_arguments_for_signature(args, self.schema, udf.as_ref())?
730 }
731 };
732
733 let new_expr = Expr::from(WindowFunction {
734 fun,
735 params: expr::WindowFunctionParams {
736 args,
737 partition_by,
738 order_by,
739 window_frame,
740 filter,
741 null_treatment,
742 distinct,
743 },
744 });
745 Ok(Transformed::yes(new_expr))
746 }
747 #[expect(deprecated)]
749 Expr::Alias(_)
750 | Expr::Column(_)
751 | Expr::ScalarVariable(_, _)
752 | Expr::Literal(_, _)
753 | Expr::SimilarTo(_)
754 | Expr::IsNotNull(_)
755 | Expr::IsNull(_)
756 | Expr::Negative(_)
757 | Expr::Cast(_)
758 | Expr::TryCast(_)
759 | Expr::Wildcard { .. }
760 | Expr::GroupingSet(_)
761 | Expr::Placeholder(_)
762 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
763 }
764 }
765}
766
767fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
769 let metadata = dfschema.as_arrow().metadata.clone();
770 let mut transformed = false;
771
772 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
773 dfschema
774 .iter()
775 .map(|(qualifier, field)| match field.data_type() {
776 DataType::Utf8View => {
777 transformed = true;
778 (
779 qualifier.cloned() as Option<TableReference>,
780 Arc::new(Field::new(
781 field.name(),
782 DataType::LargeUtf8,
783 field.is_nullable(),
784 )),
785 )
786 }
787 DataType::BinaryView => {
788 transformed = true;
789 (
790 qualifier.cloned() as Option<TableReference>,
791 Arc::new(Field::new(
792 field.name(),
793 DataType::LargeBinary,
794 field.is_nullable(),
795 )),
796 )
797 }
798 _ => (
799 qualifier.cloned() as Option<TableReference>,
800 Arc::clone(field),
801 ),
802 })
803 .unzip();
804
805 if !transformed {
806 return None;
807 }
808
809 let schema = Schema::new_with_metadata(transformed_fields, metadata);
810 Some(DFSchema::from_field_specific_qualified_schema(
811 qualifiers,
812 &Arc::new(schema),
813 ))
814}
815
816fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
819 match value {
820 ScalarValue::Utf8(Some(val)) => {
822 ScalarValue::try_from_string(val.clone(), target_type)
823 }
824 s => {
825 if s.is_null() {
826 ScalarValue::try_from(target_type)
828 } else {
829 Ok(s.clone())
833 }
834 }
835 }
836}
837
838fn coerce_scalar_range_aware(
845 target_type: &DataType,
846 value: &ScalarValue,
847) -> Result<ScalarValue> {
848 coerce_scalar(target_type, value).or_else(|err| {
849 if let Some(largest_type) = get_widest_type_in_family(target_type) {
851 coerce_scalar(largest_type, value).map_or_else(
852 |_| exec_err!("Cannot cast {value:?} to {target_type}"),
853 |_| ScalarValue::try_from(target_type),
854 )
855 } else {
856 Err(err)
857 }
858 })
859}
860
861fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
865 match given_type {
866 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
867 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
868 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
869 _ => None,
870 }
871}
872
873fn coerce_frame_bound(
875 target_type: &DataType,
876 bound: WindowFrameBound,
877) -> Result<WindowFrameBound> {
878 match bound {
879 WindowFrameBound::Preceding(v) => {
880 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
881 }
882 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
883 WindowFrameBound::Following(v) => {
884 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
885 }
886 }
887}
888
889fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
890 if col_type.is_numeric()
891 || col_type.is_string()
892 || col_type.is_null()
893 || matches!(
894 col_type,
895 DataType::List(_)
896 | DataType::LargeList(_)
897 | DataType::FixedSizeList(_, _)
898 | DataType::Boolean
899 )
900 {
901 Ok(col_type.clone())
902 } else if is_datetime(col_type) {
903 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
904 } else if let DataType::Dictionary(_, value_type) = col_type {
905 extract_window_frame_target_type(value_type)
906 } else {
907 internal_err!("Cannot run range queries on datatype: {col_type}")
908 }
909}
910
911fn coerce_window_frame(
914 window_frame: WindowFrame,
915 schema: &DFSchema,
916 expressions: &[Sort],
917) -> Result<WindowFrame> {
918 let mut window_frame = window_frame;
919 let target_type = match window_frame.units {
920 WindowFrameUnits::Range => {
921 let current_types = expressions
922 .first()
923 .map(|s| s.expr.get_type(schema))
924 .transpose()?;
925 if let Some(col_type) = current_types {
926 extract_window_frame_target_type(&col_type)?
927 } else {
928 return internal_err!("ORDER BY column cannot be empty");
929 }
930 }
931 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
932 };
933 window_frame.start_bound =
934 coerce_frame_bound(&target_type, window_frame.start_bound)?;
935 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
936 Ok(window_frame)
937}
938
939fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
942 let left_type = expr.get_type(schema)?;
943 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
944 .get_input_types()?;
945 expr.cast_to(&DataType::Boolean, schema)
946}
947
948fn coerce_arguments_for_signature<F: UDFCoercionExt>(
953 expressions: Vec<Expr>,
954 schema: &DFSchema,
955 func: &F,
956) -> Result<Vec<Expr>> {
957 let current_fields = expressions
958 .iter()
959 .map(|e| e.to_field(schema).map(|(_, f)| f))
960 .collect::<Result<Vec<_>>>()?;
961
962 let coerced_types = fields_with_udf(¤t_fields, func)?
963 .into_iter()
964 .map(|f| f.data_type().clone())
965 .collect::<Vec<_>>();
966
967 expressions
968 .into_iter()
969 .enumerate()
970 .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
971 .collect()
972}
973
974fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
975 let case_type = case
1007 .expr
1008 .as_ref()
1009 .map(|expr| expr.get_type(schema))
1010 .transpose()?;
1011 let then_types = case
1012 .when_then_expr
1013 .iter()
1014 .map(|(_when, then)| then.get_type(schema))
1015 .collect::<Result<Vec<_>>>()?;
1016 let else_type = case
1017 .else_expr
1018 .as_ref()
1019 .map(|expr| expr.get_type(schema))
1020 .transpose()?;
1021
1022 let case_when_coerce_type = case_type
1024 .as_ref()
1025 .map(|case_type| {
1026 let when_types = case
1027 .when_then_expr
1028 .iter()
1029 .map(|(when, _then)| when.get_type(schema))
1030 .collect::<Result<Vec<_>>>()?;
1031 let coerced_type =
1032 get_coerce_type_for_case_expression(&when_types, Some(case_type));
1033 coerced_type.ok_or_else(|| {
1034 plan_datafusion_err!(
1035 "Failed to coerce case ({case_type}) and when ({}) \
1036 to common types in CASE WHEN expression",
1037 when_types.iter().join(", ")
1038 )
1039 })
1040 })
1041 .transpose()?;
1042 let then_else_coerce_type =
1043 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
1044 || {
1045 if let Some(else_type) = else_type {
1046 plan_datafusion_err!(
1047 "Failed to coerce then ({}) and else ({else_type}) \
1048 to common types in CASE WHEN expression",
1049 then_types.iter().join(", ")
1050 )
1051 } else {
1052 plan_datafusion_err!(
1053 "Failed to coerce then ({}) and else (None) \
1054 to common types in CASE WHEN expression",
1055 then_types.iter().join(", ")
1056 )
1057 }
1058 },
1059 )?;
1060
1061 let case_expr = case
1063 .expr
1064 .zip(case_when_coerce_type.as_ref())
1065 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
1066 .transpose()?
1067 .map(Box::new);
1068 let when_then = case
1069 .when_then_expr
1070 .into_iter()
1071 .map(|(when, then)| {
1072 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
1073 let when = when.cast_to(when_type, schema).map_err(|e| {
1074 DataFusionError::Context(
1075 format!(
1076 "WHEN expressions in CASE couldn't be \
1077 converted to common type ({when_type})"
1078 ),
1079 Box::new(e),
1080 )
1081 })?;
1082 let then = then.cast_to(&then_else_coerce_type, schema)?;
1083 Ok((Box::new(when), Box::new(then)))
1084 })
1085 .collect::<Result<Vec<_>>>()?;
1086 let else_expr = case
1087 .else_expr
1088 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
1089 .transpose()?
1090 .map(Box::new);
1091
1092 Ok(Case::new(case_expr, when_then, else_expr))
1093}
1094
1095pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1137 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1138}
1139fn coerce_union_schema_with_schema(
1140 inputs: &[Arc<LogicalPlan>],
1141 base_schema: &DFSchemaRef,
1142) -> Result<DFSchema> {
1143 let mut union_datatypes = base_schema
1144 .fields()
1145 .iter()
1146 .map(|f| f.data_type().clone())
1147 .collect::<Vec<_>>();
1148 let mut union_nullabilities = base_schema
1149 .fields()
1150 .iter()
1151 .map(|f| f.is_nullable())
1152 .collect::<Vec<_>>();
1153 let mut union_field_meta = base_schema
1154 .fields()
1155 .iter()
1156 .map(|f| f.metadata().clone())
1157 .collect::<Vec<_>>();
1158
1159 let mut metadata = base_schema.metadata().clone();
1160
1161 for (i, plan) in inputs.iter().enumerate() {
1162 let plan_schema = plan.schema();
1163 metadata.extend(plan_schema.metadata().clone());
1164
1165 if plan_schema.fields().len() != base_schema.fields().len() {
1166 return plan_err!(
1167 "Union schemas have different number of fields: \
1168 query 1 has {} fields whereas query {} has {} fields",
1169 base_schema.fields().len(),
1170 i + 1,
1171 plan_schema.fields().len()
1172 );
1173 }
1174
1175 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1177 union_datatypes.iter_mut(),
1178 union_nullabilities.iter_mut(),
1179 union_field_meta.iter_mut(),
1180 plan_schema.fields().iter()
1181 ) {
1182 let coerced_type =
1183 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1184 || {
1185 plan_datafusion_err!(
1186 "Incompatible inputs for Union: Previous inputs were \
1187 of type {}, but got incompatible type {} on column '{}'",
1188 union_datatype,
1189 plan_field.data_type(),
1190 plan_field.name()
1191 )
1192 },
1193 )?;
1194
1195 *union_datatype = coerced_type;
1196 *union_nullable = *union_nullable || plan_field.is_nullable();
1197 union_field_map.extend(plan_field.metadata().clone());
1198 }
1199 }
1200 let union_qualified_fields = izip!(
1201 base_schema.fields(),
1202 union_datatypes.into_iter(),
1203 union_nullabilities,
1204 union_field_meta.into_iter()
1205 )
1206 .map(|(field, datatype, nullable, metadata)| {
1207 let mut field = Field::new(field.name().clone(), datatype, nullable);
1208 field.set_metadata(metadata);
1209 (None, field.into())
1210 })
1211 .collect::<Vec<_>>();
1212
1213 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1214}
1215
1216fn project_with_column_index(
1218 expr: Vec<Expr>,
1219 input: Arc<LogicalPlan>,
1220 schema: DFSchemaRef,
1221) -> Result<LogicalPlan> {
1222 let alias_expr = expr
1223 .into_iter()
1224 .enumerate()
1225 .map(|(i, e)| match e {
1226 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1227 Ok(e.unalias().alias(schema.field(i).name()))
1228 }
1229 Expr::Column(Column {
1230 relation: _,
1231 ref name,
1232 spans: _,
1233 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1234 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1235 #[expect(deprecated)]
1236 Expr::Wildcard { .. } => {
1237 plan_err!("Wildcard should be expanded before type coercion")
1238 }
1239 _ => Ok(e.alias(schema.field(i).name())),
1240 })
1241 .collect::<Result<Vec<_>>>()?;
1242
1243 Projection::try_new_with_schema(alias_expr, input, schema)
1244 .map(LogicalPlan::Projection)
1245}
1246
1247#[cfg(test)]
1248mod test {
1249 use std::any::Any;
1250 use std::sync::Arc;
1251
1252 use arrow::datatypes::DataType::Utf8;
1253 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1254 use insta::assert_snapshot;
1255
1256 use crate::analyzer::Analyzer;
1257 use crate::analyzer::type_coercion::{
1258 TypeCoercion, TypeCoercionRewriter, coerce_case_expression,
1259 };
1260 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1261 use datafusion_common::config::ConfigOptions;
1262 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1263 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1264 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1265 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1266 use datafusion_expr::test::function_stub::avg_udaf;
1267 use datafusion_expr::{
1268 AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
1269 ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF,
1270 ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast,
1271 col, create_udaf, is_true, lit,
1272 };
1273 use datafusion_functions_aggregate::average::AvgAccumulator;
1274 use datafusion_sql::TableReference;
1275
1276 fn empty() -> Arc<LogicalPlan> {
1277 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1278 produce_one_row: false,
1279 schema: Arc::new(DFSchema::empty()),
1280 }))
1281 }
1282
1283 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1284 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1285 produce_one_row: false,
1286 schema: Arc::new(
1287 DFSchema::from_unqualified_fields(
1288 vec![Field::new("a", data_type, true)].into(),
1289 std::collections::HashMap::new(),
1290 )
1291 .unwrap(),
1292 ),
1293 }))
1294 }
1295
1296 macro_rules! assert_analyzed_plan_eq {
1297 (
1298 $plan: expr,
1299 @ $expected: literal $(,)?
1300 ) => {{
1301 let options = ConfigOptions::default();
1302 let rule = Arc::new(TypeCoercion::new());
1303 assert_analyzed_plan_with_config_eq_snapshot!(
1304 options,
1305 rule,
1306 $plan,
1307 @ $expected,
1308 )
1309 }};
1310 }
1311
1312 macro_rules! coerce_on_output_if_viewtype {
1313 (
1314 $is_viewtype: expr,
1315 $plan: expr,
1316 @ $expected: literal $(,)?
1317 ) => {{
1318 let mut options = ConfigOptions::default();
1319 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1321 let rule = Arc::new(TypeCoercion::new());
1322
1323 assert_analyzed_plan_with_config_eq_snapshot!(
1324 options,
1325 rule,
1326 $plan,
1327 @ $expected,
1328 )
1329 }};
1330 }
1331
1332 fn assert_type_coercion_error(
1333 plan: LogicalPlan,
1334 expected_substr: &str,
1335 ) -> Result<()> {
1336 let options = ConfigOptions::default();
1337 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1338
1339 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1340 Ok(succeeded_plan) => {
1341 panic!(
1342 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1343 );
1344 }
1345 Err(e) => {
1346 let msg = e.to_string();
1347 assert!(
1348 msg.contains(expected_substr),
1349 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1350 );
1351 }
1352 }
1353
1354 Ok(())
1355 }
1356
1357 #[test]
1358 fn simple_case() -> Result<()> {
1359 let expr = col("a").lt(lit(2_u32));
1360 let empty = empty_with_type(DataType::Float64);
1361 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1362
1363 assert_analyzed_plan_eq!(
1364 plan,
1365 @r"
1366 Projection: a < CAST(UInt32(2) AS Float64)
1367 EmptyRelation: rows=0
1368 "
1369 )
1370 }
1371
1372 #[test]
1373 fn test_coerce_union() -> Result<()> {
1374 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1375 produce_one_row: false,
1376 schema: Arc::new(
1377 DFSchema::try_from_qualified_schema(
1378 TableReference::full("datafusion", "test", "foo"),
1379 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1380 )
1381 .unwrap(),
1382 ),
1383 }));
1384 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1385 produce_one_row: false,
1386 schema: Arc::new(
1387 DFSchema::try_from_qualified_schema(
1388 TableReference::full("datafusion", "test", "foo"),
1389 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1390 )
1391 .unwrap(),
1392 ),
1393 }));
1394 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1395 left_plan, right_plan,
1396 ])?);
1397 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1398 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1399 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1400 vec![col("a")],
1401 Arc::new(analyzed_union),
1402 )?);
1403
1404 assert_analyzed_plan_eq!(
1405 top_level_plan,
1406 @r"
1407 Projection: a
1408 Union
1409 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1410 EmptyRelation: rows=0
1411 EmptyRelation: rows=0
1412 "
1413 )
1414 }
1415
1416 #[test]
1417 fn coerce_utf8view_output() -> Result<()> {
1418 let expr = col("a");
1421 let empty = empty_with_type(DataType::Utf8View);
1422 let plan = LogicalPlan::Projection(Projection::try_new(
1423 vec![expr.clone()],
1424 Arc::clone(&empty),
1425 )?);
1426
1427 coerce_on_output_if_viewtype!(
1429 false,
1430 plan.clone(),
1431 @r"
1432 Projection: a
1433 EmptyRelation: rows=0
1434 "
1435 )?;
1436
1437 coerce_on_output_if_viewtype!(
1439 true,
1440 plan.clone(),
1441 @r"
1442 Projection: CAST(a AS LargeUtf8) AS a
1443 EmptyRelation: rows=0
1444 "
1445 )?;
1446
1447 let bool_expr = col("a").lt(lit("foo"));
1450 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1451 vec![bool_expr],
1452 Arc::clone(&empty),
1453 )?);
1454 coerce_on_output_if_viewtype!(
1456 false,
1457 bool_plan.clone(),
1458 @r#"
1459 Projection: a < CAST(Utf8("foo") AS Utf8View)
1460 EmptyRelation: rows=0
1461 "#
1462 )?;
1463
1464 coerce_on_output_if_viewtype!(
1465 false,
1466 plan.clone(),
1467 @r"
1468 Projection: a
1469 EmptyRelation: rows=0
1470 "
1471 )?;
1472
1473 coerce_on_output_if_viewtype!(
1475 true,
1476 plan.clone(),
1477 @r"
1478 Projection: CAST(a AS LargeUtf8) AS a
1479 EmptyRelation: rows=0
1480 "
1481 )?;
1482
1483 let sort_expr = expr.sort(true, true);
1486 let sort_plan = LogicalPlan::Sort(Sort {
1487 expr: vec![sort_expr],
1488 input: Arc::new(plan),
1489 fetch: None,
1490 });
1491
1492 coerce_on_output_if_viewtype!(
1494 false,
1495 sort_plan.clone(),
1496 @r"
1497 Sort: a ASC NULLS FIRST
1498 Projection: a
1499 EmptyRelation: rows=0
1500 "
1501 )?;
1502
1503 coerce_on_output_if_viewtype!(
1505 true,
1506 sort_plan.clone(),
1507 @r"
1508 Projection: CAST(a AS LargeUtf8) AS a
1509 Sort: a ASC NULLS FIRST
1510 Projection: a
1511 EmptyRelation: rows=0
1512 "
1513 )?;
1514
1515 let plan = LogicalPlan::Projection(Projection::try_new(
1518 vec![col("a")],
1519 Arc::new(sort_plan),
1520 )?);
1521 coerce_on_output_if_viewtype!(
1523 false,
1524 plan.clone(),
1525 @r"
1526 Projection: a
1527 Sort: a ASC NULLS FIRST
1528 Projection: a
1529 EmptyRelation: rows=0
1530 "
1531 )?;
1532 coerce_on_output_if_viewtype!(
1534 true,
1535 plan.clone(),
1536 @r"
1537 Projection: CAST(a AS LargeUtf8) AS a
1538 Sort: a ASC NULLS FIRST
1539 Projection: a
1540 EmptyRelation: rows=0
1541 "
1542 )?;
1543
1544 Ok(())
1545 }
1546
1547 #[test]
1548 fn coerce_binaryview_output() -> Result<()> {
1549 let expr = col("a");
1552 let empty = empty_with_type(DataType::BinaryView);
1553 let plan = LogicalPlan::Projection(Projection::try_new(
1554 vec![expr.clone()],
1555 Arc::clone(&empty),
1556 )?);
1557
1558 coerce_on_output_if_viewtype!(
1560 false,
1561 plan.clone(),
1562 @r"
1563 Projection: a
1564 EmptyRelation: rows=0
1565 "
1566 )?;
1567
1568 coerce_on_output_if_viewtype!(
1570 true,
1571 plan.clone(),
1572 @r"
1573 Projection: CAST(a AS LargeBinary) AS a
1574 EmptyRelation: rows=0
1575 "
1576 )?;
1577
1578 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1581 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1582 vec![bool_expr],
1583 Arc::clone(&empty),
1584 )?);
1585
1586 coerce_on_output_if_viewtype!(
1588 false,
1589 bool_plan.clone(),
1590 @r#"
1591 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1592 EmptyRelation: rows=0
1593 "#
1594 )?;
1595
1596 coerce_on_output_if_viewtype!(
1598 true,
1599 bool_plan.clone(),
1600 @r#"
1601 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1602 EmptyRelation: rows=0
1603 "#
1604 )?;
1605
1606 let sort_expr = expr.sort(true, true);
1609 let sort_plan = LogicalPlan::Sort(Sort {
1610 expr: vec![sort_expr],
1611 input: Arc::new(plan),
1612 fetch: None,
1613 });
1614
1615 coerce_on_output_if_viewtype!(
1617 false,
1618 sort_plan.clone(),
1619 @r"
1620 Sort: a ASC NULLS FIRST
1621 Projection: a
1622 EmptyRelation: rows=0
1623 "
1624 )?;
1625 coerce_on_output_if_viewtype!(
1627 true,
1628 sort_plan.clone(),
1629 @r"
1630 Projection: CAST(a AS LargeBinary) AS a
1631 Sort: a ASC NULLS FIRST
1632 Projection: a
1633 EmptyRelation: rows=0
1634 "
1635 )?;
1636
1637 let plan = LogicalPlan::Projection(Projection::try_new(
1640 vec![col("a")],
1641 Arc::new(sort_plan),
1642 )?);
1643
1644 coerce_on_output_if_viewtype!(
1646 false,
1647 plan.clone(),
1648 @r"
1649 Projection: a
1650 Sort: a ASC NULLS FIRST
1651 Projection: a
1652 EmptyRelation: rows=0
1653 "
1654 )?;
1655
1656 coerce_on_output_if_viewtype!(
1658 true,
1659 plan.clone(),
1660 @r"
1661 Projection: CAST(a AS LargeBinary) AS a
1662 Sort: a ASC NULLS FIRST
1663 Projection: a
1664 EmptyRelation: rows=0
1665 "
1666 )?;
1667
1668 Ok(())
1669 }
1670
1671 #[test]
1672 fn nested_case() -> Result<()> {
1673 let expr = col("a").lt(lit(2_u32));
1674 let empty = empty_with_type(DataType::Float64);
1675
1676 let plan = LogicalPlan::Projection(Projection::try_new(
1677 vec![expr.clone().or(expr)],
1678 empty,
1679 )?);
1680
1681 assert_analyzed_plan_eq!(
1682 plan,
1683 @r"
1684 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1685 EmptyRelation: rows=0
1686 "
1687 )
1688 }
1689
1690 #[derive(Debug, PartialEq, Eq, Hash)]
1691 struct TestScalarUDF {
1692 signature: Signature,
1693 }
1694
1695 impl ScalarUDFImpl for TestScalarUDF {
1696 fn as_any(&self) -> &dyn Any {
1697 self
1698 }
1699
1700 fn name(&self) -> &str {
1701 "TestScalarUDF"
1702 }
1703
1704 fn signature(&self) -> &Signature {
1705 &self.signature
1706 }
1707
1708 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1709 Ok(Utf8)
1710 }
1711
1712 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1713 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1714 }
1715 }
1716
1717 #[test]
1718 fn scalar_udf() -> Result<()> {
1719 let empty = empty();
1720
1721 let udf = ScalarUDF::from(TestScalarUDF {
1722 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1723 })
1724 .call(vec![lit(123_i32)]);
1725 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1726
1727 assert_analyzed_plan_eq!(
1728 plan,
1729 @r"
1730 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1731 EmptyRelation: rows=0
1732 "
1733 )
1734 }
1735
1736 #[test]
1737 fn scalar_udf_invalid_input() -> Result<()> {
1738 let empty = empty();
1739 let udf = ScalarUDF::from(TestScalarUDF {
1740 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1741 })
1742 .call(vec![lit("Apple")]);
1743 Projection::try_new(vec![udf], empty)
1744 .expect_err("Expected an error due to incorrect function input");
1745
1746 Ok(())
1747 }
1748
1749 #[test]
1750 fn scalar_function() -> Result<()> {
1751 let empty = empty();
1753 let lit_expr = lit(10i64);
1754 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1755 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1756 });
1757 let scalar_function_expr =
1758 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1759 let plan = LogicalPlan::Projection(Projection::try_new(
1760 vec![scalar_function_expr],
1761 empty,
1762 )?);
1763
1764 assert_analyzed_plan_eq!(
1765 plan,
1766 @r"
1767 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1768 EmptyRelation: rows=0
1769 "
1770 )
1771 }
1772
1773 #[test]
1774 fn agg_udaf() -> Result<()> {
1775 let empty = empty();
1776 let my_avg = create_udaf(
1777 "MY_AVG",
1778 vec![DataType::Float64],
1779 Arc::new(DataType::Float64),
1780 Volatility::Immutable,
1781 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1782 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1783 );
1784 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1785 Arc::new(my_avg),
1786 vec![lit(10i64)],
1787 false,
1788 None,
1789 vec![],
1790 None,
1791 ));
1792 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1793
1794 assert_analyzed_plan_eq!(
1795 plan,
1796 @r"
1797 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1798 EmptyRelation: rows=0
1799 "
1800 )
1801 }
1802
1803 #[test]
1804 fn agg_udaf_invalid_input() -> Result<()> {
1805 let empty = empty();
1806 let return_type = DataType::Float64;
1807 let accumulator: AccumulatorFactoryFunction =
1808 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1809 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1810 "MY_AVG",
1811 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1812 return_type,
1813 accumulator,
1814 vec![
1815 Field::new("count", DataType::UInt64, true).into(),
1816 Field::new("avg", DataType::Float64, true).into(),
1817 ],
1818 ));
1819 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1820 Arc::new(my_avg),
1821 vec![lit("10")],
1822 false,
1823 None,
1824 vec![],
1825 None,
1826 ));
1827
1828 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1829 assert!(
1830 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1831 );
1832 Ok(())
1833 }
1834
1835 #[test]
1836 fn agg_function_case() -> Result<()> {
1837 let empty = empty();
1838 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1839 avg_udaf(),
1840 vec![lit(12f64)],
1841 false,
1842 None,
1843 vec![],
1844 None,
1845 ));
1846 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1847
1848 assert_analyzed_plan_eq!(
1849 plan,
1850 @r"
1851 Projection: avg(Float64(12))
1852 EmptyRelation: rows=0
1853 "
1854 )?;
1855
1856 let empty = empty_with_type(DataType::Int32);
1857 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1858 avg_udaf(),
1859 vec![cast(col("a"), DataType::Float64)],
1860 false,
1861 None,
1862 vec![],
1863 None,
1864 ));
1865 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1866
1867 assert_analyzed_plan_eq!(
1868 plan,
1869 @r"
1870 Projection: avg(CAST(a AS Float64))
1871 EmptyRelation: rows=0
1872 "
1873 )
1874 }
1875
1876 #[test]
1877 fn agg_function_invalid_input_avg() -> Result<()> {
1878 let empty = empty();
1879 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1880 avg_udaf(),
1881 vec![lit("1")],
1882 false,
1883 None,
1884 vec![],
1885 None,
1886 ));
1887 let err = Projection::try_new(vec![agg_expr], empty)
1888 .err()
1889 .unwrap()
1890 .strip_backtrace();
1891 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed"));
1892 Ok(())
1893 }
1894
1895 #[test]
1896 fn binary_op_date32_op_interval() -> Result<()> {
1897 let expr = cast(lit("1998-03-18"), DataType::Date32)
1899 + lit(ScalarValue::new_interval_dt(123, 456));
1900 let empty = empty();
1901 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1902
1903 assert_analyzed_plan_eq!(
1904 plan,
1905 @r#"
1906 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1907 EmptyRelation: rows=0
1908 "#
1909 )
1910 }
1911
1912 #[test]
1913 fn inlist_case() -> Result<()> {
1914 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1916 let empty = empty_with_type(DataType::Int64);
1917 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1918 assert_analyzed_plan_eq!(
1919 plan,
1920 @r"
1921 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1922 EmptyRelation: rows=0
1923 ")?;
1924
1925 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1927 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1928 produce_one_row: false,
1929 schema: Arc::new(DFSchema::from_unqualified_fields(
1930 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1931 std::collections::HashMap::new(),
1932 )?),
1933 }));
1934 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1935 assert_analyzed_plan_eq!(
1936 plan,
1937 @r"
1938 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1939 EmptyRelation: rows=0
1940 ")
1941 }
1942
1943 #[test]
1944 fn between_case() -> Result<()> {
1945 let expr = col("a").between(
1946 lit("2002-05-08"),
1947 cast(lit("2002-05-08"), DataType::Date32)
1949 + lit(ScalarValue::new_interval_ym(0, 1)),
1950 );
1951 let empty = empty_with_type(Utf8);
1952 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1953
1954 assert_analyzed_plan_eq!(
1955 plan,
1956 @r#"
1957 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1958 EmptyRelation: rows=0
1959 "#
1960 )
1961 }
1962
1963 #[test]
1964 fn between_infer_cheap_type() -> Result<()> {
1965 let expr = col("a").between(
1966 cast(lit("2002-05-08"), DataType::Date32)
1968 + lit(ScalarValue::new_interval_ym(0, 1)),
1969 lit("2002-12-08"),
1970 );
1971 let empty = empty_with_type(Utf8);
1972 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1973
1974 assert_analyzed_plan_eq!(
1976 plan,
1977 @r#"
1978 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1979 EmptyRelation: rows=0
1980 "#
1981 )
1982 }
1983
1984 #[test]
1985 fn between_null() -> Result<()> {
1986 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1987 let empty = empty();
1988 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1989
1990 assert_analyzed_plan_eq!(
1991 plan,
1992 @r"
1993 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1994 EmptyRelation: rows=0
1995 "
1996 )
1997 }
1998
1999 #[test]
2000 fn is_bool_for_type_coercion() -> Result<()> {
2001 let expr = col("a").is_true();
2003 let empty = empty_with_type(DataType::Boolean);
2004 let plan =
2005 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2006
2007 assert_analyzed_plan_eq!(
2008 plan,
2009 @r"
2010 Projection: a IS TRUE
2011 EmptyRelation: rows=0
2012 "
2013 )?;
2014
2015 let empty = empty_with_type(DataType::Int64);
2016 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2017 assert_type_coercion_error(
2018 plan,
2019 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean",
2020 )?;
2021
2022 let expr = col("a").is_not_true();
2024 let empty = empty_with_type(DataType::Boolean);
2025 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2026
2027 assert_analyzed_plan_eq!(
2028 plan,
2029 @r"
2030 Projection: a IS NOT TRUE
2031 EmptyRelation: rows=0
2032 "
2033 )?;
2034
2035 let expr = col("a").is_false();
2037 let empty = empty_with_type(DataType::Boolean);
2038 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2039
2040 assert_analyzed_plan_eq!(
2041 plan,
2042 @r"
2043 Projection: a IS FALSE
2044 EmptyRelation: rows=0
2045 "
2046 )?;
2047
2048 let expr = col("a").is_not_false();
2050 let empty = empty_with_type(DataType::Boolean);
2051 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2052
2053 assert_analyzed_plan_eq!(
2054 plan,
2055 @r"
2056 Projection: a IS NOT FALSE
2057 EmptyRelation: rows=0
2058 "
2059 )
2060 }
2061
2062 #[test]
2063 fn like_for_type_coercion() -> Result<()> {
2064 let expr = Box::new(col("a"));
2066 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2067 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2068 let empty = empty_with_type(Utf8);
2069 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2070
2071 assert_analyzed_plan_eq!(
2072 plan,
2073 @r#"
2074 Projection: a LIKE Utf8("abc")
2075 EmptyRelation: rows=0
2076 "#
2077 )?;
2078
2079 let expr = Box::new(col("a"));
2080 let pattern = Box::new(lit(ScalarValue::Null));
2081 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2082 let empty = empty_with_type(Utf8);
2083 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2084
2085 assert_analyzed_plan_eq!(
2086 plan,
2087 @r"
2088 Projection: a LIKE CAST(NULL AS Utf8)
2089 EmptyRelation: rows=0
2090 "
2091 )?;
2092
2093 let expr = Box::new(col("a"));
2094 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2095 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2096 let empty = empty_with_type(DataType::Int64);
2097 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2098 assert_type_coercion_error(
2099 plan,
2100 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
2101 )?;
2102
2103 let expr = Box::new(col("a"));
2105 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2106 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2107 let empty = empty_with_type(Utf8);
2108 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2109
2110 assert_analyzed_plan_eq!(
2111 plan,
2112 @r#"
2113 Projection: a ILIKE Utf8("abc")
2114 EmptyRelation: rows=0
2115 "#
2116 )?;
2117
2118 let expr = Box::new(col("a"));
2119 let pattern = Box::new(lit(ScalarValue::Null));
2120 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2121 let empty = empty_with_type(Utf8);
2122 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2123
2124 assert_analyzed_plan_eq!(
2125 plan,
2126 @r"
2127 Projection: a ILIKE CAST(NULL AS Utf8)
2128 EmptyRelation: rows=0
2129 "
2130 )?;
2131
2132 let expr = Box::new(col("a"));
2133 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2134 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2135 let empty = empty_with_type(DataType::Int64);
2136 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2137 assert_type_coercion_error(
2138 plan,
2139 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2140 )?;
2141
2142 Ok(())
2143 }
2144
2145 #[test]
2146 fn unknown_for_type_coercion() -> Result<()> {
2147 let expr = col("a").is_unknown();
2149 let empty = empty_with_type(DataType::Boolean);
2150 let plan =
2151 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2152
2153 assert_analyzed_plan_eq!(
2154 plan,
2155 @r"
2156 Projection: a IS UNKNOWN
2157 EmptyRelation: rows=0
2158 "
2159 )?;
2160
2161 let empty = empty_with_type(Utf8);
2162 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2163 assert_type_coercion_error(
2164 plan,
2165 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean",
2166 )?;
2167
2168 let expr = col("a").is_not_unknown();
2170 let empty = empty_with_type(DataType::Boolean);
2171 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2172
2173 assert_analyzed_plan_eq!(
2174 plan,
2175 @r"
2176 Projection: a IS NOT UNKNOWN
2177 EmptyRelation: rows=0
2178 "
2179 )
2180 }
2181
2182 #[test]
2183 fn concat_for_type_coercion() -> Result<()> {
2184 let empty = empty_with_type(Utf8);
2185 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2186
2187 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2189 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2190 })
2191 .call(args.to_vec());
2192 let plan =
2193 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2194 assert_analyzed_plan_eq!(
2195 plan,
2196 @r#"
2197 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2198 EmptyRelation: rows=0
2199 "#
2200 )
2201 }
2202
2203 #[test]
2204 fn test_type_coercion_rewrite() -> Result<()> {
2205 let schema = Arc::new(DFSchema::from_unqualified_fields(
2207 vec![Field::new("a", DataType::Int64, true)].into(),
2208 std::collections::HashMap::new(),
2209 )?);
2210 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2211 let expr = is_true(lit(12i32).gt(lit(13i64)));
2212 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2213 let result = expr.rewrite(&mut rewriter).data()?;
2214 assert_eq!(expected, result);
2215
2216 let schema = Arc::new(DFSchema::from_unqualified_fields(
2218 vec![Field::new("a", DataType::Int64, true)].into(),
2219 std::collections::HashMap::new(),
2220 )?);
2221 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2222 let expr = is_true(lit(12i32).eq(lit(13i64)));
2223 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2224 let result = expr.rewrite(&mut rewriter).data()?;
2225 assert_eq!(expected, result);
2226
2227 let schema = Arc::new(DFSchema::from_unqualified_fields(
2229 vec![Field::new("a", DataType::Int64, true)].into(),
2230 std::collections::HashMap::new(),
2231 )?);
2232 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2233 let expr = is_true(lit(12i32).lt(lit(13i64)));
2234 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2235 let result = expr.rewrite(&mut rewriter).data()?;
2236 assert_eq!(expected, result);
2237
2238 Ok(())
2239 }
2240
2241 #[test]
2242 fn binary_op_date32_eq_ts() -> Result<()> {
2243 let expr = cast(
2244 lit("1998-03-18"),
2245 DataType::Timestamp(TimeUnit::Nanosecond, None),
2246 )
2247 .eq(cast(lit("1998-03-18"), DataType::Date32));
2248 let empty = empty();
2249 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2250
2251 assert_analyzed_plan_eq!(
2252 plan,
2253 @r#"
2254 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2255 EmptyRelation: rows=0
2256 "#
2257 )
2258 }
2259
2260 fn cast_if_not_same_type(
2261 expr: Box<Expr>,
2262 data_type: &DataType,
2263 schema: &DFSchemaRef,
2264 ) -> Box<Expr> {
2265 if &expr.get_type(schema).unwrap() != data_type {
2266 Box::new(cast(*expr, data_type.clone()))
2267 } else {
2268 expr
2269 }
2270 }
2271
2272 fn cast_helper(
2273 case: Case,
2274 case_when_type: &DataType,
2275 then_else_type: &DataType,
2276 schema: &DFSchemaRef,
2277 ) -> Case {
2278 let expr = case
2279 .expr
2280 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2281 let when_then_expr = case
2282 .when_then_expr
2283 .into_iter()
2284 .map(|(when, then)| {
2285 (
2286 cast_if_not_same_type(when, case_when_type, schema),
2287 cast_if_not_same_type(then, then_else_type, schema),
2288 )
2289 })
2290 .collect::<Vec<_>>();
2291 let else_expr = case
2292 .else_expr
2293 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2294
2295 Case {
2296 expr,
2297 when_then_expr,
2298 else_expr,
2299 }
2300 }
2301
2302 #[test]
2303 fn test_case_expression_coercion() -> Result<()> {
2304 let schema = Arc::new(DFSchema::from_unqualified_fields(
2305 vec![
2306 Field::new("boolean", DataType::Boolean, true),
2307 Field::new("integer", DataType::Int32, true),
2308 Field::new("float", DataType::Float32, true),
2309 Field::new(
2310 "timestamp",
2311 DataType::Timestamp(TimeUnit::Nanosecond, None),
2312 true,
2313 ),
2314 Field::new("date", DataType::Date32, true),
2315 Field::new(
2316 "interval",
2317 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2318 true,
2319 ),
2320 Field::new("binary", DataType::Binary, true),
2321 Field::new("string", Utf8, true),
2322 Field::new("decimal", DataType::Decimal128(10, 10), true),
2323 ]
2324 .into(),
2325 std::collections::HashMap::new(),
2326 )?);
2327
2328 let case = Case {
2329 expr: None,
2330 when_then_expr: vec![
2331 (Box::new(col("boolean")), Box::new(col("integer"))),
2332 (Box::new(col("integer")), Box::new(col("float"))),
2333 (Box::new(col("string")), Box::new(col("string"))),
2334 ],
2335 else_expr: None,
2336 };
2337 let case_when_common_type = DataType::Boolean;
2338 let then_else_common_type = Utf8;
2339 let expected = cast_helper(
2340 case.clone(),
2341 &case_when_common_type,
2342 &then_else_common_type,
2343 &schema,
2344 );
2345 let actual = coerce_case_expression(case, &schema)?;
2346 assert_eq!(expected, actual);
2347
2348 let case = Case {
2349 expr: Some(Box::new(col("string"))),
2350 when_then_expr: vec![
2351 (Box::new(col("float")), Box::new(col("integer"))),
2352 (Box::new(col("integer")), Box::new(col("float"))),
2353 (Box::new(col("string")), Box::new(col("string"))),
2354 ],
2355 else_expr: Some(Box::new(col("string"))),
2356 };
2357 let case_when_common_type = Utf8;
2358 let then_else_common_type = Utf8;
2359 let expected = cast_helper(
2360 case.clone(),
2361 &case_when_common_type,
2362 &then_else_common_type,
2363 &schema,
2364 );
2365 let actual = coerce_case_expression(case, &schema)?;
2366 assert_eq!(expected, actual);
2367
2368 let case = Case {
2369 expr: Some(Box::new(col("interval"))),
2370 when_then_expr: vec![
2371 (Box::new(col("float")), Box::new(col("integer"))),
2372 (Box::new(col("binary")), Box::new(col("float"))),
2373 (Box::new(col("string")), Box::new(col("string"))),
2374 ],
2375 else_expr: Some(Box::new(col("string"))),
2376 };
2377 let err = coerce_case_expression(case, &schema).unwrap_err();
2378 assert_snapshot!(
2379 err.strip_backtrace(),
2380 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2381 );
2382
2383 let case = Case {
2384 expr: Some(Box::new(col("string"))),
2385 when_then_expr: vec![
2386 (Box::new(col("float")), Box::new(col("date"))),
2387 (Box::new(col("string")), Box::new(col("float"))),
2388 (Box::new(col("string")), Box::new(col("binary"))),
2389 ],
2390 else_expr: Some(Box::new(col("timestamp"))),
2391 };
2392 let err = coerce_case_expression(case, &schema).unwrap_err();
2393 assert_snapshot!(
2394 err.strip_backtrace(),
2395 @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2396 );
2397
2398 Ok(())
2399 }
2400
2401 macro_rules! test_case_expression {
2402 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2403 let case = Case {
2404 expr: $expr.map(|e| Box::new(col(e))),
2405 when_then_expr: $when_then,
2406 else_expr: None,
2407 };
2408
2409 let expected =
2410 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2411
2412 let actual = coerce_case_expression(case, &$schema)?;
2413 assert_eq!(expected, actual);
2414 };
2415 }
2416
2417 #[test]
2418 fn tes_case_when_list() -> Result<()> {
2419 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2420 let schema = Arc::new(DFSchema::from_unqualified_fields(
2421 vec![
2422 Field::new(
2423 "large_list",
2424 DataType::LargeList(Arc::clone(&inner_field)),
2425 true,
2426 ),
2427 Field::new(
2428 "fixed_list",
2429 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2430 true,
2431 ),
2432 Field::new("list", DataType::List(inner_field), true),
2433 ]
2434 .into(),
2435 std::collections::HashMap::new(),
2436 )?);
2437
2438 test_case_expression!(
2439 Some("list"),
2440 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2441 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2442 Utf8,
2443 schema
2444 );
2445
2446 test_case_expression!(
2447 Some("large_list"),
2448 vec![(Box::new(col("list")), Box::new(lit("1")))],
2449 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2450 Utf8,
2451 schema
2452 );
2453
2454 test_case_expression!(
2455 Some("list"),
2456 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2457 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2458 Utf8,
2459 schema
2460 );
2461
2462 test_case_expression!(
2463 Some("fixed_list"),
2464 vec![(Box::new(col("list")), Box::new(lit("1")))],
2465 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2466 Utf8,
2467 schema
2468 );
2469
2470 test_case_expression!(
2471 Some("fixed_list"),
2472 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2473 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2474 Utf8,
2475 schema
2476 );
2477
2478 test_case_expression!(
2479 Some("large_list"),
2480 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2481 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2482 Utf8,
2483 schema
2484 );
2485 Ok(())
2486 }
2487
2488 #[test]
2489 fn test_then_else_list() -> Result<()> {
2490 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2491 let schema = Arc::new(DFSchema::from_unqualified_fields(
2492 vec![
2493 Field::new("boolean", DataType::Boolean, true),
2494 Field::new(
2495 "large_list",
2496 DataType::LargeList(Arc::clone(&inner_field)),
2497 true,
2498 ),
2499 Field::new(
2500 "fixed_list",
2501 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2502 true,
2503 ),
2504 Field::new("list", DataType::List(inner_field), true),
2505 ]
2506 .into(),
2507 std::collections::HashMap::new(),
2508 )?);
2509
2510 test_case_expression!(
2512 None::<String>,
2513 vec![
2514 (Box::new(col("boolean")), Box::new(col("large_list"))),
2515 (Box::new(col("boolean")), Box::new(col("list")))
2516 ],
2517 DataType::Boolean,
2518 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2519 schema
2520 );
2521
2522 test_case_expression!(
2523 None::<String>,
2524 vec![
2525 (Box::new(col("boolean")), Box::new(col("list"))),
2526 (Box::new(col("boolean")), Box::new(col("large_list")))
2527 ],
2528 DataType::Boolean,
2529 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2530 schema
2531 );
2532
2533 test_case_expression!(
2535 None::<String>,
2536 vec![
2537 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2538 (Box::new(col("boolean")), Box::new(col("list")))
2539 ],
2540 DataType::Boolean,
2541 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2542 schema
2543 );
2544
2545 test_case_expression!(
2546 None::<String>,
2547 vec![
2548 (Box::new(col("boolean")), Box::new(col("list"))),
2549 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2550 ],
2551 DataType::Boolean,
2552 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2553 schema
2554 );
2555
2556 test_case_expression!(
2558 None::<String>,
2559 vec![
2560 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2561 (Box::new(col("boolean")), Box::new(col("large_list")))
2562 ],
2563 DataType::Boolean,
2564 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2565 schema
2566 );
2567
2568 test_case_expression!(
2569 None::<String>,
2570 vec![
2571 (Box::new(col("boolean")), Box::new(col("large_list"))),
2572 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2573 ],
2574 DataType::Boolean,
2575 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2576 schema
2577 );
2578 Ok(())
2579 }
2580
2581 #[test]
2582 fn test_map_with_diff_name() -> Result<()> {
2583 let mut builder = SchemaBuilder::new();
2584 builder.push(Field::new("key", Utf8, false));
2585 builder.push(Field::new("value", DataType::Float64, true));
2586 let struct_fields = builder.finish().fields;
2587
2588 let fields =
2589 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2590 let map_type_entries = DataType::Map(Arc::new(fields), false);
2591
2592 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2593 let may_type_custom = DataType::Map(Arc::new(fields), false);
2594
2595 let expr = col("a").eq(cast(col("a"), may_type_custom));
2596 let empty = empty_with_type(map_type_entries);
2597 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2598
2599 assert_analyzed_plan_eq!(
2600 plan,
2601 @r#"
2602 Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted))
2603 EmptyRelation: rows=0
2604 "#
2605 )
2606 }
2607
2608 #[test]
2609 fn interval_plus_timestamp() -> Result<()> {
2610 let expr = Expr::BinaryExpr(BinaryExpr::new(
2612 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2613 Operator::Plus,
2614 Box::new(cast(
2615 lit("2000-01-01T00:00:00"),
2616 DataType::Timestamp(TimeUnit::Nanosecond, None),
2617 )),
2618 ));
2619 let empty = empty();
2620 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2621
2622 assert_analyzed_plan_eq!(
2623 plan,
2624 @r#"
2625 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2626 EmptyRelation: rows=0
2627 "#
2628 )
2629 }
2630
2631 #[test]
2632 fn timestamp_subtract_timestamp() -> Result<()> {
2633 let expr = Expr::BinaryExpr(BinaryExpr::new(
2634 Box::new(cast(
2635 lit("1998-03-18"),
2636 DataType::Timestamp(TimeUnit::Nanosecond, None),
2637 )),
2638 Operator::Minus,
2639 Box::new(cast(
2640 lit("1998-03-18"),
2641 DataType::Timestamp(TimeUnit::Nanosecond, None),
2642 )),
2643 ));
2644 let empty = empty();
2645 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2646
2647 assert_analyzed_plan_eq!(
2648 plan,
2649 @r#"
2650 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2651 EmptyRelation: rows=0
2652 "#
2653 )
2654 }
2655
2656 #[test]
2657 fn in_subquery_cast_subquery() -> Result<()> {
2658 let empty_int32 = empty_with_type(DataType::Int32);
2659 let empty_int64 = empty_with_type(DataType::Int64);
2660
2661 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2662 Box::new(col("a")),
2663 Subquery {
2664 subquery: empty_int32,
2665 outer_ref_columns: vec![],
2666 spans: Spans::new(),
2667 },
2668 false,
2669 ));
2670 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2671 assert_analyzed_plan_eq!(
2674 plan,
2675 @r"
2676 Filter: a IN (<subquery>)
2677 Subquery:
2678 Projection: CAST(a AS Int64)
2679 EmptyRelation: rows=0
2680 EmptyRelation: rows=0
2681 "
2682 )
2683 }
2684
2685 #[test]
2686 fn in_subquery_cast_expr() -> Result<()> {
2687 let empty_int32 = empty_with_type(DataType::Int32);
2688 let empty_int64 = empty_with_type(DataType::Int64);
2689
2690 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2691 Box::new(col("a")),
2692 Subquery {
2693 subquery: empty_int64,
2694 outer_ref_columns: vec![],
2695 spans: Spans::new(),
2696 },
2697 false,
2698 ));
2699 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2700
2701 assert_analyzed_plan_eq!(
2703 plan,
2704 @r"
2705 Filter: CAST(a AS Int64) IN (<subquery>)
2706 Subquery:
2707 EmptyRelation: rows=0
2708 EmptyRelation: rows=0
2709 "
2710 )
2711 }
2712
2713 #[test]
2714 fn in_subquery_cast_all() -> Result<()> {
2715 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2716 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2717
2718 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2719 Box::new(col("a")),
2720 Subquery {
2721 subquery: empty_inside,
2722 outer_ref_columns: vec![],
2723 spans: Spans::new(),
2724 },
2725 false,
2726 ));
2727 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2728
2729 assert_analyzed_plan_eq!(
2731 plan,
2732 @r"
2733 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2734 Subquery:
2735 Projection: CAST(a AS Decimal128(13, 8))
2736 EmptyRelation: rows=0
2737 EmptyRelation: rows=0
2738 "
2739 )
2740 }
2741}