1use arrow::compute::can_cast_types;
21use datafusion_expr::binary::BinaryTypeCoercer;
22use itertools::{Itertools as _, izip};
23use std::sync::Arc;
24
25use crate::analyzer::AnalyzerRule;
26use crate::utils::NamePreserver;
27
28use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
29use arrow::temporal_conversions::SECONDS_IN_DAY;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
32use datafusion_common::{
33 Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference,
34 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
35 plan_err,
36};
37use datafusion_expr::expr::{
38 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::fields_with_udf;
46use datafusion_expr::type_coercion::other::{
47 get_coerce_type_for_case_expression, get_coerce_type_for_list,
48};
49use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
50use datafusion_expr::utils::merge_schema;
51use datafusion_expr::{
52 AggregateUDF, Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator,
53 Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
54 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not,
55};
56
57#[derive(Default, Debug)]
60pub struct TypeCoercion {}
61
62impl TypeCoercion {
63 pub fn new() -> Self {
64 Self {}
65 }
66}
67
68fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
70 if !config.optimizer.expand_views_at_output {
71 return Ok(plan);
72 }
73
74 let outer_refs = plan.expressions();
75 if outer_refs.is_empty() {
76 return Ok(plan);
77 }
78
79 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
80 coerce_plan_expr_for_schema(plan, &dfschema?)
81 } else {
82 Ok(plan)
83 }
84}
85
86impl AnalyzerRule for TypeCoercion {
87 fn name(&self) -> &str {
88 "type_coercion"
89 }
90
91 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
92 let empty_schema = DFSchema::empty();
93
94 let transformed_plan = plan
96 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
97 .data;
98
99 coerce_output(transformed_plan, config)
101 }
102}
103
104fn analyze_internal(
108 external_schema: &DFSchema,
109 plan: LogicalPlan,
110) -> Result<Transformed<LogicalPlan>> {
111 let mut schema = merge_schema(&plan.inputs());
114
115 if let LogicalPlan::TableScan(ts) = &plan {
116 let source_schema = DFSchema::try_from_qualified_schema(
117 ts.table_name.clone(),
118 &ts.source.schema(),
119 )?;
120 schema.merge(&source_schema);
121 }
122
123 schema.merge(external_schema);
127
128 let plan = if let LogicalPlan::Filter(mut filter) = plan {
130 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
131 LogicalPlan::Filter(filter)
132 } else {
133 plan
134 };
135
136 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
137
138 let name_preserver = NamePreserver::new(&plan);
139 plan.map_expressions(|expr| {
141 let original_name = name_preserver.save(&expr);
142 expr.rewrite(&mut expr_rewrite)
143 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
144 })?
145 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
147 .map_data(|plan| plan.recompute_schema())
149}
150
151pub struct TypeCoercionRewriter<'a> {
153 pub(crate) schema: &'a DFSchema,
154}
155
156impl<'a> TypeCoercionRewriter<'a> {
157 pub fn new(schema: &'a DFSchema) -> Self {
160 Self { schema }
161 }
162
163 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
168 match plan {
169 LogicalPlan::Join(join) => self.coerce_join(join),
170 LogicalPlan::Union(union) => Self::coerce_union(union),
171 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
172 _ => Ok(plan),
173 }
174 }
175
176 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
185 join.on = join
186 .on
187 .into_iter()
188 .map(|(lhs, rhs)| {
189 let left_schema = join.left.schema();
192 let right_schema = join.right.schema();
193 let (lhs, rhs) = self.coerce_binary_op(
194 lhs,
195 left_schema,
196 Operator::Eq,
197 rhs,
198 right_schema,
199 )?;
200 Ok((lhs, rhs))
201 })
202 .collect::<Result<Vec<_>>>()?;
203
204 join.filter = join
206 .filter
207 .map(|expr| self.coerce_join_filter(expr))
208 .transpose()?;
209
210 Ok(LogicalPlan::Join(join))
211 }
212
213 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
216 let union_schema = Arc::new(coerce_union_schema_with_schema(
217 &union_plan.inputs,
218 &union_plan.schema,
219 )?);
220 let new_inputs = union_plan
221 .inputs
222 .into_iter()
223 .map(|p| {
224 let plan =
225 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
226 match plan {
227 LogicalPlan::Projection(Projection { expr, input, .. }) => {
228 Ok(Arc::new(project_with_column_index(
229 expr,
230 input,
231 Arc::clone(&union_schema),
232 )?))
233 }
234 other_plan => Ok(Arc::new(other_plan)),
235 }
236 })
237 .collect::<Result<Vec<_>>>()?;
238 Ok(LogicalPlan::Union(Union {
239 inputs: new_inputs,
240 schema: union_schema,
241 }))
242 }
243
244 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
246 fn coerce_limit_expr(
247 expr: Expr,
248 schema: &DFSchema,
249 expr_name: &str,
250 ) -> Result<Expr> {
251 let dt = expr.get_type(schema)?;
252 if dt.is_integer() || dt.is_null() {
253 expr.cast_to(&DataType::Int64, schema)
254 } else {
255 plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
256 }
257 }
258
259 let empty_schema = DFSchema::empty();
260 let new_fetch = limit
261 .fetch
262 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
263 .transpose()?;
264 let new_skip = limit
265 .skip
266 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
267 .transpose()?;
268 Ok(LogicalPlan::Limit(Limit {
269 input: limit.input,
270 fetch: new_fetch.map(Box::new),
271 skip: new_skip.map(Box::new),
272 }))
273 }
274
275 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
276 let expr_type = expr.get_type(self.schema)?;
277 match expr_type {
278 DataType::Boolean => Ok(expr),
279 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
280 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
281 }
282 }
283
284 fn coerce_binary_op(
285 &self,
286 left: Expr,
287 left_schema: &DFSchema,
288 op: Operator,
289 right: Expr,
290 right_schema: &DFSchema,
291 ) -> Result<(Expr, Expr)> {
292 let left_data_type = left.get_type(left_schema)?;
293 let right_data_type = right.get_type(right_schema)?;
294 let (left_type, right_type) =
295 BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type)
296 .get_input_types()?;
297 let left_cast_ok = can_cast_types(&left_data_type, &left_type);
298 let right_cast_ok = can_cast_types(&right_data_type, &right_type);
299
300 let left_expr = if !left_cast_ok {
304 Self::coerce_date_time_math_op(
305 left,
306 &op,
307 &left_data_type,
308 &left_type,
309 &right_type,
310 )?
311 } else {
312 left.cast_to(&left_type, left_schema)?
313 };
314
315 let right_expr = if !right_cast_ok {
316 Self::coerce_date_time_math_op(
317 right,
318 &op,
319 &right_data_type,
320 &right_type,
321 &left_type,
322 )?
323 } else {
324 right.cast_to(&right_type, right_schema)?
325 };
326
327 Ok((left_expr, right_expr))
328 }
329
330 fn coerce_date_time_math_op(
331 expr: Expr,
332 op: &Operator,
333 left_current_type: &DataType,
334 left_target_type: &DataType,
335 right_target_type: &DataType,
336 ) -> Result<Expr, DataFusionError> {
337 use DataType::*;
338
339 fn cast(expr: Expr, target_type: DataType) -> Expr {
340 Expr::Cast(Cast::new(Box::new(expr), target_type))
341 }
342
343 fn time_to_nanos(
344 expr: Expr,
345 expr_type: &DataType,
346 ) -> Result<Expr, DataFusionError> {
347 let expr = match expr_type {
348 Time32(TimeUnit::Second) => {
349 cast(cast(expr, Int32), Int64)
350 * lit(ScalarValue::Int64(Some(1_000_000_000)))
351 }
352 Time32(TimeUnit::Millisecond) => {
353 cast(cast(expr, Int32), Int64)
354 * lit(ScalarValue::Int64(Some(1_000_000)))
355 }
356 Time64(TimeUnit::Microsecond) => {
357 cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000)))
358 }
359 Time64(TimeUnit::Nanosecond) => cast(expr, Int64),
360 t => return internal_err!("Unexpected time data type {t}"),
361 };
362
363 Ok(expr)
364 }
365
366 let e = match (
367 &op,
368 &left_current_type,
369 &left_target_type,
370 &right_target_type,
371 ) {
372 (
374 Operator::Plus | Operator::Minus,
375 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
376 Interval(IntervalUnit::MonthDayNano),
377 Date32 | Date64,
378 ) => {
379 let expr = match *left_current_type {
381 Int64 => expr,
382 _ => cast(expr, Int64),
383 };
384 let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY));
386 let expr = cast(expr, Duration(TimeUnit::Second));
388 cast(expr, Interval(IntervalUnit::MonthDayNano))
390 }
391 (
400 Operator::Plus | Operator::Minus,
401 Time32(_) | Time64(_),
402 Duration(TimeUnit::Nanosecond),
403 Timestamp(TimeUnit::Nanosecond, None),
404 ) => {
405 let expr = time_to_nanos(expr, left_current_type)?;
407 cast(expr, Duration(TimeUnit::Nanosecond))
409 }
410 (
416 Operator::Plus | Operator::Minus,
417 Time32(_) | Time64(_),
418 Interval(IntervalUnit::MonthDayNano),
419 Interval(IntervalUnit::MonthDayNano),
420 ) => {
421 let expr = time_to_nanos(expr, left_current_type)?;
423 let expr = cast(expr, Duration(TimeUnit::Nanosecond));
425 cast(expr, Interval(IntervalUnit::MonthDayNano))
427 }
428 _ => {
429 return plan_err!(
430 "Cannot automatically convert {left_current_type} to {left_target_type}"
431 );
432 }
433 };
434
435 Ok(e)
436 }
437}
438
439impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
440 type Node = Expr;
441
442 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
443 match expr {
444 Expr::Unnest(_) => not_impl_err!(
445 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
446 ),
447 Expr::ScalarSubquery(Subquery {
448 subquery,
449 outer_ref_columns,
450 spans,
451 }) => {
452 let new_plan =
453 analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data;
454 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
455 subquery: Arc::new(new_plan),
456 outer_ref_columns,
457 spans,
458 })))
459 }
460 Expr::Exists(Exists { subquery, negated }) => {
461 let new_plan = analyze_internal(
462 self.schema,
463 Arc::unwrap_or_clone(subquery.subquery),
464 )?
465 .data;
466 Ok(Transformed::yes(Expr::Exists(Exists {
467 subquery: Subquery {
468 subquery: Arc::new(new_plan),
469 outer_ref_columns: subquery.outer_ref_columns,
470 spans: subquery.spans,
471 },
472 negated,
473 })))
474 }
475 Expr::InSubquery(InSubquery {
476 expr,
477 subquery,
478 negated,
479 }) => {
480 let new_plan = analyze_internal(
481 self.schema,
482 Arc::unwrap_or_clone(subquery.subquery),
483 )?
484 .data;
485 let expr_type = expr.get_type(self.schema)?;
486 let subquery_type = new_plan.schema().field(0).data_type();
487 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
488 plan_datafusion_err!(
489 "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
490 ),
491 )?;
492 let new_subquery = Subquery {
493 subquery: Arc::new(new_plan),
494 outer_ref_columns: subquery.outer_ref_columns,
495 spans: subquery.spans,
496 };
497 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
498 Box::new(expr.cast_to(&common_type, self.schema)?),
499 cast_subquery(new_subquery, &common_type)?,
500 negated,
501 ))))
502 }
503 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
504 *expr,
505 self.schema,
506 )?))),
507 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
508 get_casted_expr_for_bool_op(*expr, self.schema)?,
509 ))),
510 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
511 get_casted_expr_for_bool_op(*expr, self.schema)?,
512 ))),
513 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
514 get_casted_expr_for_bool_op(*expr, self.schema)?,
515 ))),
516 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
517 get_casted_expr_for_bool_op(*expr, self.schema)?,
518 ))),
519 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
520 get_casted_expr_for_bool_op(*expr, self.schema)?,
521 ))),
522 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
523 get_casted_expr_for_bool_op(*expr, self.schema)?,
524 ))),
525 Expr::Like(Like {
526 negated,
527 expr,
528 pattern,
529 escape_char,
530 case_insensitive,
531 }) => {
532 let left_type = expr.get_type(self.schema)?;
533 let right_type = pattern.get_type(self.schema)?;
534 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
535 let op_name = if case_insensitive {
536 "ILIKE"
537 } else {
538 "LIKE"
539 };
540 plan_datafusion_err!(
541 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
542 )
543 })?;
544 let expr = match left_type {
545 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
546 _ => Box::new(expr.cast_to(&coerced_type, self.schema)?),
547 };
548 let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
549 Ok(Transformed::yes(Expr::Like(Like::new(
550 negated,
551 expr,
552 pattern,
553 escape_char,
554 case_insensitive,
555 ))))
556 }
557 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
558 let (left, right) =
559 self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?;
560 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
561 Box::new(left),
562 op,
563 Box::new(right),
564 ))))
565 }
566 Expr::Between(Between {
567 expr,
568 negated,
569 low,
570 high,
571 }) => {
572 let expr_type = expr.get_type(self.schema)?;
573 let low_type = low.get_type(self.schema)?;
574 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
575 .ok_or_else(|| {
576 internal_datafusion_err!(
577 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
578 )
579 })?;
580 let high_type = high.get_type(self.schema)?;
581 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
582 .ok_or_else(|| {
583 internal_datafusion_err!(
584 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
585 )
586 })?;
587 let coercion_type =
588 comparison_coercion(&low_coerced_type, &high_coerced_type)
589 .ok_or_else(|| {
590 internal_datafusion_err!(
591 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
592 )
593 })?;
594 Ok(Transformed::yes(Expr::Between(Between::new(
595 Box::new(expr.cast_to(&coercion_type, self.schema)?),
596 negated,
597 Box::new(low.cast_to(&coercion_type, self.schema)?),
598 Box::new(high.cast_to(&coercion_type, self.schema)?),
599 ))))
600 }
601 Expr::InList(InList {
602 expr,
603 list,
604 negated,
605 }) => {
606 let expr_data_type = expr.get_type(self.schema)?;
607 let list_data_types = list
608 .iter()
609 .map(|list_expr| list_expr.get_type(self.schema))
610 .collect::<Result<Vec<_>>>()?;
611 let result_type =
612 get_coerce_type_for_list(&expr_data_type, &list_data_types);
613 match result_type {
614 None => plan_err!(
615 "Can not find compatible types to compare {expr_data_type} with [{}]",
616 list_data_types.iter().join(", ")
617 ),
618 Some(coerced_type) => {
619 let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
621 let cast_list_expr = list
622 .into_iter()
623 .map(|list_expr| {
624 list_expr.cast_to(&coerced_type, self.schema)
625 })
626 .collect::<Result<Vec<_>>>()?;
627 Ok(Transformed::yes(Expr::InList(InList::new(
628 Box::new(cast_expr),
629 cast_list_expr,
630 negated,
631 ))))
632 }
633 }
634 }
635 Expr::Case(case) => {
636 let case = coerce_case_expression(case, self.schema)?;
637 Ok(Transformed::yes(Expr::Case(case)))
638 }
639 Expr::ScalarFunction(ScalarFunction { func, args }) => {
640 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
641 args,
642 self.schema,
643 &func,
644 )?;
645 Ok(Transformed::yes(Expr::ScalarFunction(
646 ScalarFunction::new_udf(func, new_expr),
647 )))
648 }
649 Expr::AggregateFunction(expr::AggregateFunction {
650 func,
651 params:
652 AggregateFunctionParams {
653 args,
654 distinct,
655 filter,
656 order_by,
657 null_treatment,
658 },
659 }) => {
660 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
661 args,
662 self.schema,
663 &func,
664 )?;
665 Ok(Transformed::yes(Expr::AggregateFunction(
666 expr::AggregateFunction::new_udf(
667 func,
668 new_expr,
669 distinct,
670 filter,
671 order_by,
672 null_treatment,
673 ),
674 )))
675 }
676 Expr::WindowFunction(window_fun) => {
677 let WindowFunction {
678 fun,
679 params:
680 expr::WindowFunctionParams {
681 args,
682 partition_by,
683 order_by,
684 window_frame,
685 filter,
686 null_treatment,
687 distinct,
688 },
689 } = *window_fun;
690 let window_frame =
691 coerce_window_frame(window_frame, self.schema, &order_by)?;
692
693 let args = match &fun {
694 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
695 coerce_arguments_for_signature_with_aggregate_udf(
696 args,
697 self.schema,
698 udf,
699 )?
700 }
701 _ => args,
702 };
703
704 let new_expr = Expr::from(WindowFunction {
705 fun,
706 params: expr::WindowFunctionParams {
707 args,
708 partition_by,
709 order_by,
710 window_frame,
711 filter,
712 null_treatment,
713 distinct,
714 },
715 });
716 Ok(Transformed::yes(new_expr))
717 }
718 #[expect(deprecated)]
720 Expr::Alias(_)
721 | Expr::Column(_)
722 | Expr::ScalarVariable(_, _)
723 | Expr::Literal(_, _)
724 | Expr::SimilarTo(_)
725 | Expr::IsNotNull(_)
726 | Expr::IsNull(_)
727 | Expr::Negative(_)
728 | Expr::Cast(_)
729 | Expr::TryCast(_)
730 | Expr::Wildcard { .. }
731 | Expr::GroupingSet(_)
732 | Expr::Placeholder(_)
733 | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)),
734 }
735 }
736}
737
738fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
740 let metadata = dfschema.as_arrow().metadata.clone();
741 let mut transformed = false;
742
743 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
744 dfschema
745 .iter()
746 .map(|(qualifier, field)| match field.data_type() {
747 DataType::Utf8View => {
748 transformed = true;
749 (
750 qualifier.cloned() as Option<TableReference>,
751 Arc::new(Field::new(
752 field.name(),
753 DataType::LargeUtf8,
754 field.is_nullable(),
755 )),
756 )
757 }
758 DataType::BinaryView => {
759 transformed = true;
760 (
761 qualifier.cloned() as Option<TableReference>,
762 Arc::new(Field::new(
763 field.name(),
764 DataType::LargeBinary,
765 field.is_nullable(),
766 )),
767 )
768 }
769 _ => (
770 qualifier.cloned() as Option<TableReference>,
771 Arc::clone(field),
772 ),
773 })
774 .unzip();
775
776 if !transformed {
777 return None;
778 }
779
780 let schema = Schema::new_with_metadata(transformed_fields, metadata);
781 Some(DFSchema::from_field_specific_qualified_schema(
782 qualifiers,
783 &Arc::new(schema),
784 ))
785}
786
787fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
790 match value {
791 ScalarValue::Utf8(Some(val)) => {
793 ScalarValue::try_from_string(val.clone(), target_type)
794 }
795 s => {
796 if s.is_null() {
797 ScalarValue::try_from(target_type)
799 } else {
800 Ok(s.clone())
804 }
805 }
806 }
807}
808
809fn coerce_scalar_range_aware(
816 target_type: &DataType,
817 value: &ScalarValue,
818) -> Result<ScalarValue> {
819 coerce_scalar(target_type, value).or_else(|err| {
820 if let Some(largest_type) = get_widest_type_in_family(target_type) {
822 coerce_scalar(largest_type, value).map_or_else(
823 |_| exec_err!("Cannot cast {value:?} to {target_type}"),
824 |_| ScalarValue::try_from(target_type),
825 )
826 } else {
827 Err(err)
828 }
829 })
830}
831
832fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
836 match given_type {
837 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
838 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
839 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
840 _ => None,
841 }
842}
843
844fn coerce_frame_bound(
846 target_type: &DataType,
847 bound: WindowFrameBound,
848) -> Result<WindowFrameBound> {
849 match bound {
850 WindowFrameBound::Preceding(v) => {
851 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
852 }
853 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
854 WindowFrameBound::Following(v) => {
855 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
856 }
857 }
858}
859
860fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
861 if col_type.is_numeric()
862 || is_utf8_or_utf8view_or_large_utf8(col_type)
863 || matches!(col_type, DataType::List(_))
864 || matches!(col_type, DataType::LargeList(_))
865 || matches!(col_type, DataType::FixedSizeList(_, _))
866 || matches!(col_type, DataType::Null)
867 || matches!(col_type, DataType::Boolean)
868 {
869 Ok(col_type.clone())
870 } else if is_datetime(col_type) {
871 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
872 } else if let DataType::Dictionary(_, value_type) = col_type {
873 extract_window_frame_target_type(value_type)
874 } else {
875 internal_err!("Cannot run range queries on datatype: {col_type}")
876 }
877}
878
879fn coerce_window_frame(
882 window_frame: WindowFrame,
883 schema: &DFSchema,
884 expressions: &[Sort],
885) -> Result<WindowFrame> {
886 let mut window_frame = window_frame;
887 let target_type = match window_frame.units {
888 WindowFrameUnits::Range => {
889 let current_types = expressions
890 .first()
891 .map(|s| s.expr.get_type(schema))
892 .transpose()?;
893 if let Some(col_type) = current_types {
894 extract_window_frame_target_type(&col_type)?
895 } else {
896 return internal_err!("ORDER BY column cannot be empty");
897 }
898 }
899 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
900 };
901 window_frame.start_bound =
902 coerce_frame_bound(&target_type, window_frame.start_bound)?;
903 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
904 Ok(window_frame)
905}
906
907fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
910 let left_type = expr.get_type(schema)?;
911 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
912 .get_input_types()?;
913 expr.cast_to(&DataType::Boolean, schema)
914}
915
916fn coerce_arguments_for_signature_with_scalar_udf(
921 expressions: Vec<Expr>,
922 schema: &DFSchema,
923 func: &ScalarUDF,
924) -> Result<Vec<Expr>> {
925 if expressions.is_empty() {
926 return Ok(expressions);
927 }
928
929 let current_fields = expressions
930 .iter()
931 .map(|e| e.to_field(schema).map(|(_, f)| f))
932 .collect::<Result<Vec<_>>>()?;
933
934 let coerced_types = fields_with_udf(¤t_fields, func)?
935 .into_iter()
936 .map(|f| f.data_type().clone())
937 .collect::<Vec<_>>();
938
939 expressions
940 .into_iter()
941 .enumerate()
942 .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
943 .collect()
944}
945
946fn coerce_arguments_for_signature_with_aggregate_udf(
951 expressions: Vec<Expr>,
952 schema: &DFSchema,
953 func: &AggregateUDF,
954) -> Result<Vec<Expr>> {
955 if expressions.is_empty() {
956 return Ok(expressions);
957 }
958
959 let current_fields = expressions
960 .iter()
961 .map(|e| e.to_field(schema).map(|(_, f)| f))
962 .collect::<Result<Vec<_>>>()?;
963
964 let coerced_types = fields_with_udf(¤t_fields, func)?
965 .into_iter()
966 .map(|f| f.data_type().clone())
967 .collect::<Vec<_>>();
968
969 expressions
970 .into_iter()
971 .enumerate()
972 .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
973 .collect()
974}
975
976fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
977 let case_type = case
1009 .expr
1010 .as_ref()
1011 .map(|expr| expr.get_type(schema))
1012 .transpose()?;
1013 let then_types = case
1014 .when_then_expr
1015 .iter()
1016 .map(|(_when, then)| then.get_type(schema))
1017 .collect::<Result<Vec<_>>>()?;
1018 let else_type = case
1019 .else_expr
1020 .as_ref()
1021 .map(|expr| expr.get_type(schema))
1022 .transpose()?;
1023
1024 let case_when_coerce_type = case_type
1026 .as_ref()
1027 .map(|case_type| {
1028 let when_types = case
1029 .when_then_expr
1030 .iter()
1031 .map(|(when, _then)| when.get_type(schema))
1032 .collect::<Result<Vec<_>>>()?;
1033 let coerced_type =
1034 get_coerce_type_for_case_expression(&when_types, Some(case_type));
1035 coerced_type.ok_or_else(|| {
1036 plan_datafusion_err!(
1037 "Failed to coerce case ({case_type}) and when ({}) \
1038 to common types in CASE WHEN expression",
1039 when_types.iter().join(", ")
1040 )
1041 })
1042 })
1043 .transpose()?;
1044 let then_else_coerce_type =
1045 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
1046 || {
1047 if let Some(else_type) = else_type {
1048 plan_datafusion_err!(
1049 "Failed to coerce then ({}) and else ({else_type}) \
1050 to common types in CASE WHEN expression",
1051 then_types.iter().join(", ")
1052 )
1053 } else {
1054 plan_datafusion_err!(
1055 "Failed to coerce then ({}) and else (None) \
1056 to common types in CASE WHEN expression",
1057 then_types.iter().join(", ")
1058 )
1059 }
1060 },
1061 )?;
1062
1063 let case_expr = case
1065 .expr
1066 .zip(case_when_coerce_type.as_ref())
1067 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
1068 .transpose()?
1069 .map(Box::new);
1070 let when_then = case
1071 .when_then_expr
1072 .into_iter()
1073 .map(|(when, then)| {
1074 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
1075 let when = when.cast_to(when_type, schema).map_err(|e| {
1076 DataFusionError::Context(
1077 format!(
1078 "WHEN expressions in CASE couldn't be \
1079 converted to common type ({when_type})"
1080 ),
1081 Box::new(e),
1082 )
1083 })?;
1084 let then = then.cast_to(&then_else_coerce_type, schema)?;
1085 Ok((Box::new(when), Box::new(then)))
1086 })
1087 .collect::<Result<Vec<_>>>()?;
1088 let else_expr = case
1089 .else_expr
1090 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
1091 .transpose()?
1092 .map(Box::new);
1093
1094 Ok(Case::new(case_expr, when_then, else_expr))
1095}
1096
1097pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1139 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1140}
1141fn coerce_union_schema_with_schema(
1142 inputs: &[Arc<LogicalPlan>],
1143 base_schema: &DFSchemaRef,
1144) -> Result<DFSchema> {
1145 let mut union_datatypes = base_schema
1146 .fields()
1147 .iter()
1148 .map(|f| f.data_type().clone())
1149 .collect::<Vec<_>>();
1150 let mut union_nullabilities = base_schema
1151 .fields()
1152 .iter()
1153 .map(|f| f.is_nullable())
1154 .collect::<Vec<_>>();
1155 let mut union_field_meta = base_schema
1156 .fields()
1157 .iter()
1158 .map(|f| f.metadata().clone())
1159 .collect::<Vec<_>>();
1160
1161 let mut metadata = base_schema.metadata().clone();
1162
1163 for (i, plan) in inputs.iter().enumerate() {
1164 let plan_schema = plan.schema();
1165 metadata.extend(plan_schema.metadata().clone());
1166
1167 if plan_schema.fields().len() != base_schema.fields().len() {
1168 return plan_err!(
1169 "Union schemas have different number of fields: \
1170 query 1 has {} fields whereas query {} has {} fields",
1171 base_schema.fields().len(),
1172 i + 1,
1173 plan_schema.fields().len()
1174 );
1175 }
1176
1177 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1179 union_datatypes.iter_mut(),
1180 union_nullabilities.iter_mut(),
1181 union_field_meta.iter_mut(),
1182 plan_schema.fields().iter()
1183 ) {
1184 let coerced_type =
1185 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1186 || {
1187 plan_datafusion_err!(
1188 "Incompatible inputs for Union: Previous inputs were \
1189 of type {}, but got incompatible type {} on column '{}'",
1190 union_datatype,
1191 plan_field.data_type(),
1192 plan_field.name()
1193 )
1194 },
1195 )?;
1196
1197 *union_datatype = coerced_type;
1198 *union_nullable = *union_nullable || plan_field.is_nullable();
1199 union_field_map.extend(plan_field.metadata().clone());
1200 }
1201 }
1202 let union_qualified_fields = izip!(
1203 base_schema.fields(),
1204 union_datatypes.into_iter(),
1205 union_nullabilities,
1206 union_field_meta.into_iter()
1207 )
1208 .map(|(field, datatype, nullable, metadata)| {
1209 let mut field = Field::new(field.name().clone(), datatype, nullable);
1210 field.set_metadata(metadata);
1211 (None, field.into())
1212 })
1213 .collect::<Vec<_>>();
1214
1215 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1216}
1217
1218fn project_with_column_index(
1220 expr: Vec<Expr>,
1221 input: Arc<LogicalPlan>,
1222 schema: DFSchemaRef,
1223) -> Result<LogicalPlan> {
1224 let alias_expr = expr
1225 .into_iter()
1226 .enumerate()
1227 .map(|(i, e)| match e {
1228 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1229 Ok(e.unalias().alias(schema.field(i).name()))
1230 }
1231 Expr::Column(Column {
1232 relation: _,
1233 ref name,
1234 spans: _,
1235 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1236 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1237 #[expect(deprecated)]
1238 Expr::Wildcard { .. } => {
1239 plan_err!("Wildcard should be expanded before type coercion")
1240 }
1241 _ => Ok(e.alias(schema.field(i).name())),
1242 })
1243 .collect::<Result<Vec<_>>>()?;
1244
1245 Projection::try_new_with_schema(alias_expr, input, schema)
1246 .map(LogicalPlan::Projection)
1247}
1248
1249#[cfg(test)]
1250mod test {
1251 use std::any::Any;
1252 use std::sync::Arc;
1253
1254 use arrow::datatypes::DataType::Utf8;
1255 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1256 use insta::assert_snapshot;
1257
1258 use crate::analyzer::Analyzer;
1259 use crate::analyzer::type_coercion::{
1260 TypeCoercion, TypeCoercionRewriter, coerce_case_expression,
1261 };
1262 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1263 use datafusion_common::config::ConfigOptions;
1264 use datafusion_common::tree_node::{TransformedResult, TreeNode};
1265 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1266 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1267 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1268 use datafusion_expr::test::function_stub::avg_udaf;
1269 use datafusion_expr::{
1270 AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr,
1271 ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF,
1272 ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast,
1273 col, create_udaf, is_true, lit,
1274 };
1275 use datafusion_functions_aggregate::average::AvgAccumulator;
1276 use datafusion_sql::TableReference;
1277
1278 fn empty() -> Arc<LogicalPlan> {
1279 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1280 produce_one_row: false,
1281 schema: Arc::new(DFSchema::empty()),
1282 }))
1283 }
1284
1285 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1286 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1287 produce_one_row: false,
1288 schema: Arc::new(
1289 DFSchema::from_unqualified_fields(
1290 vec![Field::new("a", data_type, true)].into(),
1291 std::collections::HashMap::new(),
1292 )
1293 .unwrap(),
1294 ),
1295 }))
1296 }
1297
1298 macro_rules! assert_analyzed_plan_eq {
1299 (
1300 $plan: expr,
1301 @ $expected: literal $(,)?
1302 ) => {{
1303 let options = ConfigOptions::default();
1304 let rule = Arc::new(TypeCoercion::new());
1305 assert_analyzed_plan_with_config_eq_snapshot!(
1306 options,
1307 rule,
1308 $plan,
1309 @ $expected,
1310 )
1311 }};
1312 }
1313
1314 macro_rules! coerce_on_output_if_viewtype {
1315 (
1316 $is_viewtype: expr,
1317 $plan: expr,
1318 @ $expected: literal $(,)?
1319 ) => {{
1320 let mut options = ConfigOptions::default();
1321 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1323 let rule = Arc::new(TypeCoercion::new());
1324
1325 assert_analyzed_plan_with_config_eq_snapshot!(
1326 options,
1327 rule,
1328 $plan,
1329 @ $expected,
1330 )
1331 }};
1332 }
1333
1334 fn assert_type_coercion_error(
1335 plan: LogicalPlan,
1336 expected_substr: &str,
1337 ) -> Result<()> {
1338 let options = ConfigOptions::default();
1339 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1340
1341 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1342 Ok(succeeded_plan) => {
1343 panic!(
1344 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1345 );
1346 }
1347 Err(e) => {
1348 let msg = e.to_string();
1349 assert!(
1350 msg.contains(expected_substr),
1351 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1352 );
1353 }
1354 }
1355
1356 Ok(())
1357 }
1358
1359 #[test]
1360 fn simple_case() -> Result<()> {
1361 let expr = col("a").lt(lit(2_u32));
1362 let empty = empty_with_type(DataType::Float64);
1363 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1364
1365 assert_analyzed_plan_eq!(
1366 plan,
1367 @r"
1368 Projection: a < CAST(UInt32(2) AS Float64)
1369 EmptyRelation: rows=0
1370 "
1371 )
1372 }
1373
1374 #[test]
1375 fn test_coerce_union() -> Result<()> {
1376 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1377 produce_one_row: false,
1378 schema: Arc::new(
1379 DFSchema::try_from_qualified_schema(
1380 TableReference::full("datafusion", "test", "foo"),
1381 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1382 )
1383 .unwrap(),
1384 ),
1385 }));
1386 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1387 produce_one_row: false,
1388 schema: Arc::new(
1389 DFSchema::try_from_qualified_schema(
1390 TableReference::full("datafusion", "test", "foo"),
1391 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1392 )
1393 .unwrap(),
1394 ),
1395 }));
1396 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1397 left_plan, right_plan,
1398 ])?);
1399 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1400 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1401 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1402 vec![col("a")],
1403 Arc::new(analyzed_union),
1404 )?);
1405
1406 assert_analyzed_plan_eq!(
1407 top_level_plan,
1408 @r"
1409 Projection: a
1410 Union
1411 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1412 EmptyRelation: rows=0
1413 EmptyRelation: rows=0
1414 "
1415 )
1416 }
1417
1418 #[test]
1419 fn coerce_utf8view_output() -> Result<()> {
1420 let expr = col("a");
1423 let empty = empty_with_type(DataType::Utf8View);
1424 let plan = LogicalPlan::Projection(Projection::try_new(
1425 vec![expr.clone()],
1426 Arc::clone(&empty),
1427 )?);
1428
1429 coerce_on_output_if_viewtype!(
1431 false,
1432 plan.clone(),
1433 @r"
1434 Projection: a
1435 EmptyRelation: rows=0
1436 "
1437 )?;
1438
1439 coerce_on_output_if_viewtype!(
1441 true,
1442 plan.clone(),
1443 @r"
1444 Projection: CAST(a AS LargeUtf8) AS a
1445 EmptyRelation: rows=0
1446 "
1447 )?;
1448
1449 let bool_expr = col("a").lt(lit("foo"));
1452 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1453 vec![bool_expr],
1454 Arc::clone(&empty),
1455 )?);
1456 coerce_on_output_if_viewtype!(
1458 false,
1459 bool_plan.clone(),
1460 @r#"
1461 Projection: a < CAST(Utf8("foo") AS Utf8View)
1462 EmptyRelation: rows=0
1463 "#
1464 )?;
1465
1466 coerce_on_output_if_viewtype!(
1467 false,
1468 plan.clone(),
1469 @r"
1470 Projection: a
1471 EmptyRelation: rows=0
1472 "
1473 )?;
1474
1475 coerce_on_output_if_viewtype!(
1477 true,
1478 plan.clone(),
1479 @r"
1480 Projection: CAST(a AS LargeUtf8) AS a
1481 EmptyRelation: rows=0
1482 "
1483 )?;
1484
1485 let sort_expr = expr.sort(true, true);
1488 let sort_plan = LogicalPlan::Sort(Sort {
1489 expr: vec![sort_expr],
1490 input: Arc::new(plan),
1491 fetch: None,
1492 });
1493
1494 coerce_on_output_if_viewtype!(
1496 false,
1497 sort_plan.clone(),
1498 @r"
1499 Sort: a ASC NULLS FIRST
1500 Projection: a
1501 EmptyRelation: rows=0
1502 "
1503 )?;
1504
1505 coerce_on_output_if_viewtype!(
1507 true,
1508 sort_plan.clone(),
1509 @r"
1510 Projection: CAST(a AS LargeUtf8) AS a
1511 Sort: a ASC NULLS FIRST
1512 Projection: a
1513 EmptyRelation: rows=0
1514 "
1515 )?;
1516
1517 let plan = LogicalPlan::Projection(Projection::try_new(
1520 vec![col("a")],
1521 Arc::new(sort_plan),
1522 )?);
1523 coerce_on_output_if_viewtype!(
1525 false,
1526 plan.clone(),
1527 @r"
1528 Projection: a
1529 Sort: a ASC NULLS FIRST
1530 Projection: a
1531 EmptyRelation: rows=0
1532 "
1533 )?;
1534 coerce_on_output_if_viewtype!(
1536 true,
1537 plan.clone(),
1538 @r"
1539 Projection: CAST(a AS LargeUtf8) AS a
1540 Sort: a ASC NULLS FIRST
1541 Projection: a
1542 EmptyRelation: rows=0
1543 "
1544 )?;
1545
1546 Ok(())
1547 }
1548
1549 #[test]
1550 fn coerce_binaryview_output() -> Result<()> {
1551 let expr = col("a");
1554 let empty = empty_with_type(DataType::BinaryView);
1555 let plan = LogicalPlan::Projection(Projection::try_new(
1556 vec![expr.clone()],
1557 Arc::clone(&empty),
1558 )?);
1559
1560 coerce_on_output_if_viewtype!(
1562 false,
1563 plan.clone(),
1564 @r"
1565 Projection: a
1566 EmptyRelation: rows=0
1567 "
1568 )?;
1569
1570 coerce_on_output_if_viewtype!(
1572 true,
1573 plan.clone(),
1574 @r"
1575 Projection: CAST(a AS LargeBinary) AS a
1576 EmptyRelation: rows=0
1577 "
1578 )?;
1579
1580 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1583 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1584 vec![bool_expr],
1585 Arc::clone(&empty),
1586 )?);
1587
1588 coerce_on_output_if_viewtype!(
1590 false,
1591 bool_plan.clone(),
1592 @r#"
1593 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1594 EmptyRelation: rows=0
1595 "#
1596 )?;
1597
1598 coerce_on_output_if_viewtype!(
1600 true,
1601 bool_plan.clone(),
1602 @r#"
1603 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1604 EmptyRelation: rows=0
1605 "#
1606 )?;
1607
1608 let sort_expr = expr.sort(true, true);
1611 let sort_plan = LogicalPlan::Sort(Sort {
1612 expr: vec![sort_expr],
1613 input: Arc::new(plan),
1614 fetch: None,
1615 });
1616
1617 coerce_on_output_if_viewtype!(
1619 false,
1620 sort_plan.clone(),
1621 @r"
1622 Sort: a ASC NULLS FIRST
1623 Projection: a
1624 EmptyRelation: rows=0
1625 "
1626 )?;
1627 coerce_on_output_if_viewtype!(
1629 true,
1630 sort_plan.clone(),
1631 @r"
1632 Projection: CAST(a AS LargeBinary) AS a
1633 Sort: a ASC NULLS FIRST
1634 Projection: a
1635 EmptyRelation: rows=0
1636 "
1637 )?;
1638
1639 let plan = LogicalPlan::Projection(Projection::try_new(
1642 vec![col("a")],
1643 Arc::new(sort_plan),
1644 )?);
1645
1646 coerce_on_output_if_viewtype!(
1648 false,
1649 plan.clone(),
1650 @r"
1651 Projection: a
1652 Sort: a ASC NULLS FIRST
1653 Projection: a
1654 EmptyRelation: rows=0
1655 "
1656 )?;
1657
1658 coerce_on_output_if_viewtype!(
1660 true,
1661 plan.clone(),
1662 @r"
1663 Projection: CAST(a AS LargeBinary) AS a
1664 Sort: a ASC NULLS FIRST
1665 Projection: a
1666 EmptyRelation: rows=0
1667 "
1668 )?;
1669
1670 Ok(())
1671 }
1672
1673 #[test]
1674 fn nested_case() -> Result<()> {
1675 let expr = col("a").lt(lit(2_u32));
1676 let empty = empty_with_type(DataType::Float64);
1677
1678 let plan = LogicalPlan::Projection(Projection::try_new(
1679 vec![expr.clone().or(expr)],
1680 empty,
1681 )?);
1682
1683 assert_analyzed_plan_eq!(
1684 plan,
1685 @r"
1686 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1687 EmptyRelation: rows=0
1688 "
1689 )
1690 }
1691
1692 #[derive(Debug, PartialEq, Eq, Hash)]
1693 struct TestScalarUDF {
1694 signature: Signature,
1695 }
1696
1697 impl ScalarUDFImpl for TestScalarUDF {
1698 fn as_any(&self) -> &dyn Any {
1699 self
1700 }
1701
1702 fn name(&self) -> &str {
1703 "TestScalarUDF"
1704 }
1705
1706 fn signature(&self) -> &Signature {
1707 &self.signature
1708 }
1709
1710 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1711 Ok(Utf8)
1712 }
1713
1714 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1715 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1716 }
1717 }
1718
1719 #[test]
1720 fn scalar_udf() -> Result<()> {
1721 let empty = empty();
1722
1723 let udf = ScalarUDF::from(TestScalarUDF {
1724 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1725 })
1726 .call(vec![lit(123_i32)]);
1727 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1728
1729 assert_analyzed_plan_eq!(
1730 plan,
1731 @r"
1732 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1733 EmptyRelation: rows=0
1734 "
1735 )
1736 }
1737
1738 #[test]
1739 fn scalar_udf_invalid_input() -> Result<()> {
1740 let empty = empty();
1741 let udf = ScalarUDF::from(TestScalarUDF {
1742 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1743 })
1744 .call(vec![lit("Apple")]);
1745 Projection::try_new(vec![udf], empty)
1746 .expect_err("Expected an error due to incorrect function input");
1747
1748 Ok(())
1749 }
1750
1751 #[test]
1752 fn scalar_function() -> Result<()> {
1753 let empty = empty();
1755 let lit_expr = lit(10i64);
1756 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1757 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1758 });
1759 let scalar_function_expr =
1760 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1761 let plan = LogicalPlan::Projection(Projection::try_new(
1762 vec![scalar_function_expr],
1763 empty,
1764 )?);
1765
1766 assert_analyzed_plan_eq!(
1767 plan,
1768 @r"
1769 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1770 EmptyRelation: rows=0
1771 "
1772 )
1773 }
1774
1775 #[test]
1776 fn agg_udaf() -> Result<()> {
1777 let empty = empty();
1778 let my_avg = create_udaf(
1779 "MY_AVG",
1780 vec![DataType::Float64],
1781 Arc::new(DataType::Float64),
1782 Volatility::Immutable,
1783 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1784 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1785 );
1786 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1787 Arc::new(my_avg),
1788 vec![lit(10i64)],
1789 false,
1790 None,
1791 vec![],
1792 None,
1793 ));
1794 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1795
1796 assert_analyzed_plan_eq!(
1797 plan,
1798 @r"
1799 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1800 EmptyRelation: rows=0
1801 "
1802 )
1803 }
1804
1805 #[test]
1806 fn agg_udaf_invalid_input() -> Result<()> {
1807 let empty = empty();
1808 let return_type = DataType::Float64;
1809 let accumulator: AccumulatorFactoryFunction =
1810 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1811 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1812 "MY_AVG",
1813 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1814 return_type,
1815 accumulator,
1816 vec![
1817 Field::new("count", DataType::UInt64, true).into(),
1818 Field::new("avg", DataType::Float64, true).into(),
1819 ],
1820 ));
1821 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1822 Arc::new(my_avg),
1823 vec![lit("10")],
1824 false,
1825 None,
1826 vec![],
1827 None,
1828 ));
1829
1830 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1831 assert!(
1832 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1833 );
1834 Ok(())
1835 }
1836
1837 #[test]
1838 fn agg_function_case() -> Result<()> {
1839 let empty = empty();
1840 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1841 avg_udaf(),
1842 vec![lit(12f64)],
1843 false,
1844 None,
1845 vec![],
1846 None,
1847 ));
1848 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1849
1850 assert_analyzed_plan_eq!(
1851 plan,
1852 @r"
1853 Projection: avg(Float64(12))
1854 EmptyRelation: rows=0
1855 "
1856 )?;
1857
1858 let empty = empty_with_type(DataType::Int32);
1859 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1860 avg_udaf(),
1861 vec![cast(col("a"), DataType::Float64)],
1862 false,
1863 None,
1864 vec![],
1865 None,
1866 ));
1867 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1868
1869 assert_analyzed_plan_eq!(
1870 plan,
1871 @r"
1872 Projection: avg(CAST(a AS Float64))
1873 EmptyRelation: rows=0
1874 "
1875 )
1876 }
1877
1878 #[test]
1879 fn agg_function_invalid_input_avg() -> Result<()> {
1880 let empty = empty();
1881 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1882 avg_udaf(),
1883 vec![lit("1")],
1884 false,
1885 None,
1886 vec![],
1887 None,
1888 ));
1889 let err = Projection::try_new(vec![agg_expr], empty)
1890 .err()
1891 .unwrap()
1892 .strip_backtrace();
1893 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1894 Ok(())
1895 }
1896
1897 #[test]
1898 fn binary_op_date32_op_interval() -> Result<()> {
1899 let expr = cast(lit("1998-03-18"), DataType::Date32)
1901 + lit(ScalarValue::new_interval_dt(123, 456));
1902 let empty = empty();
1903 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1904
1905 assert_analyzed_plan_eq!(
1906 plan,
1907 @r#"
1908 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1909 EmptyRelation: rows=0
1910 "#
1911 )
1912 }
1913
1914 #[test]
1915 fn inlist_case() -> Result<()> {
1916 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1918 let empty = empty_with_type(DataType::Int64);
1919 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1920 assert_analyzed_plan_eq!(
1921 plan,
1922 @r"
1923 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1924 EmptyRelation: rows=0
1925 ")?;
1926
1927 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1929 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1930 produce_one_row: false,
1931 schema: Arc::new(DFSchema::from_unqualified_fields(
1932 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1933 std::collections::HashMap::new(),
1934 )?),
1935 }));
1936 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1937 assert_analyzed_plan_eq!(
1938 plan,
1939 @r"
1940 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1941 EmptyRelation: rows=0
1942 ")
1943 }
1944
1945 #[test]
1946 fn between_case() -> Result<()> {
1947 let expr = col("a").between(
1948 lit("2002-05-08"),
1949 cast(lit("2002-05-08"), DataType::Date32)
1951 + lit(ScalarValue::new_interval_ym(0, 1)),
1952 );
1953 let empty = empty_with_type(Utf8);
1954 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1955
1956 assert_analyzed_plan_eq!(
1957 plan,
1958 @r#"
1959 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1960 EmptyRelation: rows=0
1961 "#
1962 )
1963 }
1964
1965 #[test]
1966 fn between_infer_cheap_type() -> Result<()> {
1967 let expr = col("a").between(
1968 cast(lit("2002-05-08"), DataType::Date32)
1970 + lit(ScalarValue::new_interval_ym(0, 1)),
1971 lit("2002-12-08"),
1972 );
1973 let empty = empty_with_type(Utf8);
1974 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1975
1976 assert_analyzed_plan_eq!(
1978 plan,
1979 @r#"
1980 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1981 EmptyRelation: rows=0
1982 "#
1983 )
1984 }
1985
1986 #[test]
1987 fn between_null() -> Result<()> {
1988 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1989 let empty = empty();
1990 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1991
1992 assert_analyzed_plan_eq!(
1993 plan,
1994 @r"
1995 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1996 EmptyRelation: rows=0
1997 "
1998 )
1999 }
2000
2001 #[test]
2002 fn is_bool_for_type_coercion() -> Result<()> {
2003 let expr = col("a").is_true();
2005 let empty = empty_with_type(DataType::Boolean);
2006 let plan =
2007 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2008
2009 assert_analyzed_plan_eq!(
2010 plan,
2011 @r"
2012 Projection: a IS TRUE
2013 EmptyRelation: rows=0
2014 "
2015 )?;
2016
2017 let empty = empty_with_type(DataType::Int64);
2018 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2019 assert_type_coercion_error(
2020 plan,
2021 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean",
2022 )?;
2023
2024 let expr = col("a").is_not_true();
2026 let empty = empty_with_type(DataType::Boolean);
2027 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2028
2029 assert_analyzed_plan_eq!(
2030 plan,
2031 @r"
2032 Projection: a IS NOT TRUE
2033 EmptyRelation: rows=0
2034 "
2035 )?;
2036
2037 let expr = col("a").is_false();
2039 let empty = empty_with_type(DataType::Boolean);
2040 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2041
2042 assert_analyzed_plan_eq!(
2043 plan,
2044 @r"
2045 Projection: a IS FALSE
2046 EmptyRelation: rows=0
2047 "
2048 )?;
2049
2050 let expr = col("a").is_not_false();
2052 let empty = empty_with_type(DataType::Boolean);
2053 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2054
2055 assert_analyzed_plan_eq!(
2056 plan,
2057 @r"
2058 Projection: a IS NOT FALSE
2059 EmptyRelation: rows=0
2060 "
2061 )
2062 }
2063
2064 #[test]
2065 fn like_for_type_coercion() -> Result<()> {
2066 let expr = Box::new(col("a"));
2068 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2069 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2070 let empty = empty_with_type(Utf8);
2071 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2072
2073 assert_analyzed_plan_eq!(
2074 plan,
2075 @r#"
2076 Projection: a LIKE Utf8("abc")
2077 EmptyRelation: rows=0
2078 "#
2079 )?;
2080
2081 let expr = Box::new(col("a"));
2082 let pattern = Box::new(lit(ScalarValue::Null));
2083 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2084 let empty = empty_with_type(Utf8);
2085 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2086
2087 assert_analyzed_plan_eq!(
2088 plan,
2089 @r"
2090 Projection: a LIKE CAST(NULL AS Utf8)
2091 EmptyRelation: rows=0
2092 "
2093 )?;
2094
2095 let expr = Box::new(col("a"));
2096 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2097 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
2098 let empty = empty_with_type(DataType::Int64);
2099 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
2100 assert_type_coercion_error(
2101 plan,
2102 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
2103 )?;
2104
2105 let expr = Box::new(col("a"));
2107 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2108 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2109 let empty = empty_with_type(Utf8);
2110 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2111
2112 assert_analyzed_plan_eq!(
2113 plan,
2114 @r#"
2115 Projection: a ILIKE Utf8("abc")
2116 EmptyRelation: rows=0
2117 "#
2118 )?;
2119
2120 let expr = Box::new(col("a"));
2121 let pattern = Box::new(lit(ScalarValue::Null));
2122 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2123 let empty = empty_with_type(Utf8);
2124 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2125
2126 assert_analyzed_plan_eq!(
2127 plan,
2128 @r"
2129 Projection: a ILIKE CAST(NULL AS Utf8)
2130 EmptyRelation: rows=0
2131 "
2132 )?;
2133
2134 let expr = Box::new(col("a"));
2135 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2136 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2137 let empty = empty_with_type(DataType::Int64);
2138 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2139 assert_type_coercion_error(
2140 plan,
2141 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2142 )?;
2143
2144 Ok(())
2145 }
2146
2147 #[test]
2148 fn unknown_for_type_coercion() -> Result<()> {
2149 let expr = col("a").is_unknown();
2151 let empty = empty_with_type(DataType::Boolean);
2152 let plan =
2153 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2154
2155 assert_analyzed_plan_eq!(
2156 plan,
2157 @r"
2158 Projection: a IS UNKNOWN
2159 EmptyRelation: rows=0
2160 "
2161 )?;
2162
2163 let empty = empty_with_type(Utf8);
2164 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2165 assert_type_coercion_error(
2166 plan,
2167 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean",
2168 )?;
2169
2170 let expr = col("a").is_not_unknown();
2172 let empty = empty_with_type(DataType::Boolean);
2173 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2174
2175 assert_analyzed_plan_eq!(
2176 plan,
2177 @r"
2178 Projection: a IS NOT UNKNOWN
2179 EmptyRelation: rows=0
2180 "
2181 )
2182 }
2183
2184 #[test]
2185 fn concat_for_type_coercion() -> Result<()> {
2186 let empty = empty_with_type(Utf8);
2187 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2188
2189 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2191 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2192 })
2193 .call(args.to_vec());
2194 let plan =
2195 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2196 assert_analyzed_plan_eq!(
2197 plan,
2198 @r#"
2199 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2200 EmptyRelation: rows=0
2201 "#
2202 )
2203 }
2204
2205 #[test]
2206 fn test_type_coercion_rewrite() -> Result<()> {
2207 let schema = Arc::new(DFSchema::from_unqualified_fields(
2209 vec![Field::new("a", DataType::Int64, true)].into(),
2210 std::collections::HashMap::new(),
2211 )?);
2212 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2213 let expr = is_true(lit(12i32).gt(lit(13i64)));
2214 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2215 let result = expr.rewrite(&mut rewriter).data()?;
2216 assert_eq!(expected, result);
2217
2218 let schema = Arc::new(DFSchema::from_unqualified_fields(
2220 vec![Field::new("a", DataType::Int64, true)].into(),
2221 std::collections::HashMap::new(),
2222 )?);
2223 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2224 let expr = is_true(lit(12i32).eq(lit(13i64)));
2225 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2226 let result = expr.rewrite(&mut rewriter).data()?;
2227 assert_eq!(expected, result);
2228
2229 let schema = Arc::new(DFSchema::from_unqualified_fields(
2231 vec![Field::new("a", DataType::Int64, true)].into(),
2232 std::collections::HashMap::new(),
2233 )?);
2234 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2235 let expr = is_true(lit(12i32).lt(lit(13i64)));
2236 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2237 let result = expr.rewrite(&mut rewriter).data()?;
2238 assert_eq!(expected, result);
2239
2240 Ok(())
2241 }
2242
2243 #[test]
2244 fn binary_op_date32_eq_ts() -> Result<()> {
2245 let expr = cast(
2246 lit("1998-03-18"),
2247 DataType::Timestamp(TimeUnit::Nanosecond, None),
2248 )
2249 .eq(cast(lit("1998-03-18"), DataType::Date32));
2250 let empty = empty();
2251 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2252
2253 assert_analyzed_plan_eq!(
2254 plan,
2255 @r#"
2256 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2257 EmptyRelation: rows=0
2258 "#
2259 )
2260 }
2261
2262 fn cast_if_not_same_type(
2263 expr: Box<Expr>,
2264 data_type: &DataType,
2265 schema: &DFSchemaRef,
2266 ) -> Box<Expr> {
2267 if &expr.get_type(schema).unwrap() != data_type {
2268 Box::new(cast(*expr, data_type.clone()))
2269 } else {
2270 expr
2271 }
2272 }
2273
2274 fn cast_helper(
2275 case: Case,
2276 case_when_type: &DataType,
2277 then_else_type: &DataType,
2278 schema: &DFSchemaRef,
2279 ) -> Case {
2280 let expr = case
2281 .expr
2282 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2283 let when_then_expr = case
2284 .when_then_expr
2285 .into_iter()
2286 .map(|(when, then)| {
2287 (
2288 cast_if_not_same_type(when, case_when_type, schema),
2289 cast_if_not_same_type(then, then_else_type, schema),
2290 )
2291 })
2292 .collect::<Vec<_>>();
2293 let else_expr = case
2294 .else_expr
2295 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2296
2297 Case {
2298 expr,
2299 when_then_expr,
2300 else_expr,
2301 }
2302 }
2303
2304 #[test]
2305 fn test_case_expression_coercion() -> Result<()> {
2306 let schema = Arc::new(DFSchema::from_unqualified_fields(
2307 vec![
2308 Field::new("boolean", DataType::Boolean, true),
2309 Field::new("integer", DataType::Int32, true),
2310 Field::new("float", DataType::Float32, true),
2311 Field::new(
2312 "timestamp",
2313 DataType::Timestamp(TimeUnit::Nanosecond, None),
2314 true,
2315 ),
2316 Field::new("date", DataType::Date32, true),
2317 Field::new(
2318 "interval",
2319 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2320 true,
2321 ),
2322 Field::new("binary", DataType::Binary, true),
2323 Field::new("string", Utf8, true),
2324 Field::new("decimal", DataType::Decimal128(10, 10), true),
2325 ]
2326 .into(),
2327 std::collections::HashMap::new(),
2328 )?);
2329
2330 let case = Case {
2331 expr: None,
2332 when_then_expr: vec![
2333 (Box::new(col("boolean")), Box::new(col("integer"))),
2334 (Box::new(col("integer")), Box::new(col("float"))),
2335 (Box::new(col("string")), Box::new(col("string"))),
2336 ],
2337 else_expr: None,
2338 };
2339 let case_when_common_type = DataType::Boolean;
2340 let then_else_common_type = Utf8;
2341 let expected = cast_helper(
2342 case.clone(),
2343 &case_when_common_type,
2344 &then_else_common_type,
2345 &schema,
2346 );
2347 let actual = coerce_case_expression(case, &schema)?;
2348 assert_eq!(expected, actual);
2349
2350 let case = Case {
2351 expr: Some(Box::new(col("string"))),
2352 when_then_expr: vec![
2353 (Box::new(col("float")), Box::new(col("integer"))),
2354 (Box::new(col("integer")), Box::new(col("float"))),
2355 (Box::new(col("string")), Box::new(col("string"))),
2356 ],
2357 else_expr: Some(Box::new(col("string"))),
2358 };
2359 let case_when_common_type = Utf8;
2360 let then_else_common_type = Utf8;
2361 let expected = cast_helper(
2362 case.clone(),
2363 &case_when_common_type,
2364 &then_else_common_type,
2365 &schema,
2366 );
2367 let actual = coerce_case_expression(case, &schema)?;
2368 assert_eq!(expected, actual);
2369
2370 let case = Case {
2371 expr: Some(Box::new(col("interval"))),
2372 when_then_expr: vec![
2373 (Box::new(col("float")), Box::new(col("integer"))),
2374 (Box::new(col("binary")), Box::new(col("float"))),
2375 (Box::new(col("string")), Box::new(col("string"))),
2376 ],
2377 else_expr: Some(Box::new(col("string"))),
2378 };
2379 let err = coerce_case_expression(case, &schema).unwrap_err();
2380 assert_snapshot!(
2381 err.strip_backtrace(),
2382 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2383 );
2384
2385 let case = Case {
2386 expr: Some(Box::new(col("string"))),
2387 when_then_expr: vec![
2388 (Box::new(col("float")), Box::new(col("date"))),
2389 (Box::new(col("string")), Box::new(col("float"))),
2390 (Box::new(col("string")), Box::new(col("binary"))),
2391 ],
2392 else_expr: Some(Box::new(col("timestamp"))),
2393 };
2394 let err = coerce_case_expression(case, &schema).unwrap_err();
2395 assert_snapshot!(
2396 err.strip_backtrace(),
2397 @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2398 );
2399
2400 Ok(())
2401 }
2402
2403 macro_rules! test_case_expression {
2404 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2405 let case = Case {
2406 expr: $expr.map(|e| Box::new(col(e))),
2407 when_then_expr: $when_then,
2408 else_expr: None,
2409 };
2410
2411 let expected =
2412 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2413
2414 let actual = coerce_case_expression(case, &$schema)?;
2415 assert_eq!(expected, actual);
2416 };
2417 }
2418
2419 #[test]
2420 fn tes_case_when_list() -> Result<()> {
2421 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2422 let schema = Arc::new(DFSchema::from_unqualified_fields(
2423 vec![
2424 Field::new(
2425 "large_list",
2426 DataType::LargeList(Arc::clone(&inner_field)),
2427 true,
2428 ),
2429 Field::new(
2430 "fixed_list",
2431 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2432 true,
2433 ),
2434 Field::new("list", DataType::List(inner_field), true),
2435 ]
2436 .into(),
2437 std::collections::HashMap::new(),
2438 )?);
2439
2440 test_case_expression!(
2441 Some("list"),
2442 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2443 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2444 Utf8,
2445 schema
2446 );
2447
2448 test_case_expression!(
2449 Some("large_list"),
2450 vec![(Box::new(col("list")), Box::new(lit("1")))],
2451 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2452 Utf8,
2453 schema
2454 );
2455
2456 test_case_expression!(
2457 Some("list"),
2458 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2459 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2460 Utf8,
2461 schema
2462 );
2463
2464 test_case_expression!(
2465 Some("fixed_list"),
2466 vec![(Box::new(col("list")), Box::new(lit("1")))],
2467 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2468 Utf8,
2469 schema
2470 );
2471
2472 test_case_expression!(
2473 Some("fixed_list"),
2474 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2475 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2476 Utf8,
2477 schema
2478 );
2479
2480 test_case_expression!(
2481 Some("large_list"),
2482 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2483 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2484 Utf8,
2485 schema
2486 );
2487 Ok(())
2488 }
2489
2490 #[test]
2491 fn test_then_else_list() -> Result<()> {
2492 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2493 let schema = Arc::new(DFSchema::from_unqualified_fields(
2494 vec![
2495 Field::new("boolean", DataType::Boolean, true),
2496 Field::new(
2497 "large_list",
2498 DataType::LargeList(Arc::clone(&inner_field)),
2499 true,
2500 ),
2501 Field::new(
2502 "fixed_list",
2503 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2504 true,
2505 ),
2506 Field::new("list", DataType::List(inner_field), true),
2507 ]
2508 .into(),
2509 std::collections::HashMap::new(),
2510 )?);
2511
2512 test_case_expression!(
2514 None::<String>,
2515 vec![
2516 (Box::new(col("boolean")), Box::new(col("large_list"))),
2517 (Box::new(col("boolean")), Box::new(col("list")))
2518 ],
2519 DataType::Boolean,
2520 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2521 schema
2522 );
2523
2524 test_case_expression!(
2525 None::<String>,
2526 vec![
2527 (Box::new(col("boolean")), Box::new(col("list"))),
2528 (Box::new(col("boolean")), Box::new(col("large_list")))
2529 ],
2530 DataType::Boolean,
2531 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2532 schema
2533 );
2534
2535 test_case_expression!(
2537 None::<String>,
2538 vec![
2539 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2540 (Box::new(col("boolean")), Box::new(col("list")))
2541 ],
2542 DataType::Boolean,
2543 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2544 schema
2545 );
2546
2547 test_case_expression!(
2548 None::<String>,
2549 vec![
2550 (Box::new(col("boolean")), Box::new(col("list"))),
2551 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2552 ],
2553 DataType::Boolean,
2554 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2555 schema
2556 );
2557
2558 test_case_expression!(
2560 None::<String>,
2561 vec![
2562 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2563 (Box::new(col("boolean")), Box::new(col("large_list")))
2564 ],
2565 DataType::Boolean,
2566 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2567 schema
2568 );
2569
2570 test_case_expression!(
2571 None::<String>,
2572 vec![
2573 (Box::new(col("boolean")), Box::new(col("large_list"))),
2574 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2575 ],
2576 DataType::Boolean,
2577 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2578 schema
2579 );
2580 Ok(())
2581 }
2582
2583 #[test]
2584 fn test_map_with_diff_name() -> Result<()> {
2585 let mut builder = SchemaBuilder::new();
2586 builder.push(Field::new("key", Utf8, false));
2587 builder.push(Field::new("value", DataType::Float64, true));
2588 let struct_fields = builder.finish().fields;
2589
2590 let fields =
2591 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2592 let map_type_entries = DataType::Map(Arc::new(fields), false);
2593
2594 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2595 let may_type_custom = DataType::Map(Arc::new(fields), false);
2596
2597 let expr = col("a").eq(cast(col("a"), may_type_custom));
2598 let empty = empty_with_type(map_type_entries);
2599 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2600
2601 assert_analyzed_plan_eq!(
2602 plan,
2603 @r#"
2604 Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted))
2605 EmptyRelation: rows=0
2606 "#
2607 )
2608 }
2609
2610 #[test]
2611 fn interval_plus_timestamp() -> Result<()> {
2612 let expr = Expr::BinaryExpr(BinaryExpr::new(
2614 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2615 Operator::Plus,
2616 Box::new(cast(
2617 lit("2000-01-01T00:00:00"),
2618 DataType::Timestamp(TimeUnit::Nanosecond, None),
2619 )),
2620 ));
2621 let empty = empty();
2622 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2623
2624 assert_analyzed_plan_eq!(
2625 plan,
2626 @r#"
2627 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2628 EmptyRelation: rows=0
2629 "#
2630 )
2631 }
2632
2633 #[test]
2634 fn timestamp_subtract_timestamp() -> Result<()> {
2635 let expr = Expr::BinaryExpr(BinaryExpr::new(
2636 Box::new(cast(
2637 lit("1998-03-18"),
2638 DataType::Timestamp(TimeUnit::Nanosecond, None),
2639 )),
2640 Operator::Minus,
2641 Box::new(cast(
2642 lit("1998-03-18"),
2643 DataType::Timestamp(TimeUnit::Nanosecond, None),
2644 )),
2645 ));
2646 let empty = empty();
2647 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2648
2649 assert_analyzed_plan_eq!(
2650 plan,
2651 @r#"
2652 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2653 EmptyRelation: rows=0
2654 "#
2655 )
2656 }
2657
2658 #[test]
2659 fn in_subquery_cast_subquery() -> Result<()> {
2660 let empty_int32 = empty_with_type(DataType::Int32);
2661 let empty_int64 = empty_with_type(DataType::Int64);
2662
2663 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2664 Box::new(col("a")),
2665 Subquery {
2666 subquery: empty_int32,
2667 outer_ref_columns: vec![],
2668 spans: Spans::new(),
2669 },
2670 false,
2671 ));
2672 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2673 assert_analyzed_plan_eq!(
2676 plan,
2677 @r"
2678 Filter: a IN (<subquery>)
2679 Subquery:
2680 Projection: CAST(a AS Int64)
2681 EmptyRelation: rows=0
2682 EmptyRelation: rows=0
2683 "
2684 )
2685 }
2686
2687 #[test]
2688 fn in_subquery_cast_expr() -> Result<()> {
2689 let empty_int32 = empty_with_type(DataType::Int32);
2690 let empty_int64 = empty_with_type(DataType::Int64);
2691
2692 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2693 Box::new(col("a")),
2694 Subquery {
2695 subquery: empty_int64,
2696 outer_ref_columns: vec![],
2697 spans: Spans::new(),
2698 },
2699 false,
2700 ));
2701 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2702
2703 assert_analyzed_plan_eq!(
2705 plan,
2706 @r"
2707 Filter: CAST(a AS Int64) IN (<subquery>)
2708 Subquery:
2709 EmptyRelation: rows=0
2710 EmptyRelation: rows=0
2711 "
2712 )
2713 }
2714
2715 #[test]
2716 fn in_subquery_cast_all() -> Result<()> {
2717 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2718 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2719
2720 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2721 Box::new(col("a")),
2722 Subquery {
2723 subquery: empty_inside,
2724 outer_ref_columns: vec![],
2725 spans: Spans::new(),
2726 },
2727 false,
2728 ));
2729 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2730
2731 assert_analyzed_plan_eq!(
2733 plan,
2734 @r"
2735 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2736 Subquery:
2737 Projection: CAST(a AS Decimal128(13, 8))
2738 EmptyRelation: rows=0
2739 EmptyRelation: rows=0
2740 "
2741 )
2742 }
2743}